From 1db92dca36e5a531e9906ac92f03f25b05eb0c39 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 12 Nov 2025 14:27:12 +0800 Subject: [PATCH 1/3] add qwen3 attn --- KT-SFT/ktransformers/operators/attention.py | 142 +++++++++++++++++- .../optimize_rules/Qwen3Moe-sft-amx.yaml | 103 +++++++++++++ 2 files changed, 244 insertions(+), 1 deletion(-) create mode 100644 KT-SFT/ktransformers/optimize/optimize_rules/Qwen3Moe-sft-amx.yaml diff --git a/KT-SFT/ktransformers/operators/attention.py b/KT-SFT/ktransformers/operators/attention.py index 9dfdbdc6..8638e9f9 100644 --- a/KT-SFT/ktransformers/operators/attention.py +++ b/KT-SFT/ktransformers/operators/attention.py @@ -13,7 +13,7 @@ from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.modeling_llama import LlamaRotaryEmbedding from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb -from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention +from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention, Qwen3MoeRotaryEmbedding from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_loader import GGUFLoader @@ -21,6 +21,7 @@ import logging from transformers.configuration_utils import PretrainedConfig from transformers.cache_utils import Cache +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor try: @@ -943,3 +944,142 @@ def forward( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output).to(input_dtype) return attn_output, attn_weights + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention ): + def __init__(self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + chunck_size: int = 1000, + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, + **kwargs) + self.orig_module.__init__(self.orig_module.config, + orig_module.layer_idx) + self.rotary_emb = Qwen3MoeRotaryEmbedding(config) + self.chunck_size = chunck_size # TODO, generate chunck_size automatically. + + # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb + def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def forward(self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.Tensor], + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs + ): + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if position_embeddings is None: + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + cos, sin = position_embeddings + + query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin) + + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + self.config._attn_implementation = "flash_attention_2" + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights diff --git a/KT-SFT/ktransformers/optimize/optimize_rules/Qwen3Moe-sft-amx.yaml b/KT-SFT/ktransformers/optimize/optimize_rules/Qwen3Moe-sft-amx.yaml new file mode 100644 index 00000000..9f723484 --- /dev/null +++ b/KT-SFT/ktransformers/optimize/optimize_rules/Qwen3Moe-sft-amx.yaml @@ -0,0 +1,103 @@ +- match: + class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.RotaryEmbedding + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearTorch" + prefill_op: "KLinearTorch" + +# - match: +# name: "^model\\.layers\\..*$" # regular expression +# class: torch.nn.Linear # only match modules matching name and class simultaneously +# replace: +# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types +# kwargs: +# generate_device: "cuda" +# prefill_device: "cuda" +# generate_op: "KLinearTorch" +# prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearTorch" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock + replace: + class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2 # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KSFTExpertsCPU" + out_device: "cuda" + backend: "AMXInt8" # or "AMXBF16" or "AMXInt8" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KQwen3MoeAttention # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KQwen2MoeModel" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeRMSNorm + replace: + class: ktransformers.operators.layernorm.KQwen3MoeRMSNorm + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeMLP + replace: + class: ktransformers.operators.mlp.KQwen2MoeMLP + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KQwen3MoeModel" + kwargs: + per_layer_prefill_intput_threshold: 0 From bde90afbb94501dcfb5004a2a5be7586ec670908 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 13 Nov 2025 09:21:21 +0800 Subject: [PATCH 2/3] fix KQwen3MoeSparseMoeBlock --- KT-SFT/ktransformers/operators/experts.py | 121 ++++++++++++++++++ .../optimize_rules/Qwen3Moe-sft-amx.yaml | 23 +--- 2 files changed, 124 insertions(+), 20 deletions(-) diff --git a/KT-SFT/ktransformers/operators/experts.py b/KT-SFT/ktransformers/operators/experts.py index d695aaf8..19bbd64f 100644 --- a/KT-SFT/ktransformers/operators/experts.py +++ b/KT-SFT/ktransformers/operators/experts.py @@ -2071,3 +2071,124 @@ def moe_infer(self, x, topk_ids, topk_weight): .type(new_x.dtype) ) return final_out + + +class KQwen3MoeSparseMoeBlock(BaseInjectedModule, Qwen3MoeSparseMoeBlock): + def forward(self, hidden_states): + + orig_shape = hidden_states.shape + sequence_length = orig_shape[1] + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + router_logits = self.gate(hidden_states) + + if router_logits.device.type == "xpu": + from ipex_llm.transformers.models.common import moe_softmax_topk + selected_experts, routing_weights = moe_softmax_topk( + router_logits.half(), self.top_k, self.norm_topk_prob + ) + else: + routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + # only for generate phase + if sequence_length == 1 and hasattr(self.experts.generate_experts, + "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug + self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0], + routing_weights[0]) + # y_ = self.shared_expert(hidden_states).squeeze(0) + # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ + + y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) + + # y += y_ + y.resize_(*orig_shape) + return y + + # y_ = self.shared_expert(hidden_states).squeeze(0) + # y_ = ( + # F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ + # ) + + if isinstance(self.experts, KExpertsBase): + y = self.moe_kexperts(hidden_states, selected_experts, routing_weights).view(*orig_shape).to( + device=hidden_states.device) + elif hidden_states.size(0) > 10: + # TODO may bugs here + y = ( + self.moe_infer(hidden_states, selected_experts, routing_weights) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + else: + # TODO may bugs here + y = ( + self.moe_infer_simple(hidden_states, selected_experts, routing_weights) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + # y += y_ + return y + + @maybe_no_grad() + def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: + outs = self.experts(x, topk_ids, topk_weight) + return outs + + @maybe_no_grad() + # TODO may bugs here + def moe_infer_simple( + self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor + ) -> torch.Tensor: + """ + x: [num_tokens, hidden_size] + topk_ids, topk_weight: [num_tokens, num_selected_experts] + """ + outs = torch.zeros_like(x) + for token_idx in range(topk_ids.size(0)): + for expert_idx in range(topk_ids.size(1)): + expert = self.experts[topk_ids[token_idx, expert_idx]] + outs[token_idx] += ( + expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] + ) + return outs + + @maybe_no_grad() + # TODO may bugs here + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert.forward(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out diff --git a/KT-SFT/ktransformers/optimize/optimize_rules/Qwen3Moe-sft-amx.yaml b/KT-SFT/ktransformers/optimize/optimize_rules/Qwen3Moe-sft-amx.yaml index 9f723484..c3d98287 100644 --- a/KT-SFT/ktransformers/optimize/optimize_rules/Qwen3Moe-sft-amx.yaml +++ b/KT-SFT/ktransformers/optimize/optimize_rules/Qwen3Moe-sft-amx.yaml @@ -39,9 +39,8 @@ prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" - class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock replace: - class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2 # mlp module with custom forward function + class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlock # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" @@ -49,7 +48,7 @@ - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: - class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" @@ -79,25 +78,9 @@ generate_device: "cpu" prefill_device: "cpu" -- match: - class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeRMSNorm - replace: - class: ktransformers.operators.layernorm.KQwen3MoeRMSNorm - kwargs: - generate_device: "cuda" - prefill_device: "cuda" - -- match: - class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeMLP - replace: - class: ktransformers.operators.mlp.KQwen2MoeMLP - kwargs: - generate_device: "cuda" - prefill_device: "cuda" - - match: name: "^model$" replace: class: "ktransformers.operators.models.KQwen3MoeModel" kwargs: - per_layer_prefill_intput_threshold: 0 + per_layer_prefill_intput_threshold: 0 \ No newline at end of file From 3a3ed5b16a9d42d195f544f0e5b6f272c7ddadb6 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 13 Nov 2025 15:26:43 +0800 Subject: [PATCH 3/3] fix bug adapter for llamafactory --- .../models/modeling_qwen3_moe.py | 16 +- KT-SFT/ktransformers/operators/attention.py | 2 - KT-SFT/ktransformers/operators/models.py | 378 ++++++++++++++++++ .../optimize_rules/Qwen3Moe-sft-amx.yaml | 6 - 4 files changed, 386 insertions(+), 16 deletions(-) diff --git a/KT-SFT/ktransformers/models/modeling_qwen3_moe.py b/KT-SFT/ktransformers/models/modeling_qwen3_moe.py index 9576b0ac..5881189f 100644 --- a/KT-SFT/ktransformers/models/modeling_qwen3_moe.py +++ b/KT-SFT/ktransformers/models/modeling_qwen3_moe.py @@ -206,14 +206,14 @@ def forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - # if self.config._attn_implementation != "eager": - # if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - # logger.warning_once( - # "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - # 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - # ) - # else: - # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/KT-SFT/ktransformers/operators/attention.py b/KT-SFT/ktransformers/operators/attention.py index 8638e9f9..30e7de92 100644 --- a/KT-SFT/ktransformers/operators/attention.py +++ b/KT-SFT/ktransformers/operators/attention.py @@ -998,7 +998,6 @@ def __init__(self, **kwargs) self.orig_module.__init__(self.orig_module.config, orig_module.layer_idx) - self.rotary_emb = Qwen3MoeRotaryEmbedding(config) self.chunck_size = chunck_size # TODO, generate chunck_size automatically. # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb @@ -1058,7 +1057,6 @@ def forward(self, key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - self.config._attn_implementation = "flash_attention_2" if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( diff --git a/KT-SFT/ktransformers/operators/models.py b/KT-SFT/ktransformers/operators/models.py index 5ce1136a..941f0fb9 100644 --- a/KT-SFT/ktransformers/operators/models.py +++ b/KT-SFT/ktransformers/operators/models.py @@ -47,7 +47,15 @@ Qwen2MoeSparseMoeBlock, Qwen2MoeMLP, Qwen2MoeDecoderLayer, + Qwen2MoeRotaryEmbedding, ) + +from ktransformers.models.modeling_qwen3_moe import ( + Qwen3MoeSparseMoeBlock, + Qwen3MoeMLP, + Qwen3MoeDecoderLayer, +) + from ktransformers.models.modeling_deepseek import ( BaseModelOutputWithPast, DeepseekV2DecoderLayer, @@ -1376,3 +1384,373 @@ def _update_causal_mask( ) return causal_mask + + + +QWEN3MOE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + +class KQwen3MoeModel(BaseInjectedModule): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3MoeDecoderLayer`] + + Args: + config: Qwen3MoeConfig + """ + + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + device: str = "cuda", + per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill + transfer_map: dict = None, + **kwargs, + ): + BaseInjectedModule.__init__( + self, key, gguf_loader, config, orig_module, device, **kwargs + ) + self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold + self.transfer_map = transfer_map + self.stream_device_map = dict() + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.rotary_emb = Qwen2MoeRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + @add_start_docstrings_to_model_forward(QWEN3MOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + per_layer_prefill_intput_threshold: ( + int | None + ) = None, # if None or 0, close per-layer prefill + ) -> Union[Tuple, MoeModelOutputWithPast]: + # print(f'Total length of input_ids: {input_ids.size(1)}, {input_ids.size()}') + + if per_layer_prefill_intput_threshold is None: + per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold + per_layer_prefill_flag = False + seq_lenth = ( + inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1) + ) + if ( + per_layer_prefill_intput_threshold + and per_layer_prefill_intput_threshold < seq_lenth + ): + per_layer_prefill_flag = True + for layer in self.layers: + self.load_layer_to(layer, InferenceState.UNLOAD) + else: + pass + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + # use_legacy_cache = False + # if use_cache and not isinstance(past_key_values, Cache): + # use_legacy_cache = True + # past_key_values = DynamicCache.from_legacy_cache(past_key_values) + # logger.warning_once( + # "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + # "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" + # ) + + if inputs_embeds is None: + input_ids = input_ids.to("cpu") + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = inputs_embeds.to("cuda") + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + hidden_states = inputs_embeds + + # position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + # next_decoder_cache = None + + for i, decoder_layer in enumerate(self.layers): + # if self.transfer_map is not None and i in self.transfer_map: + # prev_stream = torch.cuda.current_stream() + # cur_device = self.transfer_map[i] + # if cur_device not in self.stream_device_map: + # self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) + # torch.cuda.set_device(cur_device) + # self.stream_device_map[cur_device].wait_stream(prev_stream) + # torch.cuda.set_stream(self.stream_device_map[cur_device]) + # hidden_states = hidden_states.to( + # self.transfer_map[i], non_blocking=True + # ) + # causal_mask = ( + # causal_mask.to(self.transfer_map[i], non_blocking=True) + # if causal_mask is not None + # else None + # ) + # position_ids = ( + # position_ids.to(self.transfer_map[i], non_blocking=True) + # if position_ids is not None + # else None + # ) + # cache_position = ( + # cache_position.to(self.transfer_map[i], non_blocking=True) + # if cache_position is not None + # else None + # ) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + # position_embeddings, + ) + else: + if per_layer_prefill_flag: + # print(f"to gpu") + self.load_layer_to(decoder_layer, InferenceState.PREFILL) + torch.cuda.empty_cache() + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + # position_embeddings=position_embeddings, + ) + if per_layer_prefill_flag: + # print(f"to cpu") + self.load_layer_to(decoder_layer, InferenceState.UNLOAD) + torch.cuda.empty_cache() + hidden_states = layer_outputs[0] + # use_cache=False + # if use_cache: + # next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits and layer_outputs[-1] is not None: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + if per_layer_prefill_flag: + per_layer_prefill_flag = False + for layer in self.layers: + self.load_layer_to(layer, InferenceState.GENERATE) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # next_cache = None + # if use_cache: + # next_cache = ( + # next_decoder_cache.to_legacy_cache() + # if use_legacy_cache + # else next_decoder_cache + # ) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attns, + all_router_logits, + ] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + def load_layer_to(self, layer: Qwen3MoeDecoderLayer, target: InferenceState): + assert isinstance( + layer, Qwen3MoeDecoderLayer + ), "module should be nn.ModuleList of decoder layers" + + # TODO Support restore to original device, not only cuda + device = "cpu" if target == InferenceState.UNLOAD else "cuda" + + # attn + layer.self_attn.q_proj.set_inference_mode(target) + layer.self_attn.k_proj.set_inference_mode(target) + layer.self_attn.v_proj.set_inference_mode(target) + layer.self_attn.o_proj.set_inference_mode(target) + layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(device) + + # mlp + if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): + layer.mlp.gate.set_inference_mode(target) + layer.mlp.experts.set_inference_mode(target) + layer.mlp.shared_expert.gate_proj.set_inference_mode(target) + layer.mlp.shared_expert.up_proj.set_inference_mode(target) + layer.mlp.shared_expert.down_proj.set_inference_mode(target) + layer.mlp.shared_expert.act_fn.to(device) + layer.mlp.shared_expert_gate.to(device) + else: + layer.mlp.gate_proj.set_inference_mode(target) + layer.mlp.up_proj.set_inference_mode(target) + layer.mlp.down_proj.set_inference_mode(target) + layer.mlp.act_fn.to(device) + # layer norm + layer.input_layernorm.to(device) + layer.post_attention_layernorm.to(device) diff --git a/KT-SFT/ktransformers/optimize/optimize_rules/Qwen3Moe-sft-amx.yaml b/KT-SFT/ktransformers/optimize/optimize_rules/Qwen3Moe-sft-amx.yaml index c3d98287..b8eceb27 100644 --- a/KT-SFT/ktransformers/optimize/optimize_rules/Qwen3Moe-sft-amx.yaml +++ b/KT-SFT/ktransformers/optimize/optimize_rules/Qwen3Moe-sft-amx.yaml @@ -64,12 +64,6 @@ kwargs: generate_device: "cuda" prefill_device: "cuda" -- match: - name: "^model$" - replace: - class: "ktransformers.operators.models.KQwen2MoeModel" - kwargs: - per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: