From 1226c389096296da80209b6a644c4b1eb5399cbc Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Sun, 26 Apr 2026 02:48:57 -0400 Subject: [PATCH 1/2] Add Kimi-K2.5 multimodal contrib (1T MoE + MoonViT vision) Kimi-K2.5 (moonshotai/Kimi-K2.5) on trn2.48xlarge via NxDI. Extends the K2 text decoder (PR #131) with MoonViT-400M vision encoder fusion. Configuration: TP=64, EP=1, LNC=2, seq_len=512, FP8 per-channel quantized. Performance: 46.6 tok/s TKG (21.4 ms TPOT), 2.1s CTE with vision. Tested on Neuron SDK 2.29 (DLAMI 20260410). Includes: - K2 text decoder (MLA attention, 384-expert MoE, shared experts) - K2.5 checkpoint loader (INT4 dequant -> BF16 -> FP8 per-channel) - Vision fusion via scatter_by_index_put (Llama4 pattern) - K25ImageToTextModelWrapper with non-trivial tracing inputs - MoonViT encoder with real-number 2D RoPE - 5 integration tests (smoke, multimodal gen, vision A/B, coherence, TPOT) --- contrib/models/Kimi-K2.5/README.md | 335 ++++ contrib/models/Kimi-K2.5/src/__init__.py | 50 + .../models/Kimi-K2.5/src/modeling_kimi_k2.py | 1548 +++++++++++++++++ .../models/Kimi-K2.5/src/modeling_kimi_k25.py | 1041 +++++++++++ contrib/models/Kimi-K2.5/src/moonvit.py | 397 +++++ contrib/models/Kimi-K2.5/test/__init__.py | 0 .../Kimi-K2.5/test/integration/__init__.py | 0 .../Kimi-K2.5/test/integration/test_model.py | 500 ++++++ .../models/Kimi-K2.5/test/unit/__init__.py | 0 9 files changed, 3871 insertions(+) create mode 100644 contrib/models/Kimi-K2.5/README.md create mode 100644 contrib/models/Kimi-K2.5/src/__init__.py create mode 100644 contrib/models/Kimi-K2.5/src/modeling_kimi_k2.py create mode 100644 contrib/models/Kimi-K2.5/src/modeling_kimi_k25.py create mode 100644 contrib/models/Kimi-K2.5/src/moonvit.py create mode 100644 contrib/models/Kimi-K2.5/test/__init__.py create mode 100644 contrib/models/Kimi-K2.5/test/integration/__init__.py create mode 100644 contrib/models/Kimi-K2.5/test/integration/test_model.py create mode 100644 contrib/models/Kimi-K2.5/test/unit/__init__.py diff --git a/contrib/models/Kimi-K2.5/README.md b/contrib/models/Kimi-K2.5/README.md new file mode 100644 index 00000000..fbbcffe0 --- /dev/null +++ b/contrib/models/Kimi-K2.5/README.md @@ -0,0 +1,335 @@ +# Contrib Model: Kimi-K2.5 (Multimodal) + +NeuronX Distributed Inference implementation of Moonshot AI's Kimi-K2.5 — a native multimodal agentic model with MoonViT vision encoder. + +## Model Information + +- **HuggingFace ID:** `moonshotai/Kimi-K2.5` +- **Model Type:** Multimodal (image + text) Mixture of Experts decoder +- **Architecture:** Kimi-K2 text decoder + MoonViT-400M vision encoder +- **License:** Check HuggingFace model card + +## Architecture Details + +### Text Decoder (same as Kimi-K2) + +| Parameter | Value | +|-----------|-------| +| Total parameters | ~1,017B | +| Active parameters per token | ~32B | +| Hidden size | 7168 | +| Attention heads | 128 | +| Layers | 61 | +| Vocabulary size | 163840 | +| Routed experts | 384 (8 active per token) | +| Shared experts | 1 per MoE layer | +| Dense layers | 1 (layer 0) | +| Attention type | Multi-Latent Attention (MLA) | +| Router activation | Sigmoid with `e_score_correction_bias` | + +### MoonViT Vision Encoder + +| Parameter | Value | +|-----------|-------| +| Architecture | ViT with 2D RoPE | +| Layers | 27 | +| Hidden size | 1152 | +| Attention heads | 16 | +| MLP hidden | 4304 | +| Parameters | ~400M (466M with projector) | +| Patch size | 14×14 | +| Patch merging | 2×2 (4 patches → 1 token) | +| Projector | PatchMergerMLP → 7168 | + +### Vision-Text Fusion + +- **Method:** scatter_by_index_put (Llama4/Pixtral pattern) +- **Mechanism:** Vision embeddings replace text embeddings at placeholder positions via `index_put_` +- **Integration point:** `NeuronBaseModel.get_model_output()` → `encode_vision_to_input()` + +### K2.5 Weight Format + +K2.5 uses a different weight format than K2: +- **Expert weights:** INT4 compressed-tensors (pack-quantized, group_size=32, symmetric) + → Dequantized to BF16 → Re-quantized to FP8 per-channel for Neuron +- **Non-expert weights:** BF16 (attention, shared experts, norms, embeddings, lm_head) +- **Key prefix:** `language_model.model.` (stripped to match K2 format) +- **Vision keys:** `vision_tower.*`, `mm_projector.*` (filtered for text-only model) + +## Validation Results + +**Validated:** 2026-04-25 (SDK 2.29) +**Configuration:** TP=64, EP=1, LNC=2, batch_size=1, seq_len=512, FP8 per-channel + +### Test Results + +| Test | Status | Result | +|------|--------|--------| +| Smoke Test | PASS | Model compiles and loads on trn2.48xlarge | +| Multimodal Generation | PASS | Generates coherent image description | +| Vision A/B Test | PASS | Real vision ≠ zero vision (max logit diff: 15.2) | +| Coherence | PASS | No repetition, natural text | +| Throughput | PASS | 45.9 tok/s at BS=1 (LNC=2) | + +### Performance Metrics + +| Metric | Value | +|--------|-------| +| **TKG throughput** | **45.9 tok/s** | +| **TPOT (per-token latency)** | 21.4 ms | +| **E2E throughput** | 26.3 tok/s (128 tokens) | +| **CTE latency** | 2,094 ms | +| **CTE vision overhead** | 1.1 ms (negligible) | +| **MoonViT latency** | 35.5 ms | +| **TTFT** | ~2,130 ms (CTE + MoonViT) | +| **Model load time** | ~79 min (weight sharding) | +| **Compile time** | ~10 min (CTE ~5 min, TKG ~5 min) | + +### Benchmark Details (10 iterations, 3 warmup) + +| Component | Mean | p50 | Std | +|-----------|------|-----|-----| +| MoonViT vision encoder | 35.5 ms | 35.5 ms | 0.1 ms | +| CTE (with vision) | 2,094.6 ms | 2,094.5 ms | 0.4 ms | +| CTE (text-only) | 2,093.5 ms | 2,093.5 ms | 0.3 ms | +| TKG per-token | 21.4 ms | 21.3 ms | 0.4 ms | +| End-to-end (CTE+TKG) | 4,863.2 ms | 4,864.7 ms | 12.1 ms | + +### Accuracy Validation (vs GPU reference) + +| Metric | Value | +|--------|-------| +| Vision encoder cosine similarity | 0.9995 | +| Token match rate (vs vLLM H100) | 3.1% (expected for 1T MoE) | +| Semantic quality | Both describe same image correctly | + +Token-level divergence is expected: FP8 vs BF16 quantization + 384-expert MoE routing causes different expert selections that cascade through autoregressive generation. Both outputs are semantically equivalent and correctly describe the input image. + +## Usage + +```python +import sys +import torch +from pathlib import Path + +# Add src to path +sys.path.insert(0, str(Path("src"))) + +from modeling_kimi_k2 import NeuronKimiK2ForCausalLM, NeuronKimiK2Model +from modeling_kimi_k25 import ( + apply_k25_patches, + apply_k25_checkpoint_patch, + build_k25_config, + create_text_only_model_dir, + BOS_TOKEN_ID, IM_USER_TOKEN_ID, IM_END_TOKEN_ID, + IM_ASSISTANT_TOKEN_ID, MEDIA_PLACEHOLDER_TOKEN_ID, +) + +model_path = "/path/to/Kimi-K2.5" +text_model_dir = "/path/to/Kimi-K2.5-text" +compiled_path = "/path/to/compiled" +vision_emb_path = "/path/to/moonvit_448_real_embeddings.pt" + +# 1. Create text-only model directory +create_text_only_model_dir(model_path, text_model_dir) + +# 2. Apply patches BEFORE model init +apply_k25_patches(NeuronKimiK2ForCausalLM, NeuronKimiK2Model, ep_degree=1) + +# 3. Build config +config = build_k25_config(text_model_dir, tp_degree=64, ep_degree=1, lnc=2) + +# 4. Initialize, patch, compile, load +model = NeuronKimiK2ForCausalLM(text_model_dir, config=config) +apply_k25_checkpoint_patch(model) +model.compile(compiled_path) # ~10 min +model.load(compiled_path) # ~79 min + +# 5. Load pre-computed vision embeddings +vision_emb = torch.load(vision_emb_path, map_location="cpu").to(torch.bfloat16) + +# 6. Build multimodal prompt +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + +n_vision = vision_emb.shape[0] # 256 +text_ids = tokenizer.encode("Describe this image in detail.") +input_ids_list = ( + [BOS_TOKEN_ID, IM_USER_TOKEN_ID] + + [MEDIA_PLACEHOLDER_TOKEN_ID] * n_vision + + text_ids + + [IM_END_TOKEN_ID, IM_ASSISTANT_TOKEN_ID] +) + +# 7. Prepare vision tensors (must be [1, seq_len, ...]) +seq_len = 512 +ve = torch.zeros(1, seq_len, 7168, dtype=torch.bfloat16) +vm = torch.full((1, seq_len, 1), fill_value=seq_len - 1, dtype=torch.int32) +for i in range(n_vision): + ve[0, i] = vision_emb[i] + vm[0, i, 0] = i + 2 # positions after BOS + im_user + +# 8. Run inference (see test/integration/test_model.py for full generation loop) +``` + +**Important:** Run with environment variables: +```bash +NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 python your_script.py +``` + +## Pre-computing MoonViT Embeddings + +All 64 Neuron cores are consumed by the text decoder (TP=64). MoonViT must be run **before** loading the text decoder: + +```python +import torch +import torch_neuronx +from moonvit import NeuronMoonViTWrapper, load_vision_weights, precompute_rope_real +from modeling_kimi_k25 import preprocess_image, precompute_rope_tables +from PIL import Image + +# Create and trace MoonViT +model = NeuronMoonViTWrapper(patch_h=32, patch_w=32) # 448x448 +model = load_vision_weights(model, "/path/to/Kimi-K2.5", 32, 32) +model = model.to(torch.bfloat16).eval() + +# Precompute RoPE +cos_table, sin_table = precompute_rope_real(72, 512, 512) +rope_cos = cos_table[:32, :32].reshape(-1, 36).to(torch.bfloat16) +rope_sin = sin_table[:32, :32].reshape(-1, 36).to(torch.bfloat16) + +# Preprocess image +image = Image.open("test_image.jpg") +pixel_values, grid_thw, n_merged = preprocess_image(image, 448) + +# Trace on Neuron +model_neuron = torch_neuronx.trace( + model, (pixel_values, rope_cos, rope_sin), + compiler_args=["--model-type", "transformer", "--auto-cast", "none"], +) +torch.jit.save(model_neuron, "moonvit_448.pt") + +# Pre-compute embeddings +with torch.no_grad(): + vision_output = model_neuron(pixel_values, rope_cos, rope_sin) +torch.save(vision_output.to(torch.bfloat16), "moonvit_448_real_embeddings.pt") +``` + +## Compatibility Matrix + +| Instance / SDK Version | 2.29 | 2.28 | 2.27 and earlier | +|------------------------|------|------|------------------| +| trn2.48xlarge (LNC=2, TP=64, EP=1) | **Working (45.9 tok/s)** | Not tested | Not tested | +| trn2.48xlarge (LNC=2, TP=32, EP=2) | Not recommended* | Not tested | Not tested | +| trn2.3xlarge | Not supported (needs TP=64) | Not supported | Not supported | +| inf2 | Not supported | Not supported | Not supported | + +\*EP=2 has known blockwise CTE kernel regression in SDK 2.29 (see K2 contrib notes). + +## Testing + +Run integration tests on a trn2.48xlarge: + +```bash +# Activate Neuron venv (SDK 2.29) +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +pip install tiktoken # Required for K2.5 tokenizer + +# Run tests +NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ + pytest test/integration/test_model.py -v --capture=tee-sys +``` + +Or run standalone: + +```bash +NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ + python test/integration/test_model.py +``` + +**Note:** Compilation takes ~10 min, model loading takes ~79 min (dominated by weight sharding across 64 ranks). The first run will compile NEFFs; subsequent runs reuse cached NEFFs. + +## Prerequisites + +1. **Model weights:** Download from HuggingFace (~555 GB): + ```bash + huggingface-cli download moonshotai/Kimi-K2.5 \ + --local-dir /mnt/nvme/models/Kimi-K2.5 + ``` + +2. **Pre-computed vision embeddings:** Trace MoonViT and pre-compute embeddings before loading the text decoder (see "Pre-computing MoonViT Embeddings" above). + +3. **Storage:** At least 600 GB for model weights + 50 GB for compiled NEFFs. NVMe RAID recommended for faster loading. + +4. **Host RAM:** At least 2 TB (safetensors mmap can use significant virtual memory during weight sharding). + +5. **tiktoken package:** Required for K2.5 tokenizer: + ```bash + pip install tiktoken + ``` + +## Key Implementation Details + +### Vision Embedding Fusion (`encode_vision_to_input`) + +Uses `scatter_by_index_put` following the Llama4/Pixtral pattern: +- `vision_embeddings`: `[BS, seq_len, 7168]` — real vision data packed at front, zeros padding rest +- `vision_mask`: `[BS, seq_len, 1]` — integer position indices, `fill_value=seq_len-1` for padding +- The function clones `inputs_embeds` and uses `index_put_` to scatter vision embeddings at their target positions + +### CRITICAL: `pad_inputs()` Silent Replacement + +`ModelWrapper.pad_inputs()` (model_wrapper.py:791-809) silently replaces vision tensors with dummy zeros when their sequence dimension doesn't match the padded sequence length. Vision tensors **must** be provided at `[BS, seq_len, ...]` size to avoid being replaced. + +### ImageToTextModelWrapper Tracing + +The standard `ImageToTextModelWrapper` provides zero-filled vision inputs for tracing, which the Neuron XLA compiler may optimize away. `K25ImageToTextModelWrapper` overrides `input_generator()` to use ones-like inputs, matching NxDI's proven `test_scatter.py` pattern. + +### MoonViT on Neuron + +MoonViT uses real-number decomposition of 2D complex RoPE and eager attention (no flash_attn) for Neuron compatibility. The 400M parameter encoder processes a 448×448 image in 35.5ms on a single Neuron core. + +## Example Checkpoints + +* [moonshotai/Kimi-K2.5](https://huggingface.co/moonshotai/Kimi-K2.5) + +## Known Limitations + +- **Pre-computed vision required:** All 64 Neuron cores are used by the text decoder. MoonViT cannot run after text decoder loading. Pre-compute embeddings first. + +- **On-device sampling (ODS):** Disabled. The model returns logits `[BS, 1, 163840]`, not token indices, due to a known ODS compatibility issue. + +- **Single image per inference:** Fixed to one 448×448 image. Variable resolution requires retracing MoonViT. + +- **Batching:** BS=1 only. Same bandwidth-bound limitation as K2 (MoE expert weight loads dominate). + +- **seq_len=512:** Maximum 512 tokens total (256 vision + text + generation). Larger seq_len causes HBM OOM with TP=64 EP=1. See HBM Memory Bottleneck section. + +### HBM Memory Bottleneck + +| seq_len | TKG scratchpad | Total per HBM bank | Headroom (of 23.363 GB) | +|---------|---------------|-------------------|------------------------| +| 128 | ~3.0 GB | ~20.9 GB | ~2.5 GB | +| 512 | ~4.1 GB | ~21.9 GB | ~1.4 GB | +| 1024 | ~5.5 GB | ~23.4 GB | ~0 GB | + +## Relationship to Kimi-K2 + +This is an extension of the Kimi-K2 text-only NxDI contrib (PR #131). Key differences: + +| Aspect | K2 | K2.5 | +|--------|-----|------| +| Modality | Text-only | Multimodal (image + text) | +| Config | TP=32, EP=2 | TP=64, EP=1 | +| Quantization | Blockwise FP8 (native) | INT4 → BF16 → FP8 per-channel | +| Weight format | K2 safetensors | K2.5 compressed-tensors | +| TKG throughput | 6.0 tok/s | 45.9 tok/s | +| Vision encoder | N/A | MoonViT-400M (35.5 ms) | + +The 7.6x throughput improvement (6.0 → 45.9 tok/s) comes from TP=64 EP=1 (vs TP=32 EP=2), which eliminates inter-EP communication overhead and gives each core more bandwidth. + +## Maintainer + +Annapurna Labs + +**Last Updated:** 2026-04-25 diff --git a/contrib/models/Kimi-K2.5/src/__init__.py b/contrib/models/Kimi-K2.5/src/__init__.py new file mode 100644 index 00000000..0ed9588e --- /dev/null +++ b/contrib/models/Kimi-K2.5/src/__init__.py @@ -0,0 +1,50 @@ +# Kimi-K2.5 Multimodal on Neuron via NxDI +# +# Text decoder: reuses K2 model code (modeling_kimi_k2.py) +# Multimodal: K2.5 checkpoint loader, vision fusion, MoonViT (modeling_kimi_k25.py, moonvit.py) + +from .modeling_kimi_k2 import NeuronKimiK2ForCausalLM, KimiK2InferenceConfig +from .modeling_kimi_k25 import ( + apply_k25_patches, + apply_k25_checkpoint_patch, + build_k25_config, + create_text_only_model_dir, + preprocess_image, + precompute_rope_tables, + K25ImageToTextModelWrapper, + BOS_TOKEN_ID, + IM_USER_TOKEN_ID, + IM_END_TOKEN_ID, + IM_ASSISTANT_TOKEN_ID, + MEDIA_PLACEHOLDER_TOKEN_ID, +) +from .moonvit import ( + NeuronMoonViTWrapper, + load_vision_weights, + precompute_rope_real, +) + +__all__ = [ + # K2 text decoder + "NeuronKimiK2ForCausalLM", + "KimiK2InferenceConfig", + # K2.5 multimodal + "apply_k25_patches", + "apply_k25_checkpoint_patch", + "build_k25_config", + "create_text_only_model_dir", + "K25ImageToTextModelWrapper", + # Vision + "NeuronMoonViTWrapper", + "load_vision_weights", + # Preprocessing + "preprocess_image", + "precompute_rope_tables", + "precompute_rope_real", + # Token IDs + "BOS_TOKEN_ID", + "IM_USER_TOKEN_ID", + "IM_END_TOKEN_ID", + "IM_ASSISTANT_TOKEN_ID", + "MEDIA_PLACEHOLDER_TOKEN_ID", +] diff --git a/contrib/models/Kimi-K2.5/src/modeling_kimi_k2.py b/contrib/models/Kimi-K2.5/src/modeling_kimi_k2.py new file mode 100644 index 00000000..221a76c2 --- /dev/null +++ b/contrib/models/Kimi-K2.5/src/modeling_kimi_k2.py @@ -0,0 +1,1548 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Kimi-K2 (moonshotai/Kimi-K2-Instruct-0905) on Neuron via NxDI. +# +# Architecture: DeepseekV3ForCausalLM variant with MLA attention + MoE +# - 1T total parameters, 32B active per token +# - 384 routed experts (8 active per token) + 1 shared expert +# - Multi-Latent Attention (MLA) with compressed KV cache +# - Sigmoid routing with e_score_correction_bias + normalized top-K +# - Blockwise FP8 quantization (e4m3, 128x128 blocks) +# - YaRN RoPE (factor=64, max_position_embeddings=262144) +# +# Supported configuration: +# - trn2.48xlarge: TP=64, EP=2, LNC=1 (128 logical cores) +# - Blockwise FP8 for routed expert weights +# - CPU greedy sampling (no on-device sampling) +# +# References: +# - NxDI DeepseekV3Attention: models/deepseek/modeling_deepseek.py +# - Qwen3-MoE reference: models/qwen3_moe/modeling_qwen3_moe.py + +import gc +import logging +import math +import os +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers.mappings import ( + gather_from_sequence_parallel_region, +) +from neuronx_distributed.utils import cpu_mode + +from neuronx_distributed_inference.models.config import InferenceConfig, MoENeuronConfig +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import manual_softmax +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.moe_v2 import initialize_moe_module + +from neuronx_distributed.modules.moe.routing import RouterTopK + +from neuronx_distributed_inference.models.deepseek.rope_util import ( + DeepseekV3YarnRotaryEmbedding, + apply_rotary_pos_emb, + yarn_get_mscale, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# RMSNorm +# --------------------------------------------------------------------------- + + +def get_rmsnorm_cls(): + return KimiK2RMSNorm if cpu_mode() else CustomRMSNorm + + +class KimiK2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# --------------------------------------------------------------------------- +# MoE initialization +# --------------------------------------------------------------------------- + + +def initialize_kimi_k2_moe_module(config: "KimiK2InferenceConfig"): + """Initialize MoE module with sigmoid routing + e_score_correction_bias. + + Uses standard RouterTopK with bias=True. The e_score_correction_bias is + loaded as the router linear layer's bias (pre-sigmoid), which is a slight + semantics change from the original post-sigmoid application. The biases are + small (~0.03-0.1) so the approximation is acceptable. + """ + from neuronx_distributed_inference.modules.moe_v2 import ( + initialize_moe_process_group, + ) + from neuronx_distributed.modules.moe.expert_mlps_v2 import ExpertMLPsV2 + from neuronx_distributed.modules.moe.model import MoE + from neuronx_distributed.modules.moe.moe_configs import RoutedExpertsMLPOpsConfig + + enabled_hybrid_sharding = config.neuron_config.hybrid_sharding_config is not None + ( + moe_tkg_tensor_model_parallel_group, + moe_tkg_expert_model_parallel_group, + moe_cte_tensor_model_parallel_group, + moe_cte_expert_model_parallel_group, + ) = initialize_moe_process_group(config, enabled_hybrid_sharding) + + router = RouterTopK( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + dtype=config.neuron_config.router_config.dtype, + act_fn=config.neuron_config.router_config.act_fn, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + sequence_dimension=1, + bias=True, + apply_act_fn_over_topk=False, + store_transposed_weights=False, + ) + + expert_mlps = ExpertMLPsV2( + routed_experts_mlp_config=RoutedExpertsMLPOpsConfig( + num_experts=config.num_local_experts, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + top_k=config.num_experts_per_tok, + hidden_act=config.hidden_act, + bias=False, + glu_mlp=config.neuron_config.glu_mlp, + glu_type=config.neuron_config.glu_type, + hidden_act_scaling_factor=config.neuron_config.hidden_act_scaling_factor, + hidden_act_bias=config.neuron_config.hidden_act_bias, + early_expert_affinity_modulation=config.neuron_config.early_expert_affinity_modulation, + normalize_top_k_affinities=config.neuron_config.normalize_top_k_affinities, + enable_spmd_rank=config.neuron_config.blockwise_matmul_config.parallelize_token_to_block_mapping, + ), + blockwise_matmul_config=config.neuron_config.blockwise_matmul_config, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + dtype=config.neuron_config.torch_dtype, + is_prefill=config.neuron_config.is_prefill_stage, + enabled_hybrid_sharding=enabled_hybrid_sharding, + tensor_model_parallel_group=parallel_state.get_tensor_model_parallel_group(), + expert_model_parallel_group=parallel_state.get_expert_model_parallel_group(), + cte_tensor_model_parallel_group=moe_cte_tensor_model_parallel_group, + cte_expert_model_parallel_group=moe_cte_expert_model_parallel_group, + tkg_tensor_model_parallel_group=moe_tkg_tensor_model_parallel_group, + tkg_expert_model_parallel_group=moe_tkg_expert_model_parallel_group, + ) + + moe = MoE( + router=router, + expert_mlps=expert_mlps, + shared_experts=None, + rmsnorm=None, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + return_expert_index=config.neuron_config.return_expert_index, + sequence_dimension=1, + init_tkg_module=False, + tkg_config=None, + ) + moe.eval() + return moe + + +# --------------------------------------------------------------------------- +# Shared Expert MLP (dense, always active) +# --------------------------------------------------------------------------- + + +class KimiK2SharedExpertMLP(nn.Module): + """Shared expert MLP (SwiGLU) for Kimi-K2. Always active, not routed.""" + + def __init__(self, config: "KimiK2InferenceConfig"): + super().__init__() + self.hidden_size = config.hidden_size + self.shared_intermediate_size = config.moe_intermediate_size + + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + self.shared_intermediate_size, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + self.shared_intermediate_size, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.down_proj = RowParallelLinear( + self.shared_intermediate_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + ) + self.act_fn = nn.SiLU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# --------------------------------------------------------------------------- +# Dense MLP (for first_k_dense_replace layers, i.e. layer 0) +# --------------------------------------------------------------------------- + + +class KimiK2DenseMLP(nn.Module): + """Standard SwiGLU MLP for the dense layers (first_k_dense_replace = 1).""" + + def __init__(self, config: "KimiK2InferenceConfig"): + super().__init__() + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.down_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + ) + self.act_fn = nn.SiLU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# --------------------------------------------------------------------------- +# MLA Attention (Multi-Latent Attention) +# --------------------------------------------------------------------------- + + +class KimiK2Attention(NeuronAttentionBase): + """ + Multi-Latent Attention (MLA) for Kimi-K2. + + KV cache format: stores (k_pe, compressed_kv) concatenated. + - K cache: [batch, 1, seq, qk_rope_head_dim + kv_lora_rank] + - V cache: same (placeholder, only K is read during decode) + """ + + def __init__( + self, + config: "KimiK2InferenceConfig", + layer_idx: int, + tensor_model_parallel_group=None, + ): + super().__init__( + config=config, + tensor_model_parallel_group=tensor_model_parallel_group, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_attention_heads, + head_dim=config.v_head_dim, + num_cores_per_group=config.num_cores_per_group, + rms_norm_eps=config.rms_norm_eps, + ) + + self.layer_idx = layer_idx + self.bias = getattr(config, "attention_bias", False) + self.attention_dropout = config.attention_dropout + self.num_total_heads = config.num_attention_heads + + if cpu_mode(): + self.num_heads = self.num_total_heads + else: + self.num_heads = self.num_total_heads // config.neuron_config.tp_degree + + # MLA dimensions + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + self.head_dim = self.v_head_dim + + # MLA doesn't use fused QKV + self.qkv_proj = None + + # YaRN RoPE + self.rotary_emb = DeepseekV3YarnRotaryEmbedding( + dim=config.qk_rope_head_dim, + max_position_embeddings=config.max_position_embeddings, + scaling_factor=config.rope_scaling["factor"], + base=config.rope_theta, + mscale=config.rope_scaling["mscale"], + mscale_all_dim=config.rope_scaling["mscale_all_dim"], + beta_fast=config.rope_scaling["beta_fast"], + beta_slow=config.rope_scaling["beta_slow"], + ) + + # Softmax scale with mscale adjustment + self.softmax_scale = self.q_head_dim ** (-0.5) + if config.rope_scaling is not None: + mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + self.is_causal = True + self._init_mla_projections(config) + + def _init_mla_projections(self, config): + dtype = self.torch_dtype + tp_group = self.tensor_model_parallel_group + + # Query projection (LoRA: hidden -> q_lora_rank -> num_heads * q_head_dim) + if self.q_lora_rank is None: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_total_heads * self.q_head_dim, + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, + config.q_lora_rank, + bias=config.attention_bias, + dtype=dtype, + ) + self.q_a_layernorm = get_rmsnorm_cls()(config.q_lora_rank) + self.q_b_proj = ColumnParallelLinear( + config.q_lora_rank, + self.num_total_heads * self.q_head_dim, + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + + # KV projection (compressed: hidden -> kv_lora_rank + qk_rope_head_dim) + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + dtype=dtype, + ) + self.kv_a_layernorm = get_rmsnorm_cls()(config.kv_lora_rank) + + # kv_b_proj: decompresses latent -> per-head qk_nope + v + if tp_group is not None: + self.kv_b_proj = ColumnParallelLinear( + config.kv_lora_rank, + self.num_total_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + else: + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_total_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + # Output projection + if tp_group is not None: + self.o_proj = RowParallelLinear( + self.num_attention_heads * self.head_dim, + self.hidden_size, + bias=self.bias, + input_is_parallel=True, + dtype=self.torch_dtype, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=self.sequence_dimension, + tensor_model_parallel_group=tp_group, + reduce_dtype=self.rpl_reduce_dtype, + ) + else: + self.o_proj = nn.Linear( + self.num_attention_heads * self.head_dim, + self.hidden_size, + bias=self.bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: torch.Tensor = None, + active_mask: Optional[torch.LongTensor] = None, + adapter_ids=None, + cos_cache: Optional[torch.Tensor] = None, + sin_cache: Optional[torch.Tensor] = None, + **kwargs, + ): + if ( + self.sequence_parallel_enabled + and self.tensor_model_parallel_group is not None + ): + hidden_states = gather_from_sequence_parallel_region( + hidden_states, + self.sequence_dimension, + process_group=self.tensor_model_parallel_group, + ) + + bsz, q_len, _ = hidden_states.size() + + # Weight absorption: precompute from kv_b_proj weights + wkv_b = self.kv_b_proj.weight + wkv_b = wkv_b.view(self.num_heads, -1, self.kv_lora_rank) + out_absorb = wkv_b[:, self.qk_nope_head_dim :, :] # V absorption + q_absorb = wkv_b[:, : self.qk_nope_head_dim, :] # Q nope absorption + + # Query projection + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + + # KV compression + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + q_nope, q_pe = torch.tensor_split(q, (self.qk_nope_head_dim,), dim=-1) + compressed_kv, k_pe = torch.tensor_split( + compressed_kv, (self.kv_lora_rank,), dim=-1 + ) + compressed_kv = self.kv_a_layernorm(compressed_kv) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + + # Q nope absorption + q_nope = torch.einsum("hdc,bhqd->bhqc", q_absorb, q_nope) + + # RoPE + seq_len = self.neuron_config.seq_len + if sin_cache is None and cos_cache is None: + cos_cache, sin_cache = self.rotary_emb(k_pe, seq_len) + q_pe = apply_rotary_pos_emb(q_pe, cos_cache, sin_cache, position_ids) + k_pe = apply_rotary_pos_emb(k_pe, cos_cache, sin_cache, position_ids) + + # Attention scores + active_scores = torch.matmul(q_pe, k_pe.transpose(2, 3)) + torch.einsum( + "bhqc,blc->bhql", q_nope, compressed_kv + ) + active_scores *= self.softmax_scale + + if past_key_value is None: + # Context encoding (prefill) + active_scores = torch.where( + attention_mask, + active_scores, + torch.finfo(active_scores.dtype).min, + ) + active_scores = nn.functional.softmax( + active_scores, dim=-1, dtype=torch.float32 + ).to(k_pe.dtype) + x = torch.einsum("bhql,blc->bhqc", active_scores, compressed_kv) + attn_output = torch.einsum("bhqc,hdc->bhqd", x, out_absorb) + else: + # Token generation (decode) -- split prior cache + cached_kv = past_key_value[0] + if cached_kv.dim() == 4: + cached_kv = cached_kv.squeeze(1) + k_pe_prior, compressed_kv_prior = torch.tensor_split( + cached_kv, + [self.qk_rope_head_dim], + dim=-1, + ) + k_pe_prior = k_pe_prior.reshape( + bsz, + 1, + compressed_kv_prior.shape[1], + self.qk_rope_head_dim, + ) + + prior_scores = torch.matmul( + q_pe, k_pe_prior.transpose(2, 3) + ) + torch.einsum("bhqc,blc->bhql", q_nope, compressed_kv_prior) + prior_scores *= self.softmax_scale + prior_scores = torch.where( + attention_mask, + prior_scores, + torch.finfo(prior_scores.dtype).min, + ) + prior_scores = prior_scores.to(torch.float32) + + softmax_prior, softmax_active = manual_softmax( + prior_scores, + active_scores, + is_speculation=False, + ) + softmax_prior = softmax_prior.to(k_pe.dtype) + softmax_active = softmax_active.to(k_pe.dtype) + + x = torch.einsum("bhql,blc->bhqc", softmax_active, compressed_kv) + attn_active = torch.einsum("bhqc,hdc->bhqd", x, out_absorb) + + x = torch.einsum("bhql,blc->bhqc", softmax_prior, compressed_kv_prior) + attn_prior = torch.einsum("bhqc,hdc->bhqd", x, out_absorb) + + attn_output = attn_prior + attn_active + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + attn_output = self.o_proj(attn_output) + + # KV cache: concatenate k_pe and compressed_kv + k_pe_flat = k_pe.squeeze(1) + kv_for_cache = torch.cat([k_pe_flat, compressed_kv], dim=-1) + kv_for_cache = kv_for_cache.unsqueeze(1) + + # Store same data in both K and V cache slots (only K is read during decode) + past_key_value = (kv_for_cache, kv_for_cache) + + return attn_output, past_key_value, cos_cache, sin_cache + + +# --------------------------------------------------------------------------- +# Decoder Layer +# --------------------------------------------------------------------------- + + +class KimiK2DecoderLayer(nn.Module): + """Decoder layer: MLA attention + (MoE or Dense MLP) + optional shared expert.""" + + def __init__(self, config: "KimiK2InferenceConfig", layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.self_attn = KimiK2Attention(config=config, layer_idx=layer_idx) + + # MLP: dense for layer 0 (first_k_dense_replace=1), MoE for the rest + self.is_moe_layer = layer_idx >= config.first_k_dense_replace + + if self.is_moe_layer: + self.mlp = initialize_kimi_k2_moe_module(config=config) + if config.n_shared_experts > 0: + self.shared_experts = KimiK2SharedExpertMLP(config) + else: + self.shared_experts = None + else: + self.mlp = KimiK2DenseMLP(config) + self.shared_experts = None + + # Routed expert scaling factor (DeepSeek-V3 / Kimi-K2 pattern) + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) + + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + cos_cache: Optional[torch.Tensor] = None, + sin_cache: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, ...]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cos_cache=cos_cache, + sin_cache=sin_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if self.is_moe_layer: + moe_output = self.mlp(hidden_states, padding_mask)[0] + moe_output = moe_output * self.routed_scaling_factor + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + hidden_states = moe_output + shared_output + else: + hidden_states = moe_output + else: + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value, cos_cache, sin_cache, None) + return outputs + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +class KimiK2InferenceConfig(InferenceConfig): + """Inference config for Kimi-K2 (DeepSeek-V3 architecture).""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.num_local_experts = self.n_routed_experts # 384 + self.first_k_dense_replace = getattr(self, "first_k_dense_replace", 1) + self.n_shared_experts = getattr(self, "n_shared_experts", 1) + + # Router config: sigmoid with normalized top-k + self.neuron_config.router_config.dtype = torch.float32 + self.neuron_config.router_config.act_fn = "sigmoid" + self.neuron_config.normalize_top_k_affinities = True + + # GLU MLP for SwiGLU + self.neuron_config.glu_mlp = True + + def get_required_attributes(self) -> List[str]: + return [ + "attention_bias", + "hidden_act", + "hidden_size", + "intermediate_size", + "kv_lora_rank", + "max_position_embeddings", + "moe_intermediate_size", + "n_routed_experts", + "n_shared_experts", + "num_attention_heads", + "num_experts_per_tok", + "num_hidden_layers", + "num_key_value_heads", + "q_lora_rank", + "qk_nope_head_dim", + "qk_rope_head_dim", + "rms_norm_eps", + "rope_scaling", + "rope_theta", + "scoring_func", + "v_head_dim", + "vocab_size", + ] + + def add_derived_config(self): + self.num_cores_per_group = 1 + + @classmethod + def get_neuron_config_cls(cls) -> Type[MoENeuronConfig]: + return MoENeuronConfig + + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + + +class NeuronKimiK2Model(NeuronBaseModel): + def setup_attr_for_model(self, config: KimiK2InferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + + # MLA KV cache: 1 "head" with dim = qk_rope_head_dim + kv_lora_rank + self.num_key_value_heads = 1 + config.head_dim = ( + config.qk_rope_head_dim + config.kv_lora_rank + ) # 64 + 512 = 576 + + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: KimiK2InferenceConfig): + self.padding_idx = getattr(config, "pad_token_id", None) + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + + self.layers = nn.ModuleList( + [ + KimiK2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + ) + + +# --------------------------------------------------------------------------- +# Block FP8 Dequantization Utilities +# --------------------------------------------------------------------------- + +# FP8 E4M3 max representable value. +# PyTorch's e4m3fn has max=448 (no NaN encoding), but with +# --experimental-unsafe-fp8e4m3fn-as-fp8e4m3, exponent-15 values (>240) become NaN. +# Use 240.0 to ensure all quantized values stay in the e4m3-safe range. +_FP8_E4M3_MAX = 240.0 + + +def _dequant_block_fp8_to_fp32( + fp8_weight: Tensor, block_scales: Tensor, block_size: List[int] +) -> Tensor: + """Dequantize block-FP8 (e4m3, 128x128 blocks) weight to FP32.""" + se = block_scales.repeat_interleave(block_size[0], dim=0).repeat_interleave( + block_size[1], dim=1 + ) + if se.shape != fp8_weight.shape: + se = se[: fp8_weight.shape[0], : fp8_weight.shape[1]] + return fp8_weight.to(torch.float32) * se.to(torch.float32) + + +def _clamp_fp8_exponent15(fp8_weight: Tensor) -> Tensor: + """Clamp FP8 e4m3fn bytes that have exponent=15, which become NaN in e4m3. + + On Neuron hardware with --experimental-unsafe-fp8e4m3fn-as-fp8e4m3, + bytes 0x78-0x7E and 0xF8-0xFE become NaN. Clamp to max safe values. + """ + raw = fp8_weight.view(torch.uint8) + pos_exp15 = (raw >= 0x78) & (raw <= 0x7E) + neg_exp15 = (raw >= 0xF8) & (raw <= 0xFE) + clamped = raw.clone() + clamped[pos_exp15] = 0x77 # +240.0 + clamped[neg_exp15] = 0xF7 # -240.0 + return clamped.view(torch.float8_e4m3fn) + + +def _pack_experts_blockwise_fp8( + expert_fp8_weights: List[Tensor], + expert_block_scales: List[Tensor], + block_size: List[int], + tp_degree: int, + layout: str = "gate_up", +) -> Tuple[Tensor, Tensor]: + """Pack per-expert FP8 weights and block scales into fused tensors. + + Preserves original FP8 bytes (with exponent-15 clamping) and packs + block-wise scales into the matching layout. This avoids the lossy + FP8->FP32->FP8 re-quantization path. + + For gate_up (ColumnParallel, stride=2): packs [gate_w.T, up_w.T] -> [E, H, 2*I] + For down (RowParallel, stride=1): packs [down_w.T] -> [E, I, H] + """ + n_experts = len(expert_fp8_weights) + bs0, bs1 = block_size + + if layout == "gate_up": + gate0, up0 = expert_fp8_weights[0] + I, H = gate0.shape + packed_w = torch.empty(n_experts, H, 2 * I, dtype=torch.float8_e4m3fn) + + gs0, us0 = expert_block_scales[0] + sI, sH = gs0.shape + raw_scale = torch.empty(n_experts, sH, 2 * sI, dtype=torch.float32) + + for e in range(n_experts): + g_fp8, u_fp8 = expert_fp8_weights[e] + g_scale, u_scale = expert_block_scales[e] + g_fp8 = _clamp_fp8_exponent15(g_fp8) + u_fp8 = _clamp_fp8_exponent15(u_fp8) + packed_w[e, :, :I] = g_fp8.T + packed_w[e, :, I:] = u_fp8.T + raw_scale[e, :, :sI] = g_scale.T + raw_scale[e, :, sI:] = u_scale.T + + out_dim = 2 * I + per_tp = out_dim // tp_degree + repeat_factor = bs1 // per_tp if per_tp < bs1 else 1 + if repeat_factor > 1: + expanded_scale = raw_scale.repeat_interleave(repeat_factor, dim=2) + else: + expanded_scale = raw_scale + + return packed_w, expanded_scale + + elif layout == "down": + d0 = expert_fp8_weights[0] + H_orig, I = d0.shape + packed_w = torch.empty(n_experts, I, H_orig, dtype=torch.float8_e4m3fn) + + ds0 = expert_block_scales[0] + sH, sI = ds0.shape + raw_scale = torch.empty(n_experts, sI, sH, dtype=torch.float32) + + for e in range(n_experts): + d_fp8 = expert_fp8_weights[e] + d_scale = expert_block_scales[e] + d_fp8 = _clamp_fp8_exponent15(d_fp8) + packed_w[e] = d_fp8.T + raw_scale[e] = d_scale.T + + in_dim = I + per_tp = in_dim // tp_degree + repeat_factor = bs0 // per_tp if per_tp < bs0 else 1 + if repeat_factor > 1: + expanded_scale = raw_scale.repeat_interleave(repeat_factor, dim=1) + else: + expanded_scale = raw_scale + + return packed_w, expanded_scale + + else: + raise ValueError(f"Unknown layout: {layout}") + + +# --------------------------------------------------------------------------- +# State Dict Conversion +# --------------------------------------------------------------------------- + + +def convert_kimi_k2_hf_to_neuron_state_dict( + neuron_state_dict: Dict[str, Any], + config: KimiK2InferenceConfig, +) -> Dict[str, Any]: + """Convert HuggingFace Kimi-K2 / DeepSeek-V3 weights to NxDI format. + + Key conversions: + 1. Dequantize block-FP8 weights to BF16 (if FP8 weights detected) + 2. Rename router weights (gate.weight -> router.linear_router.weight) + 3. Concatenate per-expert weights into packed tensors for ExpertMLPsV2 + 4. Handle shared expert weight naming + 5. Handle e_score_correction_bias loading as router bias + """ + # Check if weights are already pre-converted + has_packed_experts = any( + "expert_mlps.mlp_op.gate_up_proj.weight" in k for k in neuron_state_dict + ) + has_per_expert = any( + "mlp.experts.0.gate_proj.weight" in k for k in neuron_state_dict + ) + + if has_packed_experts and not has_per_expert: + logger.info("Weights already pre-converted. Adding rank_util only.") + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + for layer_idx in range(config.num_hidden_layers): + neuron_state_dict[f"layers.{layer_idx}.self_attn.rank_util.rank"] = ( + torch.arange(0, config.neuron_config.tp_degree, dtype=torch.int32) + ) + return neuron_state_dict + + # Dequantize FP8 weights + _maybe_dequantize_fp8(neuron_state_dict, config) + + # Add rank utility tensors + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + for layer_idx in range(config.num_hidden_layers): + neuron_state_dict[f"layers.{layer_idx}.self_attn.rank_util.rank"] = ( + torch.arange(0, config.neuron_config.tp_degree, dtype=torch.int32) + ) + + is_moe_layer = layer_idx >= config.first_k_dense_replace + if not is_moe_layer: + continue + + # Router weights + gate_key = f"layers.{layer_idx}.mlp.gate.weight" + if gate_key in neuron_state_dict: + neuron_state_dict[f"layers.{layer_idx}.mlp.router.linear_router.weight"] = ( + neuron_state_dict[gate_key].detach().clone() + ) + del neuron_state_dict[gate_key] + + # e_score_correction_bias + bias_key = f"layers.{layer_idx}.mlp.gate.e_score_correction_bias" + if bias_key in neuron_state_dict: + neuron_state_dict[ + f"layers.{layer_idx}.mlp.router.e_score_correction_bias" + ] = neuron_state_dict[bias_key].detach().clone() + del neuron_state_dict[bias_key] + + # Expert weights: per-expert -> packed format + expert_0_gate = f"layers.{layer_idx}.mlp.experts.0.gate_proj.weight" + if expert_0_gate not in neuron_state_dict: + continue + + intermediate_size, hidden_size = neuron_state_dict[expert_0_gate].shape + device = neuron_state_dict[expert_0_gate].device + dtype = neuron_state_dict[expert_0_gate].dtype + num_experts = config.n_routed_experts + + # Concatenate gate + up projections + gate_up_proj = torch.empty( + num_experts, + hidden_size, + 2 * intermediate_size, + dtype=dtype, + device=device, + ) + for e in range(num_experts): + gate_w = ( + neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight" + ] + .T.detach() + .clone() + ) + up_w = ( + neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"] + .T.detach() + .clone() + ) + gate_up_proj[e, :, :intermediate_size] = gate_w + gate_up_proj[e, :, intermediate_size:] = up_w + del neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight" + ] + del neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"] + + neuron_state_dict[ + f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + ] = gate_up_proj + + # Down projections + down_proj = torch.empty( + num_experts, + intermediate_size, + hidden_size, + dtype=dtype, + device=device, + ) + for e in range(num_experts): + down_w = ( + neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" + ] + .T.detach() + .clone() + ) + down_proj[e] = down_w + del neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" + ] + + neuron_state_dict[ + f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.down_proj.weight" + ] = down_proj + + # Shared expert rename + for proj in ["gate_proj", "up_proj", "down_proj"]: + hf_key = f"layers.{layer_idx}.mlp.shared_experts.{proj}.weight" + nxdi_key = f"layers.{layer_idx}.shared_experts.{proj}.weight" + if hf_key in neuron_state_dict: + neuron_state_dict[nxdi_key] = neuron_state_dict[hf_key] + del neuron_state_dict[hf_key] + + gc.collect() + + return neuron_state_dict + + +def _maybe_dequantize_fp8( + neuron_state_dict: Dict[str, Any], + config: "KimiK2InferenceConfig", +): + """Dequantize block-FP8 weights to BF16.""" + scale_layers = [] + + for layer_key in list(neuron_state_dict.keys()): + if "_scale_inv" in layer_key or "weight_scale_inv" in layer_key: + scales = neuron_state_dict[layer_key] + scale_layers.append(layer_key) + + if layer_key.endswith(".weight_scale_inv"): + fp8_layer_name = layer_key.replace(".weight_scale_inv", ".weight") + elif "_scale_inv" in layer_key: + fp8_layer_name = layer_key.replace("_scale_inv", "") + + if fp8_layer_name not in neuron_state_dict: + continue + + fp8_layer = neuron_state_dict[fp8_layer_name] + + if hasattr(config, "quantization_config") and config.quantization_config: + block_size = config.quantization_config.get( + "weight_block_size", [128, 128] + ) + else: + block_size = [128, 128] + + fp32_val = _dequant_block_fp8_to_fp32(fp8_layer, scales, block_size) + neuron_state_dict[fp8_layer_name] = fp32_val.to( + config.neuron_config.torch_dtype + ) + + for scale_layer in scale_layers: + del neuron_state_dict[scale_layer] + + +# --------------------------------------------------------------------------- +# Top-level CausalLM +# --------------------------------------------------------------------------- + + +class NeuronKimiK2ForCausalLM(NeuronBaseForCausalLM): + """Kimi-K2 for causal language modeling on Neuron.""" + + _model_cls = NeuronKimiK2Model + + @staticmethod + def load_hf_model(model_path: str, **kwargs): + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + **kwargs, + ) + + @classmethod + def get_config_cls(cls) -> Type[KimiK2InferenceConfig]: + return KimiK2InferenceConfig + + @staticmethod + def _apply_ep_scale_fix(): + """Monkey-patch ExpertFusedLinear._mark_expert_parallel_weights to skip + per-channel scale params that can't be EP-sharded (shape [1, 1, W]).""" + from neuronx_distributed.modules.moe.moe_parallel_layers import ( + ExpertFusedLinear, + ) + + _original_mark = ExpertFusedLinear._mark_expert_parallel_weights + + def _patched_mark( + self_inner, + iterable=None, + expert_parallel_group_size=None, + is_prefill=True, + expert_distribution=None, + ): + from neuronx_distributed.parallel_layers.parallel_state import ( + get_expert_model_parallel_size, + ) + + if expert_parallel_group_size is None: + expert_parallel_group_size = get_expert_model_parallel_size() + + if expert_parallel_group_size > 1: + if iterable is None: + params_to_mark = [] + for name, p in self_inner.named_parameters(): + if name == "scale" and p.shape[0] == 1: + continue + params_to_mark.append(p) + iterable = params_to_mark + + for p in iterable: + p.expert_model_parallel = True + if is_prefill: + p.is_prefill = True + p.expert_distribution = expert_distribution + + ExpertFusedLinear._mark_expert_parallel_weights = _patched_mark + + @staticmethod + def _apply_blockwise_scale_stride_fix(): + """Monkey-patch _setup_for_scale to use stride=1 for blockwise symmetric + scales, which avoids strided splitting failures when per-rank weight size + is smaller than block size.""" + from neuronx_distributed.quantization.quantization_layers import ( + BaseQuantizeParallelLinear, + ) + from neuronx_distributed.quantization.quantization_config import ( + QuantizationType, + ) + + _original_setup = BaseQuantizeParallelLinear._setup_for_scale + + def _patched_setup(self_inner, *args, **kwargs): + _original_setup(self_inner, *args, **kwargs) + if ( + hasattr(self_inner, "quantization_type") + and self_inner.quantization_type == QuantizationType.BLOCKWISE_SYMMETRIC + and hasattr(self_inner, "scale") + and hasattr(self_inner.scale, "partition_stride") + and self_inner.scale.partition_stride > 1 + ): + self_inner.scale.partition_stride = 1 + + BaseQuantizeParallelLinear._setup_for_scale = _patched_setup + + def load( + self, + compiled_model_path, + start_rank_id=None, + local_ranks_size=None, + skip_warmup=False, + ): + """Override to apply EP scale fix before loading.""" + if getattr(self.neuron_config, "quantized", False): + self._apply_ep_scale_fix() + self._apply_blockwise_scale_stride_fix() + return super().load( + compiled_model_path, + start_rank_id=start_rank_id, + local_ranks_size=local_ranks_size, + skip_warmup=skip_warmup, + ) + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, + config: KimiK2InferenceConfig, + ) -> dict: + return convert_kimi_k2_hf_to_neuron_state_dict(state_dict, config) + + def checkpoint_loader_fn(self, mmap: bool = False): + """Memory-efficient streaming checkpoint loader for 1T-parameter models. + + Loads safetensor shards one at a time, processes weights (FP8 handling, + expert packing, router renaming), and accumulates results to avoid OOM. + + When quantized=True (FP8 path): + - Routed expert weights stay in FP8 with blockwise scales + - All other weights are dequantized to BF16 + + When quantized=False (BF16 path): + - All weights are dequantized from FP8 to BF16 + """ + import json as json_mod + from safetensors.torch import load_file + + model_path = getattr(self.config, "_name_or_path", None) + if model_path is None or not os.path.exists(str(model_path)): + model_path = self.model_path + + index_path = os.path.join(model_path, "model.safetensors.index.json") + + if not os.path.exists(index_path): + return super().checkpoint_loader_fn(mmap=mmap) + + with open(index_path, "r") as f: + index = json_mod.load(f) + + weight_map = index["weight_map"] + shard_files = sorted(set(weight_map.values())) + + quant_config = getattr(self.config, "quantization_config", None) + if isinstance(quant_config, dict): + block_size = quant_config.get("weight_block_size", [128, 128]) + else: + block_size = [128, 128] + n_routed_experts = getattr(self.config, "n_routed_experts", 384) + first_k_dense_replace = getattr(self.config, "first_k_dense_replace", 1) + keep_experts_fp8 = getattr(self.config.neuron_config, "quantized", False) + num_layers = self.config.num_hidden_layers + + # Determine which shards are needed (supports reduced-layer testing) + needed_shards = set() + for key, shard_file in weight_map.items(): + clean_key = key[len("model.") :] if key.startswith("model.") else key + if "layers." in clean_key: + parts = clean_key.split(".") + idx = parts.index("layers") + 1 + layer_idx = int(parts[idx]) + if layer_idx < num_layers: + needed_shards.add(shard_file) + else: + needed_shards.add(shard_file) + + logger.info( + f"Streaming loader: {len(shard_files)} shards, {len(needed_shards)} needed, " + f"block_size={block_size}, experts={n_routed_experts}, fp8={keep_experts_fp8}" + ) + + result_dict = {} + for i, shard_file in enumerate(shard_files): + if shard_file not in needed_shards: + continue + shard_path = os.path.join(model_path, shard_file) + logger.info(f"Loading shard [{i + 1}/{len(shard_files)}]: {shard_file}") + + shard_data = load_file(shard_path) + + # Strip "model." prefix + for key in list(shard_data.keys()): + if key.startswith("model."): + shard_data[key[len("model.") :]] = shard_data.pop(key) + + # Filter out keys for layers beyond num_hidden_layers + for key in list(shard_data.keys()): + if "layers." in key: + parts = key.split(".") + idx = parts.index("layers") + 1 + layer_idx = int(parts[idx]) + if layer_idx >= num_layers: + del shard_data[key] + + # Determine layers in this shard + layer_ids = set() + for key in shard_data: + if "layers." in key: + parts = key.split(".") + idx = parts.index("layers") + 1 + layer_ids.add(int(parts[idx])) + + # Build expert weight/scale key mapping + expert_weight_keys = set() + expert_scale_keys = {} + if keep_experts_fp8: + for key in list(shard_data.keys()): + if ".mlp.experts." in key and ".weight" in key: + if ".shared_experts." not in key: + expert_weight_keys.add(key) + + for key in list(shard_data.keys()): + if "_scale_inv" in key or "weight_scale_inv" in key: + if key.endswith(".weight_scale_inv"): + wk = key.replace(".weight_scale_inv", ".weight") + elif "_scale_inv" in key: + wk = key.replace("_scale_inv", "") + else: + continue + if wk in expert_weight_keys: + expert_scale_keys[wk] = key + + # Process FP8 weights: dequant non-expert, keep expert raw + scale_keys = [ + k for k in shard_data if "weight_scale_inv" in k or "_scale_inv" in k + ] + for scale_key in scale_keys: + scales = shard_data[scale_key] + if scale_key.endswith(".weight_scale_inv"): + weight_key = scale_key.replace(".weight_scale_inv", ".weight") + elif "_scale_inv" in scale_key: + weight_key = scale_key.replace("_scale_inv", "") + else: + del shard_data[scale_key] + continue + + if weight_key not in shard_data: + del shard_data[scale_key] + continue + + if shard_data[weight_key].dtype != torch.float8_e4m3fn: + del shard_data[scale_key] + continue + + if weight_key in expert_weight_keys: + pass # Keep for expert packing + else: + fp8_w = shard_data[weight_key] + fp32_w = _dequant_block_fp8_to_fp32(fp8_w, scales, block_size) + shard_data[weight_key] = fp32_w.to(torch.bfloat16) + del fp8_w, fp32_w + del shard_data[scale_key] + + # Remove orphan scale keys + for k in [ + k + for k in shard_data + if ("_scale_inv" in k or "weight_scale_inv" in k) + and k not in expert_scale_keys.values() + ]: + del shard_data[k] + + # Cast non-FP8 tensors to BF16 + for key in list(shard_data.keys()): + t = shard_data[key] + if torch.is_floating_point(t) and t.dtype not in ( + torch.bfloat16, + torch.float8_e4m3fn, + torch.float32, + ): + if t.dtype != torch.int64 and t.dtype != torch.int32: + shard_data[key] = t.to(torch.bfloat16) + + # Pack experts and rename for MoE layers + for layer_idx in sorted(layer_ids): + prefix = f"layers.{layer_idx}" + if layer_idx >= first_k_dense_replace: + # Router rename + gate_key = f"{prefix}.mlp.gate.weight" + if gate_key in shard_data: + shard_data[f"{prefix}.mlp.router.linear_router.weight"] = ( + shard_data.pop(gate_key) + ) + bias_key = f"{prefix}.mlp.gate.e_score_correction_bias" + if bias_key in shard_data: + shard_data[f"{prefix}.mlp.router.linear_router.bias"] = ( + shard_data.pop(bias_key) + ) + + # Pack experts + e0_gate = f"{prefix}.mlp.experts.0.gate_proj.weight" + if e0_gate in shard_data: + isize, hsize = shard_data[e0_gate].shape + + if keep_experts_fp8: + gate_up_weights = [] + gate_up_scales = [] + down_weights = [] + down_scales = [] + + for e in range(n_routed_experts): + gk = f"{prefix}.mlp.experts.{e}.gate_proj.weight" + uk = f"{prefix}.mlp.experts.{e}.up_proj.weight" + dk = f"{prefix}.mlp.experts.{e}.down_proj.weight" + gsk = expert_scale_keys.get(gk) + usk = expert_scale_keys.get(uk) + dsk = expert_scale_keys.get(dk) + + g_fp8 = shard_data.pop(gk) if gk in shard_data else None + u_fp8 = shard_data.pop(uk) if uk in shard_data else None + g_scale = ( + shard_data.pop(gsk) + if gsk and gsk in shard_data + else None + ) + u_scale = ( + shard_data.pop(usk) + if usk and usk in shard_data + else None + ) + + if g_fp8 is not None and u_fp8 is not None: + gate_up_weights.append((g_fp8, u_fp8)) + gate_up_scales.append((g_scale, u_scale)) + + d_fp8 = shard_data.pop(dk) if dk in shard_data else None + d_scale = ( + shard_data.pop(dsk) + if dsk and dsk in shard_data + else None + ) + if d_fp8 is not None: + down_weights.append(d_fp8) + down_scales.append(d_scale) + + if gate_up_weights: + gu_fp8, gu_scale = _pack_experts_blockwise_fp8( + gate_up_weights, + gate_up_scales, + block_size, + tp_degree=self.config.neuron_config.tp_degree, + layout="gate_up", + ) + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + ] = gu_fp8 + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.gate_up_proj.scale" + ] = gu_scale + del gate_up_weights, gate_up_scales + + if down_weights: + dn_fp8, dn_scale = _pack_experts_blockwise_fp8( + down_weights, + down_scales, + block_size, + tp_degree=self.config.neuron_config.tp_degree, + layout="down", + ) + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.down_proj.weight" + ] = dn_fp8 + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.down_proj.scale" + ] = dn_scale + del down_weights, down_scales + + else: + # BF16 path + dtype = shard_data[e0_gate].dtype + gate_up = torch.zeros( + n_routed_experts, + hsize, + 2 * isize, + dtype=dtype, + device="cpu", + ) + for e in range(n_routed_experts): + gk = f"{prefix}.mlp.experts.{e}.gate_proj.weight" + uk = f"{prefix}.mlp.experts.{e}.up_proj.weight" + if gk in shard_data: + gate_up[e, :, :isize] = shard_data.pop(gk).T + if uk in shard_data: + gate_up[e, :, isize:] = shard_data.pop(uk).T + + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + ] = gate_up + + down = torch.zeros( + n_routed_experts, + isize, + hsize, + dtype=dtype, + device="cpu", + ) + for e in range(n_routed_experts): + dk = f"{prefix}.mlp.experts.{e}.down_proj.weight" + if dk in shard_data: + down[e] = shard_data.pop(dk).T + + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.down_proj.weight" + ] = down + + # Clean up remaining per-expert keys + for e in range(n_routed_experts): + for proj in ["gate_proj", "up_proj", "down_proj"]: + for suffix in [".weight", ".weight_scale_inv"]: + k = f"{prefix}.mlp.experts.{e}.{proj}{suffix}" + if k in shard_data: + del shard_data[k] + + # Shared expert rename + for proj in ["gate_proj", "up_proj", "down_proj"]: + hf_key = f"{prefix}.mlp.shared_experts.{proj}.weight" + nxdi_key = f"{prefix}.shared_experts.{proj}.weight" + if hf_key in shard_data: + shard_data[nxdi_key] = shard_data.pop(hf_key) + + # Cast remaining float32 non-scale tensors to BF16 + for key in list(shard_data.keys()): + t = shard_data[key] + if ( + t.dtype == torch.float32 + and not key.endswith(".scale") + and not key.endswith("_scale_inv") + and not key.endswith("linear_router.bias") + ): + shard_data[key] = t.to(torch.bfloat16) + + result_dict.update(shard_data) + del shard_data + gc.collect() + + # Add rank_util tensors + tp = self.config.neuron_config.tp_degree + result_dict["rank_util.rank"] = torch.arange(0, tp, dtype=torch.int32) + for layer_idx in range(self.config.num_hidden_layers): + result_dict[f"layers.{layer_idx}.self_attn.rank_util.rank"] = torch.arange( + 0, tp, dtype=torch.int32 + ) + + # Add fused prefix if needed + if self._FUSED_PREFIX != "": + for key in list(result_dict.keys()): + result_dict[f"{self._FUSED_PREFIX}.{key}"] = result_dict.pop(key) + + logger.info(f"Streaming loader done. Total keys: {len(result_dict)}") + return result_dict + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + def get_compiler_args(self) -> str: + """Compiler args optimized for Kimi-K2 on trn2. + + Key flags: + - -O3 for TKG with EP > 1: avoids Modular flow perf degradation + - --internal-enable-dge-levels vector_dynamic_offsets: DGE optimization + - --enable-ccop-compute-overlap: CC overlap for MoE all-to-all + - --lnc: must match runtime NEURON_LOGICAL_NC_CONFIG + """ + lnc = getattr(self.neuron_config, "logical_nc_config", 1) + ep_degree = getattr(self.neuron_config, "moe_ep_degree", 1) + + if self.compile_tag == TOKEN_GENERATION_MODEL_TAG and ep_degree > 1: + optimization_level = "-O3" + else: + optimization_level = "-O1" + + compiler_args = ( + "--enable-saturate-infinity " + "--enable-mixed-precision-accumulation " + "--model-type transformer " + f"{optimization_level}" + ) + compiler_args += ( + " --tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2'" + ) + compiler_args += " --tensorizer-options='--vectorize-strided-dma'" + compiler_args += " --auto-cast=none" + compiler_args += " --internal-enable-dge-levels vector_dynamic_offsets" + compiler_args += f" --lnc={lnc}" + + hlo2tensorizer_opts = "--verify-hlo=true" + if ( + getattr(self.neuron_config, "quantized", False) + and getattr(self.neuron_config, "quantization_dtype", "") == "f8e4m3" + ): + hlo2tensorizer_opts += " --experimental-unsafe-fp8e4m3fn-as-fp8e4m3" + + compiler_args += f" --internal-hlo2tensorizer-options='{hlo2tensorizer_opts}'" + + return compiler_args diff --git a/contrib/models/Kimi-K2.5/src/modeling_kimi_k25.py b/contrib/models/Kimi-K2.5/src/modeling_kimi_k25.py new file mode 100644 index 00000000..e64920f0 --- /dev/null +++ b/contrib/models/Kimi-K2.5/src/modeling_kimi_k25.py @@ -0,0 +1,1041 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Kimi-K2.5 (moonshotai/Kimi-K2.5) multimodal on Neuron via NxDI. +# +# Architecture: Kimi-K2 text decoder + MoonViT-400M vision encoder +# - Text: 1T MoE (384 experts, 8 active) + MLA attention, 61 layers +# - Vision: MoonViT (27-layer ViT, hidden=1152) + PatchMergerMLP → 7168 +# - Fusion: scatter_by_index_put (Llama4/Pixtral pattern) +# - K2.5 weights: INT4 compressed-tensors (experts) + BF16 (non-experts) +# → dequantized to BF16 → FP8 per-channel quantized for Neuron +# +# Supported configuration: +# - trn2.48xlarge: TP=64, EP=1, LNC=2, seq_len=512, batch_size=1 +# - FP8 per-channel quantization for routed expert weights +# - CPU greedy sampling (no on-device sampling) +# - Pre-computed MoonViT embeddings (all cores used by text decoder) +# +# The text decoder reuses the K2 model code (modeling_kimi_k2.py) unchanged. +# This file adds: +# 1. K2.5 checkpoint loader (INT4 dequant, prefix stripping) +# 2. Vision embedding fusion (encode_vision_to_input) +# 3. ImageToTextModelWrapper with non-trivial tracing inputs +# 4. Forward/output overrides for 24-arg vision pipeline +# 5. MoonViT image preprocessing utilities +# +# References: +# - Kimi-K2 NxDI contrib (PR #131) +# - NxDI Llama4/Pixtral scatter_by_index_put pattern +# - NxDI ImageToTextModelWrapper / NeuronBaseForImageToText + +import gc +import json +import logging +import math +import os +import shutil +import types +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from neuronx_distributed_inference.models.config import ( + MoENeuronConfig, + RouterConfig, +) +from neuronx_distributed_inference.models.image_to_text_model_wrapper import ( + ImageToTextModelWrapper, +) +from neuronx_distributed_inference.modules.generation.sampling import ( + prepare_sampling_params, +) + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Constants +# ============================================================================ + +PATCH_SIZE = 14 +MERGE_KERNEL = 2 +DEFAULT_IMAGE_SIZE = 448 + +# K2.5 special token IDs +BOS_TOKEN_ID = 163584 +IM_USER_TOKEN_ID = 163587 +IM_END_TOKEN_ID = 163586 +IM_ASSISTANT_TOKEN_ID = 163588 +MEDIA_PLACEHOLDER_TOKEN_ID = 163605 + +# Image normalization +IMAGE_MEAN = np.array([0.5, 0.5, 0.5], dtype=np.float32) +IMAGE_STD_INV = np.array([2.0, 2.0, 2.0], dtype=np.float32) + +# FP8 E4M3 max representable value +_FP8_E4M3_MAX = 240.0 + +# K2.5 weight key prefixes +K25_PREFIX = "language_model.model." +K2_PREFIX = "language_model." +VISION_PREFIXES = ("vision_tower.", "mm_projector.", "multi_modal_projector.") + + +# ============================================================================ +# Image Preprocessing +# ============================================================================ + + +def preprocess_image(image, target_size=DEFAULT_IMAGE_SIZE): + """Preprocess a PIL image for MoonViT. + + Args: + image: PIL Image + target_size: Target size (square) + + Returns: + pixel_values: [N_patches, 3, 14, 14] bfloat16 + grid_thw: (1, h_patches, w_patches) + n_merged_tokens: Number of vision tokens after 2x2 merge + """ + image = image.convert("RGB") + image = image.resize((target_size, target_size)) + img_np = np.array(image, dtype=np.float32) + img_np = (img_np / 255.0 - 0.5) * 2.0 + img_np = img_np[np.newaxis, ...] + + T, H, W, C = img_np.shape + h_patches = H // PATCH_SIZE + w_patches = W // PATCH_SIZE + patches = img_np.reshape(T, h_patches, PATCH_SIZE, w_patches, PATCH_SIZE, C) + patches = patches.transpose(0, 1, 3, 5, 2, 4) + patches = patches.reshape(-1, C, PATCH_SIZE, PATCH_SIZE) + + pixel_values = torch.from_numpy(patches).to(torch.bfloat16) + grid_thw = (1, h_patches, w_patches) + n_merged_tokens = (h_patches // MERGE_KERNEL) * (w_patches // MERGE_KERNEL) + + return pixel_values, grid_thw, n_merged_tokens + + +def precompute_rope_tables(h_patches, w_patches, head_dim=72, theta=10000.0): + """Precompute 2D RoPE cos/sin tables for MoonViT. + + Returns: + cos_table: [h_patches * w_patches, head_dim // 2] bfloat16 + sin_table: [h_patches * w_patches, head_dim // 2] bfloat16 + """ + N = h_patches * w_patches + flat_pos = torch.arange(N, dtype=torch.float32) + x_pos = flat_pos % w_patches + y_pos = flat_pos // w_patches + + dim = head_dim + dim_range = torch.arange(0, dim, 4)[: dim // 4].float() + freqs = 1.0 / (theta ** (dim_range / dim)) + + x_freqs = torch.outer(x_pos, freqs) + y_freqs = torch.outer(y_pos, freqs) + + angles = torch.cat([y_freqs, x_freqs], dim=-1) + cos_table = torch.cos(angles) + sin_table = torch.sin(angles) + + return cos_table.to(torch.bfloat16), sin_table.to(torch.bfloat16) + + +# ============================================================================ +# K2.5 Checkpoint Loader +# ============================================================================ + + +def _dequant_int4_packed_symmetric(weight_packed, weight_scale, group_size=32): + """Dequantize INT4 symmetric pack-quantized weights to BF16. + + compressed-tensors 'pack-quantized' format: + - weight_packed: INT32 [out_features, in_features // 8] (8 INT4 per int32) + - weight_scale: FP16/BF16 [out_features, in_features // group_size] + - Interleaved column ordering, offset-binary sign convention + + Returns: BF16 [out_features, in_features] + """ + out_features = weight_packed.shape[0] + packed_cols = weight_packed.shape[1] + in_features = packed_cols * 8 + pack_factor = 8 + num_bits = 4 + mask = (1 << num_bits) - 1 + + unpacked = torch.zeros( + (out_features, in_features), device=weight_packed.device, dtype=torch.int32 + ) + for i in range(pack_factor): + unpacked[:, i::pack_factor] = (weight_packed >> (num_bits * i)) & mask + + unpacked = (unpacked - 8).to(torch.int8) + + unpacked_f32 = unpacked.to(torch.float32) + scale = weight_scale.to(torch.float32) + scale = ( + scale.unsqueeze(-1) + .expand(-1, -1, group_size) + .reshape(out_features, in_features) + ) + + result = unpacked_f32 * scale + return result.to(torch.bfloat16) + + +def _strip_k25_prefix(key): + """Strip K2.5 weight key prefix to K2-compatible format.""" + if key.startswith(K25_PREFIX): + return key[len(K25_PREFIX) :] + if key.startswith(K2_PREFIX): + return key[len(K2_PREFIX) :] + return key + + +def _is_vision_key(key): + """Check if a weight key belongs to the vision encoder or projector.""" + for prefix in VISION_PREFIXES: + if key.startswith(prefix): + return True + return False + + +def _requantize_per_channel_fp8(bf16_weight): + """Re-quantize BF16 to per-expert per-channel FP8 E4M3. + + Args: + bf16_weight: [E, H, W] bfloat16 + Returns: + (fp8_weight, scale): fp8_weight [E, H, W], scale [E, 1, W] float32 + """ + fp32_weight = bf16_weight.float() + amax = fp32_weight.abs().amax(dim=1, keepdim=True).clamp(min=1e-12) + scale = amax / _FP8_E4M3_MAX + scaled = (fp32_weight / scale).clamp(-_FP8_E4M3_MAX, _FP8_E4M3_MAX) + fp8_weight = scaled.to(torch.float8_e4m3fn) + + # Clamp exponent-15 bytes + raw = fp8_weight.view(torch.uint8) + pos_exp15 = (raw >= 0x78) & (raw <= 0x7E) + neg_exp15 = (raw >= 0xF8) & (raw <= 0xFE) + raw = torch.where(pos_exp15, torch.tensor(0x77, dtype=torch.uint8), raw) + raw = torch.where(neg_exp15, torch.tensor(0xF7, dtype=torch.uint8), raw) + fp8_weight = raw.view(torch.float8_e4m3fn) + + return fp8_weight, scale.to(torch.float32) + + +def k25_checkpoint_loader_fn(model_self, mmap=False): + """K2.5-adapted checkpoint loader. + + Handles: + 1. 'language_model.model.' prefix stripping + 2. Vision key filtering + 3. INT4 compressed-tensors dequantization → BF16 + 4. Expert packing (gate+up concat, down transpose) + 5. Optional FP8 per-channel re-quantization + """ + from safetensors.torch import load_file + + model_path = getattr(model_self.config, "_name_or_path", None) + if model_path is None or not os.path.exists(str(model_path)): + model_path = model_self.model_path + + index_path = os.path.join(model_path, "model.safetensors.index.json") + if not os.path.exists(index_path): + return model_self.__class__.__bases__[0].checkpoint_loader_fn( + model_self, mmap=mmap + ) + + with open(index_path, "r") as f: + index = json.load(f) + + weight_map = index["weight_map"] + shard_files = sorted(set(weight_map.values())) + + quant_config = getattr(model_self.config, "quantization_config", None) + group_size = 32 + if isinstance(quant_config, dict): + config_groups = quant_config.get("config_groups", {}) + for group_cfg in config_groups.values(): + weights_cfg = group_cfg.get("weights", {}) + gs = weights_cfg.get("group_size", 32) + if gs: + group_size = gs + break + + n_routed_experts = getattr(model_self.config, "n_routed_experts", 384) + first_k_dense_replace = getattr(model_self.config, "first_k_dense_replace", 1) + num_layers = model_self.config.num_hidden_layers + keep_experts_fp8 = getattr(model_self.config.neuron_config, "quantized", False) + + # Determine needed shards + needed_shards = set() + for key, shard_file in weight_map.items(): + clean_key = _strip_k25_prefix(key) + if _is_vision_key(clean_key) or _is_vision_key(key): + continue + if "layers." in clean_key: + parts = clean_key.split(".") + idx = parts.index("layers") + 1 + layer_idx = int(parts[idx]) + if layer_idx < num_layers: + needed_shards.add(shard_file) + else: + needed_shards.add(shard_file) + + logger.info( + f"K2.5 loader: {len(shard_files)} shards, {len(needed_shards)} needed, " + f"num_layers={num_layers}, group_size={group_size}, " + f"experts={n_routed_experts}, fp8={keep_experts_fp8}" + ) + + result_dict = {} + for i, shard_file in enumerate(shard_files): + if shard_file not in needed_shards: + continue + + shard_path = os.path.join(model_path, shard_file) + logger.info(f"Loading shard [{i + 1}/{len(shard_files)}]: {shard_file}") + shard_data = load_file(shard_path) + + # Strip K2.5 prefix and filter vision keys + for key in list(shard_data.keys()): + clean_key = _strip_k25_prefix(key) + if _is_vision_key(clean_key) or _is_vision_key(key): + del shard_data[key] + continue + if clean_key != key: + shard_data[clean_key] = shard_data.pop(key) + + # Filter layers beyond num_hidden_layers + for key in list(shard_data.keys()): + if "layers." in key: + parts = key.split(".") + idx = parts.index("layers") + 1 + layer_idx = int(parts[idx]) + if layer_idx >= num_layers: + del shard_data[key] + + # Dequantize INT4 packed expert weights + packed_keys = [k for k in shard_data if k.endswith(".weight_packed")] + shape_keys = [k for k in shard_data if k.endswith(".weight_shape")] + zp_keys = [k for k in shard_data if k.endswith(".weight_zero_point")] + + for packed_key in packed_keys: + scale_key = packed_key.replace(".weight_packed", ".weight_scale") + weight_key = packed_key.replace(".weight_packed", ".weight") + + packed = shard_data[packed_key] + scale = shard_data.get(scale_key) + + if scale is None: + logger.warning(f"No scale for {packed_key}, skipping") + continue + + if packed.dtype in (torch.int32, torch.int16, torch.int8, torch.uint8): + dequant = _dequant_int4_packed_symmetric(packed, scale, group_size) + shard_data[weight_key] = dequant + del packed, dequant + else: + shard_data[weight_key] = shard_data[packed_key] + + del shard_data[packed_key] + if scale_key in shard_data: + del shard_data[scale_key] + + for k in shape_keys: + shard_data.pop(k, None) + for k in zp_keys: + shard_data.pop(k, None) + + # Cast non-BF16 float tensors + for key in list(shard_data.keys()): + t = shard_data[key] + if torch.is_floating_point(t) and t.dtype not in ( + torch.bfloat16, + torch.float32, + ): + if t.dtype != torch.int64 and t.dtype != torch.int32: + shard_data[key] = t.to(torch.bfloat16) + + # Determine layers in this shard + layer_ids = set() + for key in shard_data: + if "layers." in key: + parts = key.split(".") + idx = parts.index("layers") + 1 + layer_ids.add(int(parts[idx])) + + # Pack experts and rename for MoE layers + for layer_idx in sorted(layer_ids): + prefix = f"layers.{layer_idx}" + if layer_idx >= first_k_dense_replace: + # Router rename + gate_key = f"{prefix}.mlp.gate.weight" + if gate_key in shard_data: + shard_data[f"{prefix}.mlp.router.linear_router.weight"] = ( + shard_data.pop(gate_key) + ) + bias_key = f"{prefix}.mlp.gate.e_score_correction_bias" + if bias_key in shard_data: + shard_data[f"{prefix}.mlp.router.linear_router.bias"] = ( + shard_data.pop(bias_key) + ) + + # Pack experts + e0_gate = f"{prefix}.mlp.experts.0.gate_proj.weight" + if e0_gate in shard_data: + isize, hsize = shard_data[e0_gate].shape + dtype = shard_data[e0_gate].dtype + + gate_up = torch.zeros( + n_routed_experts, hsize, 2 * isize, dtype=dtype, device="cpu" + ) + for e in range(n_routed_experts): + gk = f"{prefix}.mlp.experts.{e}.gate_proj.weight" + uk = f"{prefix}.mlp.experts.{e}.up_proj.weight" + if gk in shard_data: + gate_up[e, :, :isize] = shard_data.pop(gk).T + if uk in shard_data: + gate_up[e, :, isize:] = shard_data.pop(uk).T + + down = torch.zeros( + n_routed_experts, isize, hsize, dtype=dtype, device="cpu" + ) + for e in range(n_routed_experts): + dk = f"{prefix}.mlp.experts.{e}.down_proj.weight" + if dk in shard_data: + down[e] = shard_data.pop(dk).T + + if keep_experts_fp8: + gu_fp8, gu_scale = _requantize_per_channel_fp8(gate_up) + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + ] = gu_fp8 + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.gate_up_proj.scale" + ] = gu_scale + dn_fp8, dn_scale = _requantize_per_channel_fp8(down) + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.down_proj.weight" + ] = dn_fp8 + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.down_proj.scale" + ] = dn_scale + del gate_up, down + else: + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + ] = gate_up + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.down_proj.weight" + ] = down + + # Clean up per-expert keys + for e in range(n_routed_experts): + for proj in ["gate_proj", "up_proj", "down_proj"]: + for suffix in [ + ".weight", + ".weight_scale", + ".weight_shape", + ".weight_zero_point", + ]: + shard_data.pop( + f"{prefix}.mlp.experts.{e}.{proj}{suffix}", None + ) + + # Shared expert rename + for proj in ["gate_proj", "up_proj", "down_proj"]: + hf_key = f"{prefix}.mlp.shared_experts.{proj}.weight" + nxdi_key = f"{prefix}.shared_experts.{proj}.weight" + if hf_key in shard_data: + shard_data[nxdi_key] = shard_data.pop(hf_key) + + # Cast float32 to BF16 (except scales and router bias) + for key in list(shard_data.keys()): + t = shard_data[key] + if ( + t.dtype == torch.float32 + and not key.endswith(".scale") + and not key.endswith("linear_router.bias") + ): + shard_data[key] = t.to(torch.bfloat16) + + result_dict.update(shard_data) + del shard_data + gc.collect() + + # Add rank_util tensors + tp = model_self.config.neuron_config.tp_degree + result_dict["rank_util.rank"] = torch.arange(0, tp, dtype=torch.int32) + for layer_idx in range(num_layers): + result_dict[f"layers.{layer_idx}.self_attn.rank_util.rank"] = torch.arange( + 0, tp, dtype=torch.int32 + ) + + # Add fused prefix if needed + if model_self._FUSED_PREFIX != "": + for key in list(result_dict.keys()): + result_dict[f"{model_self._FUSED_PREFIX}.{key}"] = result_dict.pop(key) + + logger.info(f"K2.5 loader done. Total keys: {len(result_dict)}") + return result_dict + + +# ============================================================================ +# Vision Embedding Fusion +# ============================================================================ + + +def patch_encode_vision_to_input(NeuronKimiK2Model): + """Add encode_vision_to_input() to NeuronKimiK2Model. + + Called by NeuronBaseModel.get_model_output() during context encoding + when vision_embeddings and vision_mask are non-None. + + Uses scatter_by_index_put (Llama4/Pixtral pattern): replaces text + embeddings at vision token positions with projected vision embeddings. + + Args: + NeuronKimiK2Model: The text decoder model class to patch. + """ + + def _encode_vision_to_input(self, inputs_embeds, vision_embeddings, vision_mask): + """Merge vision into text embeddings via index_put_. + + Args: + inputs_embeds: [BS, n_active, hidden_size] + vision_embeddings: [BS, n_active, hidden_size] — packed at front + vision_mask: [BS, n_active, 1] — integer position indices, + fill_value=(n_active - 1) for padding + Returns: + merged_embeds: [BS, n_active, hidden_size] + """ + _, max_positions, embedding_dim = inputs_embeds.shape + h = inputs_embeds.clone() + flat_ve = vision_embeddings.view(-1, embedding_dim) + positions = vision_mask.view(-1) + num_positions = len(positions) + flat_ve = flat_ve[:num_positions] + h.view(-1, embedding_dim).index_put_((positions,), flat_ve, accumulate=False) + return h + + NeuronKimiK2Model.encode_vision_to_input = _encode_vision_to_input + + +# ============================================================================ +# ImageToTextModelWrapper with non-trivial tracing inputs +# ============================================================================ + + +class K25ImageToTextModelWrapper(ImageToTextModelWrapper): + """Custom wrapper with non-trivial (ones-like) vision tracing inputs. + + The standard ImageToTextModelWrapper provides zero-filled vision inputs + for tracing, which the Neuron XLA compiler may optimize away (writing + zeros at position 0 is a no-op). We use ones-like inputs matching + NxDI's proven test_scatter.py pattern. + """ + + def input_generator(self): + inputs = [] + for bucket in self.neuron_config.buckets: + n_active_tokens = ( + bucket + if self.neuron_config.bucket_n_active_tokens + else self.neuron_config.n_active_tokens + ) + + input_ids = torch.zeros( + (self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32 + ) + attention_mask = torch.zeros( + (self.neuron_config.batch_size, bucket), dtype=torch.int32 + ) + position_ids = torch.zeros( + (self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32 + ) + seq_ids = torch.zeros((self.neuron_config.batch_size), dtype=torch.int32) + + sampling_params_len = prepare_sampling_params(1).shape[1] + sampling_params = torch.zeros( + (self.neuron_config.batch_size, sampling_params_len), + dtype=torch.float32, + ) + + if n_active_tokens > 1: + # CTE: ones-like vision inputs to prevent compiler optimization + vision_embeddings = torch.ones( + self.neuron_config.batch_size, + n_active_tokens, + self.config.hidden_size, + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.ones( + self.neuron_config.batch_size, + n_active_tokens, + 1, + dtype=torch.int32, + ) + else: + # TKG: empty vision inputs + vision_embeddings = torch.zeros( + (0), dtype=self.config.neuron_config.torch_dtype + ) + vision_mask = torch.zeros((0), dtype=torch.bool) + + inputs.append( + ( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + torch.empty(0), # prev_hidden + torch.empty(0), # adapter_ids + torch.empty(0), # accepted_indices + torch.empty(0), # current_length + torch.empty(0), # medusa_mask + torch.empty(0), # scatter_index + torch.empty(0), # slot_mapping + torch.empty(0), # active_block_table + torch.empty(0), # num_queries + torch.empty(0), # computed_context_lens + torch.empty(0), # tile_q_indices + torch.empty(0), # tile_block_tables + torch.empty(0), # tile_masks + torch.empty(0), # inputs_embeds + torch.empty(0), # kv_cache + torch.empty(0), # active_mask + torch.empty(0), # rotary_position_id + vision_embeddings, + vision_mask, + ) + ) + + return inputs + + +# ============================================================================ +# Model Patching — Apply all K2.5-specific patches +# ============================================================================ + + +def apply_k25_patches(NeuronKimiK2ForCausalLM, NeuronKimiK2Model, ep_degree=1): + """Apply all patches to transform K2 text-only model into K2.5 multimodal. + + Must be called BEFORE model initialization (__init__ calls + enable_context_encoding/enable_token_generation which use wrapper class). + + Args: + NeuronKimiK2ForCausalLM: The top-level K2 model class + NeuronKimiK2Model: The K2 model graph class + ep_degree: Expert parallelism degree + """ + # 1. Vision embedding fusion + patch_encode_vision_to_input(NeuronKimiK2Model) + + # 2. ImageToTextModelWrapper + def _get_model_wrapper_cls(self): + return K25ImageToTextModelWrapper + + NeuronKimiK2ForCausalLM.get_model_wrapper_cls = _get_model_wrapper_cls + + # 3. Default compiler args (ModelWrapper handles flags) + NeuronKimiK2ForCausalLM.get_compiler_args = lambda self: None + + # 4. Forward with vision support + _orig_forward = NeuronKimiK2ForCausalLM.forward + + def _vl_forward( + self, + input_ids=None, + seq_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + sampling_params=None, + prev_hidden=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + adapter_ids=None, + medusa_args=None, + return_dict=None, + llava_args=None, + input_capture_hook=None, + slot_mapping=None, + block_table=None, + full_context_lens=None, + computed_context_lens=None, + vision_embeddings=None, + vision_mask=None, + **kwargs, + ): + input_ids, attention_mask, position_ids, seq_ids, sampling_params = ( + self.preprocess_inputs( + input_ids=input_ids, + seq_ids=seq_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + sampling_params=sampling_params, + prev_hidden=prev_hidden, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + adapter_ids=adapter_ids, + medusa_args=medusa_args, + return_dict=return_dict, + llava_args=llava_args if llava_args else [], + input_capture_hook=input_capture_hook, + slot_mapping=slot_mapping, + block_table=block_table, + full_context_lens=full_context_lens, + computed_context_lens=computed_context_lens, + ) + ) + + outputs, is_run_on_neuron = self._get_model_outputs( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + medusa_args, + llava_args if llava_args else [], + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + ) + + generation_model = self.get_generation_model() + if not generation_model.is_neuron(): + self._copy_past_key_values(outputs) + + if ( + self.on_device_sampling + and self.neuron_config.output_logits + and not ( + self.neuron_config.enable_fused_speculation + or self.neuron_config.is_medusa + ) + ): + logits_or_next_tokens = outputs[:2] + constructed_outputs = self._construct_output_with_tokens_and_logits( + next_tokens=logits_or_next_tokens[0], + logits=logits_or_next_tokens[1], + ) + else: + if is_run_on_neuron: + logits_or_next_tokens = outputs + else: + logits_or_next_tokens, *_ = outputs + constructed_outputs = self._construct_output(logits_or_next_tokens) + + return constructed_outputs + + NeuronKimiK2ForCausalLM.forward = _vl_forward + + # 5. _get_model_outputs with 24-arg ImageToText format + def _vl_get_model_outputs( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + medusa_args, + llava_args, + vision_embeddings=None, + vision_mask=None, + **kwargs, + ): + if vision_embeddings is None: + vision_embeddings = torch.empty(0) + if vision_mask is None: + vision_mask = torch.empty(0) + + args_24 = ( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + vision_embeddings, + vision_mask, + ) + + if self._is_prefill(position_ids): + outputs = self.context_encoding_model(*args_24) + self.kv_cache_populated = True + is_run_on_neuron = self.context_encoding_model.is_neuron() + else: + outputs = self.token_generation_model(*args_24) + is_run_on_neuron = self.token_generation_model.is_neuron() + + return outputs, is_run_on_neuron + + NeuronKimiK2ForCausalLM._get_model_outputs = _vl_get_model_outputs + + # 6. EP-safe MoE forward (SDK 2.29 blockwise CTE regression workaround) + if ep_degree > 1: + from neuronx_distributed.modules.moe.expert_mlps_v2 import ExpertMLPsV2 + + _original_forward = ExpertMLPsV2.forward + + def _ep_safe_forward( + self_emv, + hidden_states, + expert_affinities, + expert_index, + seq_len, + padding_mask=None, + expert_affinities_masked_full=None, + ): + if ( + self_emv.moe_expert_model_parallel_group.size() > 1 + and not self_emv.training + ): + return self_emv.forward_all_experts_EP( + hidden_states, expert_affinities, expert_index + ) + return _original_forward( + self_emv, + hidden_states, + expert_affinities, + expert_index, + seq_len, + padding_mask=padding_mask, + expert_affinities_masked_full=expert_affinities_masked_full, + ) + + ExpertMLPsV2.forward = _ep_safe_forward + + +def apply_k25_checkpoint_patch(model): + """Patch model to use K2.5 checkpoint loader and mmap weight loading. + + Must be called AFTER model initialization. + + Args: + model: NeuronKimiK2ForCausalLM instance + """ + # K2.5 checkpoint loader + model.checkpoint_loader_fn = types.MethodType(k25_checkpoint_loader_fn, model) + + # No-op convert (K2.5 loader handles everything) + def _noop_convert(state_dict, config): + return state_dict + + model.convert_hf_to_neuron_state_dict = staticmethod(_noop_convert) + + # mmap-based weight loading + def _mmap_load_weights( + self_model, compiled_model_path, start_rank_id=None, local_ranks_size=None + ): + import resource + from safetensors import safe_open + + if self_model.traced_model is None: + raise ValueError("Model is not loaded") + if start_rank_id is None: + start_rank_id = self_model.neuron_config.start_rank_id + if local_ranks_size is None: + local_ranks_size = self_model.neuron_config.local_ranks_size + + weights_dir = os.path.join(compiled_model_path, "weights") + first_shard = os.path.join( + weights_dir, f"tp{start_rank_id}_sharded_checkpoint.safetensors" + ) + + if os.path.exists(first_shard): + logger.info(f"Loading pre-sharded weights from {weights_dir}") + weights = [] + for rank in range(start_rank_id, start_rank_id + local_ranks_size): + fpath = os.path.join( + weights_dir, f"tp{rank}_sharded_checkpoint.safetensors" + ) + sf = safe_open(fpath, framework="pt") + shard = {key: sf.get_tensor(key) for key in sf.keys()} + weights.append(shard) + else: + logger.info("No pre-sharded weights. Sharding at load time...") + weights = self_model.get_builder().shard_checkpoint() + + start_rank_tensor = torch.tensor( + [start_rank_id], dtype=torch.int32, device="cpu" + ) + self_model.traced_model.nxd_model.initialize(weights, start_rank_tensor) + del weights + + model.load_weights = types.MethodType(_mmap_load_weights, model) + + +# ============================================================================ +# Text-Only Model Directory Setup +# ============================================================================ + + +def create_text_only_model_dir(k25_model_path, output_dir): + """Create text-only model directory with flat K2-compatible config. + + K2.5 config.json has text_config nested; the K2 model expects flat config. + Creates symlinks for safetensors files and copies tokenizer files. + + Args: + k25_model_path: Path to K2.5 HF model directory + output_dir: Output directory for text-only model + + Returns: + output_dir path + """ + os.makedirs(output_dir, exist_ok=True) + + k25_config_path = os.path.join(k25_model_path, "config.json") + with open(k25_config_path, "r") as f: + k25_config = json.load(f) + + text_config = k25_config.get("text_config", k25_config) + + config_out = os.path.join(output_dir, "config.json") + with open(config_out, "w") as f: + json.dump(text_config, f, indent=2) + + # Symlink safetensors and index + for fname in os.listdir(k25_model_path): + if fname.endswith(".safetensors") or fname == "model.safetensors.index.json": + src = os.path.join(k25_model_path, fname) + dst = os.path.join(output_dir, fname) + if not os.path.exists(dst): + os.symlink(src, dst) + + # Copy tokenizer files + for fname in os.listdir(k25_model_path): + if "tokenizer" in fname or fname == "special_tokens_map.json": + src = os.path.join(k25_model_path, fname) + dst = os.path.join(output_dir, fname) + if not os.path.exists(dst) and os.path.isfile(src): + shutil.copy2(src, dst) + + return output_dir + + +# ============================================================================ +# Config Builder +# ============================================================================ + + +def build_k25_config( + model_dir, + tp_degree=64, + ep_degree=1, + lnc=2, + batch_size=1, + seq_len=512, + n_active_tokens=128, + quantized=True, + num_layers=None, +): + """Build KimiK2InferenceConfig for K2.5 multimodal inference. + + Args: + model_dir: Path to text-only model directory (with flat config.json) + tp_degree: Tensor parallel degree + ep_degree: Expert parallel degree + lnc: Logical neuron core config + batch_size: Maximum batch size + seq_len: Maximum sequence length + n_active_tokens: Active tokens for TKG + quantized: Use FP8 quantized experts + num_layers: Override number of layers (for testing) + + Returns: + KimiK2InferenceConfig + """ + from modeling_kimi_k2 import KimiK2InferenceConfig + + neuron_config_kwargs = dict( + tp_degree=tp_degree, + ep_degree=ep_degree, + logical_nc_config=lnc, + max_batch_size=batch_size, + seq_len=seq_len, + n_active_tokens=n_active_tokens, + torch_dtype="bfloat16", + capacity_factor=1.0, + glu_mlp=True, + moe_ep_degree=ep_degree, + moe_tp_degree=tp_degree, + router_config=RouterConfig(act_fn="sigmoid", dtype="float32"), + save_sharded_checkpoint=False, + ) + + if quantized: + neuron_config_kwargs["quantized"] = True + neuron_config_kwargs["quantized_checkpoints_path"] = model_dir + neuron_config_kwargs["quantization_dtype"] = "f8e4m3" + neuron_config_kwargs["quantization_type"] = "expert_wise_per_channel_symmetric" + neuron_config_kwargs["modules_to_not_convert"] = [ + "self_attn", + "shared_experts", + "embed_tokens", + "lm_head", + "norm", + "router", + "layers.0", + ] + + neuron_config = MoENeuronConfig(**neuron_config_kwargs) + + with open(os.path.join(model_dir, "config.json"), "r") as f: + hf_config = json.load(f) + + hf_kwargs = { + k: v + for k, v in hf_config.items() + if k not in ("auto_map", "torch_dtype", "transformers_version", "architectures") + } + + if num_layers is not None: + hf_kwargs["num_hidden_layers"] = num_layers + + config = KimiK2InferenceConfig(neuron_config=neuron_config, **hf_kwargs) + config.neuron_config.normalize_top_k_affinities = False + config.neuron_config.blockwise_matmul_config.block_size = 2**30 + config.neuron_config.weights_to_skip_layout_optimization = [".*"] + + return config diff --git a/contrib/models/Kimi-K2.5/src/moonvit.py b/contrib/models/Kimi-K2.5/src/moonvit.py new file mode 100644 index 00000000..a0b8f092 --- /dev/null +++ b/contrib/models/Kimi-K2.5/src/moonvit.py @@ -0,0 +1,397 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# MoonViT vision encoder for Kimi-K2.5 on Neuron. +# +# Architecture: 27-layer ViT (hidden=1152, heads=16, mlp=4304) +# - 400M parameters (466M with projector) +# - 2D RoPE (real-number decomposition, no complex ops) +# - Eager attention (no flash_attn) +# - PatchMergerMLP projector: 2x2 merge → 7168 text hidden dim +# +# Input: pixel_values [N_patches, 3, 14, 14] (patchified image) +# Output: projected embeddings [N_merged, 7168] +# +# For 448x448 image: 1024 patches → 256 merged vision tokens +# +# Usage: +# 1. Create wrapper: NeuronMoonViTWrapper(patch_h=32, patch_w=32) +# 2. Load weights: load_vision_weights(model, model_path, 32, 32) +# 3. Trace on Neuron: torch_neuronx.trace(model, (pixels, cos, sin)) +# 4. Run inference: model(pixel_values, rope_cos, rope_sin) → [256, 7168] +# +# Note: All 64 Neuron cores are consumed by the text decoder (TP=64). +# MoonViT must be traced and run BEFORE loading the text decoder, or +# pre-computed embeddings must be used. + +import json +import math +import os +import sys +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +PATCH_SIZE = 14 +MERGE_KERNEL = (2, 2) + + +# ============================================================================ +# Neuron-compatible attention +# ============================================================================ + + +def eager_attention(q, k, v): + """Simple eager attention for a single sequence. + + Args: + q, k, v: [seq_len, num_heads, head_dim] + Returns: + output: [seq_len, hidden_dim] + """ + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + scale = math.sqrt(q.shape[-1]) + attn_weight = torch.matmul(q, k.transpose(-2, -1)) / scale + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weight, v) + attn_output = attn_output.transpose(0, 1).reshape(q.shape[1], -1) + return attn_output + + +# ============================================================================ +# Real-number RoPE (no complex ops) +# ============================================================================ + + +def precompute_rope_real(dim, max_height, max_width, theta_base=10000): + """Precompute 2D RoPE cos/sin tables using real-number decomposition. + + Returns: + cos_table: [max_height, max_width, dim//2] + sin_table: [max_height, max_width, dim//2] + """ + N = max_height * max_width + flat_pos = torch.arange(0, N).float() + x_pos = flat_pos % max_width + y_pos = flat_pos // max_width + + dim_range = torch.arange(0, dim, 4)[: dim // 4].float() + freqs = 1.0 / (theta_base ** (dim_range / dim)) + + x_freqs = torch.outer(x_pos, freqs) + y_freqs = torch.outer(y_pos, freqs) + + cos_x, sin_x = torch.cos(x_freqs), torch.sin(x_freqs) + cos_y, sin_y = torch.cos(y_freqs), torch.sin(y_freqs) + + cos_table = torch.stack([cos_x, cos_y], dim=-1).reshape(N, -1) + sin_table = torch.stack([sin_x, sin_y], dim=-1).reshape(N, -1) + + return cos_table.reshape(max_height, max_width, -1), sin_table.reshape( + max_height, max_width, -1 + ) + + +def apply_rope_real(xq, xk, cos_table, sin_table): + """Apply 2D RoPE using real-number cos/sin decomposition. + + Args: + xq, xk: [..., num_heads, head_dim] + cos_table, sin_table: [..., head_dim/2] + Returns: + xq_out, xk_out: [..., num_heads, head_dim] + """ + cos = cos_table.unsqueeze(-2) + sin = sin_table.unsqueeze(-2) + + xq_pairs = xq.float().reshape(*xq.shape[:-1], -1, 2) + xk_pairs = xk.float().reshape(*xk.shape[:-1], -1, 2) + + xq_real, xq_imag = xq_pairs[..., 0], xq_pairs[..., 1] + xk_real, xk_imag = xk_pairs[..., 0], xk_pairs[..., 1] + + xq_out_real = xq_real * cos - xq_imag * sin + xq_out_imag = xq_real * sin + xq_imag * cos + xk_out_real = xk_real * cos - xk_imag * sin + xk_out_imag = xk_real * sin + xk_imag * cos + + xq_out = torch.stack([xq_out_real, xq_out_imag], dim=-1).flatten(-2) + xk_out = torch.stack([xk_out_real, xk_out_imag], dim=-1).flatten(-2) + + return xq_out.to(xq.dtype), xk_out.to(xk.dtype) + + +# ============================================================================ +# Encoder Layer +# ============================================================================ + + +class NeuronMoonViTEncoderLayer(nn.Module): + """Single MoonViT encoder layer with real-number RoPE and eager attention.""" + + def __init__(self, num_heads, hidden_dim, mlp_dim): + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.head_dim = hidden_dim // num_heads + + self.norm0 = nn.LayerNorm(hidden_dim) + self.norm1 = nn.LayerNorm(hidden_dim) + self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=True) + self.wo = nn.Linear(hidden_dim, hidden_dim, bias=True) + self.fc0 = nn.Linear(hidden_dim, mlp_dim, bias=True) + self.fc1 = nn.Linear(mlp_dim, hidden_dim, bias=True) + + def forward(self, hidden_states, rope_cos, rope_sin): + """ + Args: + hidden_states: [seq_len, hidden_dim] + rope_cos, rope_sin: [seq_len, head_dim/2] + """ + residual = hidden_states + hidden_states = self.norm0(hidden_states) + + xqkv = self.wqkv(hidden_states) + xqkv = xqkv.view(-1, 3, self.num_heads, self.head_dim) + xq, xk, xv = xqkv[:, 0], xqkv[:, 1], xqkv[:, 2] + + xq, xk = apply_rope_real(xq, xk, rope_cos, rope_sin) + attn_out = eager_attention(xq, xk, xv) + attn_out = self.wo(attn_out) + hidden_states = residual + attn_out + + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = self.fc0(hidden_states) + hidden_states = F.gelu(hidden_states, approximate="tanh") + hidden_states = self.fc1(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +# ============================================================================ +# MoonViT Wrapper (Neuron-traceable) +# ============================================================================ + + +class NeuronMoonViTWrapper(nn.Module): + """Neuron-traceable MoonViT + PatchMergerMLP wrapper. + + For a fixed image resolution (default 448x448): + - Input: pixel_values [N_patches, 3, 14, 14] (patchified) + - Output: projected embeddings [N_merged, text_hidden_size] + + Pipeline: + 1. Conv2d patch embedding + 2. Learnable positional embedding (+ temporal at t=0) + 3. 27 encoder layers with 2D RoPE + eager attention + 4. Final LayerNorm + 5. PatchMergerMLP: 2x2 merge → Linear → GELU → Linear → 7168 + """ + + def __init__( + self, + num_layers=27, + hidden_dim=1152, + num_heads=16, + mlp_dim=4304, + text_hidden_size=7168, + patch_h=32, + patch_w=32, + merge_kernel=(2, 2), + projector_ln_eps=1e-5, + ): + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + self.patch_h = patch_h + self.patch_w = patch_w + self.merge_kernel = merge_kernel + self.seq_len = patch_h * patch_w + + merged_h = patch_h // merge_kernel[0] + merged_w = patch_w // merge_kernel[1] + self.n_merged = merged_h * merged_w + self.merge_dim = hidden_dim * merge_kernel[0] * merge_kernel[1] + + self.patch_conv = nn.Conv2d(3, hidden_dim, kernel_size=14, stride=14) + + self.layers = nn.ModuleList( + [ + NeuronMoonViTEncoderLayer(num_heads, hidden_dim, mlp_dim) + for _ in range(num_layers) + ] + ) + self.final_norm = nn.LayerNorm(hidden_dim) + + # PatchMergerMLP projector + self.proj_norm = nn.LayerNorm(hidden_dim, eps=projector_ln_eps) + self.proj_fc0 = nn.Linear(self.merge_dim, self.merge_dim) + self.proj_fc1 = nn.Linear(self.merge_dim, text_hidden_size) + + def forward(self, pixel_values, rope_cos, rope_sin): + """ + Args: + pixel_values: [N_patches, 3, 14, 14] + rope_cos: [seq_len, head_dim/2] + rope_sin: [seq_len, head_dim/2] + Returns: + projected: [N_merged, text_hidden_size] + """ + x = self.patch_conv(pixel_values).view(pixel_values.shape[0], -1) + x = x + self.pos_embed + + for layer in self.layers: + x = layer(x, rope_cos, rope_sin) + + x = self.final_norm(x) + + # Patch merging (2x2) + kh, kw = self.merge_kernel + new_h = self.patch_h // kh + new_w = self.patch_w // kw + x = x.view(1, new_h, kh, new_w, kw, self.hidden_dim) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.squeeze(0) + x = x.reshape(new_h * new_w, kh * kw, self.hidden_dim) + + # Projector + x = self.proj_norm(x) + x = x.reshape(self.n_merged, -1) + x = self.proj_fc0(x) + x = F.gelu(x) + x = self.proj_fc1(x) + + return x + + +# ============================================================================ +# Weight Loading +# ============================================================================ + + +def load_vision_weights(model, model_path, patch_h, patch_w, device="cpu"): + """Load MoonViT + PatchMergerMLP weights from K2.5 safetensors. + + Args: + model: NeuronMoonViTWrapper instance + model_path: Path to K2.5 HF model directory + patch_h: Patch grid height + patch_w: Patch grid width + device: Device for weight loading + + Returns: + model with loaded weights + """ + from safetensors import safe_open + + index_path = os.path.join(model_path, "model.safetensors.index.json") + with open(index_path) as f: + index = json.load(f) + + weight_map = index["weight_map"] + vision_keys = { + k: v + for k, v in weight_map.items() + if k.startswith("vision_tower.") or k.startswith("mm_projector.") + } + + shard_to_keys = {} + for key, shard in vision_keys.items(): + shard_to_keys.setdefault(shard, []).append(key) + + all_weights = {} + for shard, keys in sorted(shard_to_keys.items()): + shard_path = os.path.join(model_path, shard) + with safe_open(shard_path, framework="pt", device=str(device)) as f: + for key in keys: + all_weights[key] = f.get_tensor(key) + + state_dict = {} + + # Patch conv + state_dict["patch_conv.weight"] = all_weights[ + "vision_tower.patch_embed.proj.weight" + ] + state_dict["patch_conv.bias"] = all_weights["vision_tower.patch_embed.proj.bias"] + + # Positional embedding + pos_weight = all_weights["vision_tower.patch_embed.pos_emb.weight"] + init_h, init_w, dim = pos_weight.shape + if (init_h, init_w) == (patch_h, patch_w): + pos_embed = pos_weight.flatten(end_dim=1) + else: + pos_embed = ( + F.interpolate( + pos_weight.permute(2, 0, 1).unsqueeze(0), + size=(patch_h, patch_w), + mode="bicubic", + ) + .squeeze(0) + .permute(1, 2, 0) + .flatten(end_dim=1) + ) + + # Temporal embedding (t=0 for single image) + time_weight = all_weights.get("vision_tower.patch_embed.pos_emb.time_weight") + if time_weight is not None: + pos_embed = pos_embed + time_weight[0] + + state_dict["pos_embed"] = pos_embed + + # Encoder layers + for i in range(27): + prefix = f"vision_tower.encoder.blocks.{i}" + layer_prefix = f"layers.{i}" + state_dict[f"{layer_prefix}.norm0.weight"] = all_weights[ + f"{prefix}.norm0.weight" + ] + state_dict[f"{layer_prefix}.norm0.bias"] = all_weights[f"{prefix}.norm0.bias"] + state_dict[f"{layer_prefix}.norm1.weight"] = all_weights[ + f"{prefix}.norm1.weight" + ] + state_dict[f"{layer_prefix}.norm1.bias"] = all_weights[f"{prefix}.norm1.bias"] + state_dict[f"{layer_prefix}.wqkv.weight"] = all_weights[f"{prefix}.wqkv.weight"] + state_dict[f"{layer_prefix}.wqkv.bias"] = all_weights[f"{prefix}.wqkv.bias"] + state_dict[f"{layer_prefix}.wo.weight"] = all_weights[f"{prefix}.wo.weight"] + state_dict[f"{layer_prefix}.wo.bias"] = all_weights[f"{prefix}.wo.bias"] + state_dict[f"{layer_prefix}.fc0.weight"] = all_weights[ + f"{prefix}.mlp.fc0.weight" + ] + state_dict[f"{layer_prefix}.fc0.bias"] = all_weights[f"{prefix}.mlp.fc0.bias"] + state_dict[f"{layer_prefix}.fc1.weight"] = all_weights[ + f"{prefix}.mlp.fc1.weight" + ] + state_dict[f"{layer_prefix}.fc1.bias"] = all_weights[f"{prefix}.mlp.fc1.bias"] + + # Final LayerNorm + state_dict["final_norm.weight"] = all_weights[ + "vision_tower.encoder.final_layernorm.weight" + ] + state_dict["final_norm.bias"] = all_weights[ + "vision_tower.encoder.final_layernorm.bias" + ] + + # PatchMergerMLP projector + state_dict["proj_norm.weight"] = all_weights["mm_projector.pre_norm.weight"] + state_dict["proj_norm.bias"] = all_weights["mm_projector.pre_norm.bias"] + state_dict["proj_fc0.weight"] = all_weights["mm_projector.proj.0.weight"] + state_dict["proj_fc0.bias"] = all_weights["mm_projector.proj.0.bias"] + state_dict["proj_fc1.weight"] = all_weights["mm_projector.proj.2.weight"] + state_dict["proj_fc1.bias"] = all_weights["mm_projector.proj.2.bias"] + + # Register pos_embed as buffer + pos_embed = state_dict.pop("pos_embed") + model.register_buffer("pos_embed", pos_embed) + + model.load_state_dict(state_dict, strict=False) + return model diff --git a/contrib/models/Kimi-K2.5/test/__init__.py b/contrib/models/Kimi-K2.5/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Kimi-K2.5/test/integration/__init__.py b/contrib/models/Kimi-K2.5/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Kimi-K2.5/test/integration/test_model.py b/contrib/models/Kimi-K2.5/test/integration/test_model.py new file mode 100644 index 00000000..7ade0015 --- /dev/null +++ b/contrib/models/Kimi-K2.5/test/integration/test_model.py @@ -0,0 +1,500 @@ +#!/usr/bin/env python3 +""" +Integration tests for Kimi-K2.5 multimodal NeuronX implementation. + +Tests compilation, loading, and multimodal inference on trn2.48xlarge. + +Requirements: + - trn2.48xlarge with NEURON_LOGICAL_NC_CONFIG=2 (64 logical cores) + - LOCAL_WORLD_SIZE=64 + - K2.5 model weights at MODEL_PATH + - Pre-computed MoonViT embeddings at VISION_EMBEDDINGS_PATH + - Neuron SDK 2.29 (Deep Learning AMI Neuron Ubuntu 24.04 20260410) + - tiktoken package installed (pip install tiktoken) + +Usage: + # Full test (compile + load + generate): + NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ + pytest test_model.py -v --capture=tee-sys + + # Load-only (skip compile, use existing NEFFs): + NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ + pytest test_model.py -v --capture=tee-sys -k "not compile" + + # Standalone (use PYTHONUNBUFFERED for real-time log output): + NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ + PYTHONUNBUFFERED=1 python test_model.py +""" + +import json +import logging +import os +import sys +import time +from pathlib import Path + +import pytest +import torch +from transformers import AutoTokenizer + +# Import from src directory +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) +from modeling_kimi_k2 import ( + NeuronKimiK2ForCausalLM, + NeuronKimiK2Model, + KimiK2InferenceConfig, +) +from modeling_kimi_k25 import ( + apply_k25_patches, + apply_k25_checkpoint_patch, + build_k25_config, + create_text_only_model_dir, + BOS_TOKEN_ID, + IM_USER_TOKEN_ID, + IM_END_TOKEN_ID, + IM_ASSISTANT_TOKEN_ID, + MEDIA_PLACEHOLDER_TOKEN_ID, +) + + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +MODEL_PATH = "/mnt/nvme/models/Kimi-K2.5" +TEXT_MODEL_DIR = "/home/ubuntu/models/Kimi-K2.5-text" +COMPILED_MODEL_PATH = "/mnt/nvme/models/Kimi-K2.5-text/neuron-compiled-k25-vl-s512-v5" +VISION_EMBEDDINGS_PATH = "/mnt/nvme/models/moonvit_448_real_embeddings.pt" + +# Model configuration (TP=64, EP=1, LNC=2) +TP_DEGREE = 64 +EP_DEGREE = 1 +LNC = 2 +BATCH_SIZE = 1 +SEQ_LEN = 512 +N_ACTIVE_TOKENS = 128 +N_VISION_TOKENS = 256 # 448x448 image → 32x32 patches → 16x16 merged + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def tokenizer(): + """Load K2.5 tokenizer.""" + tok = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + + +@pytest.fixture(scope="module") +def vision_embeddings(): + """Load pre-computed MoonViT embeddings.""" + if not os.path.exists(VISION_EMBEDDINGS_PATH): + pytest.skip(f"Vision embeddings not found at {VISION_EMBEDDINGS_PATH}") + emb = torch.load(VISION_EMBEDDINGS_PATH, map_location="cpu") + assert emb.shape == (N_VISION_TOKENS, 7168), f"Unexpected shape: {emb.shape}" + return emb.to(torch.bfloat16) + + +@pytest.fixture(scope="module") +def compiled_model(): + """Compile (if needed) and load K2.5 multimodal model.""" + # 1. Create text-only model dir + text_model_dir = create_text_only_model_dir(MODEL_PATH, TEXT_MODEL_DIR) + + # 2. Apply vision patches BEFORE model init + apply_k25_patches(NeuronKimiK2ForCausalLM, NeuronKimiK2Model, ep_degree=EP_DEGREE) + + # 3. Build config + config = build_k25_config( + text_model_dir, + tp_degree=TP_DEGREE, + ep_degree=EP_DEGREE, + lnc=LNC, + batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + n_active_tokens=N_ACTIVE_TOKENS, + quantized=True, + ) + + # 4. Initialize model + model = NeuronKimiK2ForCausalLM(text_model_dir, config=config) + + # 5. Apply checkpoint patches + apply_k25_checkpoint_patch(model) + + # 6. Compile if needed + compiled_path = Path(COMPILED_MODEL_PATH) + if not compiled_path.exists() or not (compiled_path / "model.pt").exists(): + print(f"\nCompiling model to {COMPILED_MODEL_PATH}...") + t0 = time.time() + model.compile(COMPILED_MODEL_PATH) + print(f"Compilation done in {(time.time() - t0) / 60:.1f} min") + + # 7. Load + print(f"\nLoading model from {COMPILED_MODEL_PATH}...") + t0 = time.time() + model.load(COMPILED_MODEL_PATH) + print(f"Model loaded in {(time.time() - t0) / 60:.1f} min") + + return model + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def build_multimodal_prompt( + tokenizer, vision_embeddings, seq_len, user_text="Describe this image in detail." +): + """Build multimodal prompt with vision placeholder tokens. + + Returns: + input_ids_list: List of token IDs + ve: [1, seq_len, 7168] vision embedding tensor + vm: [1, seq_len, 1] vision mask tensor + n_prompt: Number of prompt tokens + """ + n_vision = vision_embeddings.shape[0] + hidden_size = vision_embeddings.shape[1] + + text_ids = tokenizer.encode(user_text) + placeholder_ids = [MEDIA_PLACEHOLDER_TOKEN_ID] * n_vision + + input_ids_list = ( + [BOS_TOKEN_ID, IM_USER_TOKEN_ID] + + placeholder_ids + + text_ids + + [IM_END_TOKEN_ID, IM_ASSISTANT_TOKEN_ID] + ) + n_prompt = len(input_ids_list) + + vision_start = 2 # After BOS + im_user + + ve = torch.zeros(1, seq_len, hidden_size, dtype=torch.bfloat16) + vm = torch.full((1, seq_len, 1), fill_value=seq_len - 1, dtype=torch.int32) + + for i in range(n_vision): + ve[0, i] = vision_embeddings[i] + vm[0, i, 0] = vision_start + i + + return input_ids_list, ve, vm, n_prompt + + +def generate_multimodal( + model, tokenizer, vision_embeddings, max_new_tokens=32, min_tokens_before_eos=3 +): + """Generate tokens from a multimodal prompt. + + Returns: + output_text: Generated text + generated_ids: List of generated token IDs + cte_time: CTE latency in seconds + tkg_time: TKG latency in seconds + """ + input_ids_list, ve, vm, n_prompt = build_multimodal_prompt( + tokenizer, vision_embeddings, SEQ_LEN + ) + + model.reset() + seq_ids = torch.arange(BATCH_SIZE, dtype=torch.long) + + # Context encoding + cte_input_ids = torch.tensor([input_ids_list], dtype=torch.long) + cte_mask = torch.zeros(1, SEQ_LEN, dtype=torch.long) + cte_mask[0, :n_prompt] = 1 + cte_pos = torch.arange(n_prompt, dtype=torch.long).unsqueeze(0) + + t0 = time.perf_counter() + output = model( + input_ids=cte_input_ids, + attention_mask=cte_mask, + position_ids=cte_pos, + seq_ids=seq_ids, + vision_embeddings=ve, + vision_mask=vm, + ) + cte_time = time.perf_counter() - t0 + + logits = output.logits if hasattr(output, "logits") else output[0] + if logits.dim() == 3: + first_token = logits[0, -1].argmax(dim=-1).item() + else: + first_token = logits[0].argmax(dim=-1).item() + + generated_ids = [first_token] + current_pos = n_prompt + + # Token generation + t_gen_start = time.perf_counter() + for step in range(max_new_tokens - 1): + next_input_ids = torch.tensor([[generated_ids[-1]]], dtype=torch.long) + next_position_ids = torch.tensor([[current_pos]], dtype=torch.long) + next_attention_mask = torch.zeros(1, SEQ_LEN, dtype=torch.long) + next_attention_mask[0, : current_pos + 1] = 1 + + output = model( + input_ids=next_input_ids, + attention_mask=next_attention_mask, + position_ids=next_position_ids, + seq_ids=seq_ids, + vision_embeddings=torch.zeros(0, dtype=torch.bfloat16), + vision_mask=torch.zeros(0, dtype=torch.bool), + ) + + logits_tkg = output.logits if hasattr(output, "logits") else output[0] + if logits_tkg.dim() == 3: + next_token = logits_tkg[0, -1].argmax(dim=-1).item() + elif logits_tkg.dim() == 2: + next_token = logits_tkg[0].argmax(dim=-1).item() + else: + next_token = logits_tkg.argmax(dim=-1).item() + + generated_ids.append(next_token) + current_pos += 1 + + if next_token in (IM_END_TOKEN_ID, tokenizer.eos_token_id): + break + if current_pos >= SEQ_LEN: + break + + tkg_time = time.perf_counter() - t_gen_start + + output_text = tokenizer.decode(generated_ids, skip_special_tokens=True) + return output_text, generated_ids, cte_time, tkg_time + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_model_loads(compiled_model): + """Smoke test: model loads successfully.""" + assert compiled_model is not None + assert hasattr(compiled_model, "config") + assert hasattr(compiled_model.config, "neuron_config") + print("PASS: Model loaded successfully") + + +def test_multimodal_generates(compiled_model, tokenizer, vision_embeddings): + """Test multimodal generation produces coherent output.""" + output, gen_ids, cte_time, tkg_time = generate_multimodal( + compiled_model, tokenizer, vision_embeddings, max_new_tokens=32 + ) + assert len(output) > 0, "Output should not be empty" + words = output.split() + assert len(words) >= 3, f"Output too short: {output}" + print(f"PASS: Multimodal generation - Output: {output[:200]}") + print(f" CTE: {cte_time:.3f}s, TKG: {tkg_time:.3f}s") + + +def test_vision_affects_output(compiled_model, tokenizer, vision_embeddings): + """Test that vision embeddings actually affect model output. + + Compare output with real vision vs zero vision — they should differ. + """ + seq_ids = torch.arange(BATCH_SIZE, dtype=torch.long) + hidden_size = 7168 + sl = SEQ_LEN + + n_test = 128 + test_ids_list = [BOS_TOKEN_ID, IM_USER_TOKEN_ID] + [MEDIA_PLACEHOLDER_TOKEN_ID] * ( + n_test - 2 + ) + test_input_ids = torch.tensor([test_ids_list], dtype=torch.long) + test_mask = torch.zeros(1, sl, dtype=torch.long) + test_mask[0, :n_test] = 1 + test_pos = torch.arange(n_test, dtype=torch.long).unsqueeze(0) + + # Test A: Real vision + ve_real = torch.zeros(1, sl, hidden_size, dtype=torch.bfloat16) + vm_real = torch.full((1, sl, 1), fill_value=sl - 1, dtype=torch.int32) + n_vision = min(vision_embeddings.shape[0], n_test - 2) + for i in range(n_vision): + ve_real[0, i] = vision_embeddings[i] + vm_real[0, i, 0] = i + 2 + + compiled_model.reset() + out_a = compiled_model( + input_ids=test_input_ids, + attention_mask=test_mask, + position_ids=test_pos, + seq_ids=seq_ids, + vision_embeddings=ve_real, + vision_mask=vm_real, + ) + logits_a = out_a.logits if hasattr(out_a, "logits") else out_a[0] + + # Test B: Zero vision + ve_zero = torch.zeros(1, sl, hidden_size, dtype=torch.bfloat16) + vm_zero = torch.full((1, sl, 1), fill_value=sl - 1, dtype=torch.int32) + + compiled_model.reset() + out_b = compiled_model( + input_ids=test_input_ids, + attention_mask=test_mask, + position_ids=test_pos, + seq_ids=seq_ids, + vision_embeddings=ve_zero, + vision_mask=vm_zero, + ) + logits_b = out_b.logits if hasattr(out_b, "logits") else out_b[0] + + # Compare + if logits_a.dim() == 3: + token_a = logits_a[0, -1].argmax(dim=-1).item() + token_b = logits_b[0, -1].argmax(dim=-1).item() + else: + token_a = logits_a[0].argmax(dim=-1).item() + token_b = logits_b[0].argmax(dim=-1).item() + + max_diff = (logits_a.float() - logits_b.float()).abs().max().item() + + assert token_a != token_b, ( + f"Vision did not affect output: token_a={token_a}, token_b={token_b}, " + f"max_logit_diff={max_diff:.6f}" + ) + print( + f"PASS: Vision affects output (token_a={token_a}, token_b={token_b}, " + f"max_logit_diff={max_diff:.3f})" + ) + + +def test_output_coherence(compiled_model, tokenizer, vision_embeddings): + """Test that output is not gibberish or repetitive.""" + output, gen_ids, _, _ = generate_multimodal( + compiled_model, tokenizer, vision_embeddings, max_new_tokens=64 + ) + words = output.split() + assert len(words) >= 3, f"Output too short: {output}" + + # Check for excessive repetition + if len(words) >= 10: + for i in range(len(words) - 5): + repeated = all(words[i + j] == words[i] for j in range(5)) + assert not repeated, f"Excessive repetition in output: {output}" + + print(f"PASS: Coherence test - Output: {output[:200]}") + + +def test_performance_tpot(compiled_model, tokenizer, vision_embeddings): + """Measure per-token output latency (TPOT).""" + # Warmup + generate_multimodal(compiled_model, tokenizer, vision_embeddings, max_new_tokens=10) + + # Measure + n_runs = 3 + tpots = [] + for _ in range(n_runs): + _, gen_ids, cte_time, tkg_time = generate_multimodal( + compiled_model, tokenizer, vision_embeddings, max_new_tokens=32 + ) + n_tkg_tokens = len(gen_ids) - 1 # First token comes from CTE + if n_tkg_tokens > 0: + tpot = (tkg_time * 1000) / n_tkg_tokens + tpots.append(tpot) + + if tpots: + median_tpot = sorted(tpots)[len(tpots) // 2] + tok_per_sec = 1000.0 / median_tpot + print(f"PASS: TPOT = {median_tpot:.1f} ms ({tok_per_sec:.1f} tok/s)") + # K2.5 at TP=64 EP=1 LNC=2: expected ~21 ms/token (45.9 tok/s) + assert median_tpot < 100, f"TPOT {median_tpot:.1f}ms exceeds 100ms threshold" + else: + pytest.skip("Could not measure TPOT") + + +# --------------------------------------------------------------------------- +# Standalone runner +# --------------------------------------------------------------------------- + + +if __name__ == "__main__": + # Force line-buffered stdout for real-time log output when redirected to file + import sys + + sys.stdout.reconfigure(line_buffering=True) + sys.stderr.reconfigure(line_buffering=True) + + # Configure logging to show checkpoint loader progress + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + stream=sys.stderr, + ) + + print("=" * 80) + print("Kimi-K2.5 Multimodal Integration Tests") + print("=" * 80) + + # Load vision embeddings + if not os.path.exists(VISION_EMBEDDINGS_PATH): + print(f"ERROR: Vision embeddings not found at {VISION_EMBEDDINGS_PATH}") + sys.exit(1) + vision_emb = torch.load(VISION_EMBEDDINGS_PATH, map_location="cpu").to( + torch.bfloat16 + ) + print(f"Vision embeddings: {vision_emb.shape}") + + # Setup model + text_model_dir = create_text_only_model_dir(MODEL_PATH, TEXT_MODEL_DIR) + apply_k25_patches(NeuronKimiK2ForCausalLM, NeuronKimiK2Model, ep_degree=EP_DEGREE) + + config = build_k25_config( + text_model_dir, + tp_degree=TP_DEGREE, + ep_degree=EP_DEGREE, + lnc=LNC, + batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + n_active_tokens=N_ACTIVE_TOKENS, + quantized=True, + ) + + model = NeuronKimiK2ForCausalLM(text_model_dir, config=config) + apply_k25_checkpoint_patch(model) + + compiled_path = Path(COMPILED_MODEL_PATH) + if not compiled_path.exists() or not (compiled_path / "model.pt").exists(): + print(f"\nCompiling model to {COMPILED_MODEL_PATH}...") + t0 = time.time() + model.compile(COMPILED_MODEL_PATH) + print(f"Compilation done in {(time.time() - t0) / 60:.1f} min") + + print(f"\nLoading model from {COMPILED_MODEL_PATH}...") + t0 = time.time() + model.load(COMPILED_MODEL_PATH) + print(f"Model loaded in {(time.time() - t0) / 60:.1f} min") + + tok = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + + print("\n" + "=" * 80) + print("Running Tests") + print("=" * 80) + + print("\n1. Smoke Test...") + test_model_loads(model) + + print("\n2. Multimodal Generation Test...") + test_multimodal_generates(model, tok, vision_emb) + + print("\n3. Vision A/B Test...") + test_vision_affects_output(model, tok, vision_emb) + + print("\n4. Coherence Test...") + test_output_coherence(model, tok, vision_emb) + + print("\n5. TPOT Performance Test...") + test_performance_tpot(model, tok, vision_emb) + + print("\n" + "=" * 80) + print("All tests passed!") + print("=" * 80) diff --git a/contrib/models/Kimi-K2.5/test/unit/__init__.py b/contrib/models/Kimi-K2.5/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b From 13eccd9043e0c383af0cff3630910b6f76619c7a Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Wed, 29 Apr 2026 23:08:32 -0400 Subject: [PATCH 2/2] Add Kimi-K2.6 support to K2.5 contrib (validated, identical architecture) K2.6 is a post-training update of K2.5 with identical architecture (only eos_token_id differs: 163585 -> 163586). No source code changes needed -- NxDI reads eos_token_id from config at load time. Changes: - README: Add K2.6 model info, benchmarks (57.9 tok/s TKG, 17.3ms TPOT, 1010ms TTFT), compatibility matrix, behavioral notes - test_model.py: Make paths configurable via env vars (KIMI_MODEL_PATH, KIMI_TEXT_MODEL_DIR, KIMI_COMPILED_PATH, KIMI_VISION_EMB_PATH) so tests work with either K2.5 or K2.6 checkpoints Validated on trn2.48xlarge, SDK 2.29, TP=64, EP=1, LNC=2. --- contrib/models/Kimi-K2.5/README.md | 98 ++++++++++++++----- .../Kimi-K2.5/test/integration/test_model.py | 46 +++++++-- 2 files changed, 111 insertions(+), 33 deletions(-) diff --git a/contrib/models/Kimi-K2.5/README.md b/contrib/models/Kimi-K2.5/README.md index fbbcffe0..8aab7ee0 100644 --- a/contrib/models/Kimi-K2.5/README.md +++ b/contrib/models/Kimi-K2.5/README.md @@ -1,13 +1,20 @@ -# Contrib Model: Kimi-K2.5 (Multimodal) +# Contrib Model: Kimi-K2.5 / K2.6 (Multimodal) -NeuronX Distributed Inference implementation of Moonshot AI's Kimi-K2.5 — a native multimodal agentic model with MoonViT vision encoder. +NeuronX Distributed Inference implementation of Moonshot AI's Kimi-K2.5 and Kimi-K2.6 — native multimodal agentic models with MoonViT vision encoder. + +**K2.6 is a post-training update of K2.5 with an identical architecture.** The only config difference is `eos_token_id` (163585 → 163586). This contrib supports both models with no code changes — just point to the desired checkpoint. ## Model Information -- **HuggingFace ID:** `moonshotai/Kimi-K2.5` -- **Model Type:** Multimodal (image + text) Mixture of Experts decoder -- **Architecture:** Kimi-K2 text decoder + MoonViT-400M vision encoder -- **License:** Check HuggingFace model card +| Field | K2.5 | K2.6 | +|-------|------|------| +| **HuggingFace ID** | `moonshotai/Kimi-K2.5` | `moonshotai/Kimi-K2.6` | +| **Model Type** | Multimodal (image + text) Mixture of Experts decoder | Same | +| **Architecture** | Kimi-K2 text decoder + MoonViT-400M vision encoder | Same (identical weights differ, architecture unchanged) | +| **eos_token_id** | 163585 (`[EOS]`) | 163586 (`<\|im_end\|>`) | +| **License** | Check HuggingFace model card | Check HuggingFace model card | + +> **Note:** NxDI reads `eos_token_id` from the model config at load time, so both models work without code changes. ## Architecture Details @@ -58,10 +65,12 @@ K2.5 uses a different weight format than K2: ## Validation Results +### K2.5 + **Validated:** 2026-04-25 (SDK 2.29) **Configuration:** TP=64, EP=1, LNC=2, batch_size=1, seq_len=512, FP8 per-channel -### Test Results +#### Test Results | Test | Status | Result | |------|--------|--------| @@ -71,7 +80,7 @@ K2.5 uses a different weight format than K2: | Coherence | PASS | No repetition, natural text | | Throughput | PASS | 45.9 tok/s at BS=1 (LNC=2) | -### Performance Metrics +#### Performance Metrics | Metric | Value | |--------|-------| @@ -85,7 +94,7 @@ K2.5 uses a different weight format than K2: | **Model load time** | ~79 min (weight sharding) | | **Compile time** | ~10 min (CTE ~5 min, TKG ~5 min) | -### Benchmark Details (10 iterations, 3 warmup) +#### Benchmark Details (10 iterations, 3 warmup) | Component | Mean | p50 | Std | |-----------|------|-----|-----| @@ -95,15 +104,36 @@ K2.5 uses a different weight format than K2: | TKG per-token | 21.4 ms | 21.3 ms | 0.4 ms | | End-to-end (CTE+TKG) | 4,863.2 ms | 4,864.7 ms | 12.1 ms | -### Accuracy Validation (vs GPU reference) +### K2.6 + +**Validated:** 2026-04-28 (SDK 2.29) +**Configuration:** TP=64, EP=1, LNC=2, batch_size=1, seq_len=512, FP8 per-channel +**Note:** K2.6 reuses K2.5 compiled NEFFs — only weights differ. + +#### Test Results + +| Test | Status | Result | +|------|--------|--------| +| Smoke Test | PASS | Model loads on trn2.48xlarge (reuses K2.5 NEFFs) | +| Text Decoder (13 prompts) | PASS | Factual, reasoning, coding, creative, math, multilingual, long_form | +| Multimodal Generation | PASS | Correctly describes puppy image | +| Vision A/B Test | PASS | Real vision ≠ zero vision (max logit diff: 16.3) | + +#### Performance Metrics | Metric | Value | |--------|-------| -| Vision encoder cosine similarity | 0.9995 | -| Token match rate (vs vLLM H100) | 3.1% (expected for 1T MoE) | -| Semantic quality | Both describe same image correctly | +| **TKG throughput** | **57.9 tok/s** (mean), 58.0 (p50) | +| **TPOT (per-token latency)** | 17.3 ms (mean), 17.2 ms (p50) | +| **CTE latency (TTFT)** | 1,010.2 ms | +| **MoonViT latency** | 35.8 ms | +| **MoonViT cosine sim (vs CPU)** | 0.999460 | -Token-level divergence is expected: FP8 vs BF16 quantization + 384-expert MoE routing causes different expert selections that cascade through autoregressive generation. Both outputs are semantically equivalent and correctly describe the input image. +> **K2.6 vs K2.5 throughput:** K2.6 shows 57.9 tok/s vs K2.5's 45.9 tok/s. The improvement is likely due to different text generation patterns (K2.6 produces reasoning chains that may have more favorable token distributions), not architectural differences. + +### K2.6 Behavioral Note + +K2.6 produces chain-of-thought reasoning even without a `` prefix ("The user is asking a very simple factual question..."). K2.5 gives direct answers. This is a post-training behavior change, not an architecture change. ## Usage @@ -125,7 +155,8 @@ from modeling_kimi_k25 import ( IM_ASSISTANT_TOKEN_ID, MEDIA_PLACEHOLDER_TOKEN_ID, ) -model_path = "/path/to/Kimi-K2.5" +# Use either K2.5 or K2.6 — same code, just swap the checkpoint path +model_path = "/path/to/Kimi-K2.5" # or "/path/to/Kimi-K2.6" text_model_dir = "/path/to/Kimi-K2.5-text" compiled_path = "/path/to/compiled" vision_emb_path = "/path/to/moonvit_448_real_embeddings.pt" @@ -219,13 +250,16 @@ torch.save(vision_output.to(torch.bfloat16), "moonvit_448_real_embeddings.pt") | Instance / SDK Version | 2.29 | 2.28 | 2.27 and earlier | |------------------------|------|------|------------------| -| trn2.48xlarge (LNC=2, TP=64, EP=1) | **Working (45.9 tok/s)** | Not tested | Not tested | +| trn2.48xlarge K2.5 (LNC=2, TP=64, EP=1) | **Working (45.9 tok/s)** | Not tested | Not tested | +| trn2.48xlarge K2.6 (LNC=2, TP=64, EP=1) | **Working (57.9 tok/s)** | Not tested | Not tested | | trn2.48xlarge (LNC=2, TP=32, EP=2) | Not recommended* | Not tested | Not tested | | trn2.3xlarge | Not supported (needs TP=64) | Not supported | Not supported | | inf2 | Not supported | Not supported | Not supported | \*EP=2 has known blockwise CTE kernel regression in SDK 2.29 (see K2 contrib notes). +**Note:** K2.6 reuses K2.5 compiled NEFFs since the architecture is identical. Only model weights need to be re-downloaded. + ## Testing Run integration tests on a trn2.48xlarge: @@ -233,10 +267,17 @@ Run integration tests on a trn2.48xlarge: ```bash # Activate Neuron venv (SDK 2.29) source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate -pip install tiktoken # Required for K2.5 tokenizer +pip install tiktoken # Required for K2.5/K2.6 tokenizer + +# Run tests (defaults to K2.5 paths) +NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ + pytest test/integration/test_model.py -v --capture=tee-sys -# Run tests +# Run tests with K2.6 weights (override paths via env vars) NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ + KIMI_MODEL_PATH=/mnt/nvme/models/Kimi-K2.6 \ + KIMI_TEXT_MODEL_DIR=/home/ubuntu/models/Kimi-K2.6-text \ + KIMI_COMPILED_PATH=/mnt/nvme/models/Kimi-K2.6-text/neuron-compiled \ pytest test/integration/test_model.py -v --capture=tee-sys ``` @@ -245,16 +286,26 @@ Or run standalone: ```bash NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ python test/integration/test_model.py + +# With K2.6: +NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ + KIMI_MODEL_PATH=/mnt/nvme/models/Kimi-K2.6 \ + python test/integration/test_model.py ``` **Note:** Compilation takes ~10 min, model loading takes ~79 min (dominated by weight sharding across 64 ranks). The first run will compile NEFFs; subsequent runs reuse cached NEFFs. ## Prerequisites -1. **Model weights:** Download from HuggingFace (~555 GB): +1. **Model weights:** Download from HuggingFace (~555 GB each): ```bash + # K2.5 huggingface-cli download moonshotai/Kimi-K2.5 \ --local-dir /mnt/nvme/models/Kimi-K2.5 + + # K2.6 (same architecture, different weights) + huggingface-cli download moonshotai/Kimi-K2.6 \ + --local-dir /mnt/nvme/models/Kimi-K2.6 ``` 2. **Pre-computed vision embeddings:** Trace MoonViT and pre-compute embeddings before loading the text decoder (see "Pre-computing MoonViT Embeddings" above). @@ -292,6 +343,7 @@ MoonViT uses real-number decomposition of 2D complex RoPE and eager attention (n ## Example Checkpoints * [moonshotai/Kimi-K2.5](https://huggingface.co/moonshotai/Kimi-K2.5) +* [moonshotai/Kimi-K2.6](https://huggingface.co/moonshotai/Kimi-K2.6) — post-training update, identical architecture ## Known Limitations @@ -317,13 +369,13 @@ MoonViT uses real-number decomposition of 2D complex RoPE and eager attention (n This is an extension of the Kimi-K2 text-only NxDI contrib (PR #131). Key differences: -| Aspect | K2 | K2.5 | +| Aspect | K2 | K2.5 / K2.6 | |--------|-----|------| | Modality | Text-only | Multimodal (image + text) | | Config | TP=32, EP=2 | TP=64, EP=1 | | Quantization | Blockwise FP8 (native) | INT4 → BF16 → FP8 per-channel | -| Weight format | K2 safetensors | K2.5 compressed-tensors | -| TKG throughput | 6.0 tok/s | 45.9 tok/s | +| Weight format | K2 safetensors | K2.5/K2.6 compressed-tensors | +| TKG throughput | 6.0 tok/s | 45.9 tok/s (K2.5), 57.9 tok/s (K2.6) | | Vision encoder | N/A | MoonViT-400M (35.5 ms) | The 7.6x throughput improvement (6.0 → 45.9 tok/s) comes from TP=64 EP=1 (vs TP=32 EP=2), which eliminates inter-EP communication overhead and gives each core more bandwidth. @@ -332,4 +384,4 @@ The 7.6x throughput improvement (6.0 → 45.9 tok/s) comes from TP=64 EP=1 (vs T Annapurna Labs -**Last Updated:** 2026-04-25 +**Last Updated:** 2026-04-28 diff --git a/contrib/models/Kimi-K2.5/test/integration/test_model.py b/contrib/models/Kimi-K2.5/test/integration/test_model.py index 7ade0015..2a895692 100644 --- a/contrib/models/Kimi-K2.5/test/integration/test_model.py +++ b/contrib/models/Kimi-K2.5/test/integration/test_model.py @@ -1,20 +1,34 @@ #!/usr/bin/env python3 """ -Integration tests for Kimi-K2.5 multimodal NeuronX implementation. +Integration tests for Kimi-K2.5/K2.6 multimodal NeuronX implementation. Tests compilation, loading, and multimodal inference on trn2.48xlarge. +Supports both K2.5 and K2.6 checkpoints — override paths via environment variables. Requirements: - trn2.48xlarge with NEURON_LOGICAL_NC_CONFIG=2 (64 logical cores) - LOCAL_WORLD_SIZE=64 - - K2.5 model weights at MODEL_PATH - - Pre-computed MoonViT embeddings at VISION_EMBEDDINGS_PATH + - Model weights at KIMI_MODEL_PATH (default: /mnt/nvme/models/Kimi-K2.5) + - Pre-computed MoonViT embeddings at KIMI_VISION_EMB_PATH - Neuron SDK 2.29 (Deep Learning AMI Neuron Ubuntu 24.04 20260410) - tiktoken package installed (pip install tiktoken) +Environment variables: + KIMI_MODEL_PATH - Path to K2.5 or K2.6 checkpoint (default: /mnt/nvme/models/Kimi-K2.5) + KIMI_TEXT_MODEL_DIR - Path for text-only model dir (default: /home/ubuntu/models/Kimi-K2.5-text) + KIMI_COMPILED_PATH - Path for compiled NEFFs (default: derived from TEXT_MODEL_DIR) + KIMI_VISION_EMB_PATH - Path to pre-computed vision embeddings (default: /mnt/nvme/models/moonvit_448_real_embeddings.pt) + Usage: - # Full test (compile + load + generate): + # Full test with K2.5 (default): + NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ + pytest test_model.py -v --capture=tee-sys + + # Full test with K2.6: NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ + KIMI_MODEL_PATH=/mnt/nvme/models/Kimi-K2.6 \ + KIMI_TEXT_MODEL_DIR=/home/ubuntu/models/Kimi-K2.6-text \ + KIMI_COMPILED_PATH=/mnt/nvme/models/Kimi-K2.6-text/neuron-compiled \ pytest test_model.py -v --capture=tee-sys # Load-only (skip compile, use existing NEFFs): @@ -58,13 +72,21 @@ # --------------------------------------------------------------------------- -# Configuration +# Configuration — override via environment variables for K2.6 or custom paths # --------------------------------------------------------------------------- -MODEL_PATH = "/mnt/nvme/models/Kimi-K2.5" -TEXT_MODEL_DIR = "/home/ubuntu/models/Kimi-K2.5-text" -COMPILED_MODEL_PATH = "/mnt/nvme/models/Kimi-K2.5-text/neuron-compiled-k25-vl-s512-v5" -VISION_EMBEDDINGS_PATH = "/mnt/nvme/models/moonvit_448_real_embeddings.pt" +MODEL_PATH = os.environ.get("KIMI_MODEL_PATH", "/mnt/nvme/models/Kimi-K2.5") +TEXT_MODEL_DIR = os.environ.get( + "KIMI_TEXT_MODEL_DIR", "/home/ubuntu/models/Kimi-K2.5-text" +) +COMPILED_MODEL_PATH = os.environ.get( + "KIMI_COMPILED_PATH", + "/mnt/nvme/models/Kimi-K2.5-text/neuron-compiled-k25-vl-s512-v5", +) +VISION_EMBEDDINGS_PATH = os.environ.get( + "KIMI_VISION_EMB_PATH", + "/mnt/nvme/models/moonvit_448_real_embeddings.pt", +) # Model configuration (TP=64, EP=1, LNC=2) TP_DEGREE = 64 @@ -430,7 +452,11 @@ def test_performance_tpot(compiled_model, tokenizer, vision_embeddings): ) print("=" * 80) - print("Kimi-K2.5 Multimodal Integration Tests") + print("Kimi-K2.5/K2.6 Multimodal Integration Tests") + print(f" Model path: {MODEL_PATH}") + print(f" Text model dir: {TEXT_MODEL_DIR}") + print(f" Compiled path: {COMPILED_MODEL_PATH}") + print(f" Vision embeddings: {VISION_EMBEDDINGS_PATH}") print("=" * 80) # Load vision embeddings