From d599eff1b51374e94b2a1c049df09fe52632431d Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Thu, 16 Apr 2026 16:02:01 -0400 Subject: [PATCH 01/10] Add Kimi-K2-Instruct-0905 contrib model Onboard moonshotai/Kimi-K2-Instruct-0905 (1T MoE, 384 experts, MLA attention, DeepSeek-V3 architecture) to NxDI on trn2.48xlarge. Configuration: TP=64, EP=2, LNC=1, blockwise FP8 (e4m3, 128x128 blocks) Performance: 3.4 tok/s at BS=1, TPOT=297.5ms, TTFT=1,788ms Key implementation details: - Multi-Latent Attention with compressed KV cache (576 bytes/token/layer) - Blockwise FP8 quantization for routed expert weights (non-experts in BF16) - Streaming checkpoint loader for 62 safetensor shards (avoids OOM) - Sigmoid routing with e_score_correction_bias loaded as router bias - Monkey patches for EP scale sharding and blockwise scale stride Tested on: trn2.48xlarge, Neuron SDK 2.28, us-east-2 --- .../models/Kimi-K2-Instruct-0905/README.md | 256 +++ .../Kimi-K2-Instruct-0905/src/__init__.py | 3 + .../src/modeling_kimi_k2.py | 1548 +++++++++++++++++ .../Kimi-K2-Instruct-0905/test/__init__.py | 0 .../test/integration/__init__.py | 0 .../test/integration/test_model.py | 364 ++++ .../test/unit/__init__.py | 0 7 files changed, 2171 insertions(+) create mode 100644 contrib/models/Kimi-K2-Instruct-0905/README.md create mode 100644 contrib/models/Kimi-K2-Instruct-0905/src/__init__.py create mode 100644 contrib/models/Kimi-K2-Instruct-0905/src/modeling_kimi_k2.py create mode 100644 contrib/models/Kimi-K2-Instruct-0905/test/__init__.py create mode 100644 contrib/models/Kimi-K2-Instruct-0905/test/integration/__init__.py create mode 100644 contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py create mode 100644 contrib/models/Kimi-K2-Instruct-0905/test/unit/__init__.py diff --git a/contrib/models/Kimi-K2-Instruct-0905/README.md b/contrib/models/Kimi-K2-Instruct-0905/README.md new file mode 100644 index 00000000..e690b5ce --- /dev/null +++ b/contrib/models/Kimi-K2-Instruct-0905/README.md @@ -0,0 +1,256 @@ +# Contrib Model: Kimi-K2-Instruct-0905 + +NeuronX Distributed Inference implementation of Moonshot AI's Kimi-K2-Instruct-0905. + +## Model Information + +- **HuggingFace ID:** `moonshotai/Kimi-K2-Instruct-0905` +- **Model Type:** Mixture of Experts (MoE) decoder-only transformer +- **Architecture:** DeepSeek-V3 variant with MLA attention +- **License:** Check HuggingFace model card + +## Architecture Details + +| Parameter | Value | +|-----------|-------| +| Total parameters | ~1,000B | +| Active parameters per token | ~32B | +| Hidden size | 7168 | +| Attention heads | 128 | +| Layers | 61 | +| Vocabulary size | 163840 | +| Routed experts | 384 (8 active per token) | +| Shared experts | 1 per MoE layer | +| Dense layers | 1 (layer 0, `first_k_dense_replace=1`) | +| Expert intermediate size | 2048 | +| Dense intermediate size | 18432 | +| Attention type | Multi-Latent Attention (MLA) | +| KV LoRA rank | 512 | +| QK rope head dim | 64 | +| Q LoRA rank | 1536 | +| RoPE | YaRN (factor=64, max_position_embeddings=262144) | +| Quantization | Blockwise FP8 (e4m3, 128x128 blocks) | +| Router activation | Sigmoid with `e_score_correction_bias` | +| Top-K normalization | Enabled (`norm_topk_prob=True`) | +| Routed scaling factor | 2.827 | + +### Key Implementation Details + +- **Multi-Latent Attention (MLA):** Compressed KV cache with only 576 bytes/token/layer + (qk_rope_head_dim + kv_lora_rank = 64 + 512). Weight absorption is used to avoid + decompressing KV during decode. + +- **Blockwise FP8 Quantization:** Routed expert weights are kept in FP8 (e4m3) with + 128x128 block scales. Non-expert weights (attention, embeddings, shared experts, norms) + are dequantized to BF16 during loading. Requires the + `--experimental-unsafe-fp8e4m3fn-as-fp8e4m3` compiler flag. + +- **Streaming Checkpoint Loader:** Custom `checkpoint_loader_fn` that processes the 62 + safetensor shards one at a time to avoid OOM on 2TB host RAM. Each shard is loaded, + processed (FP8 handling, expert packing, router renaming), and accumulated. + +- **Monkey Patches (applied during `load()`):** + - `_apply_ep_scale_fix`: Prevents EP-sharding of per-channel FP8 scale tensors (shape [1,1,W]). + - `_apply_blockwise_scale_stride_fix`: Forces stride=1 for blockwise scale partitioning. + +- **Selective Loading Threshold:** Must be patched to 0.0 in + `neuronx_distributed/modules/moe/model_utils.py` on the target instance to ensure all + 384 expert weights load correctly. + +## Validation Results + +**Validated:** 2026-04-16 +**Configuration:** TP=64, EP=2, LNC=1, batch_size=1, seq_len=1024, blockwise FP8 + +### Test Results + +| Test | Status | Result | +|------|--------|--------| +| Smoke Test | PASS | Model compiles and loads on trn2.48xlarge | +| Generation | PASS | Correct answers for factual questions (10/13 prompts) | +| Throughput | PASS | 3.4 tok/s at BS=1 | + +### Performance Metrics + +| Metric | Value | +|--------|-------| +| TPOT (per-token latency) | 297.5 ms | +| Throughput (BS=1) | 3.4 tok/s | +| TTFT (61 input tokens) | 1,788 ms | +| Compile time (total) | 73 min (TKG -O3: 49 min, CTE -O1: 24 min) | +| Model load time | 47 min | +| HBM utilization | ~78% (1,200 GB / 1,536 GB) | + +### Token Generation Sweep (BS=1, seq_len=1024) + +| Output Tokens | TTFT P50 (ms) | TPOT P50 (ms) | tok/s | E2E P50 (ms) | +|---------------|---------------|----------------|-------|---------------| +| 16 | 1,787.9 | 297.36 | 3.4 | 6,248.3 | +| 32 | 1,787.9 | 297.37 | 3.4 | 11,006.6 | +| 64 | 1,788.3 | 297.52 | 3.4 | 20,533.8 | +| 128 | 1,787.9 | 297.44 | 3.4 | 39,564.4 | +| 256 | 1,788.4 | 297.61 | 3.4 | 77,681.2 | +| 512 | 1,795.9 | 297.55 | 3.4 | 153,842.1 | + +### Batching Results + +Batching provides **zero throughput improvement** on this model. The MoE computation is +perfectly bandwidth-bound -- each TKG step must load all 192 local expert weight matrices +from HBM regardless of batch size. BS=4 TPOT scales linearly (1,191 ms), yielding the +same aggregate throughput as BS=1. + +### Performance Bottleneck + +TPOT breakdown (estimated per 297.5 ms token): + +1. **MoE expert MLPs (~250 ms, ~84%):** 192 local experts x 2 matmuls per layer. + FP8 weights are dequantized to BF16 before the NKI kernel. +2. **MLA attention (~25 ms, ~8%):** Weight absorption projections + KV cache. +3. **Router + all-to-all (~15 ms, ~5%):** Router TopK + expert dispatch across EP=2. +4. **Other (~7.5 ms, ~3%):** RMSNorm, residuals, lm_head. + +Primary optimization opportunity: native blockwise FP8 kernel in the nki-lib MoE TKG +pipeline (currently blocked -- nki-lib requires per-channel FP8 scales). + +## Usage + +```python +import json +import os +import torch +from neuronx_distributed_inference.models.config import MoENeuronConfig, RouterConfig + +# Import model classes +from src.modeling_kimi_k2 import NeuronKimiK2ForCausalLM, KimiK2InferenceConfig + +model_path = "/path/to/Kimi-K2-Instruct-0905" +compiled_path = "/path/to/compiled" + +# Read HF config +with open(os.path.join(model_path, "config.json")) as f: + hf_config = json.load(f) + +# Configure for trn2.48xlarge +neuron_config = MoENeuronConfig( + tp_degree=64, + ep_degree=2, + logical_nc_config=1, + max_batch_size=1, + seq_len=1024, + n_active_tokens=128, + torch_dtype="bfloat16", + capacity_factor=1.0, + glu_mlp=True, + moe_ep_degree=2, + moe_tp_degree=64, + context_encoding_buckets=[128, 1024], + router_config=RouterConfig(act_fn="sigmoid", dtype="float32"), + # FP8 quantization + quantized=True, + quantized_checkpoints_path=model_path, + quantization_dtype="f8e4m3", + modules_to_not_convert=[ + "self_attn", "shared_experts", "embed_tokens", + "lm_head", "norm", "router", "layers.0", + ], + quantization_type="blockwise_symmetric", + quantization_block_axis=[1, 2], + quantization_block_size=[128, 128], +) + +# Build config from HF config fields +hf_kwargs = {k: v for k, v in hf_config.items() + if k not in ("auto_map", "torch_dtype", "transformers_version", "architectures")} +config = KimiK2InferenceConfig(neuron_config=neuron_config, **hf_kwargs) + +# Compile and load +model = NeuronKimiK2ForCausalLM(model_path, config) +model.compile(compiled_path) # ~73 min +model.load(compiled_path) # ~47 min + +# Generate (CPU greedy sampling, no on-device sampling) +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +# See test/integration/test_model.py for the full generation loop +``` + +**Important:** Run with environment variables: +```bash +NEURON_LOGICAL_NC_CONFIG=1 LOCAL_WORLD_SIZE=128 python your_script.py +``` + +## Compatibility Matrix + +| Instance / SDK Version | 2.28+ | 2.27 and earlier | +|------------------------|-------|------------------| +| trn2.48xlarge (LNC=1) | Working | Not tested | +| trn2.3xlarge | Not supported (needs TP=64, EP=2) | Not supported | +| trn1.32xlarge | Not supported (needs 128 cores) | Not supported | +| inf2 | Not supported | Not supported | + +## Testing + +Run integration tests on a trn2.48xlarge: + +```bash +# Activate Neuron venv +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate + +# Run tests +NEURON_LOGICAL_NC_CONFIG=1 LOCAL_WORLD_SIZE=128 \ + pytest test/integration/test_model.py -v --capture=tee-sys +``` + +Or run standalone: + +```bash +NEURON_LOGICAL_NC_CONFIG=1 LOCAL_WORLD_SIZE=128 \ + python test/integration/test_model.py +``` + +**Note:** Compilation takes ~73 min and loading takes ~47 min. The first run will compile +NEFFs to the compiled model path. Subsequent runs with existing NEFFs skip compilation. + +## Prerequisites + +1. **Selective loading threshold patch:** On the target instance, patch + `neuronx_distributed/modules/moe/model_utils.py` to set the selective loading + threshold to 0.0 (default is too high for 384 experts). + +2. **Model weights:** Download from HuggingFace: + ```bash + huggingface-cli download moonshotai/Kimi-K2-Instruct-0905 \ + --local-dir /home/ubuntu/models/Kimi-K2-Instruct-0905 + ``` + +3. **Host RAM:** At least 2 TB (the streaming loader peaks at ~95 GB RSS, but + safetensors mmap can use more virtual memory). + +## Example Checkpoints + +* [moonshotai/Kimi-K2-Instruct-0905](https://huggingface.co/moonshotai/Kimi-K2-Instruct-0905) + +## Known Limitations + +- **No on-device sampling:** The model uses CPU greedy sampling because the vocabulary + size (163840) is not divisible by common TP degrees, causing shape mismatches in the + on-device sampling kernel. + +- **Elevated EOS logit:** The `<|im_end|>` token (ID 163586) has an elevated logit in + early generation steps, likely due to the FP8->BF16 dequantization of shared expert + weights or slight router bias approximation. Mitigated by masking EOS for the first + few generation tokens (`min_tokens_before_eos=3`). + +- **Batching does not improve throughput:** The MoE computation is bandwidth-bound + (192 expert weight loads per step), so higher batch sizes increase latency linearly + without improving aggregate throughput. + +- **Compiler flags have no measurable impact:** -O3 with DGE vs -O1 showed 0% difference, + confirming the bottleneck is weight bandwidth, not compute or scheduling. + +## Maintainer + +Annapurna Labs + +**Last Updated:** 2026-04-16 diff --git a/contrib/models/Kimi-K2-Instruct-0905/src/__init__.py b/contrib/models/Kimi-K2-Instruct-0905/src/__init__.py new file mode 100644 index 00000000..a39f1c0b --- /dev/null +++ b/contrib/models/Kimi-K2-Instruct-0905/src/__init__.py @@ -0,0 +1,3 @@ +from .modeling_kimi_k2 import NeuronKimiK2ForCausalLM, KimiK2InferenceConfig + +__all__ = ["NeuronKimiK2ForCausalLM", "KimiK2InferenceConfig"] diff --git a/contrib/models/Kimi-K2-Instruct-0905/src/modeling_kimi_k2.py b/contrib/models/Kimi-K2-Instruct-0905/src/modeling_kimi_k2.py new file mode 100644 index 00000000..221a76c2 --- /dev/null +++ b/contrib/models/Kimi-K2-Instruct-0905/src/modeling_kimi_k2.py @@ -0,0 +1,1548 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Kimi-K2 (moonshotai/Kimi-K2-Instruct-0905) on Neuron via NxDI. +# +# Architecture: DeepseekV3ForCausalLM variant with MLA attention + MoE +# - 1T total parameters, 32B active per token +# - 384 routed experts (8 active per token) + 1 shared expert +# - Multi-Latent Attention (MLA) with compressed KV cache +# - Sigmoid routing with e_score_correction_bias + normalized top-K +# - Blockwise FP8 quantization (e4m3, 128x128 blocks) +# - YaRN RoPE (factor=64, max_position_embeddings=262144) +# +# Supported configuration: +# - trn2.48xlarge: TP=64, EP=2, LNC=1 (128 logical cores) +# - Blockwise FP8 for routed expert weights +# - CPU greedy sampling (no on-device sampling) +# +# References: +# - NxDI DeepseekV3Attention: models/deepseek/modeling_deepseek.py +# - Qwen3-MoE reference: models/qwen3_moe/modeling_qwen3_moe.py + +import gc +import logging +import math +import os +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers.mappings import ( + gather_from_sequence_parallel_region, +) +from neuronx_distributed.utils import cpu_mode + +from neuronx_distributed_inference.models.config import InferenceConfig, MoENeuronConfig +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import manual_softmax +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.moe_v2 import initialize_moe_module + +from neuronx_distributed.modules.moe.routing import RouterTopK + +from neuronx_distributed_inference.models.deepseek.rope_util import ( + DeepseekV3YarnRotaryEmbedding, + apply_rotary_pos_emb, + yarn_get_mscale, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# RMSNorm +# --------------------------------------------------------------------------- + + +def get_rmsnorm_cls(): + return KimiK2RMSNorm if cpu_mode() else CustomRMSNorm + + +class KimiK2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# --------------------------------------------------------------------------- +# MoE initialization +# --------------------------------------------------------------------------- + + +def initialize_kimi_k2_moe_module(config: "KimiK2InferenceConfig"): + """Initialize MoE module with sigmoid routing + e_score_correction_bias. + + Uses standard RouterTopK with bias=True. The e_score_correction_bias is + loaded as the router linear layer's bias (pre-sigmoid), which is a slight + semantics change from the original post-sigmoid application. The biases are + small (~0.03-0.1) so the approximation is acceptable. + """ + from neuronx_distributed_inference.modules.moe_v2 import ( + initialize_moe_process_group, + ) + from neuronx_distributed.modules.moe.expert_mlps_v2 import ExpertMLPsV2 + from neuronx_distributed.modules.moe.model import MoE + from neuronx_distributed.modules.moe.moe_configs import RoutedExpertsMLPOpsConfig + + enabled_hybrid_sharding = config.neuron_config.hybrid_sharding_config is not None + ( + moe_tkg_tensor_model_parallel_group, + moe_tkg_expert_model_parallel_group, + moe_cte_tensor_model_parallel_group, + moe_cte_expert_model_parallel_group, + ) = initialize_moe_process_group(config, enabled_hybrid_sharding) + + router = RouterTopK( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + dtype=config.neuron_config.router_config.dtype, + act_fn=config.neuron_config.router_config.act_fn, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + sequence_dimension=1, + bias=True, + apply_act_fn_over_topk=False, + store_transposed_weights=False, + ) + + expert_mlps = ExpertMLPsV2( + routed_experts_mlp_config=RoutedExpertsMLPOpsConfig( + num_experts=config.num_local_experts, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + top_k=config.num_experts_per_tok, + hidden_act=config.hidden_act, + bias=False, + glu_mlp=config.neuron_config.glu_mlp, + glu_type=config.neuron_config.glu_type, + hidden_act_scaling_factor=config.neuron_config.hidden_act_scaling_factor, + hidden_act_bias=config.neuron_config.hidden_act_bias, + early_expert_affinity_modulation=config.neuron_config.early_expert_affinity_modulation, + normalize_top_k_affinities=config.neuron_config.normalize_top_k_affinities, + enable_spmd_rank=config.neuron_config.blockwise_matmul_config.parallelize_token_to_block_mapping, + ), + blockwise_matmul_config=config.neuron_config.blockwise_matmul_config, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + dtype=config.neuron_config.torch_dtype, + is_prefill=config.neuron_config.is_prefill_stage, + enabled_hybrid_sharding=enabled_hybrid_sharding, + tensor_model_parallel_group=parallel_state.get_tensor_model_parallel_group(), + expert_model_parallel_group=parallel_state.get_expert_model_parallel_group(), + cte_tensor_model_parallel_group=moe_cte_tensor_model_parallel_group, + cte_expert_model_parallel_group=moe_cte_expert_model_parallel_group, + tkg_tensor_model_parallel_group=moe_tkg_tensor_model_parallel_group, + tkg_expert_model_parallel_group=moe_tkg_expert_model_parallel_group, + ) + + moe = MoE( + router=router, + expert_mlps=expert_mlps, + shared_experts=None, + rmsnorm=None, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + return_expert_index=config.neuron_config.return_expert_index, + sequence_dimension=1, + init_tkg_module=False, + tkg_config=None, + ) + moe.eval() + return moe + + +# --------------------------------------------------------------------------- +# Shared Expert MLP (dense, always active) +# --------------------------------------------------------------------------- + + +class KimiK2SharedExpertMLP(nn.Module): + """Shared expert MLP (SwiGLU) for Kimi-K2. Always active, not routed.""" + + def __init__(self, config: "KimiK2InferenceConfig"): + super().__init__() + self.hidden_size = config.hidden_size + self.shared_intermediate_size = config.moe_intermediate_size + + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + self.shared_intermediate_size, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + self.shared_intermediate_size, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.down_proj = RowParallelLinear( + self.shared_intermediate_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + ) + self.act_fn = nn.SiLU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# --------------------------------------------------------------------------- +# Dense MLP (for first_k_dense_replace layers, i.e. layer 0) +# --------------------------------------------------------------------------- + + +class KimiK2DenseMLP(nn.Module): + """Standard SwiGLU MLP for the dense layers (first_k_dense_replace = 1).""" + + def __init__(self, config: "KimiK2InferenceConfig"): + super().__init__() + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.down_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + ) + self.act_fn = nn.SiLU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# --------------------------------------------------------------------------- +# MLA Attention (Multi-Latent Attention) +# --------------------------------------------------------------------------- + + +class KimiK2Attention(NeuronAttentionBase): + """ + Multi-Latent Attention (MLA) for Kimi-K2. + + KV cache format: stores (k_pe, compressed_kv) concatenated. + - K cache: [batch, 1, seq, qk_rope_head_dim + kv_lora_rank] + - V cache: same (placeholder, only K is read during decode) + """ + + def __init__( + self, + config: "KimiK2InferenceConfig", + layer_idx: int, + tensor_model_parallel_group=None, + ): + super().__init__( + config=config, + tensor_model_parallel_group=tensor_model_parallel_group, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_attention_heads, + head_dim=config.v_head_dim, + num_cores_per_group=config.num_cores_per_group, + rms_norm_eps=config.rms_norm_eps, + ) + + self.layer_idx = layer_idx + self.bias = getattr(config, "attention_bias", False) + self.attention_dropout = config.attention_dropout + self.num_total_heads = config.num_attention_heads + + if cpu_mode(): + self.num_heads = self.num_total_heads + else: + self.num_heads = self.num_total_heads // config.neuron_config.tp_degree + + # MLA dimensions + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + self.head_dim = self.v_head_dim + + # MLA doesn't use fused QKV + self.qkv_proj = None + + # YaRN RoPE + self.rotary_emb = DeepseekV3YarnRotaryEmbedding( + dim=config.qk_rope_head_dim, + max_position_embeddings=config.max_position_embeddings, + scaling_factor=config.rope_scaling["factor"], + base=config.rope_theta, + mscale=config.rope_scaling["mscale"], + mscale_all_dim=config.rope_scaling["mscale_all_dim"], + beta_fast=config.rope_scaling["beta_fast"], + beta_slow=config.rope_scaling["beta_slow"], + ) + + # Softmax scale with mscale adjustment + self.softmax_scale = self.q_head_dim ** (-0.5) + if config.rope_scaling is not None: + mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + self.is_causal = True + self._init_mla_projections(config) + + def _init_mla_projections(self, config): + dtype = self.torch_dtype + tp_group = self.tensor_model_parallel_group + + # Query projection (LoRA: hidden -> q_lora_rank -> num_heads * q_head_dim) + if self.q_lora_rank is None: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_total_heads * self.q_head_dim, + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, + config.q_lora_rank, + bias=config.attention_bias, + dtype=dtype, + ) + self.q_a_layernorm = get_rmsnorm_cls()(config.q_lora_rank) + self.q_b_proj = ColumnParallelLinear( + config.q_lora_rank, + self.num_total_heads * self.q_head_dim, + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + + # KV projection (compressed: hidden -> kv_lora_rank + qk_rope_head_dim) + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + dtype=dtype, + ) + self.kv_a_layernorm = get_rmsnorm_cls()(config.kv_lora_rank) + + # kv_b_proj: decompresses latent -> per-head qk_nope + v + if tp_group is not None: + self.kv_b_proj = ColumnParallelLinear( + config.kv_lora_rank, + self.num_total_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + else: + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_total_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + # Output projection + if tp_group is not None: + self.o_proj = RowParallelLinear( + self.num_attention_heads * self.head_dim, + self.hidden_size, + bias=self.bias, + input_is_parallel=True, + dtype=self.torch_dtype, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=self.sequence_dimension, + tensor_model_parallel_group=tp_group, + reduce_dtype=self.rpl_reduce_dtype, + ) + else: + self.o_proj = nn.Linear( + self.num_attention_heads * self.head_dim, + self.hidden_size, + bias=self.bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: torch.Tensor = None, + active_mask: Optional[torch.LongTensor] = None, + adapter_ids=None, + cos_cache: Optional[torch.Tensor] = None, + sin_cache: Optional[torch.Tensor] = None, + **kwargs, + ): + if ( + self.sequence_parallel_enabled + and self.tensor_model_parallel_group is not None + ): + hidden_states = gather_from_sequence_parallel_region( + hidden_states, + self.sequence_dimension, + process_group=self.tensor_model_parallel_group, + ) + + bsz, q_len, _ = hidden_states.size() + + # Weight absorption: precompute from kv_b_proj weights + wkv_b = self.kv_b_proj.weight + wkv_b = wkv_b.view(self.num_heads, -1, self.kv_lora_rank) + out_absorb = wkv_b[:, self.qk_nope_head_dim :, :] # V absorption + q_absorb = wkv_b[:, : self.qk_nope_head_dim, :] # Q nope absorption + + # Query projection + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + + # KV compression + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + q_nope, q_pe = torch.tensor_split(q, (self.qk_nope_head_dim,), dim=-1) + compressed_kv, k_pe = torch.tensor_split( + compressed_kv, (self.kv_lora_rank,), dim=-1 + ) + compressed_kv = self.kv_a_layernorm(compressed_kv) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + + # Q nope absorption + q_nope = torch.einsum("hdc,bhqd->bhqc", q_absorb, q_nope) + + # RoPE + seq_len = self.neuron_config.seq_len + if sin_cache is None and cos_cache is None: + cos_cache, sin_cache = self.rotary_emb(k_pe, seq_len) + q_pe = apply_rotary_pos_emb(q_pe, cos_cache, sin_cache, position_ids) + k_pe = apply_rotary_pos_emb(k_pe, cos_cache, sin_cache, position_ids) + + # Attention scores + active_scores = torch.matmul(q_pe, k_pe.transpose(2, 3)) + torch.einsum( + "bhqc,blc->bhql", q_nope, compressed_kv + ) + active_scores *= self.softmax_scale + + if past_key_value is None: + # Context encoding (prefill) + active_scores = torch.where( + attention_mask, + active_scores, + torch.finfo(active_scores.dtype).min, + ) + active_scores = nn.functional.softmax( + active_scores, dim=-1, dtype=torch.float32 + ).to(k_pe.dtype) + x = torch.einsum("bhql,blc->bhqc", active_scores, compressed_kv) + attn_output = torch.einsum("bhqc,hdc->bhqd", x, out_absorb) + else: + # Token generation (decode) -- split prior cache + cached_kv = past_key_value[0] + if cached_kv.dim() == 4: + cached_kv = cached_kv.squeeze(1) + k_pe_prior, compressed_kv_prior = torch.tensor_split( + cached_kv, + [self.qk_rope_head_dim], + dim=-1, + ) + k_pe_prior = k_pe_prior.reshape( + bsz, + 1, + compressed_kv_prior.shape[1], + self.qk_rope_head_dim, + ) + + prior_scores = torch.matmul( + q_pe, k_pe_prior.transpose(2, 3) + ) + torch.einsum("bhqc,blc->bhql", q_nope, compressed_kv_prior) + prior_scores *= self.softmax_scale + prior_scores = torch.where( + attention_mask, + prior_scores, + torch.finfo(prior_scores.dtype).min, + ) + prior_scores = prior_scores.to(torch.float32) + + softmax_prior, softmax_active = manual_softmax( + prior_scores, + active_scores, + is_speculation=False, + ) + softmax_prior = softmax_prior.to(k_pe.dtype) + softmax_active = softmax_active.to(k_pe.dtype) + + x = torch.einsum("bhql,blc->bhqc", softmax_active, compressed_kv) + attn_active = torch.einsum("bhqc,hdc->bhqd", x, out_absorb) + + x = torch.einsum("bhql,blc->bhqc", softmax_prior, compressed_kv_prior) + attn_prior = torch.einsum("bhqc,hdc->bhqd", x, out_absorb) + + attn_output = attn_prior + attn_active + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + attn_output = self.o_proj(attn_output) + + # KV cache: concatenate k_pe and compressed_kv + k_pe_flat = k_pe.squeeze(1) + kv_for_cache = torch.cat([k_pe_flat, compressed_kv], dim=-1) + kv_for_cache = kv_for_cache.unsqueeze(1) + + # Store same data in both K and V cache slots (only K is read during decode) + past_key_value = (kv_for_cache, kv_for_cache) + + return attn_output, past_key_value, cos_cache, sin_cache + + +# --------------------------------------------------------------------------- +# Decoder Layer +# --------------------------------------------------------------------------- + + +class KimiK2DecoderLayer(nn.Module): + """Decoder layer: MLA attention + (MoE or Dense MLP) + optional shared expert.""" + + def __init__(self, config: "KimiK2InferenceConfig", layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.self_attn = KimiK2Attention(config=config, layer_idx=layer_idx) + + # MLP: dense for layer 0 (first_k_dense_replace=1), MoE for the rest + self.is_moe_layer = layer_idx >= config.first_k_dense_replace + + if self.is_moe_layer: + self.mlp = initialize_kimi_k2_moe_module(config=config) + if config.n_shared_experts > 0: + self.shared_experts = KimiK2SharedExpertMLP(config) + else: + self.shared_experts = None + else: + self.mlp = KimiK2DenseMLP(config) + self.shared_experts = None + + # Routed expert scaling factor (DeepSeek-V3 / Kimi-K2 pattern) + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) + + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + cos_cache: Optional[torch.Tensor] = None, + sin_cache: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, ...]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cos_cache=cos_cache, + sin_cache=sin_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if self.is_moe_layer: + moe_output = self.mlp(hidden_states, padding_mask)[0] + moe_output = moe_output * self.routed_scaling_factor + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + hidden_states = moe_output + shared_output + else: + hidden_states = moe_output + else: + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value, cos_cache, sin_cache, None) + return outputs + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +class KimiK2InferenceConfig(InferenceConfig): + """Inference config for Kimi-K2 (DeepSeek-V3 architecture).""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.num_local_experts = self.n_routed_experts # 384 + self.first_k_dense_replace = getattr(self, "first_k_dense_replace", 1) + self.n_shared_experts = getattr(self, "n_shared_experts", 1) + + # Router config: sigmoid with normalized top-k + self.neuron_config.router_config.dtype = torch.float32 + self.neuron_config.router_config.act_fn = "sigmoid" + self.neuron_config.normalize_top_k_affinities = True + + # GLU MLP for SwiGLU + self.neuron_config.glu_mlp = True + + def get_required_attributes(self) -> List[str]: + return [ + "attention_bias", + "hidden_act", + "hidden_size", + "intermediate_size", + "kv_lora_rank", + "max_position_embeddings", + "moe_intermediate_size", + "n_routed_experts", + "n_shared_experts", + "num_attention_heads", + "num_experts_per_tok", + "num_hidden_layers", + "num_key_value_heads", + "q_lora_rank", + "qk_nope_head_dim", + "qk_rope_head_dim", + "rms_norm_eps", + "rope_scaling", + "rope_theta", + "scoring_func", + "v_head_dim", + "vocab_size", + ] + + def add_derived_config(self): + self.num_cores_per_group = 1 + + @classmethod + def get_neuron_config_cls(cls) -> Type[MoENeuronConfig]: + return MoENeuronConfig + + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + + +class NeuronKimiK2Model(NeuronBaseModel): + def setup_attr_for_model(self, config: KimiK2InferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + + # MLA KV cache: 1 "head" with dim = qk_rope_head_dim + kv_lora_rank + self.num_key_value_heads = 1 + config.head_dim = ( + config.qk_rope_head_dim + config.kv_lora_rank + ) # 64 + 512 = 576 + + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: KimiK2InferenceConfig): + self.padding_idx = getattr(config, "pad_token_id", None) + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + + self.layers = nn.ModuleList( + [ + KimiK2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + ) + + +# --------------------------------------------------------------------------- +# Block FP8 Dequantization Utilities +# --------------------------------------------------------------------------- + +# FP8 E4M3 max representable value. +# PyTorch's e4m3fn has max=448 (no NaN encoding), but with +# --experimental-unsafe-fp8e4m3fn-as-fp8e4m3, exponent-15 values (>240) become NaN. +# Use 240.0 to ensure all quantized values stay in the e4m3-safe range. +_FP8_E4M3_MAX = 240.0 + + +def _dequant_block_fp8_to_fp32( + fp8_weight: Tensor, block_scales: Tensor, block_size: List[int] +) -> Tensor: + """Dequantize block-FP8 (e4m3, 128x128 blocks) weight to FP32.""" + se = block_scales.repeat_interleave(block_size[0], dim=0).repeat_interleave( + block_size[1], dim=1 + ) + if se.shape != fp8_weight.shape: + se = se[: fp8_weight.shape[0], : fp8_weight.shape[1]] + return fp8_weight.to(torch.float32) * se.to(torch.float32) + + +def _clamp_fp8_exponent15(fp8_weight: Tensor) -> Tensor: + """Clamp FP8 e4m3fn bytes that have exponent=15, which become NaN in e4m3. + + On Neuron hardware with --experimental-unsafe-fp8e4m3fn-as-fp8e4m3, + bytes 0x78-0x7E and 0xF8-0xFE become NaN. Clamp to max safe values. + """ + raw = fp8_weight.view(torch.uint8) + pos_exp15 = (raw >= 0x78) & (raw <= 0x7E) + neg_exp15 = (raw >= 0xF8) & (raw <= 0xFE) + clamped = raw.clone() + clamped[pos_exp15] = 0x77 # +240.0 + clamped[neg_exp15] = 0xF7 # -240.0 + return clamped.view(torch.float8_e4m3fn) + + +def _pack_experts_blockwise_fp8( + expert_fp8_weights: List[Tensor], + expert_block_scales: List[Tensor], + block_size: List[int], + tp_degree: int, + layout: str = "gate_up", +) -> Tuple[Tensor, Tensor]: + """Pack per-expert FP8 weights and block scales into fused tensors. + + Preserves original FP8 bytes (with exponent-15 clamping) and packs + block-wise scales into the matching layout. This avoids the lossy + FP8->FP32->FP8 re-quantization path. + + For gate_up (ColumnParallel, stride=2): packs [gate_w.T, up_w.T] -> [E, H, 2*I] + For down (RowParallel, stride=1): packs [down_w.T] -> [E, I, H] + """ + n_experts = len(expert_fp8_weights) + bs0, bs1 = block_size + + if layout == "gate_up": + gate0, up0 = expert_fp8_weights[0] + I, H = gate0.shape + packed_w = torch.empty(n_experts, H, 2 * I, dtype=torch.float8_e4m3fn) + + gs0, us0 = expert_block_scales[0] + sI, sH = gs0.shape + raw_scale = torch.empty(n_experts, sH, 2 * sI, dtype=torch.float32) + + for e in range(n_experts): + g_fp8, u_fp8 = expert_fp8_weights[e] + g_scale, u_scale = expert_block_scales[e] + g_fp8 = _clamp_fp8_exponent15(g_fp8) + u_fp8 = _clamp_fp8_exponent15(u_fp8) + packed_w[e, :, :I] = g_fp8.T + packed_w[e, :, I:] = u_fp8.T + raw_scale[e, :, :sI] = g_scale.T + raw_scale[e, :, sI:] = u_scale.T + + out_dim = 2 * I + per_tp = out_dim // tp_degree + repeat_factor = bs1 // per_tp if per_tp < bs1 else 1 + if repeat_factor > 1: + expanded_scale = raw_scale.repeat_interleave(repeat_factor, dim=2) + else: + expanded_scale = raw_scale + + return packed_w, expanded_scale + + elif layout == "down": + d0 = expert_fp8_weights[0] + H_orig, I = d0.shape + packed_w = torch.empty(n_experts, I, H_orig, dtype=torch.float8_e4m3fn) + + ds0 = expert_block_scales[0] + sH, sI = ds0.shape + raw_scale = torch.empty(n_experts, sI, sH, dtype=torch.float32) + + for e in range(n_experts): + d_fp8 = expert_fp8_weights[e] + d_scale = expert_block_scales[e] + d_fp8 = _clamp_fp8_exponent15(d_fp8) + packed_w[e] = d_fp8.T + raw_scale[e] = d_scale.T + + in_dim = I + per_tp = in_dim // tp_degree + repeat_factor = bs0 // per_tp if per_tp < bs0 else 1 + if repeat_factor > 1: + expanded_scale = raw_scale.repeat_interleave(repeat_factor, dim=1) + else: + expanded_scale = raw_scale + + return packed_w, expanded_scale + + else: + raise ValueError(f"Unknown layout: {layout}") + + +# --------------------------------------------------------------------------- +# State Dict Conversion +# --------------------------------------------------------------------------- + + +def convert_kimi_k2_hf_to_neuron_state_dict( + neuron_state_dict: Dict[str, Any], + config: KimiK2InferenceConfig, +) -> Dict[str, Any]: + """Convert HuggingFace Kimi-K2 / DeepSeek-V3 weights to NxDI format. + + Key conversions: + 1. Dequantize block-FP8 weights to BF16 (if FP8 weights detected) + 2. Rename router weights (gate.weight -> router.linear_router.weight) + 3. Concatenate per-expert weights into packed tensors for ExpertMLPsV2 + 4. Handle shared expert weight naming + 5. Handle e_score_correction_bias loading as router bias + """ + # Check if weights are already pre-converted + has_packed_experts = any( + "expert_mlps.mlp_op.gate_up_proj.weight" in k for k in neuron_state_dict + ) + has_per_expert = any( + "mlp.experts.0.gate_proj.weight" in k for k in neuron_state_dict + ) + + if has_packed_experts and not has_per_expert: + logger.info("Weights already pre-converted. Adding rank_util only.") + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + for layer_idx in range(config.num_hidden_layers): + neuron_state_dict[f"layers.{layer_idx}.self_attn.rank_util.rank"] = ( + torch.arange(0, config.neuron_config.tp_degree, dtype=torch.int32) + ) + return neuron_state_dict + + # Dequantize FP8 weights + _maybe_dequantize_fp8(neuron_state_dict, config) + + # Add rank utility tensors + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + for layer_idx in range(config.num_hidden_layers): + neuron_state_dict[f"layers.{layer_idx}.self_attn.rank_util.rank"] = ( + torch.arange(0, config.neuron_config.tp_degree, dtype=torch.int32) + ) + + is_moe_layer = layer_idx >= config.first_k_dense_replace + if not is_moe_layer: + continue + + # Router weights + gate_key = f"layers.{layer_idx}.mlp.gate.weight" + if gate_key in neuron_state_dict: + neuron_state_dict[f"layers.{layer_idx}.mlp.router.linear_router.weight"] = ( + neuron_state_dict[gate_key].detach().clone() + ) + del neuron_state_dict[gate_key] + + # e_score_correction_bias + bias_key = f"layers.{layer_idx}.mlp.gate.e_score_correction_bias" + if bias_key in neuron_state_dict: + neuron_state_dict[ + f"layers.{layer_idx}.mlp.router.e_score_correction_bias" + ] = neuron_state_dict[bias_key].detach().clone() + del neuron_state_dict[bias_key] + + # Expert weights: per-expert -> packed format + expert_0_gate = f"layers.{layer_idx}.mlp.experts.0.gate_proj.weight" + if expert_0_gate not in neuron_state_dict: + continue + + intermediate_size, hidden_size = neuron_state_dict[expert_0_gate].shape + device = neuron_state_dict[expert_0_gate].device + dtype = neuron_state_dict[expert_0_gate].dtype + num_experts = config.n_routed_experts + + # Concatenate gate + up projections + gate_up_proj = torch.empty( + num_experts, + hidden_size, + 2 * intermediate_size, + dtype=dtype, + device=device, + ) + for e in range(num_experts): + gate_w = ( + neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight" + ] + .T.detach() + .clone() + ) + up_w = ( + neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"] + .T.detach() + .clone() + ) + gate_up_proj[e, :, :intermediate_size] = gate_w + gate_up_proj[e, :, intermediate_size:] = up_w + del neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight" + ] + del neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"] + + neuron_state_dict[ + f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + ] = gate_up_proj + + # Down projections + down_proj = torch.empty( + num_experts, + intermediate_size, + hidden_size, + dtype=dtype, + device=device, + ) + for e in range(num_experts): + down_w = ( + neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" + ] + .T.detach() + .clone() + ) + down_proj[e] = down_w + del neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" + ] + + neuron_state_dict[ + f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.down_proj.weight" + ] = down_proj + + # Shared expert rename + for proj in ["gate_proj", "up_proj", "down_proj"]: + hf_key = f"layers.{layer_idx}.mlp.shared_experts.{proj}.weight" + nxdi_key = f"layers.{layer_idx}.shared_experts.{proj}.weight" + if hf_key in neuron_state_dict: + neuron_state_dict[nxdi_key] = neuron_state_dict[hf_key] + del neuron_state_dict[hf_key] + + gc.collect() + + return neuron_state_dict + + +def _maybe_dequantize_fp8( + neuron_state_dict: Dict[str, Any], + config: "KimiK2InferenceConfig", +): + """Dequantize block-FP8 weights to BF16.""" + scale_layers = [] + + for layer_key in list(neuron_state_dict.keys()): + if "_scale_inv" in layer_key or "weight_scale_inv" in layer_key: + scales = neuron_state_dict[layer_key] + scale_layers.append(layer_key) + + if layer_key.endswith(".weight_scale_inv"): + fp8_layer_name = layer_key.replace(".weight_scale_inv", ".weight") + elif "_scale_inv" in layer_key: + fp8_layer_name = layer_key.replace("_scale_inv", "") + + if fp8_layer_name not in neuron_state_dict: + continue + + fp8_layer = neuron_state_dict[fp8_layer_name] + + if hasattr(config, "quantization_config") and config.quantization_config: + block_size = config.quantization_config.get( + "weight_block_size", [128, 128] + ) + else: + block_size = [128, 128] + + fp32_val = _dequant_block_fp8_to_fp32(fp8_layer, scales, block_size) + neuron_state_dict[fp8_layer_name] = fp32_val.to( + config.neuron_config.torch_dtype + ) + + for scale_layer in scale_layers: + del neuron_state_dict[scale_layer] + + +# --------------------------------------------------------------------------- +# Top-level CausalLM +# --------------------------------------------------------------------------- + + +class NeuronKimiK2ForCausalLM(NeuronBaseForCausalLM): + """Kimi-K2 for causal language modeling on Neuron.""" + + _model_cls = NeuronKimiK2Model + + @staticmethod + def load_hf_model(model_path: str, **kwargs): + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + **kwargs, + ) + + @classmethod + def get_config_cls(cls) -> Type[KimiK2InferenceConfig]: + return KimiK2InferenceConfig + + @staticmethod + def _apply_ep_scale_fix(): + """Monkey-patch ExpertFusedLinear._mark_expert_parallel_weights to skip + per-channel scale params that can't be EP-sharded (shape [1, 1, W]).""" + from neuronx_distributed.modules.moe.moe_parallel_layers import ( + ExpertFusedLinear, + ) + + _original_mark = ExpertFusedLinear._mark_expert_parallel_weights + + def _patched_mark( + self_inner, + iterable=None, + expert_parallel_group_size=None, + is_prefill=True, + expert_distribution=None, + ): + from neuronx_distributed.parallel_layers.parallel_state import ( + get_expert_model_parallel_size, + ) + + if expert_parallel_group_size is None: + expert_parallel_group_size = get_expert_model_parallel_size() + + if expert_parallel_group_size > 1: + if iterable is None: + params_to_mark = [] + for name, p in self_inner.named_parameters(): + if name == "scale" and p.shape[0] == 1: + continue + params_to_mark.append(p) + iterable = params_to_mark + + for p in iterable: + p.expert_model_parallel = True + if is_prefill: + p.is_prefill = True + p.expert_distribution = expert_distribution + + ExpertFusedLinear._mark_expert_parallel_weights = _patched_mark + + @staticmethod + def _apply_blockwise_scale_stride_fix(): + """Monkey-patch _setup_for_scale to use stride=1 for blockwise symmetric + scales, which avoids strided splitting failures when per-rank weight size + is smaller than block size.""" + from neuronx_distributed.quantization.quantization_layers import ( + BaseQuantizeParallelLinear, + ) + from neuronx_distributed.quantization.quantization_config import ( + QuantizationType, + ) + + _original_setup = BaseQuantizeParallelLinear._setup_for_scale + + def _patched_setup(self_inner, *args, **kwargs): + _original_setup(self_inner, *args, **kwargs) + if ( + hasattr(self_inner, "quantization_type") + and self_inner.quantization_type == QuantizationType.BLOCKWISE_SYMMETRIC + and hasattr(self_inner, "scale") + and hasattr(self_inner.scale, "partition_stride") + and self_inner.scale.partition_stride > 1 + ): + self_inner.scale.partition_stride = 1 + + BaseQuantizeParallelLinear._setup_for_scale = _patched_setup + + def load( + self, + compiled_model_path, + start_rank_id=None, + local_ranks_size=None, + skip_warmup=False, + ): + """Override to apply EP scale fix before loading.""" + if getattr(self.neuron_config, "quantized", False): + self._apply_ep_scale_fix() + self._apply_blockwise_scale_stride_fix() + return super().load( + compiled_model_path, + start_rank_id=start_rank_id, + local_ranks_size=local_ranks_size, + skip_warmup=skip_warmup, + ) + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, + config: KimiK2InferenceConfig, + ) -> dict: + return convert_kimi_k2_hf_to_neuron_state_dict(state_dict, config) + + def checkpoint_loader_fn(self, mmap: bool = False): + """Memory-efficient streaming checkpoint loader for 1T-parameter models. + + Loads safetensor shards one at a time, processes weights (FP8 handling, + expert packing, router renaming), and accumulates results to avoid OOM. + + When quantized=True (FP8 path): + - Routed expert weights stay in FP8 with blockwise scales + - All other weights are dequantized to BF16 + + When quantized=False (BF16 path): + - All weights are dequantized from FP8 to BF16 + """ + import json as json_mod + from safetensors.torch import load_file + + model_path = getattr(self.config, "_name_or_path", None) + if model_path is None or not os.path.exists(str(model_path)): + model_path = self.model_path + + index_path = os.path.join(model_path, "model.safetensors.index.json") + + if not os.path.exists(index_path): + return super().checkpoint_loader_fn(mmap=mmap) + + with open(index_path, "r") as f: + index = json_mod.load(f) + + weight_map = index["weight_map"] + shard_files = sorted(set(weight_map.values())) + + quant_config = getattr(self.config, "quantization_config", None) + if isinstance(quant_config, dict): + block_size = quant_config.get("weight_block_size", [128, 128]) + else: + block_size = [128, 128] + n_routed_experts = getattr(self.config, "n_routed_experts", 384) + first_k_dense_replace = getattr(self.config, "first_k_dense_replace", 1) + keep_experts_fp8 = getattr(self.config.neuron_config, "quantized", False) + num_layers = self.config.num_hidden_layers + + # Determine which shards are needed (supports reduced-layer testing) + needed_shards = set() + for key, shard_file in weight_map.items(): + clean_key = key[len("model.") :] if key.startswith("model.") else key + if "layers." in clean_key: + parts = clean_key.split(".") + idx = parts.index("layers") + 1 + layer_idx = int(parts[idx]) + if layer_idx < num_layers: + needed_shards.add(shard_file) + else: + needed_shards.add(shard_file) + + logger.info( + f"Streaming loader: {len(shard_files)} shards, {len(needed_shards)} needed, " + f"block_size={block_size}, experts={n_routed_experts}, fp8={keep_experts_fp8}" + ) + + result_dict = {} + for i, shard_file in enumerate(shard_files): + if shard_file not in needed_shards: + continue + shard_path = os.path.join(model_path, shard_file) + logger.info(f"Loading shard [{i + 1}/{len(shard_files)}]: {shard_file}") + + shard_data = load_file(shard_path) + + # Strip "model." prefix + for key in list(shard_data.keys()): + if key.startswith("model."): + shard_data[key[len("model.") :]] = shard_data.pop(key) + + # Filter out keys for layers beyond num_hidden_layers + for key in list(shard_data.keys()): + if "layers." in key: + parts = key.split(".") + idx = parts.index("layers") + 1 + layer_idx = int(parts[idx]) + if layer_idx >= num_layers: + del shard_data[key] + + # Determine layers in this shard + layer_ids = set() + for key in shard_data: + if "layers." in key: + parts = key.split(".") + idx = parts.index("layers") + 1 + layer_ids.add(int(parts[idx])) + + # Build expert weight/scale key mapping + expert_weight_keys = set() + expert_scale_keys = {} + if keep_experts_fp8: + for key in list(shard_data.keys()): + if ".mlp.experts." in key and ".weight" in key: + if ".shared_experts." not in key: + expert_weight_keys.add(key) + + for key in list(shard_data.keys()): + if "_scale_inv" in key or "weight_scale_inv" in key: + if key.endswith(".weight_scale_inv"): + wk = key.replace(".weight_scale_inv", ".weight") + elif "_scale_inv" in key: + wk = key.replace("_scale_inv", "") + else: + continue + if wk in expert_weight_keys: + expert_scale_keys[wk] = key + + # Process FP8 weights: dequant non-expert, keep expert raw + scale_keys = [ + k for k in shard_data if "weight_scale_inv" in k or "_scale_inv" in k + ] + for scale_key in scale_keys: + scales = shard_data[scale_key] + if scale_key.endswith(".weight_scale_inv"): + weight_key = scale_key.replace(".weight_scale_inv", ".weight") + elif "_scale_inv" in scale_key: + weight_key = scale_key.replace("_scale_inv", "") + else: + del shard_data[scale_key] + continue + + if weight_key not in shard_data: + del shard_data[scale_key] + continue + + if shard_data[weight_key].dtype != torch.float8_e4m3fn: + del shard_data[scale_key] + continue + + if weight_key in expert_weight_keys: + pass # Keep for expert packing + else: + fp8_w = shard_data[weight_key] + fp32_w = _dequant_block_fp8_to_fp32(fp8_w, scales, block_size) + shard_data[weight_key] = fp32_w.to(torch.bfloat16) + del fp8_w, fp32_w + del shard_data[scale_key] + + # Remove orphan scale keys + for k in [ + k + for k in shard_data + if ("_scale_inv" in k or "weight_scale_inv" in k) + and k not in expert_scale_keys.values() + ]: + del shard_data[k] + + # Cast non-FP8 tensors to BF16 + for key in list(shard_data.keys()): + t = shard_data[key] + if torch.is_floating_point(t) and t.dtype not in ( + torch.bfloat16, + torch.float8_e4m3fn, + torch.float32, + ): + if t.dtype != torch.int64 and t.dtype != torch.int32: + shard_data[key] = t.to(torch.bfloat16) + + # Pack experts and rename for MoE layers + for layer_idx in sorted(layer_ids): + prefix = f"layers.{layer_idx}" + if layer_idx >= first_k_dense_replace: + # Router rename + gate_key = f"{prefix}.mlp.gate.weight" + if gate_key in shard_data: + shard_data[f"{prefix}.mlp.router.linear_router.weight"] = ( + shard_data.pop(gate_key) + ) + bias_key = f"{prefix}.mlp.gate.e_score_correction_bias" + if bias_key in shard_data: + shard_data[f"{prefix}.mlp.router.linear_router.bias"] = ( + shard_data.pop(bias_key) + ) + + # Pack experts + e0_gate = f"{prefix}.mlp.experts.0.gate_proj.weight" + if e0_gate in shard_data: + isize, hsize = shard_data[e0_gate].shape + + if keep_experts_fp8: + gate_up_weights = [] + gate_up_scales = [] + down_weights = [] + down_scales = [] + + for e in range(n_routed_experts): + gk = f"{prefix}.mlp.experts.{e}.gate_proj.weight" + uk = f"{prefix}.mlp.experts.{e}.up_proj.weight" + dk = f"{prefix}.mlp.experts.{e}.down_proj.weight" + gsk = expert_scale_keys.get(gk) + usk = expert_scale_keys.get(uk) + dsk = expert_scale_keys.get(dk) + + g_fp8 = shard_data.pop(gk) if gk in shard_data else None + u_fp8 = shard_data.pop(uk) if uk in shard_data else None + g_scale = ( + shard_data.pop(gsk) + if gsk and gsk in shard_data + else None + ) + u_scale = ( + shard_data.pop(usk) + if usk and usk in shard_data + else None + ) + + if g_fp8 is not None and u_fp8 is not None: + gate_up_weights.append((g_fp8, u_fp8)) + gate_up_scales.append((g_scale, u_scale)) + + d_fp8 = shard_data.pop(dk) if dk in shard_data else None + d_scale = ( + shard_data.pop(dsk) + if dsk and dsk in shard_data + else None + ) + if d_fp8 is not None: + down_weights.append(d_fp8) + down_scales.append(d_scale) + + if gate_up_weights: + gu_fp8, gu_scale = _pack_experts_blockwise_fp8( + gate_up_weights, + gate_up_scales, + block_size, + tp_degree=self.config.neuron_config.tp_degree, + layout="gate_up", + ) + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + ] = gu_fp8 + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.gate_up_proj.scale" + ] = gu_scale + del gate_up_weights, gate_up_scales + + if down_weights: + dn_fp8, dn_scale = _pack_experts_blockwise_fp8( + down_weights, + down_scales, + block_size, + tp_degree=self.config.neuron_config.tp_degree, + layout="down", + ) + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.down_proj.weight" + ] = dn_fp8 + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.down_proj.scale" + ] = dn_scale + del down_weights, down_scales + + else: + # BF16 path + dtype = shard_data[e0_gate].dtype + gate_up = torch.zeros( + n_routed_experts, + hsize, + 2 * isize, + dtype=dtype, + device="cpu", + ) + for e in range(n_routed_experts): + gk = f"{prefix}.mlp.experts.{e}.gate_proj.weight" + uk = f"{prefix}.mlp.experts.{e}.up_proj.weight" + if gk in shard_data: + gate_up[e, :, :isize] = shard_data.pop(gk).T + if uk in shard_data: + gate_up[e, :, isize:] = shard_data.pop(uk).T + + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + ] = gate_up + + down = torch.zeros( + n_routed_experts, + isize, + hsize, + dtype=dtype, + device="cpu", + ) + for e in range(n_routed_experts): + dk = f"{prefix}.mlp.experts.{e}.down_proj.weight" + if dk in shard_data: + down[e] = shard_data.pop(dk).T + + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.down_proj.weight" + ] = down + + # Clean up remaining per-expert keys + for e in range(n_routed_experts): + for proj in ["gate_proj", "up_proj", "down_proj"]: + for suffix in [".weight", ".weight_scale_inv"]: + k = f"{prefix}.mlp.experts.{e}.{proj}{suffix}" + if k in shard_data: + del shard_data[k] + + # Shared expert rename + for proj in ["gate_proj", "up_proj", "down_proj"]: + hf_key = f"{prefix}.mlp.shared_experts.{proj}.weight" + nxdi_key = f"{prefix}.shared_experts.{proj}.weight" + if hf_key in shard_data: + shard_data[nxdi_key] = shard_data.pop(hf_key) + + # Cast remaining float32 non-scale tensors to BF16 + for key in list(shard_data.keys()): + t = shard_data[key] + if ( + t.dtype == torch.float32 + and not key.endswith(".scale") + and not key.endswith("_scale_inv") + and not key.endswith("linear_router.bias") + ): + shard_data[key] = t.to(torch.bfloat16) + + result_dict.update(shard_data) + del shard_data + gc.collect() + + # Add rank_util tensors + tp = self.config.neuron_config.tp_degree + result_dict["rank_util.rank"] = torch.arange(0, tp, dtype=torch.int32) + for layer_idx in range(self.config.num_hidden_layers): + result_dict[f"layers.{layer_idx}.self_attn.rank_util.rank"] = torch.arange( + 0, tp, dtype=torch.int32 + ) + + # Add fused prefix if needed + if self._FUSED_PREFIX != "": + for key in list(result_dict.keys()): + result_dict[f"{self._FUSED_PREFIX}.{key}"] = result_dict.pop(key) + + logger.info(f"Streaming loader done. Total keys: {len(result_dict)}") + return result_dict + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + def get_compiler_args(self) -> str: + """Compiler args optimized for Kimi-K2 on trn2. + + Key flags: + - -O3 for TKG with EP > 1: avoids Modular flow perf degradation + - --internal-enable-dge-levels vector_dynamic_offsets: DGE optimization + - --enable-ccop-compute-overlap: CC overlap for MoE all-to-all + - --lnc: must match runtime NEURON_LOGICAL_NC_CONFIG + """ + lnc = getattr(self.neuron_config, "logical_nc_config", 1) + ep_degree = getattr(self.neuron_config, "moe_ep_degree", 1) + + if self.compile_tag == TOKEN_GENERATION_MODEL_TAG and ep_degree > 1: + optimization_level = "-O3" + else: + optimization_level = "-O1" + + compiler_args = ( + "--enable-saturate-infinity " + "--enable-mixed-precision-accumulation " + "--model-type transformer " + f"{optimization_level}" + ) + compiler_args += ( + " --tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2'" + ) + compiler_args += " --tensorizer-options='--vectorize-strided-dma'" + compiler_args += " --auto-cast=none" + compiler_args += " --internal-enable-dge-levels vector_dynamic_offsets" + compiler_args += f" --lnc={lnc}" + + hlo2tensorizer_opts = "--verify-hlo=true" + if ( + getattr(self.neuron_config, "quantized", False) + and getattr(self.neuron_config, "quantization_dtype", "") == "f8e4m3" + ): + hlo2tensorizer_opts += " --experimental-unsafe-fp8e4m3fn-as-fp8e4m3" + + compiler_args += f" --internal-hlo2tensorizer-options='{hlo2tensorizer_opts}'" + + return compiler_args diff --git a/contrib/models/Kimi-K2-Instruct-0905/test/__init__.py b/contrib/models/Kimi-K2-Instruct-0905/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Kimi-K2-Instruct-0905/test/integration/__init__.py b/contrib/models/Kimi-K2-Instruct-0905/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py b/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py new file mode 100644 index 00000000..ece05b2b --- /dev/null +++ b/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python3 +""" +Integration tests for Kimi-K2-Instruct-0905 NeuronX implementation. + +Tests model compilation, loading, and inference on trn2.48xlarge. + +Requirements: + - trn2.48xlarge with NEURON_LOGICAL_NC_CONFIG=1 (128 logical cores) + - LOCAL_WORLD_SIZE=128 + - Model weights at MODEL_PATH + - Neuron SDK 2.28+ (Deep Learning AMI Neuron Ubuntu 24.04) + - Selective loading threshold patched to 0.0 in + neuronx_distributed/modules/moe/model_utils.py + +Usage: + # Full test (compile + load + generate): + NEURON_LOGICAL_NC_CONFIG=1 LOCAL_WORLD_SIZE=128 pytest test_model.py -v --capture=tee-sys + + # Load-only (skip compile, use existing NEFFs): + NEURON_LOGICAL_NC_CONFIG=1 LOCAL_WORLD_SIZE=128 pytest test_model.py -v --capture=tee-sys -k "not compile" +""" + +import json +import os +import sys +import time +from pathlib import Path + +import pytest +import torch +from transformers import AutoTokenizer + +from neuronx_distributed_inference.models.config import MoENeuronConfig, RouterConfig +from neuronx_distributed_inference.utils.hf_adapter import ( + HuggingFaceGenerationAdapter, + load_pretrained_config, +) + +# Import from src directory +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) +from modeling_kimi_k2 import NeuronKimiK2ForCausalLM, KimiK2InferenceConfig + + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +MODEL_PATH = "/home/ubuntu/models/Kimi-K2-Instruct-0905" +COMPILED_MODEL_PATH = "/home/ubuntu/kimi-k2/neuron-compiled-fp8-bw-no-ods" + +# Model configuration +TP_DEGREE = 64 +EP_DEGREE = 2 +LNC = 1 +BATCH_SIZE = 1 +SEQ_LEN = 1024 +N_ACTIVE_TOKENS = 128 + + +def build_config(): + """Build KimiK2InferenceConfig for trn2.48xlarge.""" + with open(os.path.join(MODEL_PATH, "config.json"), "r") as f: + hf_config = json.load(f) + + neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + ep_degree=EP_DEGREE, + logical_nc_config=LNC, + max_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + n_active_tokens=N_ACTIVE_TOKENS, + torch_dtype="bfloat16", + capacity_factor=1.0, + glu_mlp=True, + moe_ep_degree=EP_DEGREE, + moe_tp_degree=TP_DEGREE, + context_encoding_buckets=[N_ACTIVE_TOKENS, SEQ_LEN], + router_config=RouterConfig(act_fn="sigmoid", dtype="float32"), + # FP8 quantization for routed experts + quantized=True, + quantized_checkpoints_path=MODEL_PATH, + quantization_dtype="f8e4m3", + modules_to_not_convert=[ + "self_attn", + "shared_experts", + "embed_tokens", + "lm_head", + "norm", + "router", + "layers.0", + ], + quantization_type="blockwise_symmetric", + quantization_block_axis=[1, 2], + quantization_block_size=[128, 128], + ) + + hf_kwargs = { + k: v + for k, v in hf_config.items() + if k not in ("auto_map", "torch_dtype", "transformers_version", "architectures") + } + + config = KimiK2InferenceConfig(neuron_config=neuron_config, **hf_kwargs) + return config + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def tokenizer(): + """Load tokenizer.""" + tokenizer = AutoTokenizer.from_pretrained( + MODEL_PATH, padding_side="left", trust_remote_code=True + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +@pytest.fixture(scope="module") +def compiled_model(): + """Compile (if needed) and load model.""" + config = build_config() + model = NeuronKimiK2ForCausalLM(MODEL_PATH, config) + + compiled_path = Path(COMPILED_MODEL_PATH) + if not compiled_path.exists() or not (compiled_path / "model.pt").exists(): + print(f"\nCompiling model to {COMPILED_MODEL_PATH}...") + t0 = time.time() + model.compile(COMPILED_MODEL_PATH) + print(f"Compilation done in {(time.time() - t0) / 60:.1f} min") + + print(f"\nLoading model from {COMPILED_MODEL_PATH}...") + t0 = time.time() + model.load(COMPILED_MODEL_PATH) + print(f"Model loaded in {(time.time() - t0) / 60:.1f} min") + + return model + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + + +def generate_tokens( + model, tokenizer, prompt, max_new_tokens=32, min_tokens_before_eos=3 +): + """Generate tokens using CPU greedy sampling (no on-device sampling). + + Uses model.forward(input_ids, attention_mask, position_ids, seq_ids) API. + Applies chat template with <|im_start|>user/assistant format. + """ + messages = [{"role": "user", "content": prompt}] + input_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + input_ids = tokenizer.encode(input_text, return_tensors="pt") + + batch_size = input_ids.shape[0] + seq_len = input_ids.shape[1] + seq_ids = torch.arange(batch_size, dtype=torch.long) + + generated_tokens = [] + eos_id = 163586 # <|im_end|> + + # Reset KV cache for a fresh generation + model.reset() + + # Context encoding + position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long) + + outputs = model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + seq_ids=seq_ids, + ) + + logits = outputs.logits if hasattr(outputs, "logits") else outputs[0] + last_logits = logits[:, -1, :] + + # Mask EOS for first token if requested + if min_tokens_before_eos > 0 and eos_id < last_logits.shape[-1]: + last_logits[:, eos_id] = float("-inf") + + next_token = torch.argmax(last_logits, dim=-1, keepdim=True) + next_token_id = next_token[0].item() + generated_tokens.append(next_token_id) + + # Token generation loop + cur_pos = seq_len + for step in range(max_new_tokens - 1): + input_ids_step = next_token.to(torch.long) + position_ids_step = torch.tensor([[cur_pos]], dtype=torch.long) + attention_mask_step = torch.ones(batch_size, cur_pos + 1, dtype=torch.long) + + outputs = model.forward( + input_ids=input_ids_step, + attention_mask=attention_mask_step, + position_ids=position_ids_step, + seq_ids=seq_ids, + ) + + logits = outputs.logits if hasattr(outputs, "logits") else outputs[0] + last_logits = logits[:, -1, :] + + # Mask EOS for first few tokens + if step + 1 < min_tokens_before_eos and eos_id < last_logits.shape[-1]: + last_logits[:, eos_id] = float("-inf") + + next_token = torch.argmax(last_logits, dim=-1, keepdim=True) + next_token_id = next_token[0].item() + generated_tokens.append(next_token_id) + cur_pos += 1 + + if next_token_id == eos_id: + break + + output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) + all_ids = torch.cat([input_ids, torch.tensor([generated_tokens]).long()], dim=-1) + return output_text, all_ids + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_model_loads(compiled_model): + """Smoke test: model loads successfully.""" + assert compiled_model is not None + assert hasattr(compiled_model, "config") + assert hasattr(compiled_model.config, "neuron_config") + print("PASS: Model loaded successfully") + + +def test_model_generates(compiled_model, tokenizer): + """Test that model generates coherent text.""" + output, _ = generate_tokens( + compiled_model, tokenizer, "What is the capital of France?" + ) + assert len(output) > 0, "Output should not be empty" + # Due to the known elevated EOS logit, the model may not always answer + # factual questions correctly at BS=1. Check for reasonable output. + words = output.split() + assert len(words) >= 3, f"Output too short: {output}" + print(f"PASS: Generation test - Output: {output[:200]}") + + +def test_output_coherence(compiled_model, tokenizer): + """Test that output is not gibberish or repetitive.""" + output, _ = generate_tokens( + compiled_model, + tokenizer, + "Explain quantum computing in one sentence.", + max_new_tokens=64, + ) + words = output.split() + assert len(words) >= 3, f"Output too short: {output}" + + # Check for excessive repetition + if len(words) >= 10: + for i in range(len(words) - 5): + repeated = all(words[i + j] == words[i] for j in range(5)) + assert not repeated, f"Excessive repetition in output: {output}" + + print(f"PASS: Coherence test - Output: {output[:100]}") + + +def test_performance_tpot(compiled_model, tokenizer): + """Measure per-token output latency (TPOT).""" + prompt = "What is the capital of France?" + + # Warmup + generate_tokens(compiled_model, tokenizer, prompt, max_new_tokens=10) + + # Determine input length for TPOT calculation + messages = [{"role": "user", "content": prompt}] + input_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + input_len = len(tokenizer.encode(input_text)) + + # Measure + num_tokens = 32 + n_runs = 5 + tpots = [] + + for _ in range(n_runs): + t0 = time.perf_counter() + _, gen_ids = generate_tokens( + compiled_model, tokenizer, prompt, max_new_tokens=num_tokens + ) + elapsed = time.perf_counter() - t0 + actual_generated = gen_ids.shape[1] - input_len + if actual_generated > 1: + tpot = (elapsed * 1000) / actual_generated + tpots.append(tpot) + + if tpots: + median_tpot = sorted(tpots)[len(tpots) // 2] + tok_per_sec = 1000.0 / median_tpot + print(f"PASS: TPOT = {median_tpot:.1f} ms ({tok_per_sec:.1f} tok/s)") + # Kimi-K2 at BS=1: expected ~297 ms/token (3.4 tok/s) + assert median_tpot < 500, f"TPOT {median_tpot:.1f}ms exceeds 500ms threshold" + else: + pytest.skip("Could not measure TPOT (no tokens generated)") + + +# --------------------------------------------------------------------------- +# Standalone runner +# --------------------------------------------------------------------------- + + +if __name__ == "__main__": + print("=" * 80) + print("Kimi-K2-Instruct-0905 Integration Tests") + print("=" * 80) + + config = build_config() + model = NeuronKimiK2ForCausalLM(MODEL_PATH, config) + + compiled_path = Path(COMPILED_MODEL_PATH) + if not compiled_path.exists() or not (compiled_path / "model.pt").exists(): + print(f"\nCompiling model to {COMPILED_MODEL_PATH}...") + t0 = time.time() + model.compile(COMPILED_MODEL_PATH) + print(f"Compilation done in {(time.time() - t0) / 60:.1f} min") + + print(f"\nLoading model from {COMPILED_MODEL_PATH}...") + t0 = time.time() + model.load(COMPILED_MODEL_PATH) + print(f"Model loaded in {(time.time() - t0) / 60:.1f} min") + + tok = AutoTokenizer.from_pretrained( + MODEL_PATH, padding_side="left", trust_remote_code=True + ) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + + print("\n" + "=" * 80) + print("Running Tests") + print("=" * 80) + + print("\n1. Smoke Test...") + test_model_loads(model) + + print("\n2. Generation Test...") + test_model_generates(model, tok) + + print("\n3. Coherence Test...") + test_output_coherence(model, tok) + + print("\n4. TPOT Performance Test...") + test_performance_tpot(model, tok) + + print("\n" + "=" * 80) + print("All tests passed!") + print("=" * 80) diff --git a/contrib/models/Kimi-K2-Instruct-0905/test/unit/__init__.py b/contrib/models/Kimi-K2-Instruct-0905/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b From 3a2d9d253c5957b1997993a8f4106f5424f3bb30 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Fri, 17 Apr 2026 09:45:09 -0400 Subject: [PATCH 02/10] Switch to LNC=2/TP=32/EP=2 config (53% faster: 5.2 tok/s vs 3.4 tok/s) LNC=2 gives 2x HBM bandwidth per logical core on trn2.48xlarge. With TP=32 and EP=2 (64 ranks), NEFF I/O fits at 17.55 GB / 24 GB. TPOT improves from 297 ms to ~191 ms (purely bandwidth-bound MoE). --- .../models/Kimi-K2-Instruct-0905/README.md | 83 +++++++++---------- .../test/integration/test_model.py | 14 ++-- 2 files changed, 48 insertions(+), 49 deletions(-) diff --git a/contrib/models/Kimi-K2-Instruct-0905/README.md b/contrib/models/Kimi-K2-Instruct-0905/README.md index e690b5ce..7aa7e08e 100644 --- a/contrib/models/Kimi-K2-Instruct-0905/README.md +++ b/contrib/models/Kimi-K2-Instruct-0905/README.md @@ -59,38 +59,36 @@ NeuronX Distributed Inference implementation of Moonshot AI's Kimi-K2-Instruct-0 ## Validation Results -**Validated:** 2026-04-16 -**Configuration:** TP=64, EP=2, LNC=1, batch_size=1, seq_len=1024, blockwise FP8 +**Validated:** 2026-04-17 +**Recommended Configuration:** TP=32, EP=2, LNC=2, batch_size=1, seq_len=1024, blockwise FP8 ### Test Results | Test | Status | Result | |------|--------|--------| | Smoke Test | PASS | Model compiles and loads on trn2.48xlarge | -| Generation | PASS | Correct answers for factual questions (10/13 prompts) | -| Throughput | PASS | 3.4 tok/s at BS=1 | +| Generation | PASS | Generates coherent text | +| Throughput | PASS | 5.2 tok/s at BS=1 (LNC=2) | -### Performance Metrics +### Performance Metrics (Recommended: LNC=2, TP=32) | Metric | Value | |--------|-------| -| TPOT (per-token latency) | 297.5 ms | -| Throughput (BS=1) | 3.4 tok/s | -| TTFT (61 input tokens) | 1,788 ms | -| Compile time (total) | 73 min (TKG -O3: 49 min, CTE -O1: 24 min) | -| Model load time | 47 min | -| HBM utilization | ~78% (1,200 GB / 1,536 GB) | - -### Token Generation Sweep (BS=1, seq_len=1024) - -| Output Tokens | TTFT P50 (ms) | TPOT P50 (ms) | tok/s | E2E P50 (ms) | -|---------------|---------------|----------------|-------|---------------| -| 16 | 1,787.9 | 297.36 | 3.4 | 6,248.3 | -| 32 | 1,787.9 | 297.37 | 3.4 | 11,006.6 | -| 64 | 1,788.3 | 297.52 | 3.4 | 20,533.8 | -| 128 | 1,787.9 | 297.44 | 3.4 | 39,564.4 | -| 256 | 1,788.4 | 297.61 | 3.4 | 77,681.2 | -| 512 | 1,795.9 | 297.55 | 3.4 | 153,842.1 | +| TPOT (per-token latency) | ~191 ms | +| Throughput (BS=1) | 5.2 tok/s | +| Compile time (total) | 67 min | +| Model load time | 30 min | +| HBM I/O utilization | 17.55 GB / 24 GB | + +### LNC=2 vs LNC=1 Comparison + +LNC=2 (TP=32, EP=2) is **53% faster** than LNC=1 (TP=64, EP=2) because each +logical core gets 2x HBM bandwidth, and MoE decode is purely bandwidth-bound. + +| Config | TP | EP | Cores | TPOT | tok/s | Speedup | +|--------|----|----|-------|------|-------|---------| +| LNC=2 (recommended) | 32 | 2 | 64 | ~191 ms | 5.2 | **+53%** | +| LNC=1 | 64 | 2 | 128 | 297 ms | 3.4 | baseline | ### Batching Results @@ -101,13 +99,13 @@ same aggregate throughput as BS=1. ### Performance Bottleneck -TPOT breakdown (estimated per 297.5 ms token): +TPOT breakdown (estimated per ~191 ms token at LNC=2): -1. **MoE expert MLPs (~250 ms, ~84%):** 192 local experts x 2 matmuls per layer. +1. **MoE expert MLPs (~160 ms, ~84%):** 192 local experts x 2 matmuls per layer. FP8 weights are dequantized to BF16 before the NKI kernel. -2. **MLA attention (~25 ms, ~8%):** Weight absorption projections + KV cache. -3. **Router + all-to-all (~15 ms, ~5%):** Router TopK + expert dispatch across EP=2. -4. **Other (~7.5 ms, ~3%):** RMSNorm, residuals, lm_head. +2. **MLA attention (~16 ms, ~8%):** Weight absorption projections + KV cache. +3. **Router + all-to-all (~10 ms, ~5%):** Router TopK + expert dispatch across EP=2. +4. **Other (~5 ms, ~3%):** RMSNorm, residuals, lm_head. Primary optimization opportunity: native blockwise FP8 kernel in the nki-lib MoE TKG pipeline (currently blocked -- nki-lib requires per-channel FP8 scales). @@ -130,11 +128,11 @@ compiled_path = "/path/to/compiled" with open(os.path.join(model_path, "config.json")) as f: hf_config = json.load(f) -# Configure for trn2.48xlarge +# Configure for trn2.48xlarge (LNC=2, recommended) neuron_config = MoENeuronConfig( - tp_degree=64, + tp_degree=32, ep_degree=2, - logical_nc_config=1, + logical_nc_config=2, max_batch_size=1, seq_len=1024, n_active_tokens=128, @@ -142,7 +140,7 @@ neuron_config = MoENeuronConfig( capacity_factor=1.0, glu_mlp=True, moe_ep_degree=2, - moe_tp_degree=64, + moe_tp_degree=32, context_encoding_buckets=[128, 1024], router_config=RouterConfig(act_fn="sigmoid", dtype="float32"), # FP8 quantization @@ -165,8 +163,8 @@ config = KimiK2InferenceConfig(neuron_config=neuron_config, **hf_kwargs) # Compile and load model = NeuronKimiK2ForCausalLM(model_path, config) -model.compile(compiled_path) # ~73 min -model.load(compiled_path) # ~47 min +model.compile(compiled_path) # ~67 min +model.load(compiled_path) # ~30 min # Generate (CPU greedy sampling, no on-device sampling) from transformers import AutoTokenizer @@ -177,16 +175,17 @@ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) **Important:** Run with environment variables: ```bash -NEURON_LOGICAL_NC_CONFIG=1 LOCAL_WORLD_SIZE=128 python your_script.py +NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 python your_script.py ``` ## Compatibility Matrix | Instance / SDK Version | 2.28+ | 2.27 and earlier | |------------------------|-------|------------------| -| trn2.48xlarge (LNC=1) | Working | Not tested | -| trn2.3xlarge | Not supported (needs TP=64, EP=2) | Not supported | -| trn1.32xlarge | Not supported (needs 128 cores) | Not supported | +| trn2.48xlarge (LNC=2, recommended) | Working (5.2 tok/s) | Not tested | +| trn2.48xlarge (LNC=1) | Working (3.4 tok/s) | Not tested | +| trn2.3xlarge | Not supported (needs TP=32, EP=2 = 64 cores) | Not supported | +| trn1.32xlarge | Not supported (needs 64 cores at LNC=2) | Not supported | | inf2 | Not supported | Not supported | ## Testing @@ -197,19 +196,19 @@ Run integration tests on a trn2.48xlarge: # Activate Neuron venv source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate -# Run tests -NEURON_LOGICAL_NC_CONFIG=1 LOCAL_WORLD_SIZE=128 \ +# Run tests (LNC=2, recommended) +NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ pytest test/integration/test_model.py -v --capture=tee-sys ``` Or run standalone: ```bash -NEURON_LOGICAL_NC_CONFIG=1 LOCAL_WORLD_SIZE=128 \ +NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ python test/integration/test_model.py ``` -**Note:** Compilation takes ~73 min and loading takes ~47 min. The first run will compile +**Note:** Compilation takes ~67 min and loading takes ~30 min. The first run will compile NEFFs to the compiled model path. Subsequent runs with existing NEFFs skip compilation. ## Prerequisites @@ -253,4 +252,4 @@ NEFFs to the compiled model path. Subsequent runs with existing NEFFs skip compi Annapurna Labs -**Last Updated:** 2026-04-16 +**Last Updated:** 2026-04-17 diff --git a/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py b/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py index ece05b2b..1e67a243 100644 --- a/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py +++ b/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py @@ -5,8 +5,8 @@ Tests model compilation, loading, and inference on trn2.48xlarge. Requirements: - - trn2.48xlarge with NEURON_LOGICAL_NC_CONFIG=1 (128 logical cores) - - LOCAL_WORLD_SIZE=128 + - trn2.48xlarge with NEURON_LOGICAL_NC_CONFIG=2 (64 logical cores) + - LOCAL_WORLD_SIZE=64 - Model weights at MODEL_PATH - Neuron SDK 2.28+ (Deep Learning AMI Neuron Ubuntu 24.04) - Selective loading threshold patched to 0.0 in @@ -14,10 +14,10 @@ Usage: # Full test (compile + load + generate): - NEURON_LOGICAL_NC_CONFIG=1 LOCAL_WORLD_SIZE=128 pytest test_model.py -v --capture=tee-sys + NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 pytest test_model.py -v --capture=tee-sys # Load-only (skip compile, use existing NEFFs): - NEURON_LOGICAL_NC_CONFIG=1 LOCAL_WORLD_SIZE=128 pytest test_model.py -v --capture=tee-sys -k "not compile" + NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 pytest test_model.py -v --capture=tee-sys -k "not compile" """ import json @@ -49,9 +49,9 @@ COMPILED_MODEL_PATH = "/home/ubuntu/kimi-k2/neuron-compiled-fp8-bw-no-ods" # Model configuration -TP_DEGREE = 64 +TP_DEGREE = 32 EP_DEGREE = 2 -LNC = 1 +LNC = 2 BATCH_SIZE = 1 SEQ_LEN = 1024 N_ACTIVE_TOKENS = 128 @@ -306,7 +306,7 @@ def test_performance_tpot(compiled_model, tokenizer): median_tpot = sorted(tpots)[len(tpots) // 2] tok_per_sec = 1000.0 / median_tpot print(f"PASS: TPOT = {median_tpot:.1f} ms ({tok_per_sec:.1f} tok/s)") - # Kimi-K2 at BS=1: expected ~297 ms/token (3.4 tok/s) + # Kimi-K2 at BS=1 LNC=2: expected ~191 ms/token (5.2 tok/s) assert median_tpot < 500, f"TPOT {median_tpot:.1f}ms exceeds 500ms threshold" else: pytest.skip("Could not measure TPOT (no tokens generated)") From eeaab0dfd9b560517b7788150db6289402ab5e6c Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Fri, 17 Apr 2026 11:49:33 -0400 Subject: [PATCH 03/10] Update benchmark results: 165.5ms TPOT, 6.0 tok/s at LNC=2 (76% faster than LNC=1) Full sweep across 16-512 output tokens shows rock-stable TPOT. TTFT P50 = 1,420 ms. Throughput is 1.8x the LNC=1 baseline. --- .../models/Kimi-K2-Instruct-0905/README.md | 32 +++++++++++++------ .../test/integration/test_model.py | 2 +- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/contrib/models/Kimi-K2-Instruct-0905/README.md b/contrib/models/Kimi-K2-Instruct-0905/README.md index 7aa7e08e..c0781e82 100644 --- a/contrib/models/Kimi-K2-Instruct-0905/README.md +++ b/contrib/models/Kimi-K2-Instruct-0905/README.md @@ -68,28 +68,40 @@ NeuronX Distributed Inference implementation of Moonshot AI's Kimi-K2-Instruct-0 |------|--------|--------| | Smoke Test | PASS | Model compiles and loads on trn2.48xlarge | | Generation | PASS | Generates coherent text | -| Throughput | PASS | 5.2 tok/s at BS=1 (LNC=2) | +| Throughput | PASS | 6.0 tok/s at BS=1 (LNC=2) | ### Performance Metrics (Recommended: LNC=2, TP=32) | Metric | Value | |--------|-------| -| TPOT (per-token latency) | ~191 ms | -| Throughput (BS=1) | 5.2 tok/s | +| TPOT (per-token latency) | 165.5 ms | +| Throughput (BS=1) | 6.0 tok/s | +| TTFT (61 input tokens) | 1,420 ms | | Compile time (total) | 67 min | | Model load time | 30 min | | HBM I/O utilization | 17.55 GB / 24 GB | ### LNC=2 vs LNC=1 Comparison -LNC=2 (TP=32, EP=2) is **53% faster** than LNC=1 (TP=64, EP=2) because each +LNC=2 (TP=32, EP=2) is **76% faster** than LNC=1 (TP=64, EP=2) because each logical core gets 2x HBM bandwidth, and MoE decode is purely bandwidth-bound. | Config | TP | EP | Cores | TPOT | tok/s | Speedup | |--------|----|----|-------|------|-------|---------| -| LNC=2 (recommended) | 32 | 2 | 64 | ~191 ms | 5.2 | **+53%** | +| LNC=2 (recommended) | 32 | 2 | 64 | 165.5 ms | 6.0 | **+76%** | | LNC=1 | 64 | 2 | 128 | 297 ms | 3.4 | baseline | +### Token Generation Sweep (LNC=2, BS=1, seq_len=1024) + +| Output Tokens | TTFT P50 (ms) | TPOT P50 (ms) | tok/s | E2E P50 (ms) | +|---------------|---------------|----------------|-------|---------------| +| 16 | 1,420.4 | 166.38 | 6.0 | 3,916.1 | +| 32 | 1,419.8 | 165.58 | 6.0 | 6,553.0 | +| 64 | 1,419.7 | 165.56 | 6.0 | 11,849.8 | +| 128 | 1,419.8 | 165.48 | 6.0 | 22,435.8 | +| 256 | 1,419.9 | 165.42 | 6.0 | 43,604.1 | +| 512 | 1,420.0 | 165.47 | 6.0 | 85,974.4 | + ### Batching Results Batching provides **zero throughput improvement** on this model. The MoE computation is @@ -99,13 +111,13 @@ same aggregate throughput as BS=1. ### Performance Bottleneck -TPOT breakdown (estimated per ~191 ms token at LNC=2): +TPOT breakdown (estimated per ~165.5 ms token at LNC=2): -1. **MoE expert MLPs (~160 ms, ~84%):** 192 local experts x 2 matmuls per layer. +1. **MoE expert MLPs (~139 ms, ~84%):** 192 local experts x 2 matmuls per layer. FP8 weights are dequantized to BF16 before the NKI kernel. -2. **MLA attention (~16 ms, ~8%):** Weight absorption projections + KV cache. -3. **Router + all-to-all (~10 ms, ~5%):** Router TopK + expert dispatch across EP=2. -4. **Other (~5 ms, ~3%):** RMSNorm, residuals, lm_head. +2. **MLA attention (~13 ms, ~8%):** Weight absorption projections + KV cache. +3. **Router + all-to-all (~8 ms, ~5%):** Router TopK + expert dispatch across EP=2. +4. **Other (~5.5 ms, ~3%):** RMSNorm, residuals, lm_head. Primary optimization opportunity: native blockwise FP8 kernel in the nki-lib MoE TKG pipeline (currently blocked -- nki-lib requires per-channel FP8 scales). diff --git a/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py b/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py index 1e67a243..2787ce7c 100644 --- a/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py +++ b/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py @@ -306,7 +306,7 @@ def test_performance_tpot(compiled_model, tokenizer): median_tpot = sorted(tpots)[len(tpots) // 2] tok_per_sec = 1000.0 / median_tpot print(f"PASS: TPOT = {median_tpot:.1f} ms ({tok_per_sec:.1f} tok/s)") - # Kimi-K2 at BS=1 LNC=2: expected ~191 ms/token (5.2 tok/s) + # Kimi-K2 at BS=1 LNC=2: expected ~165 ms/token (6.0 tok/s) assert median_tpot < 500, f"TPOT {median_tpot:.1f}ms exceeds 500ms threshold" else: pytest.skip("Could not measure TPOT (no tokens generated)") From 0f8922e526a400b1d09f97d990fdb5188320d446 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Sun, 19 Apr 2026 04:06:51 -0400 Subject: [PATCH 04/10] Add SDK 2.29 compatibility notes and workaround for blockwise MoE CTE issue --- .../models/Kimi-K2-Instruct-0905/README.md | 70 ++++++++++++++++--- 1 file changed, 60 insertions(+), 10 deletions(-) diff --git a/contrib/models/Kimi-K2-Instruct-0905/README.md b/contrib/models/Kimi-K2-Instruct-0905/README.md index c0781e82..cba1552a 100644 --- a/contrib/models/Kimi-K2-Instruct-0905/README.md +++ b/contrib/models/Kimi-K2-Instruct-0905/README.md @@ -59,7 +59,7 @@ NeuronX Distributed Inference implementation of Moonshot AI's Kimi-K2-Instruct-0 ## Validation Results -**Validated:** 2026-04-17 +**Validated:** 2026-04-18 (SDK 2.28 and 2.29) **Recommended Configuration:** TP=32, EP=2, LNC=2, batch_size=1, seq_len=1024, blockwise FP8 ### Test Results @@ -192,22 +192,27 @@ NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 python your_script.py ## Compatibility Matrix -| Instance / SDK Version | 2.28+ | 2.27 and earlier | -|------------------------|-------|------------------| -| trn2.48xlarge (LNC=2, recommended) | Working (5.2 tok/s) | Not tested | -| trn2.48xlarge (LNC=1) | Working (3.4 tok/s) | Not tested | -| trn2.3xlarge | Not supported (needs TP=32, EP=2 = 64 cores) | Not supported | -| trn1.32xlarge | Not supported (needs 64 cores at LNC=2) | Not supported | -| inf2 | Not supported | Not supported | +| Instance / SDK Version | 2.29 | 2.28 | 2.27 and earlier | +|------------------------|------|------|------------------| +| trn2.48xlarge (LNC=2, recommended) | Working (6.0 tok/s)* | Working (6.0 tok/s) | Not tested | +| trn2.48xlarge (LNC=1) | Not tested | Working (3.4 tok/s) | Not tested | +| trn2.3xlarge | Not supported (needs TP=32, EP=2 = 64 cores) | Not supported | Not supported | +| trn1.32xlarge | Not supported (needs 64 cores at LNC=2) | Not supported | Not supported | +| inf2 | Not supported | Not supported | Not supported | + +\*SDK 2.29 requires a workaround for context encoding (see SDK 2.29 Notes below). ## Testing Run integration tests on a trn2.48xlarge: ```bash -# Activate Neuron venv +# Activate Neuron venv (SDK 2.28) source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate +# Or for SDK 2.29 (apply forward_blockwise workaround first, install tiktoken) +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + # Run tests (LNC=2, recommended) NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ pytest test/integration/test_model.py -v --capture=tee-sys @@ -242,6 +247,51 @@ NEFFs to the compiled model path. Subsequent runs with existing NEFFs skip compi * [moonshotai/Kimi-K2-Instruct-0905](https://huggingface.co/moonshotai/Kimi-K2-Instruct-0905) +## SDK 2.29 Notes + +SDK 2.29 (NxDI 0.9.17334) introduces a new `forward_blockwise` code path for MoE context +encoding. The default kernel dispatch (`_call_shard_hidden_kernel`) is a stub that raises +`NotImplementedError`. While nkilib IS installed in the DLAMI, the available alternative +kernels (`shard_on_intermediate`, `shard_on_block`) are incompatible with this model's +dimensions: + +- `use_shard_on_intermediate_dynamic_while`: MLIR verification failure due to small per-TP + intermediate dimension (64) not matching kernel tile expectations. +- `use_shard_on_block_dynamic_while` + `PING_PONG`: Compiles but produces incorrect outputs + (likely due to blockwise FP8 scale dequantization interaction with the kernel). + +**Recommended workaround:** Patch `expert_mlps_v2.py` in the `neuronx_distributed` package +to use `forward_all_experts_EP` instead of `forward_blockwise` when expert parallelism is +enabled: + +```python +# In neuronx_distributed/modules/moe/expert_mlps_v2.py, in the forward() method, +# find the context encoding dispatch (around line 1497): +# return self.forward_blockwise(...) +# Replace with: +if self.moe_expert_model_parallel_group.size() > 1: + return self.forward_all_experts_EP(hidden_states, expert_affinities, expert_index) +return self.forward_blockwise(hidden_states, expert_affinities, expert_index, ...) +``` + +**Impact:** Token generation (TPOT) is unaffected (166.1 ms, identical to SDK 2.28). Context +encoding (TTFT) is ~7x slower (10,185 ms vs 1,420 ms) because `forward_all_experts_EP` sends +every token through every local expert rather than using the optimized blockwise dispatch. For +long-output workloads this is negligible; for TTFT-sensitive workloads, use SDK 2.28. + +### SDK 2.29 Benchmark (LNC=2, TP=32, EP=2, BS=1, seq_len=1024) + +| Output Tokens | TTFT P50 (ms) | TPOT P50 (ms) | tok/s | E2E P50 (ms) | +|---------------|---------------|----------------|-------|---------------| +| 16 | 10,184.9 | 166.27 | 6.0 | 12,678.9 | +| 32 | 10,185.1 | 166.13 | 6.0 | 15,335.3 | +| 64 | 10,184.6 | 166.10 | 6.0 | 20,651.2 | +| 128 | 10,184.7 | 166.15 | 6.0 | 31,286.4 | +| 256 | 10,184.5 | 166.03 | 6.0 | 52,522.3 | +| 512 | 10,184.6 | 166.06 | 6.0 | 95,040.9 | + +**Additional SDK 2.29 setup:** Install `tiktoken` (`pip install tiktoken`) in the venv. + ## Known Limitations - **No on-device sampling:** The model uses CPU greedy sampling because the vocabulary @@ -264,4 +314,4 @@ NEFFs to the compiled model path. Subsequent runs with existing NEFFs skip compi Annapurna Labs -**Last Updated:** 2026-04-17 +**Last Updated:** 2026-04-18 From 29f76f4b120d87e58b46e75d5d8b6def64883575 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Sun, 19 Apr 2026 23:56:25 -0400 Subject: [PATCH 05/10] Document batch size finding: NxDI processes sequences sequentially, BS>1 provides no throughput benefit --- contrib/models/Kimi-K2-Instruct-0905/README.md | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/contrib/models/Kimi-K2-Instruct-0905/README.md b/contrib/models/Kimi-K2-Instruct-0905/README.md index cba1552a..292dd0ed 100644 --- a/contrib/models/Kimi-K2-Instruct-0905/README.md +++ b/contrib/models/Kimi-K2-Instruct-0905/README.md @@ -251,9 +251,9 @@ NEFFs to the compiled model path. Subsequent runs with existing NEFFs skip compi SDK 2.29 (NxDI 0.9.17334) introduces a new `forward_blockwise` code path for MoE context encoding. The default kernel dispatch (`_call_shard_hidden_kernel`) is a stub that raises -`NotImplementedError`. While nkilib IS installed in the DLAMI, the available alternative -kernels (`shard_on_intermediate`, `shard_on_block`) are incompatible with this model's -dimensions: +`NotImplementedError`. While nkilib IS installed in the DLAMI (bundled version matches the +standalone `nki-library` April 2026 release), the available alternative kernels +(`shard_on_intermediate`, `shard_on_block`) are incompatible with this model's dimensions: - `use_shard_on_intermediate_dynamic_while`: MLIR verification failure due to small per-TP intermediate dimension (64) not matching kernel tile expectations. @@ -303,9 +303,12 @@ long-output workloads this is negligible; for TTFT-sensitive workloads, use SDK weights or slight router bias approximation. Mitigated by masking EOS for the first few generation tokens (`min_tokens_before_eos=3`). -- **Batching does not improve throughput:** The MoE computation is bandwidth-bound - (192 expert weight loads per step), so higher batch sizes increase latency linearly - without improving aggregate throughput. +- **Batching does not improve throughput:** NxDI compiles HLO with per-sequence shapes + (`[1, seq_len]` for CTE, `[1, 1]` for TKG) regardless of `max_batch_size`. Multiple + sequences in a batch are processed sequentially through the same NEFF. Combined with + the bandwidth-bound nature of MoE (192 expert weight loads per decode step), BS>1 + provides no aggregate throughput benefit. Verified: BS=2 compile produces identical + NEFF shapes to BS=1. - **Compiler flags have no measurable impact:** -O3 with DGE vs -O1 showed 0% difference, confirming the bottleneck is weight bandwidth, not compute or scheduling. @@ -314,4 +317,4 @@ long-output workloads this is negligible; for TTFT-sensitive workloads, use SDK Annapurna Labs -**Last Updated:** 2026-04-18 +**Last Updated:** 2026-04-20 From 2c3e782d30b9a01a24c2036ee2854cf622e011a7 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Thu, 23 Apr 2026 15:13:11 -0400 Subject: [PATCH 06/10] Update SDK 2.29 notes: confirm blockwise kernel padding produces depressed logits, EP workaround verified correct --- contrib/models/Kimi-K2-Instruct-0905/README.md | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/contrib/models/Kimi-K2-Instruct-0905/README.md b/contrib/models/Kimi-K2-Instruct-0905/README.md index 292dd0ed..ccfdf104 100644 --- a/contrib/models/Kimi-K2-Instruct-0905/README.md +++ b/contrib/models/Kimi-K2-Instruct-0905/README.md @@ -231,8 +231,9 @@ NEFFs to the compiled model path. Subsequent runs with existing NEFFs skip compi ## Prerequisites 1. **Selective loading threshold patch:** On the target instance, patch - `neuronx_distributed/modules/moe/model_utils.py` to set the selective loading - threshold to 0.0 (default is too high for 384 experts). + `neuronx_distributed/modules/moe/model_utils.py` to set + `DEFAULT_SELECTIVE_LOADING_THRESHOLD = 0.0` (default is 0.05 on SDK 2.28 or + 1.0 on SDK 2.29, both too high for 384 experts). 2. **Model weights:** Download from HuggingFace: ```bash @@ -256,9 +257,14 @@ standalone `nki-library` April 2026 release), the available alternative kernels (`shard_on_intermediate`, `shard_on_block`) are incompatible with this model's dimensions: - `use_shard_on_intermediate_dynamic_while`: MLIR verification failure due to small per-TP - intermediate dimension (64) not matching kernel tile expectations. -- `use_shard_on_block_dynamic_while` + `PING_PONG`: Compiles but produces incorrect outputs - (likely due to blockwise FP8 scale dequantization interaction with the kernel). + intermediate dimension (2048/32=64, sharded to 32 at LNC=2) being below the kernel's + minimum TILE_SIZE of 128 for the `nc_matmul` stationary free dimension. +- `use_shard_on_block_dynamic_while` + `PING_PONG`: Compiles but produces incorrect outputs. +- **Zero-padding fix attempted:** Padding `I_TP_sharded` from 32 to 128 in the + `shard_on_intermediate` kernel eliminates the MLIR error and compiles successfully, but + produces logits ~10 points lower than expected (13.1 vs 23.0). The EP workaround on the + same instance produces correct logits (23.0), confirming the issue is in the kernel's + handling of padded dimensions. Filed as internal ticket V2185857494. **Recommended workaround:** Patch `expert_mlps_v2.py` in the `neuronx_distributed` package to use `forward_all_experts_EP` instead of `forward_blockwise` when expert parallelism is @@ -317,4 +323,4 @@ long-output workloads this is negligible; for TTFT-sensitive workloads, use SDK Annapurna Labs -**Last Updated:** 2026-04-20 +**Last Updated:** 2026-04-23 From 0e637df3dfb684916547209567d694ed3dcccae4 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Sat, 25 Apr 2026 01:52:50 -0400 Subject: [PATCH 07/10] Identify root cause: SDK 2.29 removed shard_hidden kernel, replacement CTE kernels produce wrong output with EP=2 --- .../models/Kimi-K2-Instruct-0905/README.md | 78 +++++++++++++------ 1 file changed, 56 insertions(+), 22 deletions(-) diff --git a/contrib/models/Kimi-K2-Instruct-0905/README.md b/contrib/models/Kimi-K2-Instruct-0905/README.md index ccfdf104..e74dfdd7 100644 --- a/contrib/models/Kimi-K2-Instruct-0905/README.md +++ b/contrib/models/Kimi-K2-Instruct-0905/README.md @@ -59,7 +59,7 @@ NeuronX Distributed Inference implementation of Moonshot AI's Kimi-K2-Instruct-0 ## Validation Results -**Validated:** 2026-04-18 (SDK 2.28 and 2.29) +**Validated:** 2026-04-24 (SDK 2.28 and 2.29) **Recommended Configuration:** TP=32, EP=2, LNC=2, batch_size=1, seq_len=1024, blockwise FP8 ### Test Results @@ -194,13 +194,15 @@ NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 python your_script.py | Instance / SDK Version | 2.29 | 2.28 | 2.27 and earlier | |------------------------|------|------|------------------| -| trn2.48xlarge (LNC=2, recommended) | Working (6.0 tok/s)* | Working (6.0 tok/s) | Not tested | +| trn2.48xlarge (LNC=2, recommended) | TKG working, CTE requires EP workaround* | **Working (6.0 tok/s)** | Not tested | | trn2.48xlarge (LNC=1) | Not tested | Working (3.4 tok/s) | Not tested | | trn2.3xlarge | Not supported (needs TP=32, EP=2 = 64 cores) | Not supported | Not supported | | trn1.32xlarge | Not supported (needs 64 cores at LNC=2) | Not supported | Not supported | | inf2 | Not supported | Not supported | Not supported | -\*SDK 2.29 requires a workaround for context encoding (see SDK 2.29 Notes below). +\*SDK 2.29 blockwise CTE kernels (shard_on_I, shard_on_block) produce wrong output with EP=2. +The `forward_all_experts_EP` workaround gives correct output but 7x slower TTFT. See SDK 2.29 +Notes below for root cause analysis and workaround. ## Testing @@ -250,21 +252,49 @@ NEFFs to the compiled model path. Subsequent runs with existing NEFFs skip compi ## SDK 2.29 Notes -SDK 2.29 (NxDI 0.9.17334) introduces a new `forward_blockwise` code path for MoE context -encoding. The default kernel dispatch (`_call_shard_hidden_kernel`) is a stub that raises -`NotImplementedError`. While nkilib IS installed in the DLAMI (bundled version matches the -standalone `nki-library` April 2026 release), the available alternative kernels -(`shard_on_intermediate`, `shard_on_block`) are incompatible with this model's dimensions: - -- `use_shard_on_intermediate_dynamic_while`: MLIR verification failure due to small per-TP - intermediate dimension (2048/32=64, sharded to 32 at LNC=2) being below the kernel's - minimum TILE_SIZE of 128 for the `nc_matmul` stationary free dimension. -- `use_shard_on_block_dynamic_while` + `PING_PONG`: Compiles but produces incorrect outputs. -- **Zero-padding fix attempted:** Padding `I_TP_sharded` from 32 to 128 in the - `shard_on_intermediate` kernel eliminates the MLIR error and compiles successfully, but - produces logits ~10 points lower than expected (13.1 vs 23.0). The EP workaround on the - same instance produces correct logits (23.0), confirming the issue is in the kernel's - handling of padded dimensions. Filed as internal ticket V2185857494. +SDK 2.29 (NxDI 0.9.17334) removed the `blockwise_mm_baseline_shard_hidden` NKI kernel that +was used for MoE context encoding inference on SDK 2.28. The kernel was in +`neuronxcc.nki._private.blockwise_mm` which no longer exists in SDK 2.29 (the entire +`_private` blockwise module was removed). The replacement dynamic-while kernels +(`shard_on_intermediate` from nkilib `bwmm_shard_on_I.py`, and `shard_on_block` from nkilib +`bwmm_shard_on_block.py`) produce incorrect output for this model's configuration. + +### Root Cause + +On SDK 2.28, the blockwise CTE dispatch in `blockwise.py` defaulted to +`_call_shard_hidden_kernel` (a static kernel, no dynamic while loop). This kernel was the +**only** blockwise CTE kernel validated to work with Kimi-K2's configuration (384 experts, +EP=2, I_TP=64). In SDK 2.29: + +1. `_call_shard_hidden_kernel` now raises `NotImplementedError` (inference path removed) +2. `blockwise_mm_baseline_shard_hidden` still appears in training kernel imports but fails + to load (`No module named 'neuronxcc.nki._private.blockwise_mm'`) +3. The nkilib replacement kernels (`shard_on_I`, `shard_on_block`) were never ported from + the static `shard_hidden` approach — they use fundamentally different algorithms + +### Blockwise CTE Investigation Summary + +All available blockwise CTE kernel paths were tested exhaustively: + +| Kernel Path | Config | Compile | Output | TTFT | +|------------|--------|---------|--------|------| +| `shard_hidden` (SDK 2.28 default) | default | N/A | N/A | Raises `NotImplementedError` | +| `shard_on_intermediate` (non-hybrid) | `use_shard_on_intermediate_dynamic_while=True` | PASS | "Kimi," (wrong) | 2,386 ms | +| `shard_on_intermediate` (hybrid) | `use_shard_on_intermediate_dynamic_while=True` | PASS | "Moonshot AI." (wrong) | 1,159 ms | +| `shard_on_block` + PING_PONG | `use_shard_on_block_dynamic_while=True` | PASS | ",,," (wrong) | 1,676 ms | +| `shard_on_intermediate` (unpatched) | `use_shard_on_intermediate_dynamic_while=True` | FAIL | N/A | MLIR: `I_TP_sharded=32 < 128` | +| PyTorch fallback | `use_torch_block_wise=True` | FAIL | N/A | `selective loading + EP` error | +| EP workaround (`forward_all_experts_EP`) | patched dispatch | PASS | **"Paris" (correct)** | 10,185 ms | + +Key observations: +- Wrong outputs are coherent text from the system prompt ("Kimi", "Moonshot AI") suggesting + corrupted CTE hidden states cause the model to regurgitate system message content +- The `shard_on_intermediate` kernel passes standalone tests (5/5 cosine > 0.99998) — the + kernel math is correct in isolation +- **Both** shard_on_I and shard_on_block produce wrong output, confirming the issue is not + specific to either kernel but to the dynamic-while blockwise approach with this model's + EP=2 / I_TP=64 configuration +- Filed as internal ticket V2185857494 **Recommended workaround:** Patch `expert_mlps_v2.py` in the `neuronx_distributed` package to use `forward_all_experts_EP` instead of `forward_blockwise` when expert parallelism is @@ -285,6 +315,10 @@ encoding (TTFT) is ~7x slower (10,185 ms vs 1,420 ms) because `forward_all_exper every token through every local expert rather than using the optimized blockwise dispatch. For long-output workloads this is negligible; for TTFT-sensitive workloads, use SDK 2.28. +**Recommendation:** Use SDK 2.28 for production deployments until the blockwise CTE regression +is resolved. The `shard_hidden` kernel available in SDK 2.28 provides both correct output and +optimal TTFT (1,420 ms). + ### SDK 2.29 Benchmark (LNC=2, TP=32, EP=2, BS=1, seq_len=1024) | Output Tokens | TTFT P50 (ms) | TPOT P50 (ms) | tok/s | E2E P50 (ms) | @@ -300,9 +334,9 @@ long-output workloads this is negligible; for TTFT-sensitive workloads, use SDK ## Known Limitations -- **No on-device sampling:** The model uses CPU greedy sampling because the vocabulary - size (163840) is not divisible by common TP degrees, causing shape mismatches in the - on-device sampling kernel. +- **On-device sampling (ODS):** The model supports on-device sampling with `top_k=1` + (greedy) for vocabulary size 163840. ODS avoids transferring full logits to CPU, but + also means raw logits are not available for analysis during inference. - **Elevated EOS logit:** The `<|im_end|>` token (ID 163586) has an elevated logit in early generation steps, likely due to the FP8->BF16 dequantization of shared expert @@ -323,4 +357,4 @@ long-output workloads this is negligible; for TTFT-sensitive workloads, use SDK Annapurna Labs -**Last Updated:** 2026-04-23 +**Last Updated:** 2026-04-25 From 84bff741c1a9497877e036461e56bccbf271c8c6 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Sat, 25 Apr 2026 20:33:53 -0400 Subject: [PATCH 08/10] Document shard_hidden kernel porting attempt: NKI 0.3.0 fixes and compiler BIR verifier bug on SDK 2.29 --- .../models/Kimi-K2-Instruct-0905/README.md | 39 ++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/contrib/models/Kimi-K2-Instruct-0905/README.md b/contrib/models/Kimi-K2-Instruct-0905/README.md index e74dfdd7..a966b3a4 100644 --- a/contrib/models/Kimi-K2-Instruct-0905/README.md +++ b/contrib/models/Kimi-K2-Instruct-0905/README.md @@ -315,6 +315,43 @@ encoding (TTFT) is ~7x slower (10,185 ms vs 1,420 ms) because `forward_all_exper every token through every local expert rather than using the optimized blockwise dispatch. For long-output workloads this is negligible; for TTFT-sensitive workloads, use SDK 2.28. +### SDK 2.28 Kernel Porting Attempt + +We attempted to extract the `blockwise_mm_baseline_shard_hidden` kernel source from the +SDK 2.28 DLAMI and port it to run on SDK 2.29 with NKI 0.3.0. The kernel is 2961 lines of +NKI code with a static unrolled loop (no dynamic while). Three NKI 0.3.0 compatibility +issues were identified and fixed: + +1. **Output parameter immutability:** NKI 0.3.0 treats kernel parameters as immutable. + Fixed by removing `output` from the function signature and allocating internally via + `nl.ndarray((T, H), dtype=hidden_states.dtype, buffer=nl.shared_hbm)`. + +2. **`@nki.jit` mode:** Must use `nki.jit(fn, mode='torchxla')`, not `mode='trace'`. + The trace mode creates a `TraceKernel` that cannot handle torch tensors. The torchxla + mode creates a `PyTorchXLAKernel` matching how nkilib kernels are imported. + +3. **Helper function returning NKI data:** NKI 0.3.0 asserts "function without nki data as + input should not return nki data". Inlined the `create_block_hidden_states` helper body + directly in the kernel. + +After all NKI 0.3.0 fixes, the kernel generates valid HLO and TKG compiles successfully. +However, **CTE compilation fails** with a BIR verifier assertion failure in neuronx-cc +2.24.5133: + +``` +neuronxcc.driver.Exceptions.CompilerInternalError (exit code 70) +Assertion failure: bad_use.empty() [inst_visitor.cpp:632] +Dead memory locations in subgraphs nc00/sg01, nc00/sg02 +``` + +This is a **neuronx-cc compiler regression** — the kernel's static unrolled loop creates +dead memory allocations that the BIR verifier rejects. Tested with `-O1`, `-O2`, and bare +minimum compiler flags — all fail identically with the same HLO hash. + +**Conclusion:** The shard_hidden kernel cannot be used on SDK 2.29 due to a compiler bug. +The EP workaround (`forward_all_experts_EP`) is the only viable path until either the +compiler is fixed or a new blockwise kernel validated for EP=2 is added to nkilib. + **Recommendation:** Use SDK 2.28 for production deployments until the blockwise CTE regression is resolved. The `shard_hidden` kernel available in SDK 2.28 provides both correct output and optimal TTFT (1,420 ms). @@ -357,4 +394,4 @@ optimal TTFT (1,420 ms). Annapurna Labs -**Last Updated:** 2026-04-25 +**Last Updated:** 2026-04-26 From 6a1d331883de7bcadc7364da2ce18a403600de1b Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Thu, 30 Apr 2026 12:28:56 -0400 Subject: [PATCH 09/10] Switch to TP=64 EP=1 LNC=2 with selective loading: 2x throughput improvement - Config: TP=64 EP=1 LNC=2 seq_len=512 (was TP=32 EP=2 LNC=2 seq_len=1024) - Fix: pass capacity_factor to RoutedExpertsMLPOpsConfig (was defaulting to None) - Selective loading (threshold=1.0) loads only 8/384 active experts per TKG step - TPOT: 144.5 ms / 6.9 tok/s (was 165.5 ms / 6.0 tok/s, +105% vs LNC=1 baseline) - Compile time: 16 min (was 67 min with EP=2, 3.5h without selective loading) - Requires SDK 2.29 (neuronx-cc 2.24+); SDK 2.28 CTE fails at EP=1 - Simplified compiler args: always -O1 (no EP-conditional -O3 needed) - README updated with new recommended config, performance data, compatibility matrix --- .../models/Kimi-K2-Instruct-0905/README.md | 242 +++++------------- .../src/modeling_kimi_k2.py | 17 +- .../test/integration/test_model.py | 18 +- 3 files changed, 76 insertions(+), 201 deletions(-) diff --git a/contrib/models/Kimi-K2-Instruct-0905/README.md b/contrib/models/Kimi-K2-Instruct-0905/README.md index a966b3a4..bede6bc2 100644 --- a/contrib/models/Kimi-K2-Instruct-0905/README.md +++ b/contrib/models/Kimi-K2-Instruct-0905/README.md @@ -53,74 +53,56 @@ NeuronX Distributed Inference implementation of Moonshot AI's Kimi-K2-Instruct-0 - `_apply_ep_scale_fix`: Prevents EP-sharding of per-channel FP8 scale tensors (shape [1,1,W]). - `_apply_blockwise_scale_stride_fix`: Forces stride=1 for blockwise scale partitioning. -- **Selective Loading Threshold:** Must be patched to 0.0 in - `neuronx_distributed/modules/moe/model_utils.py` on the target instance to ensure all - 384 expert weights load correctly. +- **Selective Loading:** Uses the SDK default threshold (1.0). At EP=1, selective loading + only loads the 8 active experts per token during TKG, producing a far simpler graph + (6.2 min compile vs 3.5h) and 2.7x faster TPOT. Do NOT patch the threshold to 0.0. ## Validation Results -**Validated:** 2026-04-24 (SDK 2.28 and 2.29) -**Recommended Configuration:** TP=32, EP=2, LNC=2, batch_size=1, seq_len=1024, blockwise FP8 +**Validated:** 2026-04-30 (SDK 2.29) +**Recommended Configuration:** TP=64, EP=1, LNC=2, batch_size=1, seq_len=512, blockwise FP8 ### Test Results | Test | Status | Result | |------|--------|--------| | Smoke Test | PASS | Model compiles and loads on trn2.48xlarge | -| Generation | PASS | Generates coherent text | -| Throughput | PASS | 6.0 tok/s at BS=1 (LNC=2) | +| Generation | PASS | Generates coherent text ("The capital of France is Paris.") | +| Coherence | PASS | Coherent quantum computing explanation | +| Throughput | PASS | 6.9 tok/s at BS=1 (LNC=2, EP=1) | -### Performance Metrics (Recommended: LNC=2, TP=32) +### Performance Metrics (Recommended: TP=64, EP=1, LNC=2) | Metric | Value | |--------|-------| -| TPOT (per-token latency) | 165.5 ms | -| Throughput (BS=1) | 6.0 tok/s | -| TTFT (61 input tokens) | 1,420 ms | -| Compile time (total) | 67 min | -| Model load time | 30 min | -| HBM I/O utilization | 17.55 GB / 24 GB | - -### LNC=2 vs LNC=1 Comparison - -LNC=2 (TP=32, EP=2) is **76% faster** than LNC=1 (TP=64, EP=2) because each -logical core gets 2x HBM bandwidth, and MoE decode is purely bandwidth-bound. - -| Config | TP | EP | Cores | TPOT | tok/s | Speedup | -|--------|----|----|-------|------|-------|---------| -| LNC=2 (recommended) | 32 | 2 | 64 | 165.5 ms | 6.0 | **+76%** | -| LNC=1 | 64 | 2 | 128 | 297 ms | 3.4 | baseline | - -### Token Generation Sweep (LNC=2, BS=1, seq_len=1024) - -| Output Tokens | TTFT P50 (ms) | TPOT P50 (ms) | tok/s | E2E P50 (ms) | -|---------------|---------------|----------------|-------|---------------| -| 16 | 1,420.4 | 166.38 | 6.0 | 3,916.1 | -| 32 | 1,419.8 | 165.58 | 6.0 | 6,553.0 | -| 64 | 1,419.7 | 165.56 | 6.0 | 11,849.8 | -| 128 | 1,419.8 | 165.48 | 6.0 | 22,435.8 | -| 256 | 1,419.9 | 165.42 | 6.0 | 43,604.1 | -| 512 | 1,420.0 | 165.47 | 6.0 | 85,974.4 | +| TPOT (per-token latency) | 144.5 ms | +| Throughput (BS=1) | 6.9 tok/s | +| Compile time (CTE + TKG) | 16 min | +| Model load time | 17 min | + +### Configuration Comparison + +| Config | TP | EP | LNC | Cores | TPOT | tok/s | Compile | Speedup | +|--------|----|----|-----|-------|------|-------|---------|---------| +| **EP=1 selective (recommended)** | 64 | 1 | 2 | 64 | **144.5 ms** | **6.9** | **16 min** | **+105%** | +| EP=2 LNC=2 (previous) | 32 | 2 | 2 | 64 | 165.5 ms | 6.0 | 67 min | +76% | +| EP=2 LNC=1 (baseline) | 64 | 2 | 1 | 128 | 297.4 ms | 3.4 | ~60 min | baseline | ### Batching Results Batching provides **zero throughput improvement** on this model. The MoE computation is -perfectly bandwidth-bound -- each TKG step must load all 192 local expert weight matrices -from HBM regardless of batch size. BS=4 TPOT scales linearly (1,191 ms), yielding the +perfectly bandwidth-bound -- each TKG step must load 8 active experts' weight matrices +from HBM regardless of batch size. BS=4 TPOT scales linearly, yielding the same aggregate throughput as BS=1. ### Performance Bottleneck -TPOT breakdown (estimated per ~165.5 ms token at LNC=2): - -1. **MoE expert MLPs (~139 ms, ~84%):** 192 local experts x 2 matmuls per layer. - FP8 weights are dequantized to BF16 before the NKI kernel. -2. **MLA attention (~13 ms, ~8%):** Weight absorption projections + KV cache. -3. **Router + all-to-all (~8 ms, ~5%):** Router TopK + expert dispatch across EP=2. -4. **Other (~5.5 ms, ~3%):** RMSNorm, residuals, lm_head. +TPOT breakdown (estimated per ~144.5 ms token at TP=64, EP=1, LNC=2): -Primary optimization opportunity: native blockwise FP8 kernel in the nki-lib MoE TKG -pipeline (currently blocked -- nki-lib requires per-channel FP8 scales). +The decode step loads only 8 active experts per token (selective loading), but each expert's +weight matrices are TP-sharded across 64 cores. MoE decode remains bandwidth-bound — the +primary optimization opportunity is native per-channel FP8 support in the NKI kernel, which +could reduce weight transfer by ~50% and bring TPOT to ~22 ms (Task 017). ## Usage @@ -140,20 +122,20 @@ compiled_path = "/path/to/compiled" with open(os.path.join(model_path, "config.json")) as f: hf_config = json.load(f) -# Configure for trn2.48xlarge (LNC=2, recommended) +# Configure for trn2.48xlarge (TP=64, EP=1, LNC=2, recommended) neuron_config = MoENeuronConfig( - tp_degree=32, - ep_degree=2, + tp_degree=64, + ep_degree=1, logical_nc_config=2, max_batch_size=1, - seq_len=1024, + seq_len=512, n_active_tokens=128, torch_dtype="bfloat16", capacity_factor=1.0, glu_mlp=True, - moe_ep_degree=2, - moe_tp_degree=32, - context_encoding_buckets=[128, 1024], + moe_ep_degree=1, + moe_tp_degree=64, + context_encoding_buckets=[128, 512], router_config=RouterConfig(act_fn="sigmoid", dtype="float32"), # FP8 quantization quantized=True, @@ -175,8 +157,8 @@ config = KimiK2InferenceConfig(neuron_config=neuron_config, **hf_kwargs) # Compile and load model = NeuronKimiK2ForCausalLM(model_path, config) -model.compile(compiled_path) # ~67 min -model.load(compiled_path) # ~30 min +model.compile(compiled_path) # ~16 min +model.load(compiled_path) # ~17 min # Generate (CPU greedy sampling, no on-device sampling) from transformers import AutoTokenizer @@ -194,9 +176,10 @@ NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 python your_script.py | Instance / SDK Version | 2.29 | 2.28 | 2.27 and earlier | |------------------------|------|------|------------------| -| trn2.48xlarge (LNC=2, recommended) | TKG working, CTE requires EP workaround* | **Working (6.0 tok/s)** | Not tested | -| trn2.48xlarge (LNC=1) | Not tested | Working (3.4 tok/s) | Not tested | -| trn2.3xlarge | Not supported (needs TP=32, EP=2 = 64 cores) | Not supported | Not supported | +| trn2.48xlarge (TP=64, EP=1, LNC=2, recommended) | **Working (6.9 tok/s)** | CTE compile fails (neuronx-cc 2.23 BIR error) | Not tested | +| trn2.48xlarge (TP=32, EP=2, LNC=2) | TKG working, CTE requires EP workaround* | Working (6.0 tok/s) | Not tested | +| trn2.48xlarge (TP=64, EP=2, LNC=1) | Not tested | Working (3.4 tok/s) | Not tested | +| trn2.3xlarge | Not supported (needs 64 cores) | Not supported | Not supported | | trn1.32xlarge | Not supported (needs 64 cores at LNC=2) | Not supported | Not supported | | inf2 | Not supported | Not supported | Not supported | @@ -209,13 +192,11 @@ Notes below for root cause analysis and workaround. Run integration tests on a trn2.48xlarge: ```bash -# Activate Neuron venv (SDK 2.28) -source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate - -# Or for SDK 2.29 (apply forward_blockwise workaround first, install tiktoken) +# Activate Neuron venv (SDK 2.29, recommended) source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +pip install tiktoken # Required for tokenizer -# Run tests (LNC=2, recommended) +# Run tests (TP=64, EP=1, LNC=2) NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ pytest test/integration/test_model.py -v --capture=tee-sys ``` @@ -227,15 +208,13 @@ NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ python test/integration/test_model.py ``` -**Note:** Compilation takes ~67 min and loading takes ~30 min. The first run will compile +**Note:** Compilation takes ~16 min and loading takes ~17 min. The first run will compile NEFFs to the compiled model path. Subsequent runs with existing NEFFs skip compilation. ## Prerequisites -1. **Selective loading threshold patch:** On the target instance, patch - `neuronx_distributed/modules/moe/model_utils.py` to set - `DEFAULT_SELECTIVE_LOADING_THRESHOLD = 0.0` (default is 0.05 on SDK 2.28 or - 1.0 on SDK 2.29, both too high for 384 experts). +1. **SDK 2.29:** Requires Neuron SDK 2.29 (neuronx-cc 2.24+). SDK 2.28 cannot compile + CTE at EP=1 due to a BIR verification error in neuronx-cc 2.23. 2. **Model weights:** Download from HuggingFace: ```bash @@ -252,122 +231,23 @@ NEFFs to the compiled model path. Subsequent runs with existing NEFFs skip compi ## SDK 2.29 Notes -SDK 2.29 (NxDI 0.9.17334) removed the `blockwise_mm_baseline_shard_hidden` NKI kernel that -was used for MoE context encoding inference on SDK 2.28. The kernel was in -`neuronxcc.nki._private.blockwise_mm` which no longer exists in SDK 2.29 (the entire -`_private` blockwise module was removed). The replacement dynamic-while kernels -(`shard_on_intermediate` from nkilib `bwmm_shard_on_I.py`, and `shard_on_block` from nkilib -`bwmm_shard_on_block.py`) produce incorrect output for this model's configuration. - -### Root Cause - -On SDK 2.28, the blockwise CTE dispatch in `blockwise.py` defaulted to -`_call_shard_hidden_kernel` (a static kernel, no dynamic while loop). This kernel was the -**only** blockwise CTE kernel validated to work with Kimi-K2's configuration (384 experts, -EP=2, I_TP=64). In SDK 2.29: - -1. `_call_shard_hidden_kernel` now raises `NotImplementedError` (inference path removed) -2. `blockwise_mm_baseline_shard_hidden` still appears in training kernel imports but fails - to load (`No module named 'neuronxcc.nki._private.blockwise_mm'`) -3. The nkilib replacement kernels (`shard_on_I`, `shard_on_block`) were never ported from - the static `shard_hidden` approach — they use fundamentally different algorithms - -### Blockwise CTE Investigation Summary - -All available blockwise CTE kernel paths were tested exhaustively: - -| Kernel Path | Config | Compile | Output | TTFT | -|------------|--------|---------|--------|------| -| `shard_hidden` (SDK 2.28 default) | default | N/A | N/A | Raises `NotImplementedError` | -| `shard_on_intermediate` (non-hybrid) | `use_shard_on_intermediate_dynamic_while=True` | PASS | "Kimi," (wrong) | 2,386 ms | -| `shard_on_intermediate` (hybrid) | `use_shard_on_intermediate_dynamic_while=True` | PASS | "Moonshot AI." (wrong) | 1,159 ms | -| `shard_on_block` + PING_PONG | `use_shard_on_block_dynamic_while=True` | PASS | ",,," (wrong) | 1,676 ms | -| `shard_on_intermediate` (unpatched) | `use_shard_on_intermediate_dynamic_while=True` | FAIL | N/A | MLIR: `I_TP_sharded=32 < 128` | -| PyTorch fallback | `use_torch_block_wise=True` | FAIL | N/A | `selective loading + EP` error | -| EP workaround (`forward_all_experts_EP`) | patched dispatch | PASS | **"Paris" (correct)** | 10,185 ms | - -Key observations: -- Wrong outputs are coherent text from the system prompt ("Kimi", "Moonshot AI") suggesting - corrupted CTE hidden states cause the model to regurgitate system message content -- The `shard_on_intermediate` kernel passes standalone tests (5/5 cosine > 0.99998) — the - kernel math is correct in isolation -- **Both** shard_on_I and shard_on_block produce wrong output, confirming the issue is not - specific to either kernel but to the dynamic-while blockwise approach with this model's - EP=2 / I_TP=64 configuration -- Filed as internal ticket V2185857494 - -**Recommended workaround:** Patch `expert_mlps_v2.py` in the `neuronx_distributed` package -to use `forward_all_experts_EP` instead of `forward_blockwise` when expert parallelism is -enabled: - -```python -# In neuronx_distributed/modules/moe/expert_mlps_v2.py, in the forward() method, -# find the context encoding dispatch (around line 1497): -# return self.forward_blockwise(...) -# Replace with: -if self.moe_expert_model_parallel_group.size() > 1: - return self.forward_all_experts_EP(hidden_states, expert_affinities, expert_index) -return self.forward_blockwise(hidden_states, expert_affinities, expert_index, ...) -``` - -**Impact:** Token generation (TPOT) is unaffected (166.1 ms, identical to SDK 2.28). Context -encoding (TTFT) is ~7x slower (10,185 ms vs 1,420 ms) because `forward_all_experts_EP` sends -every token through every local expert rather than using the optimized blockwise dispatch. For -long-output workloads this is negligible; for TTFT-sensitive workloads, use SDK 2.28. - -### SDK 2.28 Kernel Porting Attempt - -We attempted to extract the `blockwise_mm_baseline_shard_hidden` kernel source from the -SDK 2.28 DLAMI and port it to run on SDK 2.29 with NKI 0.3.0. The kernel is 2961 lines of -NKI code with a static unrolled loop (no dynamic while). Three NKI 0.3.0 compatibility -issues were identified and fixed: - -1. **Output parameter immutability:** NKI 0.3.0 treats kernel parameters as immutable. - Fixed by removing `output` from the function signature and allocating internally via - `nl.ndarray((T, H), dtype=hidden_states.dtype, buffer=nl.shared_hbm)`. - -2. **`@nki.jit` mode:** Must use `nki.jit(fn, mode='torchxla')`, not `mode='trace'`. - The trace mode creates a `TraceKernel` that cannot handle torch tensors. The torchxla - mode creates a `PyTorchXLAKernel` matching how nkilib kernels are imported. - -3. **Helper function returning NKI data:** NKI 0.3.0 asserts "function without nki data as - input should not return nki data". Inlined the `create_block_hidden_states` helper body - directly in the kernel. - -After all NKI 0.3.0 fixes, the kernel generates valid HLO and TKG compiles successfully. -However, **CTE compilation fails** with a BIR verifier assertion failure in neuronx-cc -2.24.5133: - -``` -neuronxcc.driver.Exceptions.CompilerInternalError (exit code 70) -Assertion failure: bad_use.empty() [inst_visitor.cpp:632] -Dead memory locations in subgraphs nc00/sg01, nc00/sg02 -``` - -This is a **neuronx-cc compiler regression** — the kernel's static unrolled loop creates -dead memory allocations that the BIR verifier rejects. Tested with `-O1`, `-O2`, and bare -minimum compiler flags — all fail identically with the same HLO hash. - -**Conclusion:** The shard_hidden kernel cannot be used on SDK 2.29 due to a compiler bug. -The EP workaround (`forward_all_experts_EP`) is the only viable path until either the -compiler is fixed or a new blockwise kernel validated for EP=2 is added to nkilib. +SDK 2.29 (NxDI 0.9.17334) is the **recommended SDK** for this model. The EP=1 configuration +avoids the blockwise CTE kernel regressions that affected the EP=2 configuration on SDK 2.29. -**Recommendation:** Use SDK 2.28 for production deployments until the blockwise CTE regression -is resolved. The `shard_hidden` kernel available in SDK 2.28 provides both correct output and -optimal TTFT (1,420 ms). +### Historical: EP=2 Blockwise CTE Issues (Resolved by EP=1) -### SDK 2.29 Benchmark (LNC=2, TP=32, EP=2, BS=1, seq_len=1024) +The previous EP=2 configuration was affected by SDK 2.29 removing the +`blockwise_mm_baseline_shard_hidden` NKI kernel used for MoE context encoding. The +replacement dynamic-while kernels (`shard_on_intermediate`, `shard_on_block`) produced +incorrect output for EP=2 / I_TP=64. A `forward_all_experts_EP` workaround provided +correct output but 7x slower TTFT (10,185 ms vs 1,420 ms). -| Output Tokens | TTFT P50 (ms) | TPOT P50 (ms) | tok/s | E2E P50 (ms) | -|---------------|---------------|----------------|-------|---------------| -| 16 | 10,184.9 | 166.27 | 6.0 | 12,678.9 | -| 32 | 10,185.1 | 166.13 | 6.0 | 15,335.3 | -| 64 | 10,184.6 | 166.10 | 6.0 | 20,651.2 | -| 128 | 10,184.7 | 166.15 | 6.0 | 31,286.4 | -| 256 | 10,184.5 | 166.03 | 6.0 | 52,522.3 | -| 512 | 10,184.6 | 166.06 | 6.0 | 95,040.9 | +**Resolution:** Switching to EP=1 (Task 016) eliminates expert parallelism entirely, +avoiding the broken blockwise CTE kernels. With `capacity_factor=1.0`, CTE uses the +`forward_capacity_factor` path which compiles correctly on neuronx-cc 2.24. -**Additional SDK 2.29 setup:** Install `tiktoken` (`pip install tiktoken`) in the venv. +Detailed investigation of all blockwise kernel paths and the SDK 2.28 kernel porting +attempt is documented in the git history (commits prior to Task 016). ## Known Limitations @@ -394,4 +274,4 @@ optimal TTFT (1,420 ms). Annapurna Labs -**Last Updated:** 2026-04-26 +**Last Updated:** 2026-04-30 diff --git a/contrib/models/Kimi-K2-Instruct-0905/src/modeling_kimi_k2.py b/contrib/models/Kimi-K2-Instruct-0905/src/modeling_kimi_k2.py index 221a76c2..35122ead 100644 --- a/contrib/models/Kimi-K2-Instruct-0905/src/modeling_kimi_k2.py +++ b/contrib/models/Kimi-K2-Instruct-0905/src/modeling_kimi_k2.py @@ -13,7 +13,7 @@ # - YaRN RoPE (factor=64, max_position_embeddings=262144) # # Supported configuration: -# - trn2.48xlarge: TP=64, EP=2, LNC=1 (128 logical cores) +# - trn2.48xlarge: TP=64, EP=1, LNC=2 (64 logical cores) # - Blockwise FP8 for routed expert weights # - CPU greedy sampling (no on-device sampling) # @@ -148,6 +148,7 @@ def initialize_kimi_k2_moe_module(config: "KimiK2InferenceConfig"): early_expert_affinity_modulation=config.neuron_config.early_expert_affinity_modulation, normalize_top_k_affinities=config.neuron_config.normalize_top_k_affinities, enable_spmd_rank=config.neuron_config.blockwise_matmul_config.parallelize_token_to_block_mapping, + capacity_factor=config.neuron_config.capacity_factor, ), blockwise_matmul_config=config.neuron_config.blockwise_matmul_config, sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, @@ -1508,24 +1509,18 @@ def get_compiler_args(self) -> str: """Compiler args optimized for Kimi-K2 on trn2. Key flags: - - -O3 for TKG with EP > 1: avoids Modular flow perf degradation + - -O1 for both CTE and TKG (no EP all-to-all overhead at EP=1) - --internal-enable-dge-levels vector_dynamic_offsets: DGE optimization - - --enable-ccop-compute-overlap: CC overlap for MoE all-to-all + - --enable-ccop-compute-overlap: CC overlap for MoE - --lnc: must match runtime NEURON_LOGICAL_NC_CONFIG """ - lnc = getattr(self.neuron_config, "logical_nc_config", 1) - ep_degree = getattr(self.neuron_config, "moe_ep_degree", 1) - - if self.compile_tag == TOKEN_GENERATION_MODEL_TAG and ep_degree > 1: - optimization_level = "-O3" - else: - optimization_level = "-O1" + lnc = getattr(self.neuron_config, "logical_nc_config", 2) compiler_args = ( "--enable-saturate-infinity " "--enable-mixed-precision-accumulation " "--model-type transformer " - f"{optimization_level}" + "-O1" ) compiler_args += ( " --tensorizer-options='--enable-ccop-compute-overlap " diff --git a/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py b/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py index 2787ce7c..00f670af 100644 --- a/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py +++ b/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py @@ -8,9 +8,8 @@ - trn2.48xlarge with NEURON_LOGICAL_NC_CONFIG=2 (64 logical cores) - LOCAL_WORLD_SIZE=64 - Model weights at MODEL_PATH - - Neuron SDK 2.28+ (Deep Learning AMI Neuron Ubuntu 24.04) - - Selective loading threshold patched to 0.0 in - neuronx_distributed/modules/moe/model_utils.py + - Neuron SDK 2.29+ (Deep Learning AMI Neuron Ubuntu 24.04 20260410) + - Selective loading uses SDK default threshold (1.0) — do NOT patch to 0.0 Usage: # Full test (compile + load + generate): @@ -46,14 +45,14 @@ # --------------------------------------------------------------------------- MODEL_PATH = "/home/ubuntu/models/Kimi-K2-Instruct-0905" -COMPILED_MODEL_PATH = "/home/ubuntu/kimi-k2/neuron-compiled-fp8-bw-no-ods" +COMPILED_MODEL_PATH = "/home/ubuntu/kimi-k2/neuron-compiled-tp64-ep1" # Model configuration -TP_DEGREE = 32 -EP_DEGREE = 2 +TP_DEGREE = 64 +EP_DEGREE = 1 LNC = 2 BATCH_SIZE = 1 -SEQ_LEN = 1024 +SEQ_LEN = 512 N_ACTIVE_TOKENS = 128 @@ -306,8 +305,9 @@ def test_performance_tpot(compiled_model, tokenizer): median_tpot = sorted(tpots)[len(tpots) // 2] tok_per_sec = 1000.0 / median_tpot print(f"PASS: TPOT = {median_tpot:.1f} ms ({tok_per_sec:.1f} tok/s)") - # Kimi-K2 at BS=1 LNC=2: expected ~165 ms/token (6.0 tok/s) - assert median_tpot < 500, f"TPOT {median_tpot:.1f}ms exceeds 500ms threshold" + # Kimi-K2 at BS=1 TP=64 EP=1 LNC=2 with selective loading + blockwise FP8: + # ~145 ms/token (~6.9 tok/s). Note: includes CTE time amortized over 32 tokens. + assert median_tpot < 200, f"TPOT {median_tpot:.1f}ms exceeds 200ms threshold" else: pytest.skip("Could not measure TPOT (no tokens generated)") From 2bf3af85d95caa9e50a276a01858d39df640e32b Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Fri, 1 May 2026 02:11:08 -0400 Subject: [PATCH 10/10] Switch to per-channel FP8: 3.5x throughput improvement (41.1ms TPOT, 24.3 tok/s at seq_len=1024) Per-channel FP8 enables native FP8 execution in the NKI TKG megakernel, eliminating the BF16 dequantization overhead that blockwise FP8 incurred. Key changes: - Add _requantize_per_channel_fp8() for blockwise->per-channel conversion - Checkpoint loader detects quantization_type and dispatches per-channel or blockwise FP8 packing paths - Config: quantization_type=expert_wise_per_channel_symmetric, seq_len=1024 - README: updated performance numbers, config, compatibility matrix Results (TP=64, EP=1, LNC=2, BS=1): - seq_len=512: 76.3ms TPOT (13.1 tok/s), 12.9 min compile - seq_len=1024: 41.1ms TPOT (24.3 tok/s), 11.6 min compile - Both produce correct output (The capital of France is Paris.) - Load time: 71-93 min (blockwise->per-channel re-quantization at load) --- .../models/Kimi-K2-Instruct-0905/README.md | 90 +++--- .../src/modeling_kimi_k2.py | 265 ++++++++++++++---- .../test/integration/test_model.py | 14 +- 3 files changed, 266 insertions(+), 103 deletions(-) diff --git a/contrib/models/Kimi-K2-Instruct-0905/README.md b/contrib/models/Kimi-K2-Instruct-0905/README.md index bede6bc2..182793bf 100644 --- a/contrib/models/Kimi-K2-Instruct-0905/README.md +++ b/contrib/models/Kimi-K2-Instruct-0905/README.md @@ -29,7 +29,7 @@ NeuronX Distributed Inference implementation of Moonshot AI's Kimi-K2-Instruct-0 | QK rope head dim | 64 | | Q LoRA rank | 1536 | | RoPE | YaRN (factor=64, max_position_embeddings=262144) | -| Quantization | Blockwise FP8 (e4m3, 128x128 blocks) | +| Quantization | Per-channel FP8 (e4m3, re-quantized from blockwise) | | Router activation | Sigmoid with `e_score_correction_bias` | | Top-K normalization | Enabled (`norm_topk_prob=True`) | | Routed scaling factor | 2.827 | @@ -40,9 +40,12 @@ NeuronX Distributed Inference implementation of Moonshot AI's Kimi-K2-Instruct-0 (qk_rope_head_dim + kv_lora_rank = 64 + 512). Weight absorption is used to avoid decompressing KV during decode. -- **Blockwise FP8 Quantization:** Routed expert weights are kept in FP8 (e4m3) with - 128x128 block scales. Non-expert weights (attention, embeddings, shared experts, norms) - are dequantized to BF16 during loading. Requires the +- **Per-Channel FP8 Quantization:** Routed expert weights are stored in FP8 (e4m3) with + per-expert per-channel scales (`[E, 1, W]`). The checkpoint loader dequantizes blockwise + FP8 from the HuggingFace checkpoint to BF16, packs into `[E, H, W]` tensors, and + re-quantizes to per-channel FP8. This enables native FP8 in the NKI TKG megakernel + (no BF16 dequantization overhead). Non-expert weights (attention, embeddings, shared + experts, norms) are dequantized to BF16 during loading. Requires the `--experimental-unsafe-fp8e4m3fn-as-fp8e4m3` compiler flag. - **Streaming Checkpoint Loader:** Custom `checkpoint_loader_fn` that processes the 62 @@ -59,34 +62,43 @@ NeuronX Distributed Inference implementation of Moonshot AI's Kimi-K2-Instruct-0 ## Validation Results -**Validated:** 2026-04-30 (SDK 2.29) -**Recommended Configuration:** TP=64, EP=1, LNC=2, batch_size=1, seq_len=512, blockwise FP8 +**Validated:** 2026-05-01 (SDK 2.29) +**Recommended Configuration:** TP=64, EP=1, LNC=2, batch_size=1, seq_len=1024, per-channel FP8 ### Test Results | Test | Status | Result | |------|--------|--------| | Smoke Test | PASS | Model compiles and loads on trn2.48xlarge | -| Generation | PASS | Generates coherent text ("The capital of France is Paris.") | +| Generation (seq_len=512) | PASS | "The capital of France is Paris." | +| Generation (seq_len=1024) | PASS | "The capital of France is Paris." | | Coherence | PASS | Coherent quantum computing explanation | -| Throughput | PASS | 6.9 tok/s at BS=1 (LNC=2, EP=1) | +| Throughput (seq_len=512) | PASS | 13.1 tok/s at BS=1 (76.3 ms TPOT) | +| Throughput (seq_len=1024) | PASS | 24.3 tok/s at BS=1 (41.1 ms TPOT) | -### Performance Metrics (Recommended: TP=64, EP=1, LNC=2) +### Performance Metrics (Recommended: TP=64, EP=1, LNC=2, per-channel FP8) -| Metric | Value | -|--------|-------| -| TPOT (per-token latency) | 144.5 ms | -| Throughput (BS=1) | 6.9 tok/s | -| Compile time (CTE + TKG) | 16 min | -| Model load time | 17 min | +| Metric | seq_len=512 | seq_len=1024 | +|--------|-------------|--------------| +| TPOT (per-token latency) | 76.3 ms | 41.1 ms | +| Throughput (BS=1) | 13.1 tok/s | 24.3 tok/s | +| Compile time (CTE + TKG) | 12.9 min | 11.6 min | +| Model load time | ~71 min | ~93 min | + +**Note:** Load time is high because the loader dequantizes blockwise FP8 to BF16 and +re-quantizes to per-channel FP8 for all 60 MoE layers x 384 experts at load time. +A pre-sharding script (Task 018) could reduce this to ~17 min by saving pre-converted +per-channel FP8 checkpoints. ### Configuration Comparison -| Config | TP | EP | LNC | Cores | TPOT | tok/s | Compile | Speedup | -|--------|----|----|-----|-------|------|-------|---------|---------| -| **EP=1 selective (recommended)** | 64 | 1 | 2 | 64 | **144.5 ms** | **6.9** | **16 min** | **+105%** | -| EP=2 LNC=2 (previous) | 32 | 2 | 2 | 64 | 165.5 ms | 6.0 | 67 min | +76% | -| EP=2 LNC=1 (baseline) | 64 | 2 | 1 | 128 | 297.4 ms | 3.4 | ~60 min | baseline | +| Config | TP | EP | LNC | Quant | TPOT | tok/s | Compile | Speedup vs baseline | +|--------|----|----|-----|-------|------|-------|---------|---------------------| +| **EP=1 per-channel FP8 seq=1024 (recommended)** | 64 | 1 | 2 | per-channel | **41.1 ms** | **24.3** | **11.6 min** | **+623%** | +| EP=1 per-channel FP8 seq=512 | 64 | 1 | 2 | per-channel | 76.3 ms | 13.1 | 12.9 min | +290% | +| EP=1 blockwise FP8 seq=512 | 64 | 1 | 2 | blockwise | 144.5 ms | 6.9 | 16 min | +105% | +| EP=2 LNC=2 (previous) | 32 | 2 | 2 | blockwise | 165.5 ms | 6.0 | 67 min | +76% | +| EP=2 LNC=1 (baseline) | 64 | 2 | 1 | blockwise | 297.4 ms | 3.4 | ~60 min | baseline | ### Batching Results @@ -97,12 +109,16 @@ same aggregate throughput as BS=1. ### Performance Bottleneck -TPOT breakdown (estimated per ~144.5 ms token at TP=64, EP=1, LNC=2): +The decode step is bandwidth-bound: each token loads 8 active experts' weight matrices +from HBM (selective loading). Per-channel FP8 enables native FP8 execution in the NKI +TKG megakernel, eliminating the BF16 dequantization overhead that blockwise FP8 incurred. +This reduced TPOT from 144.5 ms (blockwise) to 76.3 ms (per-channel, seq_len=512) — a +1.9x improvement. Increasing seq_len from 512 to 1024 further halved TPOT to 41.1 ms +(larger CTE bucket amortizes context encoding overhead and enables larger KV cache). -The decode step loads only 8 active experts per token (selective loading), but each expert's -weight matrices are TP-sharded across 64 cores. MoE decode remains bandwidth-bound — the -primary optimization opportunity is native per-channel FP8 support in the NKI kernel, which -could reduce weight transfer by ~50% and bring TPOT to ~22 ms (Task 017). +Remaining optimization opportunities: +- Pre-sharded per-channel FP8 checkpoints to reduce load time from ~93 min to ~17 min +- Batching (unlikely to help — MoE decode is fully bandwidth-bound, BS>1 TPOT scales linearly) ## Usage @@ -128,16 +144,16 @@ neuron_config = MoENeuronConfig( ep_degree=1, logical_nc_config=2, max_batch_size=1, - seq_len=512, + seq_len=1024, n_active_tokens=128, torch_dtype="bfloat16", capacity_factor=1.0, glu_mlp=True, moe_ep_degree=1, moe_tp_degree=64, - context_encoding_buckets=[128, 512], + context_encoding_buckets=[128, 1024], router_config=RouterConfig(act_fn="sigmoid", dtype="float32"), - # FP8 quantization + # Per-channel FP8 quantization for routed experts quantized=True, quantized_checkpoints_path=model_path, quantization_dtype="f8e4m3", @@ -145,9 +161,7 @@ neuron_config = MoENeuronConfig( "self_attn", "shared_experts", "embed_tokens", "lm_head", "norm", "router", "layers.0", ], - quantization_type="blockwise_symmetric", - quantization_block_axis=[1, 2], - quantization_block_size=[128, 128], + quantization_type="expert_wise_per_channel_symmetric", ) # Build config from HF config fields @@ -157,8 +171,8 @@ config = KimiK2InferenceConfig(neuron_config=neuron_config, **hf_kwargs) # Compile and load model = NeuronKimiK2ForCausalLM(model_path, config) -model.compile(compiled_path) # ~16 min -model.load(compiled_path) # ~17 min +model.compile(compiled_path) # ~12 min +model.load(compiled_path) # ~71-93 min (re-quantizes blockwise->per-channel FP8) # Generate (CPU greedy sampling, no on-device sampling) from transformers import AutoTokenizer @@ -176,7 +190,8 @@ NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 python your_script.py | Instance / SDK Version | 2.29 | 2.28 | 2.27 and earlier | |------------------------|------|------|------------------| -| trn2.48xlarge (TP=64, EP=1, LNC=2, recommended) | **Working (6.9 tok/s)** | CTE compile fails (neuronx-cc 2.23 BIR error) | Not tested | +| trn2.48xlarge (TP=64, EP=1, LNC=2, per-channel FP8, recommended) | **Working (24.3 tok/s @ seq=1024)** | CTE compile fails (neuronx-cc 2.23 BIR error) | Not tested | +| trn2.48xlarge (TP=64, EP=1, LNC=2, blockwise FP8) | Working (6.9 tok/s) | CTE compile fails (neuronx-cc 2.23 BIR error) | Not tested | | trn2.48xlarge (TP=32, EP=2, LNC=2) | TKG working, CTE requires EP workaround* | Working (6.0 tok/s) | Not tested | | trn2.48xlarge (TP=64, EP=2, LNC=1) | Not tested | Working (3.4 tok/s) | Not tested | | trn2.3xlarge | Not supported (needs 64 cores) | Not supported | Not supported | @@ -208,8 +223,9 @@ NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 \ python test/integration/test_model.py ``` -**Note:** Compilation takes ~16 min and loading takes ~17 min. The first run will compile -NEFFs to the compiled model path. Subsequent runs with existing NEFFs skip compilation. +**Note:** Compilation takes ~12 min and loading takes ~71-93 min (the loader re-quantizes +blockwise FP8 to per-channel FP8 for all MoE experts at load time). The first run will +compile NEFFs to the compiled model path. Subsequent runs with existing NEFFs skip compilation. ## Prerequisites @@ -274,4 +290,4 @@ attempt is documented in the git history (commits prior to Task 016). Annapurna Labs -**Last Updated:** 2026-04-30 +**Last Updated:** 2026-05-01 diff --git a/contrib/models/Kimi-K2-Instruct-0905/src/modeling_kimi_k2.py b/contrib/models/Kimi-K2-Instruct-0905/src/modeling_kimi_k2.py index 35122ead..21617950 100644 --- a/contrib/models/Kimi-K2-Instruct-0905/src/modeling_kimi_k2.py +++ b/contrib/models/Kimi-K2-Instruct-0905/src/modeling_kimi_k2.py @@ -9,12 +9,13 @@ # - 384 routed experts (8 active per token) + 1 shared expert # - Multi-Latent Attention (MLA) with compressed KV cache # - Sigmoid routing with e_score_correction_bias + normalized top-K -# - Blockwise FP8 quantization (e4m3, 128x128 blocks) +# - Blockwise FP8 quantization (e4m3, 128x128 blocks) in HF checkpoint +# - Per-channel FP8 re-quantization for NKI TKG megakernel # - YaRN RoPE (factor=64, max_position_embeddings=262144) # # Supported configuration: # - trn2.48xlarge: TP=64, EP=1, LNC=2 (64 logical cores) -# - Blockwise FP8 for routed expert weights +# - Per-channel FP8 for routed expert weights (expert_wise_per_channel_symmetric) # - CPU greedy sampling (no on-device sampling) # # References: @@ -849,6 +850,38 @@ def _pack_experts_blockwise_fp8( raise ValueError(f"Unknown layout: {layout}") +def _requantize_per_channel_fp8(bf16_weight: Tensor) -> Tuple[Tensor, Tensor]: + """Re-quantize a BF16 weight to per-expert per-channel FP8 E4M3. + + Per-expert per-channel means one scale per output column PER EXPERT. + This matches NxDI's EXPERT_WISE_PER_CHANNEL_SYMMETRIC quantization type + which uses scale shape [E, 1, output_per_tp]. + + Args: + bf16_weight: [E, H, W] bfloat16 tensor (experts x input x output) + Returns: + (fp8_weight, per_expert_per_channel_scale) where: + fp8_weight: [E, H, W] float8_e4m3fn + per_expert_per_channel_scale: [E, 1, W] float32 (per-expert per-channel) + """ + fp32_weight = bf16_weight.float() + # max abs per output column, PER EXPERT (reduce dim 1 = input rows only) + amax = fp32_weight.abs().amax(dim=1, keepdim=True).clamp(min=1e-12) # [E, 1, W] + scale = amax / _FP8_E4M3_MAX + scaled = (fp32_weight / scale).clamp(-_FP8_E4M3_MAX, _FP8_E4M3_MAX) + fp8_weight = scaled.to(torch.float8_e4m3fn) + + # Clamp exponent-15 bytes (0x78-0x7E, 0xF8-0xFE become NaN in e4m3) + raw = fp8_weight.view(torch.uint8) + pos_exp15 = (raw >= 0x78) & (raw <= 0x7E) + neg_exp15 = (raw >= 0xF8) & (raw <= 0xFE) + raw = torch.where(pos_exp15, torch.tensor(0x77, dtype=torch.uint8), raw) + raw = torch.where(neg_exp15, torch.tensor(0xF7, dtype=torch.uint8), raw) + fp8_weight = raw.view(torch.float8_e4m3fn) + + return fp8_weight, scale.to(torch.float32) + + # --------------------------------------------------------------------------- # State Dict Conversion # --------------------------------------------------------------------------- @@ -1165,7 +1198,9 @@ def checkpoint_loader_fn(self, mmap: bool = False): expert packing, router renaming), and accumulates results to avoid OOM. When quantized=True (FP8 path): - - Routed expert weights stay in FP8 with blockwise scales + - Per-channel: dequant blockwise FP8 -> BF16 -> re-quantize per-channel FP8 + with [E, 1, W] scales (expert_wise_per_channel_symmetric) + - Blockwise: keep original FP8 bytes with [E, H/bs, W/bs] block scales - All other weights are dequantized to BF16 When quantized=False (BF16 path): @@ -1197,6 +1232,13 @@ def checkpoint_loader_fn(self, mmap: bool = False): n_routed_experts = getattr(self.config, "n_routed_experts", 384) first_k_dense_replace = getattr(self.config, "first_k_dense_replace", 1) keep_experts_fp8 = getattr(self.config.neuron_config, "quantized", False) + quantization_type = getattr( + self.config.neuron_config, "quantization_type", "blockwise_symmetric" + ) + use_per_channel = quantization_type in ( + "per_channel_symmetric", + "expert_wise_per_channel_symmetric", + ) num_layers = self.config.num_hidden_layers # Determine which shards are needed (supports reduced-layer testing) @@ -1214,7 +1256,8 @@ def checkpoint_loader_fn(self, mmap: bool = False): logger.info( f"Streaming loader: {len(shard_files)} shards, {len(needed_shards)} needed, " - f"block_size={block_size}, experts={n_routed_experts}, fp8={keep_experts_fp8}" + f"block_size={block_size}, experts={n_routed_experts}, fp8={keep_experts_fp8}, " + f"quant_type={quantization_type}" ) result_dict = {} @@ -1341,53 +1384,84 @@ def checkpoint_loader_fn(self, mmap: bool = False): isize, hsize = shard_data[e0_gate].shape if keep_experts_fp8: - gate_up_weights = [] - gate_up_scales = [] - down_weights = [] - down_scales = [] - - for e in range(n_routed_experts): - gk = f"{prefix}.mlp.experts.{e}.gate_proj.weight" - uk = f"{prefix}.mlp.experts.{e}.up_proj.weight" - dk = f"{prefix}.mlp.experts.{e}.down_proj.weight" - gsk = expert_scale_keys.get(gk) - usk = expert_scale_keys.get(uk) - dsk = expert_scale_keys.get(dk) - - g_fp8 = shard_data.pop(gk) if gk in shard_data else None - u_fp8 = shard_data.pop(uk) if uk in shard_data else None - g_scale = ( - shard_data.pop(gsk) - if gsk and gsk in shard_data - else None + if use_per_channel: + # Per-channel FP8: dequant blockwise -> BF16 -> pack -> requant per-channel + gate_up_bf16 = torch.zeros( + n_routed_experts, + hsize, + 2 * isize, + dtype=torch.bfloat16, + device="cpu", ) - u_scale = ( - shard_data.pop(usk) - if usk and usk in shard_data - else None + down_bf16 = torch.zeros( + n_routed_experts, + isize, + hsize, + dtype=torch.bfloat16, + device="cpu", ) - if g_fp8 is not None and u_fp8 is not None: - gate_up_weights.append((g_fp8, u_fp8)) - gate_up_scales.append((g_scale, u_scale)) - - d_fp8 = shard_data.pop(dk) if dk in shard_data else None - d_scale = ( - shard_data.pop(dsk) - if dsk and dsk in shard_data - else None - ) - if d_fp8 is not None: - down_weights.append(d_fp8) - down_scales.append(d_scale) - - if gate_up_weights: - gu_fp8, gu_scale = _pack_experts_blockwise_fp8( - gate_up_weights, - gate_up_scales, - block_size, - tp_degree=self.config.neuron_config.tp_degree, - layout="gate_up", + for e in range(n_routed_experts): + gk = f"{prefix}.mlp.experts.{e}.gate_proj.weight" + uk = f"{prefix}.mlp.experts.{e}.up_proj.weight" + dk = f"{prefix}.mlp.experts.{e}.down_proj.weight" + gsk = expert_scale_keys.get(gk) + usk = expert_scale_keys.get(uk) + dsk = expert_scale_keys.get(dk) + + g_fp8 = ( + shard_data.pop(gk) if gk in shard_data else None + ) + u_fp8 = ( + shard_data.pop(uk) if uk in shard_data else None + ) + g_scale = ( + shard_data.pop(gsk) + if gsk and gsk in shard_data + else None + ) + u_scale = ( + shard_data.pop(usk) + if usk and usk in shard_data + else None + ) + + if g_fp8 is not None and u_fp8 is not None: + g_bf16 = _dequant_block_fp8_to_fp32( + g_fp8, g_scale, block_size + ).to(torch.bfloat16) + u_bf16 = _dequant_block_fp8_to_fp32( + u_fp8, u_scale, block_size + ).to(torch.bfloat16) + gate_up_bf16[e, :, :isize] = g_bf16.T + gate_up_bf16[e, :, isize:] = u_bf16.T + del ( + g_fp8, + u_fp8, + g_scale, + u_scale, + g_bf16, + u_bf16, + ) + + d_fp8 = ( + shard_data.pop(dk) if dk in shard_data else None + ) + d_scale = ( + shard_data.pop(dsk) + if dsk and dsk in shard_data + else None + ) + if d_fp8 is not None: + d_bf16 = _dequant_block_fp8_to_fp32( + d_fp8, d_scale, block_size + ).to(torch.bfloat16) + down_bf16[e] = d_bf16.T + del d_fp8, d_scale, d_bf16 + + # Re-quantize to per-channel FP8 [E, 1, W] scales + gu_fp8, gu_scale = _requantize_per_channel_fp8( + gate_up_bf16 ) shard_data[ f"{prefix}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" @@ -1395,15 +1469,10 @@ def checkpoint_loader_fn(self, mmap: bool = False): shard_data[ f"{prefix}.mlp.expert_mlps.mlp_op.gate_up_proj.scale" ] = gu_scale - del gate_up_weights, gate_up_scales - - if down_weights: - dn_fp8, dn_scale = _pack_experts_blockwise_fp8( - down_weights, - down_scales, - block_size, - tp_degree=self.config.neuron_config.tp_degree, - layout="down", + del gate_up_bf16, gu_fp8, gu_scale + + dn_fp8, dn_scale = _requantize_per_channel_fp8( + down_bf16 ) shard_data[ f"{prefix}.mlp.expert_mlps.mlp_op.down_proj.weight" @@ -1411,7 +1480,87 @@ def checkpoint_loader_fn(self, mmap: bool = False): shard_data[ f"{prefix}.mlp.expert_mlps.mlp_op.down_proj.scale" ] = dn_scale - del down_weights, down_scales + del down_bf16, dn_fp8, dn_scale + + else: + # Blockwise FP8: keep original FP8 bytes with block scales + gate_up_weights = [] + gate_up_scales = [] + down_weights = [] + down_scales = [] + + for e in range(n_routed_experts): + gk = f"{prefix}.mlp.experts.{e}.gate_proj.weight" + uk = f"{prefix}.mlp.experts.{e}.up_proj.weight" + dk = f"{prefix}.mlp.experts.{e}.down_proj.weight" + gsk = expert_scale_keys.get(gk) + usk = expert_scale_keys.get(uk) + dsk = expert_scale_keys.get(dk) + + g_fp8 = ( + shard_data.pop(gk) if gk in shard_data else None + ) + u_fp8 = ( + shard_data.pop(uk) if uk in shard_data else None + ) + g_scale = ( + shard_data.pop(gsk) + if gsk and gsk in shard_data + else None + ) + u_scale = ( + shard_data.pop(usk) + if usk and usk in shard_data + else None + ) + + if g_fp8 is not None and u_fp8 is not None: + gate_up_weights.append((g_fp8, u_fp8)) + gate_up_scales.append((g_scale, u_scale)) + + d_fp8 = ( + shard_data.pop(dk) if dk in shard_data else None + ) + d_scale = ( + shard_data.pop(dsk) + if dsk and dsk in shard_data + else None + ) + if d_fp8 is not None: + down_weights.append(d_fp8) + down_scales.append(d_scale) + + if gate_up_weights: + gu_fp8, gu_scale = _pack_experts_blockwise_fp8( + gate_up_weights, + gate_up_scales, + block_size, + tp_degree=self.config.neuron_config.tp_degree, + layout="gate_up", + ) + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + ] = gu_fp8 + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.gate_up_proj.scale" + ] = gu_scale + del gate_up_weights, gate_up_scales + + if down_weights: + dn_fp8, dn_scale = _pack_experts_blockwise_fp8( + down_weights, + down_scales, + block_size, + tp_degree=self.config.neuron_config.tp_degree, + layout="down", + ) + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.down_proj.weight" + ] = dn_fp8 + shard_data[ + f"{prefix}.mlp.expert_mlps.mlp_op.down_proj.scale" + ] = dn_scale + del down_weights, down_scales else: # BF16 path diff --git a/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py b/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py index 00f670af..5be16521 100644 --- a/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py +++ b/contrib/models/Kimi-K2-Instruct-0905/test/integration/test_model.py @@ -45,14 +45,14 @@ # --------------------------------------------------------------------------- MODEL_PATH = "/home/ubuntu/models/Kimi-K2-Instruct-0905" -COMPILED_MODEL_PATH = "/home/ubuntu/kimi-k2/neuron-compiled-tp64-ep1" +COMPILED_MODEL_PATH = "/home/ubuntu/kimi-k2/neuron-compiled-tp64-ep1-perchannel-1024" # Model configuration TP_DEGREE = 64 EP_DEGREE = 1 LNC = 2 BATCH_SIZE = 1 -SEQ_LEN = 512 +SEQ_LEN = 1024 N_ACTIVE_TOKENS = 128 @@ -75,7 +75,7 @@ def build_config(): moe_tp_degree=TP_DEGREE, context_encoding_buckets=[N_ACTIVE_TOKENS, SEQ_LEN], router_config=RouterConfig(act_fn="sigmoid", dtype="float32"), - # FP8 quantization for routed experts + # FP8 quantization for routed experts (per-channel) quantized=True, quantized_checkpoints_path=MODEL_PATH, quantization_dtype="f8e4m3", @@ -88,9 +88,7 @@ def build_config(): "router", "layers.0", ], - quantization_type="blockwise_symmetric", - quantization_block_axis=[1, 2], - quantization_block_size=[128, 128], + quantization_type="expert_wise_per_channel_symmetric", ) hf_kwargs = { @@ -305,8 +303,8 @@ def test_performance_tpot(compiled_model, tokenizer): median_tpot = sorted(tpots)[len(tpots) // 2] tok_per_sec = 1000.0 / median_tpot print(f"PASS: TPOT = {median_tpot:.1f} ms ({tok_per_sec:.1f} tok/s)") - # Kimi-K2 at BS=1 TP=64 EP=1 LNC=2 with selective loading + blockwise FP8: - # ~145 ms/token (~6.9 tok/s). Note: includes CTE time amortized over 32 tokens. + # Kimi-K2 at BS=1 TP=64 EP=1 LNC=2 with per-channel FP8: + # ~76 ms/token at seq_len=512, ~41 ms at seq_len=1024. assert median_tpot < 200, f"TPOT {median_tpot:.1f}ms exceeds 200ms threshold" else: pytest.skip("Could not measure TPOT (no tokens generated)")