diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 866aa536f19..e19b55c00b7 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -563,6 +563,10 @@ def prepare_tensors(self): gguf.MODEL_TENSOR.A_ENC_EMBD_POS, gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF, gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF, + # Kimi KDA conv weights should be F32 + gguf.MODEL_TENSOR.SSM_CONV1D_Q, + gguf.MODEL_TENSOR.SSM_CONV1D_K, + gguf.MODEL_TENSOR.SSM_CONV1D_V, ) ) or new_name[-7:] not in (".weight", ".lora_a", ".lora_b") @@ -2722,6 +2726,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] # skip other tensors +# KimiLinearModel is defined later in this file (line ~5140) as a TextModel subclass +# This old definition has been removed to avoid conflicts + + @ModelBase.register( "Llama4ForConditionalGeneration", "Llama4ForCausalLM", @@ -5108,8 +5116,298 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k), (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v), ] + + +@ModelBase.register("KimiLinearModel", "KimiLinearForCausalLM") +class KimiLinearModel(TextModel): + """Kimi-Linear model with hybrid MLA+KDA architecture""" + model_arch = gguf.MODEL_ARCH.KIMI_LINEAR + + _experts: list[dict[str, Tensor]] | None = None + + def set_gguf_parameters(self): + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + # Use find_hparam for context length + # Kimi uses model_max_length + n_ctx = self.find_hparam(["max_position_embeddings", "model_max_length", "n_ctx", "n_positions"], optional=True) + if n_ctx is not None: + self.gguf_writer.add_context_length(n_ctx) else: - return [(self.map_tensor_name(name), data_torch)] + # Default to 4096 if not found + logger.warning("No context length found in config, defaulting to 4096") + self.gguf_writer.add_context_length(4096) + + self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_file_type(self.ftype) + + # KDA & MLA params + # Get ssm_d_conv from linear_attn_config.short_conv_kernel_size or ssm_d_conv + linear_attn_config = self.hparams.get("linear_attn_config", {}) + ssm_d_conv = self.hparams.get("ssm_d_conv") or linear_attn_config.get("short_conv_kernel_size") + if ssm_d_conv is not None: + self.gguf_writer.add_ssm_conv_kernel(ssm_d_conv) + + # MLA params - use add_* methods that handle arch substitution + # Support both HuggingFace naming (q_lora_rank, kv_lora_rank) and internal naming (n_lora_q, n_lora_kv) + q_lora_rank = self.hparams.get("q_lora_rank", self.hparams.get("n_lora_q")) + kv_lora_rank = self.hparams.get("kv_lora_rank", self.hparams.get("n_lora_kv")) + + if q_lora_rank is not None: + self.gguf_writer.add_q_lora_rank(q_lora_rank) + if kv_lora_rank is not None: + self.gguf_writer.add_kv_lora_rank(kv_lora_rank) + + # MLA head dimensions + # Support HuggingFace naming: qk_nope_head_dim, qk_rope_head_dim, v_head_dim + qk_nope_head_dim = self.hparams.get("qk_nope_head_dim") + qk_rope_head_dim = self.hparams.get("qk_rope_head_dim", self.hparams.get("n_rot")) + v_head_dim = self.hparams.get("v_head_dim") + + # Calculate n_embd_head_k_mla = qk_nope_head_dim + qk_rope_head_dim + if "n_embd_head_k_mla" in self.hparams: + self.gguf_writer.add_key_length_mla(self.hparams["n_embd_head_k_mla"]) + elif qk_nope_head_dim is not None and qk_rope_head_dim is not None: + n_embd_head_k_mla = qk_nope_head_dim + qk_rope_head_dim + self.gguf_writer.add_key_length_mla(n_embd_head_k_mla) + + # n_embd_head_v_mla = v_head_dim + if "n_embd_head_v_mla" in self.hparams: + self.gguf_writer.add_value_length_mla(self.hparams["n_embd_head_v_mla"]) + elif v_head_dim is not None: + self.gguf_writer.add_value_length_mla(v_head_dim) + + # Rotation - use qk_rope_head_dim for Kimi + rope_dim = self.hparams.get("qk_rope_head_dim") or self.hparams.get("n_rot") + if rope_dim is not None: + self.gguf_writer.add_rope_dimension_count(rope_dim) + else: + # Default to head_dim + head_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(head_dim) + + self.gguf_writer.add_rope_freq_base(self.hparams.get("rope_theta", 10000.0)) + + # MoE params + n_experts = self.hparams.get("num_local_experts", self.hparams.get("num_experts")) + if n_experts is not None: + self.gguf_writer.add_expert_count(n_experts) + # Support both num_experts_per_tok and num_experts_per_token + n_experts_used = self.hparams.get("num_experts_per_tok", self.hparams.get("num_experts_per_token")) + if n_experts_used is not None: + self.gguf_writer.add_expert_used_count(n_experts_used) + + # moe_intermediate_size (1024 for Kimi) + moe_intermediate_size = self.hparams.get("moe_intermediate_size") + if moe_intermediate_size is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + + # num_shared_experts (1 for Kimi) + num_shared_experts = self.hparams.get("num_shared_experts") + if num_shared_experts is not None: + self.gguf_writer.add_expert_shared_count(num_shared_experts) + + # first_k_dense_replace (1 for Kimi - first layer uses dense MLP) + first_k_dense_replace = self.hparams.get("first_k_dense_replace") + if first_k_dense_replace is not None: + self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace) + + # Expert gating function (sigmoid for Kimi) + moe_router_activation_func = self.hparams.get("moe_router_activation_func", "sigmoid") + if moe_router_activation_func == "sigmoid": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + elif moe_router_activation_func == "softmax": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + logger.warning(f"Unknown moe_router_activation_func: {moe_router_activation_func}, defaulting to sigmoid") + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + # Routed scaling factor (expert_weights_scale = 2.446 for Kimi) + routed_scaling_factor = self.hparams.get("routed_scaling_factor") + if routed_scaling_factor is not None: + self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) + + def set_vocab(self): + # Kimi uses TikToken tokenizer - load via transformers + from transformers import AutoTokenizer + + dir_model = self.dir_model + vocab_size = self.hparams["vocab_size"] + + logger.info(f"Loading TikToken tokenizer from {dir_model}") + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) + + tokens: list[str] = [] + toktypes: list[int] = [] + + # Get tokenizer pre string + tokpre = self.get_vocab_base_pre(tokenizer) + + # Build vocab from tokenizer + merges = [] + vocab = {} + + # TikToken stores vocab in mergeable_ranks + if hasattr(tokenizer, 'mergeable_ranks'): + mergeable_ranks = tokenizer.mergeable_ranks + for token, rank in mergeable_ranks.items(): + vocab[self._token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + # Build merges + merged = self._bpe(mergeable_ranks, token, max_rank=rank) + if len(merged) == 2: + merges.append(' '.join(map(self._token_bytes_to_string, merged))) + else: + # Fallback: get vocab directly + vocab = {tok: idx for tok, idx in tokenizer.get_vocab().items()} + + # Get special tokens + added_vocab = {} + if hasattr(tokenizer, 'special_tokens'): + added_vocab = tokenizer.special_tokens + elif hasattr(tokenizer, 'added_tokens_encoder'): + added_vocab = tokenizer.added_tokens_encoder + + # Combine vocab + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()} + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + elif i in added_vocab.values() if added_vocab else False: + tokens.append(reverse_vocab[i]) + toktypes.append(gguf.TokenType.CONTROL) + else: + tokens.append(reverse_vocab[i]) + toktypes.append(gguf.TokenType.NORMAL) + + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(dir_model, load_merges=False) + special_vocab.merges = merges + special_vocab.add_to_gguf(self.gguf_writer) + logger.info(f"Loaded {len(tokens)} tokens, {len(merges)} merges") + + @staticmethod + def _token_bytes_to_string(b: bytes) -> str: + """Convert bytes to string representation for tokenizer""" + return ''.join([chr(byte) if byte < 128 else f'<0x{byte:02X}>' for byte in b]) + + @staticmethod + def _bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]: + """BPE tokenization for merges extraction""" + parts = [bytes([b]) for b in token] + while True: + min_idx = None + min_rank = None + for i, pair in enumerate(zip(parts[:-1], parts[1:])): + rank = mergeable_ranks.get(pair[0] + pair[1]) + if rank is not None and (min_rank is None or rank < min_rank): + min_idx = i + min_rank = rank + if min_rank is None or (max_rank is not None and min_rank >= max_rank): + break + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:] + return parts + + def prepare_tensors(self): + super().prepare_tensors() + if self._experts is not None: + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + logger.info(f"Processing {name}: shape before = {tuple(data_torch.shape)}") + + # Handle KDA conv1d weights + # HuggingFace/vLLM stores as [d_inner, d_conv] (2D), memory layout: conv_step changes fastest + # llama.cpp expects ggml ne = [d_conv, 1, d_inner, 1], memory layout: ne[0]=d_conv changes fastest + # GGUF reverses numpy shape when writing, so numpy (1, d_inner, 1, d_conv) -> ggml ne = [d_conv, 1, d_inner, 1] + # Memory layouts match: both have conv_step (d_conv) changing fastest + if name.endswith((".q_conv1d.weight", ".k_conv1d.weight", ".v_conv1d.weight")): + # HF shape: [d_inner, d_conv] e.g. [4096, 4] + # Target numpy shape: (1, d_inner, 1, d_conv) -> ggml ne = [d_conv, 1, d_inner, 1] + if data_torch.ndim == 2: + d_inner, d_conv = data_torch.shape + # Reshape to (1, d_inner, 1, d_conv) - memory layout preserved (d_conv fastest) + data_torch = data_torch.reshape(1, d_inner, 1, d_conv) + logger.info(f"Reshaped conv1d weight {name}: [d_inner={d_inner}, d_conv={d_conv}] -> numpy {tuple(data_torch.shape)} -> ggml ne=[{d_conv}, 1, {d_inner}, 1]") + elif data_torch.ndim == 3: + # Already 3D [d_inner, 1, d_conv] from unsqueeze + d_inner, _, d_conv = data_torch.shape + data_torch = data_torch.reshape(1, d_inner, 1, d_conv) + logger.info(f"Reshaped conv1d weight {name}: [d_inner={d_inner}, 1, d_conv={d_conv}] -> numpy {tuple(data_torch.shape)} -> ggml ne=[{d_conv}, 1, {d_inner}, 1]") + + # Handle A_log: HF stores as [1, 1, num_heads, 1] + # llama.cpp expects ggml ne = [1, num_heads, 1, 1] + # GGUF reverses numpy shape: numpy (1, 1, num_heads, 1) -> ggml ne = [1, num_heads, 1, 1] + # So no transformation needed! The shapes already match after GGUF reversal. + if name.endswith(".A_log"): + if data_torch.ndim == 4: + logger.info(f"A_log {name}: numpy {tuple(data_torch.shape)} -> ggml ne={list(reversed(data_torch.shape))}") + + # Kimi specific bias + if name.endswith("block_sparse_moe.gate.e_score_correction_bias"): + new_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_EXP_PROBS_B, bid) + return [(new_name, data_torch)] + + # process the experts separately + if name.find("block_sparse_moe.experts") != -1: + n_experts = self.hparams.get("num_local_experts", self.hparams.get("num_experts")) + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + # merge the experts into a single 3d tensor + tensors = [] + # w1: gate, w2: down, w3: up + for wid, tname in [("w1", gguf.MODEL_TENSOR.FFN_GATE_EXP), + ("w2", gguf.MODEL_TENSOR.FFN_DOWN_EXP), + ("w3", gguf.MODEL_TENSOR.FFN_UP_EXP)]: + datas: list[Tensor] = [] + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + new_name = self.format_tensor_name(tname, bid) + tensors.append((new_name, data_torch)) + return tensors + return [] + + mapped_name = self.map_tensor_name(name) + logger.info(f"Returning {mapped_name}: shape after = {tuple(data_torch.shape)}") + return [(mapped_name, data_torch)] + + def get_vocab_base(self) -> tuple[list[str], list[int], str]: + # This method is not used when set_vocab is overridden + # But adding it for completeness in case it's called elsewhere + logger.warning("get_vocab_base called, but set_vocab is already overridden") + vocab_size = self.hparams.get("vocab_size", 100) + tokens = [f"" for i in range(vocab_size)] + tokens[0] = "" + tokens[1] = "" + tokens[2] = "" + toktypes = [gguf.TokenType.NORMAL] * vocab_size + return tokens, toktypes, "gpt-2" + + # Note: set_vocab() is defined earlier in this class (around line 5144) @ModelBase.register("InternLM3ForCausalLM") diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 4dbca868bc7..3262c0e31ba 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -539,6 +539,7 @@ extern "C" { GGML_OP_FLASH_ATTN_BACK, GGML_OP_SSM_CONV, GGML_OP_SSM_SCAN, + GGML_OP_KDA_SCAN, GGML_OP_WIN_PART, GGML_OP_WIN_UNPART, GGML_OP_GET_REL_POS, @@ -2336,6 +2337,28 @@ extern "C" { struct ggml_tensor * C, struct ggml_tensor * ids); + // KDA (Kimi Delta Attention) scan + // Delta attention recurrence: + // h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t])) + // o[t] = q[t]^T @ h[t] + // Parameters: + // h: hidden state {head_dim, head_dim, n_head, n_seqs+} + // q: query {head_dim, n_head, n_seq_tokens, n_seqs} + // k: key {head_dim, n_head, n_seq_tokens, n_seqs} + // v: value {head_dim, n_head, n_seq_tokens, n_seqs} + // g: gate {head_dim, n_head, n_seq_tokens, n_seqs} + // beta: mixing {n_head, n_seq_tokens, n_seqs} + // ids: seq indices {n_seqs} + GGML_API struct ggml_tensor * ggml_kda_scan( + struct ggml_context * ctx, + struct ggml_tensor * h, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * ids); + // partition into non-overlapping windows with padding if needed // example: // a: 768 64 64 1 diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 3247af8bb03..7b40f1e8c2c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1962,6 +1962,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_ssm_scan(params, tensor); } break; + case GGML_OP_KDA_SCAN: + { + ggml_compute_forward_kda_scan(params, tensor); + } break; case GGML_OP_WIN_PART: { ggml_compute_forward_win_part(params, tensor); @@ -2320,6 +2324,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: + case GGML_OP_KDA_SCAN: case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2745fc54e15..32d492ca240 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8627,6 +8627,9 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; + static int conv_debug_count = 0; + bool do_conv_debug = false; // (ith == 0 && conv_debug_count++ < 3); + for (int i3 = 0; i3 < n_s; ++i3) { for (int i2 = 0; i2 < n_t; ++i2) { // {d_conv - 1 + n_t, d_inner, n_seqs} @@ -8647,6 +8650,13 @@ static void ggml_compute_forward_ssm_conv_f32( sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; } x[i1] = sumf; + + // Debug output + if (do_conv_debug && i1 == 0 && i2 == 0 && i3 == 0) { + fprintf(stderr, "DEBUG SSM_CONV: nc=%d, nr=%d, n_t=%d, n_s=%d\n", nc, nr, n_t, n_s); + fprintf(stderr, "DEBUG SSM_CONV: s[0..3]=%f,%f,%f,%f, c[0..3]=%f,%f,%f,%f, x[0]=%f\n", + s[0], s[1], s[2], s[3], c[0], c[1], c[2], c[3], x[0]); + } } } } @@ -8897,6 +8907,192 @@ void ggml_compute_forward_ssm_scan( } } +// ggml_compute_forward_kda_scan +// KDA (Kimi Delta Attention) recurrence: +// h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t])) +// o[t] = q[t]^T @ h[t] + +static void ggml_compute_forward_kda_scan_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // h {head_dim, head_dim, n_head, n_seqs+} + const ggml_tensor * src1 = dst->src[1]; // q {head_dim, n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src2 = dst->src[2]; // k {head_dim, n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src3 = dst->src[3]; // v {head_dim, n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src4 = dst->src[4]; // g {head_dim, n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src5 = dst->src[5]; // beta {n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs} + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t head_dim = src0->ne[0]; + const int64_t n_head = src1->ne[1]; + const int64_t n_seq_tokens = src1->ne[2]; + const int64_t n_seqs = src1->ne[3]; + + // Output offset for hidden state + const int64_t y_off = ggml_nelements(src1) * sizeof(float); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(float)); + GGML_ASSERT(src3->nb[0] == sizeof(float)); + GGML_ASSERT(src4->nb[0] == sizeof(float)); + GGML_ASSERT(src5->nb[0] == sizeof(float)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); + + // Parallelize over heads + const int dh = (n_head + nth - 1) / nth; + const int ih0 = dh * ith; + const int ih1 = MIN(ih0 + dh, (int)n_head); + + const int32_t * ids = (const int32_t *) src6->data; + + // Temporary buffer for h @ k computation + float * hk_buf = (float *) malloc(head_dim * sizeof(float)); + + static int debug_count = 0; + bool do_debug = false; // (ith == 0 && debug_count++ < 20); + + for (int i3 = 0; i3 < n_seqs; ++i3) { + // Get initial hidden state for this sequence + const float * h0 = (const float *) ((const char *) src0->data + ids[i3] * src0->nb[3]); + // Output hidden state location + float * h_out = (float *) ((char *) dst->data + i3 * src0->nb[3] + y_off); + + for (int ih = ih0; ih < ih1; ++ih) { + // Per-head hidden state: [head_dim, head_dim] + // Copy initial state to output (will be updated in place) + const float * h_in = h0 + ih * head_dim * head_dim; + float * h = h_out + ih * head_dim * head_dim; + + // Copy initial state, but check for invalid values and clear if needed + bool need_clear = false; + for (int i = 0; i < head_dim * head_dim && !need_clear; ++i) { + if (!isfinite(h_in[i]) || fabsf(h_in[i]) > 1e6f) { + need_clear = true; + } + } + for (int i = 0; i < head_dim * head_dim; ++i) { + h[i] = need_clear ? 0.0f : h_in[i]; + } + + for (int it = 0; it < n_seq_tokens; ++it) { + const float * q_raw = (const float *) ((const char *) src1->data + + it * src1->nb[2] + i3 * src1->nb[3]) + ih * head_dim; + const float * k_raw = (const float *) ((const char *) src2->data + + it * src2->nb[2] + i3 * src2->nb[3]) + ih * head_dim; + const float * v = (const float *) ((const char *) src3->data + + it * src3->nb[2] + i3 * src3->nb[3]) + ih * head_dim; + const float * g = (const float *) ((const char *) src4->data + + it * src4->nb[2] + i3 * src4->nb[3]) + ih * head_dim; + const float beta = ((const float *) ((const char *) src5->data + + it * src5->nb[1] + i3 * src5->nb[2]))[ih]; + + float * y = (float *) dst->data + + it * n_head * head_dim + i3 * n_seq_tokens * n_head * head_dim + ih * head_dim; + + // L2 normalize q and k (critical for KDA stability) + float q_norm = 0.0f, k_norm = 0.0f; + for (int i = 0; i < head_dim; ++i) { + q_norm += q_raw[i] * q_raw[i]; + k_norm += k_raw[i] * k_raw[i]; + } + q_norm = sqrtf(q_norm + 1e-6f); + k_norm = sqrtf(k_norm + 1e-6f); + + // Debug output + if (do_debug && ih == 0 && it == 0 && i3 == 0) { + fprintf(stderr, "DEBUG KDA: q_raw[0]=%f, k_raw[0]=%f, v[0]=%f, g[0]=%f, beta=%f\n", + q_raw[0], k_raw[0], v[0], g[0], beta); + fprintf(stderr, "DEBUG KDA: q_norm=%f, k_norm=%f, exp(g[0])=%f, scale=%f\n", + q_norm, k_norm, expf(g[0]), 1.0f / sqrtf((float)head_dim)); + } + + // Normalized q and k with scale = 1/sqrt(head_dim) + // Note: scale is applied only to q after L2 normalization + const float scale = 1.0f / sqrtf((float)head_dim); + float q[128], k[128]; // assume head_dim <= 128 + for (int i = 0; i < head_dim; ++i) { + // L2 normalize then scale q + q[i] = (q_raw[i] / q_norm) * scale; + // L2 normalize k (no scale) + k[i] = k_raw[i] / k_norm; + } + + // KDA recurrence: h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t])) + // Note: Apply decay first, then compute retrieval and update + + // Step 1: Apply decay to h first: h = h * exp(g) + for (int i = 0; i < head_dim; ++i) { + const float exp_gi = expf(g[i]); + for (int j = 0; j < head_dim; ++j) { + h[i * head_dim + j] *= exp_gi; + } + } + + // Step 2: Compute h^T @ k -> hk_buf [head_dim] + // hk_buf[j] = sum_i (h[i,j] * k[i]) which is column j of h dotted with k + for (int j = 0; j < head_dim; ++j) { + float sum = 0.0f; + for (int i = 0; i < head_dim; ++i) { + sum += h[i * head_dim + j] * k[i]; + } + hk_buf[j] = sum; + } + + // Step 3: Compute delta = beta * (v - hk) and update h + // h = h + outer(k, delta) where outer(k,delta)[i,j] = k[i] * delta[j] + for (int i = 0; i < head_dim; ++i) { + for (int j = 0; j < head_dim; ++j) { + const float delta_j = beta * (v[j] - hk_buf[j]); + h[i * head_dim + j] += k[i] * delta_j; + } + } + + // Step 4: Compute output y = h^T @ q -> [head_dim] + // vLLM: b_o = tl.sum(b_h * b_q[:, None], 0) means o[j] = sum_i(h[i,j] * q[i]) + for (int j = 0; j < head_dim; ++j) { + float sum = 0.0f; + for (int i = 0; i < head_dim; ++i) { + sum += h[i * head_dim + j] * q[i]; + } + y[j] = sum; + } + + // Debug output + if (do_debug && ih == 0 && it == 0 && i3 == 0) { + // Find max abs value in h for stability check + float h_max = 0.0f; + for (int i = 0; i < head_dim * head_dim; i++) { + if (fabsf(h[i]) > h_max) h_max = fabsf(h[i]); + } + fprintf(stderr, "DEBUG KDA: y[0]=%.6f, h_max=%.6f, exp(g[0])=%.6f\n", + y[0], h_max, expf(g[0])); + } + } + } + } + + free(hk_buf); +} + +void ggml_compute_forward_kda_scan( + const ggml_compute_params * params, + ggml_tensor * dst) { + switch (dst->src[0]->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_kda_scan_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_win_part static void ggml_compute_forward_win_part_f32( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 0fdfee79766..080cf6e090e 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -92,6 +92,7 @@ void ggml_compute_forward_flash_attn_back( struct ggml_tensor * dst); void ggml_compute_forward_ssm_conv(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_kda_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fa7e1e13a71..1b85c0c325d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -41,6 +41,7 @@ #include "ggml-cuda/softmax.cuh" #include "ggml-cuda/ssm-conv.cuh" #include "ggml-cuda/ssm-scan.cuh" +#include "ggml-cuda/kda-scan.cuh" #include "ggml-cuda/sum.cuh" #include "ggml-cuda/sumrows.cuh" #include "ggml-cuda/mean.cuh" @@ -2692,6 +2693,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SSM_SCAN: ggml_cuda_op_ssm_scan(ctx, dst); break; + case GGML_OP_KDA_SCAN: + ggml_cuda_op_kda_scan(ctx, dst); + break; case GGML_OP_ARGSORT: ggml_cuda_op_argsort(ctx, dst); break; @@ -4503,6 +4507,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1; } } + case GGML_OP_KDA_SCAN: { + // KDA scan kernel supports head_dim 64 or 128 + const int64_t head_dim = op->src[0]->ne[0]; + return head_dim == 64 || head_dim == 128; + } case GGML_OP_SSM_CONV: { // assumes d_inner % threads == 0 return op->src[0]->ne[1] % 128 == 0; diff --git a/ggml/src/ggml-cuda/kda-scan.cu b/ggml/src/ggml-cuda/kda-scan.cu new file mode 100644 index 00000000000..5763f1cc90a --- /dev/null +++ b/ggml/src/ggml-cuda/kda-scan.cu @@ -0,0 +1,209 @@ +#include "kda-scan.cuh" + +// KDA (Kimi Delta Attention) scan CUDA kernel +// Recurrence: +// h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t])) +// o[t] = q[t]^T @ h[t] +// +// This kernel uses global memory for the hidden state to avoid shared memory limits. +// Each block processes one head for one sequence. + +__global__ void kda_scan_f32_kernel( + const float * __restrict__ src0, // h: [head_dim, head_dim, n_head, n_seqs+] + const float * __restrict__ src1, // q: [head_dim, n_head, n_seq_tokens, n_seqs] + const float * __restrict__ src2, // k: [head_dim, n_head, n_seq_tokens, n_seqs] + const float * __restrict__ src3, // v: [head_dim, n_head, n_seq_tokens, n_seqs] + const float * __restrict__ src4, // g: [head_dim, n_head, n_seq_tokens, n_seqs] + const float * __restrict__ src5, // beta: [n_head, n_seq_tokens, n_seqs] + const int32_t * __restrict__ src6, // ids: [n_seqs] + float * __restrict__ dst, + const int64_t head_dim, + const int64_t n_head, + const int64_t n_seq_tokens, + const int64_t n_seqs, + const int64_t y_off) // offset to state output in dst (in floats) +{ + // Each block handles one head for one sequence + const int seq_idx = blockIdx.x / n_head; + const int head_idx = blockIdx.x % n_head; + const int tid = threadIdx.x; + const int n_threads = blockDim.x; + + if (seq_idx >= n_seqs || head_idx >= n_head) return; + + // Get sequence ID for initial state + const int src_seq = src6[seq_idx]; + + // Shared memory for temporary buffers + extern __shared__ float smem[]; + float * hk_buf = smem; // [head_dim] - h @ k buffer + float * q_norm = smem + head_dim; // [head_dim] - normalized q + float * k_norm = q_norm + head_dim; // [head_dim] - normalized k + float * warp_sums = k_norm + head_dim; // [64] - for reductions + + // Pointers to input/output data for this head + const int64_t h_stride_head = head_dim * head_dim; + const int64_t h_stride_seq = h_stride_head * n_head; + const int64_t qkv_stride_head = head_dim; + const int64_t qkv_stride_token = head_dim * n_head; + const int64_t qkv_stride_seq = qkv_stride_token * n_seq_tokens; + const int64_t beta_stride_token = n_head; + const int64_t beta_stride_seq = beta_stride_token * n_seq_tokens; + + const float * h_in = src0 + src_seq * h_stride_seq + head_idx * h_stride_head; + float * h_out = dst + y_off + seq_idx * h_stride_seq + head_idx * h_stride_head; + float * y_out = dst + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head; + + // Copy initial state to output (we'll update in place) + for (int i = tid; i < head_dim * head_dim; i += n_threads) { + float val = h_in[i]; + if (!isfinite(val) || fabsf(val) > 1e6f) { + val = 0.0f; + } + h_out[i] = val; + } + __syncthreads(); + + const float scale = 1.0f / sqrtf((float)head_dim); + + // Process each token sequentially + for (int t = 0; t < n_seq_tokens; ++t) { + const float * q_raw = src1 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head; + const float * k_raw = src2 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head; + const float * v = src3 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head; + const float * g = src4 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head; + const float beta = src5[t * beta_stride_token + seq_idx * beta_stride_seq + head_idx]; + float * y = y_out + t * qkv_stride_token; + + // Step 1: L2 normalize q and k + float q_sq_sum = 0.0f, k_sq_sum = 0.0f; + for (int i = tid; i < head_dim; i += n_threads) { + q_sq_sum += q_raw[i] * q_raw[i]; + k_sq_sum += k_raw[i] * k_raw[i]; + } + + // Warp reduction + for (int offset = warpSize/2; offset > 0; offset /= 2) { + q_sq_sum += __shfl_down_sync(0xffffffff, q_sq_sum, offset); + k_sq_sum += __shfl_down_sync(0xffffffff, k_sq_sum, offset); + } + + // Cross-warp reduction + int warp_id = tid / warpSize; + int lane_id = tid % warpSize; + if (lane_id == 0 && warp_id < 32) { + warp_sums[warp_id] = q_sq_sum; + warp_sums[32 + warp_id] = k_sq_sum; + } + __syncthreads(); + + if (tid == 0) { + float total_q = 0.0f, total_k = 0.0f; + for (int i = 0; i < (n_threads + warpSize - 1) / warpSize; ++i) { + total_q += warp_sums[i]; + total_k += warp_sums[32 + i]; + } + warp_sums[0] = rsqrtf(total_q + 1e-6f) * scale; + warp_sums[1] = rsqrtf(total_k + 1e-6f); + } + __syncthreads(); + + float q_norm_factor = warp_sums[0]; + float k_norm_factor = warp_sums[1]; + + // Store normalized q and k + for (int i = tid; i < head_dim; i += n_threads) { + q_norm[i] = q_raw[i] * q_norm_factor; + k_norm[i] = k_raw[i] * k_norm_factor; + } + __syncthreads(); + + // KDA recurrence: h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t])) + // Apply decay first, then compute retrieval and update + + // Step 2: Apply decay to h: h = h * exp(g) + for (int idx = tid; idx < head_dim * head_dim; idx += n_threads) { + int i = idx / head_dim; + float exp_gi = expf(g[i]); + h_out[idx] *= exp_gi; + } + __syncthreads(); + + // Step 3: Compute h^T @ k -> hk_buf + for (int j = tid; j < head_dim; j += n_threads) { + float sum = 0.0f; + for (int i = 0; i < head_dim; ++i) { + sum += h_out[i * head_dim + j] * k_norm[i]; + } + hk_buf[j] = sum; + } + __syncthreads(); + + // Step 4: Update h: h = h + outer(k, beta * (v - hk)) + for (int idx = tid; idx < head_dim * head_dim; idx += n_threads) { + int i = idx / head_dim; + int j = idx % head_dim; + float delta_j = beta * (v[j] - hk_buf[j]); + h_out[idx] += k_norm[i] * delta_j; + } + __syncthreads(); + + // Step 5: Compute output y = h^T @ q + for (int j = tid; j < head_dim; j += n_threads) { + float sum = 0.0f; + for (int i = 0; i < head_dim; ++i) { + sum += h_out[i * head_dim + j] * q_norm[i]; + } + y[j] = sum; + } + __syncthreads(); + } +} + +void ggml_cuda_op_kda_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // h + const ggml_tensor * src1 = dst->src[1]; // q + const ggml_tensor * src2 = dst->src[2]; // k + const ggml_tensor * src3 = dst->src[3]; // v + const ggml_tensor * src4 = dst->src[4]; // g + const ggml_tensor * src5 = dst->src[5]; // beta + const ggml_tensor * src6 = dst->src[6]; // ids + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src3->type == GGML_TYPE_F32); + GGML_ASSERT(src4->type == GGML_TYPE_F32); + GGML_ASSERT(src5->type == GGML_TYPE_F32); + GGML_ASSERT(src6->type == GGML_TYPE_I32); + + const int64_t head_dim = src0->ne[0]; + const int64_t n_head = src1->ne[1]; + const int64_t n_seq_tokens = src1->ne[2]; + const int64_t n_seqs = src1->ne[3]; + + // Output offset for hidden state (after attention output) - in floats + const int64_t y_off = ggml_nelements(src1); + + const float * h_d = (const float *)src0->data; + const float * q_d = (const float *)src1->data; + const float * k_d = (const float *)src2->data; + const float * v_d = (const float *)src3->data; + const float * g_d = (const float *)src4->data; + const float * beta_d = (const float *)src5->data; + const int32_t * ids_d = (const int32_t *)src6->data; + float * dst_d = (float *)dst->data; + + cudaStream_t stream = ctx.stream(); + + // Launch kernel: one block per (sequence, head) pair + const int n_blocks = n_seqs * n_head; + const int n_threads = 128; + + // Shared memory: hk_buf[head_dim] + q_norm[head_dim] + k_norm[head_dim] + warp_sums[64] + size_t smem_size = (3 * head_dim + 64) * sizeof(float); + + kda_scan_f32_kernel<<>>( + h_d, q_d, k_d, v_d, g_d, beta_d, ids_d, dst_d, + head_dim, n_head, n_seq_tokens, n_seqs, y_off); +} diff --git a/ggml/src/ggml-cuda/kda-scan.cuh b/ggml/src/ggml-cuda/kda-scan.cuh new file mode 100644 index 00000000000..55783fb82bc --- /dev/null +++ b/ggml/src/ggml-cuda/kda-scan.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_kda_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b99345a2e93..002102dde08 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -999,6 +999,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "FLASH_ATTN_BACK", "SSM_CONV", "SSM_SCAN", + "KDA_SCAN", "WIN_PART", "WIN_UNPART", "GET_REL_POS", @@ -1024,7 +1025,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1133,7 +1134,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5432,6 +5433,70 @@ struct ggml_tensor * ggml_ssm_scan( return result; } +// ggml_kda_scan + +struct ggml_tensor * ggml_kda_scan( + struct ggml_context * ctx, + struct ggml_tensor * h, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * ids) { + GGML_ASSERT(ggml_is_contiguous(h)); + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + { + const int64_t head_dim = h->ne[0]; + const int64_t n_head = q->ne[1]; + const int64_t n_seq_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + GGML_ASSERT(h->ne[0] == head_dim); + GGML_ASSERT(h->ne[1] == head_dim); + GGML_ASSERT(h->ne[2] == n_head); + GGML_ASSERT(q->ne[0] == head_dim); + GGML_ASSERT(k->ne[0] == head_dim); + GGML_ASSERT(v->ne[0] == head_dim); + GGML_ASSERT(g->ne[0] == head_dim); + GGML_ASSERT(ggml_are_same_shape(q, k)); + GGML_ASSERT(ggml_are_same_shape(q, v)); + GGML_ASSERT(ggml_are_same_shape(q, g)); + GGML_ASSERT(beta->ne[0] == n_head); + GGML_ASSERT(beta->ne[1] == n_seq_tokens); + GGML_ASSERT(beta->ne[2] == n_seqs); + GGML_ASSERT(ids->ne[0] == n_seqs); + GGML_ASSERT(ggml_is_vector(ids)); + } + + // Output: y (attention output) + updated hidden states + // y: {head_dim, n_head, n_seq_tokens, n_seqs} + // h_new: {head_dim, head_dim, n_head, n_seqs} + const int64_t head_dim = h->ne[0]; + const int64_t n_head = q->ne[1]; + const int64_t n_seq_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, + ggml_nelements(q) + head_dim * head_dim * n_head * n_seqs); + + result->op = GGML_OP_KDA_SCAN; + result->src[0] = h; + result->src[1] = q; + result->src[2] = k; + result->src[3] = v; + result->src[4] = g; + result->src[5] = beta; + result->src[6] = ids; + + return result; +} + // ggml_win_part struct ggml_tensor * ggml_win_part( diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 266d19f9dd7..6fb6b181c76 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -335,6 +335,7 @@ class GGUFType: ADAPTER = "adapter" IMATRIX = "imatrix" MMPROJ = "mmproj" # dummy, unused for now + KIMI = "kimi" class MODEL_ARCH(IntEnum): @@ -444,6 +445,7 @@ class MODEL_ARCH(IntEnum): MINIMAXM2 = auto() RND1 = auto() PANGU_EMBED = auto() + KIMI_LINEAR = auto() # Kimi-Linear (hybrid MLA+KDA) class VISION_PROJECTOR_TYPE(IntEnum): @@ -701,6 +703,17 @@ class MODEL_TENSOR(IntEnum): A_MMPROJ_FC = auto() A_MM_NORM_PRE = auto() A_MM_NORM_MID = auto() + # Kimi Linear KDA (using SSM_ prefix for consistency) + SSM_CONV1D_Q = auto() + SSM_CONV1D_K = auto() + SSM_CONV1D_V = auto() + SSM_F_A = auto() + SSM_F_B = auto() + SSM_BETA = auto() + SSM_A_LOG = auto() + SSM_G_A = auto() + SSM_G_B = auto() + SSM_DT_B = auto() # nextn/mtp NEXTN_EH_PROJ = auto() NEXTN_EMBED_TOKENS = auto() @@ -817,6 +830,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.COGVLM: "cogvlm", MODEL_ARCH.RND1: "rnd1", MODEL_ARCH.PANGU_EMBED: "pangu-embedded", + MODEL_ARCH.KIMI_LINEAR: "kimi-linear", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -1072,6 +1086,17 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc", MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre", MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid", + # Kimi Linear KDA (using SSM_ prefix for consistency) + MODEL_TENSOR.SSM_CONV1D_Q: "blk.{bid}.ssm_conv1d_q", + MODEL_TENSOR.SSM_CONV1D_K: "blk.{bid}.ssm_conv1d_k", + MODEL_TENSOR.SSM_CONV1D_V: "blk.{bid}.ssm_conv1d_v", + MODEL_TENSOR.SSM_F_A: "blk.{bid}.ssm_f_a", + MODEL_TENSOR.SSM_F_B: "blk.{bid}.ssm_f_b", + MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta", + MODEL_TENSOR.SSM_A_LOG: "blk.{bid}.ssm_a", + MODEL_TENSOR.SSM_G_A: "blk.{bid}.ssm_g_a", + MODEL_TENSOR.SSM_G_B: "blk.{bid}.ssm_g_b", + MODEL_TENSOR.SSM_DT_B: "blk.{bid}.ssm_dt", # NextN/MTP MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj", MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens", @@ -3071,6 +3096,45 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.KIMI_LINEAR: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.SSM_CONV1D_Q, + MODEL_TENSOR.SSM_CONV1D_K, + MODEL_TENSOR.SSM_CONV1D_V, + MODEL_TENSOR.SSM_F_A, + MODEL_TENSOR.SSM_F_B, + MODEL_TENSOR.SSM_BETA, + MODEL_TENSOR.SSM_A_LOG, + MODEL_TENSOR.SSM_G_A, + MODEL_TENSOR.SSM_G_B, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_DT_B, + MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], # TODO } @@ -3374,6 +3438,10 @@ class VisionProjectorType: KEY_ATTENTION_CLAMP_KQV = Keys.Attention.CLAMP_KQV KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS +KEY_ATTENTION_Q_LORA_RANK = Keys.Attention.Q_LORA_RANK +KEY_ATTENTION_KV_LORA_RANK = Keys.Attention.KV_LORA_RANK +KEY_ATTENTION_KEY_LENGTH_MLA = Keys.Attention.KEY_LENGTH_MLA +KEY_ATTENTION_VALUE_LENGTH_MLA = Keys.Attention.VALUE_LENGTH_MLA # RoPE KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index a7b09739791..8a2aebac276 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -389,6 +389,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.expert_bias", # afmoe "model.layers.{bid}.feed_forward.expert_bias", # lfm2moe "model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2 + "model.layers.{bid}.block_sparse_moe.gate.e_score_correction_bias", # kimi ), # Feed-forward up @@ -450,6 +451,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 "model.layers.{bid}.feed_forward.down_proj", "model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan + "model.layers.{bid}.block_sparse_moe.shared_experts.up_proj", # kimi ), MODEL_TENSOR.FFN_UP_CHEXP: ( @@ -480,7 +482,6 @@ class TensorNameMap: "layers.{bid}.mlp.gate_proj", # qwen3-embedding "model.layers.{bid}.mlp.language_mlp.gate_proj", # cogvlm ), - MODEL_TENSOR.FFN_GATE_EXP: ( "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) @@ -496,10 +497,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 "model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 "model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan - ), - - MODEL_TENSOR.FFN_GATE_CHEXP: ( - "model.layers.{bid}.mlp.chunk_experts.gate_proj", # grovemoe + "model.layers.{bid}.block_sparse_moe.shared_experts.gate_proj", # kimi ), # Feed-forward down @@ -557,6 +555,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 "model.layers.{bid}.shared_mlp.output_linear", # granitemoe "model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan + "model.layers.{bid}.block_sparse_moe.shared_experts.down_proj", # kimi ), MODEL_TENSOR.FFN_DOWN_CHEXP: ( @@ -738,6 +737,7 @@ class TensorNameMap: "model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid "model.layers.{bid}.linear_attn.norm", # qwen3next "backbone.layers.{bid}.mixer.norm", # mamba2 + "model.layers.{bid}.self_attn.o_norm", # kimi ), MODEL_TENSOR.SSM_OUT: ( @@ -1569,6 +1569,38 @@ class TensorNameMap: "audio.multi_modal_projector.ln_mid", # ultravox ), + # Kimi Linear KDA (using SSM_ prefix for consistency) + MODEL_TENSOR.SSM_CONV1D_Q: ( + "model.layers.{bid}.self_attn.q_conv1d", + ), + MODEL_TENSOR.SSM_CONV1D_K: ( + "model.layers.{bid}.self_attn.k_conv1d", + ), + MODEL_TENSOR.SSM_CONV1D_V: ( + "model.layers.{bid}.self_attn.v_conv1d", + ), + MODEL_TENSOR.SSM_F_A: ( + "model.layers.{bid}.self_attn.f_a_proj", + ), + MODEL_TENSOR.SSM_F_B: ( + "model.layers.{bid}.self_attn.f_b_proj", + ), + MODEL_TENSOR.SSM_BETA: ( + "model.layers.{bid}.self_attn.b_proj", + ), + MODEL_TENSOR.SSM_A_LOG: ( + "model.layers.{bid}.self_attn.A_log", + ), + MODEL_TENSOR.SSM_G_A: ( + "model.layers.{bid}.self_attn.g_a_proj", + ), + MODEL_TENSOR.SSM_G_B: ( + "model.layers.{bid}.self_attn.g_b_proj", + ), + MODEL_TENSOR.SSM_DT_B: ( + "model.layers.{bid}.self_attn.dt_bias", + ), + # NextN/MTP tensors for GLM4_MOE MODEL_TENSOR.NEXTN_EH_PROJ: ( "model.layers.{bid}.eh_proj", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 67c7807e092..194aff9e579 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -82,6 +82,7 @@ add_library(llama models/internlm2.cpp models/jais.cpp models/jamba.cpp + models/kimi-linear.cpp models/lfm2.cpp models/llada-moe.cpp models/llada.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 8571a2e025a..fe61d524a87 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -111,6 +111,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_COGVLM, "cogvlm" }, { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, + { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -2492,6 +2493,54 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" }, }, }, + { + LLM_ARCH_KIMI_LINEAR, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + // Dense FFN (layer 0 only) + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + // MoE FFN (layers 1+) + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + // Shared experts + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + // KDA (using SSM_ enum prefix, keeping GGUF names for backward compat) + { LLM_TENSOR_SSM_CONV1D_Q, "blk.%d.ssm_conv1d_q" }, + { LLM_TENSOR_SSM_CONV1D_K, "blk.%d.ssm_conv1d_k" }, + { LLM_TENSOR_SSM_CONV1D_V, "blk.%d.ssm_conv1d_v" }, + { LLM_TENSOR_SSM_F_A, "blk.%d.ssm_f_a" }, + { LLM_TENSOR_SSM_F_B, "blk.%d.ssm_f_b" }, + { LLM_TENSOR_SSM_BETA, "blk.%d.ssm_beta" }, + { LLM_TENSOR_SSM_A_LOG, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_DT_B, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_G_A, "blk.%d.ssm_g_a" }, + { LLM_TENSOR_SSM_G_B, "blk.%d.ssm_g_b" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + // MLA + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + }, + }, { LLM_ARCH_RND1, { @@ -2713,6 +2762,17 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // Kimi KDA - Conv tensors are 4D [d_conv, 1, d_inner, 1], reshaped to 2D at runtime + {LLM_TENSOR_SSM_CONV1D_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_F_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_F_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_BETA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_A_LOG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_DT_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_SSM_G_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_G_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} @@ -2773,6 +2833,7 @@ bool llm_arch_is_recurrent(const llm_arch & arch) { case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: + case LLM_ARCH_KIMI_LINEAR: // KDA layers use delta attention with recurrent state return true; default: return false; @@ -2789,6 +2850,9 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: case LLM_ARCH_QWEN3NEXT: + // Kimi: Currently using recurrent-only mode since MLA doesn't use KV cache + // TODO: Enable hybrid when MLA KV caching is implemented + // case LLM_ARCH_KIMI_LINEAR: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index 150646478ae..7c5d2fa9ab4 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -115,6 +115,7 @@ enum llm_arch { LLM_ARCH_COGVLM, LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, + LLM_ARCH_KIMI_LINEAR, LLM_ARCH_UNKNOWN, }; @@ -383,6 +384,17 @@ enum llm_tensor { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next + // Kimi Linear KDA (using SSM_ prefix for consistency) + LLM_TENSOR_SSM_CONV1D_Q, // kimi: Q conv1d weight + LLM_TENSOR_SSM_CONV1D_K, // kimi: K conv1d weight + LLM_TENSOR_SSM_CONV1D_V, // kimi: V conv1d weight + LLM_TENSOR_SSM_F_A, // kimi: forget gate projection A + LLM_TENSOR_SSM_F_B, // kimi: forget gate projection B + LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient + LLM_TENSOR_SSM_A_LOG, // kimi: A_log (pre-converted in GGUF) + LLM_TENSOR_SSM_DT_B, // kimi: dt bias + LLM_TENSOR_SSM_G_A, // kimi: output gate projection A + LLM_TENSOR_SSM_G_B, // kimi: output gate projection B LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W2, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e04f0fc4f9a..3278cf2ef84 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1387,7 +1387,7 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes() const { - if (model.arch == LLM_ARCH_QWEN3NEXT) { + if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR) { return std::max(8192u, 32u*model.n_tensors()); } return std::max(1024u, 8u*model.n_tensors()); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1d012e09aba..481d1f2543d 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1819,11 +1819,14 @@ ggml_tensor * llm_graph_context::build_rs( ggml_build_forward_expand(gf, output_states); // copy extra states which won't be changed further (between n_seqs and n_rs) - ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra); - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - states_extra, - ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s)))); + // Skip if there are no extra states to copy (n_rs == n_seqs) + if (n_rs > n_seqs) { + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra); + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + states_extra, + ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s)))); + } return output_states; } diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 8cdbaf69fc0..88d266b8daf 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -133,6 +133,13 @@ uint32_t llama_hparams::n_embd_r() const { return n_embd * (n_shortconv_l_cache - 1); } + if (kda_head_dim != 0) { + // for Kimi KDA layers + // Conv state for Q, K, V: 3 * (d_conv - 1) * n_head * head_dim + const uint32_t d_inner = n_head() * kda_head_dim; // 32 * 128 = 4096 + return 3 * (kda_d_conv > 0 ? kda_d_conv - 1 : 3) * d_inner; + } + // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // Corresponds to Mamba's conv_states size @@ -145,6 +152,13 @@ uint32_t llama_hparams::n_embd_s() const { return n_embd * wkv_head_size; } + if (kda_head_dim != 0) { + // for Kimi KDA layers + // Full recurrent state: head_dim * head_dim * n_head + // h tensor shape for delta attention: [head_dim, head_dim, n_head] + return kda_head_dim * kda_head_dim * n_head(); // 128 * 128 * 32 = 524288 + } + // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index c3a53be793f..d26be3442d8 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -133,6 +133,10 @@ struct llama_hparams { uint32_t ssm_dt_rank = 0; uint32_t ssm_n_group = 0; + // for Kimi Delta Attention (KDA) + uint32_t kda_head_dim = 0; // head_dim for KDA layers (128 for Kimi) + uint32_t kda_d_conv = 0; // conv kernel size for KDA (4 for Kimi) + // for hybrid state space models std::array recurrent_layer_arr; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c2a545531a9..58681fc5602 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2247,6 +2247,54 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_KIMI_LINEAR: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv, false); + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); + + // KDA (Delta Attention) parameters + hparams.kda_head_dim = 128; // linear_attn_config.head_dim + hparams.kda_d_conv = 4; // linear_attn_config.short_conv_kernel_size + + // MLA qk_rope_head_dim (for reference) + // qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192 + + // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) + // MLA layers are at: 3, 7, 11, 15, 19, 23, 26 (7 MLA layers total) + // KDA layers are all others: 0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, 16, 17, 18, 20, 21, 22, 24, 25 (20 KDA layers) + // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + bool is_mla = (i == 3 || i == 7 || i == 11 || i == 15 || i == 19 || i == 23 || i == 26); + hparams.n_head_kv_arr[i] = is_mla ? hparams.n_head() : 0; + hparams.recurrent_layer_arr[i] = !is_mla; // KDA layers are recurrent + } + + // MoE parameters - Kimi uses moe_intermediate_size = 1024 + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + // Default values if not in GGUF + if (hparams.n_ff_exp == 0) hparams.n_ff_exp = 1024; // moe_intermediate_size + if (hparams.n_ff_shexp == 0) hparams.n_ff_shexp = 9216; // shared_expert_intermediate_size = intermediate_size + if (hparams.n_expert_shared == 0) hparams.n_expert_shared = 1; // num_shared_experts + if (hparams.n_layer_dense_lead == 0) hparams.n_layer_dense_lead = 1; // first_k_dense_replace + if (hparams.expert_weights_scale == 0.0f) hparams.expert_weights_scale = 2.446f; // routed_scaling_factor + + // MoE gating function - Kimi uses sigmoid (moe_router_activation_func: sigmoid) + if (hparams.expert_gating_func == 0) hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + + switch (hparams.n_layer) { + case 27: type = LLM_TYPE_48B; break; // Kimi-Linear-48B-A3B + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -6358,6 +6406,148 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); } } break; + case LLM_ARCH_KIMI_LINEAR: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // Check for KDA specific tensors to determine layer type or if it's a mixed model + // Assuming KDA layer if KDA tensors are present + + // KDA uses head_dim = 128 (from linear_attn_config.head_dim) + const int64_t n_embd_head_k_kda = 128; + const int64_t n_embd_head_v_kda = 128; + const int64_t ssm_d_conv = hparams.ssm_d_conv > 0 ? hparams.ssm_d_conv : 4; + + // Try loading KDA specific tensors (using SSM_ prefix) + // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) + // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_q_conv) { + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, TENSOR_NOT_REQUIRED); + } + + if (layer.ssm_q_conv) { + // KDA Layer - Conv1d weights may be 3D or 4D + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_k_conv) { + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); + } + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_v_conv) { + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head}, 0); + } + + // Conv bias may not exist in all models - make optional + layer.ssm_q_conv_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "bias", i), {n_embd_head_k_kda * n_head}, TENSOR_NOT_REQUIRED); + layer.ssm_k_conv_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "bias", i), {n_embd_head_k_kda * n_head}, TENSOR_NOT_REQUIRED); + layer.ssm_v_conv_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "bias", i), {n_embd_head_v_kda * n_head}, TENSOR_NOT_REQUIRED); + + // q, k, v projections + // Python: q_proj, k_proj, v_proj + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_v_kda * n_head}, 0); + + // KDA specific projections + // f_a_proj, f_b_proj + layer.ssm_f_a = create_tensor(tn(LLM_TENSOR_SSM_F_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); // head_dim + layer.ssm_f_b = create_tensor(tn(LLM_TENSOR_SSM_F_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); // projection_size + + // b_proj (beta mixing coefficient) + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), {n_embd, n_head}, 0); + + // A_log - Shape in GGUF: [1, num_heads, 1, 1] (4D) or [1, num_heads] (2D after quantization) + layer.ssm_a_log = create_tensor(tn(LLM_TENSOR_SSM_A_LOG, i), {1, n_head, 1, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_a_log) { + layer.ssm_a_log = create_tensor(tn(LLM_TENSOR_SSM_A_LOG, i), {1, n_head}, 0); + } + + // dt_bias - shape [n_embd_head_k_kda * n_head] = [4096] + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT_B, i), {n_embd_head_k_kda * n_head}, 0); + + // g_a_proj, g_b_proj (output gate) + layer.ssm_g_a = create_tensor(tn(LLM_TENSOR_SSM_G_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); + layer.ssm_g_b = create_tensor(tn(LLM_TENSOR_SSM_G_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); + + // o_norm (reusing SSM_NORM) + layer.ssm_o_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {n_embd_head_k_kda}, 0); // FusedRMSNormGated + layer.ssm_o_norm_b = create_tensor(tn(LLM_TENSOR_SSM_NORM, "bias", i), {n_embd_head_k_kda}, TENSOR_NOT_REQUIRED); + + // o_proj + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v_kda * n_head, n_embd}, 0); + + } else { + // MLA Layer - use MLA-specific head dimensions + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla > 0 ? hparams.n_embd_head_k_mla : 192; + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla > 0 ? hparams.n_embd_head_v_mla : 128; + + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, TENSOR_NOT_REQUIRED); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (layer.attn_q_a_norm) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + } else { + // Kimi MLA without Q compression: wq = [n_embd, n_head * n_embd_head_k_mla] + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + } + + // Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA) + // Note: hparams.n_rot may be 72 (from conversion) but actual is 64 + const int64_t qk_rope_head_dim = 64; // From config: qk_rope_head_dim + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // MoE intermediate size (different from dense FFN) + const int64_t n_ff_exp = hparams.n_ff_exp > 0 ? hparams.n_ff_exp : 1024; + + // Kimi uses n_layer_dense_lead to determine which layers use dense FFN vs MoE + // first_k_dense_replace = 1 means layer 0 uses dense FFN, layers 1+ use MoE + if (i < (int) hparams.n_layer_dense_lead) { + // Dense FFN layer - use normal n_ff + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + // MoE layer - use n_ff_exp (1024) instead of n_ff (9216) + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared experts use moe_intermediate_size * num_shared_experts + // Kimi: shared_expert_intermediate_size = 1024 * 1 = 1024 + // Tensors are 2D: [n_embd, n_ff_shexp] or [n_ff_shexp, n_embd] + const int64_t n_ff_shexp_actual = n_ff_exp * (hparams.n_expert_shared > 0 ? hparams.n_expert_shared : 1); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp_actual, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + + // exp_probs_b (e_score_correction_bias in vLLM) + // Try "bias" first (standard), then "weight" (for compatibility) + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + if (!layer.ffn_exp_probs_b) { + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "weight", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + } + } + } break; case LLM_ARCH_COGVLM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7522,6 +7712,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_KIMI_LINEAR: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -7677,6 +7871,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_KIMI_LINEAR: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GLM4: diff --git a/src/llama-model.h b/src/llama-model.h index f8342cf2cb1..b067b686d22 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -84,6 +84,7 @@ enum llm_type { LLM_TYPE_35B, LLM_TYPE_36B, LLM_TYPE_40B, + LLM_TYPE_48B, LLM_TYPE_65B, LLM_TYPE_70B, LLM_TYPE_120B, @@ -404,6 +405,23 @@ struct llama_layer { struct ggml_tensor * ffn_act_beta = nullptr; struct ggml_tensor * ffn_act_eps = nullptr; + // Kimi Linear KDA (using ssm_ prefix for consistency) + // Note: ssm_dt_b already exists above (mamba bias), reused for Kimi dt_bias + struct ggml_tensor * ssm_q_conv = nullptr; + struct ggml_tensor * ssm_q_conv_b = nullptr; + struct ggml_tensor * ssm_k_conv = nullptr; + struct ggml_tensor * ssm_k_conv_b = nullptr; + struct ggml_tensor * ssm_v_conv = nullptr; + struct ggml_tensor * ssm_v_conv_b = nullptr; + struct ggml_tensor * ssm_f_a = nullptr; + struct ggml_tensor * ssm_f_b = nullptr; + struct ggml_tensor * ssm_beta = nullptr; + struct ggml_tensor * ssm_a_log = nullptr; + struct ggml_tensor * ssm_g_a = nullptr; + struct ggml_tensor * ssm_g_b = nullptr; + struct ggml_tensor * ssm_o_norm = nullptr; + struct ggml_tensor * ssm_o_norm_b = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 0b23eaef3a8..7b8bf6e5246 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -724,7 +724,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; // sanity checks for models that have attention layers - if (qs.n_attention_wv != 0 && !is_clip_model) + // Skip this check for Kimi models which have hybrid KDA+MLA architecture + // (only MLA layers have attn_kv_b weights, KDA layers don't) + if (qs.n_attention_wv != 0 && !is_clip_model && model.arch != LLM_ARCH_KIMI_LINEAR) { const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); // attention layers have a non-zero number of kv heads diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a73c4c448ba..caefa706ce1 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1738,26 +1738,34 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // read bpe merges and populate bpe ranks const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + + // Kimi-K2 uses custom tokenization without traditional BPE merges + const bool is_kimi_k2 = (tokenizer_pre == "kimi-k2"); + if (merges_keyidx == -1) { - throw std::runtime_error("cannot find tokenizer merges in model file\n"); - } + if (!is_kimi_k2) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + // Kimi-K2 doesn't need merges, skip + LLAMA_LOG_INFO("%s: Kimi-K2 tokenizer detected, skipping BPE merges\n", __func__); + } else { + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); - const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - for (int i = 0; i < n_merges; i++) { - const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); - //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + std::string first; + std::string second; - std::string first; - std::string second; + const size_t pos = word.find(' ', 1); - const size_t pos = word.find(' ', 1); + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } - if (pos != std::string::npos) { - first = word.substr(0, pos); - second = word.substr(pos + 1); + bpe_ranks.emplace(std::make_pair(first, second), i); } - - bpe_ranks.emplace(std::make_pair(first, second), i); } // default special tokens diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp new file mode 100644 index 00000000000..660cd06f0e9 --- /dev/null +++ b/src/models/kimi-linear.cpp @@ -0,0 +1,429 @@ +#include "models.h" + +llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params), model(model) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM) + // So we don't need inp_pos + + // Only use recurrent state input for KDA layers + // MLA layers use direct softmax attention without KV cache + auto * inp_rs = build_rs_inp(); + + // Input for MLA layers (no KV cache) + auto * inp_no_cache = build_attn_inp_no_cache(); + + // Output ids for selecting which tokens to output + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // Kimi dimension constants + const int64_t n_head = hparams.n_head(); + const int64_t head_dim = hparams.kda_head_dim > 0 ? hparams.kda_head_dim : 128; + const int64_t d_conv = hparams.kda_d_conv > 0 ? hparams.kda_d_conv : 4; + const int64_t d_inner = n_head * head_dim; // 32 * 128 = 4096 + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + // Verify batch consistency for recurrent layers + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + // MLA params + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla > 0 ? hparams.n_embd_head_k_mla : 192; + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla > 0 ? hparams.n_embd_head_v_mla : 128; + const int64_t kv_lora_rank = hparams.n_lora_kv > 0 ? hparams.n_lora_kv : 512; + // qk_rope_head_dim = 64 (from Kimi config), NOT hparams.n_rot (which is 72) + // Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim] + const int64_t n_embd_head_qk_rope = 64; // config.qk_rope_head_dim + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; // 192 - 64 = 128 + + // Attention scale for KDA (1/sqrt(head_dim)) + const float kq_scale_kda = 1.0f / sqrtf((float)head_dim); + + // Attention scale for MLA + const float kq_scale_mla = 1.0f / sqrtf((float)n_embd_head_k_mla); + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers[il]; + ggml_tensor * inpSA = inpL; + + // Attention Norm + cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // Check layer type by checking which tensors exist + // KDA layers have ssm_a_log tensor, MLA layers have wkv_a_mqa tensor + bool is_kda = (layer.ssm_a_log != nullptr); + bool is_mla = (layer.wkv_a_mqa != nullptr); + + if (is_kda) { + // === KDA Layer (Kimi Delta Attention) with Recurrent State === + // Reference: vLLM kda.py + + const auto * mctx_cur = inp_rs->mctx; + const auto kv_head = mctx_cur->get_head(); + + // Get conv states from r_l tensor (Q, K, V each have separate state) + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + const int64_t conv_state_size = (d_conv - 1) * d_inner; + const int64_t n_embd_r_total = 3 * conv_state_size; // Q + K + V + ggml_tensor * conv_state_all = build_rs(inp_rs, conv_states_all, hparams.n_embd_r(), n_seqs); + // conv_state_all is [n_embd_r_total, n_seqs], split into Q, K, V + // Each conv state is [(d_conv-1) * d_inner] per sequence, need to reshape to [d_conv-1, d_inner, n_seqs] + // Memory layout: for each seq, Q state is first conv_state_size elements, then K, then V + // conv_state_all has stride: nb[0] = element_size, nb[1] = n_embd_r_total * element_size + + // View Q conv state: offset 0, size conv_state_size per seq + // conv_state_all is [n_embd_r_total, n_seqs] with memory layout: + // state[i + seq * n_embd_r_total] where i = conv_step + channel * (d_conv-1) + {0, conv_state_size, 2*conv_state_size} for Q/K/V + // We want [d_conv-1, d_inner, n_seqs] view: + // nb1 = (d_conv-1) * element_size (stride between channels) + // nb2 = n_embd_r_total * element_size (stride between seqs) + ggml_tensor * conv_state_q = ggml_view_3d(ctx0, conv_state_all, d_conv - 1, d_inner, n_seqs, + (d_conv - 1) * ggml_element_size(conv_state_all), // nb1: stride between channels + n_embd_r_total * ggml_element_size(conv_state_all), // nb2: stride between seqs + 0); // offset for Q + ggml_tensor * conv_state_k = ggml_view_3d(ctx0, conv_state_all, d_conv - 1, d_inner, n_seqs, + (d_conv - 1) * ggml_element_size(conv_state_all), + n_embd_r_total * ggml_element_size(conv_state_all), + conv_state_size * ggml_element_size(conv_state_all)); // offset for K + ggml_tensor * conv_state_v = ggml_view_3d(ctx0, conv_state_all, d_conv - 1, d_inner, n_seqs, + (d_conv - 1) * ggml_element_size(conv_state_all), + n_embd_r_total * ggml_element_size(conv_state_all), + 2 * conv_state_size * ggml_element_size(conv_state_all)); // offset for V + + // Step 1: Q, K, V projections -> [d_inner, n_tokens] + ggml_tensor * q_proj = ggml_mul_mat(ctx0, layer.wq, cur); + ggml_tensor * k_proj = ggml_mul_mat(ctx0, layer.wk, cur); + ggml_tensor * v_proj = ggml_mul_mat(ctx0, layer.wv, cur); + cb(q_proj, "kda_q_proj", il); + cb(k_proj, "kda_k_proj", il); + cb(v_proj, "kda_v_proj", il); + + // Step 2: Causal Conv1d for Q + // Reshape input: {d_inner, n_tokens} -> {d_inner, n_seq_tokens, n_seqs} + ggml_tensor * q_3d = ggml_reshape_3d(ctx0, q_proj, d_inner, n_seq_tokens, n_seqs); + + // Concat Q conv state and current input: {d_conv-1 + n_seq_tokens, d_inner, n_seqs} + ggml_tensor * conv_q = ggml_concat(ctx0, conv_state_q, ggml_transpose(ctx0, q_3d), 0); + + // Save last (d_conv-1) columns back to Q conv state + ggml_tensor * last_conv_q = ggml_view_3d(ctx0, conv_q, d_conv - 1, d_inner, n_seqs, + conv_q->nb[1], conv_q->nb[2], n_seq_tokens * conv_q->nb[0]); + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv_q, + ggml_view_1d(ctx0, conv_states_all, conv_state_size * n_seqs, + kv_head * n_embd_r_total * ggml_element_size(conv_states_all)))); + + // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner] + // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv] + // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step] + // ggml_ssm_conv computes: c[conv_step + channel * d_conv] + // GGUF layout: [d_conv, 1, d_inner] or [d_conv, 1, d_inner, 1] -> reshape to [d_conv, d_inner] + ggml_tensor * conv_weight = nullptr; + if (layer.ssm_q_conv) { + // Reshape conv weight from [d_conv, 1, d_inner, 1] to [d_conv, d_inner] for ggml_ssm_conv + // Cast to F32 if quantized (ggml_ssm_conv requires float weights) + ggml_tensor * q_conv_f32 = layer.ssm_q_conv; + if (q_conv_f32->type != GGML_TYPE_F32) { + q_conv_f32 = ggml_cast(ctx0, q_conv_f32, GGML_TYPE_F32); + } + conv_weight = ggml_reshape_2d(ctx0, q_conv_f32, d_conv, d_inner); + } + + // Apply conv1d + ggml_tensor * Qcur; + if (conv_weight) { + // Make conv_q contiguous for ggml_ssm_conv + conv_q = ggml_cont(ctx0, conv_q); + + // ggml_ssm_conv output: {d_inner, n_seq_tokens, n_seqs} + Qcur = ggml_ssm_conv(ctx0, conv_q, conv_weight); + // Reshape to 2D for bias add: {d_inner, n_tokens} + Qcur = ggml_reshape_2d(ctx0, Qcur, d_inner, n_tokens); + if (layer.ssm_q_conv_b) { + Qcur = ggml_add(ctx0, Qcur, layer.ssm_q_conv_b); + } + Qcur = ggml_silu(ctx0, Qcur); + } else { + GGML_ABORT("KDA layer missing Q conv weight"); + } + + // K conv1d (with separate K conv state) + ggml_tensor * Kcur; + if (layer.ssm_k_conv) { + ggml_tensor * k_3d = ggml_reshape_3d(ctx0, k_proj, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * conv_k = ggml_cont(ctx0, ggml_concat(ctx0, conv_state_k, ggml_transpose(ctx0, k_3d), 0)); + + // Save K conv state + ggml_tensor * last_conv_k = ggml_view_3d(ctx0, conv_k, d_conv - 1, d_inner, n_seqs, + conv_k->nb[1], conv_k->nb[2], n_seq_tokens * conv_k->nb[0]); + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv_k, + ggml_view_1d(ctx0, conv_states_all, conv_state_size * n_seqs, + (kv_head * n_embd_r_total + conv_state_size) * ggml_element_size(conv_states_all)))); + + ggml_tensor * k_conv_f32 = layer.ssm_k_conv; + if (k_conv_f32->type != GGML_TYPE_F32) { + k_conv_f32 = ggml_cast(ctx0, k_conv_f32, GGML_TYPE_F32); + } + ggml_tensor * k_conv_weight = ggml_reshape_2d(ctx0, k_conv_f32, d_conv, d_inner); + Kcur = ggml_ssm_conv(ctx0, conv_k, k_conv_weight); + Kcur = ggml_reshape_2d(ctx0, Kcur, d_inner, n_tokens); + if (layer.ssm_k_conv_b) { + Kcur = ggml_add(ctx0, Kcur, layer.ssm_k_conv_b); + } + Kcur = ggml_silu(ctx0, Kcur); + } else { + GGML_ABORT("KDA layer missing K conv weight"); + } + + // V conv1d (with separate V conv state) + ggml_tensor * Vcur; + if (layer.ssm_v_conv) { + ggml_tensor * v_3d = ggml_reshape_3d(ctx0, v_proj, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * conv_v = ggml_cont(ctx0, ggml_concat(ctx0, conv_state_v, ggml_transpose(ctx0, v_3d), 0)); + + // Save V conv state + ggml_tensor * last_conv_v = ggml_view_3d(ctx0, conv_v, d_conv - 1, d_inner, n_seqs, + conv_v->nb[1], conv_v->nb[2], n_seq_tokens * conv_v->nb[0]); + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv_v, + ggml_view_1d(ctx0, conv_states_all, conv_state_size * n_seqs, + (kv_head * n_embd_r_total + 2 * conv_state_size) * ggml_element_size(conv_states_all)))); + + ggml_tensor * v_conv_f32 = layer.ssm_v_conv; + if (v_conv_f32->type != GGML_TYPE_F32) { + v_conv_f32 = ggml_cast(ctx0, v_conv_f32, GGML_TYPE_F32); + } + ggml_tensor * v_conv_weight = ggml_reshape_2d(ctx0, v_conv_f32, d_conv, d_inner); + Vcur = ggml_ssm_conv(ctx0, conv_v, v_conv_weight); + Vcur = ggml_reshape_2d(ctx0, Vcur, d_inner, n_tokens); + if (layer.ssm_v_conv_b) { + Vcur = ggml_add(ctx0, Vcur, layer.ssm_v_conv_b); + } + Vcur = ggml_silu(ctx0, Vcur); + } else { + GGML_ABORT("KDA layer missing V conv weight"); + } + + // Step 3: Compute g1 (forget gate) + // g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias) + ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur); + ggml_tensor * g1 = ggml_mul_mat(ctx0, layer.ssm_f_b, f_a); + g1 = ggml_add(ctx0, g1, layer.ssm_dt_b); + g1 = ggml_softplus(ctx0, g1); + g1 = ggml_reshape_3d(ctx0, g1, head_dim, n_head, n_tokens); + + // A_log shape is [1, n_head] or [1, n_head, 1, 1], need to broadcast to [head_dim, n_head, n_tokens] + // First compute -exp(A_log), then reshape for broadcasting + ggml_tensor * A_neg_exp = ggml_neg(ctx0, ggml_exp(ctx0, layer.ssm_a_log)); + // Reshape to [1, n_head, 1] for broadcasting with g1 [head_dim, n_head, n_tokens] + A_neg_exp = ggml_reshape_3d(ctx0, A_neg_exp, 1, n_head, 1); + g1 = ggml_mul(ctx0, g1, A_neg_exp); + cb(g1, "kda_g1", il); + + // Step 4: Compute beta (mixing coefficient) + ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur); + beta = ggml_sigmoid(ctx0, beta); + cb(beta, "kda_beta", il); + + // Step 5: Reshape for KDA recurrence + // {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + Qcur = ggml_cont(ctx0, ggml_reshape_4d(ctx0, Qcur, head_dim, n_head, n_seq_tokens, n_seqs)); + Kcur = ggml_cont(ctx0, ggml_reshape_4d(ctx0, Kcur, head_dim, n_head, n_seq_tokens, n_seqs)); + Vcur = ggml_cont(ctx0, ggml_reshape_4d(ctx0, Vcur, head_dim, n_head, n_seq_tokens, n_seqs)); + g1 = ggml_cont(ctx0, ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs)); + beta = ggml_cont(ctx0, ggml_reshape_3d(ctx0, beta, n_head, n_seq_tokens, n_seqs)); + + cb(Qcur, "kda_Q", il); + cb(Kcur, "kda_K", il); + cb(Vcur, "kda_V", il); + + // Step 6: Get SSM state and compute KDA recurrence using ggml_kda_scan + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + // Use build_rs with lambda pattern (like Mamba SSM scan) + auto get_kda_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * h_state = ggml_reshape_4d(ctx, states, head_dim, head_dim, n_head, mctx_cur->get_size()); + // Call ggml_kda_scan which implements the correct KDA recurrence + return ggml_kda_scan(ctx, h_state, Qcur, Kcur, Vcur, g1, beta, ids); + }; + + ggml_tensor * y_kda = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs, get_kda_rows); + cb(y_kda, "kda_scan_out", il); + + // Store updated state back + // y_kda contains: [attention_output (head_dim * n_head * n_seq_tokens * n_seqs), new_state (head_dim * head_dim * n_head * n_seqs)] + const int64_t attn_out_size = head_dim * n_head * n_seq_tokens * n_seqs; + const int64_t state_size = head_dim * head_dim * n_head; + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_kda, state_size * n_seqs, attn_out_size * ggml_element_size(y_kda)), + ggml_view_1d(ctx0, ssm_states_all, state_size * n_seqs, kv_head * state_size * ggml_element_size(ssm_states_all)))); + + // Extract attention output + ggml_tensor * attn_out = ggml_view_1d(ctx0, y_kda, attn_out_size, 0); + attn_out = ggml_reshape_3d(ctx0, attn_out, head_dim, n_head, n_seq_tokens * n_seqs); + cb(attn_out, "kda_attn_out", il); + + // Step 7: Output gating g2 = g_b(g_a(x)) + ggml_tensor * cur_2d = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + ggml_tensor * g_a = ggml_mul_mat(ctx0, layer.ssm_g_a, cur_2d); + ggml_tensor * g2 = ggml_mul_mat(ctx0, layer.ssm_g_b, g_a); + g2 = ggml_reshape_3d(ctx0, g2, head_dim, n_head, n_seq_tokens * n_seqs); + + // Step 8: Apply o_norm with sigmoid gating + // Note: Kimi model uses sigmoid gating, not SiLU (despite FusedRMSNormGated default being swish) + // Formula: output = RMSNorm(x) * sigmoid(g) + ggml_tensor * normed = build_norm(attn_out, layer.ssm_o_norm, layer.ssm_o_norm_b, LLM_NORM_RMS, il); + ggml_tensor * gate = ggml_sigmoid(ctx0, g2); + ggml_tensor * gated = ggml_mul(ctx0, normed, gate); + + // Step 9: Output projection + gated = ggml_cont_2d(ctx0, gated, d_inner, n_tokens); + cur = ggml_mul_mat(ctx0, layer.wo, gated); + cb(cur, "kda_out", il); + + + GGML_UNUSED(d_conv); + GGML_UNUSED(kq_scale_kda); + + } else if (is_mla) { + // === MLA Layer (Multi-head Latent Attention) without KV Cache === + // Reference: vLLM mla.py + // TODO: Implement proper KV caching for MLA (requires custom cache format) + + // Step 1: Q projection and reshape + // vLLM Kimi: q = q_proj(hidden_states), then view as [n_tokens, n_head, qk_head_dim] + // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM) + ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.wq, cur); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens); + cb(Qcur, "mla_Q", il); + + // Step 2: KV compression + // kv_lora = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens] + ggml_tensor * kv_lora = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur); + + // Split: kv_c = kv_lora[:kv_lora_rank], k_pe = kv_lora[kv_lora_rank:] + ggml_tensor * kv_c = ggml_view_2d(ctx0, kv_lora, kv_lora_rank, n_tokens, + ggml_row_size(kv_lora->type, kv_lora_rank + n_embd_head_qk_rope), 0); + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_lora, n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(kv_lora->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_lora->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_lora->type, kv_lora_rank)); + + // Note: Kimi MLA does NOT apply RoPE (rotary_emb=None in vLLM) + // k_pe is used directly without RoPE + + // Normalize kv_c + kv_c = build_norm(kv_c, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il); + + // KV decompression: kv = kv_b_proj(kv_c_normed) + ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_c); + const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla; + + // Split kv into k_nope and v + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, kv_per_head), + ggml_row_size(kv->type, kv_per_head * n_head), 0); + ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens, + ggml_row_size(kv->type, kv_per_head), + ggml_row_size(kv->type, kv_per_head * n_head), + ggml_row_size(kv->type, n_embd_head_qk_nope)); + k_nope = ggml_cont(ctx0, k_nope); + Vcur = ggml_cont(ctx0, Vcur); + + // Concatenate k_nope + k_pe (broadcast k_pe to all heads) + // K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens] + // and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads + k_pe = ggml_cont(ctx0, k_pe); + // Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens] + ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens); + ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target); + ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, k_pe_repeated, 0); + cb(Kcur, "mla_K", il); + cb(Vcur, "mla_V", il); + + // Direct softmax attention (without KV cache) + // Use build_attn with inp_no_cache for proper mask handling + cur = build_attn(inp_no_cache, layer.wo, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il); + cb(cur, "mla_out", il); + + } else { + // Unknown layer type - this should not happen + GGML_ABORT("Kimi layer is neither KDA nor MLA - missing required tensors"); + } + + // On last layer, select only the output tokens + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Residual + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FFN Norm + cur = build_norm(ffn_inp, layer.ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // FFN / MoE + if (layer.ffn_gate_inp) { + // MoE layer + // Kimi uses moe_renormalize=True and routed_scaling_factor (stored as expert_weights_scale) = 2.446 + ggml_tensor * moe_out = build_moe_ffn(cur, layer.ffn_gate_inp, layer.ffn_up_exps, layer.ffn_gate_exps, layer.ffn_down_exps, + layer.ffn_exp_probs_b, hparams.n_expert, hparams.n_expert_used, + LLM_FFN_SILU, true, true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, il); + cb(moe_out, "ffn_moe_out", il); + + // Shared expert (if present) + if (layer.ffn_gate_shexp) { + ggml_tensor * ffn_shexp = build_ffn(cur, + layer.ffn_up_shexp, NULL, NULL, + layer.ffn_gate_shexp, NULL, NULL, + layer.ffn_down_shexp, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } else if (layer.ffn_gate) { + // Dense FFN layer + cur = build_ffn(cur, layer.ffn_up, NULL, NULL, layer.ffn_gate, NULL, NULL, + layer.ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // No FFN - this should not happen in Kimi + GGML_ABORT("Kimi layer missing FFN tensors"); + } + + // Residual + cur = ggml_add(ctx0, cur, ffn_inp); + inpL = cur; + } + + // Final Norm + cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // Output + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + + GGML_UNUSED(n_embd_head_qk_nope); +} diff --git a/src/models/models.h b/src/models/models.h index 7ba225b4784..93afc96c96e 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -283,6 +283,12 @@ struct llm_build_jamba : public llm_graph_context_mamba { llm_build_jamba(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_kimi_linear : public llm_graph_context_mamba { + llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params); +private: + const llama_model & model; +}; + struct llm_build_lfm2 : public llm_graph_context { const llama_model & model;