From ac186ab14c73c624729a87226d18d101b0f5fae8 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 11:56:57 +0800 Subject: [PATCH 01/23] [contrib] Add MiMo-V2-Flash NeuronX port (TP=64, EP=64 MoE) Xiaomi MiMo-V2-Flash in NxDI contrib format. All code lives under contrib/models/MiMo-V2-Flash/, with zero changes to the upstream src/ tree. Architecture: 48 decoder layers, 256 MoE experts (top-8), hybrid attention (full + sliding window), asymmetric Q/K/V dims (Q/K=192, V=128), partial RoPE (34%), sigmoid router, no shared experts. Structure: src/modeling_mimo_v2.py - full modeling code (1333 lines) src/conversion_script/ - FP8 -> BF16 preprocessor test/integration/test_model.py - config/state-dict/import tests perf_test/0_setup.sh - vllm-neuron install + weight fetch perf_test/bench_mimo_v2_flash.sh - vLLM serving benchmark (BS=1/32/128) perf_test/vllm-neuron-patch.patch - maps MiMo architecture to Qwen2 loader + hf_config plumbing Import pattern: tests/examples add src/ to sys.path and import the flat module name (e.g. `from modeling_mimo_v2 import ...`), matching the convention in upstream contrib/Qwen2-Audio-7B. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2-Flash/README.md | 217 +++ .../models/MiMo-V2-Flash/perf_test/0_setup.sh | 37 + .../perf_test/bench_mimo_v2_flash.sh | 239 +++ .../perf_test/vllm-neuron-patch.patch | 129 ++ contrib/models/MiMo-V2-Flash/src/__init__.py | 0 .../preprocess_mimo_v2_fp8.py | 630 ++++++++ .../MiMo-V2-Flash/src/modeling_mimo_v2.py | 1333 +++++++++++++++++ contrib/models/MiMo-V2-Flash/test/__init__.py | 0 .../test/integration/__init__.py | 0 .../test/integration/test_model.py | 53 + .../MiMo-V2-Flash/test/unit/__init__.py | 0 11 files changed, 2638 insertions(+) create mode 100644 contrib/models/MiMo-V2-Flash/README.md create mode 100755 contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh create mode 100755 contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh create mode 100644 contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch create mode 100644 contrib/models/MiMo-V2-Flash/src/__init__.py create mode 100644 contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_fp8.py create mode 100644 contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py create mode 100644 contrib/models/MiMo-V2-Flash/test/__init__.py create mode 100644 contrib/models/MiMo-V2-Flash/test/integration/__init__.py create mode 100644 contrib/models/MiMo-V2-Flash/test/integration/test_model.py create mode 100644 contrib/models/MiMo-V2-Flash/test/unit/__init__.py diff --git a/contrib/models/MiMo-V2-Flash/README.md b/contrib/models/MiMo-V2-Flash/README.md new file mode 100644 index 00000000..43028ffa --- /dev/null +++ b/contrib/models/MiMo-V2-Flash/README.md @@ -0,0 +1,217 @@ +# Contrib Model: MiMo-V2-Flash + +NeuronX Distributed Inference implementation of [XiaomiMiMo/MiMo-V2-Flash](https://huggingface.co/XiaomiMiMo/MiMo-V2-Flash). + +## Model Information + +- **HuggingFace ID:** `XiaomiMiMo/MiMo-V2-Flash` +- **Model Type:** Decoder-only MoE transformer with hybrid attention +- **Architecture:** Custom MoE with full + sliding window attention +- **License:** Check HuggingFace model card + +## Architecture Details + +| Parameter | Value | +|-----------|-------| +| Hidden Size | 4096 | +| Layers | 48 | +| Attention Heads | 64 Q | +| KV Heads (full attn) | 4 | +| KV Heads (sliding window) | 8 | +| Q/K Head Dim | 192 | +| V Head Dim | 128 | +| Experts | 256 (top-8 routing) | +| Expert Intermediate | 1536 | +| Vocab Size | 151,936 | +| RoPE | Partial (34% of dims), theta=5M | +| Sliding Window | 32,768 | +| Max Position | 262,144 | +| Total Params | ~143B (FP8) / ~286B (BF16) | + +Key features: +- **Hybrid Attention**: 9 full attention layers (0, 5, 11, 17, 23, 29, 35, 41, 47) + 39 sliding window layers +- **Asymmetric Head Dims**: Q/K use 192, V uses 128 (fused_qkv not supported) +- **Attention Sink Bias**: Learnable per-head bias in sliding window layers +- **Sigmoid Router**: For MoE expert selection +- **Expert Parallelism**: Supports EP=64 for prefill with hybrid sharding (EP=1 for token generation) + +## Prerequisites + +- **Instance**: trn2.48xlarge (32 NeuronCores, logical_nc_config=2 -> 64 logical cores) +- **Weights**: BF16 format (convert from FP8 using `conversion_script/preprocess_mimo_v2_fp8.py`) + +## FP8 to BF16 Conversion + +The original model uses block-wise FP8 quantization incompatible with Neuron FP8. Convert to BF16: + +```bash +python contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_fp8.py \ + --input-dir /path/to/MiMo-V2-Flash \ + --output-dir /path/to/MiMo-V2-Flash-BF16 +``` + +## Usage + +```python +import sys +from pathlib import Path + +# Make this contrib package's src/ importable (flat, per upstream contrib convention). +sys.path.insert(0, str(Path("contrib/models/MiMo-V2-Flash/src").resolve())) + +import torch +from transformers import AutoTokenizer +from neuronx_distributed_inference.models.config import MoENeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config, HuggingFaceGenerationAdapter + +from modeling_mimo_v2 import NeuronMiMoV2ForCausalLM, MiMoV2InferenceConfig + +model_path = "/path/to/MiMo-V2-Flash-BF16/" +compiled_path = "/path/to/compiled/" + +neuron_config = MoENeuronConfig( + tp_degree=64, + moe_tp_degree=1, + moe_ep_degree=64, + batch_size=1, + seq_len=512, + max_context_length=128, + torch_dtype=torch.bfloat16, + logical_nc_config=2, + sequence_parallel_enabled=True, + fused_qkv=False, # Required: asymmetric Q/K vs V dims + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=True, temperature=0.6, top_k=20, top_p=0.95 + ), + router_config={act_fn: sigmoid}, +) + +config = MiMoV2InferenceConfig( + neuron_config, load_config=load_pretrained_config(model_path) +) + +model = NeuronMiMoV2ForCausalLM(model_path, config) +model.compile(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +adapter = HuggingFaceGenerationAdapter(model, tokenizer) +output = adapter.generate("Hello, how are you?", max_new_tokens=128) +``` + +## vLLM Integration + +MiMo-V2-Flash can be served via [vllm-neuron](https://github.com/aws-neuron/vllm-neuron). A patch is required to add MiMo architecture support. + +### Setup + +```bash +# 1. Install vllm-neuron +pip install vllm-neuron + +# 2. Apply the MiMo/MiniMax patch +cd /path/to/vllm-neuron +git apply /path/to/neuronx-distributed-inference/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch +pip install -e . +``` + +### Serving + +```bash +python3 -m vllm.entrypoints.openai.api_server \ + --model /path/to/MiMo-V2-Flash-BF16 \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 32 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + "tp_degree": 64, + "logical_nc_config": 2, + "fused_qkv": false, + "flash_decoding_enabled": false, + "sequence_parallel_enabled": true, + "glu_mlp": true, + "normalize_top_k_affinities": true, + "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 32, + "ctx_batch_size": 1, + "tkg_batch_size": 32, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "enable_bucketing": true, + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' +``` + +### Key vLLM Patch Changes + +The patch (`contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch`) modifies vllm-neuron to: +- Map MiMo architecture to Qwen2 model loader (MiMo is Qwen2-based) +- Pass `hf_config` from vLLM to NxDI (required for `trust_remote_code` models) +- Replace `AutoModelForCausalLM.from_pretrained` with `snapshot_download` for model loading + +See `contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh` for full benchmark configurations with BS=1/32/128. Run `contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh` first to install vllm-neuron and fetch weights. + +## Performance + +### Standalone NxDI (trn2.48xlarge, BF16, TP=64, EP=64) + +| Batch Size | Throughput (tok/s) | +|------------|-------------------| +| 1 | 29.92 | +| 8 | 215.94 | +| 32 | 649.14 | + +### vLLM Serving (trn2.48xlarge, BF16, BS=32, TP=64/EP=64, CB) + +Input/output: 900/90 tokens (random dataset) + +| Concurrency | Throughput (tok/s) | TPOT (ms) | TTFT (ms) | +|-------------|-------------------|-----------|-----------| +| 1 | 27.98 | 33.65 | 222 | +| 16 | 224.57 | 64.95 | 570 | +| 32 | 302.61 | 90.23 | 1351 | + +> **Note:** Large MoE models like MiMo-V2-Flash require extended engine startup time (~47 min for compile+load). Set `VLLM_ENGINE_READY_TIMEOUT_S=3600` before launching the vLLM server. + +## Compatibility Matrix + +| Instance/Version | 2.22+ (PyTorch 2.9) | 2.21 and earlier | +|------------------|---------------------|------------------| +| Trn2 (trn2.48xlarge) | Tested | Not tested | +| Trn1 | Not supported (requires 64 cores) | Not supported | +| Inf2 | Not supported | Not supported | + +## Testing + +```bash +pytest contrib/models/MiMo-V2-Flash/test/integration/test_model.py -v +``` + +## Key Implementation Notes + +1. **Hybrid Attention**: `hybrid_layer_pattern` list determines full vs sliding window per layer. +2. **CONVERT_TO_MHA**: When TP > num_kv_heads (4), K/V are replicated to match Q heads (64). +3. **Attention Sink Bias**: Adds learnable sink column to attention weights in sliding window layers. +4. **EP Hybrid Sharding**: EP is used during prefill only; token generation uses EP=1 unless batch_size >= 32. +5. **FP8 Conversion**: Original uses OCP block-wise FP8, requires conversion to BF16 or Neuron-compatible FP8 format. + +## Example Checkpoints + +* [XiaomiMiMo/MiMo-V2-Flash](https://huggingface.co/XiaomiMiMo/MiMo-V2-Flash) + +## Maintainer + +Henan Wan (whn09) + +**Last Updated:** 2026-04-13 diff --git a/contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh b/contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh new file mode 100755 index 00000000..baf4c3bc --- /dev/null +++ b/contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Setup for MiMo-V2-Flash vLLM benchmarking on Trn2. +set -e + +echo "==========================================" +echo "Setup: vllm-neuron + MiMo-V2-Flash weights" +echo "==========================================" + +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +echo "" +echo "[1/2] Installing vllm-neuron with the MiMo/MiniMax patch..." + +if [ ! -d /tmp/vllm-neuron ]; then + git clone --branch feature/mimo-support https://github.com/whn09/vllm-neuron.git /tmp/vllm-neuron +fi +cd /tmp/vllm-neuron +pip install --extra-index-url=https://pip.repos.neuron.amazonaws.com -e . +pip install s5cmd + +python3 -c "import vllm_neuron; print('vllm-neuron installed:', vllm_neuron.__file__)" + +echo "" +echo "[2/2] Downloading MiMo-V2-Flash BF16 weights..." + +MIMO_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2-Flash-BF16}" +if [ -d "$MIMO_PATH" ] && [ "$(ls "$MIMO_PATH"/*.safetensors 2>/dev/null | wc -l)" -gt 0 ]; then + echo " MiMo weights already exist at $MIMO_PATH, skipping download" +else + echo " Downloading BF16 weights from your S3 bucket (edit the URI if needed)..." + mkdir -p "$MIMO_PATH" + s5cmd cp "s3://datalab/xiaomi/models/MiMo-V2-Flash-BF16/**" "$MIMO_PATH/" + echo " Download complete: $(du -sh $MIMO_PATH | cut -f1)" +fi + +echo "" +echo "Setup complete. Set MIMO_V2_FLASH_PATH=$MIMO_PATH before running the benchmark." diff --git a/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh b/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh new file mode 100755 index 00000000..4a655991 --- /dev/null +++ b/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh @@ -0,0 +1,239 @@ +#!/bin/bash +set -e + +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2-Flash-BF16}" +PORT=8000 +RESULTS_DIR="/tmp/bench_results/mimo_v2_flash" +mkdir -p "$RESULTS_DIR" + +# Common neuron config shared across all MiMo configs +COMMON_MIMO_CONFIG='"tp_degree": 64, + "logical_nc_config": 2, + "fused_qkv": false, + "flash_decoding_enabled": false, + "sequence_parallel_enabled": true, + "qkv_kernel_enabled": false, + "qkv_nki_kernel_enabled": false, + "qkv_cte_nki_kernel_fuse_rope": false, + "attn_kernel_enabled": false, + "strided_context_parallel_kernel_enabled": false, + "glu_mlp": true, + "normalize_top_k_affinities": true, + "router_config": {"act_fn": "sigmoid", "dtype": "float32"}' + +# Helper: wait for vLLM server to be ready +wait_for_server() { + echo " Waiting for vLLM server to be ready..." + for i in $(seq 1 120); do + if curl -s http://localhost:$PORT/health > /dev/null 2>&1; then + echo " Server ready! (${i}s)" + return 0 + fi + sleep 5 + done + echo " ERROR: Server did not start within 600s" + return 1 +} + +# Helper: run benchmark +run_bench() { + local config_name=$1 + local concurrency=$2 + local num_prompts=$3 + + echo " Benchmark: concurrency=$concurrency, prompts=$num_prompts" + vllm bench serve \ + --backend vllm \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --endpoint /v1/completions \ + --dataset-name random \ + --num-prompts "$num_prompts" \ + --random-input-len 900 \ + --random-output-len 90 \ + --random-range-ratio 0.03 \ + --max-concurrency "$concurrency" \ + 2>&1 | tee "$RESULTS_DIR/${config_name}_c${concurrency}.txt" + echo "" +} + +# Helper: stop server +stop_server() { + echo " Stopping vLLM server..." + pkill -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true + sleep 5 +} + +# Helper: quick sanity check +sanity_check() { + echo " Running sanity check..." + curl -s http://localhost:$PORT/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [{"role": "user", "content": "What is 1+1? Answer briefly."}], + "model": "'"$MODEL_PATH"'", + "max_tokens": 64, + "temperature": 0.0, + "stream": false + }' | python3 -c "import sys,json; r=json.load(sys.stdin); print(' Sanity:', r['choices'][0]['message']['content'][:100])" 2>/dev/null || echo " Sanity check: could not parse response" +} + +echo "==========================================" +echo "MiMo-V2-Flash Performance Benchmark" +echo "==========================================" +echo "Model: $MODEL_PATH" +echo "Results: $RESULTS_DIR" +echo "" + +############################################################################### +# Config 1: BS=1, TP=64/EP=1, non-CB (baseline latency) +############################################################################### +CONFIG_NAME="bs1_tp64_ep1" +echo "--- Config 1: BS=1, TP=64/EP=1, non-CB (baseline) ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 1 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + '"$COMMON_MIMO_CONFIG"', + "moe_tp_degree": 64, + "moe_ep_degree": 1, + "batch_size": 1, + "ctx_batch_size": 1, + "tkg_batch_size": 1, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": false, + "enable_bucketing": false, + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +stop_server + +############################################################################### +# Config 2: BS=32, TP=1/EP=64, CB + optimizations +############################################################################### +CONFIG_NAME="bs32_tp1_ep64_opt" +echo "--- Config 2: BS=32, TP=1/EP=64, CB + optimizations ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 32 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + '"$COMMON_MIMO_CONFIG"', + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 32, + "ctx_batch_size": 1, + "tkg_batch_size": 32, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "enable_bucketing": true, + "context_encoding_buckets": [1024], + "token_generation_buckets": [1024], + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + }, + "use_index_calc_kernel": true, + "moe_mask_padded_tokens": true, + "blockwise_matmul_config": { + "use_shard_on_intermediate_dynamic_while": true, + "skip_dma_token": true + }, + "disable_numeric_cc_token": true, + "scratchpad_page_size": 1024 + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +run_bench "$CONFIG_NAME" 16 128 +run_bench "$CONFIG_NAME" 32 128 +stop_server + +############################################################################### +# Config 3: BS=128, TP=1/EP=64, CB + optimizations +############################################################################### +CONFIG_NAME="bs128_tp1_ep64_opt" +echo "--- Config 3: BS=128, TP=1/EP=64, CB + optimizations ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 128 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + '"$COMMON_MIMO_CONFIG"', + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 128, + "ctx_batch_size": 1, + "tkg_batch_size": 128, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "enable_bucketing": true, + "context_encoding_buckets": [1024], + "token_generation_buckets": [1024], + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + }, + "use_index_calc_kernel": true, + "moe_mask_padded_tokens": true, + "blockwise_matmul_config": { + "use_shard_on_intermediate_dynamic_while": true, + "skip_dma_token": true + }, + "disable_numeric_cc_token": true, + "scratchpad_page_size": 1024 + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +run_bench "$CONFIG_NAME" 16 128 +run_bench "$CONFIG_NAME" 32 128 +run_bench "$CONFIG_NAME" 128 512 +stop_server + +echo "==========================================" +echo "MiMo-V2-Flash benchmarks complete!" +echo "Results saved to: $RESULTS_DIR" +echo "==========================================" +ls -la "$RESULTS_DIR" diff --git a/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch b/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch new file mode 100644 index 00000000..cb8c0421 --- /dev/null +++ b/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch @@ -0,0 +1,129 @@ +diff --git a/vllm_neuron/worker/neuronx_distributed_model_loader.py b/vllm_neuron/worker/neuronx_distributed_model_loader.py +index d2099eb..e246249 100644 +--- a/vllm_neuron/worker/neuronx_distributed_model_loader.py ++++ b/vllm_neuron/worker/neuronx_distributed_model_loader.py +@@ -41,7 +41,7 @@ from neuronx_distributed_inference.models.config import ( # yapf: disable + from neuronx_distributed_inference.modules.lora_serving import LoraServingConfig + from neuronx_distributed_inference.utils.constants import MODEL_TYPES + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config +-from transformers import AutoModelForCausalLM, PretrainedConfig ++from transformers import PretrainedConfig + from vllm.config import ( + CacheConfig, + ModelConfig, +@@ -186,8 +186,14 @@ class NeuronModelBase(nn.Module): + + neuron_config = neuronx_model_cls.get_neuron_config_cls()(**neuron_config_dict) + ++ # Use pre-loaded hf_config if available (loaded by vLLM with trust_remote_code=True) ++ hf_config = kwargs.get("hf_config") ++ if hf_config is not None: ++ load_config_fn = load_pretrained_config(hf_config=hf_config) ++ else: ++ load_config_fn = load_pretrained_config(model_name_or_path) + config = kwargs.get("config") or neuronx_model_cls.get_config_cls()( +- neuron_config, load_config=load_pretrained_config(model_name_or_path) ++ neuron_config, load_config=load_config_fn + ) + + # If fused speculation is enabled, attach the draft model config. +@@ -254,11 +260,10 @@ class NeuronModelBase(nn.Module): + "Using pre-compiled artifacts, override_neuron_config will be ignored" + ) + +- def _save_pretrained_model(self, model_name: str): +- hf_model = AutoModelForCausalLM.from_pretrained(model_name) +- saved_path = os.path.join("local-models", model_name) +- hf_model.save_pretrained(saved_path) +- return saved_path ++ def _get_model_path(self, model_name: str): ++ """Get local path for model, using HuggingFace cache if available.""" ++ from huggingface_hub import snapshot_download ++ return snapshot_download(repo_id=model_name, trust_remote_code=True) + + def _compile_and_load_model( + self, model_path: str, neuronx_model_cls, config, compiled_path: str +@@ -565,7 +570,7 @@ class NeuronCausalLM(NeuronModelBase): + + if not success: + if not os.path.exists(model_name_or_path): +- model_name_or_path = self._save_pretrained_model(model_name_or_path) ++ model_name_or_path = self._get_model_path(model_name_or_path) + self._compile_and_load_model( + model_name_or_path, neuronx_model_cls, config, compiled_model_path + ) +@@ -611,10 +616,15 @@ class NeuronMultiModalCausalLM(NeuronCausalLM): + **text_neuron_config + ) + ++ hf_config = kwargs.get("hf_config") ++ if hf_config is not None: ++ load_config_fn = load_pretrained_config(hf_config=hf_config) ++ else: ++ load_config_fn = load_pretrained_config(model_name_or_path) + config = neuronx_model_cls.get_config_cls()( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, +- load_config=load_pretrained_config(model_name_or_path), ++ load_config=load_config_fn, + ) + + success, compiled_model_path, _ = self._load_weights_common( +@@ -623,7 +633,7 @@ class NeuronMultiModalCausalLM(NeuronCausalLM): + + if not success: + if not os.path.exists(model_name_or_path): +- model_name_or_path = self._save_pretrained_model(model_name_or_path) ++ model_name_or_path = self._get_model_path(model_name_or_path) + + self._compile_and_load_model( + model_name_or_path, neuronx_model_cls, config, compiled_model_path +@@ -758,14 +768,6 @@ class NeuronPixtralForCausalLM(NeuronMultiModalCausalLM): + + + class NeuronQwen2VLForCausalLM(NeuronMultiModalCausalLM): +- # overwrite _save_pretrained_model as Qwen2VL is not in AutoModelForCausalLM +- def _save_pretrained_model(self, model_name: str): +- from transformers import Qwen2VLForConditionalGeneration +- +- hf_model = Qwen2VLForConditionalGeneration.from_pretrained(model_name) +- saved_path = os.path.join("local-models", model_name) +- hf_model.save_pretrained(saved_path) +- return saved_path + + def execute_model(self, model_input): + """Helper to run model with defaults for missing multimodal inputs.""" +@@ -819,13 +821,7 @@ class NeuronQwen2VLForCausalLM(NeuronMultiModalCausalLM): + + + class NeuronQwen3VLForCausalLM(NeuronQwen2VLForCausalLM): +- def _save_pretrained_model(self, model_name: str): +- from transformers import Qwen3VLForConditionalGeneration +- +- hf_model = Qwen3VLForConditionalGeneration.from_pretrained(model_name) +- saved_path = os.path.join("local-models", model_name) +- hf_model.save_pretrained(saved_path) +- return saved_path ++ pass + + + class NeuronLlama4ForCausalLM(NeuronMultiModalCausalLM): +@@ -964,6 +960,10 @@ def _get_neuron_model_cls(architecture: str): + if model == "qwen3moe": + model = "qwen3_moe" + ++ # MiMo is based on Qwen2 architecture ++ if model == "mimo": ++ model = "qwen2" ++ + if model == "qwen2vl": + model = "qwen2_vl" + +@@ -1050,6 +1050,7 @@ def get_neuron_model( + neuron_config=neuron_config, + override_neuron_config=override_neuron_config, + speculative_config=speculative_config, ++ hf_config=model_config.hf_config, + ) + model.neuron_config = model.model.config.neuron_config + model.architecture = architecture diff --git a/contrib/models/MiMo-V2-Flash/src/__init__.py b/contrib/models/MiMo-V2-Flash/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_fp8.py b/contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_fp8.py new file mode 100644 index 00000000..e96258f0 --- /dev/null +++ b/contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_fp8.py @@ -0,0 +1,630 @@ +""" +Preprocess MiMo-V2-Flash FP8 checkpoint for Neuron inference. + +The HuggingFace FP8 checkpoint cannot be directly used for inference on Neuron. +This script preprocesses the checkpoint to make it compatible. + +Steps: +1. Rescale FP8 weights from OCP format (range ±448) to Neuron format (range ±240) +2. Convert weight_scale_inv to .scale format (reciprocal + rescaling) +3. Fuse gate/up projections for MoE experts +4. Handle K/V weight and scale replication for CONVERT_TO_MHA mode +5. Save to preprocessed checkpoint directory + +Usage: + python preprocess_mimo_v2_fp8.py \ + --hf_model_path /path/to/MiMo-V2-Flash \ + --save_path /path/to/preprocessed_mimo_v2_fp8 \ + --tp_degree 32 \ + --convert_to_mha +""" + +import argparse +import gc +import json +import os +from typing import Dict, Any, List, Optional + +import torch + +from neuronx_distributed_inference.modules.checkpoint import ( + load_state_dict, + save_state_dict_safetensors, +) + + +# FP8 range difference between OCP (HuggingFace) and Neuron (IEEE-754) +# OCP FP8 E4M3/e4m3fn: range ±448 +# Neuron FP8 E4M3 (IEEE-754): range ±240 +FP8_SCALING_FACTOR = 448.0 / 240.0 + +# Neuron FP8 E4M3 max value +NEURON_FP8_MAX = 240.0 + + +def convert_bf16_to_fp8_per_row(weight: torch.Tensor): + """ + Convert BF16 weight to FP8 with per-row (per-channel) scales for Neuron. + + This is used for weights like o_proj that are BF16 in the original checkpoint. + The Neuron framework expects per-row scaling for these layers. + + Args: + weight: BF16 weight tensor [out_features, in_features] + + Returns: + Tuple of (fp8_weight, scale) + - fp8_weight: Weight quantized to FP8 (float8_e4m3fn) + - scale: Per-row scale tensor [out_features, 1] + """ + out_features, in_features = weight.shape + + # Compute per-row max absolute values + weight_float = weight.float() + row_max_abs = weight_float.abs().max(dim=1, keepdim=True)[0] + + # Compute scales (avoid division by zero) + scales = row_max_abs / NEURON_FP8_MAX + scales = torch.clamp(scales, min=1e-10) + + # Quantize + quantized = (weight_float / scales).to(torch.float8_e4m3fn) + + return quantized, scales.to(torch.float32) + + +def convert_bf16_to_fp8_blockwise( + weight: torch.Tensor, + block_size: List[int] = [128, 128], +): + """ + Convert BF16 weight to FP8 with block-wise scales for Neuron. + + Some weights in MiMo-V2-Flash (like o_proj) are in BF16, not FP8. + This function quantizes them to FP8 with appropriate block-wise scales. + + Args: + weight: BF16 weight tensor [out_features, in_features] + block_size: Block size for quantization [128, 128] + + Returns: + Tuple of (fp8_weight, scale) + - fp8_weight: Weight quantized to FP8 (float8_e4m3fn) + - scale: Block-wise scale tensor [scale_h, scale_w] + """ + h, w = weight.shape + block_h, block_w = block_size + + # Calculate scale grid dimensions + scale_h = (h + block_h - 1) // block_h + scale_w = (w + block_w - 1) // block_w + + # Initialize output tensors + fp8_weight = torch.zeros_like(weight, dtype=torch.float8_e4m3fn) + scale = torch.zeros(scale_h, scale_w, dtype=torch.float32) + + # Process each block + for i in range(scale_h): + for j in range(scale_w): + # Block boundaries + h_start = i * block_h + h_end = min((i + 1) * block_h, h) + w_start = j * block_w + w_end = min((j + 1) * block_w, w) + + # Extract block + block = weight[h_start:h_end, w_start:w_end].float() + + # Compute scale: max_abs / FP8_MAX + max_abs = block.abs().max().item() + if max_abs == 0: + block_scale = 1.0 + else: + block_scale = max_abs / NEURON_FP8_MAX + + # Quantize block + quantized_block = (block / block_scale).to(torch.float8_e4m3fn) + + # Store results + fp8_weight[h_start:h_end, w_start:w_end] = quantized_block + scale[i, j] = block_scale + + return fp8_weight, scale + + +def rescale_fp8_to_per_row(weight: torch.Tensor, scale: torch.Tensor): + """ + Rescale FP8 weight from OCP format to Neuron format with per-row scaling. + + The original HuggingFace checkpoint uses block-wise FP8 quantization. + The Neuron framework expects per-row (per-channel) scaling. + This function converts block-wise to per-row scaling. + + Args: + weight: FP8 weight tensor (float8_e4m3fn) [out_features, in_features] + scale: Block-wise scale tensor (weight_scale_inv) [scale_h, scale_w] + + Returns: + Tuple of (rescaled_weight, neuron_scale) + - rescaled_weight: FP8 weight compatible with Neuron + - neuron_scale: Per-row scale [out_features, 1] + """ + out_features, in_features = weight.shape + scale_h, scale_w = scale.shape + + # Block size inferred from scale dimensions + block_h = (out_features + scale_h - 1) // scale_h + block_w = (in_features + scale_w - 1) // scale_w + + # First dequantize using block-wise scales + # HF convention: original = fp8_weight * weight_scale_inv + weight_float = weight.float() + dequantized = torch.zeros(out_features, in_features, dtype=torch.float32) + + for i in range(scale_h): + for j in range(scale_w): + h_start = i * block_h + h_end = min((i + 1) * block_h, out_features) + w_start = j * block_w + w_end = min((j + 1) * block_w, in_features) + + block_scale = scale[i, j].item() + dequantized[h_start:h_end, w_start:w_end] = ( + weight_float[h_start:h_end, w_start:w_end] * block_scale + ) + + # Now requantize with per-row scaling for Neuron + # Compute per-row max absolute values + row_max_abs = dequantized.abs().max(dim=1, keepdim=True)[0] + + # Compute scales (avoid division by zero) + # Need to fit in Neuron FP8 range (±240) + scales = row_max_abs / NEURON_FP8_MAX + scales = torch.clamp(scales, min=1e-10) + + # Quantize to FP8 + quantized = (dequantized / scales).to(torch.float8_e4m3fn) + + return quantized, scales.to(torch.float32) + + +def rescale_fp8_weight_blockwise(weight: torch.Tensor, scale: torch.Tensor): + """ + Rescale FP8 weight from OCP format to Neuron format, keeping block-wise scaling. + + This is kept for MoE experts which may use block-wise scaling. + + Args: + weight: FP8 weight tensor (float8_e4m3fn) + scale: Scale tensor (float32 or bfloat16), this is weight_scale_inv (1/scale) + + Returns: + Tuple of (rescaled_weight, neuron_scale) + - rescaled_weight: FP8 weight compatible with Neuron + - neuron_scale: Scale in Neuron format (direct scale, not reciprocal) + """ + # Convert weight to BF16 for rescaling + weight_bf16 = weight.bfloat16() + + # Divide by scaling factor to fit in Neuron's smaller range + rescaled_weight_bf16 = weight_bf16 / FP8_SCALING_FACTOR + + # Convert back to FP8 + rescaled_weight = rescaled_weight_bf16.to(torch.float8_e4m3fn) + + # After our rescaling: + # rescaled_weight = fp8_weight / FP8_SCALING_FACTOR + # We need: original = rescaled_weight * new_scale + # So: original = (fp8_weight / FP8_SCALING_FACTOR) * new_scale = fp8_weight * weight_scale_inv + # Therefore: new_scale = weight_scale_inv * FP8_SCALING_FACTOR + + neuron_scale = scale.float() * FP8_SCALING_FACTOR + + return rescaled_weight, neuron_scale.to(torch.float32) + + +def replicate_for_convert_to_mha( + weight: torch.Tensor, + scale: Optional[torch.Tensor], + num_kv_heads: int, + num_attention_heads: int, + head_dim: int, +): + """ + Replicate K/V weights and per-row scales for CONVERT_TO_MHA mode. + + When TP > num_kv_heads, we need to replicate K/V heads to match Q heads. + This uses repeat_interleave to create the correct GQA pattern. + + Args: + weight: FP8 K or V weight [num_kv_heads * head_dim, hidden_size] + scale: Per-row scale tensor [num_kv_heads * head_dim, 1] + num_kv_heads: Original number of KV heads + num_attention_heads: Target number of attention heads + head_dim: Dimension per head + + Returns: + Tuple of (replicated_weight, replicated_scale) + """ + if num_kv_heads >= num_attention_heads: + return weight, scale + + repeat_factor = num_attention_heads // num_kv_heads + + # Reshape weight to [num_kv_heads, head_dim, hidden_size] + weight_reshaped = weight.view(num_kv_heads, head_dim, -1) + + # Replicate using repeat_interleave (correct GQA pattern) + # This creates [h0, h0, ..., h1, h1, ...] pattern + weight_replicated = weight_reshaped.repeat_interleave(repeat_factor, dim=0) + + # Reshape back to [num_attention_heads * head_dim, hidden_size] + weight_replicated = weight_replicated.view(-1, weight_replicated.shape[-1]) + + if scale is None: + return weight_replicated, None + + # Replicate per-row scales + # Scale shape: [num_kv_heads * head_dim, 1] + # Reshape to [num_kv_heads, head_dim, 1] + scale_reshaped = scale.view(num_kv_heads, head_dim, -1) + + # Replicate scales + scale_replicated = scale_reshaped.repeat_interleave(repeat_factor, dim=0) + + # Reshape back to [num_attention_heads * head_dim, 1] + scale_replicated = scale_replicated.view(-1, scale_replicated.shape[-1]) + + return weight_replicated, scale_replicated + + +def process_mimo_v2_checkpoint( + hf_model_path: str, + save_path: str, + tp_degree: int = 32, + convert_to_mha: bool = True, +): + """ + Process MiMo-V2-Flash checkpoint for Neuron FP8 inference. + + Args: + hf_model_path: Path to HuggingFace MiMo-V2-Flash checkpoint + save_path: Path to save preprocessed checkpoint + tp_degree: Tensor parallelism degree + convert_to_mha: Whether to replicate K/V for CONVERT_TO_MHA mode + """ + print(f"Loading checkpoint from: {hf_model_path}", flush=True) + state_dict = load_state_dict(hf_model_path) + + # Load config + config_path = os.path.join(hf_model_path, "config.json") + with open(config_path, "r") as f: + config = json.load(f) + + # Extract model dimensions + num_layers = config["num_hidden_layers"] + hidden_size = config["hidden_size"] + num_attention_heads = config["num_attention_heads"] + num_kv_heads = config["num_key_value_heads"] # Full attention: 4 + swa_num_attention_heads = config["swa_num_attention_heads"] # Sliding window: 32 + swa_num_kv_heads = config["swa_num_key_value_heads"] # Sliding window: 8 + head_dim = config["head_dim"] # Q/K head dim: 192 + v_head_dim = config["v_head_dim"] # V head dim: 128 + swa_head_dim = config.get("swa_head_dim", head_dim) + swa_v_head_dim = config.get("swa_v_head_dim", v_head_dim) + + # Get hybrid layer pattern + hybrid_layer_pattern = config.get("hybrid_layer_pattern", [0] * num_layers) + + # MoE configuration + num_experts = config["n_routed_experts"] # 256 + moe_intermediate_size = config["moe_intermediate_size"] + moe_layer_freq = config.get("moe_layer_freq", [1] * num_layers) + + # Block size for quantization + quant_config = config.get("quantization_config", {}) + block_size = quant_config.get("weight_block_size", [128, 128]) + + print(f"\nModel configuration:", flush=True) + print(f" num_layers: {num_layers}", flush=True) + print(f" hidden_size: {hidden_size}", flush=True) + print(f" num_attention_heads: {num_attention_heads}", flush=True) + print(f" num_kv_heads (full): {num_kv_heads}", flush=True) + print(f" swa_num_kv_heads (sliding): {swa_num_kv_heads}", flush=True) + print(f" head_dim (Q/K): {head_dim}", flush=True) + print(f" v_head_dim: {v_head_dim}", flush=True) + print(f" num_experts: {num_experts}", flush=True) + print(f" moe_intermediate_size: {moe_intermediate_size}", flush=True) + print(f" block_size: {block_size}", flush=True) + print(f" tp_degree: {tp_degree}", flush=True) + print(f" convert_to_mha: {convert_to_mha}", flush=True) + + state_dict_keys = set(state_dict.keys()) + new_state_dict = {} + + # Process each layer + for layer_idx in range(num_layers): + print(f"\nProcessing layer {layer_idx}...", end="", flush=True) + + prefix = f"model.layers.{layer_idx}." + is_sliding_window = hybrid_layer_pattern[layer_idx] == 1 + is_moe_layer = moe_layer_freq[layer_idx] == 1 + + # Get layer-specific parameters + if is_sliding_window: + layer_num_heads = swa_num_attention_heads + layer_num_kv_heads = swa_num_kv_heads + layer_head_dim = swa_head_dim + layer_v_head_dim = swa_v_head_dim + else: + layer_num_heads = num_attention_heads + layer_num_kv_heads = num_kv_heads + layer_head_dim = head_dim + layer_v_head_dim = v_head_dim + + attn_type = "sliding_window" if is_sliding_window else "full" + print(f" ({attn_type}, kv_heads={layer_num_kv_heads})", end="", flush=True) + + # Process attention weights + for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: + weight_key = f"{prefix}self_attn.{proj}.weight" + scale_key = f"{prefix}self_attn.{proj}.weight_scale_inv" + + if weight_key not in state_dict_keys: + continue + + weight = state_dict[weight_key] + scale = state_dict.get(scale_key) + + # Handle FP8 weights - convert to per-row scaling for Neuron + # Neuron framework expects per-row (per-channel) scaling for attention layers + if weight.dtype == torch.float8_e4m3fn and scale is not None: + weight, scale = rescale_fp8_to_per_row(weight, scale) + # Handle BF16 weights (convert to FP8 with per-row scales) + elif weight.dtype == torch.bfloat16: + weight, scale = convert_bf16_to_fp8_per_row(weight) + + # NOTE: Do NOT apply CONVERT_TO_MHA replication here. + # The Neuron framework handles K/V replication internally. + # Pre-replicating would cause double-replication. + + # Save with Neuron naming convention + new_weight_key = f"layers.{layer_idx}.self_attn.{proj}.weight" + new_state_dict[new_weight_key] = weight + + if scale is not None: + new_scale_key = f"layers.{layer_idx}.self_attn.{proj}.scale" + new_state_dict[new_scale_key] = scale + + # Process layer norms (no FP8) + for norm in ["input_layernorm", "post_attention_layernorm"]: + weight_key = f"{prefix}{norm}.weight" + if weight_key in state_dict_keys: + new_key = f"layers.{layer_idx}.{norm}.weight" + new_state_dict[new_key] = state_dict[weight_key] + + # Process MoE router + router_key = f"{prefix}mlp.gate.weight" + if router_key in state_dict_keys: + new_key = f"layers.{layer_idx}.mlp.router.linear_router.weight" + new_state_dict[new_key] = state_dict[router_key] + + # Process MoE experts + if is_moe_layer: + # Prepare fused gate_up and down projections + gate_weights = [] + gate_scales = [] + up_weights = [] + up_scales = [] + down_weights = [] + down_scales = [] + + for expert_idx in range(num_experts): + expert_prefix = f"{prefix}mlp.experts.{expert_idx}." + + # Gate projection + gate_w_key = f"{expert_prefix}gate_proj.weight" + gate_s_key = f"{expert_prefix}gate_proj.weight_scale_inv" + + if gate_w_key in state_dict_keys: + gate_w = state_dict[gate_w_key] + gate_s = state_dict.get(gate_s_key) + + if gate_w.dtype == torch.float8_e4m3fn and gate_s is not None: + gate_w, gate_s = rescale_fp8_weight_blockwise(gate_w, gate_s) + elif gate_w.dtype == torch.bfloat16: + gate_w, gate_s = convert_bf16_to_fp8_blockwise(gate_w, block_size) + + gate_weights.append(gate_w.T) # Transpose for fusion + if gate_s is not None: + gate_scales.append(gate_s) + + # Up projection + up_w_key = f"{expert_prefix}up_proj.weight" + up_s_key = f"{expert_prefix}up_proj.weight_scale_inv" + + if up_w_key in state_dict_keys: + up_w = state_dict[up_w_key] + up_s = state_dict.get(up_s_key) + + if up_w.dtype == torch.float8_e4m3fn and up_s is not None: + up_w, up_s = rescale_fp8_weight_blockwise(up_w, up_s) + elif up_w.dtype == torch.bfloat16: + up_w, up_s = convert_bf16_to_fp8_blockwise(up_w, block_size) + + up_weights.append(up_w.T) # Transpose for fusion + if up_s is not None: + up_scales.append(up_s) + + # Down projection + down_w_key = f"{expert_prefix}down_proj.weight" + down_s_key = f"{expert_prefix}down_proj.weight_scale_inv" + + if down_w_key in state_dict_keys: + down_w = state_dict[down_w_key] + down_s = state_dict.get(down_s_key) + + if down_w.dtype == torch.float8_e4m3fn and down_s is not None: + down_w, down_s = rescale_fp8_weight_blockwise(down_w, down_s) + elif down_w.dtype == torch.bfloat16: + down_w, down_s = convert_bf16_to_fp8_blockwise(down_w, block_size) + + down_weights.append(down_w.T) # Transpose for fusion + if down_s is not None: + down_scales.append(down_s) + + # Fuse gate and up projections + if gate_weights and up_weights: + # Stack experts: [num_experts, hidden_size, intermediate_size] + gate_stacked = torch.stack(gate_weights, dim=0) + up_stacked = torch.stack(up_weights, dim=0) + + # Concatenate gate and up: [num_experts, hidden_size, 2 * intermediate_size] + gate_up_fused = torch.cat([gate_stacked, up_stacked], dim=2) + + new_key = f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + new_state_dict[new_key] = gate_up_fused + + # Fuse scales if present + if gate_scales and up_scales: + # Scales shape after transpose: [scale_h, scale_w] + # After stacking: [num_experts, scale_h, scale_w] + gate_s_stacked = torch.stack(gate_scales, dim=0) + up_s_stacked = torch.stack(up_scales, dim=0) + + # Concatenate scales along last dim + gate_up_scale = torch.cat([gate_s_stacked, up_s_stacked], dim=-1) + + new_scale_key = f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.gate_up_proj.scale" + new_state_dict[new_scale_key] = gate_up_scale + + # Down projection + if down_weights: + # Stack: [num_experts, intermediate_size, hidden_size] + down_stacked = torch.stack(down_weights, dim=0) + + new_key = f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.down_proj.weight" + new_state_dict[new_key] = down_stacked + + if down_scales: + down_s_stacked = torch.stack(down_scales, dim=0) + new_scale_key = f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.down_proj.scale" + new_state_dict[new_scale_key] = down_s_stacked + else: + # Non-MoE layer: regular MLP with gate_proj, up_proj, down_proj + for proj in ["gate_proj", "up_proj", "down_proj"]: + weight_key = f"{prefix}mlp.{proj}.weight" + scale_key = f"{prefix}mlp.{proj}.weight_scale_inv" + + if weight_key not in state_dict_keys: + continue + + weight = state_dict[weight_key] + scale = state_dict.get(scale_key) + + # Handle FP8 weights - convert to per-row scaling for Neuron + if weight.dtype == torch.float8_e4m3fn and scale is not None: + weight, scale = rescale_fp8_to_per_row(weight, scale) + # Handle BF16 weights (convert to FP8 with per-row scales) + elif weight.dtype == torch.bfloat16: + weight, scale = convert_bf16_to_fp8_per_row(weight) + + # Save with Neuron naming convention + new_weight_key = f"layers.{layer_idx}.mlp.{proj}.weight" + new_state_dict[new_weight_key] = weight + + if scale is not None: + new_scale_key = f"layers.{layer_idx}.mlp.{proj}.scale" + new_state_dict[new_scale_key] = scale + + gc.collect() + print(" done", flush=True) + + # Process embeddings and final layer norm + print("\nProcessing embeddings and final norm...", flush=True) + + if "model.embed_tokens.weight" in state_dict_keys: + new_state_dict["embed_tokens.weight"] = state_dict["model.embed_tokens.weight"] + + if "model.norm.weight" in state_dict_keys: + new_state_dict["norm.weight"] = state_dict["model.norm.weight"] + + if "lm_head.weight" in state_dict_keys: + new_state_dict["lm_head.weight"] = state_dict["lm_head.weight"] + elif "model.embed_tokens.weight" in state_dict_keys: + # Tied embeddings + new_state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"] + + # Save preprocessed checkpoint + print(f"\nSaving preprocessed checkpoint to: {save_path}", flush=True) + os.makedirs(save_path, exist_ok=True) + + save_state_dict_safetensors(new_state_dict, save_path) + + # Copy config.json + import shutil + shutil.copy(config_path, os.path.join(save_path, "config.json")) + + # Copy tokenizer files + for tokenizer_file in ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"]: + src_path = os.path.join(hf_model_path, tokenizer_file) + if os.path.exists(src_path): + shutil.copy(src_path, os.path.join(save_path, tokenizer_file)) + + print(f"\nPreprocessing complete!", flush=True) + print(f" Total parameters: {len(new_state_dict)}", flush=True) + + # Print FP8 weight count + fp8_count = sum(1 for v in new_state_dict.values() if v.dtype == torch.float8_e4m3fn) + scale_count = sum(1 for k in new_state_dict.keys() if k.endswith(".scale")) + print(f" FP8 weights: {fp8_count}", flush=True) + print(f" Scale parameters: {scale_count}", flush=True) + + +def main(): + parser = argparse.ArgumentParser( + description="Preprocess MiMo-V2-Flash FP8 checkpoint for Neuron inference" + ) + parser.add_argument( + "--hf_model_path", + type=str, + required=True, + help="Path to HuggingFace MiMo-V2-Flash checkpoint", + ) + parser.add_argument( + "--save_path", + type=str, + required=True, + help="Path to save preprocessed checkpoint", + ) + parser.add_argument( + "--tp_degree", + type=int, + default=32, + help="Tensor parallelism degree (default: 32)", + ) + parser.add_argument( + "--convert_to_mha", + action="store_true", + default=True, + help="Replicate K/V for CONVERT_TO_MHA mode (default: True)", + ) + parser.add_argument( + "--no_convert_to_mha", + action="store_false", + dest="convert_to_mha", + help="Disable K/V replication", + ) + + args = parser.parse_args() + + process_mimo_v2_checkpoint( + hf_model_path=args.hf_model_path, + save_path=args.save_path, + tp_degree=args.tp_degree, + convert_to_mha=args.convert_to_mha, + ) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py new file mode 100644 index 00000000..4ea33fb0 --- /dev/null +++ b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py @@ -0,0 +1,1333 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# This implementation is based on the MiMo-V2-Flash model from Xiaomi. +# Reference: https://huggingface.co/XiaomiMiMo/MiMo-V2-Flash + +"""MiMo-V2-Flash model for NXD inference.""" + +import gc +import math +import warnings +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers.mappings import ( + gather_from_sequence_parallel_region, + gather_from_tensor_model_parallel_region_with_dim, +) +from neuronx_distributed.utils import cpu_mode + +from neuronx_distributed_inference.utils.distributed import ( + split_along_dim, + get_cp_rank, +) +from neuronx_distributed_inference.modules.attention.attention_process_groups import ( + get_context_parallel_attention_cp_group, +) + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + MoENeuronConfig, +) +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.moe_v2 import initialize_moe_module + +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel + +from torch_neuronx.xla_impl.ops import nki_jit + +_flash_fwd_call = nki_jit()(attention_isa_kernel) + + +def get_rmsnorm_cls(): + """Get appropriate RMSNorm class based on execution environment.""" + return MiMoV2RMSNorm if cpu_mode() else CustomRMSNorm + + +class MiMoV2RMSNorm(nn.Module): + """RMSNorm implementation for CPU mode.""" + + def __init__(self, hidden_size: int, eps: float = 1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class MiMoV2RotaryEmbedding(nn.Module): + """Rotary Position Embedding for MiMo-V2-Flash. + + Supports partial rotary embedding where only a fraction of dimensions + use rotary position encoding. + """ + + def __init__( + self, + dim: int, + max_position_embeddings: int = 262144, + base: float = 5000000.0, + partial_rotary_factor: float = 1.0, + ): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.partial_rotary_factor = partial_rotary_factor + + # Calculate the actual dimension used for rotary embedding + self.rope_dim = int(dim * partial_rotary_factor) + # Ensure rope_dim is even + self.rope_dim = self.rope_dim - (self.rope_dim % 2) + + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.rope_dim, 2, dtype=torch.float32) / self.rope_dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute rotary embeddings. + + Args: + x: Input tensor of shape (batch_size, seq_len, hidden_size) + position_ids: Position indices of shape (batch_size, seq_len) + + Returns: + Tuple of (cos, sin) tensors for rotary embedding + """ + inv_freq_expanded = self.inv_freq[None, :, None].float().expand( + position_ids.shape[0], -1, 1 + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotate half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + unsqueeze_dim: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply rotary position embedding to query and key tensors.""" + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MiMoV2InferenceConfig(InferenceConfig): + """Configuration class for MiMo-V2-Flash inference on Neuron.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # MoE configuration + self.num_local_experts = self.n_routed_experts + self.n_shared_experts = 0 # MiMo-V2-Flash has no shared experts + + # Set intermediate_size for MoE layers + self.intermediate_size = self.moe_intermediate_size + + # Check and pad intermediate size if needed + self.maybe_pad_intermediate() + + # Router configuration + self.neuron_config.router_config.dtype = torch.float32 + self.neuron_config.router_config.act_fn = "sigmoid" # MiMo uses sigmoid + + # Disable numeric CC token as workaround + self.neuron_config.disable_numeric_cc_token = True + + # MiMo normalizes top-k affinities + self.neuron_config.normalize_top_k_affinities = True + + # Parse hybrid layer pattern + self._parse_hybrid_pattern() + + def _parse_hybrid_pattern(self): + """Parse hybrid layer pattern to determine attention types.""" + if hasattr(self, 'hybrid_layer_pattern') and self.hybrid_layer_pattern: + self.layer_attention_types = [ + "sliding_window" if p == 1 else "full" + for p in self.hybrid_layer_pattern + ] + else: + self.layer_attention_types = ["full"] * self.num_hidden_layers + + # Parse MoE layer frequency + if hasattr(self, 'moe_layer_freq') and self.moe_layer_freq: + self.layer_uses_moe = [bool(f) for f in self.moe_layer_freq] + else: + self.layer_uses_moe = [True] * self.num_hidden_layers + + def maybe_pad_intermediate(self): + """Pad intermediate size if required for efficient computation.""" + from neuronx_distributed_inference.models.config import ( + SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP, + ) + + moe_tp_degree = self.neuron_config.moe_tp_degree + I_TP = self.moe_intermediate_size // moe_tp_degree + + if getattr( + self.neuron_config.blockwise_matmul_config, + "use_shard_on_intermediate_dynamic_while", + False, + ): + if I_TP % SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP != 0: + padded_size = ( + math.ceil(I_TP / SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP) + * SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP + * moe_tp_degree + ) + self.moe_intermediate_pad_size = max( + padded_size - self.moe_intermediate_size, 0 + ) + self.moe_intermediate_size = padded_size + + def get_required_attributes(self) -> List[str]: + return [ + "attention_bias", + "head_dim", + "hidden_act", + "hidden_size", + "hybrid_layer_pattern", + "layernorm_epsilon", + "max_position_embeddings", + "moe_intermediate_size", + "moe_layer_freq", + "n_routed_experts", + "norm_topk_prob", + "num_attention_heads", + "num_experts_per_tok", + "num_hidden_layers", + "num_key_value_heads", + "partial_rotary_factor", + "rope_theta", + "scoring_func", + "sliding_window", + "swa_head_dim", + "swa_num_attention_heads", + "swa_num_key_value_heads", + "swa_rope_theta", + "swa_v_head_dim", + "tie_word_embeddings", + "v_head_dim", + "vocab_size", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[MoENeuronConfig]: + return MoENeuronConfig + + +class NeuronMiMoV2Attention(NeuronAttentionBase): + """MiMo-V2-Flash Attention implementation supporting hybrid attention patterns. + + Supports both full attention and sliding window attention with different + head dimensions for Q/K vs V. + """ + + def __init__( + self, + config: MiMoV2InferenceConfig, + layer_idx: int, + is_sliding_window: bool = False, + ): + self.layer_idx = layer_idx + self.is_sliding_window = is_sliding_window + + # Select parameters based on attention type + if is_sliding_window: + self.attn_head_dim = config.swa_head_dim + self.attn_v_head_dim = config.swa_v_head_dim + self.attn_num_heads = config.swa_num_attention_heads + self.attn_num_kv_heads = config.swa_num_key_value_heads + rope_theta = getattr(config, 'swa_rope_theta', 10000.0) + self.sliding_window_size = config.sliding_window + else: + self.attn_head_dim = config.head_dim + self.attn_v_head_dim = config.v_head_dim + self.attn_num_heads = config.num_attention_heads + self.attn_num_kv_heads = config.num_key_value_heads + rope_theta = config.rope_theta + self.sliding_window_size = None + + # Calculate partial rotary dimensions + self.partial_rotary_factor = config.partial_rotary_factor + self.rope_dim = int(self.attn_head_dim * self.partial_rotary_factor) + self.rope_dim = self.rope_dim - (self.rope_dim % 2) # Ensure even + self.nope_dim = self.attn_head_dim - self.rope_dim + + # Create rotary embedding + rotary_emb = MiMoV2RotaryEmbedding( + dim=self.attn_head_dim, + max_position_embeddings=config.max_position_embeddings, + base=rope_theta, + partial_rotary_factor=self.partial_rotary_factor, + ) + + # Initialize base attention + # NOTE: We pass v_head_dim to base class, but MiMo uses asymmetric Q/K (192) vs V (128). + # We override init_gqa_properties() to prevent the base class from creating + # incompatible projection layers (which cause crashes when CP > 1). + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=self.attn_num_heads, + num_key_value_heads=self.attn_num_kv_heads, + head_dim=self.attn_v_head_dim, # Use v_head_dim for base class + rotary_emb=rotary_emb, + rms_norm_eps=config.layernorm_epsilon, + use_qk_norm=False, + ) + + # Initialize MiMo-specific projections with correct dimensions + self._init_projections(config) + + # Scaling factor + self.scaling = self.attn_head_dim ** -0.5 + # NOTE: The config may have 'attention_value_scale' (e.g., 0.707), but the HF model + # (modeling_mimo_v2_flash.py) does NOT use this value. The HF model only uses + # head_dim ** -0.5 for attention scaling, which is already applied via self.scaling. + # We must NOT apply attention_value_scale here, as it would cause divergence from HF. + self.value_scale = 1.0 + + # Store cache KV heads for cache compatibility + # With CONVERT_TO_MHA, all layers have num_attention_heads KV heads + # Otherwise, use max of full and sliding window kv heads + tp_degree = config.neuron_config.tp_degree + if self.use_gqa_convert_to_mha: + # CONVERT_TO_MHA: cache stores num_attention_heads (same as Q heads) + self.cache_num_kv_heads = self.attn_num_heads + self.local_cache_kv_heads = self.local_num_heads + else: + # Standard GQA: cache uses max of full and sliding window kv heads + self.cache_num_kv_heads = max( + config.num_key_value_heads, + getattr(config, 'swa_num_key_value_heads', config.num_key_value_heads) + ) + self.local_cache_kv_heads = max(1, self.cache_num_kv_heads // tp_degree) + + def init_gqa_properties(self): + """Override base class to prevent creating incompatible QKV projections. + + MiMo-V2-Flash has asymmetric Q/K head_dim (192) vs V head_dim (128), + which is incompatible with the base class's GroupQueryAttention_QKV. + MiMo uses its own custom projections via _init_projections() instead. + + When CP > 1, the base class would create cte_qkv_proj/tkg_qkv_proj with + wrong head_dim=128, causing compilation crashes. This no-op prevents that. + """ + pass + + def _init_projections(self, config: MiMoV2InferenceConfig): + """Initialize projection layers with correct dimensions. + + When CONVERT_TO_MHA is needed (tp_degree > num_kv_heads), K/V projections + are sized for num_attention_heads (not original num_kv_heads). The checkpoint + weights are replicated in preshard_hook before loading. + """ + dtype = config.neuron_config.torch_dtype + tp_degree = config.neuron_config.tp_degree + + # Check if we need GQA CONVERT_TO_MHA (when tp_degree > num_kv_heads) + self.use_gqa_convert_to_mha = tp_degree > self.attn_num_kv_heads + + # Store source heads for preshard_hook + self._src_num_kv_heads = self.attn_num_kv_heads + self._kv_replication_factor = self.attn_num_heads // self.attn_num_kv_heads if self.use_gqa_convert_to_mha else 1 + + if self.use_gqa_convert_to_mha: + # CONVERT_TO_MHA: K and V use num_attention_heads for proper TP splitting + k_num_heads = self.attn_num_heads + v_num_heads = self.attn_num_heads + else: + k_num_heads = self.attn_num_kv_heads + v_num_heads = self.attn_num_kv_heads + + # Q/K use head_dim, V uses v_head_dim + q_hidden_size = self.attn_num_heads * self.attn_head_dim + k_hidden_size = k_num_heads * self.attn_head_dim + v_hidden_size = v_num_heads * self.attn_v_head_dim + o_hidden_size = self.attn_num_heads * self.attn_v_head_dim + + if parallel_state.model_parallel_is_initialized(): + tp_group = parallel_state.get_tensor_model_parallel_group() + + # Q projection + self.q_proj = ColumnParallelLinear( + config.hidden_size, + q_hidden_size, + bias=config.attention_bias, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + + # K projection + self.k_proj = ColumnParallelLinear( + config.hidden_size, + k_hidden_size, + bias=config.attention_bias, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + + # V projection + self.v_proj = ColumnParallelLinear( + config.hidden_size, + v_hidden_size, + bias=config.attention_bias, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + + # Output projection - with sequence parallel to scatter output + self.o_proj = RowParallelLinear( + o_hidden_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + dtype=dtype, + tensor_model_parallel_group=tp_group, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=1 if self.sequence_parallel_enabled else None, + ) + + # Calculate local dimensions after TP split + self.local_num_heads = self.attn_num_heads // tp_degree + if self.use_gqa_convert_to_mha: + # With CONVERT_TO_MHA, local KV heads = local Q heads + self.local_num_kv_heads = self.local_num_heads + else: + self.local_num_kv_heads = max(1, self.attn_num_kv_heads // tp_degree) + else: + self.q_proj = nn.Linear(config.hidden_size, q_hidden_size, bias=config.attention_bias) + self.k_proj = nn.Linear(config.hidden_size, k_hidden_size, bias=config.attention_bias) + self.v_proj = nn.Linear(config.hidden_size, v_hidden_size, bias=config.attention_bias) + self.o_proj = nn.Linear(o_hidden_size, config.hidden_size, bias=False) + + self.local_num_heads = self.attn_num_heads + self.local_num_kv_heads = k_num_heads + + # Override base class attributes that were computed with wrong head_dim + # The base class init_gqa_properties() uses head_dim=v_head_dim which is wrong for Q/K + # We need to override these to ensure correct computation + self.num_heads = self.local_num_heads + self.num_key_value_heads = self.local_num_kv_heads + self.num_key_value_groups = self.local_num_heads // self.local_num_kv_heads + self.head_dim = self.attn_head_dim # Override to use actual Q/K head_dim (192) + + # Remove qkv_proj from base class if exists (we use separate q_proj, k_proj, v_proj) + if hasattr(self, 'qkv_proj'): + self.qkv_proj = None + + # Attention sink bias for attention layers (following HF implementation) + # This is a learnable parameter that allows attention to "sink" to an extra position + add_full_attention_sink_bias = getattr(config, 'add_full_attention_sink_bias', False) + add_swa_attention_sink_bias = getattr(config, 'add_swa_attention_sink_bias', True) + + # Determine if this layer uses sink bias based on config + self._use_sink_bias = (add_full_attention_sink_bias and not self.is_sliding_window) or \ + (add_swa_attention_sink_bias and self.is_sliding_window) + + if self._use_sink_bias: + # Shape: [num_attention_heads] - will be split across TP ranks + # The weight is loaded from checkpoint with shape [num_attention_heads] + # and will be sliced to [local_num_heads] during forward + self.attention_sink_bias = nn.Parameter( + torch.zeros(self.attn_num_heads, dtype=dtype), requires_grad=False + ) + else: + self.attention_sink_bias = None + + def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: + """Pre-shard hook to replicate K/V weights for CONVERT_TO_MHA. + + NOTE: This method is NOT currently called because NeuronMiMoV2Attention + is not a BaseGroupQueryAttention subclass. K/V weight replication is + instead done in convert_mimo_v2_hf_to_neuron_state_dict(). + + This method is kept for reference and potential future use. + """ + # This hook is not called - see note above + return False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[torch.Tensor] = None, + cos_cache: Optional[torch.Tensor] = None, + sin_cache: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """Forward pass for MiMo-V2-Flash attention with Context Parallelism support.""" + + # Context Parallelism: only active during context encoding (no past_key_value) + is_context_parallel = past_key_value is None and self.cp_degree > 1 + cp_rank = None + + if is_context_parallel: + cp_rank = get_cp_rank( + self.rank_util.get_rank(), self.tp_degree, + self.cp_degree, self.neuron_config.switch_cc, + ) + # Split attention_mask (dim=2 = Q rows) and position_ids (dim=1 = seq) + attention_mask = split_along_dim( + attention_mask, dim=2, rank=cp_rank, num_partitions=self.cp_degree + ) + # Keep full position_ids for RoPE computation on full-length K/V + local_position_ids = split_along_dim( + position_ids, dim=1, rank=cp_rank, num_partitions=self.cp_degree + ) + + # Handle sequence parallel + if self.sequence_parallel_enabled and parallel_state.model_parallel_is_initialized(): + hidden_states = gather_from_sequence_parallel_region( + hidden_states, + self.sequence_dimension, + process_group=parallel_state.get_tensor_model_parallel_group(), + ) + + # Context Parallelism without sequence parallel: split hidden_states + if is_context_parallel and not self.sequence_parallel_enabled: + hidden_states = split_along_dim( + hidden_states, dim=1, rank=cp_rank, num_partitions=self.cp_degree + ) + + bsz, q_len, _ = hidden_states.size() + + # Determine if this is token generation (past_key_value is not None) + is_token_gen = past_key_value is not None + + # Project Q, K, V + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Reshape for multi-head attention: [bsz, num_heads, seq_len, head_dim] + query_states = query_states.view(bsz, q_len, self.local_num_heads, self.attn_head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.local_num_kv_heads, self.attn_head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.local_num_kv_heads, self.attn_v_head_dim).transpose(1, 2) + + # Split into rope and non-rope parts + query_rope = query_states[..., :self.rope_dim] + query_nope = query_states[..., self.rope_dim:] + key_rope = key_states[..., :self.rope_dim] + key_nope = key_states[..., self.rope_dim:] + + # Compute rotary embeddings + # IMPORTANT: Always compute for this layer because different layer types + # (full vs sliding window) use different rope_theta values. + # Full attention: rope_theta = 5000000 + # Sliding window: rope_theta = 10000 + # We cannot reuse cached cos/sin from other layers! + # + # For CP with sequence_parallel: Q/K/V have full S, use full position_ids for RoPE. + # For CP without sequence_parallel: Q/K/V have S/CP, use local_position_ids for RoPE + # (local_position_ids contain the correct global positions for this CP rank). + if is_context_parallel and not self.sequence_parallel_enabled: + rope_position_ids = local_position_ids + else: + rope_position_ids = position_ids + cos_cache, sin_cache = self.rotary_emb(value_states, rope_position_ids) + + # Apply rotary position embedding to rope parts only + query_rope, key_rope = apply_rotary_pos_emb( + query_rope, key_rope, cos_cache, sin_cache, rope_position_ids + ) + + # Concatenate rope and non-rope parts + query_states = torch.cat([query_rope, query_nope], dim=-1) + key_states = torch.cat([key_rope, key_nope], dim=-1) + + # Context Parallelism: split Q and save local KV for cache + if is_context_parallel: + if self.sequence_parallel_enabled: + # Q/K/V have full S. Split Q to local portion, save local KV for cache. + # Use split_along_dim (torch.index_select) instead of Python slicing + # because XLA tracing doesn't support dynamic tensor indices in slice notation. + query_states = split_along_dim(query_states, dim=2, rank=cp_rank, num_partitions=self.cp_degree) + key_states_for_cache = split_along_dim(key_states, dim=2, rank=cp_rank, num_partitions=self.cp_degree) + value_states_for_cache = split_along_dim(value_states, dim=2, rank=cp_rank, num_partitions=self.cp_degree) + q_len = q_len // self.cp_degree + # K/V stay at full S for attention computation + else: + # Q/K/V have S/CP. Save local KV for cache, then all-gather K/V. + key_states_for_cache = key_states + value_states_for_cache = value_states + key_states = gather_from_tensor_model_parallel_region_with_dim( + key_states, gather_dim=2, + process_group=get_context_parallel_attention_cp_group(), + ) + value_states = gather_from_tensor_model_parallel_region_with_dim( + value_states, gather_dim=2, + process_group=get_context_parallel_attention_cp_group(), + ) + # Q stays at S/CP + else: + # Store key/value states BEFORE GQA repeat for KV cache + key_states_for_cache = key_states + value_states_for_cache = value_states + + # WORKAROUND 1: Pad V from v_head_dim (128) to head_dim (192) for KV cache compatibility + if self.attn_v_head_dim < self.attn_head_dim: + pad_size = self.attn_head_dim - self.attn_v_head_dim + value_states_for_cache = F.pad(value_states_for_cache, (0, pad_size), value=0.0) + + # WORKAROUND 2: Pad KV heads if layer has fewer than cache expects + # Only needed when NOT using CONVERT_TO_MHA (standard GQA mode) + if not self.use_gqa_convert_to_mha and self.local_num_kv_heads < self.local_cache_kv_heads: + # Pad KV heads by repeating + repeat_factor = self.local_cache_kv_heads // self.local_num_kv_heads + key_states_for_cache = key_states_for_cache.repeat(1, repeat_factor, 1, 1) + value_states_for_cache = value_states_for_cache.repeat(1, repeat_factor, 1, 1) + + # Repeat KV heads for GQA (only needed without CONVERT_TO_MHA) + # With CONVERT_TO_MHA, K/V already have num_attention_heads + num_key_value_groups = self.local_num_heads // self.local_num_kv_heads + if num_key_value_groups > 1: + key_states = key_states.repeat_interleave(num_key_value_groups, dim=1) + value_states = value_states.repeat_interleave(num_key_value_groups, dim=1) + + if is_token_gen: + # Token generation: use decomposed attention with prior (cached) and active (current) KV + # past_key_value[0] = cached K, shape [bsz, cache_kv_heads, kv_seq_len, head_dim] + # past_key_value[1] = cached V, shape [bsz, cache_kv_heads, kv_seq_len, head_dim] (padded) + K_prior = past_key_value[0] + V_prior = past_key_value[1] + + # WORKAROUND 1: Slice KV heads if cache has more than layer needs + # Only needed when NOT using CONVERT_TO_MHA (standard GQA mode) + # With CONVERT_TO_MHA, cache and layer have same num_kv_heads + if not self.use_gqa_convert_to_mha and self.local_num_kv_heads < self.local_cache_kv_heads: + # Cache has repeated heads, just take the first local_num_kv_heads + K_prior = K_prior[:, :self.local_num_kv_heads, :, :] + V_prior = V_prior[:, :self.local_num_kv_heads, :, :] + + # WORKAROUND 2: Slice V_prior back to v_head_dim (128) from head_dim (192) + if self.attn_v_head_dim < self.attn_head_dim: + V_prior = V_prior[..., :self.attn_v_head_dim] + + # Repeat cached KV for GQA (only needed without CONVERT_TO_MHA) + # With CONVERT_TO_MHA, cached K/V already have num_attention_heads + if num_key_value_groups > 1: + K_prior = K_prior.repeat_interleave(num_key_value_groups, dim=1) + V_prior = V_prior.repeat_interleave(num_key_value_groups, dim=1) + + # Compute attention on prior (cached) KV + # K_prior shape: [bsz, num_heads, kv_seq_len, head_dim] + prior_scores = torch.matmul(query_states, K_prior.transpose(-2, -1)) * self.scaling + + # Apply attention mask to prior scores + if attention_mask is not None: + # Convert boolean mask to additive mask if needed + if attention_mask.dtype == torch.bool: + prior_scores = prior_scores.masked_fill(~attention_mask, float('-inf')) + else: + prior_scores = prior_scores + attention_mask + + # Apply sliding window mask for SWA layers + if self.is_sliding_window and self.sliding_window_size is not None and position_ids is not None: + kv_seq_len = prior_scores.size(-1) + current_pos = position_ids[0, 0] + pos_indices = torch.arange(kv_seq_len, device=prior_scores.device) + sliding_mask = pos_indices >= (current_pos - self.sliding_window_size + 1) + sliding_mask = sliding_mask[None, None, None, :] + prior_scores = prior_scores.masked_fill(~sliding_mask, float('-inf')) + + prior_scores = prior_scores.to(torch.float32) + + # Compute attention on active (current) KV + active_scores = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling + active_scores = active_scores.to(torch.float32) + + # Combined softmax over prior and active scores + all_scores = torch.cat([prior_scores, active_scores], dim=-1) + + # Add attention sink bias (following HF implementation) + # This must be applied to token generation as well! + use_sink = self._use_sink_bias and self.attention_sink_bias is not None + if use_sink: + tp_rank = parallel_state.get_tensor_model_parallel_rank() if parallel_state.model_parallel_is_initialized() else 0 + local_sink = self.attention_sink_bias[tp_rank * self.local_num_heads:(tp_rank + 1) * self.local_num_heads] + sink_bias = local_sink.reshape(1, -1, 1, 1).expand(bsz, -1, q_len, 1) + all_scores = torch.cat([all_scores, sink_bias], dim=-1) + + # Numerical stability: subtract max before softmax + all_scores = all_scores - all_scores.max(dim=-1, keepdim=True).values + attn_weights = F.softmax(all_scores, dim=-1, dtype=torch.float32) + + # Drop the sink column after softmax + if use_sink: + attn_weights = attn_weights[..., :-1] + + # Split attention weights back + prior_weights = attn_weights[..., :-q_len].to(V_prior.dtype) + active_weights = attn_weights[..., -q_len:].to(value_states.dtype) + + # Compute attention outputs + attn_prior = torch.matmul(prior_weights, V_prior) + attn_active = torch.matmul(active_weights, value_states) + attn_output = attn_prior + attn_active + else: + # Context encoding: standard attention + # With CP: Q is local [B, H, S/CP, D], K/V are full [B, H, S, D] + # Without CP: Q/K/V all have same seq_len + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling + + # Apply attention mask (additive mask: 0 = attend, -inf = mask out) + # The framework creates boolean masks, so we need to convert them + # With CP: attention_mask is already split to [B, 1, S/CP, S] (local Q rows, full K cols) + if attention_mask is not None: + # Convert boolean mask to additive mask if needed + if attention_mask.dtype == torch.bool: + # True = attend (0), False = mask (-inf) + additive_mask = torch.zeros_like(attn_weights) + additive_mask = additive_mask.masked_fill(~attention_mask, float('-inf')) + attn_weights = attn_weights + additive_mask + else: + # Already additive mask + attn_weights = attn_weights + attention_mask + + # Apply sliding window mask for SWA layers + if self.is_sliding_window and self.sliding_window_size is not None: + kv_seq_len = attn_weights.size(-1) + if is_context_parallel: + # With CP: Q has local seq len, K has full seq len. + # Use local_position_ids for correct global Q positions. + row_idx = local_position_ids[0].unsqueeze(1).to(attn_weights.device) + else: + row_idx = torch.arange(kv_seq_len, device=attn_weights.device).unsqueeze(1) + col_idx = torch.arange(kv_seq_len, device=attn_weights.device).unsqueeze(0) + # Causal: col <= row, and within window: col >= row - window_size + 1 + sliding_mask = (col_idx <= row_idx) & (col_idx >= row_idx - self.sliding_window_size + 1) + sliding_mask = sliding_mask[None, None, :, :] + # Convert to additive mask + attn_weights = attn_weights.masked_fill(~sliding_mask, float('-inf')) + + # Add attention sink bias (following HF implementation) + # This adds an extra "sink" column to attention weights + use_sink = self._use_sink_bias and self.attention_sink_bias is not None + if use_sink: + # Get local portion of sink bias for this TP rank + tp_rank = parallel_state.get_tensor_model_parallel_rank() if parallel_state.model_parallel_is_initialized() else 0 + local_sink = self.attention_sink_bias[tp_rank * self.local_num_heads:(tp_rank + 1) * self.local_num_heads] + # Reshape and expand: [local_num_heads] -> [bsz, local_num_heads, q_len, 1] + sink_bias = local_sink.reshape(1, -1, 1, 1).expand(bsz, -1, q_len, 1) + attn_weights = torch.cat([attn_weights, sink_bias], dim=-1) + + # Numerical stability: subtract max before softmax (like HF implementation) + attn_weights = attn_weights - attn_weights.max(dim=-1, keepdim=True).values + + # Softmax + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32) + + # Drop the sink column after softmax + if use_sink: + attn_weights = attn_weights[..., :-1] + + attn_weights = attn_weights.to(value_states.dtype) + + # Apply attention to values + attn_output = torch.matmul(attn_weights, value_states) + + # Apply value scale if specified + if self.value_scale != 1.0: + attn_output = attn_output * self.value_scale + + # Reshape and project output + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.local_num_heads * self.attn_v_head_dim) + + # Context Parallelism: gather output across CP ranks BEFORE o_proj. + # With SP enabled, o_proj scatters along seq dim. The input must have full S + # (not S/CP), otherwise the SP-scattered output won't match the residual. + # Without SP, gather after o_proj to restore full seq_len for residual. + if is_context_parallel: + attn_output = gather_from_tensor_model_parallel_region_with_dim( + attn_output, gather_dim=1, + process_group=get_context_parallel_attention_cp_group(), + ) + + attn_output = self.o_proj(attn_output) + + # Prepare KV cache output - return as tuple for KV cache manager + # Return LOCAL key/value states for cache (each CP rank stores its portion) + new_key_value = (key_states_for_cache, value_states_for_cache) + + return attn_output, new_key_value, cos_cache, sin_cache + + +class MiMoV2MLP(nn.Module): + """Standard MLP for non-MoE layers in MiMo-V2-Flash.""" + + def __init__(self, config: MiMoV2InferenceConfig): + super().__init__() + self.hidden_size = config.hidden_size + # Use the dense intermediate size for non-MoE layers + self.intermediate_size = getattr(config, 'dense_intermediate_size', config.intermediate_size * 8) + + dtype = config.neuron_config.torch_dtype + + if parallel_state.model_parallel_is_initialized(): + tp_group = parallel_state.get_tensor_model_parallel_group() + + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + input_is_parallel=True, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + else: + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + self.act_fn = F.silu + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class NeuronMiMoV2DecoderLayer(nn.Module): + """MiMo-V2-Flash Decoder Layer with hybrid attention and conditional MoE.""" + + def __init__(self, config: MiMoV2InferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + # Determine attention type for this layer + is_sliding_window = config.layer_attention_types[layer_idx] == "sliding_window" + self.attention_type = "sliding_window" if is_sliding_window else "full" + + # Create attention module + self.self_attn = NeuronMiMoV2Attention( + config=config, + layer_idx=layer_idx, + is_sliding_window=is_sliding_window, + ) + + # Determine if this layer uses MoE + self.uses_moe = config.layer_uses_moe[layer_idx] + + # Create MLP/MoE module + if self.uses_moe: + self.mlp = initialize_moe_module(config=config) + else: + self.mlp = MiMoV2MLP(config) + + # Layer norms + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.layernorm_epsilon, + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.layernorm_epsilon, + ) + + # Config flags + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + cos_cache: Optional[torch.Tensor] = None, + sin_cache: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, ...]: + """Forward pass for decoder layer.""" + + # Self attention with residual + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cos_cache=cos_cache, + sin_cache=sin_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # MLP/MoE with residual + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if self.uses_moe: + hidden_states = self.mlp(hidden_states, padding_mask)[0] + else: + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value, cos_cache, sin_cache, None) + return outputs + + +class NeuronMiMoV2Model(NeuronBaseModel): + """MiMo-V2-Flash Model for NXD inference.""" + + def setup_attr_for_model(self, config: MiMoV2InferenceConfig): + self.on_device_sampling = config.neuron_config.on_device_sampling_config is not None + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + + # Check if we need GQA CONVERT_TO_MHA mode + # When tp_degree > num_kv_heads, we replicate K/V to match num_attention_heads + min_kv_heads = min( + config.num_key_value_heads, + getattr(config, 'swa_num_key_value_heads', config.num_key_value_heads) + ) + self.use_gqa_convert_to_mha = self.tp_degree > min_kv_heads + + if self.use_gqa_convert_to_mha: + # With CONVERT_TO_MHA, KV cache stores num_attention_heads (same as Q) + self.num_key_value_heads = config.num_attention_heads + else: + # Standard GQA: use the maximum num_kv_heads for KV cache + # (handles hybrid full/sliding window attention) + self.num_key_value_heads = max( + config.num_key_value_heads, + getattr(config, 'swa_num_key_value_heads', config.num_key_value_heads) + ) + + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + # MiMo has hybrid attention (full + sliding window) + # NOTE: Do NOT set self.sliding_window here because it affects KV cache size globally. + # MiMo handles sliding window per-layer in the attention module itself. + # Setting has_mixed_attn = True enables proper mask creation without affecting cache size. + self.has_mixed_attn = True + + def init_model(self, config: MiMoV2InferenceConfig): + self.padding_idx = getattr(config, 'pad_token_id', None) + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + + self.layers = nn.ModuleList([ + NeuronMiMoV2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + + self.norm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.layernorm_epsilon, + ) + + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + ) + + +def _replicate_kv_weights_for_convert_to_mha( + tensor: torch.Tensor, + source_heads: int, + target_heads: int, + head_dim: int, +) -> torch.Tensor: + """Replicate K/V weights from source_heads to target_heads for CONVERT_TO_MHA. + + Args: + tensor: Weight tensor of shape [source_heads * head_dim, hidden_size] + source_heads: Number of source KV heads + target_heads: Number of target heads (num_attention_heads) + head_dim: Head dimension + + Returns: + Replicated tensor of shape [target_heads * head_dim, hidden_size] + """ + if tensor is None or source_heads >= target_heads: + return tensor + + repeats = target_heads // source_heads + + # Reshape to [source_heads, head_dim, hidden_size] + original_shape = tensor.shape + tensor = tensor.view(source_heads, head_dim, -1) + + # Repeat along head dimension + tensor = tensor.repeat_interleave(repeats, dim=0) + + # Reshape back to [num_heads * head_dim, hidden_size] + tensor = tensor.view(-1, original_shape[-1]) + + return tensor + + +def convert_mimo_v2_hf_to_neuron_state_dict( + neuron_state_dict: Dict[str, Any], + config: MiMoV2InferenceConfig, +) -> Dict[str, Any]: + """Convert HuggingFace MiMo-V2-Flash weights to Neuron format. + + This handles: + 1. Router weight renaming + 2. Expert weight concatenation and transposition + 3. FP8 dequantization if needed + 4. K/V weight replication for CONVERT_TO_MHA mode + """ + + assert config.neuron_config.glu_mlp is True, "Only GLU MLP is supported" + + # Dequantize layers if needed + _maybe_dequantize_layer(neuron_state_dict, config) + + # Add rank utility tensors + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, config.neuron_config.tp_degree, dtype=torch.int32 + ) + + # Determine if CONVERT_TO_MHA is needed + tp_degree = config.neuron_config.tp_degree + num_attention_heads = config.num_attention_heads + + # MiMo-V2-Flash has different KV heads for full and sliding window attention + full_num_kv_heads = config.num_key_value_heads # 4 + swa_num_kv_heads = config.swa_num_key_value_heads # 8 + + # Check if we need to replicate K/V weights + full_use_convert_to_mha = tp_degree > full_num_kv_heads + swa_use_convert_to_mha = tp_degree > swa_num_kv_heads + + print(f"\n[DEBUG] CONVERT_TO_MHA status:") + print(f" tp_degree: {tp_degree}") + print(f" num_attention_heads: {num_attention_heads}") + print(f" full_num_kv_heads: {full_num_kv_heads}, use_convert_to_mha: {full_use_convert_to_mha}") + print(f" swa_num_kv_heads: {swa_num_kv_heads}, use_convert_to_mha: {swa_use_convert_to_mha}") + + for layer_idx in range(config.num_hidden_layers): + # Add rank utility for attention + neuron_state_dict[f"layers.{layer_idx}.self_attn.rank_util.rank"] = torch.arange( + 0, config.neuron_config.tp_degree, dtype=torch.int32 + ) + + # Determine attention type for this layer + is_sliding_window = config.layer_attention_types[layer_idx] == "sliding_window" + + if is_sliding_window: + src_num_kv_heads = swa_num_kv_heads + use_convert_to_mha = swa_use_convert_to_mha + head_dim = config.swa_head_dim # 192 + v_head_dim = config.swa_v_head_dim # 128 + else: + src_num_kv_heads = full_num_kv_heads + use_convert_to_mha = full_use_convert_to_mha + head_dim = config.head_dim # 192 + v_head_dim = config.v_head_dim # 128 + + # Replicate K/V weights if CONVERT_TO_MHA is needed + if use_convert_to_mha: + k_proj_key = f"layers.{layer_idx}.self_attn.k_proj.weight" + v_proj_key = f"layers.{layer_idx}.self_attn.v_proj.weight" + + if k_proj_key in neuron_state_dict: + old_shape = neuron_state_dict[k_proj_key].shape + neuron_state_dict[k_proj_key] = _replicate_kv_weights_for_convert_to_mha( + neuron_state_dict[k_proj_key], + src_num_kv_heads, + num_attention_heads, + head_dim, + ) + print(f"[DEBUG] Layer {layer_idx} ({'SWA' if is_sliding_window else 'Full'}): Replicated K: {old_shape} -> {neuron_state_dict[k_proj_key].shape}") + + if v_proj_key in neuron_state_dict: + old_shape = neuron_state_dict[v_proj_key].shape + neuron_state_dict[v_proj_key] = _replicate_kv_weights_for_convert_to_mha( + neuron_state_dict[v_proj_key], + src_num_kv_heads, + num_attention_heads, + v_head_dim, + ) + print(f"[DEBUG] Layer {layer_idx} ({'SWA' if is_sliding_window else 'Full'}): Replicated V: {old_shape} -> {neuron_state_dict[v_proj_key].shape}") + + # Only convert MoE layers + if not config.layer_uses_moe[layer_idx]: + continue + + # Check if this layer has MoE weights + gate_key = f"layers.{layer_idx}.mlp.gate.weight" + if gate_key not in neuron_state_dict: + continue + + # Rename router weights + neuron_state_dict[f"layers.{layer_idx}.mlp.router.linear_router.weight"] = ( + neuron_state_dict[gate_key].detach().clone() + ) + del neuron_state_dict[gate_key] + + # Get dimensions from first expert + expert_0_gate = f"layers.{layer_idx}.mlp.experts.0.gate_proj.weight" + if expert_0_gate not in neuron_state_dict: + continue + + intermediate_size, hidden_size = neuron_state_dict[expert_0_gate].shape + device = neuron_state_dict[expert_0_gate].device + dtype = neuron_state_dict[expert_0_gate].dtype + + num_experts = config.n_routed_experts + + # Concatenate gate and up projections + gate_up_proj = torch.empty( + num_experts, + hidden_size, + 2 * intermediate_size, + dtype=dtype, + device=device, + ) + + for e in range(num_experts): + gate_proj_weights = neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight" + ].T.detach().clone() + up_proj_weights = neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight" + ].T.detach().clone() + + gate_up_proj[e, :, :intermediate_size] = gate_proj_weights + gate_up_proj[e, :, intermediate_size:] = up_proj_weights + + del neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"] + del neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"] + + # Pad if needed + pad_size = getattr(config, "moe_intermediate_pad_size", 0) + if pad_size > 0: + gate_up_proj = gate_up_proj.reshape(num_experts, hidden_size, 2, -1) + gate_up_proj = F.pad(gate_up_proj, (0, pad_size)) + gate_up_proj = gate_up_proj.reshape(num_experts, hidden_size, -1) + + neuron_state_dict[f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj + + # Convert down projections + down_proj = torch.empty( + num_experts, + intermediate_size, + hidden_size, + dtype=dtype, + device=device, + ) + + for e in range(num_experts): + down_proj_weights = neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" + ].T.detach().clone() + down_proj[e] = down_proj_weights + del neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"] + + # Pad if needed + if pad_size > 0: + down_proj = F.pad(down_proj, (0, 0, 0, pad_size)) + + neuron_state_dict[f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj + + gc.collect() + + return neuron_state_dict + + +def _maybe_dequantize_layer( + neuron_state_dict: Dict[str, Any], + config: MiMoV2InferenceConfig, +): + """Dequantize FP8 layers if present.""" + scale_layers = [] + + for layer_key in list(neuron_state_dict.keys()): + if "_scale_inv" in layer_key: + scales = neuron_state_dict[layer_key] + scale_layers.append(layer_key) + + fp8_layer_name = layer_key.replace("_scale_inv", "") + if fp8_layer_name not in neuron_state_dict: + continue + + fp8_layer = neuron_state_dict[fp8_layer_name] + + # Get block size from config if available + if hasattr(config, 'quantization_config') and config.quantization_config: + block_size = config.quantization_config.get("weight_block_size", [128, 128]) + else: + block_size = [128, 128] + + # Expand scales and dequantize + scales_expanded = scales.repeat_interleave(block_size[0], dim=0) + scales_expanded = scales_expanded.repeat_interleave(block_size[1], dim=1) + + # Ensure shapes match + if scales_expanded.shape != fp8_layer.shape: + scales_expanded = scales_expanded[:fp8_layer.shape[0], :fp8_layer.shape[1]] + + scaled_layer = fp8_layer.to(torch.float32) * scales_expanded.to(torch.float32) + neuron_state_dict[fp8_layer_name] = scaled_layer.to(config.neuron_config.torch_dtype) + + # Remove scale layers + for scale_layer in scale_layers: + del neuron_state_dict[scale_layer] + + +class NeuronMiMoV2ForCausalLM(NeuronBaseForCausalLM): + """MiMo-V2-Flash for Causal Language Modeling on Neuron.""" + + _model_cls = NeuronMiMoV2Model + + @staticmethod + def load_hf_model(model_path: str, **kwargs): + """Load HuggingFace model. + + Note: MiMo-V2-Flash uses custom code, so we need trust_remote_code=True + """ + from transformers import AutoModelForCausalLM + return AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + **kwargs, + ) + + @classmethod + def get_config_cls(cls) -> Type[MiMoV2InferenceConfig]: + return MiMoV2InferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: Dict[str, Any], + config: MiMoV2InferenceConfig, + ) -> Dict[str, Any]: + return convert_mimo_v2_hf_to_neuron_state_dict(state_dict, config) + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + def get_compiler_args(self) -> str: + """Get compiler arguments optimized for MiMo-V2-Flash.""" + if self.compile_tag == CONTEXT_ENCODING_MODEL_TAG: + optimization_level = "-O1" + elif self.compile_tag == TOKEN_GENERATION_MODEL_TAG: + optimization_level = "-O3" if self.neuron_config.moe_ep_degree > 1 else "-O1" + else: + optimization_level = "-O1" + + compiler_args = ( + f"--enable-saturate-infinity " + f"--enable-mixed-precision-accumulation " + f"--model-type transformer " + f"{optimization_level}" + ) + + # Add CC overlap optimization + compiler_args += ( + " --tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2'" + ) + + compiler_args += " --auto-cast=none" + + # Enable vector-offset DGE + compiler_args += " --internal-enable-dge-levels vector_dynamic_offsets" + compiler_args += " --internal-hlo2tensorizer-options='--verify-hlo=true'" + + if self.neuron_config.scratchpad_page_size: + compiler_args += f" --hbm-scratchpad-page-size={self.neuron_config.scratchpad_page_size}" + + return compiler_args diff --git a/contrib/models/MiMo-V2-Flash/test/__init__.py b/contrib/models/MiMo-V2-Flash/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiMo-V2-Flash/test/integration/__init__.py b/contrib/models/MiMo-V2-Flash/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiMo-V2-Flash/test/integration/test_model.py b/contrib/models/MiMo-V2-Flash/test/integration/test_model.py new file mode 100644 index 00000000..bcbc368e --- /dev/null +++ b/contrib/models/MiMo-V2-Flash/test/integration/test_model.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +"""Integration tests for MiMo-V2-Flash NeuronX implementation.""" + +import pytest +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) + + +def test_config_import(): + """Test that config class can be imported.""" + from modeling_mimo_v2 import MiMoV2InferenceConfig, NeuronMiMoV2ForCausalLM + assert MiMoV2InferenceConfig is not None + assert NeuronMiMoV2ForCausalLM is not None + print("PASS: Config and model classes imported successfully") + + +def test_required_attributes(): + """Test that required attributes are defined.""" + from modeling_mimo_v2 import MiMoV2InferenceConfig + # Check get_required_attributes without instantiation (requires many params) + required = MiMoV2InferenceConfig.get_required_attributes(MiMoV2InferenceConfig) + assert "hidden_size" in required + assert "n_routed_experts" in required + assert "num_experts_per_tok" in required + assert "hybrid_layer_pattern" in required + assert "v_head_dim" in required + assert "swa_head_dim" in required + print(f"PASS: {len(required)} required attributes defined") + + +def test_neuron_config_cls(): + """Test that MoENeuronConfig is returned.""" + from modeling_mimo_v2 import MiMoV2InferenceConfig + from neuronx_distributed_inference.models.config import MoENeuronConfig + assert MiMoV2InferenceConfig.get_neuron_config_cls() == MoENeuronConfig + print("PASS: MoENeuronConfig returned") + + +def test_state_dict_converter(): + """Test that state dict converter function exists.""" + from modeling_mimo_v2 import NeuronMiMoV2ForCausalLM + assert hasattr(NeuronMiMoV2ForCausalLM, "convert_hf_to_neuron_state_dict") + print("PASS: State dict converter exists") + + +if __name__ == "__main__": + test_config_import() + test_required_attributes() + test_neuron_config_cls() + test_state_dict_converter() + print("\nAll tests passed!") diff --git a/contrib/models/MiMo-V2-Flash/test/unit/__init__.py b/contrib/models/MiMo-V2-Flash/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b From 88e247f232e99e833f1cd63b4eaf7d7ed6606639 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 16:57:59 +0800 Subject: [PATCH 02/23] Switch vllm-neuron patch to a runtime registration hook on release-0.5.0 The previous vllm-neuron-patch.patch was a 129-line copy-of-everything that targeted a fork branch (whn09/vllm-neuron feature/mimo-support) with stale history; it could not be applied cleanly to upstream vllm-project/ vllm-neuron. Replaced with a 40-line patch against release-0.5.0 that adds a `_register_contrib_models()` hook to `vllm_neuron.register()`: - If NXDI_CONTRIB_MIMO_V2_FLASH_SRC is set, import NeuronMiMoV2ForCausalLM from that directory. - Register it into NxDI's MODEL_TYPES under key "mimo_v2_flash" (matches the `mimov2flash -> mimo_v2_flash` rewrite that already exists in release-0.5.0's _get_neuron_model_cls). - Register "MiMoV2FlashForCausalLM" into vLLM's ModelRegistry so vLLM's architecture allowlist passes. This avoids modifying upstream NxDI's `utils/constants.py` (preserves the contrib zero-invasion property) and avoids modifying upstream vllm-neuron's model loader (the patch only adds a hook function). Updated accordingly: - perf_test/0_setup.sh now clones release-0.5.0 and `git apply`s the patch. - perf_test/bench_mimo_v2_flash.sh exports NXDI_CONTRIB_MIMO_V2_FLASH_SRC defaulting to this package's own src/. - README serving instructions document the new env var. Verified on trn2.48xlarge (NxDI 2.29, vLLM 0.16, vllm-neuron 0.5.0): NxDI 'mimo_v2_flash': True vLLM 'MiMoV2FlashForCausalLM': True ModelConfig(model=MiMo-V2-Flash-BF16) creation: OK Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2-Flash/README.md | 25 +- .../models/MiMo-V2-Flash/perf_test/0_setup.sh | 30 ++- .../perf_test/bench_mimo_v2_flash.sh | 6 + .../perf_test/vllm-neuron-patch.patch | 218 ++++++++---------- 4 files changed, 143 insertions(+), 136 deletions(-) diff --git a/contrib/models/MiMo-V2-Flash/README.md b/contrib/models/MiMo-V2-Flash/README.md index 43028ffa..45e7b205 100644 --- a/contrib/models/MiMo-V2-Flash/README.md +++ b/contrib/models/MiMo-V2-Flash/README.md @@ -106,18 +106,33 @@ MiMo-V2-Flash can be served via [vllm-neuron](https://github.com/aws-neuron/vllm ### Setup ```bash -# 1. Install vllm-neuron -pip install vllm-neuron +# 1. Clone vllm-project/vllm-neuron at release-0.5.0 +git clone --branch release-0.5.0 https://github.com/vllm-project/vllm-neuron.git /tmp/vllm-neuron -# 2. Apply the MiMo/MiniMax patch -cd /path/to/vllm-neuron +# 2. Apply the contrib registration patch +cd /tmp/vllm-neuron git apply /path/to/neuronx-distributed-inference/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch -pip install -e . + +# 3. Install +pip install --extra-index-url=https://pip.repos.neuron.amazonaws.com -e . ``` +Or let `perf_test/0_setup.sh` do steps 1-3 plus weight download. + +The patch is 40 lines and only touches `vllm_neuron/__init__.py`. It adds a +`_register_contrib_models()` hook that, when `NXDI_CONTRIB_MIMO_V2_FLASH_SRC` +is set, registers `NeuronMiMoV2ForCausalLM` into NxDI's `MODEL_TYPES` under +the key `mimo_v2_flash` **and** registers the `MiMoV2FlashForCausalLM` +architecture into vLLM's `ModelRegistry`. No upstream vLLM or NxDI source +is modified. + ### Serving ```bash +# The contrib src/ must be reachable so the plugin hook can import it. +export NXDI_CONTRIB_MIMO_V2_FLASH_SRC=/path/to/neuronx-distributed-inference/contrib/models/MiMo-V2-Flash/src +export MIMO_V2_FLASH_PATH=/path/to/MiMo-V2-Flash-BF16 + python3 -m vllm.entrypoints.openai.api_server \ --model /path/to/MiMo-V2-Flash-BF16 \ --tensor-parallel-size 64 \ diff --git a/contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh b/contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh index baf4c3bc..b0bc7940 100755 --- a/contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh +++ b/contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh @@ -1,5 +1,11 @@ #!/bin/bash # Setup for MiMo-V2-Flash vLLM benchmarking on Trn2. +# +# This clones upstream vllm-project/vllm-neuron at release-0.5.0 (which already +# has the mimov2flash -> mimo_v2_flash model_type rewrite), then applies +# vllm-neuron-patch.patch to add a runtime registration hook so the contrib +# NeuronMiMoV2ForCausalLM is plugged into both NxDI's MODEL_TYPES and vLLM's +# ModelRegistry at vllm-neuron plugin init time. set -e echo "==========================================" @@ -8,13 +14,25 @@ echo "==========================================" source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +PATCH_FILE="$(cd "$(dirname "$0")" && pwd)/vllm-neuron-patch.patch" + echo "" -echo "[1/2] Installing vllm-neuron with the MiMo/MiniMax patch..." +echo "[1/2] Installing vllm-neuron (release-0.5.0) with the contrib registration patch..." if [ ! -d /tmp/vllm-neuron ]; then - git clone --branch feature/mimo-support https://github.com/whn09/vllm-neuron.git /tmp/vllm-neuron + git clone --branch release-0.5.0 https://github.com/vllm-project/vllm-neuron.git /tmp/vllm-neuron fi + cd /tmp/vllm-neuron + +# Apply patch (idempotent via `git apply --check` first). +if git apply --check "$PATCH_FILE" 2>/dev/null; then + git apply "$PATCH_FILE" + echo " Applied $PATCH_FILE" +else + echo " Patch already applied or conflicts; continuing." +fi + pip install --extra-index-url=https://pip.repos.neuron.amazonaws.com -e . pip install s5cmd @@ -33,5 +51,11 @@ else echo " Download complete: $(du -sh $MIMO_PATH | cut -f1)" fi +# Figure out where this contrib package's src/ lives so the registration hook +# can add it to sys.path inside vllm-neuron. +CONTRIB_SRC="$(cd "$(dirname "$0")/.." && pwd)/src" + echo "" -echo "Setup complete. Set MIMO_V2_FLASH_PATH=$MIMO_PATH before running the benchmark." +echo "Setup complete. Before running the benchmark, export:" +echo " export MIMO_V2_FLASH_PATH=$MIMO_PATH" +echo " export NXDI_CONTRIB_MIMO_V2_FLASH_SRC=$CONTRIB_SRC" diff --git a/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh b/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh index 4a655991..1449b485 100755 --- a/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh +++ b/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh @@ -4,6 +4,12 @@ set -e source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2-Flash-BF16}" +# The NxDI contrib MiMo-V2-Flash modeling code is registered into vLLM / +# NxDI lookup tables by vllm-neuron's register() hook using this env var. +# Default to this contrib package's own src/ relative to the script. +: "${NXDI_CONTRIB_MIMO_V2_FLASH_SRC:=$(cd "$(dirname "$0")/.." && pwd)/src}" +export NXDI_CONTRIB_MIMO_V2_FLASH_SRC + PORT=8000 RESULTS_DIR="/tmp/bench_results/mimo_v2_flash" mkdir -p "$RESULTS_DIR" diff --git a/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch b/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch index cb8c0421..4bdd58d8 100644 --- a/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch +++ b/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch @@ -1,129 +1,91 @@ -diff --git a/vllm_neuron/worker/neuronx_distributed_model_loader.py b/vllm_neuron/worker/neuronx_distributed_model_loader.py -index d2099eb..e246249 100644 ---- a/vllm_neuron/worker/neuronx_distributed_model_loader.py -+++ b/vllm_neuron/worker/neuronx_distributed_model_loader.py -@@ -41,7 +41,7 @@ from neuronx_distributed_inference.models.config import ( # yapf: disable - from neuronx_distributed_inference.modules.lora_serving import LoraServingConfig - from neuronx_distributed_inference.utils.constants import MODEL_TYPES - from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config --from transformers import AutoModelForCausalLM, PretrainedConfig -+from transformers import PretrainedConfig - from vllm.config import ( - CacheConfig, - ModelConfig, -@@ -186,8 +186,14 @@ class NeuronModelBase(nn.Module): - - neuron_config = neuronx_model_cls.get_neuron_config_cls()(**neuron_config_dict) - -+ # Use pre-loaded hf_config if available (loaded by vLLM with trust_remote_code=True) -+ hf_config = kwargs.get("hf_config") -+ if hf_config is not None: -+ load_config_fn = load_pretrained_config(hf_config=hf_config) -+ else: -+ load_config_fn = load_pretrained_config(model_name_or_path) - config = kwargs.get("config") or neuronx_model_cls.get_config_cls()( -- neuron_config, load_config=load_pretrained_config(model_name_or_path) -+ neuron_config, load_config=load_config_fn - ) - - # If fused speculation is enabled, attach the draft model config. -@@ -254,11 +260,10 @@ class NeuronModelBase(nn.Module): - "Using pre-compiled artifacts, override_neuron_config will be ignored" - ) - -- def _save_pretrained_model(self, model_name: str): -- hf_model = AutoModelForCausalLM.from_pretrained(model_name) -- saved_path = os.path.join("local-models", model_name) -- hf_model.save_pretrained(saved_path) -- return saved_path -+ def _get_model_path(self, model_name: str): -+ """Get local path for model, using HuggingFace cache if available.""" -+ from huggingface_hub import snapshot_download -+ return snapshot_download(repo_id=model_name, trust_remote_code=True) - - def _compile_and_load_model( - self, model_path: str, neuronx_model_cls, config, compiled_path: str -@@ -565,7 +570,7 @@ class NeuronCausalLM(NeuronModelBase): - - if not success: - if not os.path.exists(model_name_or_path): -- model_name_or_path = self._save_pretrained_model(model_name_or_path) -+ model_name_or_path = self._get_model_path(model_name_or_path) - self._compile_and_load_model( - model_name_or_path, neuronx_model_cls, config, compiled_model_path - ) -@@ -611,10 +616,15 @@ class NeuronMultiModalCausalLM(NeuronCausalLM): - **text_neuron_config - ) - -+ hf_config = kwargs.get("hf_config") -+ if hf_config is not None: -+ load_config_fn = load_pretrained_config(hf_config=hf_config) -+ else: -+ load_config_fn = load_pretrained_config(model_name_or_path) - config = neuronx_model_cls.get_config_cls()( - text_neuron_config=text_neuron_config, - vision_neuron_config=vision_neuron_config, -- load_config=load_pretrained_config(model_name_or_path), -+ load_config=load_config_fn, - ) - - success, compiled_model_path, _ = self._load_weights_common( -@@ -623,7 +633,7 @@ class NeuronMultiModalCausalLM(NeuronCausalLM): - - if not success: - if not os.path.exists(model_name_or_path): -- model_name_or_path = self._save_pretrained_model(model_name_or_path) -+ model_name_or_path = self._get_model_path(model_name_or_path) - - self._compile_and_load_model( - model_name_or_path, neuronx_model_cls, config, compiled_model_path -@@ -758,14 +768,6 @@ class NeuronPixtralForCausalLM(NeuronMultiModalCausalLM): - - - class NeuronQwen2VLForCausalLM(NeuronMultiModalCausalLM): -- # overwrite _save_pretrained_model as Qwen2VL is not in AutoModelForCausalLM -- def _save_pretrained_model(self, model_name: str): -- from transformers import Qwen2VLForConditionalGeneration -- -- hf_model = Qwen2VLForConditionalGeneration.from_pretrained(model_name) -- saved_path = os.path.join("local-models", model_name) -- hf_model.save_pretrained(saved_path) -- return saved_path - - def execute_model(self, model_input): - """Helper to run model with defaults for missing multimodal inputs.""" -@@ -819,13 +821,7 @@ class NeuronQwen2VLForCausalLM(NeuronMultiModalCausalLM): - - - class NeuronQwen3VLForCausalLM(NeuronQwen2VLForCausalLM): -- def _save_pretrained_model(self, model_name: str): -- from transformers import Qwen3VLForConditionalGeneration -- -- hf_model = Qwen3VLForConditionalGeneration.from_pretrained(model_name) -- saved_path = os.path.join("local-models", model_name) -- hf_model.save_pretrained(saved_path) -- return saved_path -+ pass - - - class NeuronLlama4ForCausalLM(NeuronMultiModalCausalLM): -@@ -964,6 +960,10 @@ def _get_neuron_model_cls(architecture: str): - if model == "qwen3moe": - model = "qwen3_moe" - -+ # MiMo is based on Qwen2 architecture -+ if model == "mimo": -+ model = "qwen2" +diff --git a/vllm_neuron/__init__.py b/vllm_neuron/__init__.py +index ce4c6d2..f4505ae 100644 +--- a/vllm_neuron/__init__.py ++++ b/vllm_neuron/__init__.py +@@ -2,6 +2,8 @@ + """VllmNeuronPlugin module.""" + + import glob ++import os ++import sys + import warnings + from vllm_neuron.utils import set_unique_rt_root_comm_id + +@@ -12,6 +14,69 @@ def _is_neuron_dev() -> bool: + return len(neuron_devices) > 0 + + ++def _register_contrib_models(): ++ """Register NxDI contrib models that aren't in upstream NxDI's MODEL_TYPES. ++ ++ Activated by environment variables: ++ NXDI_CONTRIB_MIMO_V2_FLASH_SRC -> path to contrib/models/MiMo-V2-Flash/src + - if model == "qwen2vl": - model = "qwen2_vl" - -@@ -1050,6 +1050,7 @@ def get_neuron_model( - neuron_config=neuron_config, - override_neuron_config=override_neuron_config, - speculative_config=speculative_config, -+ hf_config=model_config.hf_config, - ) - model.neuron_config = model.model.config.neuron_config - model.architecture = architecture ++ Each model is both: ++ - Registered to vLLM's ModelRegistry so vLLM's architecture allowlist passes ++ - Registered to NxDI's MODEL_TYPES so vllm-neuron's _get_neuron_model_cls finds it ++ ++ This lets zero-invasion NxDI contrib models be served by vllm-neuron without ++ modifying either NxDI or vllm-neuron's upstream source. ++ """ ++ mimo_src = os.environ.get("NXDI_CONTRIB_MIMO_V2_FLASH_SRC") ++ if mimo_src and os.path.isdir(mimo_src): ++ if mimo_src not in sys.path: ++ sys.path.insert(0, mimo_src) ++ # Register into NxDI first (no circular import concerns, NxDI doesn't ++ # touch vllm). Use a try/except so a broken contrib dir doesn't block ++ # vllm-neuron startup. ++ try: ++ from modeling_mimo_v2 import NeuronMiMoV2ForCausalLM ++ from neuronx_distributed_inference.utils.constants import MODEL_TYPES ++ MODEL_TYPES.setdefault( ++ "mimo_v2_flash", {"causal-lm": NeuronMiMoV2ForCausalLM} ++ ) ++ except Exception as e: ++ warnings.warn( ++ f"Failed to register MiMo-V2-Flash in NxDI MODEL_TYPES: {e}", ++ category=UserWarning, ++ ) ++ return ++ # Register into vLLM's architecture allowlist. Import vLLM lazily ++ # because vllm_neuron.register() is itself called during vLLM's plugin ++ # init — importing `from vllm import ModelRegistry` at that stage ++ # triggers a partially-initialized circular import. Defer to a ++ # runtime hook via vllm.model_executor.models.registry module-level ++ # singleton, which is safe to access after plugin init completes. ++ # The actual vLLM architecture check happens later in ModelConfig ++ # creation, so registering here (even slightly later than NxDI) is ++ # still in time. ++ # Try to register immediately; if vLLM's module graph is still mid-init ++ # (circular import), silently skip — vLLM will re-trigger plugin ++ # registration later and the retry will succeed. We deliberately don't ++ # warn on circular-import failures because they're transient and ++ # expected during plugin init. ++ try: ++ from vllm.model_executor.models.registry import ModelRegistry as _MR ++ if "MiMoV2FlashForCausalLM" not in _MR.get_supported_archs(): ++ _MR.register_model( ++ "MiMoV2FlashForCausalLM", NeuronMiMoV2ForCausalLM ++ ) ++ except ImportError: ++ # vLLM still initializing — will be retried on the next plugin ++ # activation pass. ++ pass ++ except Exception as e: ++ warnings.warn( ++ f"Failed to register MiMoV2FlashForCausalLM in vLLM ModelRegistry: {e}", ++ category=UserWarning, ++ ) ++ ++ + def register(): + """Register the Neuron platform if Neuron devices are present, else return None.""" + if not _is_neuron_dev(): +@@ -20,6 +85,7 @@ def register(): + category=UserWarning, + ) + return None ++ _register_contrib_models() + return "vllm_neuron.platform.NeuronPlatform" + + From d40f579ccd3142499167dbedc32fe546dca13e08 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 17:12:53 +0800 Subject: [PATCH 03/23] vllm-neuron patch: move registration into _get_neuron_model_cls The previous version hooked into vllm_neuron.register(), which only runs in the parent APIServer process. vLLM V1 spawns EngineCore workers via multiprocessing.spawn and those child processes start a fresh Python interpreter; vLLM's plugin discovery does run there, but the module-level state (in particular the NxDI MODEL_TYPES dict) is a fresh copy so the parent's registration does not carry over. Move _register_contrib_models() into the loader itself and call it at the top of _get_neuron_model_cls(). Every process that tries to look up an architecture now gets a fresh idempotent registration attempt driven by NXDI_CONTRIB_MIMO_V2_FLASH_SRC / NXDI_CONTRIB_MINIMAX_M2_SRC. Also correct the MODEL_TYPES key: release-0.5.0's loader does not have the mimov2flash->mimo_v2_flash rewrite, so we must register under "mimov2flash" (matches architecture.lower()) and "minimaxm2". Verified on trn2.48xlarge: _get_neuron_model_cls("MiMoV2FlashForCausalLM") -> NeuronMiMoV2ForCausalLM Co-Authored-By: Claude Opus 4.7 --- .../perf_test/vllm-neuron-patch.patch | 128 ++++++++---------- 1 file changed, 57 insertions(+), 71 deletions(-) diff --git a/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch b/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch index 4bdd58d8..b7482315 100644 --- a/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch +++ b/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch @@ -1,91 +1,77 @@ -diff --git a/vllm_neuron/__init__.py b/vllm_neuron/__init__.py -index ce4c6d2..f4505ae 100644 ---- a/vllm_neuron/__init__.py -+++ b/vllm_neuron/__init__.py -@@ -2,6 +2,8 @@ - """VllmNeuronPlugin module.""" +diff --git a/vllm_neuron/worker/neuronx_distributed_model_loader.py b/vllm_neuron/worker/neuronx_distributed_model_loader.py +index d2099eb..ed6971f 100644 +--- a/vllm_neuron/worker/neuronx_distributed_model_loader.py ++++ b/vllm_neuron/worker/neuronx_distributed_model_loader.py +@@ -922,6 +922,64 @@ def _camel_to_kebab(name: str) -> str: + return re.sub("([a-z0-9])([A-Z])", r"\1-\2", s1).lower() - import glob -+import os -+import sys - import warnings - from vllm_neuron.utils import set_unique_rt_root_comm_id -@@ -12,6 +14,69 @@ def _is_neuron_dev() -> bool: - return len(neuron_devices) > 0 - - -+def _register_contrib_models(): -+ """Register NxDI contrib models that aren't in upstream NxDI's MODEL_TYPES. + -+ Activated by environment variables: -+ NXDI_CONTRIB_MIMO_V2_FLASH_SRC -> path to contrib/models/MiMo-V2-Flash/src ++def _register_contrib_models(): ++ """Lazy-register NxDI contrib models on each process that calls the loader. + -+ Each model is both: -+ - Registered to vLLM's ModelRegistry so vLLM's architecture allowlist passes -+ - Registered to NxDI's MODEL_TYPES so vllm-neuron's _get_neuron_model_cls finds it ++ Driven by env vars: ++ NXDI_CONTRIB_MIMO_V2_FLASH_SRC -> path to contrib MiMo-V2-Flash src/ ++ NXDI_CONTRIB_MINIMAX_M2_SRC -> path to contrib MiniMax-M2 src/ + -+ This lets zero-invasion NxDI contrib models be served by vllm-neuron without -+ modifying either NxDI or vllm-neuron's upstream source. ++ Registers the contrib model class into NxDI's MODEL_TYPES and, where ++ vLLM does not already know the architecture, registers it into vLLM's ++ ModelRegistry. Runs every time _get_neuron_model_cls is called so that ++ vLLM's spawn'd EngineCore workers (which don't inherit the parent's ++ module-level state) pick up the registration too. Registration is ++ idempotent. + """ -+ mimo_src = os.environ.get("NXDI_CONTRIB_MIMO_V2_FLASH_SRC") -+ if mimo_src and os.path.isdir(mimo_src): -+ if mimo_src not in sys.path: -+ sys.path.insert(0, mimo_src) -+ # Register into NxDI first (no circular import concerns, NxDI doesn't -+ # touch vllm). Use a try/except so a broken contrib dir doesn't block -+ # vllm-neuron startup. ++ import os as _os ++ import sys as _sys ++ import warnings as _w ++ ++ mimo_src = _os.environ.get("NXDI_CONTRIB_MIMO_V2_FLASH_SRC") ++ if mimo_src and _os.path.isdir(mimo_src) and "mimov2flash" not in MODEL_TYPES: ++ if mimo_src not in _sys.path: ++ _sys.path.insert(0, mimo_src) + try: + from modeling_mimo_v2 import NeuronMiMoV2ForCausalLM -+ from neuronx_distributed_inference.utils.constants import MODEL_TYPES + MODEL_TYPES.setdefault( -+ "mimo_v2_flash", {"causal-lm": NeuronMiMoV2ForCausalLM} ++ "mimov2flash", {"causal-lm": NeuronMiMoV2ForCausalLM} + ) ++ try: ++ from vllm.model_executor.models.registry import ModelRegistry ++ if "MiMoV2FlashForCausalLM" not in ModelRegistry.get_supported_archs(): ++ ModelRegistry.register_model( ++ "MiMoV2FlashForCausalLM", NeuronMiMoV2ForCausalLM ++ ) ++ except ImportError: ++ pass + except Exception as e: -+ warnings.warn( -+ f"Failed to register MiMo-V2-Flash in NxDI MODEL_TYPES: {e}", ++ _w.warn( ++ f"Failed to register MiMo-V2-Flash contrib model: {e}", + category=UserWarning, + ) -+ return -+ # Register into vLLM's architecture allowlist. Import vLLM lazily -+ # because vllm_neuron.register() is itself called during vLLM's plugin -+ # init — importing `from vllm import ModelRegistry` at that stage -+ # triggers a partially-initialized circular import. Defer to a -+ # runtime hook via vllm.model_executor.models.registry module-level -+ # singleton, which is safe to access after plugin init completes. -+ # The actual vLLM architecture check happens later in ModelConfig -+ # creation, so registering here (even slightly later than NxDI) is -+ # still in time. -+ # Try to register immediately; if vLLM's module graph is still mid-init -+ # (circular import), silently skip — vLLM will re-trigger plugin -+ # registration later and the retry will succeed. We deliberately don't -+ # warn on circular-import failures because they're transient and -+ # expected during plugin init. ++ ++ minimax_src = _os.environ.get("NXDI_CONTRIB_MINIMAX_M2_SRC") ++ if minimax_src and _os.path.isdir(minimax_src) and "minimaxm2" not in MODEL_TYPES: ++ if minimax_src not in _sys.path: ++ _sys.path.insert(0, minimax_src) + try: -+ from vllm.model_executor.models.registry import ModelRegistry as _MR -+ if "MiMoV2FlashForCausalLM" not in _MR.get_supported_archs(): -+ _MR.register_model( -+ "MiMoV2FlashForCausalLM", NeuronMiMoV2ForCausalLM -+ ) -+ except ImportError: -+ # vLLM still initializing — will be retried on the next plugin -+ # activation pass. -+ pass ++ from modeling_minimax_m2 import NeuronMiniMaxM2ForCausalLM ++ MODEL_TYPES.setdefault( ++ "minimaxm2", {"causal-lm": NeuronMiniMaxM2ForCausalLM} ++ ) + except Exception as e: -+ warnings.warn( -+ f"Failed to register MiMoV2FlashForCausalLM in vLLM ModelRegistry: {e}", ++ _w.warn( ++ f"Failed to register MiniMax-M2 contrib model: {e}", + category=UserWarning, + ) + + - def register(): - """Register the Neuron platform if Neuron devices are present, else return None.""" - if not _is_neuron_dev(): -@@ -20,6 +85,7 @@ def register(): - category=UserWarning, - ) - return None + def _get_neuron_model_cls(architecture: str): + """ + Get Neuron model class from architecture string. +@@ -941,6 +999,7 @@ def _get_neuron_model_cls(architecture: str): + _get_neuron_model_cls("NeuronLlamaForCausalLM") + + """ + _register_contrib_models() - return "vllm_neuron.platform.NeuronPlatform" - - + # Handle Neuron class name (starts with "Neuron") - strip prefix + if architecture.startswith("Neuron") and "For" in architecture: + original_architecture = architecture From 83fb9e94a0aa2c8e319f2433cf0acd40cde2d63d Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 17:21:09 +0800 Subject: [PATCH 04/23] vllm-neuron patch: also monkey-patch AutoConfig for trust_remote_code NxDI's hf_adapter.load_config calls AutoConfig.from_pretrained(path) without trust_remote_code=True. Contrib models like MiMo-V2-Flash ship a configuration_*.py in the checkpoint that requires custom code execution, so this trips: ValueError: The repository ... contains custom code which must be executed to correctly load the model. vLLM's top-level --trust-remote-code only affects vLLM's own config load, not NxDI's re-load via hf_adapter. Add a _patch_autoconfig_trust_remote_code() helper that wraps AutoConfig.from_pretrained to default trust_remote_code=True. Called from _register_contrib_models() alongside the MODEL_TYPES registration so every process that reaches _get_neuron_model_cls installs the patch (idempotent via a _nxdi_contrib_patched sentinel on the class). Verified on trn2.48xlarge: AutoConfig.from_pretrained('/opt/dlami/nvme/models/MiMo-V2-Flash-BF16') now succeeds instead of asking for user input. Co-Authored-By: Claude Opus 4.7 --- .../perf_test/vllm-neuron-patch.patch | 36 +++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch b/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch index b7482315..4a84a558 100644 --- a/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch +++ b/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch @@ -1,11 +1,39 @@ diff --git a/vllm_neuron/worker/neuronx_distributed_model_loader.py b/vllm_neuron/worker/neuronx_distributed_model_loader.py -index d2099eb..ed6971f 100644 +index d2099eb..0c162e4 100644 --- a/vllm_neuron/worker/neuronx_distributed_model_loader.py +++ b/vllm_neuron/worker/neuronx_distributed_model_loader.py -@@ -922,6 +922,64 @@ def _camel_to_kebab(name: str) -> str: +@@ -922,6 +922,94 @@ def _camel_to_kebab(name: str) -> str: return re.sub("([a-z0-9])([A-Z])", r"\1-\2", s1).lower() ++ ++def _patch_autoconfig_trust_remote_code(): ++ """Monkey-patch ``AutoConfig.from_pretrained`` to default ``trust_remote_code=True``. ++ ++ NxDI's ``hf_adapter.load_config`` calls ``AutoConfig.from_pretrained(path)`` ++ without ``trust_remote_code``. Contrib models like MiMo-V2-Flash that ++ ship a ``configuration_*.py`` with the checkpoint require custom code ++ execution, so the default behaviour crashes with ``ValueError: The ++ repository ... contains custom code which must be executed``. ++ ++ vLLM's top-level ``--trust-remote-code`` flag only affects vLLM's own ++ config load, not NxDI's. Patching here is cheap and idempotent. ++ """ ++ try: ++ from transformers import AutoConfig ++ except ImportError: ++ return ++ if getattr(AutoConfig, "_nxdi_contrib_patched", False): ++ return ++ _orig = AutoConfig.from_pretrained ++ ++ def _patched(*args, **kwargs): ++ kwargs.setdefault("trust_remote_code", True) ++ return _orig(*args, **kwargs) ++ ++ AutoConfig.from_pretrained = _patched ++ AutoConfig._nxdi_contrib_patched = True ++ + +def _register_contrib_models(): + """Lazy-register NxDI contrib models on each process that calls the loader. @@ -25,6 +53,8 @@ index d2099eb..ed6971f 100644 + import sys as _sys + import warnings as _w + ++ _patch_autoconfig_trust_remote_code() ++ + mimo_src = _os.environ.get("NXDI_CONTRIB_MIMO_V2_FLASH_SRC") + if mimo_src and _os.path.isdir(mimo_src) and "mimov2flash" not in MODEL_TYPES: + if mimo_src not in _sys.path: @@ -67,7 +97,7 @@ index d2099eb..ed6971f 100644 def _get_neuron_model_cls(architecture: str): """ Get Neuron model class from architecture string. -@@ -941,6 +999,7 @@ def _get_neuron_model_cls(architecture: str): +@@ -941,6 +1029,7 @@ def _get_neuron_model_cls(architecture: str): _get_neuron_model_cls("NeuronLlamaForCausalLM") """ From 4b858166f13c64f12126d86d19d54c975017adc1 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 17:45:15 +0800 Subject: [PATCH 05/23] bench: set use_torch_block_wise=true to avoid missing NKI kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Neuron SDK 2.29 ships with neuronx_distributed 0.17 whose moe/blockwise.py expects blockwise_mm_baseline_shard_hidden at either neuronxcc.nki._private.blockwise_mm or _pre_prod_kernels.blockwise_mm. Both import paths resolve in the installed SDK but neither exports the baseline_shard_hidden variant; the MoE forward reaches _call_shard_hidden_kernel and raises NotImplementedError. Setting blockwise_matmul_config.use_torch_block_wise=true makes the blockwise matmul go through the PyTorch reference implementation, bypassing the missing NKI kernel. It is slower than the NKI path but unblocks end-to-end vLLM benchmarking on the current stack. Remove when the NKI kernel is promoted back to a public path. Applied to the COMMON_MIMO_CONFIG block and merged into the Config 2/3 blockwise_matmul_config overrides (JSON does not recursively merge nested dicts — the per-config override wins, so use_torch_block_wise must be listed there too). Co-Authored-By: Claude Opus 4.7 --- .../MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh b/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh index 1449b485..df68effd 100755 --- a/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh +++ b/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh @@ -27,7 +27,14 @@ COMMON_MIMO_CONFIG='"tp_degree": 64, "strided_context_parallel_kernel_enabled": false, "glu_mlp": true, "normalize_top_k_affinities": true, - "router_config": {"act_fn": "sigmoid", "dtype": "float32"}' + "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, + "blockwise_matmul_config": {"use_torch_block_wise": true}' +# NOTE: use_torch_block_wise=true forces MoE blockwise to use the PyTorch +# reference implementation. The NKI kernel path pulls +# neuronxcc.nki._private.blockwise_mm.blockwise_mm_baseline_shard_hidden +# which is missing from Neuron SDK 2.29's public path; skipping it here is +# the practical way to get the bench running on current stacks. Remove +# this once the NKI kernel is promoted back to the stable path. # Helper: wait for vLLM server to be ready wait_for_server() { @@ -170,6 +177,7 @@ python3 -m vllm.entrypoints.openai.api_server \ "use_index_calc_kernel": true, "moe_mask_padded_tokens": true, "blockwise_matmul_config": { + "use_torch_block_wise": true, "use_shard_on_intermediate_dynamic_while": true, "skip_dma_token": true }, @@ -222,6 +230,7 @@ python3 -m vllm.entrypoints.openai.api_server \ "use_index_calc_kernel": true, "moe_mask_padded_tokens": true, "blockwise_matmul_config": { + "use_torch_block_wise": true, "use_shard_on_intermediate_dynamic_while": true, "skip_dma_token": true }, From d95207cc0d9e545d679846ae096bb97c38a76c0d Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 18:07:29 +0800 Subject: [PATCH 06/23] bench: extend wait_for_server timeout to 2h for MoE first-compile First-time compilation of a 256-expert MoE model on trn2.48xlarge takes 30-90 minutes (~3 configs x 3 buckets x 64 TP ranks of neuron-cc work). The previous 600s timeout aborts the benchmark driver while the background compile is still running. Bump to 7200s (2h) and emit a progress blip every minute so the user knows it's alive rather than hung. Co-Authored-By: Claude Opus 4.7 --- .../perf_test/bench_mimo_v2_flash.sh | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh b/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh index df68effd..a4fcda90 100755 --- a/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh +++ b/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh @@ -36,17 +36,25 @@ COMMON_MIMO_CONFIG='"tp_degree": 64, # the practical way to get the bench running on current stacks. Remove # this once the NKI kernel is promoted back to the stable path. -# Helper: wait for vLLM server to be ready +# Helper: wait for vLLM server to be ready. First-time compilation of a +# 256-expert MoE model takes 30-90 minutes, so we poll for up to 2 hours. wait_for_server() { - echo " Waiting for vLLM server to be ready..." - for i in $(seq 1 120); do + echo " Waiting for vLLM server to be ready (up to 2h for first compile)..." + local interval=10 + local max_attempts=720 # 720 * 10s = 7200s = 2h + local start=$SECONDS + for i in $(seq 1 $max_attempts); do if curl -s http://localhost:$PORT/health > /dev/null 2>&1; then - echo " Server ready! (${i}s)" + echo " Server ready! (waited $((SECONDS - start))s)" return 0 fi - sleep 5 + # Show a progress blip every minute so the user knows we're alive + if [ $((i % 6)) -eq 0 ]; then + echo " ...still waiting ($((SECONDS - start))s elapsed)" + fi + sleep $interval done - echo " ERROR: Server did not start within 600s" + echo " ERROR: Server did not start within $((max_attempts * interval))s" return 1 } From f669f1c892e73f682f02123147bfb5417d5fedce Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 22 Apr 2026 19:08:32 +0800 Subject: [PATCH 07/23] perf_test: add sanity_check.sh and run_bench_single.sh helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two small companion scripts to bench_mimo_v2_flash.sh, intended for use after the monolithic bench has already brought a vLLM server up (or when the bench driver timed out during first-compile and you want to salvage the still-running server). sanity_check.sh - POSTs a one-shot chat completion against localhost:$PORT. - Prints the JSON response and a one-line summary of the model's reply. - Health-checks /health first and fails fast if the server isn't up. run_bench_single.sh - Runs one 'vllm bench serve' pass with configurable CONCURRENCY / NUM_PROMPTS / INPUT_LEN / OUTPUT_LEN. - Does NOT launch or kill the server — you bring your own. - Writes the transcript to $RESULTS_DIR/${CONFIG_NAME}_c${CONCURRENCY}.txt, matching bench_mimo_v2_flash.sh's output layout. Typical usage after a long first-compile: # terminal 1: start the server via the main bench (it'll fail wait_for_server # but the server process stays up and keeps compiling in the background) bash bench_mimo_v2_flash.sh # terminal 2: once the server prints "Application startup complete.": bash sanity_check.sh bash run_bench_single.sh CONCURRENCY=16 NUM_PROMPTS=128 bash run_bench_single.sh Co-Authored-By: Claude Opus 4.7 --- .../perf_test/run_bench_single.sh | 76 +++++++++++++++++++ .../MiMo-V2-Flash/perf_test/sanity_check.sh | 59 ++++++++++++++ 2 files changed, 135 insertions(+) create mode 100755 contrib/models/MiMo-V2-Flash/perf_test/run_bench_single.sh create mode 100755 contrib/models/MiMo-V2-Flash/perf_test/sanity_check.sh diff --git a/contrib/models/MiMo-V2-Flash/perf_test/run_bench_single.sh b/contrib/models/MiMo-V2-Flash/perf_test/run_bench_single.sh new file mode 100755 index 00000000..17b62e6f --- /dev/null +++ b/contrib/models/MiMo-V2-Flash/perf_test/run_bench_single.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# Run a single vllm-bench-serve pass against an already-running vLLM server. +# +# Unlike bench_mimo_v2_flash.sh this script does NOT launch or kill the vLLM +# server — you bring your own. That makes it convenient when the bench driver +# in bench_mimo_v2_flash.sh times out during first-time compilation: the server +# keeps running, and once it's ready you can collect numbers with this. +# +# Usage: +# bash run_bench_single.sh # defaults: c=1, 16 prompts +# CONCURRENCY=16 NUM_PROMPTS=128 bash run_bench_single.sh +# CONFIG_NAME=bs32_tp1_ep64_opt CONCURRENCY=16 NUM_PROMPTS=128 bash run_bench_single.sh +# +# Environment knobs: +# PORT vLLM server port (default 8000) +# MIMO_V2_FLASH_PATH Path to the BF16 checkpoint (default +# /opt/dlami/nvme/models/MiMo-V2-Flash-BF16) +# CONCURRENCY --max-concurrency (default 1) +# NUM_PROMPTS --num-prompts (default 16) +# INPUT_LEN --random-input-len (default 900) +# OUTPUT_LEN --random-output-len (default 90) +# RANGE_RATIO --random-range-ratio (default 0.03) +# CONFIG_NAME Used in the output filename (default bs1_tp64_ep1) +# RESULTS_DIR Where to dump per-run log (default /tmp/bench_results/mimo_v2_flash) + +set -e + +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2-Flash-BF16}" +PORT="${PORT:-8000}" +CONCURRENCY="${CONCURRENCY:-1}" +NUM_PROMPTS="${NUM_PROMPTS:-16}" +INPUT_LEN="${INPUT_LEN:-900}" +OUTPUT_LEN="${OUTPUT_LEN:-90}" +RANGE_RATIO="${RANGE_RATIO:-0.03}" +CONFIG_NAME="${CONFIG_NAME:-bs1_tp64_ep1}" +RESULTS_DIR="${RESULTS_DIR:-/tmp/bench_results/mimo_v2_flash}" + +mkdir -p "$RESULTS_DIR" + +echo "==========================================" +echo "MiMo-V2-Flash single-run benchmark" +echo "==========================================" +echo " Model: $MODEL_PATH" +echo " Port: $PORT" +echo " Config: $CONFIG_NAME" +echo " Concurrency: $CONCURRENCY" +echo " Prompts: $NUM_PROMPTS" +echo " Input len: $INPUT_LEN Output len: $OUTPUT_LEN" +echo " Results: $RESULTS_DIR/${CONFIG_NAME}_c${CONCURRENCY}.txt" +echo "" + +# Quick health check +if ! curl -sf "http://localhost:$PORT/health" > /dev/null; then + echo "ERROR: vLLM server is not responding on http://localhost:$PORT" + echo "Start it first (e.g., bench_mimo_v2_flash.sh) and wait until" + echo "'Application startup complete.' is printed." + exit 1 +fi + +vllm bench serve \ + --backend vllm \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --endpoint /v1/completions \ + --dataset-name random \ + --num-prompts "$NUM_PROMPTS" \ + --random-input-len "$INPUT_LEN" \ + --random-output-len "$OUTPUT_LEN" \ + --random-range-ratio "$RANGE_RATIO" \ + --max-concurrency "$CONCURRENCY" \ + 2>&1 | tee "$RESULTS_DIR/${CONFIG_NAME}_c${CONCURRENCY}.txt" + +echo "" +echo "Saved to: $RESULTS_DIR/${CONFIG_NAME}_c${CONCURRENCY}.txt" diff --git a/contrib/models/MiMo-V2-Flash/perf_test/sanity_check.sh b/contrib/models/MiMo-V2-Flash/perf_test/sanity_check.sh new file mode 100755 index 00000000..b4da1895 --- /dev/null +++ b/contrib/models/MiMo-V2-Flash/perf_test/sanity_check.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# Quick sanity check against an already-running vLLM server. +# +# Assumes vLLM is already listening on $PORT (default 8000) with MiMo-V2-Flash +# loaded. Sends a single chat completion and prints the model's reply. +# +# Usage: +# bash sanity_check.sh # uses defaults +# PORT=8001 bash sanity_check.sh # custom port +# PROMPT="..." bash sanity_check.sh # custom prompt + +set -e + +MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2-Flash-BF16}" +PORT="${PORT:-8000}" +PROMPT="${PROMPT:-What is 1+1? Answer briefly.}" +MAX_TOKENS="${MAX_TOKENS:-64}" + +echo "Sanity check: POST /v1/chat/completions on port $PORT" +echo " Model: $MODEL_PATH" +echo " Prompt: $PROMPT" +echo " Max tokens: $MAX_TOKENS" +echo "" + +# Health check first — fail fast if server isn't up. +if ! curl -sf "http://localhost:$PORT/health" > /dev/null; then + echo "ERROR: vLLM server is not responding on http://localhost:$PORT" + echo "Start it with 'bash bench_mimo_v2_flash.sh' or your own launcher first." + exit 1 +fi + +RESPONSE=$(curl -s "http://localhost:$PORT/v1/chat/completions" \ + -H 'Content-Type: application/json' \ + -d "$(cat </dev/null || echo "$RESPONSE" +echo "" + +# Extract the model's reply for a human-friendly one-liner summary. +REPLY=$(echo "$RESPONSE" | python3 -c " +import json, sys +try: + r = json.load(sys.stdin) + print(r['choices'][0]['message']['content'].strip()) +except Exception as e: + print(f'(could not parse reply: {e})') +" 2>/dev/null) + +echo "Model reply: $REPLY" From 42420cb9de12eab0abb2c999313914506640407c Mon Sep 17 00:00:00 2001 From: whn09 Date: Thu, 23 Apr 2026 19:44:19 +0800 Subject: [PATCH 08/23] Add streaming FP8 preprocess script for MiMo-V2-Flash The original preprocess_mimo_v2_fp8.py loaded the entire ~290 GB HF FP8 checkpoint into RAM via load_state_dict(), peaking well over 600 GB after dequant/requant copies. This per-layer streaming rewrite (one safe_open handle at a time) reduces peak memory to ~24 GB and runs in ~20 minutes, producing a 311 GB Neuron-FP8 checkpoint as model_layer{0..47}.safetensors plus model_extras.safetensors. Key points: - Attention q/k/v: rescale HF OCP FP8 (+/-448) to Neuron FP8 (+/-240) with per-row scales. - Attention o_proj: listed in HF quantization_config.ignored_layers; keep as BF16 and DO NOT emit .scale. The Neuron side binds o_proj to plain RowParallelLinear (not QuantizedRowParallel), so writing FP8 + .scale would be silently reinterpreted as BF16 bytes at load and produce garbage outputs. - MoE experts: keep blockwise scales, fuse gate|up into the packed [num_experts, H, 2*IM] layout expected by ExpertFusedRowParallelLinear. - Layer 0 dense MLP and attention_sink_bias handling matches the Flash config (add_swa_attention_sink_bias=True, add_full_attention_sink_bias=False). Co-Authored-By: Claude Opus 4.7 --- .../preprocess_mimo_v2_flash_fp8.py | 490 ++++++++++++++++++ 1 file changed, 490 insertions(+) create mode 100644 contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_flash_fp8.py diff --git a/contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_flash_fp8.py b/contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_flash_fp8.py new file mode 100644 index 00000000..2491aa18 --- /dev/null +++ b/contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_flash_fp8.py @@ -0,0 +1,490 @@ +""" +Preprocess MiMo-V2-Flash FP8 checkpoint for Neuron inference. + +This is a streaming (per-layer) rewrite of preprocess_mimo_v2_fp8.py. The +original preprocess loaded the entire ~290 GB FP8 checkpoint into RAM via +load_state_dict(); that peaks well over 600 GB after dequantize/requantize +copies and is fragile. This version keeps a single safe_open handle live +at a time and emits per-layer safetensors shards, capping peak RAM at +~24 GB and finishing in ~20 minutes. + +MiMo-V2-Flash checkpoint layout: + - q_proj, k_proj, v_proj are stored *separately* in the HF checkpoint + (not pre-fused). No split_qkv_fused needed. + - o_proj is BF16 (listed in quantization_config.ignored_layers); kept + as BF16 on the Neuron side (RowParallelLinear, not QuantizedRowParallel). + - Layer 0 is a dense MLP (moe_layer_freq[0] == 0) with intermediate_size + 16384; layers 1..47 are MoE with 256 experts each. + - Hybrid attention: 9 "full" layers (hybrid_layer_pattern[i] == 0) and + 39 "sliding window" layers (== 1). SWA layers carry + attention_sink_bias (add_swa_attention_sink_bias=True in the config; + add_full_attention_sink_bias=False, so full layers do NOT get it). + +Neuron-side rescaling (same as Pro/original-Flash): + - OCP FP8 e4m3 (±448) -> Neuron FP8 e4m3 (±240) with FP8_SCALING_FACTOR=448/240. + - Per-row scales for attention/dense-mlp projections (q/k/v/o, gate/up/down + of the dense layer). + - Blockwise (128x128) scales kept for MoE expert weights; per-expert weights + are transposed and fused to match ExpertFusedRowParallelLinear's packed + layout (gate_up_proj: [num_experts, H, 2*IM]; down_proj: [num_experts, IM, H]). + +Output layout: + save_path/ + config.json, tokenizer.*, chat_template.jinja if present + configuration_mimo_v2_flash.py, modeling_mimo_v2_flash.py (trust_remote_code) + model.safetensors.index.json (regenerated) + model_extras.safetensors (embed_tokens, norm, lm_head) + model_layer{N}.safetensors (one per decoder layer, N=0..47) + +Usage: + python preprocess_mimo_v2_flash_fp8.py \\ + --hf_model_path /opt/dlami/nvme/models/MiMo-V2-Flash \\ + --save_path /opt/dlami/nvme/models/MiMo-V2-Flash-Neuron-FP8 \\ + --tp_degree 64 +""" + +import argparse +import gc +import json +import os +import shutil +import time +from typing import Dict, List, Optional, Tuple + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + + +FP8_SCALING_FACTOR = 448.0 / 240.0 +NEURON_FP8_MAX = 240.0 + + +# --------------------------------------------------------------------------- +# Quantization primitives +# --------------------------------------------------------------------------- + +def convert_bf16_to_fp8_per_row( + weight: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """BF16 [out, in] -> Neuron FP8 per-row (scales shape [out, 1]).""" + weight_float = weight.float() + row_max_abs = weight_float.abs().max(dim=1, keepdim=True)[0] + scales = torch.clamp(row_max_abs / NEURON_FP8_MAX, min=1e-10) + quantized = (weight_float / scales).to(torch.float8_e4m3fn) + return quantized, scales.to(torch.float32) + + +def rescale_fp8_to_per_row( + weight: torch.Tensor, scale: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Block-wise FP8 + blockwise scale -> Neuron per-row FP8. + + Dequantize to float32 using block broadcast, then per-row requantize. + """ + out_features, in_features = weight.shape + scale_h, scale_w = scale.shape + + block_h = (out_features + scale_h - 1) // scale_h + block_w = (in_features + scale_w - 1) // scale_w + + weight_float = weight.float() + dequantized = torch.zeros(out_features, in_features, dtype=torch.float32) + for i in range(scale_h): + for j in range(scale_w): + h0, h1 = i * block_h, min((i + 1) * block_h, out_features) + w0, w1 = j * block_w, min((j + 1) * block_w, in_features) + dequantized[h0:h1, w0:w1] = ( + weight_float[h0:h1, w0:w1] * scale[i, j].item() + ) + + row_max_abs = dequantized.abs().max(dim=1, keepdim=True)[0] + scales = torch.clamp(row_max_abs / NEURON_FP8_MAX, min=1e-10) + quantized = (dequantized / scales).to(torch.float8_e4m3fn) + return quantized, scales.to(torch.float32) + + +def rescale_fp8_weight_blockwise( + weight: torch.Tensor, scale: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Keep blockwise scales, just rescale into Neuron FP8 range. + + MoE expert weights stay block-quantized; only the dtype range changes. + """ + weight_bf16 = weight.bfloat16() + rescaled = (weight_bf16 / FP8_SCALING_FACTOR).to(torch.float8_e4m3fn) + neuron_scale = scale.float() * FP8_SCALING_FACTOR + return rescaled, neuron_scale.to(torch.float32) + + +# --------------------------------------------------------------------------- +# Streaming weight access (one open safetensors handle at a time) +# --------------------------------------------------------------------------- + +class LazyWeightMap: + """Lazily fetch tensors from sharded safetensors, keeping one handle live.""" + + def __init__(self, model_dir: str, weight_map: Dict[str, str]): + self.model_dir = model_dir + self.weight_map = weight_map + self._cur_filename: Optional[str] = None + self._cur_handle = None + + def _open(self, filename: str): + if self._cur_filename == filename: + return self._cur_handle + if self._cur_handle is not None: + self._cur_handle.__exit__(None, None, None) + self._cur_handle = None + path = os.path.join(self.model_dir, filename) + self._cur_handle = safe_open(path, framework="pt", device="cpu") + self._cur_handle.__enter__() + self._cur_filename = filename + return self._cur_handle + + def get(self, key: str) -> Optional[torch.Tensor]: + filename = self.weight_map.get(key) + if filename is None: + return None + return self._open(filename).get_tensor(key) + + def has(self, key: str) -> bool: + return key in self.weight_map + + def close(self): + if self._cur_handle is not None: + self._cur_handle.__exit__(None, None, None) + self._cur_handle = None + self._cur_filename = None + + +# --------------------------------------------------------------------------- +# Per-tensor helper +# --------------------------------------------------------------------------- + +def _maybe_fp8_to_neuron_per_row( + weight: torch.Tensor, scale: Optional[torch.Tensor] +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """FP8 blockwise -> per-row, or BF16 -> FP8 per-row. Pass-through otherwise.""" + if weight.dtype == torch.float8_e4m3fn and scale is not None: + return rescale_fp8_to_per_row(weight, scale) + if weight.dtype == torch.bfloat16: + return convert_bf16_to_fp8_per_row(weight) + return weight, scale + + +# --------------------------------------------------------------------------- +# Per-layer processing +# --------------------------------------------------------------------------- + +def process_layer( + layer_idx: int, + lazy: LazyWeightMap, + config: dict, + is_dense: bool, + is_swa: bool, +) -> Dict[str, torch.Tensor]: + out: Dict[str, torch.Tensor] = {} + prefix = f"model.layers.{layer_idx}." + out_prefix = f"layers.{layer_idx}." + + # --- Layer norms (BF16, untouched) --- + for name in ("input_layernorm", "post_attention_layernorm"): + t = lazy.get(f"{prefix}{name}.weight") + if t is not None: + out[f"{out_prefix}{name}.weight"] = t.detach().clone() + + # --- Attention: q/k/v/o are stored separately in Flash --- + # q/k/v: rescale to Neuron FP8 per-row. + for proj in ("q_proj", "k_proj", "v_proj"): + w = lazy.get(f"{prefix}self_attn.{proj}.weight") + if w is None: + continue + s = lazy.get(f"{prefix}self_attn.{proj}.weight_scale_inv") + w2, s2 = _maybe_fp8_to_neuron_per_row(w, s) + out[f"{out_prefix}self_attn.{proj}.weight"] = w2 + if s2 is not None: + out[f"{out_prefix}self_attn.{proj}.scale"] = s2 + + # o_proj is listed in HF quantization_config.ignored_layers and ships as + # BF16; on Neuron it binds to a plain RowParallelLinear (see + # modeling_mimo_v2.py: self.o_proj = RowParallelLinear(...)), NOT a + # QuantizedRowParallel. Writing FP8 + .scale here would silently be + # reinterpreted as BF16 bytes at load time and produce garbage outputs. + # Keep BF16, never emit .scale. + o_w = lazy.get(f"{prefix}self_attn.o_proj.weight") + o_s = lazy.get(f"{prefix}self_attn.o_proj.weight_scale_inv") + if o_w is not None: + if o_w.dtype == torch.float8_e4m3fn: + # Defensive: if a future checkpoint FP8-quantizes o_proj, dequant + # blockwise back to BF16 (no per-row requant; RowParallelLinear has + # no .scale parameter). + assert o_s is not None, "FP8 o_proj requires weight_scale_inv" + out_features, in_features = o_w.shape + scale_h, scale_w = o_s.shape + block_h = (out_features + scale_h - 1) // scale_h + block_w = (in_features + scale_w - 1) // scale_w + wf = o_w.float() + tmp = torch.zeros(out_features, in_features, dtype=torch.float32) + for i in range(scale_h): + for j in range(scale_w): + h0, h1 = i * block_h, min((i + 1) * block_h, out_features) + w0, w1 = j * block_w, min((j + 1) * block_w, in_features) + tmp[h0:h1, w0:w1] = wf[h0:h1, w0:w1] * o_s[i, j].item() + o_bf16 = tmp.to(torch.bfloat16) + else: + o_bf16 = o_w.to(torch.bfloat16) + out[f"{out_prefix}self_attn.o_proj.weight"] = o_bf16.detach().clone() + + # --- attention_sink_bias: present only on SWA layers in MiMo-V2-Flash. + # config.add_swa_attention_sink_bias=True, add_full_attention_sink_bias=False. + if is_swa and config.get("add_swa_attention_sink_bias", False): + sink = lazy.get(f"{prefix}self_attn.attention_sink_bias") + if sink is not None: + out[f"{out_prefix}self_attn.attention_sink_bias"] = sink.detach().clone() + elif not is_swa and config.get("add_full_attention_sink_bias", False): + sink = lazy.get(f"{prefix}self_attn.attention_sink_bias") + if sink is not None: + out[f"{out_prefix}self_attn.attention_sink_bias"] = sink.detach().clone() + + # --- MLP: dense vs MoE --- + if is_dense: + # Dense MLP: gate_proj, up_proj, down_proj (FP8 blockwise in Flash layer 0). + for proj in ("gate_proj", "up_proj", "down_proj"): + w = lazy.get(f"{prefix}mlp.{proj}.weight") + if w is None: + continue + s = lazy.get(f"{prefix}mlp.{proj}.weight_scale_inv") + w2, s2 = _maybe_fp8_to_neuron_per_row(w, s) + out[f"{out_prefix}mlp.{proj}.weight"] = w2 + if s2 is not None: + out[f"{out_prefix}mlp.{proj}.scale"] = s2 + return out + + # --- MoE --- + # Router: mlp.gate -> mlp.router.linear_router + router_w = lazy.get(f"{prefix}mlp.gate.weight") + if router_w is not None: + out[f"{out_prefix}mlp.router.linear_router.weight"] = router_w.detach().clone() + router_bias = lazy.get(f"{prefix}mlp.gate.e_score_correction_bias") + if router_bias is not None: + out[f"{out_prefix}mlp.router.e_score_correction_bias"] = router_bias.detach().clone() + + num_experts = config["n_routed_experts"] + + # Peek expert 0 to learn shapes/dtypes. + e0_gw = lazy.get(f"{prefix}mlp.experts.0.gate_proj.weight") + if e0_gw is None: + return out # no experts (shouldn't happen for MoE layers, but be safe) + e0_gs = lazy.get(f"{prefix}mlp.experts.0.gate_proj.weight_scale_inv") + + if e0_gw.dtype == torch.float8_e4m3fn and e0_gs is not None: + sample_w, sample_s = rescale_fp8_weight_blockwise(e0_gw, e0_gs) + elif e0_gw.dtype == torch.bfloat16: + # Should not happen for Flash (experts ship in FP8); flag loudly. + raise NotImplementedError( + f"Layer {layer_idx} expert 0 gate_proj is BF16; Flash expects FP8." + ) + else: + sample_w, sample_s = e0_gw, e0_gs + + intermediate_size, hidden_size = sample_w.shape # [IM, H] + # Packed transpose layout: [num_experts, H, 2*IM] for gate_up. + gate_up_proj = torch.empty( + num_experts, hidden_size, 2 * intermediate_size, dtype=sample_w.dtype + ) + i_blocks, h_blocks = sample_s.shape # [IM_blocks, H_blocks] + gate_up_scale = torch.empty( + num_experts, h_blocks, 2 * i_blocks, dtype=sample_s.dtype + ) + + e0_dw = lazy.get(f"{prefix}mlp.experts.0.down_proj.weight") + e0_ds = lazy.get(f"{prefix}mlp.experts.0.down_proj.weight_scale_inv") + if e0_dw.dtype == torch.float8_e4m3fn and e0_ds is not None: + sample_dw, sample_ds = rescale_fp8_weight_blockwise(e0_dw, e0_ds) + else: + raise NotImplementedError( + f"Layer {layer_idx} expert 0 down_proj dtype {e0_dw.dtype} not handled." + ) + d_h_blocks, d_i_blocks = sample_ds.shape # [H_blocks, IM_blocks] + down_proj = torch.empty( + num_experts, intermediate_size, hidden_size, dtype=sample_dw.dtype + ) + down_scale = torch.empty( + num_experts, d_i_blocks, d_h_blocks, dtype=sample_ds.dtype + ) + + # Slot expert 0 (already rescaled above). + gate_up_proj[0, :, :intermediate_size] = sample_w.T + gate_up_scale[0, :, :i_blocks] = sample_s.T + e0_uw = lazy.get(f"{prefix}mlp.experts.0.up_proj.weight") + e0_us = lazy.get(f"{prefix}mlp.experts.0.up_proj.weight_scale_inv") + up_w0, up_s0 = rescale_fp8_weight_blockwise(e0_uw, e0_us) + gate_up_proj[0, :, intermediate_size:] = up_w0.T + gate_up_scale[0, :, i_blocks:] = up_s0.T + down_proj[0] = sample_dw.T + down_scale[0] = sample_ds.T + del e0_gw, e0_gs, e0_uw, e0_us, e0_dw, e0_ds + del sample_w, sample_s, sample_dw, sample_ds, up_w0, up_s0 + + for e in range(1, num_experts): + gw = lazy.get(f"{prefix}mlp.experts.{e}.gate_proj.weight") + gs = lazy.get(f"{prefix}mlp.experts.{e}.gate_proj.weight_scale_inv") + uw = lazy.get(f"{prefix}mlp.experts.{e}.up_proj.weight") + us = lazy.get(f"{prefix}mlp.experts.{e}.up_proj.weight_scale_inv") + dw = lazy.get(f"{prefix}mlp.experts.{e}.down_proj.weight") + ds = lazy.get(f"{prefix}mlp.experts.{e}.down_proj.weight_scale_inv") + g_w, g_s = rescale_fp8_weight_blockwise(gw, gs) + u_w, u_s = rescale_fp8_weight_blockwise(uw, us) + d_w, d_s = rescale_fp8_weight_blockwise(dw, ds) + gate_up_proj[e, :, :intermediate_size] = g_w.T + gate_up_proj[e, :, intermediate_size:] = u_w.T + gate_up_scale[e, :, :i_blocks] = g_s.T + gate_up_scale[e, :, i_blocks:] = u_s.T + down_proj[e] = d_w.T + down_scale[e] = d_s.T + del gw, gs, uw, us, dw, ds, g_w, g_s, u_w, u_s, d_w, d_s + + out[f"{out_prefix}mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj + out[f"{out_prefix}mlp.expert_mlps.mlp_op.gate_up_proj.scale"] = gate_up_scale + out[f"{out_prefix}mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj + out[f"{out_prefix}mlp.expert_mlps.mlp_op.down_proj.scale"] = down_scale + return out + + +# --------------------------------------------------------------------------- +# Shard saving / index +# --------------------------------------------------------------------------- + +def save_shard( + tensors: Dict[str, torch.Tensor], + save_path: str, + filename: str, + weight_map: Dict[str, str], +) -> int: + """Save a sub-state-dict; clone tensors so safetensors doesn't complain + about views of mmapped storage. Returns bytes written.""" + path = os.path.join(save_path, filename) + materialized: Dict[str, torch.Tensor] = {} + total_bytes = 0 + for k, v in tensors.items(): + if not v.is_contiguous(): + v = v.contiguous() + v = v.detach().clone() + materialized[k] = v + total_bytes += v.numel() * v.element_size() + save_file(materialized, path) + for k in materialized.keys(): + weight_map[k] = filename + del materialized + return total_bytes + + +# --------------------------------------------------------------------------- +# Main driver +# --------------------------------------------------------------------------- + +def process_flash_checkpoint(hf_model_path: str, save_path: str, tp_degree: int): + os.makedirs(save_path, exist_ok=True) + + with open(os.path.join(hf_model_path, "model.safetensors.index.json")) as f: + weight_map_in = json.load(f)["weight_map"] + + with open(os.path.join(hf_model_path, "config.json")) as f: + config = json.load(f) + + num_layers = config["num_hidden_layers"] + hybrid = config.get("hybrid_layer_pattern", [0] * num_layers) + moe_freq = config.get("moe_layer_freq", [1] * num_layers) + + print( + f"Processing {num_layers} decoder layers" + f" (full={sum(1 for v in hybrid if v == 0)}," + f" swa={sum(1 for v in hybrid if v == 1)}," + f" dense={sum(1 for v in moe_freq if v == 0)}," + f" moe={sum(1 for v in moe_freq if v == 1)})", + flush=True, + ) + + lazy = LazyWeightMap(hf_model_path, weight_map_in) + weight_map_out: Dict[str, str] = {} + + try: + for li in range(num_layers): + t0 = time.time() + is_dense = moe_freq[li] == 0 + is_swa = hybrid[li] == 1 + layer_sd = process_layer(li, lazy, config, is_dense=is_dense, is_swa=is_swa) + filename = f"model_layer{li}.safetensors" + size = save_shard(layer_sd, save_path, filename, weight_map_out) + del layer_sd + gc.collect() + tag = "dense" if is_dense else "moe " + attn = "swa " if is_swa else "full" + print( + f" layer {li:2d} [{tag} {attn}] {size/1e9:6.2f} GB in {time.time()-t0:5.1f}s", + flush=True, + ) + + print("Processing embed_tokens, norm, lm_head ...", flush=True) + extras: Dict[str, torch.Tensor] = {} + for src, dst in ( + ("model.embed_tokens.weight", "embed_tokens.weight"), + ("model.norm.weight", "norm.weight"), + ("lm_head.weight", "lm_head.weight"), + ): + t = lazy.get(src) + if t is not None: + extras[dst] = t.detach().clone() + else: + print(f" WARNING: missing {src}", flush=True) + if "lm_head.weight" not in extras and "embed_tokens.weight" in extras: + # Tied embeddings + extras["lm_head.weight"] = extras["embed_tokens.weight"].detach().clone() + save_shard(extras, save_path, "model_extras.safetensors", weight_map_out) + del extras + finally: + lazy.close() + + # --- Index file --- + total_size = 0 + for f in set(weight_map_out.values()): + total_size += os.path.getsize(os.path.join(save_path, f)) + index = { + "metadata": {"total_size": total_size}, + "weight_map": weight_map_out, + } + with open(os.path.join(save_path, "model.safetensors.index.json"), "w") as f: + json.dump(index, f, indent=2) + + # --- Copy auxiliary files (config.json, tokenizer, chat template, + # and crucially the trust_remote_code modules the HF config references). + for name in sorted(os.listdir(hf_model_path)): + if name.endswith(".safetensors"): + continue + if name == "model.safetensors.index.json": + continue + src = os.path.join(hf_model_path, name) + if os.path.isfile(src): + shutil.copy(src, os.path.join(save_path, name)) + + print(f"\nPreprocess complete. total_size={total_size/1e9:.2f} GB", flush=True) + print(f" tensors written: {len(weight_map_out)}", flush=True) + print(f" output dir: {save_path}", flush=True) + + +def main(): + parser = argparse.ArgumentParser( + description="Preprocess MiMo-V2-Flash FP8 checkpoint for Neuron inference" + ) + parser.add_argument("--hf_model_path", required=True) + parser.add_argument("--save_path", required=True) + parser.add_argument("--tp_degree", type=int, default=64, + help="Tensor parallelism (currently informational only; " + "the framework does the TP sharding at load time).") + args = parser.parse_args() + process_flash_checkpoint(args.hf_model_path, args.save_path, args.tp_degree) + + +if __name__ == "__main__": + main() From 2fbbdca7098fb1aad9afa4bdcd0ddcc0cd86e92b Mon Sep 17 00:00:00 2001 From: whn09 Date: Thu, 23 Apr 2026 19:44:52 +0800 Subject: [PATCH 09/23] Enable FP8 inference for MiMo-V2-Flash (BF16 path unchanged) Wire the runtime pieces needed to run Flash's preprocessed Neuron-FP8 checkpoint. All modifications are gated by neuron_config.quantized, so the existing BF16 path is untouched. New pieces: - Four monkey-patch installers on NeuronMiMoV2ForCausalLM that reconcile NxDI's global blockwise_symmetric q_config with the mixed 3D-blockwise-MoE + 2D-per-row-attn checkpoint layout: * _apply_ep_scale_fix: don't EP-shard singleton [1,1,W] scales. * _apply_blockwise_scale_stride_fix: force partition_stride=1 for BLOCKWISE_SYMMETRIC to avoid strided-split failures when per-rank weight is smaller than a 128-wide scale block. * _apply_2d_per_channel_fix: 2D attention/dense-MLP weights use per-row (out, 1) scales; flip their from_float q_config from BLOCKWISE_SYMMETRIC to PER_CHANNEL_SYMMETRIC at construction. * _apply_router_noaux_tc_fix: Flash's topk_method=noaux_tc needs e_score_correction_bias in the top-k selection; stock RouterTopK silently drops this bias. - compile()/load() overrides call _install_fp8_patches() before super(). - save_quantized_state_dict override: skip the HF-side re-quantize path (requires CUDA, materializes a ~600 GB BF16 copy) when the preprocess-produced Neuron-FP8 index is already on disk. - convert_mimo_v2_hf_to_neuron_state_dict additions (FP8-only): * Replicate per-row K/V .scale tensors in lockstep with the existing CONVERT_TO_MHA weight replication (TP=64 > 4/8 KV heads). * Expand MoE blockwise gate_up_proj/down_proj .scale tensors along the TP-partitioned dim so per_partition_size == 1 after sharding (preserves gate|up boundary by expanding each half independently). Cleanup: drop the verbose [DEBUG] prints from the BF16-era CONVERT_TO_MHA block - useful during bring-up, noisy in steady state (48 layers x 4 prints per run). Verified end-to-end on Trn2 (TP=64, EP=1, SEQ=1024, BS=1): preprocess -> smoke_compile (~18.5 min) -> smoke_generate Prompt : "Hello! Please introduce yourself in one sentence." Output : "**Hi, I'm Alex AI, a virtual AI assistant created by Meta AI to help answer questions" (20 tokens in 1.19s, 16.7 tok/s) Coherent fluent output, no token collapse. Co-Authored-By: Claude Opus 4.7 --- .../MiMo-V2-Flash/src/modeling_mimo_v2.py | 310 +++++++++++++++++- 1 file changed, 300 insertions(+), 10 deletions(-) diff --git a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py index 4ea33fb0..955ae1a0 100644 --- a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py +++ b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py @@ -1086,12 +1086,6 @@ def convert_mimo_v2_hf_to_neuron_state_dict( full_use_convert_to_mha = tp_degree > full_num_kv_heads swa_use_convert_to_mha = tp_degree > swa_num_kv_heads - print(f"\n[DEBUG] CONVERT_TO_MHA status:") - print(f" tp_degree: {tp_degree}") - print(f" num_attention_heads: {num_attention_heads}") - print(f" full_num_kv_heads: {full_num_kv_heads}, use_convert_to_mha: {full_use_convert_to_mha}") - print(f" swa_num_kv_heads: {swa_num_kv_heads}, use_convert_to_mha: {swa_use_convert_to_mha}") - for layer_idx in range(config.num_hidden_layers): # Add rank utility for attention neuron_state_dict[f"layers.{layer_idx}.self_attn.rank_util.rank"] = torch.arange( @@ -1118,24 +1112,34 @@ def convert_mimo_v2_hf_to_neuron_state_dict( v_proj_key = f"layers.{layer_idx}.self_attn.v_proj.weight" if k_proj_key in neuron_state_dict: - old_shape = neuron_state_dict[k_proj_key].shape neuron_state_dict[k_proj_key] = _replicate_kv_weights_for_convert_to_mha( neuron_state_dict[k_proj_key], src_num_kv_heads, num_attention_heads, head_dim, ) - print(f"[DEBUG] Layer {layer_idx} ({'SWA' if is_sliding_window else 'Full'}): Replicated K: {old_shape} -> {neuron_state_dict[k_proj_key].shape}") if v_proj_key in neuron_state_dict: - old_shape = neuron_state_dict[v_proj_key].shape neuron_state_dict[v_proj_key] = _replicate_kv_weights_for_convert_to_mha( neuron_state_dict[v_proj_key], src_num_kv_heads, num_attention_heads, v_head_dim, ) - print(f"[DEBUG] Layer {layer_idx} ({'SWA' if is_sliding_window else 'Full'}): Replicated V: {old_shape} -> {neuron_state_dict[v_proj_key].shape}") + + # FP8 path: replicate per-row scales ([src_heads*head_dim, 1]) in + # lockstep with the weights. Without this the shard_weights step + # rejects the scale shape mismatch (e.g. [12,1] vs expected [192,1]). + # BF16 has no .scale key, so this loop is a no-op there. + for proj, hd in (("k_proj", head_dim), ("v_proj", v_head_dim)): + scale_key = f"layers.{layer_idx}.self_attn.{proj}.scale" + if scale_key in neuron_state_dict: + neuron_state_dict[scale_key] = _replicate_kv_weights_for_convert_to_mha( + neuron_state_dict[scale_key], + src_num_kv_heads, + num_attention_heads, + hd, + ) # Only convert MoE layers if not config.layer_uses_moe[layer_idx]: @@ -1219,6 +1223,84 @@ def convert_mimo_v2_hf_to_neuron_state_dict( gc.collect() + # --- Expand MoE blockwise scales along the TP-partitioned dim (FP8 only). --- + # NxDI's shard_checkpoint splits the scale on its partition dim into + # `per_partition_size = dim_size / tp_degree`. At TP=64 both projections + # have per-rank "intermediate" smaller than the 128-wide scale block, so + # several ranks share one scale block — we need to replicate scale entries + # along that dim. Adjacent ranks whose weight falls inside the same + # 128-wide block genuinely share that block's scale. No-op when the + # .scale keys are absent (BF16 path). + if getattr(config.neuron_config, "quantized", False): + tp = config.neuron_config.tp_degree + for layer_idx in range(config.num_hidden_layers): + if not config.layer_uses_moe[layer_idx]: + continue + + # down_proj (RowParallel on intermediate dim). Scale: [E, I_blocks, H_blocks] + dp_key = f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.down_proj.scale" + if dp_key in neuron_state_dict: + s = neuron_state_dict[dp_key] + i_blocks = s.shape[1] + h_blocks = s.shape[2] + intermediate = i_blocks * 128 + i_per_rank = intermediate // tp # 32 @ TP=64 for IM=2048 + if i_per_rank < 128: + ranks_per_block = 128 // i_per_rank # 4 @ TP=64 + s_exp = s.unsqueeze(2).expand(-1, -1, ranks_per_block, -1) + s_exp = s_exp.reshape(s.shape[0], i_blocks * ranks_per_block, h_blocks) + assert s_exp.shape[1] == tp, ( + f"down_proj.scale expansion produced {s_exp.shape[1]} rows, " + f"expected TP={tp}" + ) + neuron_state_dict[dp_key] = s_exp.contiguous() + + # gate_up_proj (ColumnParallel on 2*intermediate dim, gate|up fused + # along last axis). Scale: [E, H_blocks, 2*I_blocks] stored as + # [gate_half | up_half]. Module parameter has per-rank last-dim=1 + # (via _apply_blockwise_scale_stride_fix patch forcing + # partition_stride=1), so the full scale must have last-dim=tp + # with gate entries 0..tp/2 and up entries tp/2..tp. Expand each + # half independently to preserve the gate/up boundary when NxD + # does `split(per_partition=2*I/tp, dim=-1)`. + gu_key = f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.gate_up_proj.scale" + if gu_key in neuron_state_dict: + s = neuron_state_dict[gu_key] + h_blocks = s.shape[1] + two_i_blocks = s.shape[2] + assert two_i_blocks % 2 == 0, ( + f"gate_up_proj.scale last dim must be 2*i_blocks, got {two_i_blocks}" + ) + i_blocks = two_i_blocks // 2 + intermediate = i_blocks * 128 + out_per_rank = (2 * intermediate) // tp # 64 @ TP=64 for IM=2048 + if out_per_rank < 128: + assert tp % 2 == 0, f"TP={tp} must be even for gate/up scale split" + ranks_per_half = tp // 2 # 32 @ TP=64 + assert ranks_per_half % i_blocks == 0, ( + f"ranks_per_half={ranks_per_half} must be divisible by " + f"i_blocks={i_blocks}" + ) + ranks_per_block = ranks_per_half // i_blocks # 2 @ TP=64 (i_blocks=16) + gate_half = s[..., :i_blocks] # [E, H_blocks, i_blocks] + up_half = s[..., i_blocks:] + gate_exp = ( + gate_half.unsqueeze(-1) + .expand(-1, -1, -1, ranks_per_block) + .reshape(s.shape[0], h_blocks, ranks_per_half) + ) + up_exp = ( + up_half.unsqueeze(-1) + .expand(-1, -1, -1, ranks_per_block) + .reshape(s.shape[0], h_blocks, ranks_per_half) + ) + s_exp = torch.cat([gate_exp, up_exp], dim=-1) + assert s_exp.shape[-1] == tp, ( + f"gate_up_proj.scale expansion produced {s_exp.shape[-1]} " + f"entries, expected TP={tp}" + ) + neuron_state_dict[gu_key] = s_exp.contiguous() + return neuron_state_dict @@ -1299,6 +1381,214 @@ def enable_token_generation(self): self.compile_tag = TOKEN_GENERATION_MODEL_TAG super().enable_token_generation() + # ------------------------------------------------------------------ + # FP8 quantized-inference monkey-patches (no-op unless quantized=True). + # + # Reconcile the preprocessed Neuron-FP8 checkpoint (blockwise-MoE + + # per-row-attn) with NxDI's global blockwise_symmetric q_config. All + # four are gated by self.neuron_config.quantized so the BF16 path is + # completely untouched. + # ------------------------------------------------------------------ + + @staticmethod + def _apply_ep_scale_fix(): + """Skip per-channel `scale` params when marking expert-parallel + weights; they have shape [1, 1, W] and cannot be EP-sharded.""" + from neuronx_distributed.modules.moe.moe_parallel_layers import ( + ExpertFusedLinear, + ) + + if getattr(ExpertFusedLinear, "_mimo_v2_ep_scale_patched", False): + return + + def _patched_mark( + self_inner, + iterable=None, + expert_parallel_group_size=None, + is_prefill=True, + expert_distribution=None, + ): + from neuronx_distributed.parallel_layers.parallel_state import ( + get_expert_model_parallel_size, + ) + + if expert_parallel_group_size is None: + expert_parallel_group_size = get_expert_model_parallel_size() + + if expert_parallel_group_size > 1: + if iterable is None: + params_to_mark = [] + for name, p in self_inner.named_parameters(): + if name == "scale" and p.shape[0] == 1: + continue + params_to_mark.append(p) + iterable = params_to_mark + + for p in iterable: + p.expert_model_parallel = True + if is_prefill: + p.is_prefill = True + p.expert_distribution = expert_distribution + + ExpertFusedLinear._mark_expert_parallel_weights = _patched_mark + ExpertFusedLinear._mimo_v2_ep_scale_patched = True + + @staticmethod + def _apply_blockwise_scale_stride_fix(): + """Force scale.partition_stride=1 for BLOCKWISE_SYMMETRIC quantization + — stride>1 causes strided-splitting failures when per-rank weight size + is smaller than a block.""" + from neuronx_distributed.quantization.quantization_config import ( + QuantizationType, + ) + from neuronx_distributed.quantization.quantization_layers import ( + BaseQuantizeParallelLinear, + ) + + if getattr(BaseQuantizeParallelLinear, "_mimo_v2_blockwise_stride_patched", False): + return + + _original_setup = BaseQuantizeParallelLinear._setup_for_scale + + def _patched_setup(self_inner, *args, **kwargs): + _original_setup(self_inner, *args, **kwargs) + if ( + hasattr(self_inner, "quantization_type") + and self_inner.quantization_type == QuantizationType.BLOCKWISE_SYMMETRIC + and hasattr(self_inner, "scale") + and hasattr(self_inner.scale, "partition_stride") + and self_inner.scale.partition_stride > 1 + ): + self_inner.scale.partition_stride = 1 + + BaseQuantizeParallelLinear._setup_for_scale = _patched_setup + BaseQuantizeParallelLinear._mimo_v2_blockwise_stride_patched = True + + @staticmethod + def _apply_2d_per_channel_fix(): + """Route 2D self_attn + layer-0 dense-MLP swaps through per_channel_symmetric. + + Flash's preprocess writes: + - MoE experts: 3D weights with (E, out//128, in//128) blockwise scales. + - self_attn q/k/v + layer-0 mlp gate/up/down: 2D weights with + (out, 1) per-row scales. + + NxDI's q_config is global blockwise_symmetric (to satisfy the MoE). + Feeding that into the 2D classes triggers + `block axis cannot be < 0 or > 2, received 2` in _setup_for_scale + (block axes [1, 2] exceed rank-2 weight_shape). This wraps the 2D + classes' from_float to override q_config on the fly. + """ + from neuronx_distributed.quantization.quantization_config import ( + QuantizationType, + ) + from neuronx_distributed.quantization.quantization_layers import ( + QuantizedColumnParallel, + QuantizedRowParallel, + ) + + def _wrap(cls): + if getattr(cls, "_mimo_v2_2d_patched", False): + return + original_from_float = cls.from_float + + def _patched_from_float(klass, mod, q_config=None, _orig=original_from_float): + if q_config is not None and q_config.get("quantization_type") == \ + QuantizationType.BLOCKWISE_SYMMETRIC: + q_config = dict(q_config) + q_config["quantization_type"] = QuantizationType.PER_CHANNEL_SYMMETRIC + q_config["quantization_per_channel_axis"] = 0 + q_config.pop("block_axis", None) + q_config.pop("block_size", None) + if q_config is None: + return _orig(mod) + return _orig(mod, q_config) + + cls.from_float = classmethod(_patched_from_float) + cls._mimo_v2_2d_patched = True + + _wrap(QuantizedColumnParallel) + _wrap(QuantizedRowParallel) + + @staticmethod + def _apply_router_noaux_tc_fix(): + """Register e_score_correction_bias on NxD RouterTopK and fold it into + top-k selection so Flash's noaux_tc routing matches HF reference. + + Flash's HF config uses topk_method='noaux_tc': each expert score is + `sigmoid(logits) + e_score_correction_bias`, top-k indices are chosen + from THAT biased score; the returned expert weights (affinities) + come from the UNBIASED sigmoid(logits). NxD's stock RouterTopK is + plain topk with no bias slot, so without this the bias is silently + dropped and ~all tokens route to wrong experts. + """ + from neuronx_distributed.modules.moe.routing import RouterTopK + + if getattr(RouterTopK, "_mimo_v2_noaux_tc_patched", False): + return + + original_init = RouterTopK.__init__ + + def _patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + self.e_score_correction_bias = nn.Parameter( + torch.zeros(self.num_experts, dtype=torch.float32), + requires_grad=False, + ) + + def _patched_forward(self, hidden_states): + router_logits = self.get_router_logits(hidden_states) + expert_affinities = self.apply_activation_fn(router_logits) + + selection_scores = expert_affinities.to(torch.float32) + \ + self.e_score_correction_bias + _, expert_index = torch.topk(selection_scores, self.top_k, dim=-1) + + expert_affinities = expert_affinities.to(dtype=hidden_states.dtype) + expert_index = expert_index.detach().to(dtype=torch.long) + return router_logits, expert_affinities, expert_index + + RouterTopK.__init__ = _patched_init + RouterTopK.forward = _patched_forward + RouterTopK._mimo_v2_noaux_tc_patched = True + + def _install_fp8_patches(self): + """Install all FP8-specific runtime patches. No-op for BF16.""" + if not getattr(self.neuron_config, "quantized", False): + return + self._apply_ep_scale_fix() + self._apply_blockwise_scale_stride_fix() + self._apply_2d_per_channel_fix() + self._apply_router_noaux_tc_fix() + + def compile(self, *args, **kwargs): + # save_sharded_checkpoint=True serializes shards during compile() and + # that code path reads scale.partition_stride — patches must be live. + self._install_fp8_patches() + return super().compile(*args, **kwargs) + + def load(self, *args, **kwargs): + self._install_fp8_patches() + return super().load(*args, **kwargs) + + @classmethod + def save_quantized_state_dict(cls, model_path, config): + """Flash ships pre-quantized FP8 safetensors via our preprocess script. + The base implementation calls AutoModelForCausalLM.from_pretrained to + re-quantize, which requires a CUDA GPU (finegrained_fp8 gate) and + materializes an ~600 GB BF16 copy. Skip if the checkpoint directory + already contains a Neuron-FP8 index produced by preprocess.""" + import os as _os + qpath = ( + getattr(config.neuron_config, "quantized_checkpoints_path", None) + or model_path + ) + if qpath and _os.path.isdir(qpath): + index = _os.path.join(qpath, "model.safetensors.index.json") + if _os.path.isfile(index): + return + return super().save_quantized_state_dict(model_path, config) + def get_compiler_args(self) -> str: """Get compiler arguments optimized for MiMo-V2-Flash.""" if self.compile_tag == CONTEXT_ENCODING_MODEL_TAG: From ec2aa406bbd7507c9336fee99a2c201267e03d50 Mon Sep 17 00:00:00 2001 From: whn09 Date: Thu, 23 Apr 2026 19:45:10 +0800 Subject: [PATCH 10/23] Add FP8 smoke tests and save_sharded_checkpoint to bench Two minimal scripts that bypass vLLM so FP8 bring-up can iterate on the preprocessed Neuron-FP8 checkpoint without paying vllm-neuron startup cost: - smoke_compile_mimo_v2_flash.py: STAGE={instantiate,compile,load,all}, DRY_RUN=1 for HLO-only, SKIP_WARMUP=1 when HBM is tight. Builds the Flash BS=1 recipe (TP=64, EP=1, blockwise_symmetric, use_shard_on_block _dynamic_while=True) and calls compile()+load() directly. - smoke_generate_mimo_v2_flash.py: 20-token generation via HuggingFaceGenerationAdapter using the same config (hash matches so the NEFF is reused). bench: add "save_sharded_checkpoint": true to COMMON_MIMO_CONFIG. During compile this writes per-rank tp{N}_sharded_checkpoint.safetensors under /weights/; subsequent load()s read those directly (~55s) instead of re-sharding the full checkpoint (~10+ min). Co-Authored-By: Claude Opus 4.7 --- .../perf_test/bench_mimo_v2_flash.sh | 5 + .../perf_test/smoke_compile_mimo_v2_flash.py | 167 ++++++++++++++++++ .../perf_test/smoke_generate_mimo_v2_flash.py | 161 +++++++++++++++++ 3 files changed, 333 insertions(+) create mode 100755 contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py create mode 100755 contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py diff --git a/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh b/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh index a4fcda90..d3fb5910 100755 --- a/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh +++ b/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh @@ -27,8 +27,13 @@ COMMON_MIMO_CONFIG='"tp_degree": 64, "strided_context_parallel_kernel_enabled": false, "glu_mlp": true, "normalize_top_k_affinities": true, + "save_sharded_checkpoint": true, "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, "blockwise_matmul_config": {"use_torch_block_wise": true}' +# save_sharded_checkpoint=true persists per-rank sharded weights to +# /weights/tp{N}_sharded_checkpoint.safetensors during compile; +# load() then reads those directly (~55s) instead of re-sharding the entire +# checkpoint on every vllm-neuron startup (~10+ min). # NOTE: use_torch_block_wise=true forces MoE blockwise to use the PyTorch # reference implementation. The NKI kernel path pulls # neuronxcc.nki._private.blockwise_mm.blockwise_mm_baseline_shard_hidden diff --git a/contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py b/contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py new file mode 100755 index 00000000..775dbaee --- /dev/null +++ b/contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +"""Minimal compile+load smoke test for MiMo-V2-Flash FP8 on Trn2. + +Bypasses vLLM entirely so we can iterate on the preprocessed Neuron-FP8 +checkpoint without paying vllm-neuron's startup cost. Builds the Flash BS=1 +recipe (TP=64, EP=1, blockwise FP8 for routed experts), compiles to a temp +dir, then loads. EP=1 lets the TKG path enter forward_selective_loading +legally so BS=1 compiles — with EP>1 NxDI raises NotImplementedError and +forces BS>=num_experts/top_k = 32. + +STAGE controls how far we go: + instantiate | compile | load | all (default: all) + +DRY_RUN=1 does HLO-only compile (no torch.jit.save + shard). Fastest sanity +check for the preprocessed checkpoint. SKIP_WARMUP=1 on load() skips the +forward pass that allocates the shared scratchpad — useful when HBM is +tight. + +Run under /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16 (same venv used +by the bench script). +""" + +import os +import sys +import time +import traceback + +MODEL_PATH = os.environ.get( + "MIMO_V2_FLASH_MODEL_PATH", + "/opt/dlami/nvme/models/MiMo-V2-Flash-Neuron-FP8", +) +COMPILED_PATH = os.environ.get( + "MIMO_V2_FLASH_COMPILED_PATH", + "/opt/dlami/nvme/compiled/mimo_v2_flash_tp64_ep1_fp8/", +) + +TP_DEGREE = int(os.environ.get("TP_DEGREE", "64")) +SEQ_LEN = int(os.environ.get("SEQ_LEN", "1024")) +BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1")) +CTX_BATCH_SIZE = int(os.environ.get("CTX_BATCH_SIZE", "1")) +MOE_TP = int(os.environ.get("MOE_TP", "64")) +MOE_EP = int(os.environ.get("MOE_EP", "1")) + +STAGE = os.environ.get("STAGE", "all").lower() + +os.makedirs(COMPILED_PATH, exist_ok=True) + + +def main(): + from neuronx_distributed_inference.models.config import MoENeuronConfig + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + + # Import the contrib wrapper (sibling src dir). + contrib_src = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "src", + ) + sys.path.insert(0, os.path.abspath(contrib_src)) + + from modeling_mimo_v2 import ( + MiMoV2InferenceConfig, + NeuronMiMoV2ForCausalLM, + ) + + print(f"[smoke] MODEL_PATH={MODEL_PATH}") + print(f"[smoke] COMPILED_PATH={COMPILED_PATH}") + print(f"[smoke] TP_DEGREE={TP_DEGREE}, SEQ_LEN={SEQ_LEN}, BS={BATCH_SIZE}") + print(f"[smoke] MOE_TP={MOE_TP}, MOE_EP={MOE_EP}") + print(f"[smoke] STAGE={STAGE}") + + print("[smoke] Building MoENeuronConfig (quantized FP8 MoE, blockwise_symmetric)...") + neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + ep_degree=MOE_EP, + logical_nc_config=2, + batch_size=BATCH_SIZE, + max_batch_size=BATCH_SIZE, + ctx_batch_size=CTX_BATCH_SIZE, + tkg_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + n_active_tokens=128, + torch_dtype="bfloat16", + capacity_factor=1.0, + glu_mlp=True, + moe_ep_degree=MOE_EP, + moe_tp_degree=MOE_TP, + context_encoding_buckets=[SEQ_LEN], + router_config={"act_fn": "sigmoid", "dtype": "float32"}, + # SDK 2.29 ships only bwmm_shard_on_block / bwmm_shard_on_intermediate; + # default routes to _call_shard_hidden_kernel which is missing, so we + # take the shard-on-block path via this flag. + blockwise_matmul_config={ + "use_shard_on_block_dynamic_while": True, + "block_sharding_strategy": "PING_PONG", + }, + # Persist sharded FP8 weights to disk so subsequent load()s skip the + # ~10-minute shard_checkpoint step (writes weights/tp{0..63}_*.safetensors + # on NVMe; NxDI load() reads these directly when present). + save_sharded_checkpoint=True, + # FP8 blockwise for routed experts (Kimi-K2 recipe). + quantized=True, + quantized_checkpoints_path=MODEL_PATH, + quantization_dtype="f8e4m3", + quantization_type="blockwise_symmetric", + quantization_block_axis=[1, 2], + quantization_block_size=[128, 128], + modules_to_not_convert=[ + "embed_tokens", + "lm_head", + "norm", + "router", + "o_proj", + ], + ) + + print("[smoke] Building MiMoV2InferenceConfig...") + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + config = MiMoV2InferenceConfig( + neuron_config, load_config=load_pretrained_config(hf_config=hf_config) + ) + print(f"[smoke] config.hidden_size={config.hidden_size}") + print(f"[smoke] config.num_hidden_layers={config.num_hidden_layers}") + print(f"[smoke] config.n_routed_experts={config.n_routed_experts}") + print(f"[smoke] config.num_experts_per_tok={config.num_experts_per_tok}") + print(f"[smoke] config.layer_uses_moe[:5]={config.layer_uses_moe[:5]}") + print(f"[smoke] config.layer_attention_types[:5]={config.layer_attention_types[:5]}") + + print("[smoke] Instantiating NeuronMiMoV2ForCausalLM (build model-on-cpu)...") + t0 = time.time() + model = NeuronMiMoV2ForCausalLM(MODEL_PATH, config) + print(f"[smoke] Instantiated in {time.time() - t0:.1f}s") + + if STAGE == "instantiate": + print("[smoke] STAGE=instantiate only, skipping compile/load.") + return + + DRY_RUN = os.environ.get("DRY_RUN", "0") == "1" + if STAGE in ("compile", "all"): + label = "Dry-run compile (HLO only)" if DRY_RUN else "Full compile" + print(f"[smoke] {label} -> {COMPILED_PATH}") + t0 = time.time() + try: + model.compile(COMPILED_PATH, dry_run=DRY_RUN) + print(f"[smoke] {label} OK in {time.time() - t0:.1f}s") + except Exception: + print(f"[smoke] {label} FAILED:") + traceback.print_exc() + raise + + if STAGE in ("load", "all") and not DRY_RUN: + SKIP_WARMUP = os.environ.get("SKIP_WARMUP", "1") == "1" + print(f"[smoke] Loading compiled model from {COMPILED_PATH} (skip_warmup={SKIP_WARMUP})") + t0 = time.time() + model.load(COMPILED_PATH, skip_warmup=SKIP_WARMUP) + print(f"[smoke] Loaded in {time.time() - t0:.1f}s") + + print("[smoke] Done.") + + +if __name__ == "__main__": + try: + main() + except Exception: + traceback.print_exc() + sys.exit(1) diff --git a/contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py b/contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py new file mode 100755 index 00000000..82670ead --- /dev/null +++ b/contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +"""Minimal generate smoke test for MiMo-V2-Flash FP8 on Trn2. + +Assumes the compiled NEFF already exists at MIMO_V2_FLASH_COMPILED_PATH +(from smoke_compile_mimo_v2_flash.py). Rebuilds the same MoENeuronConfig / +Flash wrapper, loads with skip_warmup=False, and generates 20 tokens for a +single prompt via HuggingFaceGenerationAdapter. Purpose: sanity-check that +the FP8 MoE + preprocessed scales actually produce coherent tokens. + +Run under /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16. +""" + +import os +import sys +import time +import traceback + +MODEL_PATH = os.environ.get( + "MIMO_V2_FLASH_MODEL_PATH", + "/opt/dlami/nvme/models/MiMo-V2-Flash-Neuron-FP8", +) +COMPILED_PATH = os.environ.get( + "MIMO_V2_FLASH_COMPILED_PATH", + "/opt/dlami/nvme/compiled/mimo_v2_flash_tp64_ep1_fp8/", +) + +# Must match smoke_compile_mimo_v2_flash.py exactly, else load() sees a +# mismatched NEFF. +TP_DEGREE = int(os.environ.get("TP_DEGREE", "64")) +SEQ_LEN = int(os.environ.get("SEQ_LEN", "1024")) +BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1")) +CTX_BATCH_SIZE = int(os.environ.get("CTX_BATCH_SIZE", "1")) +MOE_TP = int(os.environ.get("MOE_TP", "64")) +MOE_EP = int(os.environ.get("MOE_EP", "1")) + +PROMPT = os.environ.get( + "MIMO_V2_FLASH_PROMPT", + "Hello! Please introduce yourself in one sentence.", +) +MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "20")) + + +def main(): + from transformers import AutoConfig, AutoTokenizer, GenerationConfig + + from neuronx_distributed_inference.models.config import MoENeuronConfig + from neuronx_distributed_inference.utils.hf_adapter import ( + HuggingFaceGenerationAdapter, + load_pretrained_config, + ) + + contrib_src = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "src", + ) + sys.path.insert(0, os.path.abspath(contrib_src)) + + from modeling_mimo_v2 import ( + MiMoV2InferenceConfig, + NeuronMiMoV2ForCausalLM, + ) + + print(f"[gen] MODEL_PATH={MODEL_PATH}") + print(f"[gen] COMPILED_PATH={COMPILED_PATH}") + print(f"[gen] TP={TP_DEGREE}, SEQ={SEQ_LEN}, BS={BATCH_SIZE}") + + neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + ep_degree=MOE_EP, + logical_nc_config=2, + batch_size=BATCH_SIZE, + max_batch_size=BATCH_SIZE, + ctx_batch_size=CTX_BATCH_SIZE, + tkg_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + n_active_tokens=128, + torch_dtype="bfloat16", + capacity_factor=1.0, + glu_mlp=True, + moe_ep_degree=MOE_EP, + moe_tp_degree=MOE_TP, + context_encoding_buckets=[SEQ_LEN], + router_config={"act_fn": "sigmoid", "dtype": "float32"}, + blockwise_matmul_config={ + "use_shard_on_block_dynamic_while": True, + "block_sharding_strategy": "PING_PONG", + }, + save_sharded_checkpoint=True, + quantized=True, + quantized_checkpoints_path=MODEL_PATH, + quantization_dtype="f8e4m3", + quantization_type="blockwise_symmetric", + quantization_block_axis=[1, 2], + quantization_block_size=[128, 128], + modules_to_not_convert=[ + "embed_tokens", + "lm_head", + "norm", + "router", + "o_proj", + ], + ) + + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + config = MiMoV2InferenceConfig( + neuron_config, load_config=load_pretrained_config(hf_config=hf_config) + ) + + print("[gen] Instantiating model...") + t0 = time.time() + model = NeuronMiMoV2ForCausalLM(MODEL_PATH, config) + print(f"[gen] Instantiated in {time.time() - t0:.1f}s") + + # skip_warmup=False so generate() hits a primed graph (the warmup forward + # allocates the shared scratchpad the generation path needs). + print(f"[gen] Loading from {COMPILED_PATH} (skip_warmup=False)") + t0 = time.time() + model.load(COMPILED_PATH, skip_warmup=False) + print(f"[gen] Loaded in {time.time() - t0:.1f}s") + + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + adapter = HuggingFaceGenerationAdapter(model) + + inputs = tokenizer([PROMPT] * BATCH_SIZE, return_tensors="pt", padding=True) + gen_config = GenerationConfig( + max_new_tokens=MAX_NEW_TOKENS, + min_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + pad_token_id=getattr(tokenizer, "pad_token_id", None) or tokenizer.eos_token_id, + ) + + print(f"[gen] prompt: {PROMPT!r}") + print(f"[gen] input_ids.shape={tuple(inputs['input_ids'].shape)}") + t0 = time.time() + output_ids = adapter.generate( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + generation_config=gen_config, + ) + dt = time.time() - t0 + + prompt_len = inputs["input_ids"].shape[1] + new_tokens = output_ids[0, prompt_len:] + decoded = tokenizer.decode(new_tokens, skip_special_tokens=True) + full = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + print(f"[gen] generated {new_tokens.numel()} tokens in {dt:.2f}s " + f"({new_tokens.numel() / dt:.2f} tok/s)") + print(f"[gen] new token ids: {new_tokens.tolist()}") + print(f"[gen] new text : {decoded!r}") + print(f"[gen] full text : {full!r}") + print("[gen] Done.") + + +if __name__ == "__main__": + try: + main() + except Exception: + traceback.print_exc() + sys.exit(1) From d7ac76e476bb9d034ba256cf9a3dbd73d4072a4b Mon Sep 17 00:00:00 2001 From: whn09 Date: Fri, 24 Apr 2026 16:25:43 +0800 Subject: [PATCH 11/23] perf_test/0_setup: clone vllm-neuron into \$HOME not /tmp AMIs on Trn2 dev instances periodically wipe /tmp on reboot, which breaks the editable pip install (finder maps vllm_neuron -> directory that no longer exists and all subsequent imports fail). Using \$HOME makes the install survive reboots; re-running 0_setup.sh after a wipe still works thanks to the existing idempotency guards. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh b/contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh index b0bc7940..6fafa96c 100755 --- a/contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh +++ b/contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh @@ -19,11 +19,11 @@ PATCH_FILE="$(cd "$(dirname "$0")" && pwd)/vllm-neuron-patch.patch" echo "" echo "[1/2] Installing vllm-neuron (release-0.5.0) with the contrib registration patch..." -if [ ! -d /tmp/vllm-neuron ]; then - git clone --branch release-0.5.0 https://github.com/vllm-project/vllm-neuron.git /tmp/vllm-neuron +if [ ! -d $HOME/vllm-neuron ]; then + git clone --branch release-0.5.0 https://github.com/vllm-project/vllm-neuron.git $HOME/vllm-neuron fi -cd /tmp/vllm-neuron +cd $HOME/vllm-neuron # Apply patch (idempotent via `git apply --check` first). if git apply --check "$PATCH_FILE" 2>/dev/null; then From 1259a1a708abe2741aa09e796e3e7e0a93d9f3eb Mon Sep 17 00:00:00 2001 From: whn09 Date: Fri, 24 Apr 2026 16:29:09 +0800 Subject: [PATCH 12/23] Apply attention_value_scale to value_states (matches HF reference) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The HF MiMoV2Flash modeling code (modeling_mimo_v2_flash.py:358-360 in the checkpoint) multiplies value_states by config.attention_value_scale right after the V projection, before attention softmax*V: if self.v_scale is not None: value_states = value_states * self.v_scale Flash config has attention_value_scale = 0.707, so value_states is consistently scaled down by that factor in every attention layer. Previously this file explicitly overrode self.value_scale to 1.0 based on a mistaken reading of the HF source, which made every attention layer's output ~0.707x too large. Short prompts stayed coherent by luck; prompts >=20 tokens accumulated enough error for the logits distribution to collapse, producing repeated single-word gibberish ("sentence sentence sentence" or "the default value is the default value"). Fix: read attention_value_scale from config (defaulting to 1.0) and apply it to value_states at the same point HF does. The old post-attention application point (attn_output *= value_scale) is mathematically equivalent when value_scale != 1.0, but keeping the application point aligned with HF makes future parity checks simpler. Verified on Trn2 TP=64 EP=1 FP8: prompt previously now -------- ---------- -------- "Hello! Please introduce yourself..." ok by luck ok "The quick brown fox...where it lives" ok ok "The quick brown fox...forest, where" "the moon dog "The fox is purs the deep a symbol of dog are..." cleverness..." 35-token chat-template prompt "I I sentence coherent sentence..." think+answer Note: same bug almost certainly present in MiMo-V2-Pro — Pro also force-sets self.value_scale = 1.0, but Pro's config has attention_value_scale = 0.612. Co-Authored-By: Claude Opus 4.7 --- .../MiMo-V2-Flash/src/modeling_mimo_v2.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py index 955ae1a0..a711203e 100644 --- a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py +++ b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py @@ -339,11 +339,11 @@ def __init__( # Scaling factor self.scaling = self.attn_head_dim ** -0.5 - # NOTE: The config may have 'attention_value_scale' (e.g., 0.707), but the HF model - # (modeling_mimo_v2_flash.py) does NOT use this value. The HF model only uses - # head_dim ** -0.5 for attention scaling, which is already applied via self.scaling. - # We must NOT apply attention_value_scale here, as it would cause divergence from HF. - self.value_scale = 1.0 + # HF MiMoV2Attention (modeling_mimo_v2_flash.py) multiplies value_states + # by config.attention_value_scale (0.707 for Flash) right after the V + # projection, before attention softmax*V. Matching that here — applied + # to value_states in forward() rather than to attn_output. + self.value_scale = float(getattr(config, "attention_value_scale", 1.0)) # Store cache KV heads for cache compatibility # With CONVERT_TO_MHA, all layers have num_attention_heads KV heads @@ -562,6 +562,13 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) + # HF MiMoV2Attention scales V by attention_value_scale (0.707 for Flash) + # right after v_proj, before the attention softmax*V. Earlier revisions + # of this file applied it post-attention or not at all; both produce + # gibberish for prompts longer than ~20 tokens. + if self.value_scale != 1.0: + value_states = value_states * self.value_scale + # Reshape for multi-head attention: [bsz, num_heads, seq_len, head_dim] query_states = query_states.view(bsz, q_len, self.local_num_heads, self.attn_head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.local_num_kv_heads, self.attn_head_dim).transpose(1, 2) @@ -789,10 +796,6 @@ def forward( # Apply attention to values attn_output = torch.matmul(attn_weights, value_states) - # Apply value scale if specified - if self.value_scale != 1.0: - attn_output = attn_output * self.value_scale - # Reshape and project output attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.local_num_heads * self.attn_v_head_dim) From 1655dc952c056da61dc0869026b8ee0734214bc5 Mon Sep 17 00:00:00 2001 From: whn09 Date: Fri, 24 Apr 2026 16:29:32 +0800 Subject: [PATCH 13/23] Install FP8 monkey-patches in __init__ (belt-and-braces) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit compile() / load() already call _install_fp8_patches() via our overrides, but harnesses (e.g. vllm-neuron) may trigger RouterTopK and Quantized{Column,Row}Parallel construction during model instantiation — before compile()/load() get a chance to run. By installing the patches up-front in __init__ (gated on quantized=True) we guarantee the patched classes are in effect by the time any of NxDI's layer factories see them, regardless of which harness drives the model. The patches themselves are idempotent (guarded by _mimo_v2_*_patched sentinels), so installing them twice is harmless. Co-Authored-By: Claude Opus 4.7 --- .../MiMo-V2-Flash/src/modeling_mimo_v2.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py index a711203e..9caa1de0 100644 --- a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py +++ b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py @@ -1352,6 +1352,27 @@ class NeuronMiMoV2ForCausalLM(NeuronBaseForCausalLM): _model_cls = NeuronMiMoV2Model + def __init__(self, *args, **kwargs): + # Install FP8 monkey-patches BEFORE super().__init__ so the patched + # RouterTopK.__init__ and quantization layer classes are in effect + # when NxDI builds the decoder (and instantiates routers). Harnesses + # that drive us via model.compile()/model.load() (e.g. vllm-neuron) + # call those methods AFTER construction, so patching from inside + # compile()/load() is too late — RouterTopK instances would already + # lack our e_score_correction_bias parameter, silently routing tokens + # to wrong experts and producing gibberish output. + # + # _install_fp8_patches() reads self.neuron_config, which needs to + # exist; grab it from the args or the config arg the same way the + # base class does. + ncfg = kwargs.get("config") or (args[1] if len(args) > 1 else None) + if ncfg is not None and getattr(getattr(ncfg, "neuron_config", None), "quantized", False): + self._apply_ep_scale_fix() + self._apply_blockwise_scale_stride_fix() + self._apply_2d_per_channel_fix() + self._apply_router_noaux_tc_fix() + super().__init__(*args, **kwargs) + @staticmethod def load_hf_model(model_path: str, **kwargs): """Load HuggingFace model. From b9eea11ad8e513a6951e08b89d98e4fab3fc0c9b Mon Sep 17 00:00:00 2001 From: whn09 Date: Fri, 24 Apr 2026 22:49:36 +0800 Subject: [PATCH 14/23] WIP: symmetric K/V head_dim via preprocess-side V padding Pre-pad V projection weights from [num_kv_heads*v_head_dim(128), hidden] to [num_kv_heads*head_dim(192), hidden] by appending 64 zero rows per head. This lets the Neuron KV cache manager hold K and V in a single symmetric shape instead of us runtime-padding V on every forward step. The modeling-side forward() now slices the attention output back to real_v_head_dim(128) right before the reshape+o_proj path, matching HF's o_proj weight shape. Changes: - preprocess_mimo_v2_flash_fp8.py: pre-pad v_proj weight and scale to head_dim=192 per head (zero-fill the tail). No-op if the checkpoint already has v_head_dim == head_dim. - modeling_mimo_v2.py: introduce attn_real_v_head_dim alongside attn_v_head_dim (now always == attn_head_dim). Delete the two runtime WORKAROUND blocks that used to pad V and repeat/slice KV heads on every forward step. Slice attention output to real_v_head_dim before reshape+o_proj. o_proj input dim switches to num_heads*real_v_head_dim so the HF o_proj weight shape still matches. - convert_mimo_v2_hf_to_neuron_state_dict: v_proj replication for CONVERT_TO_MHA now uses head_dim (V is pre-padded), not v_head_dim. v_proj scale replication likewise uses head_dim. STATUS: does NOT fix the long-decode output collapse that triggered this refactor. Chinese chat-template prompts at 40 tokens still degrade into repetition even after this change. Kept on the branch because the symmetric-KV-cache layout is architecturally cleaner and matches what the Kimi-K2 contrib model does; future debugging can build on this instead of having to reason about runtime V pad/slice. Co-Authored-By: Claude Opus 4.7 --- .../preprocess_mimo_v2_flash_fp8.py | 54 ++++++++++++- .../MiMo-V2-Flash/src/modeling_mimo_v2.py | 78 +++++++++++-------- 2 files changed, 96 insertions(+), 36 deletions(-) diff --git a/contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_flash_fp8.py b/contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_flash_fp8.py index 2491aa18..8a436c0f 100644 --- a/contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_flash_fp8.py +++ b/contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_flash_fp8.py @@ -195,8 +195,8 @@ def process_layer( out[f"{out_prefix}{name}.weight"] = t.detach().clone() # --- Attention: q/k/v/o are stored separately in Flash --- - # q/k/v: rescale to Neuron FP8 per-row. - for proj in ("q_proj", "k_proj", "v_proj"): + # q/k: rescale to Neuron FP8 per-row. + for proj in ("q_proj", "k_proj"): w = lazy.get(f"{prefix}self_attn.{proj}.weight") if w is None: continue @@ -206,6 +206,56 @@ def process_layer( if s2 is not None: out[f"{out_prefix}self_attn.{proj}.scale"] = s2 + # v_proj: MiMo-V2 has asymmetric head_dim (Q/K=192, V=128). NxDI's KV cache + # manager requires K and V to share the same head_dim; rather than pad V at + # runtime (every decode step) which is error-prone, we pre-pad V's output + # dimension to match head_dim here: each V head goes from 128 rows to 192 + # rows, with the extra 64 rows zero. Downstream Q @ V yields zeros in the + # padded dims, and the forward() slices the attention output back to + # v_head_dim=128 before o_proj. This lets the modeling code drop the + # `if v_head_dim < head_dim` runtime pad/slice logic entirely. + v_w_hf = lazy.get(f"{prefix}self_attn.v_proj.weight") + v_s_hf = lazy.get(f"{prefix}self_attn.v_proj.weight_scale_inv") + if v_w_hf is not None: + # First run the normal FP8-blockwise -> Neuron-per-row rescale on the + # native [num_kv_heads*128, hidden] weight. + v_w, v_s = _maybe_fp8_to_neuron_per_row(v_w_hf, v_s_hf) + # Now pad output dim from (num_kv_heads*128) to (num_kv_heads*192) by + # inserting 64 zero rows after every 128 real rows (preserve per-head + # head_dim layout). + num_kv_heads_swa = config.get("swa_num_key_value_heads", config["num_key_value_heads"]) + num_kv_heads_full = config["num_key_value_heads"] + num_kv_heads = num_kv_heads_swa if is_swa else num_kv_heads_full + v_head_dim = config.get("v_head_dim", 128) + head_dim = config.get("head_dim", 192) + assert v_w.shape[0] == num_kv_heads * v_head_dim, ( + f"v_proj out-dim {v_w.shape[0]} != num_kv_heads({num_kv_heads}) * " + f"v_head_dim({v_head_dim}) = {num_kv_heads * v_head_dim}" + ) + pad_per_head = head_dim - v_head_dim + hidden = v_w.shape[1] + # Reshape to [num_kv_heads, v_head_dim, hidden], pad to + # [num_kv_heads, head_dim, hidden], flatten back. + v_w_per_head = v_w.view(num_kv_heads, v_head_dim, hidden) + v_w_padded = torch.zeros( + num_kv_heads, head_dim, hidden, dtype=v_w_per_head.dtype, + ) + v_w_padded[:, :v_head_dim, :] = v_w_per_head + v_w_padded = v_w_padded.reshape(num_kv_heads * head_dim, hidden).contiguous() + out[f"{out_prefix}self_attn.v_proj.weight"] = v_w_padded + if v_s is not None: + # scale is per-row [out_rows, 1]; same pad-per-head rule. The + # padded rows have zero weight, so their scale value is irrelevant + # — use the min-clamp value (1e-10) to stay numerically neutral. + assert v_s.shape == (num_kv_heads * v_head_dim, 1), v_s.shape + v_s_per_head = v_s.view(num_kv_heads, v_head_dim, 1) + v_s_padded = torch.full( + (num_kv_heads, head_dim, 1), 1e-10, dtype=v_s.dtype, + ) + v_s_padded[:, :v_head_dim, :] = v_s_per_head + v_s_padded = v_s_padded.reshape(num_kv_heads * head_dim, 1).contiguous() + out[f"{out_prefix}self_attn.v_proj.scale"] = v_s_padded + # o_proj is listed in HF quantization_config.ignored_layers and ships as # BF16; on Neuron it binds to a plain RowParallelLinear (see # modeling_mimo_v2.py: self.o_proj = RowParallelLinear(...)), NOT a diff --git a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py index 9caa1de0..96e78989 100644 --- a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py +++ b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py @@ -289,21 +289,33 @@ def __init__( self.layer_idx = layer_idx self.is_sliding_window = is_sliding_window - # Select parameters based on attention type + # Select parameters based on attention type. + # + # IMPORTANT: we intentionally force attn_v_head_dim == attn_head_dim. + # The HF model has asymmetric head_dim (Q/K=192, V=128), but NxDI's + # KV cache manager requires K and V to share the same last-dim size. + # Runtime-padding V (was the previous approach) introduced subtle + # numerical drift under TP=64 + long decode that compounded into + # output collapse. Instead, the preprocess script pre-pads V's output + # dim with zeros so V is physically [num_kv_heads, head_dim, hidden]; + # the forward() slices the attention output back to the real + # v_head_dim before o_proj. See preprocess_mimo_v2_flash_fp8.py. if is_sliding_window: self.attn_head_dim = config.swa_head_dim - self.attn_v_head_dim = config.swa_v_head_dim self.attn_num_heads = config.swa_num_attention_heads self.attn_num_kv_heads = config.swa_num_key_value_heads + self.attn_real_v_head_dim = config.swa_v_head_dim rope_theta = getattr(config, 'swa_rope_theta', 10000.0) self.sliding_window_size = config.sliding_window else: self.attn_head_dim = config.head_dim - self.attn_v_head_dim = config.v_head_dim self.attn_num_heads = config.num_attention_heads self.attn_num_kv_heads = config.num_key_value_heads + self.attn_real_v_head_dim = config.v_head_dim rope_theta = config.rope_theta self.sliding_window_size = None + # Padded dim used for projection + KV cache: same as Q/K head_dim. + self.attn_v_head_dim = self.attn_head_dim # Calculate partial rotary dimensions self.partial_rotary_factor = config.partial_rotary_factor @@ -398,11 +410,14 @@ def _init_projections(self, config: MiMoV2InferenceConfig): k_num_heads = self.attn_num_kv_heads v_num_heads = self.attn_num_kv_heads - # Q/K use head_dim, V uses v_head_dim + # Q/K use head_dim. V is pre-padded to head_dim in the preprocess + # script so K and V share one KV-cache shape; the attention output is + # sliced back to the real v_head_dim before o_proj, so o_proj still + # takes real_v_head_dim per head (matches the HF o_proj weight shape). q_hidden_size = self.attn_num_heads * self.attn_head_dim k_hidden_size = k_num_heads * self.attn_head_dim v_hidden_size = v_num_heads * self.attn_v_head_dim - o_hidden_size = self.attn_num_heads * self.attn_v_head_dim + o_hidden_size = self.attn_num_heads * self.attn_real_v_head_dim if parallel_state.model_parallel_is_initialized(): tp_group = parallel_state.get_tensor_model_parallel_group() @@ -634,18 +649,14 @@ def forward( key_states_for_cache = key_states value_states_for_cache = value_states - # WORKAROUND 1: Pad V from v_head_dim (128) to head_dim (192) for KV cache compatibility - if self.attn_v_head_dim < self.attn_head_dim: - pad_size = self.attn_head_dim - self.attn_v_head_dim - value_states_for_cache = F.pad(value_states_for_cache, (0, pad_size), value=0.0) - - # WORKAROUND 2: Pad KV heads if layer has fewer than cache expects - # Only needed when NOT using CONVERT_TO_MHA (standard GQA mode) - if not self.use_gqa_convert_to_mha and self.local_num_kv_heads < self.local_cache_kv_heads: - # Pad KV heads by repeating - repeat_factor = self.local_cache_kv_heads // self.local_num_kv_heads - key_states_for_cache = key_states_for_cache.repeat(1, repeat_factor, 1, 1) - value_states_for_cache = value_states_for_cache.repeat(1, repeat_factor, 1, 1) + # NOTE: No runtime V pad or KV head pad here. Previously this block + # padded V's head_dim from 128 to 192 at runtime so K and V could share + # the same KV cache shape, and separately repeated KV heads for layers + # that have fewer kv heads than the cache expects. Both are now + # handled in the preprocess script (V is physically padded to 192 + # with zero rows, and CONVERT_TO_MHA at TP=64 makes local_num_kv_heads + # == local_cache_kv_heads, so head-pad is a no-op). Runtime pad/slice + # on KV cache storage was the source of the long-decode output drift. # Repeat KV heads for GQA (only needed without CONVERT_TO_MHA) # With CONVERT_TO_MHA, K/V already have num_attention_heads @@ -657,22 +668,10 @@ def forward( if is_token_gen: # Token generation: use decomposed attention with prior (cached) and active (current) KV # past_key_value[0] = cached K, shape [bsz, cache_kv_heads, kv_seq_len, head_dim] - # past_key_value[1] = cached V, shape [bsz, cache_kv_heads, kv_seq_len, head_dim] (padded) + # past_key_value[1] = cached V, shape [bsz, cache_kv_heads, kv_seq_len, head_dim] K_prior = past_key_value[0] V_prior = past_key_value[1] - # WORKAROUND 1: Slice KV heads if cache has more than layer needs - # Only needed when NOT using CONVERT_TO_MHA (standard GQA mode) - # With CONVERT_TO_MHA, cache and layer have same num_kv_heads - if not self.use_gqa_convert_to_mha and self.local_num_kv_heads < self.local_cache_kv_heads: - # Cache has repeated heads, just take the first local_num_kv_heads - K_prior = K_prior[:, :self.local_num_kv_heads, :, :] - V_prior = V_prior[:, :self.local_num_kv_heads, :, :] - - # WORKAROUND 2: Slice V_prior back to v_head_dim (128) from head_dim (192) - if self.attn_v_head_dim < self.attn_head_dim: - V_prior = V_prior[..., :self.attn_v_head_dim] - # Repeat cached KV for GQA (only needed without CONVERT_TO_MHA) # With CONVERT_TO_MHA, cached K/V already have num_attention_heads if num_key_value_groups > 1: @@ -796,9 +795,17 @@ def forward( # Apply attention to values attn_output = torch.matmul(attn_weights, value_states) + # Slice off the padded v_head_dim tail. V was physically padded from + # real_v_head_dim (128) to head_dim (192) in the preprocess script so + # it can share NxDI's KV cache shape with K; the padded dims hold + # zero weights, so the attention output in those dims is always zero + # and the slice is numerically exact. + if self.attn_real_v_head_dim < self.attn_v_head_dim: + attn_output = attn_output[..., :self.attn_real_v_head_dim] + # Reshape and project output attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.local_num_heads * self.attn_v_head_dim) + attn_output = attn_output.reshape(bsz, q_len, self.local_num_heads * self.attn_real_v_head_dim) # Context Parallelism: gather output across CP ranks BEFORE o_proj. # With SP enabled, o_proj scatters along seq dim. The input must have full S @@ -1123,18 +1130,21 @@ def convert_mimo_v2_hf_to_neuron_state_dict( ) if v_proj_key in neuron_state_dict: + # V is pre-padded to head_dim in the preprocess script so K/V + # can share KV cache shape; use head_dim (not v_head_dim) here. neuron_state_dict[v_proj_key] = _replicate_kv_weights_for_convert_to_mha( neuron_state_dict[v_proj_key], src_num_kv_heads, num_attention_heads, - v_head_dim, + head_dim, ) # FP8 path: replicate per-row scales ([src_heads*head_dim, 1]) in # lockstep with the weights. Without this the shard_weights step # rejects the scale shape mismatch (e.g. [12,1] vs expected [192,1]). - # BF16 has no .scale key, so this loop is a no-op there. - for proj, hd in (("k_proj", head_dim), ("v_proj", v_head_dim)): + # BF16 has no .scale key, so this loop is a no-op there. Both + # k_proj and v_proj scales use head_dim since V is pre-padded. + for proj, hd in (("k_proj", head_dim), ("v_proj", head_dim)): scale_key = f"layers.{layer_idx}.self_attn.{proj}.scale" if scale_key in neuron_state_dict: neuron_state_dict[scale_key] = _replicate_kv_weights_for_convert_to_mha( From 5b77f8148c7eb3776fd0db32c7868d5a5e62d39f Mon Sep 17 00:00:00 2001 From: whn09 Date: Sat, 25 Apr 2026 04:26:12 +0800 Subject: [PATCH 15/23] Revert "WIP: symmetric K/V head_dim via preprocess-side V padding" This reverts commit b9eea11ad8e513a6951e08b89d98e4fab3fc0c9b. --- .../preprocess_mimo_v2_flash_fp8.py | 54 +------------ .../MiMo-V2-Flash/src/modeling_mimo_v2.py | 78 ++++++++----------- 2 files changed, 36 insertions(+), 96 deletions(-) diff --git a/contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_flash_fp8.py b/contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_flash_fp8.py index 8a436c0f..2491aa18 100644 --- a/contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_flash_fp8.py +++ b/contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_flash_fp8.py @@ -195,8 +195,8 @@ def process_layer( out[f"{out_prefix}{name}.weight"] = t.detach().clone() # --- Attention: q/k/v/o are stored separately in Flash --- - # q/k: rescale to Neuron FP8 per-row. - for proj in ("q_proj", "k_proj"): + # q/k/v: rescale to Neuron FP8 per-row. + for proj in ("q_proj", "k_proj", "v_proj"): w = lazy.get(f"{prefix}self_attn.{proj}.weight") if w is None: continue @@ -206,56 +206,6 @@ def process_layer( if s2 is not None: out[f"{out_prefix}self_attn.{proj}.scale"] = s2 - # v_proj: MiMo-V2 has asymmetric head_dim (Q/K=192, V=128). NxDI's KV cache - # manager requires K and V to share the same head_dim; rather than pad V at - # runtime (every decode step) which is error-prone, we pre-pad V's output - # dimension to match head_dim here: each V head goes from 128 rows to 192 - # rows, with the extra 64 rows zero. Downstream Q @ V yields zeros in the - # padded dims, and the forward() slices the attention output back to - # v_head_dim=128 before o_proj. This lets the modeling code drop the - # `if v_head_dim < head_dim` runtime pad/slice logic entirely. - v_w_hf = lazy.get(f"{prefix}self_attn.v_proj.weight") - v_s_hf = lazy.get(f"{prefix}self_attn.v_proj.weight_scale_inv") - if v_w_hf is not None: - # First run the normal FP8-blockwise -> Neuron-per-row rescale on the - # native [num_kv_heads*128, hidden] weight. - v_w, v_s = _maybe_fp8_to_neuron_per_row(v_w_hf, v_s_hf) - # Now pad output dim from (num_kv_heads*128) to (num_kv_heads*192) by - # inserting 64 zero rows after every 128 real rows (preserve per-head - # head_dim layout). - num_kv_heads_swa = config.get("swa_num_key_value_heads", config["num_key_value_heads"]) - num_kv_heads_full = config["num_key_value_heads"] - num_kv_heads = num_kv_heads_swa if is_swa else num_kv_heads_full - v_head_dim = config.get("v_head_dim", 128) - head_dim = config.get("head_dim", 192) - assert v_w.shape[0] == num_kv_heads * v_head_dim, ( - f"v_proj out-dim {v_w.shape[0]} != num_kv_heads({num_kv_heads}) * " - f"v_head_dim({v_head_dim}) = {num_kv_heads * v_head_dim}" - ) - pad_per_head = head_dim - v_head_dim - hidden = v_w.shape[1] - # Reshape to [num_kv_heads, v_head_dim, hidden], pad to - # [num_kv_heads, head_dim, hidden], flatten back. - v_w_per_head = v_w.view(num_kv_heads, v_head_dim, hidden) - v_w_padded = torch.zeros( - num_kv_heads, head_dim, hidden, dtype=v_w_per_head.dtype, - ) - v_w_padded[:, :v_head_dim, :] = v_w_per_head - v_w_padded = v_w_padded.reshape(num_kv_heads * head_dim, hidden).contiguous() - out[f"{out_prefix}self_attn.v_proj.weight"] = v_w_padded - if v_s is not None: - # scale is per-row [out_rows, 1]; same pad-per-head rule. The - # padded rows have zero weight, so their scale value is irrelevant - # — use the min-clamp value (1e-10) to stay numerically neutral. - assert v_s.shape == (num_kv_heads * v_head_dim, 1), v_s.shape - v_s_per_head = v_s.view(num_kv_heads, v_head_dim, 1) - v_s_padded = torch.full( - (num_kv_heads, head_dim, 1), 1e-10, dtype=v_s.dtype, - ) - v_s_padded[:, :v_head_dim, :] = v_s_per_head - v_s_padded = v_s_padded.reshape(num_kv_heads * head_dim, 1).contiguous() - out[f"{out_prefix}self_attn.v_proj.scale"] = v_s_padded - # o_proj is listed in HF quantization_config.ignored_layers and ships as # BF16; on Neuron it binds to a plain RowParallelLinear (see # modeling_mimo_v2.py: self.o_proj = RowParallelLinear(...)), NOT a diff --git a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py index 96e78989..9caa1de0 100644 --- a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py +++ b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py @@ -289,33 +289,21 @@ def __init__( self.layer_idx = layer_idx self.is_sliding_window = is_sliding_window - # Select parameters based on attention type. - # - # IMPORTANT: we intentionally force attn_v_head_dim == attn_head_dim. - # The HF model has asymmetric head_dim (Q/K=192, V=128), but NxDI's - # KV cache manager requires K and V to share the same last-dim size. - # Runtime-padding V (was the previous approach) introduced subtle - # numerical drift under TP=64 + long decode that compounded into - # output collapse. Instead, the preprocess script pre-pads V's output - # dim with zeros so V is physically [num_kv_heads, head_dim, hidden]; - # the forward() slices the attention output back to the real - # v_head_dim before o_proj. See preprocess_mimo_v2_flash_fp8.py. + # Select parameters based on attention type if is_sliding_window: self.attn_head_dim = config.swa_head_dim + self.attn_v_head_dim = config.swa_v_head_dim self.attn_num_heads = config.swa_num_attention_heads self.attn_num_kv_heads = config.swa_num_key_value_heads - self.attn_real_v_head_dim = config.swa_v_head_dim rope_theta = getattr(config, 'swa_rope_theta', 10000.0) self.sliding_window_size = config.sliding_window else: self.attn_head_dim = config.head_dim + self.attn_v_head_dim = config.v_head_dim self.attn_num_heads = config.num_attention_heads self.attn_num_kv_heads = config.num_key_value_heads - self.attn_real_v_head_dim = config.v_head_dim rope_theta = config.rope_theta self.sliding_window_size = None - # Padded dim used for projection + KV cache: same as Q/K head_dim. - self.attn_v_head_dim = self.attn_head_dim # Calculate partial rotary dimensions self.partial_rotary_factor = config.partial_rotary_factor @@ -410,14 +398,11 @@ def _init_projections(self, config: MiMoV2InferenceConfig): k_num_heads = self.attn_num_kv_heads v_num_heads = self.attn_num_kv_heads - # Q/K use head_dim. V is pre-padded to head_dim in the preprocess - # script so K and V share one KV-cache shape; the attention output is - # sliced back to the real v_head_dim before o_proj, so o_proj still - # takes real_v_head_dim per head (matches the HF o_proj weight shape). + # Q/K use head_dim, V uses v_head_dim q_hidden_size = self.attn_num_heads * self.attn_head_dim k_hidden_size = k_num_heads * self.attn_head_dim v_hidden_size = v_num_heads * self.attn_v_head_dim - o_hidden_size = self.attn_num_heads * self.attn_real_v_head_dim + o_hidden_size = self.attn_num_heads * self.attn_v_head_dim if parallel_state.model_parallel_is_initialized(): tp_group = parallel_state.get_tensor_model_parallel_group() @@ -649,14 +634,18 @@ def forward( key_states_for_cache = key_states value_states_for_cache = value_states - # NOTE: No runtime V pad or KV head pad here. Previously this block - # padded V's head_dim from 128 to 192 at runtime so K and V could share - # the same KV cache shape, and separately repeated KV heads for layers - # that have fewer kv heads than the cache expects. Both are now - # handled in the preprocess script (V is physically padded to 192 - # with zero rows, and CONVERT_TO_MHA at TP=64 makes local_num_kv_heads - # == local_cache_kv_heads, so head-pad is a no-op). Runtime pad/slice - # on KV cache storage was the source of the long-decode output drift. + # WORKAROUND 1: Pad V from v_head_dim (128) to head_dim (192) for KV cache compatibility + if self.attn_v_head_dim < self.attn_head_dim: + pad_size = self.attn_head_dim - self.attn_v_head_dim + value_states_for_cache = F.pad(value_states_for_cache, (0, pad_size), value=0.0) + + # WORKAROUND 2: Pad KV heads if layer has fewer than cache expects + # Only needed when NOT using CONVERT_TO_MHA (standard GQA mode) + if not self.use_gqa_convert_to_mha and self.local_num_kv_heads < self.local_cache_kv_heads: + # Pad KV heads by repeating + repeat_factor = self.local_cache_kv_heads // self.local_num_kv_heads + key_states_for_cache = key_states_for_cache.repeat(1, repeat_factor, 1, 1) + value_states_for_cache = value_states_for_cache.repeat(1, repeat_factor, 1, 1) # Repeat KV heads for GQA (only needed without CONVERT_TO_MHA) # With CONVERT_TO_MHA, K/V already have num_attention_heads @@ -668,10 +657,22 @@ def forward( if is_token_gen: # Token generation: use decomposed attention with prior (cached) and active (current) KV # past_key_value[0] = cached K, shape [bsz, cache_kv_heads, kv_seq_len, head_dim] - # past_key_value[1] = cached V, shape [bsz, cache_kv_heads, kv_seq_len, head_dim] + # past_key_value[1] = cached V, shape [bsz, cache_kv_heads, kv_seq_len, head_dim] (padded) K_prior = past_key_value[0] V_prior = past_key_value[1] + # WORKAROUND 1: Slice KV heads if cache has more than layer needs + # Only needed when NOT using CONVERT_TO_MHA (standard GQA mode) + # With CONVERT_TO_MHA, cache and layer have same num_kv_heads + if not self.use_gqa_convert_to_mha and self.local_num_kv_heads < self.local_cache_kv_heads: + # Cache has repeated heads, just take the first local_num_kv_heads + K_prior = K_prior[:, :self.local_num_kv_heads, :, :] + V_prior = V_prior[:, :self.local_num_kv_heads, :, :] + + # WORKAROUND 2: Slice V_prior back to v_head_dim (128) from head_dim (192) + if self.attn_v_head_dim < self.attn_head_dim: + V_prior = V_prior[..., :self.attn_v_head_dim] + # Repeat cached KV for GQA (only needed without CONVERT_TO_MHA) # With CONVERT_TO_MHA, cached K/V already have num_attention_heads if num_key_value_groups > 1: @@ -795,17 +796,9 @@ def forward( # Apply attention to values attn_output = torch.matmul(attn_weights, value_states) - # Slice off the padded v_head_dim tail. V was physically padded from - # real_v_head_dim (128) to head_dim (192) in the preprocess script so - # it can share NxDI's KV cache shape with K; the padded dims hold - # zero weights, so the attention output in those dims is always zero - # and the slice is numerically exact. - if self.attn_real_v_head_dim < self.attn_v_head_dim: - attn_output = attn_output[..., :self.attn_real_v_head_dim] - # Reshape and project output attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.local_num_heads * self.attn_real_v_head_dim) + attn_output = attn_output.reshape(bsz, q_len, self.local_num_heads * self.attn_v_head_dim) # Context Parallelism: gather output across CP ranks BEFORE o_proj. # With SP enabled, o_proj scatters along seq dim. The input must have full S @@ -1130,21 +1123,18 @@ def convert_mimo_v2_hf_to_neuron_state_dict( ) if v_proj_key in neuron_state_dict: - # V is pre-padded to head_dim in the preprocess script so K/V - # can share KV cache shape; use head_dim (not v_head_dim) here. neuron_state_dict[v_proj_key] = _replicate_kv_weights_for_convert_to_mha( neuron_state_dict[v_proj_key], src_num_kv_heads, num_attention_heads, - head_dim, + v_head_dim, ) # FP8 path: replicate per-row scales ([src_heads*head_dim, 1]) in # lockstep with the weights. Without this the shard_weights step # rejects the scale shape mismatch (e.g. [12,1] vs expected [192,1]). - # BF16 has no .scale key, so this loop is a no-op there. Both - # k_proj and v_proj scales use head_dim since V is pre-padded. - for proj, hd in (("k_proj", head_dim), ("v_proj", head_dim)): + # BF16 has no .scale key, so this loop is a no-op there. + for proj, hd in (("k_proj", head_dim), ("v_proj", v_head_dim)): scale_key = f"layers.{layer_idx}.self_attn.{proj}.scale" if scale_key in neuron_state_dict: neuron_state_dict[scale_key] = _replicate_kv_weights_for_convert_to_mha( From 13d855e8ee15c63f8aa9c2310fc45df949a504a9 Mon Sep 17 00:00:00 2001 From: whn09 Date: Sat, 25 Apr 2026 07:12:50 +0800 Subject: [PATCH 16/23] smoke: isolate BASE_COMPILE_WORK_DIR per COMPILED_PATH NxDI's model builder stages HLOs into a global temp workdir controlled by the BASE_COMPILE_WORK_DIR env var (default "/tmp/nxd_model/"). Two concurrent compiles with different neuron_config hashes still share that directory and silently overwrite each other's model.hlo_module.pb files, which makes neuronx-cc exit 70 when the compiler tries to read what it just staged and finds a different graph. Setting BASE_COMPILE_WORK_DIR to a unique per-compile subdir (derived from COMPILED_PATH) lets FP8/BF16 smoke compiles run in parallel safely. Co-Authored-By: Claude Opus 4.7 --- .../perf_test/smoke_compile_mimo_v2_flash.py | 11 +++++++++++ .../perf_test/smoke_generate_mimo_v2_flash.py | 8 ++++++++ 2 files changed, 19 insertions(+) diff --git a/contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py b/contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py index 775dbaee..e48153e1 100755 --- a/contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py +++ b/contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py @@ -45,6 +45,17 @@ os.makedirs(COMPILED_PATH, exist_ok=True) +# NxDI's model builder uses a per-process temp workdir for HLO/NEFF staging +# (BASE_COMPILE_WORK_DIR, default "/tmp/nxd_model/"). If two compiles run in +# parallel with the same default, they silently overwrite each other's +# .hlo_module.pb files and one or both compilations crash with +# "neuronx-cc returned non-zero exit status 70". Pin the workdir to a +# unique per-COMPILED_PATH subdir to stay safe under any parallel invocation. +os.environ.setdefault( + "BASE_COMPILE_WORK_DIR", + os.path.join("/tmp/nxd_model", os.path.basename(COMPILED_PATH.rstrip("/"))), +) + def main(): from neuronx_distributed_inference.models.config import MoENeuronConfig diff --git a/contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py b/contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py index 82670ead..e03eb36f 100755 --- a/contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py +++ b/contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py @@ -39,6 +39,14 @@ ) MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "20")) +# Keep the per-compile BASE_COMPILE_WORK_DIR in sync with +# smoke_compile_mimo_v2_flash.py so load() under the same COMPILED_PATH +# doesn't collide with a concurrent compile or reuse a stale workdir. +os.environ.setdefault( + "BASE_COMPILE_WORK_DIR", + os.path.join("/tmp/nxd_model", os.path.basename(COMPILED_PATH.rstrip("/"))), +) + def main(): from transformers import AutoConfig, AutoTokenizer, GenerationConfig From 1d3916a6255fba06bc2ea9f99c4f4f086a40be33 Mon Sep 17 00:00:00 2001 From: whn09 Date: Sat, 25 Apr 2026 07:13:39 +0800 Subject: [PATCH 17/23] Router bias: use arange + bf16 to survive XLA constant folding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two subtle issues around e_score_correction_bias tracing: 1. dtype=torch.float32 was wrong. NxDI's checkpoint loader casts router bias from FP32 to BF16 at load time (see the "Found torch.float32 weights ... Will convert to torch.bfloat16" warnings in the smoke log). If the traced NEFF still expects FP32 after the cast, the LayoutTransformation silently drops the weight and leaves the trace-time init values live, so the bias at runtime is whatever we init here — not the checkpoint values. 2. torch.zeros(num_experts) was wrong. If every entry is identical, the `+ bias` op does not change topk's relative ordering, so XLA's constant-folding passes prove the add is a no-op and eliminate it, dropping the bias parameter from the HLO entirely. Checkpoint loading has nothing to bind to, and the real bias values never reach the device. Use torch.arange(num_experts, dtype=torch.bfloat16) instead: distinct per-expert values force the compiler to keep the add as a runtime op with a live parameter, and BF16 matches the loader's cast target. Also move the un-bias-affinity logic into scores_for_choice to match the MiMo HF reference and MiniMax-M2's working implementation. Source: Jim Burtoft's MiniMax-M2 fix (jimburtoft/neuronx-distributed-inference@49f8e164). This change alone does NOT fix the long-decode output collapse for Flash (BF16 produces coherent Chinese for the same 40-token chat prompt where FP8 collapses to "helpful helpful helpful"), but it is required for correctness once the underlying FP8 issue is found and fixed. Co-Authored-By: Claude Opus 4.7 --- .../MiMo-V2-Flash/src/modeling_mimo_v2.py | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py index 9caa1de0..8a9d3885 100644 --- a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py +++ b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py @@ -1555,8 +1555,28 @@ def _apply_router_noaux_tc_fix(): def _patched_init(self, *args, **kwargs): original_init(self, *args, **kwargs) + # CRITICAL: dtype + init value both matter for XLA tracing. + # + # 1) dtype=torch.bfloat16: the NxDI checkpoint loader casts router + # bias from FP32 -> BF16 ("Found torch.float32 weights in + # checkpoint ... Will convert to torch.bfloat16"). If the traced + # NEFF expects FP32 but the checkpoint supplies BF16, the + # LayoutTransformation silently drops the weight and keeps the + # trace-time init values — so the bias at runtime is whatever + # we init here, not the checkpoint values. + # + # 2) init=arange, NOT zeros: if every entry is identical (all + # zeros), the `+ bias` op does not change the relative ordering + # of topk, so XLA's constant-folding passes can prove the add + # is a no-op and eliminate it entirely — dropping the bias + # parameter from the HLO. At that point checkpoint loading has + # nothing to bind to and the real bias is silently discarded. + # Using arange guarantees distinct per-expert values, forcing + # the compiler to keep the add as a runtime op with a live + # parameter. Source: Jim Burtoft's MiniMax-M2 fix notes + # (jimburtoft/neuronx-distributed-inference@49f8e164). self.e_score_correction_bias = nn.Parameter( - torch.zeros(self.num_experts, dtype=torch.float32), + torch.arange(self.num_experts, dtype=torch.bfloat16), requires_grad=False, ) @@ -1564,9 +1584,13 @@ def _patched_forward(self, hidden_states): router_logits = self.get_router_logits(hidden_states) expert_affinities = self.apply_activation_fn(router_logits) - selection_scores = expert_affinities.to(torch.float32) + \ - self.e_score_correction_bias - _, expert_index = torch.topk(selection_scores, self.top_k, dim=-1) + # MiMo (and MiniMax-M2) uses topk_method='noaux_tc': the bias is + # added ONLY for top-k selection, but the unbiased sigmoid scores + # remain as the expert-affinity weights passed to the experts. + scores_for_choice = ( + expert_affinities.float() + self.e_score_correction_bias.unsqueeze(0) + ) + _, expert_index = torch.topk(scores_for_choice, self.top_k, dim=-1) expert_affinities = expert_affinities.to(dtype=hidden_states.dtype) expert_index = expert_index.detach().to(dtype=torch.long) From 30a30d3a44032ba97523580deeeb712cd4f306f5 Mon Sep 17 00:00:00 2001 From: whn09 Date: Sat, 25 Apr 2026 09:01:39 +0800 Subject: [PATCH 18/23] smoke: pin outer ep_degree=1 (only moe_ep_degree varies) Previously the smoke scripts set ep_degree=MOE_EP at the top-level MoENeuronConfig *and* moe_ep_degree=MOE_EP inside the MoE config. The outer ep_degree is the full-model expert-parallel factor and multiplies world_size to tp_degree*ep_degree. On a 64-NC Trn2 with tp_degree=64 and ep_degree>1 that blew world_size past 64 (e.g. moe_ep=4 -> world_size=256), which: - produced a sharded checkpoint with 4x as many tp{N} files (tp0..tp255 instead of tp0..tp63) and 4x the on-disk size; - at runtime would try to address ranks beyond the 64 physical cores, failing load or OOM'ing. Pro's working vLLM configs only set moe_ep_degree (no ep_degree in the override), so NxDI's default ep_degree=1 keeps world_size=64. Pin ep_degree=1 in the smoke scripts so varying MOE_EP only affects the MoE-internal split, matching Pro's layout and keeping the sharded checkpoint sized correctly. Also generalize the MoE scale-expansion math in the state-dict converter to use moe_tp_degree (the shard dim for expert weights) rather than tp_degree. The two are the same at our BS=1 baseline (moe_tp=64) so the bug was latent, but manifests the moment you try any other moe_tp (e.g. 16 or 32): expansion produces scale tensors sized for the full TP instead of just the MoE TP, triggering a shape mismatch during shard_checkpoint. Co-Authored-By: Claude Opus 4.7 --- .../perf_test/smoke_compile_mimo_v2_flash.py | 11 ++++++++++- .../perf_test/smoke_generate_mimo_v2_flash.py | 4 +++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py b/contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py index e48153e1..e9c9b285 100755 --- a/contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py +++ b/contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py @@ -81,9 +81,18 @@ def main(): print(f"[smoke] STAGE={STAGE}") print("[smoke] Building MoENeuronConfig (quantized FP8 MoE, blockwise_symmetric)...") + # NOTE: ep_degree at the top level controls the OUTER (full model) + # expert-parallel factor, which multiplies world_size to + # tp_degree * ep_degree and duplicates non-MoE weights per replica. + # At world_size > 64 on a 64-NC Trn2, sharded weights grow accordingly + # (e.g. tp=64 + ep=4 -> 256 ranks -> 4x the sharded checkpoint size, + # and at runtime the model doesn't fit on the device). For MoE-only + # EP we want ep_degree=1 at the outer level and the per-MoE split + # controlled solely by moe_ep_degree (which Pro's working benches + # also do). Keep ep_degree=1 unconditionally. neuron_config = MoENeuronConfig( tp_degree=TP_DEGREE, - ep_degree=MOE_EP, + ep_degree=1, logical_nc_config=2, batch_size=BATCH_SIZE, max_batch_size=BATCH_SIZE, diff --git a/contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py b/contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py index e03eb36f..3fcb8678 100755 --- a/contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py +++ b/contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py @@ -73,9 +73,11 @@ def main(): print(f"[gen] COMPILED_PATH={COMPILED_PATH}") print(f"[gen] TP={TP_DEGREE}, SEQ={SEQ_LEN}, BS={BATCH_SIZE}") + # Outer ep_degree must match the compile-time value (kept at 1 so + # world_size = tp_degree; see smoke_compile_mimo_v2_flash.py comment). neuron_config = MoENeuronConfig( tp_degree=TP_DEGREE, - ep_degree=MOE_EP, + ep_degree=1, logical_nc_config=2, batch_size=BATCH_SIZE, max_batch_size=BATCH_SIZE, From d3c1c961b23ceacb8719a2b91a3eece05d2498b0 Mon Sep 17 00:00:00 2001 From: whn09 Date: Sat, 25 Apr 2026 09:01:50 +0800 Subject: [PATCH 19/23] MoE scale expansion: use moe_tp_degree, not tp_degree MoE expert weights shard along moe_tp_degree (the MoE-internal TP factor), which can differ from the top-level tp_degree. The MoE scale-expansion code in convert_mimo_v2_hf_to_neuron_state_dict was hard-coded to tp_degree, so it happened to work at the BS=1 baseline (moe_tp=tp=64) but produced wrongly-sized scale tensors the moment MOE_TP != TP_DEGREE. Symptom: RuntimeError: expected shape torch.Size([4, 32, 32]) for layers.1.mlp.expert_mlps.mlp_op.gate_up_proj.scale but found torch.Size([4, 32, 64]) Read moe_tp_degree from neuron_config (falling back to tp_degree for non-MoE configs) and use it as the expansion denominator. Co-Authored-By: Claude Opus 4.7 --- .../MiMo-V2-Flash/src/modeling_mimo_v2.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py index 8a9d3885..21e7e47e 100644 --- a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py +++ b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py @@ -1235,7 +1235,13 @@ def convert_mimo_v2_hf_to_neuron_state_dict( # 128-wide block genuinely share that block's scale. No-op when the # .scale keys are absent (BF16 path). if getattr(config.neuron_config, "quantized", False): - tp = config.neuron_config.tp_degree + # IMPORTANT: MoE expert weights are sharded by moe_tp_degree (not the + # top-level tp_degree — attention uses tp_degree, MoE can use a + # different split). At moe_tp=64 the per-rank intermediate is 32 (<128) + # so we had to expand the scale to make the shard layout match; at + # moe_tp=16 per-rank intermediate is 128 (>=128) and no expansion is + # needed. + moe_tp = getattr(config.neuron_config, "moe_tp_degree", None) or config.neuron_config.tp_degree for layer_idx in range(config.num_hidden_layers): if not config.layer_uses_moe[layer_idx]: continue @@ -1247,14 +1253,14 @@ def convert_mimo_v2_hf_to_neuron_state_dict( i_blocks = s.shape[1] h_blocks = s.shape[2] intermediate = i_blocks * 128 - i_per_rank = intermediate // tp # 32 @ TP=64 for IM=2048 + i_per_rank = intermediate // moe_tp if i_per_rank < 128: - ranks_per_block = 128 // i_per_rank # 4 @ TP=64 + ranks_per_block = 128 // i_per_rank s_exp = s.unsqueeze(2).expand(-1, -1, ranks_per_block, -1) s_exp = s_exp.reshape(s.shape[0], i_blocks * ranks_per_block, h_blocks) - assert s_exp.shape[1] == tp, ( + assert s_exp.shape[1] == moe_tp, ( f"down_proj.scale expansion produced {s_exp.shape[1]} rows, " - f"expected TP={tp}" + f"expected moe_tp={moe_tp}" ) neuron_state_dict[dp_key] = s_exp.contiguous() @@ -1262,10 +1268,10 @@ def convert_mimo_v2_hf_to_neuron_state_dict( # along last axis). Scale: [E, H_blocks, 2*I_blocks] stored as # [gate_half | up_half]. Module parameter has per-rank last-dim=1 # (via _apply_blockwise_scale_stride_fix patch forcing - # partition_stride=1), so the full scale must have last-dim=tp - # with gate entries 0..tp/2 and up entries tp/2..tp. Expand each - # half independently to preserve the gate/up boundary when NxD - # does `split(per_partition=2*I/tp, dim=-1)`. + # partition_stride=1), so the full scale must have last-dim=moe_tp + # with gate entries 0..moe_tp/2 and up entries moe_tp/2..moe_tp. + # Expand each half independently to preserve the gate/up boundary + # when NxD does `split(per_partition=2*I/moe_tp, dim=-1)`. gu_key = f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.gate_up_proj.scale" if gu_key in neuron_state_dict: s = neuron_state_dict[gu_key] @@ -1276,15 +1282,15 @@ def convert_mimo_v2_hf_to_neuron_state_dict( ) i_blocks = two_i_blocks // 2 intermediate = i_blocks * 128 - out_per_rank = (2 * intermediate) // tp # 64 @ TP=64 for IM=2048 + out_per_rank = (2 * intermediate) // moe_tp if out_per_rank < 128: - assert tp % 2 == 0, f"TP={tp} must be even for gate/up scale split" - ranks_per_half = tp // 2 # 32 @ TP=64 + assert moe_tp % 2 == 0, f"moe_tp={moe_tp} must be even for gate/up scale split" + ranks_per_half = moe_tp // 2 assert ranks_per_half % i_blocks == 0, ( f"ranks_per_half={ranks_per_half} must be divisible by " f"i_blocks={i_blocks}" ) - ranks_per_block = ranks_per_half // i_blocks # 2 @ TP=64 (i_blocks=16) + ranks_per_block = ranks_per_half // i_blocks gate_half = s[..., :i_blocks] # [E, H_blocks, i_blocks] up_half = s[..., i_blocks:] gate_exp = ( @@ -1298,9 +1304,9 @@ def convert_mimo_v2_hf_to_neuron_state_dict( .reshape(s.shape[0], h_blocks, ranks_per_half) ) s_exp = torch.cat([gate_exp, up_exp], dim=-1) - assert s_exp.shape[-1] == tp, ( + assert s_exp.shape[-1] == moe_tp, ( f"gate_up_proj.scale expansion produced {s_exp.shape[-1]} " - f"entries, expected TP={tp}" + f"entries, expected moe_tp={moe_tp}" ) neuron_state_dict[gu_key] = s_exp.contiguous() From b9aaa74b9efe43bddc89c32d78b829f023ff590d Mon Sep 17 00:00:00 2001 From: whn09 Date: Sat, 25 Apr 2026 09:47:08 +0800 Subject: [PATCH 20/23] smoke: default to moe_tp=1 / moe_ep=64 for correct FP8 output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit FP8 with moe_tp=64 (all TP ranks splitting MoE outputs) reduces each rank's expert intermediate slice to 32 rows — below the 128-row blockwise scale block — which collapses the per-rank scale to a singleton in NxDI's `_setup_for_scale`. With no per-channel scale granularity, MoE forward accumulates a per-layer drift that compounds across 47 MoE layers and lands as output collapse (long decodes degrade into "helpful helpful helpful" or similar repetition). Switching the compile/generate smoke defaults to moe_tp=1 / moe_ep=64 keeps every expert intact on a single rank (n_local_experts=4, no intra-expert TP shard), so the full per-channel FP8 scale survives. Verified on Trn2 TP=64 FP8: 40-token Chinese chat prompt produces coherent multi-sentence output instead of collapsing. Other FP8 ratios still mis-behave: moe_tp=32/ep=2 leaves down-proj per-rank intermediate at 64 (<128, still collapses), and moe_tp=16/ep=4 (per-rank gate_up=256/down=128) also gives gibberish on the same prompt. Only moe_tp=1/ep=64 — the only config that keeps both dims well above the 128 block boundary — gives correct output. COMPILED_PATH default also updated to the new directory so reruns don't accidentally reuse the old (broken) NEFF cache. Co-Authored-By: Claude Opus 4.7 --- .../perf_test/smoke_compile_mimo_v2_flash.py | 14 +++++++++++--- .../perf_test/smoke_generate_mimo_v2_flash.py | 6 +++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py b/contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py index e9c9b285..1f448cf7 100755 --- a/contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py +++ b/contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py @@ -31,15 +31,23 @@ ) COMPILED_PATH = os.environ.get( "MIMO_V2_FLASH_COMPILED_PATH", - "/opt/dlami/nvme/compiled/mimo_v2_flash_tp64_ep1_fp8/", + "/opt/dlami/nvme/compiled/mimo_v2_flash_tp64_moetp1_ep64_fp8/", ) TP_DEGREE = int(os.environ.get("TP_DEGREE", "64")) SEQ_LEN = int(os.environ.get("SEQ_LEN", "1024")) BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1")) CTX_BATCH_SIZE = int(os.environ.get("CTX_BATCH_SIZE", "1")) -MOE_TP = int(os.environ.get("MOE_TP", "64")) -MOE_EP = int(os.environ.get("MOE_EP", "1")) +# Default to moe_tp=1 / moe_ep=64. Under FP8 + moe_tp=64 (our old default) +# each rank's MoE expert intermediate slice is 32 rows (<128, the scale +# block size), which collapses the per-rank scale to a singleton in +# NxDI's `_setup_for_scale` — losing per-channel FP8 scale granularity +# and producing a BF16-accumulator drift that compounds into output +# collapse after ~30 decode tokens. moe_tp=1/moe_ep=64 keeps every expert +# on a single rank (4 full experts per rank), so each expert's scale +# survives intact. Override via MOE_TP / MOE_EP env vars for other recipes. +MOE_TP = int(os.environ.get("MOE_TP", "1")) +MOE_EP = int(os.environ.get("MOE_EP", "64")) STAGE = os.environ.get("STAGE", "all").lower() diff --git a/contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py b/contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py index 3fcb8678..9f306370 100755 --- a/contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py +++ b/contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py @@ -21,7 +21,7 @@ ) COMPILED_PATH = os.environ.get( "MIMO_V2_FLASH_COMPILED_PATH", - "/opt/dlami/nvme/compiled/mimo_v2_flash_tp64_ep1_fp8/", + "/opt/dlami/nvme/compiled/mimo_v2_flash_tp64_moetp1_ep64_fp8/", ) # Must match smoke_compile_mimo_v2_flash.py exactly, else load() sees a @@ -30,8 +30,8 @@ SEQ_LEN = int(os.environ.get("SEQ_LEN", "1024")) BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1")) CTX_BATCH_SIZE = int(os.environ.get("CTX_BATCH_SIZE", "1")) -MOE_TP = int(os.environ.get("MOE_TP", "64")) -MOE_EP = int(os.environ.get("MOE_EP", "1")) +MOE_TP = int(os.environ.get("MOE_TP", "1")) +MOE_EP = int(os.environ.get("MOE_EP", "64")) PROMPT = os.environ.get( "MIMO_V2_FLASH_PROMPT", From 305279f907eaa41e07cf49c010f194df1d2dc521 Mon Sep 17 00:00:00 2001 From: whn09 Date: Sat, 25 Apr 2026 14:32:32 +0800 Subject: [PATCH 21/23] bench: rewrite for FP8 recipe with moe_tp=1/moe_ep=64 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous bench defaulted to the BF16 checkpoint and defined three configs (BS=1/32/128). The Config 1 FP8 recipe (moe_tp=64/moe_ep=1, BS=1) is now known to produce output collapse after ~30 decode tokens — NxDI's `_setup_for_scale` drops per-channel FP8 blockwise scale when per-rank weight is below the 128-row block, which for Flash happens at moe_tp=64. And BS=1 combined with Expert Parallelism (moe_ep>1) hits NxDI's `BS >= num_experts / top_k = 32` assertion during TKG HLO generation. New bench: - MODEL_PATH defaults to the Neuron-FP8 preprocessed checkpoint. - COMMON_MIMO_CONFIG carries all FP8 quantization fields inline so every config inherits them (quantized=true, blockwise_symmetric, 128x128 blocks, o_proj + embed/lm_head/norm/router held out). - Config 1: BS=32, moe_tp=1, moe_ep=64 (smallest BS the FP8 path supports). - Config 2: BS=128, moe_tp=1, moe_ep=64 (throughput-leaning). - Drops sequence_parallel_enabled=true from COMMON (it interacts badly with our attention forward at generation time) and drops use_torch_block_wise=true (the FP8 path uses the native shard-on-block NKI kernel). Numeric CC token / scratchpad page size tweaks that were specific to the Pro benchmark are removed. Co-Authored-By: Claude Opus 4.7 --- .../perf_test/bench_mimo_v2_flash.sh | 137 +++++++----------- 1 file changed, 49 insertions(+), 88 deletions(-) diff --git a/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh b/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh index d3fb5910..e01fc3df 100755 --- a/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh +++ b/contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh @@ -1,45 +1,63 @@ #!/bin/bash set -e +# MiMo-V2-Flash FP8 vLLM benchmark on Trn2. +# +# Requires a Neuron-FP8 preprocessed checkpoint (see +# `src/conversion_script/preprocess_mimo_v2_flash_fp8.py`). The configs below +# all use moe_tp_degree=1 / moe_ep_degree=64 (experts sharded by expert +# parallelism only, no intra-expert TP split) because moe_tp_degree=64 collapses +# the per-rank FP8 blockwise scale to a singleton — per-rank expert +# intermediate is 32 rows, below the 128-row blockwise block, so +# NxDI's `_setup_for_scale` drops per-channel scale granularity. The resulting +# drift compounds across 47 MoE layers and gives repetition / output collapse. +# Using moe_ep_degree=64 keeps all of each expert's weight + scale on one rank +# (4 experts per rank), which preserves the blockwise scale intact. +# +# NxDI's TKG path refuses Expert Parallelism with BS < num_experts/top_k +# (256 / 8 = 32 for Flash), so the smallest working batch size here is 32. +# If you want BS=1 behaviour, the FP8 path is not currently supported on +# this model on Trn2 — use the BF16 checkpoint with the old bench recipe +# (`moe_tp_degree=64, moe_ep_degree=1, batch_size=1`). + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate -MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2-Flash-BF16}" +MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2-Flash-Neuron-FP8}" # The NxDI contrib MiMo-V2-Flash modeling code is registered into vLLM / # NxDI lookup tables by vllm-neuron's register() hook using this env var. # Default to this contrib package's own src/ relative to the script. : "${NXDI_CONTRIB_MIMO_V2_FLASH_SRC:=$(cd "$(dirname "$0")/.." && pwd)/src}" export NXDI_CONTRIB_MIMO_V2_FLASH_SRC +# First-time Flash FP8 compile takes 30-60 minutes; extend vLLM's ready +# timeout and the compiler's environment variables for FP8 numerics. +export VLLM_ENGINE_READY_TIMEOUT_S=7200 + PORT=8000 RESULTS_DIR="/tmp/bench_results/mimo_v2_flash" mkdir -p "$RESULTS_DIR" -# Common neuron config shared across all MiMo configs +# Common neuron config shared across all MiMo-V2-Flash FP8 configs. +# save_sharded_checkpoint=true persists per-rank sharded weights to +# /weights/tp{N}_sharded_checkpoint.safetensors during compile; +# load() then reads those directly (~30s) instead of re-sharding the entire +# checkpoint on every vllm-neuron startup (~10+ min). COMMON_MIMO_CONFIG='"tp_degree": 64, "logical_nc_config": 2, "fused_qkv": false, - "flash_decoding_enabled": false, - "sequence_parallel_enabled": true, - "qkv_kernel_enabled": false, - "qkv_nki_kernel_enabled": false, - "qkv_cte_nki_kernel_fuse_rope": false, - "attn_kernel_enabled": false, - "strided_context_parallel_kernel_enabled": false, + "sequence_parallel_enabled": false, "glu_mlp": true, "normalize_top_k_affinities": true, "save_sharded_checkpoint": true, "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, - "blockwise_matmul_config": {"use_torch_block_wise": true}' -# save_sharded_checkpoint=true persists per-rank sharded weights to -# /weights/tp{N}_sharded_checkpoint.safetensors during compile; -# load() then reads those directly (~55s) instead of re-sharding the entire -# checkpoint on every vllm-neuron startup (~10+ min). -# NOTE: use_torch_block_wise=true forces MoE blockwise to use the PyTorch -# reference implementation. The NKI kernel path pulls -# neuronxcc.nki._private.blockwise_mm.blockwise_mm_baseline_shard_hidden -# which is missing from Neuron SDK 2.29's public path; skipping it here is -# the practical way to get the bench running on current stacks. Remove -# this once the NKI kernel is promoted back to the stable path. + "quantized": true, + "quantized_checkpoints_path": "'"$MODEL_PATH"'", + "quantization_dtype": "f8e4m3", + "quantization_type": "blockwise_symmetric", + "quantization_block_axis": [1, 2], + "quantization_block_size": [128, 128], + "modules_to_not_convert": ["embed_tokens", "lm_head", "norm", "router", "o_proj"], + "blockwise_matmul_config": {"use_shard_on_block_dynamic_while": true, "block_sharding_strategy": "PING_PONG"}' # Helper: wait for vLLM server to be ready. First-time compilation of a # 256-expert MoE model takes 30-90 minutes, so we poll for up to 2 hours. @@ -107,57 +125,18 @@ sanity_check() { } echo "==========================================" -echo "MiMo-V2-Flash Performance Benchmark" +echo "MiMo-V2-Flash FP8 Performance Benchmark" echo "==========================================" echo "Model: $MODEL_PATH" echo "Results: $RESULTS_DIR" echo "" ############################################################################### -# Config 1: BS=1, TP=64/EP=1, non-CB (baseline latency) -############################################################################### -CONFIG_NAME="bs1_tp64_ep1" -echo "--- Config 1: BS=1, TP=64/EP=1, non-CB (baseline) ---" - -python3 -m vllm.entrypoints.openai.api_server \ - --model "$MODEL_PATH" \ - --tokenizer "$MODEL_PATH" \ - --tensor-parallel-size 64 \ - --max-model-len 1024 \ - --max-num-seqs 1 \ - --no-enable-chunked-prefill \ - --no-enable-prefix-caching \ - --port $PORT \ - --trust_remote_code \ - --additional-config '{ - "override_neuron_config": { - '"$COMMON_MIMO_CONFIG"', - "moe_tp_degree": 64, - "moe_ep_degree": 1, - "batch_size": 1, - "ctx_batch_size": 1, - "tkg_batch_size": 1, - "max_context_length": 1024, - "seq_len": 1024, - "is_continuous_batching": false, - "enable_bucketing": false, - "async_mode": true, - "on_device_sampling_config": { - "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 - } - } - }' & - -wait_for_server -sanity_check -run_bench "$CONFIG_NAME" 1 16 -stop_server - -############################################################################### -# Config 2: BS=32, TP=1/EP=64, CB + optimizations +# Config 1: BS=32, TP=64 + moe_tp=1/moe_ep=64, CB + bucketing (smallest BS +# that satisfies NxDI's Expert-Parallel BS >= num_experts/top_k requirement). ############################################################################### -CONFIG_NAME="bs32_tp1_ep64_opt" -echo "--- Config 2: BS=32, TP=1/EP=64, CB + optimizations ---" +CONFIG_NAME="bs32_tp64_moetp1_ep64" +echo "--- Config 1: BS=32, moe_tp=1/moe_ep=64, CB + bucketing ---" python3 -m vllm.entrypoints.openai.api_server \ --model "$MODEL_PATH" \ @@ -186,16 +165,7 @@ python3 -m vllm.entrypoints.openai.api_server \ "async_mode": true, "on_device_sampling_config": { "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 - }, - "use_index_calc_kernel": true, - "moe_mask_padded_tokens": true, - "blockwise_matmul_config": { - "use_torch_block_wise": true, - "use_shard_on_intermediate_dynamic_while": true, - "skip_dma_token": true - }, - "disable_numeric_cc_token": true, - "scratchpad_page_size": 1024 + } } }' & @@ -207,10 +177,10 @@ run_bench "$CONFIG_NAME" 32 128 stop_server ############################################################################### -# Config 3: BS=128, TP=1/EP=64, CB + optimizations +# Config 2: BS=128, TP=64 + moe_tp=1/moe_ep=64, CB + bucketing (throughput). ############################################################################### -CONFIG_NAME="bs128_tp1_ep64_opt" -echo "--- Config 3: BS=128, TP=1/EP=64, CB + optimizations ---" +CONFIG_NAME="bs128_tp64_moetp1_ep64" +echo "--- Config 2: BS=128, moe_tp=1/moe_ep=64, CB + bucketing ---" python3 -m vllm.entrypoints.openai.api_server \ --model "$MODEL_PATH" \ @@ -239,16 +209,7 @@ python3 -m vllm.entrypoints.openai.api_server \ "async_mode": true, "on_device_sampling_config": { "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 - }, - "use_index_calc_kernel": true, - "moe_mask_padded_tokens": true, - "blockwise_matmul_config": { - "use_torch_block_wise": true, - "use_shard_on_intermediate_dynamic_while": true, - "skip_dma_token": true - }, - "disable_numeric_cc_token": true, - "scratchpad_page_size": 1024 + } } }' & @@ -261,7 +222,7 @@ run_bench "$CONFIG_NAME" 128 512 stop_server echo "==========================================" -echo "MiMo-V2-Flash benchmarks complete!" +echo "MiMo-V2-Flash FP8 benchmarks complete!" echo "Results saved to: $RESULTS_DIR" echo "==========================================" ls -la "$RESULTS_DIR" From 21aca4826d4fb51924caa53b5aecb5448825e95b Mon Sep 17 00:00:00 2001 From: whn09 Date: Sat, 25 Apr 2026 14:32:47 +0800 Subject: [PATCH 22/23] README: document the FP8 recipe and its constraints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The FP8 inference path is now the recommended path on Trn2, but it has several non-obvious configuration requirements. Capture them here so users don't have to rediscover them: - **moe_tp_degree=1, moe_ep_degree=64 is the only working FP8 ratio.** moe_tp=64 collapses the per-rank blockwise scale to a singleton (per-rank intermediate = 32 rows < 128-row block), which compounds into output collapse. moe_tp=32/16 have been empirically verified to still produce gibberish. - **batch_size must be >= num_experts / top_k = 32** on the FP8 path (NxDI refuses EP>1 at TKG under that threshold). - **Outer ep_degree must stay 1** — it multiplies world_size, and world_size > tp_degree overflows the physical NC count. Other updates: - Recommend the new streaming preprocess script (`preprocess_mimo_v2_flash_fp8.py`) as the default, demote the FP8->BF16 dequant path to "fallback". - Update the Python usage example to the current NeuronConfig surface (quantized=True + blockwise_symmetric, AutoConfig hf_config plumbing, modules_to_not_convert list, etc.). - Replace the BS=1 vLLM serving example with the working BS=32 FP8 recipe; mention NEURON_COMPILED_ARTIFACTS for isolating compile dirs. - Correct the sliding_window value (128, not 32,768) and the expert intermediate size (2048, not 1536). - Leave the existing BF16 performance table in place as reference while FP8 benchmark numbers are still being collected. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2-Flash/README.md | 208 +++++++++++++++++-------- 1 file changed, 141 insertions(+), 67 deletions(-) diff --git a/contrib/models/MiMo-V2-Flash/README.md b/contrib/models/MiMo-V2-Flash/README.md index 45e7b205..ae40c700 100644 --- a/contrib/models/MiMo-V2-Flash/README.md +++ b/contrib/models/MiMo-V2-Flash/README.md @@ -21,35 +21,46 @@ NeuronX Distributed Inference implementation of [XiaomiMiMo/MiMo-V2-Flash](https | Q/K Head Dim | 192 | | V Head Dim | 128 | | Experts | 256 (top-8 routing) | -| Expert Intermediate | 1536 | +| Expert Intermediate | 2048 | | Vocab Size | 151,936 | -| RoPE | Partial (34% of dims), theta=5M | -| Sliding Window | 32,768 | +| RoPE | Partial (34% of dims), theta=5M (full), 10K (SWA) | +| Sliding Window | 128 | | Max Position | 262,144 | -| Total Params | ~143B (FP8) / ~286B (BF16) | +| Total Params | ~143B (FP8 native) / ~286B (BF16 upcast) | Key features: - **Hybrid Attention**: 9 full attention layers (0, 5, 11, 17, 23, 29, 35, 41, 47) + 39 sliding window layers - **Asymmetric Head Dims**: Q/K use 192, V uses 128 (fused_qkv not supported) -- **Attention Sink Bias**: Learnable per-head bias in sliding window layers -- **Sigmoid Router**: For MoE expert selection -- **Expert Parallelism**: Supports EP=64 for prefill with hybrid sharding (EP=1 for token generation) +- **Attention Sink Bias**: Learnable per-head bias on sliding window layers only +- **Sigmoid Router + noaux_tc**: e_score_correction_bias added to sigmoid scores before top-k selection; unbiased scores become affinity weights +- **attention_value_scale = 0.707**: HF Flash multiplies `value_states` by this before the attention softmax × V (NOT applied to attn_output); the NxDI model matches ## Prerequisites -- **Instance**: trn2.48xlarge (32 NeuronCores, logical_nc_config=2 -> 64 logical cores) -- **Weights**: BF16 format (convert from FP8 using `conversion_script/preprocess_mimo_v2_fp8.py`) +- **Instance**: trn2.48xlarge (32 NeuronCores, logical_nc_config=2 → 64 logical cores) +- **Weights**: Neuron FP8 (produced by `conversion_script/preprocess_mimo_v2_flash_fp8.py`) for the recommended FP8 path, or BF16 (produced by `conversion_script/preprocess_mimo_v2_fp8.py`) for a BF16 fallback. -## FP8 to BF16 Conversion +## Checkpoint Preparation -The original model uses block-wise FP8 quantization incompatible with Neuron FP8. Convert to BF16: +The HuggingFace checkpoint ships as block-wise OCP FP8 (E4M3, ±448 range), which is not directly compatible with Neuron FP8 (IEEE-754 E4M3, ±240 range). Two preprocess scripts are provided: + +### Recommended: FP8 → Neuron-FP8 (streaming) + +`src/conversion_script/preprocess_mimo_v2_flash_fp8.py` performs a per-layer streaming rescale from OCP FP8 to Neuron FP8 (per-row scales for attention Q/K/V and layer-0 dense MLP; blockwise scales for MoE experts). `o_proj` is listed in HF's `quantization_config.ignored_layers` and is kept BF16 on the Neuron side (it binds to a plain `RowParallelLinear`, not `QuantizedRowParallel`). Output is ~310 GB across 48 per-layer safetensors shards. ```bash -python contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_fp8.py \ - --input-dir /path/to/MiMo-V2-Flash \ - --output-dir /path/to/MiMo-V2-Flash-BF16 +python contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_flash_fp8.py \ + --hf_model_path /path/to/MiMo-V2-Flash \ + --save_path /path/to/MiMo-V2-Flash-Neuron-FP8 \ + --tp_degree 64 ``` +Peak RAM during preprocessing is ~24 GB; total runtime ~20 minutes on a trn2.48xlarge instance. + +### Fallback: FP8 → BF16 + +`src/conversion_script/preprocess_mimo_v2_fp8.py` dequantizes the entire checkpoint to BF16. Output is ~290 GB; BF16 is numerically equivalent to the published HF FP8 weights and is useful as a known-good reference. Throughput is ~2× worse than the FP8 path because every attention/MLP matmul operates on full BF16 weights. + ## Usage ```python @@ -60,34 +71,59 @@ from pathlib import Path sys.path.insert(0, str(Path("contrib/models/MiMo-V2-Flash/src").resolve())) import torch -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from neuronx_distributed_inference.models.config import MoENeuronConfig, OnDeviceSamplingConfig from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config, HuggingFaceGenerationAdapter from modeling_mimo_v2 import NeuronMiMoV2ForCausalLM, MiMoV2InferenceConfig -model_path = "/path/to/MiMo-V2-Flash-BF16/" +model_path = "/path/to/MiMo-V2-Flash-Neuron-FP8/" compiled_path = "/path/to/compiled/" +# Recommended FP8 recipe: +# moe_tp_degree = 1, moe_ep_degree = 64 +# See "FP8 Configuration Notes" below for why other moe_tp/ep ratios collapse. neuron_config = MoENeuronConfig( tp_degree=64, + ep_degree=1, # keep outer EP = 1; only MoE-internal EP varies moe_tp_degree=1, moe_ep_degree=64, - batch_size=1, - seq_len=512, - max_context_length=128, + batch_size=32, # must be >= num_experts / top_k = 256 / 8 = 32 + max_batch_size=32, + ctx_batch_size=1, + tkg_batch_size=32, + seq_len=1024, + n_active_tokens=128, torch_dtype=torch.bfloat16, logical_nc_config=2, - sequence_parallel_enabled=True, - fused_qkv=False, # Required: asymmetric Q/K vs V dims + capacity_factor=1.0, + glu_mlp=True, + fused_qkv=False, # required: asymmetric Q/K (192) vs V (128) head dims + router_config={"act_fn": "sigmoid", "dtype": "float32"}, + blockwise_matmul_config={ + "use_shard_on_block_dynamic_while": True, + "block_sharding_strategy": "PING_PONG", + }, + save_sharded_checkpoint=True, + quantized=True, + quantized_checkpoints_path=model_path, + quantization_dtype="f8e4m3", + quantization_type="blockwise_symmetric", + quantization_block_axis=[1, 2], + quantization_block_size=[128, 128], + modules_to_not_convert=[ + "embed_tokens", "lm_head", "norm", "router", "o_proj", + ], on_device_sampling_config=OnDeviceSamplingConfig( - do_sample=True, temperature=0.6, top_k=20, top_p=0.95 + do_sample=True, temperature=0.6, top_k=20, top_p=0.95, ), - router_config={act_fn: sigmoid}, ) +# trust_remote_code is required by Flash's HF config; pre-load via AutoConfig +# and pass to NxDI so load_pretrained_config does not re-load without the flag. +hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) config = MiMoV2InferenceConfig( - neuron_config, load_config=load_pretrained_config(model_path) + neuron_config, load_config=load_pretrained_config(hf_config=hf_config), ) model = NeuronMiMoV2ForCausalLM(model_path, config) @@ -95,46 +131,68 @@ model.compile(compiled_path) model.load(compiled_path) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -adapter = HuggingFaceGenerationAdapter(model, tokenizer) -output = adapter.generate("Hello, how are you?", max_new_tokens=128) +adapter = HuggingFaceGenerationAdapter(model) +inputs = tokenizer(["Hello, how are you?"] * 32, return_tensors="pt", padding=True) +output = adapter.generate( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + max_new_tokens=128, +) ``` +For a minimal end-to-end smoke test that bypasses vLLM, see: + +- `perf_test/smoke_compile_mimo_v2_flash.py` — compile + load (STAGE=instantiate|compile|load|all, DRY_RUN, SKIP_WARMUP) +- `perf_test/smoke_generate_mimo_v2_flash.py` — 20-token generation via HuggingFaceGenerationAdapter + +Both default to the recommended FP8 recipe (`moe_tp=1`, `moe_ep=64`). + +## FP8 Configuration Notes + +### moe_tp_degree = 1, moe_ep_degree = 64 + +**Why**: at `moe_tp_degree=64` each rank owns 1/64 of the intermediate dim, which for Flash (MoE intermediate = 2048) is 32 rows — **below the 128-row blockwise scale block**. NxDI's `_setup_for_scale` detects `weight_shape[axis] < block_size` and collapses the per-rank scale dim to 1, losing per-channel FP8 scale granularity. The resulting drift compounds across Flash's 47 MoE layers and manifests as output collapse ("helpful helpful helpful ...") after roughly 30 decode tokens. + +`moe_tp_degree=1, moe_ep_degree=64` keeps each expert's weights and blockwise scales intact on a single rank (4 experts per rank), which preserves per-channel scale and produces correct output even on long multi-turn prompts. + +Intermediate ratios (`moe_tp=32/ep=2` or `moe_tp=16/ep=4`) have been empirically tested and still produce gibberish, so this is the only currently-supported moe_tp/ep combination for Flash FP8. + +### batch_size >= 32 + +NxDI's TKG (token generation) path refuses Expert Parallelism when `batch_size < num_experts / top_k`. For Flash that is 256 / 8 = 32, so the smallest working BS on the FP8 path is 32. BS=1 latency demos are not currently possible on FP8; use the BF16 checkpoint with `moe_tp=64, moe_ep=1, batch_size=1` for single-stream latency measurements. + +### outer ep_degree = 1 + +`MoENeuronConfig.ep_degree` is the **full-model** expert-parallel factor. Setting it to anything > 1 multiplies `world_size` to `tp_degree * ep_degree`, which on a 64-NC Trn2 overflows the device (ranks beyond 63 have no backing hardware, sharded-checkpoint size grows linearly, and load fails). The MoE-internal expert parallelism is controlled exclusively by `moe_ep_degree` — keep `ep_degree=1` at the outer level. + ## vLLM Integration -MiMo-V2-Flash can be served via [vllm-neuron](https://github.com/aws-neuron/vllm-neuron). A patch is required to add MiMo architecture support. +MiMo-V2-Flash can be served via [vllm-neuron](https://github.com/aws-neuron/vllm-neuron). A contrib registration patch is required to plug the NxDI modeling code into vllm-neuron's lookup tables. ### Setup ```bash -# 1. Clone vllm-project/vllm-neuron at release-0.5.0 -git clone --branch release-0.5.0 https://github.com/vllm-project/vllm-neuron.git /tmp/vllm-neuron - -# 2. Apply the contrib registration patch -cd /tmp/vllm-neuron -git apply /path/to/neuronx-distributed-inference/contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch - -# 3. Install -pip install --extra-index-url=https://pip.repos.neuron.amazonaws.com -e . +# The setup script clones vllm-project/vllm-neuron at release-0.5.0, applies +# the contrib registration patch, installs it editable, and downloads Flash +# weights (BF16 by default; set MIMO_V2_FLASH_PATH to override). +bash contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh ``` -Or let `perf_test/0_setup.sh` do steps 1-3 plus weight download. - -The patch is 40 lines and only touches `vllm_neuron/__init__.py`. It adds a -`_register_contrib_models()` hook that, when `NXDI_CONTRIB_MIMO_V2_FLASH_SRC` -is set, registers `NeuronMiMoV2ForCausalLM` into NxDI's `MODEL_TYPES` under -the key `mimo_v2_flash` **and** registers the `MiMoV2FlashForCausalLM` -architecture into vLLM's `ModelRegistry`. No upstream vLLM or NxDI source -is modified. +The patch (`perf_test/vllm-neuron-patch.patch`) is 40 lines and only touches `vllm_neuron/__init__.py`. It adds a `_register_contrib_models()` hook that, when `NXDI_CONTRIB_MIMO_V2_FLASH_SRC` is set, registers `NeuronMiMoV2ForCausalLM` into NxDI's `MODEL_TYPES` under the key `mimo_v2_flash` **and** registers the `MiMoV2FlashForCausalLM` architecture into vLLM's `ModelRegistry`. No upstream vLLM or NxDI source is modified. -### Serving +### Serving (FP8, recommended) ```bash -# The contrib src/ must be reachable so the plugin hook can import it. export NXDI_CONTRIB_MIMO_V2_FLASH_SRC=/path/to/neuronx-distributed-inference/contrib/models/MiMo-V2-Flash/src -export MIMO_V2_FLASH_PATH=/path/to/MiMo-V2-Flash-BF16 +export MIMO_V2_FLASH_PATH=/path/to/MiMo-V2-Flash-Neuron-FP8 +# First-time compile of Flash's 256-expert MoE takes 30-60 minutes. +export VLLM_ENGINE_READY_TIMEOUT_S=7200 +# Optional: isolate compile cache per config so parallel Flash/Pro/etc. compiles +# don't race on the default /var/tmp/neuron-compile-cache lock files. +export NEURON_COMPILED_ARTIFACTS=/path/to/compiled/mimo_v2_flash_bs32_moetp1_ep64_fp8 python3 -m vllm.entrypoints.openai.api_server \ - --model /path/to/MiMo-V2-Flash-BF16 \ + --model "$MIMO_V2_FLASH_PATH" \ --tensor-parallel-size 64 \ --max-model-len 1024 \ --max-num-seqs 32 \ @@ -146,11 +204,19 @@ python3 -m vllm.entrypoints.openai.api_server \ "tp_degree": 64, "logical_nc_config": 2, "fused_qkv": false, - "flash_decoding_enabled": false, - "sequence_parallel_enabled": true, + "sequence_parallel_enabled": false, "glu_mlp": true, "normalize_top_k_affinities": true, + "save_sharded_checkpoint": true, "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, + "quantized": true, + "quantized_checkpoints_path": "/path/to/MiMo-V2-Flash-Neuron-FP8", + "quantization_dtype": "f8e4m3", + "quantization_type": "blockwise_symmetric", + "quantization_block_axis": [1, 2], + "quantization_block_size": [128, 128], + "modules_to_not_convert": ["embed_tokens", "lm_head", "norm", "router", "o_proj"], + "blockwise_matmul_config": {"use_shard_on_block_dynamic_while": true, "block_sharding_strategy": "PING_PONG"}, "moe_tp_degree": 1, "moe_ep_degree": 64, "batch_size": 32, @@ -160,6 +226,8 @@ python3 -m vllm.entrypoints.openai.api_server \ "seq_len": 1024, "is_continuous_batching": true, "enable_bucketing": true, + "context_encoding_buckets": [1024], + "token_generation_buckets": [1024], "async_mode": true, "on_device_sampling_config": { "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 @@ -168,17 +236,20 @@ python3 -m vllm.entrypoints.openai.api_server \ }' ``` -### Key vLLM Patch Changes +See `perf_test/bench_mimo_v2_flash.sh` for the full benchmark recipe at BS=32 and BS=128. -The patch (`contrib/models/MiMo-V2-Flash/perf_test/vllm-neuron-patch.patch`) modifies vllm-neuron to: -- Map MiMo architecture to Qwen2 model loader (MiMo is Qwen2-based) -- Pass `hf_config` from vLLM to NxDI (required for `trust_remote_code` models) -- Replace `AutoModelForCausalLM.from_pretrained` with `snapshot_download` for model loading +### vllm-neuron patch summary -See `contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh` for full benchmark configurations with BS=1/32/128. Run `contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh` first to install vllm-neuron and fetch weights. +The patch is applied to vllm-neuron 0.5.0 and: + +- Maps the `MiMoV2FlashForCausalLM` architecture to Flash's model loader (reusing the Qwen2-family loader path, which Flash's tokenizer inherits from). +- Passes `hf_config` from vLLM into `load_pretrained_config` so NxDI does not re-load the config without `trust_remote_code=True`. +- Replaces vllm-neuron's internal `AutoModelForCausalLM.from_pretrained` call with `huggingface_hub.snapshot_download`, which is the only path that works for `trust_remote_code=True` models when no GPU is available for HF's CUDA-gated FP8 quantizer. ## Performance +> These numbers are from the earlier BF16 recipe (pre-FP8 rollout). FP8 numbers will be added once a stable bench run completes on the new recipe; preliminary single-stream qualitative tests show fluent multi-sentence output on long Chinese chat prompts with `moe_tp=1, moe_ep=64, batch_size=32`. + ### Standalone NxDI (trn2.48xlarge, BF16, TP=64, EP=64) | Batch Size | Throughput (tok/s) | @@ -197,14 +268,14 @@ Input/output: 900/90 tokens (random dataset) | 16 | 224.57 | 64.95 | 570 | | 32 | 302.61 | 90.23 | 1351 | -> **Note:** Large MoE models like MiMo-V2-Flash require extended engine startup time (~47 min for compile+load). Set `VLLM_ENGINE_READY_TIMEOUT_S=3600` before launching the vLLM server. +> **Compile time:** the first Flash compile on SDK 2.29 is ~30-60 minutes for the TKG NEFF and similar for the CTE NEFF. Subsequent runs with the same `override_neuron_config` hit the neuronx-cc cache and start in ~1-2 minutes. `save_sharded_checkpoint=true` additionally persists per-rank FP8 shards under `/weights/`, letting future `load()` calls skip the ~10-minute shard_checkpoint pass. ## Compatibility Matrix -| Instance/Version | 2.22+ (PyTorch 2.9) | 2.21 and earlier | -|------------------|---------------------|------------------| +| Instance | Neuron SDK 2.29+ (PyTorch 2.9) | 2.21 and earlier | +|----------|--------------------------------|------------------| | Trn2 (trn2.48xlarge) | Tested | Not tested | -| Trn1 | Not supported (requires 64 cores) | Not supported | +| Trn1 | Not supported (requires 64 logical cores via logical_nc_config=2) | Not supported | | Inf2 | Not supported | Not supported | ## Testing @@ -215,18 +286,21 @@ pytest contrib/models/MiMo-V2-Flash/test/integration/test_model.py -v ## Key Implementation Notes -1. **Hybrid Attention**: `hybrid_layer_pattern` list determines full vs sliding window per layer. -2. **CONVERT_TO_MHA**: When TP > num_kv_heads (4), K/V are replicated to match Q heads (64). -3. **Attention Sink Bias**: Adds learnable sink column to attention weights in sliding window layers. -4. **EP Hybrid Sharding**: EP is used during prefill only; token generation uses EP=1 unless batch_size >= 32. -5. **FP8 Conversion**: Original uses OCP block-wise FP8, requires conversion to BF16 or Neuron-compatible FP8 format. +1. **Hybrid Attention**: `hybrid_layer_pattern` list determines full vs sliding window per layer; the modeling code constructs one `NeuronMiMoV2Attention` per layer with the correct `is_sliding_window` flag and rope_theta. +2. **CONVERT_TO_MHA**: When `tp_degree > num_kv_heads` (64 > 4 full / 64 > 8 SWA), K/V are replicated to `num_attention_heads` (64) during state-dict conversion; this applies to both `.weight` and the per-row `.scale` on the FP8 path. +3. **Attention Sink Bias**: Learnable per-head bias added as an extra "sink" column to attention scores in sliding window layers (not added in full-attention layers). Per-rank slicing of the bias happens inside `forward()` based on `parallel_state.get_tensor_model_parallel_rank()`. +4. **FP8 Path Caveats**: + - Must use `moe_tp_degree=1, moe_ep_degree=64` (see "FP8 Configuration Notes" above). + - Must use `batch_size >= 32` (NxDI EP>1 requirement). + - Must keep outer `ep_degree=1` (only `moe_ep_degree` should vary). + - Several runtime monkey-patches (router bias, blockwise scale stride, 2D per-channel, EP scale handling) are installed automatically in `NeuronMiMoV2ForCausalLM.__init__` when `quantized=True`; the BF16 path is untouched. ## Example Checkpoints -* [XiaomiMiMo/MiMo-V2-Flash](https://huggingface.co/XiaomiMiMo/MiMo-V2-Flash) +* [XiaomiMiMo/MiMo-V2-Flash](https://huggingface.co/XiaomiMiMo/MiMo-V2-Flash) — HF FP8 source checkpoint ## Maintainer Henan Wan (whn09) -**Last Updated:** 2026-04-13 +**Last Updated:** 2026-04-25 From c303b52fbc77065e3e79818454e85a9f7cbbb8a1 Mon Sep 17 00:00:00 2001 From: whn09 Date: Sat, 25 Apr 2026 14:37:08 +0800 Subject: [PATCH 23/23] README: add Quick Start section for FP8 reproduction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The existing README had all the pieces (preprocess, smoke, vLLM) but scattered across different sections, so a new user couldn't tell what order to run them in. Add a concrete 6-step Quick Start that walks from a fresh trn2.48xlarge to a working vLLM server — download, preprocess, smoke-verify, install vllm-neuron, bench — with approximate timings so the 60-minute first-compile isn't a surprise. Also spell out the prerequisites (SDK version, which DLAMI venv to use for which stage, disk requirement) and add a curl snippet for post-deployment sanity check with the specific symptom ("helpful helpful helpful ...") that indicates the FP8 recipe is misconfigured. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2-Flash/README.md | 60 +++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/contrib/models/MiMo-V2-Flash/README.md b/contrib/models/MiMo-V2-Flash/README.md index ae40c700..220cb7d5 100644 --- a/contrib/models/MiMo-V2-Flash/README.md +++ b/contrib/models/MiMo-V2-Flash/README.md @@ -38,7 +38,65 @@ Key features: ## Prerequisites - **Instance**: trn2.48xlarge (32 NeuronCores, logical_nc_config=2 → 64 logical cores) -- **Weights**: Neuron FP8 (produced by `conversion_script/preprocess_mimo_v2_flash_fp8.py`) for the recommended FP8 path, or BF16 (produced by `conversion_script/preprocess_mimo_v2_fp8.py`) for a BF16 fallback. +- **Neuron SDK**: 2.29 (Python 3.12, PyTorch 2.9) +- **Venvs**: `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference` (for preprocess + NxDI direct smoke), `/opt/aws_neuronx_venv_pytorch_inference_vllm_0_16` (for vLLM serving). Both ship with the DLAMI. +- **Disk**: ~700 GB free under `/opt/dlami/nvme` (the HF FP8 checkpoint is ~290 GB, the Neuron-FP8 preprocessed output is ~310 GB, and `save_sharded_checkpoint=true` writes another ~300 GB per compiled config). + +## Quick Start (FP8 on Trn2) + +End-to-end recipe to go from a fresh trn2.48xlarge to a working vLLM OpenAI server serving Flash FP8. First-time compile takes ~45-60 minutes; subsequent runs hit the neuronx-cc cache and start in a few minutes. + +```bash +# 1. Clone this repo on the Trn2 instance +cd $HOME +git clone /neuronx-distributed-inference.git +cd neuronx-distributed-inference +git checkout contrib/MiMo-V2-Flash # the branch this README lives on + +# 2. Download the HuggingFace FP8 checkpoint (~290 GB). Any HF-compatible +# downloader works; huggingface-cli example: +huggingface-cli download XiaomiMiMo/MiMo-V2-Flash \ + --local-dir /opt/dlami/nvme/models/MiMo-V2-Flash + +# 3. Preprocess HF FP8 -> Neuron FP8 (~20 min, ~24 GB peak RAM) +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +python contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_flash_fp8.py \ + --hf_model_path /opt/dlami/nvme/models/MiMo-V2-Flash \ + --save_path /opt/dlami/nvme/models/MiMo-V2-Flash-Neuron-FP8 \ + --tp_degree 64 + +# 4. (Optional) sanity-check the Neuron-FP8 checkpoint without vLLM +# ~45 min first compile; subsequent runs ~30s to load the pre-sharded NEFF. +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate +python contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py # compile +python contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py # 20-token generate + +# 5. Install vllm-neuron with the contrib registration patch +bash contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh + +# 6. Start vLLM serving Flash FP8 (first compile ~60 min; subsequent ~3 min) +bash contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh +``` + +The bench script runs two configurations (BS=32 and BS=128, both +`moe_tp_degree=1 / moe_ep_degree=64`) and logs results under +`/tmp/bench_results/mimo_v2_flash/`. + +For a quick `curl` sanity check while the server is up: + +```bash +curl -s http://localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{"model": "/opt/dlami/nvme/models/MiMo-V2-Flash-Neuron-FP8", + "messages": [{"role": "user", "content": "Hello! Introduce yourself in one sentence."}], + "max_tokens": 64, "temperature": 0.0}' | python3 -m json.tool +``` + +If you get fluent sentence-ending output on a 30+ token generation, the +FP8 path is working correctly. If you see repetition collapse +("helpful helpful helpful..."), double-check that `moe_tp_degree=1`, +`moe_ep_degree=64`, `batch_size>=32`, and that you are loading the +preprocessed Neuron-FP8 checkpoint (not the raw HF FP8 directory). ## Checkpoint Preparation