From 82d7eade1d9bb358e98982c70ea635a38e69851d Mon Sep 17 00:00:00 2001 From: whn09 Date: Tue, 28 Apr 2026 05:52:07 +0800 Subject: [PATCH 01/24] [contrib] Add MiMo-V2.5-Pro initial port (copied from MiMo-V2-Pro, renamed) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bootstrap contrib entry for XiaomiMiMo/MiMo-V2.5-Pro on Trn2 via NxDI. Same starting point as the MiMo-V2-Pro port: src/modeling_mimo_v2.py (was modeling_mimo_v2_pro.py) src/conversion_script/preprocess_mimo_v2_fp8.py perf_test/{smoke_compile,smoke_generate,bench}_mimo_v2.{py,sh} Rename-only changes in this commit: MiMoV2Pro* identifiers -> MiMoV2* (classes, configs, modules) mimo_v2_pro paths -> mimo_v2 / mimo_v25_pro (compile dirs) HF repo XiaomiMiMo/MiMo-V2-Pro -> XiaomiMiMo/MiMo-V2.5-Pro README architecture table updated to V2.5-Pro config (70 layers, 6144 hidden, 128 heads, 384 experts, etc.) README disk footprint updated to match V2.5-Pro actual size (~962GB HF) Not yet adapted to V2.5-specific differences — these still need work: - attention_chunk_size=128 (new in V2.5, not handled in V2-Pro code) - MoE group-limited noaux_tc (n_group, topk_group) — V2.5 config sets 1,1 so it degenerates to plain noaux_tc; the Pro monkey-patch already matches - FP8 recipe verification on V2.5 weights (V2-Pro workarounds may or may not apply: mean-subtract router bias, split_qkv_fused interleaved layout, blockwise scale stride fix) Subsequent commits will adapt each of the above after validation on Trn2. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2.5-Pro/README.md | 364 ++++ .../models/MiMo-V2.5-Pro/perf_test/0_setup.sh | 61 + .../MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh | 228 +++ .../perf_test/run_bench_single.sh | 76 + .../MiMo-V2.5-Pro/perf_test/sanity_check.sh | 59 + .../perf_test/smoke_compile_mimo_v2.py | 195 ++ .../perf_test/smoke_generate_mimo_v2.py | 231 +++ .../perf_test/vllm-neuron-patch.patch | 107 ++ contrib/models/MiMo-V2.5-Pro/src/__init__.py | 0 .../preprocess_mimo_v2_fp8.py | 641 +++++++ .../MiMo-V2.5-Pro/src/modeling_mimo_v2.py | 1676 +++++++++++++++++ contrib/models/MiMo-V2.5-Pro/test/__init__.py | 0 .../test/integration/__init__.py | 0 .../test/integration/test_model.py | 53 + .../MiMo-V2.5-Pro/test/unit/__init__.py | 0 15 files changed, 3691 insertions(+) create mode 100644 contrib/models/MiMo-V2.5-Pro/README.md create mode 100755 contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh create mode 100755 contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh create mode 100755 contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh create mode 100755 contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh create mode 100755 contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py create mode 100755 contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py create mode 100644 contrib/models/MiMo-V2.5-Pro/perf_test/vllm-neuron-patch.patch create mode 100644 contrib/models/MiMo-V2.5-Pro/src/__init__.py create mode 100644 contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py create mode 100644 contrib/models/MiMo-V2.5-Pro/src/modeling_mimo_v2.py create mode 100644 contrib/models/MiMo-V2.5-Pro/test/__init__.py create mode 100644 contrib/models/MiMo-V2.5-Pro/test/integration/__init__.py create mode 100644 contrib/models/MiMo-V2.5-Pro/test/integration/test_model.py create mode 100644 contrib/models/MiMo-V2.5-Pro/test/unit/__init__.py diff --git a/contrib/models/MiMo-V2.5-Pro/README.md b/contrib/models/MiMo-V2.5-Pro/README.md new file mode 100644 index 00000000..38623380 --- /dev/null +++ b/contrib/models/MiMo-V2.5-Pro/README.md @@ -0,0 +1,364 @@ +# Contrib Model: MiMo-V2.5-Pro + +NeuronX Distributed Inference implementation of [XiaomiMiMo/MiMo-V2.5-Pro](https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro). + +## Model Information + +- **HuggingFace ID:** `XiaomiMiMo/MiMo-V2.5-Pro` +- **Model Type:** Decoder-only MoE transformer with hybrid attention +- **Architecture:** Custom MoE with full + sliding window attention +- **License:** Check HuggingFace model card + +## Architecture Details + +| Parameter | Value | +|-----------|-------| +| Hidden Size | 6144 | +| Layers | 70 | +| Attention Heads | 128 Q | +| KV Heads (full & sliding window) | 8 | +| Q/K Head Dim | 192 | +| V Head Dim | 128 | +| Experts | 384 routed (top-8 routing), no shared expert | +| Expert Intermediate | 2048 | +| Dense MLP Intermediate (layer 0) | 16,384 | +| Vocab Size | 152,576 | +| RoPE | Partial (33.4% → 64 of 192 dims), theta=10M (full) / 10K (SWA) | +| Sliding Window | 128 | +| Max Position | 1,048,576 (1M) | +| Attention Projection | `fused_qkv` (single `qkv_proj.weight`) | + +Key features: +- **Hybrid Attention**: 10 full attention layers (0, 7, 15, 23, 31, 39, 47, 55, 62, 69) + 60 sliding window layers, per `hybrid_layer_pattern` +- **Asymmetric Head Dims**: Q/K use head_dim=192, V uses v_head_dim=128 +- **Attention Sink Bias**: Learnable per-head bias on sliding window layers only (`add_swa_attention_sink_bias=True`, `add_full_attention_sink_bias=False`) +- **Sigmoid Router + noaux_tc**: `sigmoid(logits) + e_score_correction_bias` is used to pick top-8 experts; unbiased `sigmoid(logits)` becomes the affinity weights. `n_group=1, topk_group=1` degenerates group-limited routing to plain noaux_tc. +- **attention_value_scale = 0.612**: HF reference multiplies `value_states` by this before `softmax(QK^T) × V` (NOT applied post-attention); the NxDI port matches. + +## Prerequisites + +- **Instance**: trn2.48xlarge (32 NeuronCores, logical_nc_config=2 → 64 logical cores) +- **Neuron SDK**: 2.29 (Python 3.12, PyTorch 2.9) +- **Venvs**: `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference` (for preprocess + NxDI direct smoke), `/opt/aws_neuronx_venv_pytorch_inference_vllm_0_16` (for vLLM serving). Both ship with the DLAMI. +- **Disk**: ~3 TB free under `/opt/dlami/nvme` (the HF FP8 checkpoint is ~962 GB, the Neuron-FP8 preprocessed output is ~1 TB, and `save_sharded_checkpoint=true` writes another ~300-1000 GB per compiled config (varies with recipe)). + +## Quick Start (FP8 on Trn2) + +End-to-end recipe to go from a fresh trn2.48xlarge to a working vLLM OpenAI server serving MiMo-V2.5-Pro FP8. First-time compile takes ~45-60 minutes; subsequent runs hit the neuronx-cc cache and start in a few minutes. + +```bash +# 1. Clone this repo on the Trn2 instance +cd $HOME +git clone /neuronx-distributed-inference.git +cd neuronx-distributed-inference +git checkout contrib/MiMo-V2.5-Pro # the branch this README lives on + +# 2. Download the HuggingFace FP8 checkpoint (~290 GB). Any HF-compatible +# downloader works; huggingface-cli example: +huggingface-cli download XiaomiMiMo/MiMo-V2.5-Pro \ + --local-dir /opt/dlami/nvme/models/MiMo-V2.5-Pro + +# 3. Preprocess HF FP8 -> Neuron FP8 (~20 min, ~24 GB peak RAM) +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +python contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py \ + --hf_model_path /opt/dlami/nvme/models/MiMo-V2.5-Pro \ + --save_path /opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8 \ + --tp_degree 64 + +# 4. (Optional) sanity-check the Neuron-FP8 checkpoint without vLLM +# ~45 min first compile; subsequent runs ~30s to load the pre-sharded NEFF. +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate +python contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py # compile +python contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py # 20-token generate + +# 5. Install vllm-neuron with the contrib registration patch +bash contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh + +# 6. Start vLLM serving MiMo-V2.5-Pro FP8 (first compile ~60 min; subsequent ~3 min) +bash contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh +``` + +The bench script runs two configurations (BS=32 and BS=128, both +`moe_tp_degree=X / moe_ep_degree=Y (see bench script)`) and logs results under +`/tmp/bench_results/mimo_v25_pro/`. + +For a quick `curl` sanity check while the server is up: + +```bash +curl -s http://localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{"model": "/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8", + "messages": [{"role": "user", "content": "Hello! Introduce yourself in one sentence."}], + "max_tokens": 64, "temperature": 0.0}' | python3 -m json.tool +``` + +If you get fluent sentence-ending output on a 30+ token generation, the +FP8 path is working correctly. If you see repetition collapse +("helpful helpful helpful..."), double-check that `moe_tp_degree=1`, +`moe_ep_degree=64`, `batch_size>=32`, and that you are loading the +preprocessed Neuron-FP8 checkpoint (not the raw HF FP8 directory). + +## Checkpoint Preparation + +The HuggingFace checkpoint ships as block-wise OCP FP8 (E4M3, ±448 range), which is not directly compatible with Neuron FP8 (IEEE-754 E4M3, ±240 range). Two preprocess scripts are provided: + +### Recommended: FP8 → Neuron-FP8 (streaming) + +`src/conversion_script/preprocess_mimo_v2_fp8.py` performs a per-layer streaming rescale from OCP FP8 to Neuron FP8 (per-row scales for attention Q/K/V and layer-0 dense MLP; blockwise scales for MoE experts). `o_proj` is listed in HF's `quantization_config.ignored_layers` and is kept BF16 on the Neuron side (it binds to a plain `RowParallelLinear`, not `QuantizedRowParallel`). Output is ~1 TB across 70 per-layer safetensors shards. + +```bash +python contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py \ + --hf_model_path /path/to/MiMo-V2.5-Pro \ + --save_path /path/to/MiMo-V2.5-Pro-Neuron-FP8 \ + --tp_degree 64 +``` + +Peak RAM during preprocessing is ~24 GB; total runtime ~20 minutes on a trn2.48xlarge instance. + +### Fallback: FP8 → BF16 + +`src/conversion_script/preprocess_mimo_v2_fp8.py` dequantizes the entire checkpoint to BF16. Output is ~290 GB; BF16 is numerically equivalent to the published HF FP8 weights and is useful as a known-good reference. Throughput is ~2× worse than the FP8 path because every attention/MLP matmul operates on full BF16 weights. + +## Usage + +```python +import sys +from pathlib import Path + +# Make this contrib package's src/ importable (flat, per upstream contrib convention). +sys.path.insert(0, str(Path("contrib/models/MiMo-V2.5-Pro/src").resolve())) + +import torch +from transformers import AutoConfig, AutoTokenizer +from neuronx_distributed_inference.models.config import MoENeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config, HuggingFaceGenerationAdapter + +from modeling_mimo_v2 import NeuronMiMoV2ForCausalLM, MiMoV2InferenceConfig + +model_path = "/path/to/MiMo-V2.5-Pro-Neuron-FP8/" +compiled_path = "/path/to/compiled/" + +# Recommended FP8 recipe: +# moe_tp_degree = 1, moe_ep_degree = 64 +# See "FP8 Configuration Notes" below for why other moe_tp/ep ratios collapse. +neuron_config = MoENeuronConfig( + tp_degree=64, + ep_degree=1, # keep outer EP = 1; only MoE-internal EP varies + moe_tp_degree=1, + moe_ep_degree=64, + batch_size=32, # must be >= num_experts / top_k = 256 / 8 = 32 + max_batch_size=32, + ctx_batch_size=1, + tkg_batch_size=32, + seq_len=1024, + n_active_tokens=128, + torch_dtype=torch.bfloat16, + logical_nc_config=2, + capacity_factor=1.0, + glu_mlp=True, + fused_qkv=False, # required: asymmetric Q/K (192) vs V (128) head dims + router_config={"act_fn": "sigmoid", "dtype": "float32"}, + blockwise_matmul_config={ + "use_shard_on_block_dynamic_while": True, + "block_sharding_strategy": "PING_PONG", + }, + save_sharded_checkpoint=True, + quantized=True, + quantized_checkpoints_path=model_path, + quantization_dtype="f8e4m3", + quantization_type="blockwise_symmetric", + quantization_block_axis=[1, 2], + quantization_block_size=[128, 128], + modules_to_not_convert=[ + "embed_tokens", "lm_head", "norm", "router", "o_proj", + ], + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=True, temperature=0.6, top_k=20, top_p=0.95, + ), +) + +# trust_remote_code is required by Flash's HF config; pre-load via AutoConfig +# and pass to NxDI so load_pretrained_config does not re-load without the flag. +hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) +config = MiMoV2InferenceConfig( + neuron_config, load_config=load_pretrained_config(hf_config=hf_config), +) + +model = NeuronMiMoV2ForCausalLM(model_path, config) +model.compile(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +adapter = HuggingFaceGenerationAdapter(model) +inputs = tokenizer(["Hello, how are you?"] * 32, return_tensors="pt", padding=True) +output = adapter.generate( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + max_new_tokens=128, +) +``` + +For a minimal end-to-end smoke test that bypasses vLLM, see: + +- `perf_test/smoke_compile_mimo_v2.py` — compile + load (STAGE=instantiate|compile|load|all, DRY_RUN, SKIP_WARMUP) +- `perf_test/smoke_generate_mimo_v2.py` — 20-token generation via HuggingFaceGenerationAdapter + +Both default to the recommended FP8 recipe (`moe_tp=1`, `moe_ep=64`). + +## FP8 Configuration Notes + +### moe_tp_degree = 1, moe_ep_degree = 64 + +**Why**: at `moe_tp_degree=64` each rank owns 1/64 of the intermediate dim, which for Flash (MoE intermediate = 2048) is 32 rows — **below the 128-row blockwise scale block**. NxDI's `_setup_for_scale` detects `weight_shape[axis] < block_size` and collapses the per-rank scale dim to 1, losing per-channel FP8 scale granularity. The resulting drift compounds across Flash's 47 MoE layers and manifests as output collapse ("helpful helpful helpful ...") after roughly 30 decode tokens. + +`moe_tp_degree=1, moe_ep_degree=64` keeps each expert's weights and blockwise scales intact on a single rank (4 experts per rank), which preserves per-channel scale and produces correct output even on long multi-turn prompts. + +Intermediate ratios (`moe_tp=32/ep=2` or `moe_tp=16/ep=4`) have been empirically tested and still produce gibberish, so this is the only currently-supported moe_tp/ep combination for MiMo-V2.5-Pro FP8. + +### batch_size >= 32 + +NxDI's TKG (token generation) path refuses Expert Parallelism when `batch_size < num_experts / top_k`. For Flash that is 256 / 8 = 32, so the smallest working BS on the FP8 path is 32. BS=1 latency demos are not currently possible on FP8; use the BF16 checkpoint with `moe_tp=64, moe_ep=1, batch_size=1` for single-stream latency measurements. + +### outer ep_degree = 1 + +`MoENeuronConfig.ep_degree` is the **full-model** expert-parallel factor. Setting it to anything > 1 multiplies `world_size` to `tp_degree * ep_degree`, which on a 64-NC Trn2 overflows the device (ranks beyond 63 have no backing hardware, sharded-checkpoint size grows linearly, and load fails). The MoE-internal expert parallelism is controlled exclusively by `moe_ep_degree` — keep `ep_degree=1` at the outer level. + +## vLLM Integration + +MiMo-V2.5-Pro can be served via [vllm-neuron](https://github.com/aws-neuron/vllm-neuron). A contrib registration patch is required to plug the NxDI modeling code into vllm-neuron's lookup tables. + +### Setup + +```bash +# The setup script clones vllm-project/vllm-neuron at release-0.5.0, applies +# the contrib registration patch, installs it editable, and downloads Flash +# weights (BF16 by default; set MIMO_V2_FLASH_PATH to override). +bash contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh +``` + +The patch (`perf_test/vllm-neuron-patch.patch`) is 40 lines and only touches `vllm_neuron/__init__.py`. It adds a `_register_contrib_models()` hook that, when `NXDI_CONTRIB_MIMO_V2_FLASH_SRC` is set, registers `NeuronMiMoV2ForCausalLM` into NxDI's `MODEL_TYPES` under the key `mimo_v2` **and** registers the `MiMoV2ForCausalLM` architecture into vLLM's `ModelRegistry`. No upstream vLLM or NxDI source is modified. + +### Serving (FP8, recommended) + +```bash +export NXDI_CONTRIB_MIMO_V2_FLASH_SRC=/path/to/neuronx-distributed-inference/contrib/models/MiMo-V2.5-Pro/src +export MIMO_V2_FLASH_PATH=/path/to/MiMo-V2.5-Pro-Neuron-FP8 +# First-time compile of Flash's 256-expert MoE takes 30-60 minutes. +export VLLM_ENGINE_READY_TIMEOUT_S=7200 +# Optional: isolate compile cache per config so parallel Flash/Pro/etc. compiles +# don't race on the default /var/tmp/neuron-compile-cache lock files. +export NEURON_COMPILED_ARTIFACTS=/path/to/compiled/mimo_v25_pro_bs32_moetp1_ep64_fp8 + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MIMO_V2_FLASH_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 32 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + "tp_degree": 64, + "logical_nc_config": 2, + "fused_qkv": false, + "sequence_parallel_enabled": false, + "glu_mlp": true, + "normalize_top_k_affinities": true, + "save_sharded_checkpoint": true, + "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, + "quantized": true, + "quantized_checkpoints_path": "/path/to/MiMo-V2.5-Pro-Neuron-FP8", + "quantization_dtype": "f8e4m3", + "quantization_type": "blockwise_symmetric", + "quantization_block_axis": [1, 2], + "quantization_block_size": [128, 128], + "modules_to_not_convert": ["embed_tokens", "lm_head", "norm", "router", "o_proj"], + "blockwise_matmul_config": {"use_shard_on_block_dynamic_while": true, "block_sharding_strategy": "PING_PONG"}, + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 32, + "ctx_batch_size": 1, + "tkg_batch_size": 32, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "enable_bucketing": true, + "context_encoding_buckets": [1024], + "token_generation_buckets": [1024], + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' +``` + +See `perf_test/bench_mimo_v2.sh` for the full benchmark recipe at BS=32 and BS=128. + +### vllm-neuron patch summary + +The patch is applied to vllm-neuron 0.5.0 and: + +- Maps the `MiMoV2ForCausalLM` architecture to Flash's model loader (reusing the Qwen2-family loader path, which Flash's tokenizer inherits from). +- Passes `hf_config` from vLLM into `load_pretrained_config` so NxDI does not re-load the config without `trust_remote_code=True`. +- Replaces vllm-neuron's internal `AutoModelForCausalLM.from_pretrained` call with `huggingface_hub.snapshot_download`, which is the only path that works for `trust_remote_code=True` models when no GPU is available for HF's CUDA-gated FP8 quantizer. + +## Performance + +> These numbers are from the earlier BF16 recipe (pre-FP8 rollout). FP8 numbers will be added once a stable bench run completes on the new recipe; preliminary single-stream qualitative tests show fluent multi-sentence output on long Chinese chat prompts with `moe_tp=1, moe_ep=64, batch_size=32`. + +### Standalone NxDI (trn2.48xlarge, BF16, TP=64, EP=64) + +| Batch Size | Throughput (tok/s) | +|------------|-------------------| +| 1 | 29.92 | +| 8 | 215.94 | +| 32 | 649.14 | + +### vLLM Serving (trn2.48xlarge, BF16, BS=32, TP=64/EP=64, CB) + +Input/output: 900/90 tokens (random dataset) + +| Concurrency | Throughput (tok/s) | TPOT (ms) | TTFT (ms) | +|-------------|-------------------|-----------|-----------| +| 1 | 27.98 | 33.65 | 222 | +| 16 | 224.57 | 64.95 | 570 | +| 32 | 302.61 | 90.23 | 1351 | + +> **Compile time:** the first Flash compile on SDK 2.29 is ~30-60 minutes for the TKG NEFF and similar for the CTE NEFF. Subsequent runs with the same `override_neuron_config` hit the neuronx-cc cache and start in ~1-2 minutes. `save_sharded_checkpoint=true` additionally persists per-rank FP8 shards under `/weights/`, letting future `load()` calls skip the ~10-minute shard_checkpoint pass. + +## Compatibility Matrix + +| Instance | Neuron SDK 2.29+ (PyTorch 2.9) | 2.21 and earlier | +|----------|--------------------------------|------------------| +| Trn2 (trn2.48xlarge) | Tested | Not tested | +| Trn1 | Not supported (requires 64 logical cores via logical_nc_config=2) | Not supported | +| Inf2 | Not supported | Not supported | + +## Testing + +```bash +pytest contrib/models/MiMo-V2.5-Pro/test/integration/test_model.py -v +``` + +## Key Implementation Notes + +1. **Hybrid Attention**: `hybrid_layer_pattern` list determines full vs sliding window per layer; the modeling code constructs one `NeuronMiMoV2Attention` per layer with the correct `is_sliding_window` flag and rope_theta. +2. **CONVERT_TO_MHA**: When `tp_degree > num_kv_heads` (64 > 4 full / 64 > 8 SWA), K/V are replicated to `num_attention_heads` (64) during state-dict conversion; this applies to both `.weight` and the per-row `.scale` on the FP8 path. +3. **Attention Sink Bias**: Learnable per-head bias added as an extra "sink" column to attention scores in sliding window layers (not added in full-attention layers). Per-rank slicing of the bias happens inside `forward()` based on `parallel_state.get_tensor_model_parallel_rank()`. +4. **FP8 Path Caveats**: + - Must use `moe_tp_degree=1, moe_ep_degree=64` (see "FP8 Configuration Notes" above). + - Must use `batch_size >= 32` (NxDI EP>1 requirement). + - Must keep outer `ep_degree=1` (only `moe_ep_degree` should vary). + - Several runtime monkey-patches (router bias, blockwise scale stride, 2D per-channel, EP scale handling) are installed automatically in `NeuronMiMoV2ForCausalLM.__init__` when `quantized=True`; the BF16 path is untouched. + +## Example Checkpoints + +* [XiaomiMiMo/MiMo-V2.5-Pro](https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro) — HF FP8 source checkpoint + +## Maintainer + +Henan Wan (whn09) + +**Last Updated:** 2026-04-25 diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh new file mode 100755 index 00000000..5d43c529 --- /dev/null +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Setup for MiMo-V2.5-Pro vLLM benchmarking on Trn2. +# +# This clones upstream vllm-project/vllm-neuron at release-0.5.0 (which already +# has the mimov2flash -> mimo_v2 model_type rewrite), then applies +# vllm-neuron-patch.patch to add a runtime registration hook so the contrib +# NeuronMiMoV2ForCausalLM is plugged into both NxDI's MODEL_TYPES and vLLM's +# ModelRegistry at vllm-neuron plugin init time. +set -e + +echo "==========================================" +echo "Setup: vllm-neuron + MiMo-V2.5-Pro weights" +echo "==========================================" + +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +PATCH_FILE="$(cd "$(dirname "$0")" && pwd)/vllm-neuron-patch.patch" + +echo "" +echo "[1/2] Installing vllm-neuron (release-0.5.0) with the contrib registration patch..." + +if [ ! -d $HOME/vllm-neuron ]; then + git clone --branch release-0.5.0 https://github.com/vllm-project/vllm-neuron.git $HOME/vllm-neuron +fi + +cd $HOME/vllm-neuron + +# Apply patch (idempotent via `git apply --check` first). +if git apply --check "$PATCH_FILE" 2>/dev/null; then + git apply "$PATCH_FILE" + echo " Applied $PATCH_FILE" +else + echo " Patch already applied or conflicts; continuing." +fi + +pip install --extra-index-url=https://pip.repos.neuron.amazonaws.com -e . +pip install s5cmd + +python3 -c "import vllm_neuron; print('vllm-neuron installed:', vllm_neuron.__file__)" + +echo "" +echo "[2/2] Downloading MiMo-V2.5-Pro BF16 weights..." + +MIMO_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-BF16}" +if [ -d "$MIMO_PATH" ] && [ "$(ls "$MIMO_PATH"/*.safetensors 2>/dev/null | wc -l)" -gt 0 ]; then + echo " MiMo weights already exist at $MIMO_PATH, skipping download" +else + echo " Downloading BF16 weights from your S3 bucket (edit the URI if needed)..." + mkdir -p "$MIMO_PATH" + s5cmd cp "s3://datalab/xiaomi/models/MiMo-V2.5-Pro-BF16/**" "$MIMO_PATH/" + echo " Download complete: $(du -sh $MIMO_PATH | cut -f1)" +fi + +# Figure out where this contrib package's src/ lives so the registration hook +# can add it to sys.path inside vllm-neuron. +CONTRIB_SRC="$(cd "$(dirname "$0")/.." && pwd)/src" + +echo "" +echo "Setup complete. Before running the benchmark, export:" +echo " export MIMO_V2_FLASH_PATH=$MIMO_PATH" +echo " export NXDI_CONTRIB_MIMO_V2_FLASH_SRC=$CONTRIB_SRC" diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh new file mode 100755 index 00000000..69c7d417 --- /dev/null +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh @@ -0,0 +1,228 @@ +#!/bin/bash +set -e + +# MiMo-V2.5-Pro FP8 vLLM benchmark on Trn2. +# +# Requires a Neuron-FP8 preprocessed checkpoint (see +# `src/conversion_script/preprocess_mimo_v2_fp8.py`). The configs below +# all use moe_tp_degree=1 / moe_ep_degree=64 (experts sharded by expert +# parallelism only, no intra-expert TP split) because moe_tp_degree=64 collapses +# the per-rank FP8 blockwise scale to a singleton — per-rank expert +# intermediate is 32 rows, below the 128-row blockwise block, so +# NxDI's `_setup_for_scale` drops per-channel scale granularity. The resulting +# drift compounds across 47 MoE layers and gives repetition / output collapse. +# Using moe_ep_degree=64 keeps all of each expert's weight + scale on one rank +# (4 experts per rank), which preserves the blockwise scale intact. +# +# NxDI's TKG path refuses Expert Parallelism with BS < num_experts/top_k +# (256 / 8 = 32 for Flash), so the smallest working batch size here is 32. +# If you want BS=1 behaviour, the FP8 path is not currently supported on +# this model on Trn2 — use the BF16 checkpoint with the old bench recipe +# (`moe_tp_degree=64, moe_ep_degree=1, batch_size=1`). + +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8}" +# The NxDI contrib MiMo-V2.5-Pro modeling code is registered into vLLM / +# NxDI lookup tables by vllm-neuron's register() hook using this env var. +# Default to this contrib package's own src/ relative to the script. +: "${NXDI_CONTRIB_MIMO_V2_FLASH_SRC:=$(cd "$(dirname "$0")/.." && pwd)/src}" +export NXDI_CONTRIB_MIMO_V2_FLASH_SRC + +# First-time Flash FP8 compile takes 30-60 minutes; extend vLLM's ready +# timeout and the compiler's environment variables for FP8 numerics. +export VLLM_ENGINE_READY_TIMEOUT_S=7200 + +PORT=8000 +RESULTS_DIR="/tmp/bench_results/mimo_v25_pro" +mkdir -p "$RESULTS_DIR" + +# Common neuron config shared across all MiMo-V2.5-Pro FP8 configs. +# save_sharded_checkpoint=true persists per-rank sharded weights to +# /weights/tp{N}_sharded_checkpoint.safetensors during compile; +# load() then reads those directly (~30s) instead of re-sharding the entire +# checkpoint on every vllm-neuron startup (~10+ min). +COMMON_MIMO_CONFIG='"tp_degree": 64, + "logical_nc_config": 2, + "fused_qkv": false, + "sequence_parallel_enabled": false, + "glu_mlp": true, + "normalize_top_k_affinities": true, + "save_sharded_checkpoint": true, + "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, + "quantized": true, + "quantized_checkpoints_path": "'"$MODEL_PATH"'", + "quantization_dtype": "f8e4m3", + "quantization_type": "blockwise_symmetric", + "quantization_block_axis": [1, 2], + "quantization_block_size": [128, 128], + "modules_to_not_convert": ["embed_tokens", "lm_head", "norm", "router", "o_proj"], + "blockwise_matmul_config": {"use_shard_on_block_dynamic_while": true, "block_sharding_strategy": "PING_PONG"}' + +# Helper: wait for vLLM server to be ready. First-time compilation of a +# 256-expert MoE model takes 30-90 minutes, so we poll for up to 2 hours. +wait_for_server() { + echo " Waiting for vLLM server to be ready (up to 2h for first compile)..." + local interval=10 + local max_attempts=720 # 720 * 10s = 7200s = 2h + local start=$SECONDS + for i in $(seq 1 $max_attempts); do + if curl -s http://localhost:$PORT/health > /dev/null 2>&1; then + echo " Server ready! (waited $((SECONDS - start))s)" + return 0 + fi + # Show a progress blip every minute so the user knows we're alive + if [ $((i % 6)) -eq 0 ]; then + echo " ...still waiting ($((SECONDS - start))s elapsed)" + fi + sleep $interval + done + echo " ERROR: Server did not start within $((max_attempts * interval))s" + return 1 +} + +# Helper: run benchmark +run_bench() { + local config_name=$1 + local concurrency=$2 + local num_prompts=$3 + + echo " Benchmark: concurrency=$concurrency, prompts=$num_prompts" + vllm bench serve \ + --backend vllm \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --endpoint /v1/completions \ + --dataset-name random \ + --num-prompts "$num_prompts" \ + --random-input-len 900 \ + --random-output-len 90 \ + --random-range-ratio 0.03 \ + --max-concurrency "$concurrency" \ + 2>&1 | tee "$RESULTS_DIR/${config_name}_c${concurrency}.txt" + echo "" +} + +# Helper: stop server +stop_server() { + echo " Stopping vLLM server..." + pkill -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true + sleep 5 +} + +# Helper: quick sanity check +sanity_check() { + echo " Running sanity check..." + curl -s http://localhost:$PORT/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [{"role": "user", "content": "What is 1+1? Answer briefly."}], + "model": "'"$MODEL_PATH"'", + "max_tokens": 64, + "temperature": 0.0, + "stream": false + }' | python3 -c "import sys,json; r=json.load(sys.stdin); print(' Sanity:', r['choices'][0]['message']['content'][:100])" 2>/dev/null || echo " Sanity check: could not parse response" +} + +echo "==========================================" +echo "MiMo-V2.5-Pro FP8 Performance Benchmark" +echo "==========================================" +echo "Model: $MODEL_PATH" +echo "Results: $RESULTS_DIR" +echo "" + +############################################################################### +# Config 1: BS=32, TP=64 + moe_tp=1/moe_ep=64, CB + bucketing (smallest BS +# that satisfies NxDI's Expert-Parallel BS >= num_experts/top_k requirement). +############################################################################### +CONFIG_NAME="bs32_tp64_moetp1_ep64" +echo "--- Config 1: BS=32, moe_tp=1/moe_ep=64, CB + bucketing ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 32 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + '"$COMMON_MIMO_CONFIG"', + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 32, + "ctx_batch_size": 1, + "tkg_batch_size": 32, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "enable_bucketing": true, + "context_encoding_buckets": [1024], + "token_generation_buckets": [1024], + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +run_bench "$CONFIG_NAME" 16 128 +run_bench "$CONFIG_NAME" 32 128 +stop_server + +############################################################################### +# Config 2: BS=128, TP=64 + moe_tp=1/moe_ep=64, CB + bucketing (throughput). +############################################################################### +CONFIG_NAME="bs128_tp64_moetp1_ep64" +echo "--- Config 2: BS=128, moe_tp=1/moe_ep=64, CB + bucketing ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 128 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + '"$COMMON_MIMO_CONFIG"', + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 128, + "ctx_batch_size": 1, + "tkg_batch_size": 128, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "enable_bucketing": true, + "context_encoding_buckets": [1024], + "token_generation_buckets": [1024], + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +run_bench "$CONFIG_NAME" 16 128 +run_bench "$CONFIG_NAME" 32 128 +run_bench "$CONFIG_NAME" 128 512 +stop_server + +echo "==========================================" +echo "MiMo-V2.5-Pro FP8 benchmarks complete!" +echo "Results saved to: $RESULTS_DIR" +echo "==========================================" +ls -la "$RESULTS_DIR" diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh new file mode 100755 index 00000000..45729cd2 --- /dev/null +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# Run a single vllm-bench-serve pass against an already-running vLLM server. +# +# Unlike bench_mimo_v2.sh this script does NOT launch or kill the vLLM +# server — you bring your own. That makes it convenient when the bench driver +# in bench_mimo_v2.sh times out during first-time compilation: the server +# keeps running, and once it's ready you can collect numbers with this. +# +# Usage: +# bash run_bench_single.sh # defaults: c=1, 16 prompts +# CONCURRENCY=16 NUM_PROMPTS=128 bash run_bench_single.sh +# CONFIG_NAME=bs32_tp1_ep64_opt CONCURRENCY=16 NUM_PROMPTS=128 bash run_bench_single.sh +# +# Environment knobs: +# PORT vLLM server port (default 8000) +# MIMO_V2_FLASH_PATH Path to the BF16 checkpoint (default +# /opt/dlami/nvme/models/MiMo-V2.5-Pro-BF16) +# CONCURRENCY --max-concurrency (default 1) +# NUM_PROMPTS --num-prompts (default 16) +# INPUT_LEN --random-input-len (default 900) +# OUTPUT_LEN --random-output-len (default 90) +# RANGE_RATIO --random-range-ratio (default 0.03) +# CONFIG_NAME Used in the output filename (default bs1_tp64_ep1) +# RESULTS_DIR Where to dump per-run log (default /tmp/bench_results/mimo_v25_pro) + +set -e + +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-BF16}" +PORT="${PORT:-8000}" +CONCURRENCY="${CONCURRENCY:-1}" +NUM_PROMPTS="${NUM_PROMPTS:-16}" +INPUT_LEN="${INPUT_LEN:-900}" +OUTPUT_LEN="${OUTPUT_LEN:-90}" +RANGE_RATIO="${RANGE_RATIO:-0.03}" +CONFIG_NAME="${CONFIG_NAME:-bs1_tp64_ep1}" +RESULTS_DIR="${RESULTS_DIR:-/tmp/bench_results/mimo_v25_pro}" + +mkdir -p "$RESULTS_DIR" + +echo "==========================================" +echo "MiMo-V2.5-Pro single-run benchmark" +echo "==========================================" +echo " Model: $MODEL_PATH" +echo " Port: $PORT" +echo " Config: $CONFIG_NAME" +echo " Concurrency: $CONCURRENCY" +echo " Prompts: $NUM_PROMPTS" +echo " Input len: $INPUT_LEN Output len: $OUTPUT_LEN" +echo " Results: $RESULTS_DIR/${CONFIG_NAME}_c${CONCURRENCY}.txt" +echo "" + +# Quick health check +if ! curl -sf "http://localhost:$PORT/health" > /dev/null; then + echo "ERROR: vLLM server is not responding on http://localhost:$PORT" + echo "Start it first (e.g., bench_mimo_v2.sh) and wait until" + echo "'Application startup complete.' is printed." + exit 1 +fi + +vllm bench serve \ + --backend vllm \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --endpoint /v1/completions \ + --dataset-name random \ + --num-prompts "$NUM_PROMPTS" \ + --random-input-len "$INPUT_LEN" \ + --random-output-len "$OUTPUT_LEN" \ + --random-range-ratio "$RANGE_RATIO" \ + --max-concurrency "$CONCURRENCY" \ + 2>&1 | tee "$RESULTS_DIR/${CONFIG_NAME}_c${CONCURRENCY}.txt" + +echo "" +echo "Saved to: $RESULTS_DIR/${CONFIG_NAME}_c${CONCURRENCY}.txt" diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh new file mode 100755 index 00000000..9cef80ab --- /dev/null +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# Quick sanity check against an already-running vLLM server. +# +# Assumes vLLM is already listening on $PORT (default 8000) with MiMo-V2.5-Pro +# loaded. Sends a single chat completion and prints the model's reply. +# +# Usage: +# bash sanity_check.sh # uses defaults +# PORT=8001 bash sanity_check.sh # custom port +# PROMPT="..." bash sanity_check.sh # custom prompt + +set -e + +MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-BF16}" +PORT="${PORT:-8000}" +PROMPT="${PROMPT:-What is 1+1? Answer briefly.}" +MAX_TOKENS="${MAX_TOKENS:-64}" + +echo "Sanity check: POST /v1/chat/completions on port $PORT" +echo " Model: $MODEL_PATH" +echo " Prompt: $PROMPT" +echo " Max tokens: $MAX_TOKENS" +echo "" + +# Health check first — fail fast if server isn't up. +if ! curl -sf "http://localhost:$PORT/health" > /dev/null; then + echo "ERROR: vLLM server is not responding on http://localhost:$PORT" + echo "Start it with 'bash bench_mimo_v2.sh' or your own launcher first." + exit 1 +fi + +RESPONSE=$(curl -s "http://localhost:$PORT/v1/chat/completions" \ + -H 'Content-Type: application/json' \ + -d "$(cat </dev/null || echo "$RESPONSE" +echo "" + +# Extract the model's reply for a human-friendly one-liner summary. +REPLY=$(echo "$RESPONSE" | python3 -c " +import json, sys +try: + r = json.load(sys.stdin) + print(r['choices'][0]['message']['content'].strip()) +except Exception as e: + print(f'(could not parse reply: {e})') +" 2>/dev/null) + +echo "Model reply: $REPLY" diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py new file mode 100755 index 00000000..b67a95f8 --- /dev/null +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +"""Minimal compile+load smoke test for MiMo-V2.5-Pro FP8 on Trn2. + +Bypasses vLLM entirely so we can iterate on the preprocessed Neuron-FP8 +checkpoint without paying vllm-neuron's startup cost. Builds the Flash BS=1 +recipe (TP=64, EP=1, blockwise FP8 for routed experts), compiles to a temp +dir, then loads. EP=1 lets the TKG path enter forward_selective_loading +legally so BS=1 compiles — with EP>1 NxDI raises NotImplementedError and +forces BS>=num_experts/top_k = 32. + +STAGE controls how far we go: + instantiate | compile | load | all (default: all) + +DRY_RUN=1 does HLO-only compile (no torch.jit.save + shard). Fastest sanity +check for the preprocessed checkpoint. SKIP_WARMUP=1 on load() skips the +forward pass that allocates the shared scratchpad — useful when HBM is +tight. + +Run under /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16 (same venv used +by the bench script). +""" + +import os +import sys +import time +import traceback + +MODEL_PATH = os.environ.get( + "MIMO_V25_PRO_MODEL_PATH", + "/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8", +) +COMPILED_PATH = os.environ.get( + "MIMO_V25_PRO_COMPILED_PATH", + "/opt/dlami/nvme/compiled/mimo_v25_pro_moetp16_ep4_bs48/", +) + +TP_DEGREE = int(os.environ.get("TP_DEGREE", "64")) +SEQ_LEN = int(os.environ.get("SEQ_LEN", "1024")) +# BS=48 is the minimum that avoids forward_selective_loading on decode: +# `BS * top_k / num_experts >= 1.0` → BS >= 384/8 = 48. At BS=1 the TKG +# path raises `NotImplementedError: Selective Loading with Expert parallelism`. +BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "48")) +CTX_BATCH_SIZE = int(os.environ.get("CTX_BATCH_SIZE", "1")) +# moe_tp=16 / moe_ep=4: after moetp1/ep64 (prefill broken, E_local=6) and +# moetp32/ep2 (HBM OOM by 28MB at load, E_local=192), try the middle ground. +# E_local=384/4=96. Also intermediate/moe_tp = 2048/16 = 128 which matches +# the FP8 scale block size exactly — avoids the stride_fix workaround path. +# world_size = 16*4 = 64 OK. +MOE_TP = int(os.environ.get("MOE_TP", "16")) +MOE_EP = int(os.environ.get("MOE_EP", "4")) + +STAGE = os.environ.get("STAGE", "all").lower() + +os.makedirs(COMPILED_PATH, exist_ok=True) + +# NxDI's model builder uses a per-process temp workdir for HLO/NEFF staging +# (BASE_COMPILE_WORK_DIR, default "/tmp/nxd_model/"). If two compiles run in +# parallel with the same default, they silently overwrite each other's +# .hlo_module.pb files and one or both compilations crash with +# "neuronx-cc returned non-zero exit status 70". Pin the workdir to a +# unique per-COMPILED_PATH subdir to stay safe under any parallel invocation. +os.environ.setdefault( + "BASE_COMPILE_WORK_DIR", + os.path.join("/tmp/nxd_model", os.path.basename(COMPILED_PATH.rstrip("/"))), +) + + +def main(): + from neuronx_distributed_inference.models.config import MoENeuronConfig + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + + # Import the contrib wrapper (sibling src dir). + contrib_src = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "src", + ) + sys.path.insert(0, os.path.abspath(contrib_src)) + + from modeling_mimo_v2 import ( + MiMoV2InferenceConfig, + NeuronMiMoV2ForCausalLM, + ) + + print(f"[smoke] MODEL_PATH={MODEL_PATH}") + print(f"[smoke] COMPILED_PATH={COMPILED_PATH}") + print(f"[smoke] TP_DEGREE={TP_DEGREE}, SEQ_LEN={SEQ_LEN}, BS={BATCH_SIZE}") + print(f"[smoke] MOE_TP={MOE_TP}, MOE_EP={MOE_EP}") + print(f"[smoke] STAGE={STAGE}") + + print("[smoke] Building MoENeuronConfig (quantized FP8 MoE, blockwise_symmetric)...") + # NOTE: ep_degree at the top level controls the OUTER (full model) + # expert-parallel factor, which multiplies world_size to + # tp_degree * ep_degree and duplicates non-MoE weights per replica. + # At world_size > 64 on a 64-NC Trn2, sharded weights grow accordingly + # (e.g. tp=64 + ep=4 -> 256 ranks -> 4x the sharded checkpoint size, + # and at runtime the model doesn't fit on the device). For MoE-only + # EP we want ep_degree=1 at the outer level and the per-MoE split + # controlled solely by moe_ep_degree (which Pro's working benches + # also do). Keep ep_degree=1 unconditionally. + neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + ep_degree=1, + logical_nc_config=2, + batch_size=BATCH_SIZE, + max_batch_size=BATCH_SIZE, + ctx_batch_size=CTX_BATCH_SIZE, + tkg_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + n_active_tokens=128, + torch_dtype="bfloat16", + capacity_factor=1.0, + glu_mlp=True, + moe_ep_degree=MOE_EP, + moe_tp_degree=MOE_TP, + context_encoding_buckets=[SEQ_LEN], + router_config={"act_fn": "sigmoid", "dtype": "float32"}, + # SDK 2.29 ships only bwmm_shard_on_block / bwmm_shard_on_intermediate; + # default routes to _call_shard_hidden_kernel which is missing, so we + # take the shard-on-block path via this flag. Matches Flash + Kimi. + blockwise_matmul_config={ + "use_shard_on_block_dynamic_while": True, + "block_sharding_strategy": "PING_PONG", + }, + # Persist sharded FP8 weights to disk so subsequent load()s skip the + # ~10-minute shard_checkpoint step (writes weights/tp{0..63}_*.safetensors + # on NVMe; NxDI load() reads these directly when present). + save_sharded_checkpoint=True, + # FP8 blockwise for routed experts (Kimi-K2 recipe). + quantized=True, + quantized_checkpoints_path=MODEL_PATH, + quantization_dtype="f8e4m3", + quantization_type="blockwise_symmetric", + quantization_block_axis=[1, 2], + quantization_block_size=[128, 128], + modules_to_not_convert=[ + "embed_tokens", + "lm_head", + "norm", + "router", + "o_proj", + ], + ) + + print("[smoke] Building MiMoV2InferenceConfig...") + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + config = MiMoV2InferenceConfig( + neuron_config, load_config=load_pretrained_config(hf_config=hf_config) + ) + print(f"[smoke] config.hidden_size={config.hidden_size}") + print(f"[smoke] config.num_hidden_layers={config.num_hidden_layers}") + print(f"[smoke] config.n_routed_experts={config.n_routed_experts}") + print(f"[smoke] config.num_experts_per_tok={config.num_experts_per_tok}") + print(f"[smoke] config.layer_uses_moe[:5]={config.layer_uses_moe[:5]}") + print(f"[smoke] config.layer_attention_types[:5]={config.layer_attention_types[:5]}") + + print("[smoke] Instantiating NeuronMiMoV2ForCausalLM (build model-on-cpu)...") + t0 = time.time() + model = NeuronMiMoV2ForCausalLM(MODEL_PATH, config) + print(f"[smoke] Instantiated in {time.time() - t0:.1f}s") + + if STAGE == "instantiate": + print("[smoke] STAGE=instantiate only, skipping compile/load.") + return + + DRY_RUN = os.environ.get("DRY_RUN", "0") == "1" + if STAGE in ("compile", "all"): + label = "Dry-run compile (HLO only)" if DRY_RUN else "Full compile" + print(f"[smoke] {label} -> {COMPILED_PATH}") + t0 = time.time() + try: + model.compile(COMPILED_PATH, dry_run=DRY_RUN) + print(f"[smoke] {label} OK in {time.time() - t0:.1f}s") + except Exception: + print(f"[smoke] {label} FAILED:") + traceback.print_exc() + raise + + if STAGE in ("load", "all") and not DRY_RUN: + SKIP_WARMUP = os.environ.get("SKIP_WARMUP", "1") == "1" + print(f"[smoke] Loading compiled model from {COMPILED_PATH} (skip_warmup={SKIP_WARMUP})") + t0 = time.time() + model.load(COMPILED_PATH, skip_warmup=SKIP_WARMUP) + print(f"[smoke] Loaded in {time.time() - t0:.1f}s") + + print("[smoke] Done.") + + +if __name__ == "__main__": + try: + main() + except Exception: + traceback.print_exc() + sys.exit(1) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py new file mode 100755 index 00000000..f357aee5 --- /dev/null +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 +"""Minimal generate smoke test for MiMo-V2.5-Pro FP8 on Trn2. + +Assumes the compiled NEFF already exists at MIMO_V25_PRO_COMPILED_PATH +(from smoke_compile_mimo_v2.py). Rebuilds the same MoENeuronConfig / +Flash wrapper, loads with skip_warmup=False, and generates 20 tokens for a +single prompt via HuggingFaceGenerationAdapter. Purpose: sanity-check that +the FP8 MoE + preprocessed scales actually produce coherent tokens. + +Run under /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16. +""" + +import os +import sys +import time +import traceback + +MODEL_PATH = os.environ.get( + "MIMO_V25_PRO_MODEL_PATH", + "/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8", +) +COMPILED_PATH = os.environ.get( + "MIMO_V25_PRO_COMPILED_PATH", + "/opt/dlami/nvme/compiled/mimo_v25_pro_tp64_moetp1_ep64_fp8/", +) + +# Must match smoke_compile_mimo_v2.py exactly, else load() sees a +# mismatched NEFF. +TP_DEGREE = int(os.environ.get("TP_DEGREE", "64")) +SEQ_LEN = int(os.environ.get("SEQ_LEN", "1024")) +BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "48")) # must match smoke_compile +CTX_BATCH_SIZE = int(os.environ.get("CTX_BATCH_SIZE", "1")) +MOE_TP = int(os.environ.get("MOE_TP", "16")) +MOE_EP = int(os.environ.get("MOE_EP", "4")) + +PROMPT = os.environ.get( + "MIMO_V25_PRO_PROMPT", + "Hello! Please introduce yourself in one sentence.", +) +MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "20")) + +# Keep the per-compile BASE_COMPILE_WORK_DIR in sync with +# smoke_compile_mimo_v2.py so load() under the same COMPILED_PATH +# doesn't collide with a concurrent compile or reuse a stale workdir. +os.environ.setdefault( + "BASE_COMPILE_WORK_DIR", + os.path.join("/tmp/nxd_model", os.path.basename(COMPILED_PATH.rstrip("/"))), +) + + +def main(): + from transformers import AutoConfig, AutoTokenizer, GenerationConfig + + from neuronx_distributed_inference.models.config import MoENeuronConfig + from neuronx_distributed_inference.utils.hf_adapter import ( + HuggingFaceGenerationAdapter, + load_pretrained_config, + ) + + contrib_src = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "src", + ) + sys.path.insert(0, os.path.abspath(contrib_src)) + + from modeling_mimo_v2 import ( + MiMoV2InferenceConfig, + NeuronMiMoV2ForCausalLM, + ) + + print(f"[gen] MODEL_PATH={MODEL_PATH}") + print(f"[gen] COMPILED_PATH={COMPILED_PATH}") + print(f"[gen] TP={TP_DEGREE}, SEQ={SEQ_LEN}, BS={BATCH_SIZE}") + + # Outer ep_degree must match the compile-time value (kept at 1 so + # world_size = tp_degree; see smoke_compile_mimo_v2.py comment). + neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + ep_degree=1, + logical_nc_config=2, + batch_size=BATCH_SIZE, + max_batch_size=BATCH_SIZE, + ctx_batch_size=CTX_BATCH_SIZE, + tkg_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + n_active_tokens=128, + torch_dtype="bfloat16", + capacity_factor=1.0, + glu_mlp=True, + moe_ep_degree=MOE_EP, + moe_tp_degree=MOE_TP, + context_encoding_buckets=[SEQ_LEN], + router_config={"act_fn": "sigmoid", "dtype": "float32"}, + blockwise_matmul_config={ + "use_shard_on_block_dynamic_while": True, + "block_sharding_strategy": "PING_PONG", + }, + save_sharded_checkpoint=True, + quantized=True, + quantized_checkpoints_path=MODEL_PATH, + quantization_dtype="f8e4m3", + quantization_type="blockwise_symmetric", + quantization_block_axis=[1, 2], + quantization_block_size=[128, 128], + modules_to_not_convert=[ + "embed_tokens", + "lm_head", + "norm", + "router", + "o_proj", + ], + ) + + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + config = MiMoV2InferenceConfig( + neuron_config, load_config=load_pretrained_config(hf_config=hf_config) + ) + + print("[gen] Instantiating model...") + t0 = time.time() + model = NeuronMiMoV2ForCausalLM(MODEL_PATH, config) + print(f"[gen] Instantiated in {time.time() - t0:.1f}s") + + # skip_warmup=False so generate() hits a primed graph (the warmup forward + # allocates the shared scratchpad the generation path needs). + print(f"[gen] Loading from {COMPILED_PATH} (skip_warmup=False)") + t0 = time.time() + model.load(COMPILED_PATH, skip_warmup=False) + print(f"[gen] Loaded in {time.time() - t0:.1f}s") + + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + adapter = HuggingFaceGenerationAdapter(model) + + # When CHAT_TEMPLATE=1, wrap the raw prompt in the checkpoint's chat + # template (system + user turns with <|im_start|>/<|im_end|> markers and + # trailing assistant cue). Matches how vllm /v1/chat/completions prepares + # inputs. Without this, the model free-continues the prompt as raw text + # instead of answering it. + use_chat_template = os.environ.get("CHAT_TEMPLATE", "0") == "1" + minimal_chat = os.environ.get("MINIMAL_CHAT", "0") == "1" + if minimal_chat: + # Skip the Pro default system prompt entirely; wrap prompt in bare + # <|im_start|>user ... <|im_end|><|im_start|>assistant\n framing. + templated = ( + f"<|im_start|>user\n{PROMPT}<|im_end|>" + f"<|im_start|>assistant\n" + ) + print(f"[gen] minimal-chat prompt ({len(templated)} chars, no system)") + inputs = tokenizer( + [templated] * BATCH_SIZE, + return_tensors="pt", + padding=True, + add_special_tokens=False, + ) + elif use_chat_template: + system = os.environ.get("CHAT_SYSTEM", "") + messages = [] + if system: + messages.append({"role": "system", "content": system}) + messages.append({"role": "user", "content": PROMPT}) + templated = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, + ) + print(f"[gen] chat-templated prompt ({len(templated)} chars)") + inputs = tokenizer([templated] * BATCH_SIZE, return_tensors="pt", padding=True) + else: + inputs = tokenizer([PROMPT] * BATCH_SIZE, return_tensors="pt", padding=True) + gen_config = GenerationConfig( + max_new_tokens=MAX_NEW_TOKENS, + min_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + pad_token_id=getattr(tokenizer, "pad_token_id", None) or tokenizer.eos_token_id, + ) + + # DUMP_LOGITS=1 -> request scores so we can see top-k per step. + dump_logits = os.environ.get("DUMP_LOGITS", "0") == "1" + if dump_logits: + gen_config.output_scores = True + gen_config.return_dict_in_generate = True + + print(f"[gen] prompt: {PROMPT!r}") + print(f"[gen] input_ids.shape={tuple(inputs['input_ids'].shape)}") + t0 = time.time() + output = adapter.generate( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + generation_config=gen_config, + ) + dt = time.time() - t0 + + if dump_logits and hasattr(output, "sequences"): + output_ids = output.sequences + scores = output.scores # tuple of [bs, vocab] per step + else: + output_ids = output + scores = None + + prompt_len = inputs["input_ids"].shape[1] + new_tokens = output_ids[0, prompt_len:] + decoded = tokenizer.decode(new_tokens, skip_special_tokens=True) + full = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + print(f"[gen] generated {new_tokens.numel()} tokens in {dt:.2f}s " + f"({new_tokens.numel() / dt:.2f} tok/s)") + print(f"[gen] new token ids: {new_tokens.tolist()}") + print(f"[gen] new text : {decoded!r}") + print(f"[gen] full text : {full!r}") + + if scores is not None: + import torch as _t + print("[gen] === top-5 per decode step (batch slot 0) ===") + for step, step_logits in enumerate(scores): + lp = _t.log_softmax(step_logits[0].float(), dim=-1) + top_lp, top_id = _t.topk(lp, 5) + parts = [] + for l, i in zip(top_lp.tolist(), top_id.tolist()): + tok = tokenizer.decode([i]).replace("\n", "\\n") + parts.append(f"({tok!r}:{i}:{l:.2f})") + chosen = new_tokens[step].item() + print(f" step {step:3d} chose id={chosen} top5={' '.join(parts)}") + + print("[gen] Done.") + + +if __name__ == "__main__": + try: + main() + except Exception: + traceback.print_exc() + sys.exit(1) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/vllm-neuron-patch.patch b/contrib/models/MiMo-V2.5-Pro/perf_test/vllm-neuron-patch.patch new file mode 100644 index 00000000..d8a85b89 --- /dev/null +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/vllm-neuron-patch.patch @@ -0,0 +1,107 @@ +diff --git a/vllm_neuron/worker/neuronx_distributed_model_loader.py b/vllm_neuron/worker/neuronx_distributed_model_loader.py +index d2099eb..0c162e4 100644 +--- a/vllm_neuron/worker/neuronx_distributed_model_loader.py ++++ b/vllm_neuron/worker/neuronx_distributed_model_loader.py +@@ -922,6 +922,94 @@ def _camel_to_kebab(name: str) -> str: + return re.sub("([a-z0-9])([A-Z])", r"\1-\2", s1).lower() + + ++ ++def _patch_autoconfig_trust_remote_code(): ++ """Monkey-patch ``AutoConfig.from_pretrained`` to default ``trust_remote_code=True``. ++ ++ NxDI's ``hf_adapter.load_config`` calls ``AutoConfig.from_pretrained(path)`` ++ without ``trust_remote_code``. Contrib models like MiMo-V2.5-Pro that ++ ship a ``configuration_*.py`` with the checkpoint require custom code ++ execution, so the default behaviour crashes with ``ValueError: The ++ repository ... contains custom code which must be executed``. ++ ++ vLLM's top-level ``--trust-remote-code`` flag only affects vLLM's own ++ config load, not NxDI's. Patching here is cheap and idempotent. ++ """ ++ try: ++ from transformers import AutoConfig ++ except ImportError: ++ return ++ if getattr(AutoConfig, "_nxdi_contrib_patched", False): ++ return ++ _orig = AutoConfig.from_pretrained ++ ++ def _patched(*args, **kwargs): ++ kwargs.setdefault("trust_remote_code", True) ++ return _orig(*args, **kwargs) ++ ++ AutoConfig.from_pretrained = _patched ++ AutoConfig._nxdi_contrib_patched = True ++ ++ ++def _register_contrib_models(): ++ """Lazy-register NxDI contrib models on each process that calls the loader. ++ ++ Driven by env vars: ++ NXDI_CONTRIB_MIMO_V2_FLASH_SRC -> path to contrib MiMo-V2.5-Pro src/ ++ NXDI_CONTRIB_MINIMAX_M2_SRC -> path to contrib MiniMax-M2 src/ ++ ++ Registers the contrib model class into NxDI's MODEL_TYPES and, where ++ vLLM does not already know the architecture, registers it into vLLM's ++ ModelRegistry. Runs every time _get_neuron_model_cls is called so that ++ vLLM's spawn'd EngineCore workers (which don't inherit the parent's ++ module-level state) pick up the registration too. Registration is ++ idempotent. ++ """ ++ import os as _os ++ import sys as _sys ++ import warnings as _w ++ ++ _patch_autoconfig_trust_remote_code() ++ ++ mimo_src = _os.environ.get("NXDI_CONTRIB_MIMO_V2_FLASH_SRC") ++ if mimo_src and _os.path.isdir(mimo_src) and "mimov2flash" not in MODEL_TYPES: ++ if mimo_src not in _sys.path: ++ _sys.path.insert(0, mimo_src) ++ try: ++ from modeling_mimo_v2 import NeuronMiMoV2ForCausalLM ++ MODEL_TYPES.setdefault( ++ "mimov2flash", {"causal-lm": NeuronMiMoV2ForCausalLM} ++ ) ++ try: ++ from vllm.model_executor.models.registry import ModelRegistry ++ if "MiMoV2ForCausalLM" not in ModelRegistry.get_supported_archs(): ++ ModelRegistry.register_model( ++ "MiMoV2ForCausalLM", NeuronMiMoV2ForCausalLM ++ ) ++ except ImportError: ++ pass ++ except Exception as e: ++ _w.warn( ++ f"Failed to register MiMo-V2.5-Pro contrib model: {e}", ++ category=UserWarning, ++ ) ++ ++ minimax_src = _os.environ.get("NXDI_CONTRIB_MINIMAX_M2_SRC") ++ if minimax_src and _os.path.isdir(minimax_src) and "minimaxm2" not in MODEL_TYPES: ++ if minimax_src not in _sys.path: ++ _sys.path.insert(0, minimax_src) ++ try: ++ from modeling_minimax_m2 import NeuronMiniMaxM2ForCausalLM ++ MODEL_TYPES.setdefault( ++ "minimaxm2", {"causal-lm": NeuronMiniMaxM2ForCausalLM} ++ ) ++ except Exception as e: ++ _w.warn( ++ f"Failed to register MiniMax-M2 contrib model: {e}", ++ category=UserWarning, ++ ) ++ ++ + def _get_neuron_model_cls(architecture: str): + """ + Get Neuron model class from architecture string. +@@ -941,6 +1029,7 @@ def _get_neuron_model_cls(architecture: str): + _get_neuron_model_cls("NeuronLlamaForCausalLM") + + """ ++ _register_contrib_models() + # Handle Neuron class name (starts with "Neuron") - strip prefix + if architecture.startswith("Neuron") and "For" in architecture: + original_architecture = architecture diff --git a/contrib/models/MiMo-V2.5-Pro/src/__init__.py b/contrib/models/MiMo-V2.5-Pro/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py b/contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py new file mode 100644 index 00000000..56eabeaa --- /dev/null +++ b/contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py @@ -0,0 +1,641 @@ +""" +Preprocess MiMo-V2.5-Pro FP8 checkpoint for Neuron inference. + +This is a streaming (per-layer) rewrite of preprocess_mimo_v2_fp8.py. The +original preprocess loaded the entire ~290 GB FP8 checkpoint into RAM via +load_state_dict(); that peaks well over 600 GB after dequantize/requantize +copies and is fragile. This version keeps a single safe_open handle live +at a time and emits per-layer safetensors shards, capping peak RAM at +~24 GB and finishing in ~20 minutes. + +MiMo-V2.5-Pro checkpoint layout: + - q_proj, k_proj, v_proj are FUSED into a single `qkv_proj` tensor per + layer (num_kv_heads interleaved groups, Pro-specific). We split into + three per-row-quantized projections via `split_qkv_fused()`. + - o_proj is BF16 (listed in quantization_config.ignored_layers); kept + as BF16 on the Neuron side (RowParallelLinear, not QuantizedRowParallel). + - Layer 0 is a dense MLP (moe_layer_freq[0] == 0) with intermediate_size + 16384; layers 1..69 are MoE with 384 experts each. + - Hybrid attention: 10 "full" layers (hybrid_layer_pattern[i] == 0) and + 60 "sliding window" layers (== 1). SWA layers carry + attention_sink_bias (add_swa_attention_sink_bias=True in the config; + add_full_attention_sink_bias=False, so full layers do NOT get it). + +Neuron-side rescaling (same as Pro/original-Flash): + - OCP FP8 e4m3 (±448) -> Neuron FP8 e4m3 (±240) with FP8_SCALING_FACTOR=448/240. + - Per-row scales for attention/dense-mlp projections (q/k/v/o, gate/up/down + of the dense layer). + - Blockwise (128x128) scales kept for MoE expert weights; per-expert weights + are transposed and fused to match ExpertFusedRowParallelLinear's packed + layout (gate_up_proj: [num_experts, H, 2*IM]; down_proj: [num_experts, IM, H]). + +Output layout: + save_path/ + config.json, tokenizer.*, chat_template.jinja if present + configuration_mimo_v2.py, modeling_mimo_v2.py (trust_remote_code) + model.safetensors.index.json (regenerated) + model_extras.safetensors (embed_tokens, norm, lm_head) + model_layer{N}.safetensors (one per decoder layer, N=0..47) + +Usage: + python preprocess_mimo_v2_fp8.py \\ + --hf_model_path /opt/dlami/nvme/models/MiMo-V2.5-Pro \\ + --save_path /opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8 \\ + --tp_degree 64 +""" + +import argparse +import gc +import json +import os +import shutil +import time +from typing import Dict, List, Optional, Tuple + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + + +FP8_SCALING_FACTOR = 448.0 / 240.0 +NEURON_FP8_MAX = 240.0 + + +# --------------------------------------------------------------------------- +# Quantization primitives +# --------------------------------------------------------------------------- + +def convert_bf16_to_fp8_per_row( + weight: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """BF16 [out, in] -> Neuron FP8 per-row (scales shape [out, 1]).""" + weight_float = weight.float() + row_max_abs = weight_float.abs().max(dim=1, keepdim=True)[0] + scales = torch.clamp(row_max_abs / NEURON_FP8_MAX, min=1e-10) + quantized = (weight_float / scales).to(torch.float8_e4m3fn) + return quantized, scales.to(torch.float32) + + +def rescale_fp8_to_per_row( + weight: torch.Tensor, scale: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Block-wise FP8 + blockwise scale -> Neuron per-row FP8. + + Dequantize to float32 using block broadcast, then per-row requantize. + """ + out_features, in_features = weight.shape + scale_h, scale_w = scale.shape + + block_h = (out_features + scale_h - 1) // scale_h + block_w = (in_features + scale_w - 1) // scale_w + + weight_float = weight.float() + dequantized = torch.zeros(out_features, in_features, dtype=torch.float32) + for i in range(scale_h): + for j in range(scale_w): + h0, h1 = i * block_h, min((i + 1) * block_h, out_features) + w0, w1 = j * block_w, min((j + 1) * block_w, in_features) + dequantized[h0:h1, w0:w1] = ( + weight_float[h0:h1, w0:w1] * scale[i, j].item() + ) + + row_max_abs = dequantized.abs().max(dim=1, keepdim=True)[0] + scales = torch.clamp(row_max_abs / NEURON_FP8_MAX, min=1e-10) + quantized = (dequantized / scales).to(torch.float8_e4m3fn) + return quantized, scales.to(torch.float32) + + +def rescale_fp8_weight_blockwise( + weight: torch.Tensor, scale: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Keep blockwise scales, just rescale into Neuron FP8 range. + + MoE expert weights stay block-quantized; only the dtype range changes. + """ + weight_bf16 = weight.bfloat16() + rescaled = (weight_bf16 / FP8_SCALING_FACTOR).to(torch.float8_e4m3fn) + neuron_scale = scale.float() * FP8_SCALING_FACTOR + return rescaled, neuron_scale.to(torch.float32) + + +# --------------------------------------------------------------------------- +# Streaming weight access (one open safetensors handle at a time) +# --------------------------------------------------------------------------- + +class LazyWeightMap: + """Lazily fetch tensors from sharded safetensors, keeping one handle live.""" + + def __init__(self, model_dir: str, weight_map: Dict[str, str]): + self.model_dir = model_dir + self.weight_map = weight_map + self._cur_filename: Optional[str] = None + self._cur_handle = None + + def _open(self, filename: str): + if self._cur_filename == filename: + return self._cur_handle + if self._cur_handle is not None: + self._cur_handle.__exit__(None, None, None) + self._cur_handle = None + path = os.path.join(self.model_dir, filename) + self._cur_handle = safe_open(path, framework="pt", device="cpu") + self._cur_handle.__enter__() + self._cur_filename = filename + return self._cur_handle + + def get(self, key: str) -> Optional[torch.Tensor]: + filename = self.weight_map.get(key) + if filename is None: + return None + return self._open(filename).get_tensor(key) + + def has(self, key: str) -> bool: + return key in self.weight_map + + def close(self): + if self._cur_handle is not None: + self._cur_handle.__exit__(None, None, None) + self._cur_handle = None + self._cur_filename = None + + +# --------------------------------------------------------------------------- +# Per-tensor helpers +# --------------------------------------------------------------------------- + +def _requantize_per_row(dequant: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """BF16/FP32 -> Neuron FP8 per-row.""" + row_max_abs = dequant.abs().max(dim=1, keepdim=True)[0] + scales = row_max_abs / NEURON_FP8_MAX + scales = torch.clamp(scales, min=1e-10) + quantized = (dequant / scales).to(torch.float8_e4m3fn) + return quantized, scales.to(torch.float32) + + +def split_qkv_fused( + qkv_weight: torch.Tensor, + qkv_scale: Optional[torch.Tensor], + num_q_heads: int, + num_kv_heads: int, + head_dim: int, + v_head_dim: int, +) -> Dict[str, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Split Pro's pre-fused qkv_proj into q/k/v. MiMo-V2.5-Pro specific. + + HF layout — cross-validated against sglang on H200: + `qkv_proj.weight` is NOT `[all_Q | all_K | all_V]`. It is num_kv_heads + interleaved groups, each holding (heads_per_group Q heads, 1 K head, + 1 V head) packed contiguously: + + group g (g = 0 .. num_kv_heads-1): + rows [g*R : g*R + qg] = Q heads [g*hpg : (g+1)*hpg] + rows [g*R+qg : g*R + qg + kg] = K head g + rows [g*R+qg+kg : g*R + R] = V head g + where + hpg = num_q_heads / num_kv_heads (e.g. 128/8 = 16) + qg = hpg * head_dim (e.g. 16 * 192 = 3072) + kg = 1 * head_dim (e.g. 192) + vg = 1 * v_head_dim (e.g. 128) + R = qg + kg + vg (e.g. 3392) + + Scale: per-group 27 scale rows covering 27*128 = 3456 "padded" rows: + 24 rows for Q (24 * 128 = 3072 real Q rows) + 2 rows for K (1 full block + 1 half-real/half-phantom block) + 1 row for V (128 rows) + Total: 8 * 27 = 216 scale rows, 8 * 3392 = 27136 weight rows. + + The "phantom" 64 rows sit between each group's K tail and V start in + *scale block coordinates* only; in the physical weight tensor, group g's + V is immediately followed by group (g+1)'s Q. We recover the correct + dequant by padding each group up to 3456 rows before applying the scale, + then stripping the phantom rows. + """ + in_features = qkv_weight.shape[1] + hpg = num_q_heads // num_kv_heads + qg_rows = hpg * head_dim + kg_rows = 1 * head_dim + vg_rows = 1 * v_head_dim + real_rows_per_group = qg_rows + kg_rows + vg_rows + total_real_rows = num_kv_heads * real_rows_per_group + + BLOCK = 128 + q_scale_rows_per_group = qg_rows // BLOCK + k_scale_rows_per_group = (kg_rows + BLOCK - 1) // BLOCK + v_scale_rows_per_group = (vg_rows + BLOCK - 1) // BLOCK + scale_rows_per_group = (q_scale_rows_per_group + + k_scale_rows_per_group + + v_scale_rows_per_group) + padded_rows_per_group = scale_rows_per_group * BLOCK + + assert qkv_weight.shape[0] == total_real_rows, ( + f"qkv_proj.weight row count {qkv_weight.shape[0]} != " + f"expected {total_real_rows} " + f"(num_kv_heads={num_kv_heads}, R={real_rows_per_group})" + ) + + if qkv_weight.dtype != torch.float8_e4m3fn or qkv_scale is None: + # BF16 path + w = qkv_weight.view(num_kv_heads, real_rows_per_group, in_features) + q_w = w[:, :qg_rows, :].reshape(num_kv_heads * qg_rows, in_features).contiguous() + k_w = w[:, qg_rows:qg_rows + kg_rows, :].reshape(num_kv_heads * kg_rows, in_features).contiguous() + v_w = w[:, qg_rows + kg_rows:, :].reshape(num_kv_heads * vg_rows, in_features).contiguous() + q_w2, q_s2 = convert_bf16_to_fp8_per_row(q_w) + k_w2, k_s2 = convert_bf16_to_fp8_per_row(k_w) + v_w2, v_s2 = convert_bf16_to_fp8_per_row(v_w) + return {"q_proj": (q_w2, q_s2), "k_proj": (k_w2, k_s2), "v_proj": (v_w2, v_s2)} + + # FP8 + blockwise scale path. + expected_scale_rows = num_kv_heads * scale_rows_per_group + expected_scale_cols = (in_features + BLOCK - 1) // BLOCK + assert qkv_scale.shape == (expected_scale_rows, expected_scale_cols), ( + f"qkv scale shape {tuple(qkv_scale.shape)} != expected " + f"({expected_scale_rows}, {expected_scale_cols})" + ) + + w = qkv_weight.to(torch.float32).view( + num_kv_heads, real_rows_per_group, in_features + ) + w_padded = torch.zeros( + num_kv_heads, padded_rows_per_group, in_features, dtype=torch.float32 + ) + w_padded[:, :real_rows_per_group, :] = w + + s = qkv_scale.to(torch.float32).view( + num_kv_heads, scale_rows_per_group, expected_scale_cols + ) + s_exp = s.repeat_interleave(BLOCK, dim=1).repeat_interleave(BLOCK, dim=2) + s_exp = s_exp[:, :padded_rows_per_group, :in_features] + + deq_padded = w_padded * s_exp + deq = deq_padded[:, :real_rows_per_group, :] + + q_deq = deq[:, :qg_rows, :].reshape(num_kv_heads * qg_rows, in_features).contiguous() + k_deq = deq[:, qg_rows:qg_rows + kg_rows, :].reshape(num_kv_heads * kg_rows, in_features).contiguous() + v_deq = deq[:, qg_rows + kg_rows:, :].reshape(num_kv_heads * vg_rows, in_features).contiguous() + + q_w2, q_s2 = _requantize_per_row(q_deq) + k_w2, k_s2 = _requantize_per_row(k_deq) + v_w2, v_s2 = _requantize_per_row(v_deq) + + return {"q_proj": (q_w2, q_s2), "k_proj": (k_w2, k_s2), "v_proj": (v_w2, v_s2)} + + +def _maybe_fp8_to_neuron_per_row( + weight: torch.Tensor, scale: Optional[torch.Tensor] +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """FP8 blockwise -> per-row, or BF16 -> FP8 per-row. Pass-through otherwise.""" + if weight.dtype == torch.float8_e4m3fn and scale is not None: + return rescale_fp8_to_per_row(weight, scale) + if weight.dtype == torch.bfloat16: + return convert_bf16_to_fp8_per_row(weight) + return weight, scale + + +# --------------------------------------------------------------------------- +# Per-layer processing +# --------------------------------------------------------------------------- + +def process_layer( + layer_idx: int, + lazy: LazyWeightMap, + config: dict, + is_dense: bool, + is_swa: bool, +) -> Dict[str, torch.Tensor]: + out: Dict[str, torch.Tensor] = {} + prefix = f"model.layers.{layer_idx}." + out_prefix = f"layers.{layer_idx}." + + # --- Layer norms (BF16, untouched) --- + for name in ("input_layernorm", "post_attention_layernorm"): + t = lazy.get(f"{prefix}{name}.weight") + if t is not None: + out[f"{out_prefix}{name}.weight"] = t.detach().clone() + + # --- Attention: Pro ships a pre-fused qkv_proj; Flash ships q/k/v split. + # Support both: detect qkv_proj.weight first. If present, split it using + # Pro's interleaved num_kv_heads-group layout (with phantom-row FP8 scale + # handling). Otherwise fall back to the Flash-style per-proj path. + qkv_w = lazy.get(f"{prefix}self_attn.qkv_proj.weight") + if qkv_w is not None: + qkv_s = lazy.get(f"{prefix}self_attn.qkv_proj.weight_scale_inv") + # Attention heads: use swa_* for SWA layers, else main. + if is_swa: + num_q = config.get("swa_num_attention_heads", config["num_attention_heads"]) + num_kv = config.get("swa_num_key_value_heads", config["num_key_value_heads"]) + hd = config.get("swa_head_dim", config.get("head_dim")) + vhd = config.get("swa_v_head_dim", config.get("v_head_dim", hd)) + else: + num_q = config["num_attention_heads"] + num_kv = config["num_key_value_heads"] + hd = config.get("head_dim") + vhd = config.get("v_head_dim", hd) + split = split_qkv_fused(qkv_w, qkv_s, num_q, num_kv, hd, vhd) + for proj, (w2, s2) in split.items(): + out[f"{out_prefix}self_attn.{proj}.weight"] = w2 + if s2 is not None: + out[f"{out_prefix}self_attn.{proj}.scale"] = s2 + else: + # Flash-style: q/k/v stored separately. + for proj in ("q_proj", "k_proj", "v_proj"): + w = lazy.get(f"{prefix}self_attn.{proj}.weight") + if w is None: + continue + s = lazy.get(f"{prefix}self_attn.{proj}.weight_scale_inv") + w2, s2 = _maybe_fp8_to_neuron_per_row(w, s) + out[f"{out_prefix}self_attn.{proj}.weight"] = w2 + if s2 is not None: + out[f"{out_prefix}self_attn.{proj}.scale"] = s2 + + # o_proj is listed in HF quantization_config.ignored_layers and ships as + # BF16; on Neuron it binds to a plain RowParallelLinear (see + # modeling_mimo_v2.py: self.o_proj = RowParallelLinear(...)), NOT a + # QuantizedRowParallel. Writing FP8 + .scale here would silently be + # reinterpreted as BF16 bytes at load time and produce garbage outputs. + # Keep BF16, never emit .scale. + o_w = lazy.get(f"{prefix}self_attn.o_proj.weight") + o_s = lazy.get(f"{prefix}self_attn.o_proj.weight_scale_inv") + if o_w is not None: + if o_w.dtype == torch.float8_e4m3fn: + # Defensive: if a future checkpoint FP8-quantizes o_proj, dequant + # blockwise back to BF16 (no per-row requant; RowParallelLinear has + # no .scale parameter). + assert o_s is not None, "FP8 o_proj requires weight_scale_inv" + out_features, in_features = o_w.shape + scale_h, scale_w = o_s.shape + block_h = (out_features + scale_h - 1) // scale_h + block_w = (in_features + scale_w - 1) // scale_w + wf = o_w.float() + tmp = torch.zeros(out_features, in_features, dtype=torch.float32) + for i in range(scale_h): + for j in range(scale_w): + h0, h1 = i * block_h, min((i + 1) * block_h, out_features) + w0, w1 = j * block_w, min((j + 1) * block_w, in_features) + tmp[h0:h1, w0:w1] = wf[h0:h1, w0:w1] * o_s[i, j].item() + o_bf16 = tmp.to(torch.bfloat16) + else: + o_bf16 = o_w.to(torch.bfloat16) + out[f"{out_prefix}self_attn.o_proj.weight"] = o_bf16.detach().clone() + + # --- attention_sink_bias: present only on SWA layers in MiMo-V2.5-Pro. + # config.add_swa_attention_sink_bias=True, add_full_attention_sink_bias=False. + if is_swa and config.get("add_swa_attention_sink_bias", False): + sink = lazy.get(f"{prefix}self_attn.attention_sink_bias") + if sink is not None: + out[f"{out_prefix}self_attn.attention_sink_bias"] = sink.detach().clone() + elif not is_swa and config.get("add_full_attention_sink_bias", False): + sink = lazy.get(f"{prefix}self_attn.attention_sink_bias") + if sink is not None: + out[f"{out_prefix}self_attn.attention_sink_bias"] = sink.detach().clone() + + # --- MLP: dense vs MoE --- + if is_dense: + # Dense MLP: gate_proj, up_proj, down_proj (FP8 blockwise in Flash layer 0). + for proj in ("gate_proj", "up_proj", "down_proj"): + w = lazy.get(f"{prefix}mlp.{proj}.weight") + if w is None: + continue + s = lazy.get(f"{prefix}mlp.{proj}.weight_scale_inv") + w2, s2 = _maybe_fp8_to_neuron_per_row(w, s) + out[f"{out_prefix}mlp.{proj}.weight"] = w2 + if s2 is not None: + out[f"{out_prefix}mlp.{proj}.scale"] = s2 + return out + + # --- MoE --- + # Router: mlp.gate -> mlp.router.linear_router + router_w = lazy.get(f"{prefix}mlp.gate.weight") + if router_w is not None: + out[f"{out_prefix}mlp.router.linear_router.weight"] = router_w.detach().clone() + router_bias = lazy.get(f"{prefix}mlp.gate.e_score_correction_bias") + if router_bias is not None: + # Pro-specific: HF bias has mean ~71 with per-expert std ~3e-4. NxDI + # casts router parameters to bf16 at load time, and bf16 step size at + # magnitude 71 is ~0.5 — which completely wipes out the per-expert + # std=3e-4 variation, collapsing all 384 experts to a single bias + # value (all 71.0) and reducing noaux_tc topk to plain sigmoid topk. + # Subtracting the mean first puts the bias at ~0, where bf16 step is + # 2.4e-4 (small enough to preserve the variation). topk is invariant + # to additive constants across all experts, so this is safe. + bias_f32 = router_bias.detach().float().clone() + bias_f32 = bias_f32 - bias_f32.mean() + out[f"{out_prefix}mlp.router.e_score_correction_bias"] = bias_f32 + + num_experts = config["n_routed_experts"] + + # Peek expert 0 to learn shapes/dtypes. + e0_gw = lazy.get(f"{prefix}mlp.experts.0.gate_proj.weight") + if e0_gw is None: + return out # no experts (shouldn't happen for MoE layers, but be safe) + e0_gs = lazy.get(f"{prefix}mlp.experts.0.gate_proj.weight_scale_inv") + + if e0_gw.dtype == torch.float8_e4m3fn and e0_gs is not None: + sample_w, sample_s = rescale_fp8_weight_blockwise(e0_gw, e0_gs) + elif e0_gw.dtype == torch.bfloat16: + # Should not happen for Flash (experts ship in FP8); flag loudly. + raise NotImplementedError( + f"Layer {layer_idx} expert 0 gate_proj is BF16; Flash expects FP8." + ) + else: + sample_w, sample_s = e0_gw, e0_gs + + intermediate_size, hidden_size = sample_w.shape # [IM, H] + # Packed transpose layout: [num_experts, H, 2*IM] for gate_up. + gate_up_proj = torch.empty( + num_experts, hidden_size, 2 * intermediate_size, dtype=sample_w.dtype + ) + i_blocks, h_blocks = sample_s.shape # [IM_blocks, H_blocks] + gate_up_scale = torch.empty( + num_experts, h_blocks, 2 * i_blocks, dtype=sample_s.dtype + ) + + e0_dw = lazy.get(f"{prefix}mlp.experts.0.down_proj.weight") + e0_ds = lazy.get(f"{prefix}mlp.experts.0.down_proj.weight_scale_inv") + if e0_dw.dtype == torch.float8_e4m3fn and e0_ds is not None: + sample_dw, sample_ds = rescale_fp8_weight_blockwise(e0_dw, e0_ds) + else: + raise NotImplementedError( + f"Layer {layer_idx} expert 0 down_proj dtype {e0_dw.dtype} not handled." + ) + d_h_blocks, d_i_blocks = sample_ds.shape # [H_blocks, IM_blocks] + down_proj = torch.empty( + num_experts, intermediate_size, hidden_size, dtype=sample_dw.dtype + ) + down_scale = torch.empty( + num_experts, d_i_blocks, d_h_blocks, dtype=sample_ds.dtype + ) + + # Slot expert 0 (already rescaled above). + gate_up_proj[0, :, :intermediate_size] = sample_w.T + gate_up_scale[0, :, :i_blocks] = sample_s.T + e0_uw = lazy.get(f"{prefix}mlp.experts.0.up_proj.weight") + e0_us = lazy.get(f"{prefix}mlp.experts.0.up_proj.weight_scale_inv") + up_w0, up_s0 = rescale_fp8_weight_blockwise(e0_uw, e0_us) + gate_up_proj[0, :, intermediate_size:] = up_w0.T + gate_up_scale[0, :, i_blocks:] = up_s0.T + down_proj[0] = sample_dw.T + down_scale[0] = sample_ds.T + del e0_gw, e0_gs, e0_uw, e0_us, e0_dw, e0_ds + del sample_w, sample_s, sample_dw, sample_ds, up_w0, up_s0 + + for e in range(1, num_experts): + gw = lazy.get(f"{prefix}mlp.experts.{e}.gate_proj.weight") + gs = lazy.get(f"{prefix}mlp.experts.{e}.gate_proj.weight_scale_inv") + uw = lazy.get(f"{prefix}mlp.experts.{e}.up_proj.weight") + us = lazy.get(f"{prefix}mlp.experts.{e}.up_proj.weight_scale_inv") + dw = lazy.get(f"{prefix}mlp.experts.{e}.down_proj.weight") + ds = lazy.get(f"{prefix}mlp.experts.{e}.down_proj.weight_scale_inv") + g_w, g_s = rescale_fp8_weight_blockwise(gw, gs) + u_w, u_s = rescale_fp8_weight_blockwise(uw, us) + d_w, d_s = rescale_fp8_weight_blockwise(dw, ds) + gate_up_proj[e, :, :intermediate_size] = g_w.T + gate_up_proj[e, :, intermediate_size:] = u_w.T + gate_up_scale[e, :, :i_blocks] = g_s.T + gate_up_scale[e, :, i_blocks:] = u_s.T + down_proj[e] = d_w.T + down_scale[e] = d_s.T + del gw, gs, uw, us, dw, ds, g_w, g_s, u_w, u_s, d_w, d_s + + out[f"{out_prefix}mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj + out[f"{out_prefix}mlp.expert_mlps.mlp_op.gate_up_proj.scale"] = gate_up_scale + out[f"{out_prefix}mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj + out[f"{out_prefix}mlp.expert_mlps.mlp_op.down_proj.scale"] = down_scale + return out + + +# --------------------------------------------------------------------------- +# Shard saving / index +# --------------------------------------------------------------------------- + +def save_shard( + tensors: Dict[str, torch.Tensor], + save_path: str, + filename: str, + weight_map: Dict[str, str], +) -> int: + """Save a sub-state-dict; clone tensors so safetensors doesn't complain + about views of mmapped storage. Returns bytes written.""" + path = os.path.join(save_path, filename) + materialized: Dict[str, torch.Tensor] = {} + total_bytes = 0 + for k, v in tensors.items(): + if not v.is_contiguous(): + v = v.contiguous() + v = v.detach().clone() + materialized[k] = v + total_bytes += v.numel() * v.element_size() + save_file(materialized, path) + for k in materialized.keys(): + weight_map[k] = filename + del materialized + return total_bytes + + +# --------------------------------------------------------------------------- +# Main driver +# --------------------------------------------------------------------------- + +def process_flash_checkpoint(hf_model_path: str, save_path: str, tp_degree: int): + os.makedirs(save_path, exist_ok=True) + + with open(os.path.join(hf_model_path, "model.safetensors.index.json")) as f: + weight_map_in = json.load(f)["weight_map"] + + with open(os.path.join(hf_model_path, "config.json")) as f: + config = json.load(f) + + num_layers = config["num_hidden_layers"] + hybrid = config.get("hybrid_layer_pattern", [0] * num_layers) + moe_freq = config.get("moe_layer_freq", [1] * num_layers) + + print( + f"Processing {num_layers} decoder layers" + f" (full={sum(1 for v in hybrid if v == 0)}," + f" swa={sum(1 for v in hybrid if v == 1)}," + f" dense={sum(1 for v in moe_freq if v == 0)}," + f" moe={sum(1 for v in moe_freq if v == 1)})", + flush=True, + ) + + lazy = LazyWeightMap(hf_model_path, weight_map_in) + weight_map_out: Dict[str, str] = {} + + try: + for li in range(num_layers): + t0 = time.time() + is_dense = moe_freq[li] == 0 + is_swa = hybrid[li] == 1 + layer_sd = process_layer(li, lazy, config, is_dense=is_dense, is_swa=is_swa) + filename = f"model_layer{li}.safetensors" + size = save_shard(layer_sd, save_path, filename, weight_map_out) + del layer_sd + gc.collect() + tag = "dense" if is_dense else "moe " + attn = "swa " if is_swa else "full" + print( + f" layer {li:2d} [{tag} {attn}] {size/1e9:6.2f} GB in {time.time()-t0:5.1f}s", + flush=True, + ) + + print("Processing embed_tokens, norm, lm_head ...", flush=True) + extras: Dict[str, torch.Tensor] = {} + for src, dst in ( + ("model.embed_tokens.weight", "embed_tokens.weight"), + ("model.norm.weight", "norm.weight"), + ("lm_head.weight", "lm_head.weight"), + ): + t = lazy.get(src) + if t is not None: + extras[dst] = t.detach().clone() + else: + print(f" WARNING: missing {src}", flush=True) + if "lm_head.weight" not in extras and "embed_tokens.weight" in extras: + # Tied embeddings + extras["lm_head.weight"] = extras["embed_tokens.weight"].detach().clone() + save_shard(extras, save_path, "model_extras.safetensors", weight_map_out) + del extras + finally: + lazy.close() + + # --- Index file --- + total_size = 0 + for f in set(weight_map_out.values()): + total_size += os.path.getsize(os.path.join(save_path, f)) + index = { + "metadata": {"total_size": total_size}, + "weight_map": weight_map_out, + } + with open(os.path.join(save_path, "model.safetensors.index.json"), "w") as f: + json.dump(index, f, indent=2) + + # --- Copy auxiliary files (config.json, tokenizer, chat template, + # and crucially the trust_remote_code modules the HF config references). + for name in sorted(os.listdir(hf_model_path)): + if name.endswith(".safetensors"): + continue + if name == "model.safetensors.index.json": + continue + src = os.path.join(hf_model_path, name) + if os.path.isfile(src): + shutil.copy(src, os.path.join(save_path, name)) + + print(f"\nPreprocess complete. total_size={total_size/1e9:.2f} GB", flush=True) + print(f" tensors written: {len(weight_map_out)}", flush=True) + print(f" output dir: {save_path}", flush=True) + + +def main(): + parser = argparse.ArgumentParser( + description="Preprocess MiMo-V2.5-Pro FP8 checkpoint for Neuron inference" + ) + parser.add_argument("--hf_model_path", required=True) + parser.add_argument("--save_path", required=True) + parser.add_argument("--tp_degree", type=int, default=64, + help="Tensor parallelism (currently informational only; " + "the framework does the TP sharding at load time).") + args = parser.parse_args() + process_flash_checkpoint(args.hf_model_path, args.save_path, args.tp_degree) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/MiMo-V2.5-Pro/src/modeling_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/src/modeling_mimo_v2.py new file mode 100644 index 00000000..85014672 --- /dev/null +++ b/contrib/models/MiMo-V2.5-Pro/src/modeling_mimo_v2.py @@ -0,0 +1,1676 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# This implementation is based on the MiMo-V2.5-Pro model from Xiaomi. +# Reference: https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro + +"""MiMo-V2.5-Pro model for NXD inference.""" + +import gc +import math +import warnings +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers.mappings import ( + gather_from_sequence_parallel_region, + gather_from_tensor_model_parallel_region_with_dim, +) +from neuronx_distributed.utils import cpu_mode + +from neuronx_distributed_inference.utils.distributed import ( + split_along_dim, + get_cp_rank, +) +from neuronx_distributed_inference.modules.attention.attention_process_groups import ( + get_context_parallel_attention_cp_group, +) + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + MoENeuronConfig, +) +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.moe_v2 import initialize_moe_module + +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel + +from torch_neuronx.xla_impl.ops import nki_jit + +_flash_fwd_call = nki_jit()(attention_isa_kernel) + + +def get_rmsnorm_cls(): + """Get appropriate RMSNorm class based on execution environment.""" + return MiMoV2RMSNorm if cpu_mode() else CustomRMSNorm + + +class MiMoV2RMSNorm(nn.Module): + """RMSNorm implementation for CPU mode.""" + + def __init__(self, hidden_size: int, eps: float = 1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class MiMoV2RotaryEmbedding(nn.Module): + """Rotary Position Embedding for MiMo-V2.5-Pro. + + Supports partial rotary embedding where only a fraction of dimensions + use rotary position encoding. + """ + + def __init__( + self, + dim: int, + max_position_embeddings: int = 262144, + base: float = 5000000.0, + partial_rotary_factor: float = 1.0, + ): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.partial_rotary_factor = partial_rotary_factor + + # Calculate the actual dimension used for rotary embedding + self.rope_dim = int(dim * partial_rotary_factor) + # Ensure rope_dim is even + self.rope_dim = self.rope_dim - (self.rope_dim % 2) + + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.rope_dim, 2, dtype=torch.float32) / self.rope_dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute rotary embeddings. + + Args: + x: Input tensor of shape (batch_size, seq_len, hidden_size) + position_ids: Position indices of shape (batch_size, seq_len) + + Returns: + Tuple of (cos, sin) tensors for rotary embedding + """ + inv_freq_expanded = self.inv_freq[None, :, None].float().expand( + position_ids.shape[0], -1, 1 + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotate half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + unsqueeze_dim: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply rotary position embedding to query and key tensors.""" + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MiMoV2InferenceConfig(InferenceConfig): + """Configuration class for MiMo-V2.5-Pro inference on Neuron.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # MoE configuration + self.num_local_experts = self.n_routed_experts + self.n_shared_experts = 0 # MiMo-V2.5-Pro has no shared experts + + # Set intermediate_size for MoE layers + self.intermediate_size = self.moe_intermediate_size + + # Check and pad intermediate size if needed + self.maybe_pad_intermediate() + + # Router configuration + self.neuron_config.router_config.dtype = torch.float32 + self.neuron_config.router_config.act_fn = "sigmoid" # MiMo uses sigmoid + + # Disable numeric CC token as workaround + self.neuron_config.disable_numeric_cc_token = True + + # MiMo normalizes top-k affinities + self.neuron_config.normalize_top_k_affinities = True + + # Parse hybrid layer pattern + self._parse_hybrid_pattern() + + def _parse_hybrid_pattern(self): + """Parse hybrid layer pattern to determine attention types.""" + if hasattr(self, 'hybrid_layer_pattern') and self.hybrid_layer_pattern: + self.layer_attention_types = [ + "sliding_window" if p == 1 else "full" + for p in self.hybrid_layer_pattern + ] + else: + self.layer_attention_types = ["full"] * self.num_hidden_layers + + # Parse MoE layer frequency + if hasattr(self, 'moe_layer_freq') and self.moe_layer_freq: + self.layer_uses_moe = [bool(f) for f in self.moe_layer_freq] + else: + self.layer_uses_moe = [True] * self.num_hidden_layers + + def maybe_pad_intermediate(self): + """Pad intermediate size if required for efficient computation.""" + from neuronx_distributed_inference.models.config import ( + SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP, + ) + + moe_tp_degree = self.neuron_config.moe_tp_degree + I_TP = self.moe_intermediate_size // moe_tp_degree + + if getattr( + self.neuron_config.blockwise_matmul_config, + "use_shard_on_intermediate_dynamic_while", + False, + ): + if I_TP % SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP != 0: + padded_size = ( + math.ceil(I_TP / SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP) + * SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP + * moe_tp_degree + ) + self.moe_intermediate_pad_size = max( + padded_size - self.moe_intermediate_size, 0 + ) + self.moe_intermediate_size = padded_size + + def get_required_attributes(self) -> List[str]: + return [ + "attention_bias", + "head_dim", + "hidden_act", + "hidden_size", + "hybrid_layer_pattern", + "layernorm_epsilon", + "max_position_embeddings", + "moe_intermediate_size", + "moe_layer_freq", + "n_routed_experts", + "norm_topk_prob", + "num_attention_heads", + "num_experts_per_tok", + "num_hidden_layers", + "num_key_value_heads", + "partial_rotary_factor", + "rope_theta", + "scoring_func", + "sliding_window", + "swa_head_dim", + "swa_num_attention_heads", + "swa_num_key_value_heads", + "swa_rope_theta", + "swa_v_head_dim", + "tie_word_embeddings", + "v_head_dim", + "vocab_size", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[MoENeuronConfig]: + return MoENeuronConfig + + +class NeuronMiMoV2Attention(NeuronAttentionBase): + """MiMo-V2.5-Pro Attention implementation supporting hybrid attention patterns. + + Supports both full attention and sliding window attention with different + head dimensions for Q/K vs V. + """ + + def __init__( + self, + config: MiMoV2InferenceConfig, + layer_idx: int, + is_sliding_window: bool = False, + ): + self.layer_idx = layer_idx + self.is_sliding_window = is_sliding_window + + # Select parameters based on attention type + if is_sliding_window: + self.attn_head_dim = config.swa_head_dim + self.attn_v_head_dim = config.swa_v_head_dim + self.attn_num_heads = config.swa_num_attention_heads + self.attn_num_kv_heads = config.swa_num_key_value_heads + rope_theta = getattr(config, 'swa_rope_theta', 10000.0) + self.sliding_window_size = config.sliding_window + else: + self.attn_head_dim = config.head_dim + self.attn_v_head_dim = config.v_head_dim + self.attn_num_heads = config.num_attention_heads + self.attn_num_kv_heads = config.num_key_value_heads + rope_theta = config.rope_theta + self.sliding_window_size = None + + # Calculate partial rotary dimensions + self.partial_rotary_factor = config.partial_rotary_factor + self.rope_dim = int(self.attn_head_dim * self.partial_rotary_factor) + self.rope_dim = self.rope_dim - (self.rope_dim % 2) # Ensure even + self.nope_dim = self.attn_head_dim - self.rope_dim + + # Create rotary embedding + rotary_emb = MiMoV2RotaryEmbedding( + dim=self.attn_head_dim, + max_position_embeddings=config.max_position_embeddings, + base=rope_theta, + partial_rotary_factor=self.partial_rotary_factor, + ) + + # Initialize base attention + # NOTE: We pass v_head_dim to base class, but MiMo uses asymmetric Q/K (192) vs V (128). + # We override init_gqa_properties() to prevent the base class from creating + # incompatible projection layers (which cause crashes when CP > 1). + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=self.attn_num_heads, + num_key_value_heads=self.attn_num_kv_heads, + head_dim=self.attn_v_head_dim, # Use v_head_dim for base class + rotary_emb=rotary_emb, + rms_norm_eps=config.layernorm_epsilon, + use_qk_norm=False, + ) + + # Initialize MiMo-specific projections with correct dimensions + self._init_projections(config) + + # Scaling factor + self.scaling = self.attn_head_dim ** -0.5 + # HF MiMoV2Attention (modeling_mimo_v2.py) multiplies value_states + # by config.attention_value_scale (0.707 for Flash) right after the V + # projection, before attention softmax*V. Matching that here — applied + # to value_states in forward() rather than to attn_output. + self.value_scale = float(getattr(config, "attention_value_scale", 1.0)) + + # Store cache KV heads for cache compatibility + # With CONVERT_TO_MHA, all layers have num_attention_heads KV heads + # Otherwise, use max of full and sliding window kv heads + tp_degree = config.neuron_config.tp_degree + if self.use_gqa_convert_to_mha: + # CONVERT_TO_MHA: cache stores num_attention_heads (same as Q heads) + self.cache_num_kv_heads = self.attn_num_heads + self.local_cache_kv_heads = self.local_num_heads + else: + # Standard GQA: cache uses max of full and sliding window kv heads + self.cache_num_kv_heads = max( + config.num_key_value_heads, + getattr(config, 'swa_num_key_value_heads', config.num_key_value_heads) + ) + self.local_cache_kv_heads = max(1, self.cache_num_kv_heads // tp_degree) + + def init_gqa_properties(self): + """Override base class to prevent creating incompatible QKV projections. + + MiMo-V2.5-Pro has asymmetric Q/K head_dim (192) vs V head_dim (128), + which is incompatible with the base class's GroupQueryAttention_QKV. + MiMo uses its own custom projections via _init_projections() instead. + + When CP > 1, the base class would create cte_qkv_proj/tkg_qkv_proj with + wrong head_dim=128, causing compilation crashes. This no-op prevents that. + """ + pass + + def _init_projections(self, config: MiMoV2InferenceConfig): + """Initialize projection layers with correct dimensions. + + When CONVERT_TO_MHA is needed (tp_degree > num_kv_heads), K/V projections + are sized for num_attention_heads (not original num_kv_heads). The checkpoint + weights are replicated in preshard_hook before loading. + """ + dtype = config.neuron_config.torch_dtype + tp_degree = config.neuron_config.tp_degree + + # Check if we need GQA CONVERT_TO_MHA (when tp_degree > num_kv_heads) + self.use_gqa_convert_to_mha = tp_degree > self.attn_num_kv_heads + + # Store source heads for preshard_hook + self._src_num_kv_heads = self.attn_num_kv_heads + self._kv_replication_factor = self.attn_num_heads // self.attn_num_kv_heads if self.use_gqa_convert_to_mha else 1 + + if self.use_gqa_convert_to_mha: + # CONVERT_TO_MHA: K and V use num_attention_heads for proper TP splitting + k_num_heads = self.attn_num_heads + v_num_heads = self.attn_num_heads + else: + k_num_heads = self.attn_num_kv_heads + v_num_heads = self.attn_num_kv_heads + + # Q/K use head_dim, V uses v_head_dim + q_hidden_size = self.attn_num_heads * self.attn_head_dim + k_hidden_size = k_num_heads * self.attn_head_dim + v_hidden_size = v_num_heads * self.attn_v_head_dim + o_hidden_size = self.attn_num_heads * self.attn_v_head_dim + + if parallel_state.model_parallel_is_initialized(): + tp_group = parallel_state.get_tensor_model_parallel_group() + + # Q projection + self.q_proj = ColumnParallelLinear( + config.hidden_size, + q_hidden_size, + bias=config.attention_bias, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + + # K projection + self.k_proj = ColumnParallelLinear( + config.hidden_size, + k_hidden_size, + bias=config.attention_bias, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + + # V projection + self.v_proj = ColumnParallelLinear( + config.hidden_size, + v_hidden_size, + bias=config.attention_bias, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + + # Output projection - with sequence parallel to scatter output + self.o_proj = RowParallelLinear( + o_hidden_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + dtype=dtype, + tensor_model_parallel_group=tp_group, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=1 if self.sequence_parallel_enabled else None, + ) + + # Calculate local dimensions after TP split + self.local_num_heads = self.attn_num_heads // tp_degree + if self.use_gqa_convert_to_mha: + # With CONVERT_TO_MHA, local KV heads = local Q heads + self.local_num_kv_heads = self.local_num_heads + else: + self.local_num_kv_heads = max(1, self.attn_num_kv_heads // tp_degree) + else: + self.q_proj = nn.Linear(config.hidden_size, q_hidden_size, bias=config.attention_bias) + self.k_proj = nn.Linear(config.hidden_size, k_hidden_size, bias=config.attention_bias) + self.v_proj = nn.Linear(config.hidden_size, v_hidden_size, bias=config.attention_bias) + self.o_proj = nn.Linear(o_hidden_size, config.hidden_size, bias=False) + + self.local_num_heads = self.attn_num_heads + self.local_num_kv_heads = k_num_heads + + # Override base class attributes that were computed with wrong head_dim + # The base class init_gqa_properties() uses head_dim=v_head_dim which is wrong for Q/K + # We need to override these to ensure correct computation + self.num_heads = self.local_num_heads + self.num_key_value_heads = self.local_num_kv_heads + self.num_key_value_groups = self.local_num_heads // self.local_num_kv_heads + self.head_dim = self.attn_head_dim # Override to use actual Q/K head_dim (192) + + # Remove qkv_proj from base class if exists (we use separate q_proj, k_proj, v_proj) + if hasattr(self, 'qkv_proj'): + self.qkv_proj = None + + # Attention sink bias for attention layers (following HF implementation) + # This is a learnable parameter that allows attention to "sink" to an extra position + add_full_attention_sink_bias = getattr(config, 'add_full_attention_sink_bias', False) + add_swa_attention_sink_bias = getattr(config, 'add_swa_attention_sink_bias', True) + + # Determine if this layer uses sink bias based on config + self._use_sink_bias = (add_full_attention_sink_bias and not self.is_sliding_window) or \ + (add_swa_attention_sink_bias and self.is_sliding_window) + + if self._use_sink_bias: + # Shape: [num_attention_heads] - will be split across TP ranks + # The weight is loaded from checkpoint with shape [num_attention_heads] + # and will be sliced to [local_num_heads] during forward + self.attention_sink_bias = nn.Parameter( + torch.zeros(self.attn_num_heads, dtype=dtype), requires_grad=False + ) + else: + self.attention_sink_bias = None + + def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: + """Pre-shard hook to replicate K/V weights for CONVERT_TO_MHA. + + NOTE: This method is NOT currently called because NeuronMiMoV2Attention + is not a BaseGroupQueryAttention subclass. K/V weight replication is + instead done in convert_mimo_v2_hf_to_neuron_state_dict(). + + This method is kept for reference and potential future use. + """ + # This hook is not called - see note above + return False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[torch.Tensor] = None, + cos_cache: Optional[torch.Tensor] = None, + sin_cache: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """Forward pass for MiMo-V2.5-Pro attention with Context Parallelism support.""" + + # Context Parallelism: only active during context encoding (no past_key_value) + is_context_parallel = past_key_value is None and self.cp_degree > 1 + cp_rank = None + + if is_context_parallel: + cp_rank = get_cp_rank( + self.rank_util.get_rank(), self.tp_degree, + self.cp_degree, self.neuron_config.switch_cc, + ) + # Split attention_mask (dim=2 = Q rows) and position_ids (dim=1 = seq) + attention_mask = split_along_dim( + attention_mask, dim=2, rank=cp_rank, num_partitions=self.cp_degree + ) + # Keep full position_ids for RoPE computation on full-length K/V + local_position_ids = split_along_dim( + position_ids, dim=1, rank=cp_rank, num_partitions=self.cp_degree + ) + + # Handle sequence parallel + if self.sequence_parallel_enabled and parallel_state.model_parallel_is_initialized(): + hidden_states = gather_from_sequence_parallel_region( + hidden_states, + self.sequence_dimension, + process_group=parallel_state.get_tensor_model_parallel_group(), + ) + + # Context Parallelism without sequence parallel: split hidden_states + if is_context_parallel and not self.sequence_parallel_enabled: + hidden_states = split_along_dim( + hidden_states, dim=1, rank=cp_rank, num_partitions=self.cp_degree + ) + + bsz, q_len, _ = hidden_states.size() + + # Determine if this is token generation (past_key_value is not None) + is_token_gen = past_key_value is not None + + # Project Q, K, V + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # HF MiMoV2Attention scales V by attention_value_scale (0.707 for Flash) + # right after v_proj, before the attention softmax*V. Earlier revisions + # of this file applied it post-attention or not at all; both produce + # gibberish for prompts longer than ~20 tokens. + if self.value_scale != 1.0: + value_states = value_states * self.value_scale + + # Reshape for multi-head attention: [bsz, num_heads, seq_len, head_dim] + query_states = query_states.view(bsz, q_len, self.local_num_heads, self.attn_head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.local_num_kv_heads, self.attn_head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.local_num_kv_heads, self.attn_v_head_dim).transpose(1, 2) + + # Split into rope and non-rope parts + query_rope = query_states[..., :self.rope_dim] + query_nope = query_states[..., self.rope_dim:] + key_rope = key_states[..., :self.rope_dim] + key_nope = key_states[..., self.rope_dim:] + + # Compute rotary embeddings + # IMPORTANT: Always compute for this layer because different layer types + # (full vs sliding window) use different rope_theta values. + # Pro: full=config.rope_theta (10M), SWA=config.swa_rope_theta (10K). + # Cannot reuse cached cos/sin across layer types. + # + # For CP with sequence_parallel: Q/K/V have full S, use full position_ids for RoPE. + # For CP without sequence_parallel: Q/K/V have S/CP, use local_position_ids for RoPE + # (local_position_ids contain the correct global positions for this CP rank). + if is_context_parallel and not self.sequence_parallel_enabled: + rope_position_ids = local_position_ids + else: + rope_position_ids = position_ids + cos_cache, sin_cache = self.rotary_emb(value_states, rope_position_ids) + + # Apply rotary position embedding to rope parts only + query_rope, key_rope = apply_rotary_pos_emb( + query_rope, key_rope, cos_cache, sin_cache, rope_position_ids + ) + + # Concatenate rope and non-rope parts + query_states = torch.cat([query_rope, query_nope], dim=-1) + key_states = torch.cat([key_rope, key_nope], dim=-1) + + # Context Parallelism: split Q and save local KV for cache + if is_context_parallel: + if self.sequence_parallel_enabled: + # Q/K/V have full S. Split Q to local portion, save local KV for cache. + # Use split_along_dim (torch.index_select) instead of Python slicing + # because XLA tracing doesn't support dynamic tensor indices in slice notation. + query_states = split_along_dim(query_states, dim=2, rank=cp_rank, num_partitions=self.cp_degree) + key_states_for_cache = split_along_dim(key_states, dim=2, rank=cp_rank, num_partitions=self.cp_degree) + value_states_for_cache = split_along_dim(value_states, dim=2, rank=cp_rank, num_partitions=self.cp_degree) + q_len = q_len // self.cp_degree + # K/V stay at full S for attention computation + else: + # Q/K/V have S/CP. Save local KV for cache, then all-gather K/V. + key_states_for_cache = key_states + value_states_for_cache = value_states + key_states = gather_from_tensor_model_parallel_region_with_dim( + key_states, gather_dim=2, + process_group=get_context_parallel_attention_cp_group(), + ) + value_states = gather_from_tensor_model_parallel_region_with_dim( + value_states, gather_dim=2, + process_group=get_context_parallel_attention_cp_group(), + ) + # Q stays at S/CP + else: + # Store key/value states BEFORE GQA repeat for KV cache + key_states_for_cache = key_states + value_states_for_cache = value_states + + # WORKAROUND 1: Pad V from v_head_dim (128) to head_dim (192) for KV cache compatibility + if self.attn_v_head_dim < self.attn_head_dim: + pad_size = self.attn_head_dim - self.attn_v_head_dim + value_states_for_cache = F.pad(value_states_for_cache, (0, pad_size), value=0.0) + + # WORKAROUND 2: Pad KV heads if layer has fewer than cache expects + # Only needed when NOT using CONVERT_TO_MHA (standard GQA mode) + if not self.use_gqa_convert_to_mha and self.local_num_kv_heads < self.local_cache_kv_heads: + # Pad KV heads by repeating + repeat_factor = self.local_cache_kv_heads // self.local_num_kv_heads + key_states_for_cache = key_states_for_cache.repeat(1, repeat_factor, 1, 1) + value_states_for_cache = value_states_for_cache.repeat(1, repeat_factor, 1, 1) + + # Repeat KV heads for GQA (only needed without CONVERT_TO_MHA) + # With CONVERT_TO_MHA, K/V already have num_attention_heads + num_key_value_groups = self.local_num_heads // self.local_num_kv_heads + if num_key_value_groups > 1: + key_states = key_states.repeat_interleave(num_key_value_groups, dim=1) + value_states = value_states.repeat_interleave(num_key_value_groups, dim=1) + + if is_token_gen: + # Token generation: use decomposed attention with prior (cached) and active (current) KV + # past_key_value[0] = cached K, shape [bsz, cache_kv_heads, kv_seq_len, head_dim] + # past_key_value[1] = cached V, shape [bsz, cache_kv_heads, kv_seq_len, head_dim] (padded) + K_prior = past_key_value[0] + V_prior = past_key_value[1] + + # WORKAROUND 1: Slice KV heads if cache has more than layer needs + # Only needed when NOT using CONVERT_TO_MHA (standard GQA mode) + # With CONVERT_TO_MHA, cache and layer have same num_kv_heads + if not self.use_gqa_convert_to_mha and self.local_num_kv_heads < self.local_cache_kv_heads: + # Cache has repeated heads, just take the first local_num_kv_heads + K_prior = K_prior[:, :self.local_num_kv_heads, :, :] + V_prior = V_prior[:, :self.local_num_kv_heads, :, :] + + # WORKAROUND 2: Slice V_prior back to v_head_dim (128) from head_dim (192) + if self.attn_v_head_dim < self.attn_head_dim: + V_prior = V_prior[..., :self.attn_v_head_dim] + + # Repeat cached KV for GQA (only needed without CONVERT_TO_MHA) + # With CONVERT_TO_MHA, cached K/V already have num_attention_heads + if num_key_value_groups > 1: + K_prior = K_prior.repeat_interleave(num_key_value_groups, dim=1) + V_prior = V_prior.repeat_interleave(num_key_value_groups, dim=1) + + # Compute attention on prior (cached) KV + # K_prior shape: [bsz, num_heads, kv_seq_len, head_dim] + prior_scores = torch.matmul(query_states, K_prior.transpose(-2, -1)) * self.scaling + + # Apply attention mask to prior scores + if attention_mask is not None: + # Convert boolean mask to additive mask if needed + if attention_mask.dtype == torch.bool: + prior_scores = prior_scores.masked_fill(~attention_mask, float('-inf')) + else: + prior_scores = prior_scores + attention_mask + + # Apply sliding window mask for SWA layers + if self.is_sliding_window and self.sliding_window_size is not None and position_ids is not None: + kv_seq_len = prior_scores.size(-1) + current_pos = position_ids[0, 0] + pos_indices = torch.arange(kv_seq_len, device=prior_scores.device) + sliding_mask = pos_indices >= (current_pos - self.sliding_window_size + 1) + sliding_mask = sliding_mask[None, None, None, :] + prior_scores = prior_scores.masked_fill(~sliding_mask, float('-inf')) + + prior_scores = prior_scores.to(torch.float32) + + # Compute attention on active (current) KV + active_scores = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling + active_scores = active_scores.to(torch.float32) + + # Combined softmax over prior and active scores + all_scores = torch.cat([prior_scores, active_scores], dim=-1) + + # Add attention sink bias (following HF implementation) + # This must be applied to token generation as well! + use_sink = self._use_sink_bias and self.attention_sink_bias is not None + if use_sink: + tp_rank = parallel_state.get_tensor_model_parallel_rank() if parallel_state.model_parallel_is_initialized() else 0 + local_sink = self.attention_sink_bias[tp_rank * self.local_num_heads:(tp_rank + 1) * self.local_num_heads] + sink_bias = local_sink.reshape(1, -1, 1, 1).expand(bsz, -1, q_len, 1) + all_scores = torch.cat([all_scores, sink_bias], dim=-1) + + # Numerical stability: subtract max before softmax + all_scores = all_scores - all_scores.max(dim=-1, keepdim=True).values + attn_weights = F.softmax(all_scores, dim=-1, dtype=torch.float32) + + # Drop the sink column after softmax + if use_sink: + attn_weights = attn_weights[..., :-1] + + # Split attention weights back + prior_weights = attn_weights[..., :-q_len].to(V_prior.dtype) + active_weights = attn_weights[..., -q_len:].to(value_states.dtype) + + # Compute attention outputs + attn_prior = torch.matmul(prior_weights, V_prior) + attn_active = torch.matmul(active_weights, value_states) + attn_output = attn_prior + attn_active + else: + # Context encoding: standard attention + # With CP: Q is local [B, H, S/CP, D], K/V are full [B, H, S, D] + # Without CP: Q/K/V all have same seq_len + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling + + # Apply attention mask (additive mask: 0 = attend, -inf = mask out) + # The framework creates boolean masks, so we need to convert them + # With CP: attention_mask is already split to [B, 1, S/CP, S] (local Q rows, full K cols) + if attention_mask is not None: + # Convert boolean mask to additive mask if needed + if attention_mask.dtype == torch.bool: + # True = attend (0), False = mask (-inf) + additive_mask = torch.zeros_like(attn_weights) + additive_mask = additive_mask.masked_fill(~attention_mask, float('-inf')) + attn_weights = attn_weights + additive_mask + else: + # Already additive mask + attn_weights = attn_weights + attention_mask + + # Apply sliding window mask for SWA layers + if self.is_sliding_window and self.sliding_window_size is not None: + kv_seq_len = attn_weights.size(-1) + if is_context_parallel: + # With CP: Q has local seq len, K has full seq len. + # Use local_position_ids for correct global Q positions. + row_idx = local_position_ids[0].unsqueeze(1).to(attn_weights.device) + else: + row_idx = torch.arange(kv_seq_len, device=attn_weights.device).unsqueeze(1) + col_idx = torch.arange(kv_seq_len, device=attn_weights.device).unsqueeze(0) + # Causal: col <= row, and within window: col >= row - window_size + 1 + sliding_mask = (col_idx <= row_idx) & (col_idx >= row_idx - self.sliding_window_size + 1) + sliding_mask = sliding_mask[None, None, :, :] + # Convert to additive mask + attn_weights = attn_weights.masked_fill(~sliding_mask, float('-inf')) + + # Add attention sink bias (following HF implementation) + # This adds an extra "sink" column to attention weights + use_sink = self._use_sink_bias and self.attention_sink_bias is not None + if use_sink: + # Get local portion of sink bias for this TP rank + tp_rank = parallel_state.get_tensor_model_parallel_rank() if parallel_state.model_parallel_is_initialized() else 0 + local_sink = self.attention_sink_bias[tp_rank * self.local_num_heads:(tp_rank + 1) * self.local_num_heads] + # Reshape and expand: [local_num_heads] -> [bsz, local_num_heads, q_len, 1] + sink_bias = local_sink.reshape(1, -1, 1, 1).expand(bsz, -1, q_len, 1) + attn_weights = torch.cat([attn_weights, sink_bias], dim=-1) + + # Numerical stability: subtract max before softmax (like HF implementation) + attn_weights = attn_weights - attn_weights.max(dim=-1, keepdim=True).values + + # Softmax + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32) + + # Drop the sink column after softmax + if use_sink: + attn_weights = attn_weights[..., :-1] + + attn_weights = attn_weights.to(value_states.dtype) + + # Apply attention to values + attn_output = torch.matmul(attn_weights, value_states) + + # Reshape and project output + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.local_num_heads * self.attn_v_head_dim) + + # Context Parallelism: gather output across CP ranks BEFORE o_proj. + # With SP enabled, o_proj scatters along seq dim. The input must have full S + # (not S/CP), otherwise the SP-scattered output won't match the residual. + # Without SP, gather after o_proj to restore full seq_len for residual. + if is_context_parallel: + attn_output = gather_from_tensor_model_parallel_region_with_dim( + attn_output, gather_dim=1, + process_group=get_context_parallel_attention_cp_group(), + ) + + attn_output = self.o_proj(attn_output) + + # Prepare KV cache output - return as tuple for KV cache manager + # Return LOCAL key/value states for cache (each CP rank stores its portion) + new_key_value = (key_states_for_cache, value_states_for_cache) + + return attn_output, new_key_value, cos_cache, sin_cache + + +class MiMoV2MLP(nn.Module): + """Standard MLP for non-MoE layers in MiMo-V2.5-Pro.""" + + def __init__(self, config: MiMoV2InferenceConfig): + super().__init__() + self.hidden_size = config.hidden_size + # Use the dense intermediate size for non-MoE layers + self.intermediate_size = getattr(config, 'dense_intermediate_size', config.intermediate_size * 8) + + dtype = config.neuron_config.torch_dtype + + if parallel_state.model_parallel_is_initialized(): + tp_group = parallel_state.get_tensor_model_parallel_group() + + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + input_is_parallel=True, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + else: + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + self.act_fn = F.silu + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class NeuronMiMoV2DecoderLayer(nn.Module): + """MiMo-V2.5-Pro Decoder Layer with hybrid attention and conditional MoE.""" + + def __init__(self, config: MiMoV2InferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + # Determine attention type for this layer + is_sliding_window = config.layer_attention_types[layer_idx] == "sliding_window" + self.attention_type = "sliding_window" if is_sliding_window else "full" + + # Create attention module + self.self_attn = NeuronMiMoV2Attention( + config=config, + layer_idx=layer_idx, + is_sliding_window=is_sliding_window, + ) + + # Determine if this layer uses MoE + self.uses_moe = config.layer_uses_moe[layer_idx] + + # Create MLP/MoE module + if self.uses_moe: + self.mlp = initialize_moe_module(config=config) + else: + self.mlp = MiMoV2MLP(config) + + # Layer norms + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.layernorm_epsilon, + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.layernorm_epsilon, + ) + + # Config flags + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + cos_cache: Optional[torch.Tensor] = None, + sin_cache: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, ...]: + """Forward pass for decoder layer.""" + + # Self attention with residual + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cos_cache=cos_cache, + sin_cache=sin_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # MLP/MoE with residual + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if self.uses_moe: + hidden_states = self.mlp(hidden_states, padding_mask)[0] + else: + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value, cos_cache, sin_cache, None) + return outputs + + +class NeuronMiMoV2Model(NeuronBaseModel): + """MiMo-V2.5-Pro Model for NXD inference.""" + + def setup_attr_for_model(self, config: MiMoV2InferenceConfig): + self.on_device_sampling = config.neuron_config.on_device_sampling_config is not None + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + + # Check if we need GQA CONVERT_TO_MHA mode + # When tp_degree > num_kv_heads, we replicate K/V to match num_attention_heads + min_kv_heads = min( + config.num_key_value_heads, + getattr(config, 'swa_num_key_value_heads', config.num_key_value_heads) + ) + self.use_gqa_convert_to_mha = self.tp_degree > min_kv_heads + + if self.use_gqa_convert_to_mha: + # With CONVERT_TO_MHA, KV cache stores num_attention_heads (same as Q) + self.num_key_value_heads = config.num_attention_heads + else: + # Standard GQA: use the maximum num_kv_heads for KV cache + # (handles hybrid full/sliding window attention) + self.num_key_value_heads = max( + config.num_key_value_heads, + getattr(config, 'swa_num_key_value_heads', config.num_key_value_heads) + ) + + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + # MiMo has hybrid attention (full + sliding window) + # NOTE: Do NOT set self.sliding_window here because it affects KV cache size globally. + # MiMo handles sliding window per-layer in the attention module itself. + # Setting has_mixed_attn = True enables proper mask creation without affecting cache size. + self.has_mixed_attn = True + + def init_model(self, config: MiMoV2InferenceConfig): + self.padding_idx = getattr(config, 'pad_token_id', None) + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + + self.layers = nn.ModuleList([ + NeuronMiMoV2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + + self.norm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.layernorm_epsilon, + ) + + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + ) + + +def _replicate_kv_weights_for_convert_to_mha( + tensor: torch.Tensor, + source_heads: int, + target_heads: int, + head_dim: int, +) -> torch.Tensor: + """Replicate K/V weights from source_heads to target_heads for CONVERT_TO_MHA. + + Args: + tensor: Weight tensor of shape [source_heads * head_dim, hidden_size] + source_heads: Number of source KV heads + target_heads: Number of target heads (num_attention_heads) + head_dim: Head dimension + + Returns: + Replicated tensor of shape [target_heads * head_dim, hidden_size] + """ + if tensor is None or source_heads >= target_heads: + return tensor + + repeats = target_heads // source_heads + + # Reshape to [source_heads, head_dim, hidden_size] + original_shape = tensor.shape + tensor = tensor.view(source_heads, head_dim, -1) + + # Repeat along head dimension + tensor = tensor.repeat_interleave(repeats, dim=0) + + # Reshape back to [num_heads * head_dim, hidden_size] + tensor = tensor.view(-1, original_shape[-1]) + + return tensor + + +def convert_mimo_v2_hf_to_neuron_state_dict( + neuron_state_dict: Dict[str, Any], + config: MiMoV2InferenceConfig, +) -> Dict[str, Any]: + """Convert HuggingFace MiMo-V2.5-Pro weights to Neuron format. + + This handles: + 1. Router weight renaming + 2. Expert weight concatenation and transposition + 3. FP8 dequantization if needed + 4. K/V weight replication for CONVERT_TO_MHA mode + """ + + assert config.neuron_config.glu_mlp is True, "Only GLU MLP is supported" + + # Dequantize layers if needed + _maybe_dequantize_layer(neuron_state_dict, config) + + # Add rank utility tensors + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, config.neuron_config.tp_degree, dtype=torch.int32 + ) + + # Determine if CONVERT_TO_MHA is needed + tp_degree = config.neuron_config.tp_degree + num_attention_heads = config.num_attention_heads + + # MiMo-V2.5-Pro has different KV heads for full and sliding window attention + full_num_kv_heads = config.num_key_value_heads # 4 + swa_num_kv_heads = config.swa_num_key_value_heads # 8 + + # Check if we need to replicate K/V weights + full_use_convert_to_mha = tp_degree > full_num_kv_heads + swa_use_convert_to_mha = tp_degree > swa_num_kv_heads + + for layer_idx in range(config.num_hidden_layers): + # Add rank utility for attention + neuron_state_dict[f"layers.{layer_idx}.self_attn.rank_util.rank"] = torch.arange( + 0, config.neuron_config.tp_degree, dtype=torch.int32 + ) + + # Determine attention type for this layer + is_sliding_window = config.layer_attention_types[layer_idx] == "sliding_window" + + if is_sliding_window: + src_num_kv_heads = swa_num_kv_heads + use_convert_to_mha = swa_use_convert_to_mha + head_dim = config.swa_head_dim # 192 + v_head_dim = config.swa_v_head_dim # 128 + else: + src_num_kv_heads = full_num_kv_heads + use_convert_to_mha = full_use_convert_to_mha + head_dim = config.head_dim # 192 + v_head_dim = config.v_head_dim # 128 + + # Replicate K/V weights if CONVERT_TO_MHA is needed + if use_convert_to_mha: + k_proj_key = f"layers.{layer_idx}.self_attn.k_proj.weight" + v_proj_key = f"layers.{layer_idx}.self_attn.v_proj.weight" + + if k_proj_key in neuron_state_dict: + neuron_state_dict[k_proj_key] = _replicate_kv_weights_for_convert_to_mha( + neuron_state_dict[k_proj_key], + src_num_kv_heads, + num_attention_heads, + head_dim, + ) + + if v_proj_key in neuron_state_dict: + neuron_state_dict[v_proj_key] = _replicate_kv_weights_for_convert_to_mha( + neuron_state_dict[v_proj_key], + src_num_kv_heads, + num_attention_heads, + v_head_dim, + ) + + # FP8 path: replicate per-row scales ([src_heads*head_dim, 1]) in + # lockstep with the weights. Without this the shard_weights step + # rejects the scale shape mismatch (e.g. [12,1] vs expected [192,1]). + # BF16 has no .scale key, so this loop is a no-op there. + for proj, hd in (("k_proj", head_dim), ("v_proj", v_head_dim)): + scale_key = f"layers.{layer_idx}.self_attn.{proj}.scale" + if scale_key in neuron_state_dict: + neuron_state_dict[scale_key] = _replicate_kv_weights_for_convert_to_mha( + neuron_state_dict[scale_key], + src_num_kv_heads, + num_attention_heads, + hd, + ) + + # Only convert MoE layers + if not config.layer_uses_moe[layer_idx]: + continue + + # Check if this layer has MoE weights + gate_key = f"layers.{layer_idx}.mlp.gate.weight" + if gate_key not in neuron_state_dict: + continue + + # Rename router weights + neuron_state_dict[f"layers.{layer_idx}.mlp.router.linear_router.weight"] = ( + neuron_state_dict[gate_key].detach().clone() + ) + del neuron_state_dict[gate_key] + + # Get dimensions from first expert + expert_0_gate = f"layers.{layer_idx}.mlp.experts.0.gate_proj.weight" + if expert_0_gate not in neuron_state_dict: + continue + + intermediate_size, hidden_size = neuron_state_dict[expert_0_gate].shape + device = neuron_state_dict[expert_0_gate].device + dtype = neuron_state_dict[expert_0_gate].dtype + + num_experts = config.n_routed_experts + + # Concatenate gate and up projections + gate_up_proj = torch.empty( + num_experts, + hidden_size, + 2 * intermediate_size, + dtype=dtype, + device=device, + ) + + for e in range(num_experts): + gate_proj_weights = neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight" + ].T.detach().clone() + up_proj_weights = neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight" + ].T.detach().clone() + + gate_up_proj[e, :, :intermediate_size] = gate_proj_weights + gate_up_proj[e, :, intermediate_size:] = up_proj_weights + + del neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"] + del neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"] + + # Pad if needed + pad_size = getattr(config, "moe_intermediate_pad_size", 0) + if pad_size > 0: + gate_up_proj = gate_up_proj.reshape(num_experts, hidden_size, 2, -1) + gate_up_proj = F.pad(gate_up_proj, (0, pad_size)) + gate_up_proj = gate_up_proj.reshape(num_experts, hidden_size, -1) + + neuron_state_dict[f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj + + # Convert down projections + down_proj = torch.empty( + num_experts, + intermediate_size, + hidden_size, + dtype=dtype, + device=device, + ) + + for e in range(num_experts): + down_proj_weights = neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" + ].T.detach().clone() + down_proj[e] = down_proj_weights + del neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"] + + # Pad if needed + if pad_size > 0: + down_proj = F.pad(down_proj, (0, 0, 0, pad_size)) + + neuron_state_dict[f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj + + gc.collect() + + # --- Expand MoE blockwise scales along the TP-partitioned dim (FP8 only). --- + # NxDI's shard_checkpoint splits the scale on its partition dim into + # `per_partition_size = dim_size / tp_degree`. At TP=64 both projections + # have per-rank "intermediate" smaller than the 128-wide scale block, so + # several ranks share one scale block — we need to replicate scale entries + # along that dim. Adjacent ranks whose weight falls inside the same + # 128-wide block genuinely share that block's scale. No-op when the + # .scale keys are absent (BF16 path). + if getattr(config.neuron_config, "quantized", False): + # IMPORTANT: MoE expert weights are sharded by moe_tp_degree (not the + # top-level tp_degree — attention uses tp_degree, MoE can use a + # different split). At moe_tp=64 the per-rank intermediate is 32 (<128) + # so we had to expand the scale to make the shard layout match; at + # moe_tp=16 per-rank intermediate is 128 (>=128) and no expansion is + # needed. + moe_tp = getattr(config.neuron_config, "moe_tp_degree", None) or config.neuron_config.tp_degree + for layer_idx in range(config.num_hidden_layers): + if not config.layer_uses_moe[layer_idx]: + continue + + # down_proj (RowParallel on intermediate dim). Scale: [E, I_blocks, H_blocks] + dp_key = f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.down_proj.scale" + if dp_key in neuron_state_dict: + s = neuron_state_dict[dp_key] + i_blocks = s.shape[1] + h_blocks = s.shape[2] + intermediate = i_blocks * 128 + i_per_rank = intermediate // moe_tp + if i_per_rank < 128: + ranks_per_block = 128 // i_per_rank + s_exp = s.unsqueeze(2).expand(-1, -1, ranks_per_block, -1) + s_exp = s_exp.reshape(s.shape[0], i_blocks * ranks_per_block, h_blocks) + assert s_exp.shape[1] == moe_tp, ( + f"down_proj.scale expansion produced {s_exp.shape[1]} rows, " + f"expected moe_tp={moe_tp}" + ) + neuron_state_dict[dp_key] = s_exp.contiguous() + + # gate_up_proj (ColumnParallel on 2*intermediate dim, gate|up fused + # along last axis). Scale: [E, H_blocks, 2*I_blocks] stored as + # [gate_half | up_half]. Module parameter has per-rank last-dim=1 + # (via _apply_blockwise_scale_stride_fix patch forcing + # partition_stride=1), so the full scale must have last-dim=moe_tp + # with gate entries 0..moe_tp/2 and up entries moe_tp/2..moe_tp. + # Expand each half independently to preserve the gate/up boundary + # when NxD does `split(per_partition=2*I/moe_tp, dim=-1)`. + gu_key = f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.gate_up_proj.scale" + if gu_key in neuron_state_dict: + s = neuron_state_dict[gu_key] + h_blocks = s.shape[1] + two_i_blocks = s.shape[2] + assert two_i_blocks % 2 == 0, ( + f"gate_up_proj.scale last dim must be 2*i_blocks, got {two_i_blocks}" + ) + i_blocks = two_i_blocks // 2 + intermediate = i_blocks * 128 + out_per_rank = (2 * intermediate) // moe_tp + if out_per_rank < 128: + assert moe_tp % 2 == 0, f"moe_tp={moe_tp} must be even for gate/up scale split" + ranks_per_half = moe_tp // 2 + assert ranks_per_half % i_blocks == 0, ( + f"ranks_per_half={ranks_per_half} must be divisible by " + f"i_blocks={i_blocks}" + ) + ranks_per_block = ranks_per_half // i_blocks + gate_half = s[..., :i_blocks] # [E, H_blocks, i_blocks] + up_half = s[..., i_blocks:] + gate_exp = ( + gate_half.unsqueeze(-1) + .expand(-1, -1, -1, ranks_per_block) + .reshape(s.shape[0], h_blocks, ranks_per_half) + ) + up_exp = ( + up_half.unsqueeze(-1) + .expand(-1, -1, -1, ranks_per_block) + .reshape(s.shape[0], h_blocks, ranks_per_half) + ) + s_exp = torch.cat([gate_exp, up_exp], dim=-1) + assert s_exp.shape[-1] == moe_tp, ( + f"gate_up_proj.scale expansion produced {s_exp.shape[-1]} " + f"entries, expected moe_tp={moe_tp}" + ) + neuron_state_dict[gu_key] = s_exp.contiguous() + + return neuron_state_dict + + +def _maybe_dequantize_layer( + neuron_state_dict: Dict[str, Any], + config: MiMoV2InferenceConfig, +): + """Dequantize FP8 layers if present.""" + scale_layers = [] + + for layer_key in list(neuron_state_dict.keys()): + if "_scale_inv" in layer_key: + scales = neuron_state_dict[layer_key] + scale_layers.append(layer_key) + + fp8_layer_name = layer_key.replace("_scale_inv", "") + if fp8_layer_name not in neuron_state_dict: + continue + + fp8_layer = neuron_state_dict[fp8_layer_name] + + # Get block size from config if available + if hasattr(config, 'quantization_config') and config.quantization_config: + block_size = config.quantization_config.get("weight_block_size", [128, 128]) + else: + block_size = [128, 128] + + # Expand scales and dequantize + scales_expanded = scales.repeat_interleave(block_size[0], dim=0) + scales_expanded = scales_expanded.repeat_interleave(block_size[1], dim=1) + + # Ensure shapes match + if scales_expanded.shape != fp8_layer.shape: + scales_expanded = scales_expanded[:fp8_layer.shape[0], :fp8_layer.shape[1]] + + scaled_layer = fp8_layer.to(torch.float32) * scales_expanded.to(torch.float32) + neuron_state_dict[fp8_layer_name] = scaled_layer.to(config.neuron_config.torch_dtype) + + # Remove scale layers + for scale_layer in scale_layers: + del neuron_state_dict[scale_layer] + + +class NeuronMiMoV2ForCausalLM(NeuronBaseForCausalLM): + """MiMo-V2.5-Pro for Causal Language Modeling on Neuron.""" + + _model_cls = NeuronMiMoV2Model + + def __init__(self, *args, **kwargs): + # Install FP8 monkey-patches BEFORE super().__init__ so the patched + # RouterTopK.__init__ and quantization layer classes are in effect + # when NxDI builds the decoder (and instantiates routers). Harnesses + # that drive us via model.compile()/model.load() (e.g. vllm-neuron) + # call those methods AFTER construction, so patching from inside + # compile()/load() is too late — RouterTopK instances would already + # lack our e_score_correction_bias parameter, silently routing tokens + # to wrong experts and producing gibberish output. + # + # _install_fp8_patches() reads self.neuron_config, which needs to + # exist; grab it from the args or the config arg the same way the + # base class does. + ncfg = kwargs.get("config") or (args[1] if len(args) > 1 else None) + if ncfg is not None and getattr(getattr(ncfg, "neuron_config", None), "quantized", False): + self._apply_ep_scale_fix() + self._apply_blockwise_scale_stride_fix() + self._apply_2d_per_channel_fix() + self._apply_router_noaux_tc_fix() + super().__init__(*args, **kwargs) + + @staticmethod + def load_hf_model(model_path: str, **kwargs): + """Load HuggingFace model. + + Note: MiMo-V2.5-Pro uses custom code, so we need trust_remote_code=True + """ + from transformers import AutoModelForCausalLM + return AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + **kwargs, + ) + + @classmethod + def get_config_cls(cls) -> Type[MiMoV2InferenceConfig]: + return MiMoV2InferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: Dict[str, Any], + config: MiMoV2InferenceConfig, + ) -> Dict[str, Any]: + return convert_mimo_v2_hf_to_neuron_state_dict(state_dict, config) + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + # ------------------------------------------------------------------ + # FP8 quantized-inference monkey-patches (no-op unless quantized=True). + # + # Reconcile the preprocessed Neuron-FP8 checkpoint (blockwise-MoE + + # per-row-attn) with NxDI's global blockwise_symmetric q_config. All + # four are gated by self.neuron_config.quantized so the BF16 path is + # completely untouched. + # ------------------------------------------------------------------ + + @staticmethod + def _apply_ep_scale_fix(): + """Skip per-channel `scale` params when marking expert-parallel + weights; they have shape [1, 1, W] and cannot be EP-sharded.""" + from neuronx_distributed.modules.moe.moe_parallel_layers import ( + ExpertFusedLinear, + ) + + if getattr(ExpertFusedLinear, "_mimo_v2_ep_scale_patched", False): + return + + def _patched_mark( + self_inner, + iterable=None, + expert_parallel_group_size=None, + is_prefill=True, + expert_distribution=None, + ): + from neuronx_distributed.parallel_layers.parallel_state import ( + get_expert_model_parallel_size, + ) + + if expert_parallel_group_size is None: + expert_parallel_group_size = get_expert_model_parallel_size() + + if expert_parallel_group_size > 1: + if iterable is None: + params_to_mark = [] + for name, p in self_inner.named_parameters(): + if name == "scale" and p.shape[0] == 1: + continue + params_to_mark.append(p) + iterable = params_to_mark + + for p in iterable: + p.expert_model_parallel = True + if is_prefill: + p.is_prefill = True + p.expert_distribution = expert_distribution + + ExpertFusedLinear._mark_expert_parallel_weights = _patched_mark + ExpertFusedLinear._mimo_v2_ep_scale_patched = True + + @staticmethod + def _apply_blockwise_scale_stride_fix(): + """Force scale.partition_stride=1 for BLOCKWISE_SYMMETRIC quantization + — stride>1 causes strided-splitting failures when per-rank weight size + is smaller than a block.""" + from neuronx_distributed.quantization.quantization_config import ( + QuantizationType, + ) + from neuronx_distributed.quantization.quantization_layers import ( + BaseQuantizeParallelLinear, + ) + + if getattr(BaseQuantizeParallelLinear, "_mimo_v2_blockwise_stride_patched", False): + return + + _original_setup = BaseQuantizeParallelLinear._setup_for_scale + + def _patched_setup(self_inner, *args, **kwargs): + _original_setup(self_inner, *args, **kwargs) + if ( + hasattr(self_inner, "quantization_type") + and self_inner.quantization_type == QuantizationType.BLOCKWISE_SYMMETRIC + and hasattr(self_inner, "scale") + and hasattr(self_inner.scale, "partition_stride") + and self_inner.scale.partition_stride > 1 + ): + self_inner.scale.partition_stride = 1 + + BaseQuantizeParallelLinear._setup_for_scale = _patched_setup + BaseQuantizeParallelLinear._mimo_v2_blockwise_stride_patched = True + + @staticmethod + def _apply_2d_per_channel_fix(): + """Route 2D self_attn + layer-0 dense-MLP swaps through per_channel_symmetric. + + Flash's preprocess writes: + - MoE experts: 3D weights with (E, out//128, in//128) blockwise scales. + - self_attn q/k/v + layer-0 mlp gate/up/down: 2D weights with + (out, 1) per-row scales. + + NxDI's q_config is global blockwise_symmetric (to satisfy the MoE). + Feeding that into the 2D classes triggers + `block axis cannot be < 0 or > 2, received 2` in _setup_for_scale + (block axes [1, 2] exceed rank-2 weight_shape). This wraps the 2D + classes' from_float to override q_config on the fly. + """ + from neuronx_distributed.quantization.quantization_config import ( + QuantizationType, + ) + from neuronx_distributed.quantization.quantization_layers import ( + QuantizedColumnParallel, + QuantizedRowParallel, + ) + + def _wrap(cls): + if getattr(cls, "_mimo_v2_2d_patched", False): + return + original_from_float = cls.from_float + + def _patched_from_float(klass, mod, q_config=None, _orig=original_from_float): + if q_config is not None and q_config.get("quantization_type") == \ + QuantizationType.BLOCKWISE_SYMMETRIC: + q_config = dict(q_config) + q_config["quantization_type"] = QuantizationType.PER_CHANNEL_SYMMETRIC + q_config["quantization_per_channel_axis"] = 0 + q_config.pop("block_axis", None) + q_config.pop("block_size", None) + if q_config is None: + return _orig(mod) + return _orig(mod, q_config) + + cls.from_float = classmethod(_patched_from_float) + cls._mimo_v2_2d_patched = True + + _wrap(QuantizedColumnParallel) + _wrap(QuantizedRowParallel) + + @staticmethod + def _apply_router_noaux_tc_fix(): + """Register e_score_correction_bias on NxD RouterTopK and fold it into + top-k selection so Flash's noaux_tc routing matches HF reference. + + Flash's HF config uses topk_method='noaux_tc': each expert score is + `sigmoid(logits) + e_score_correction_bias`, top-k indices are chosen + from THAT biased score; the returned expert weights (affinities) + come from the UNBIASED sigmoid(logits). NxD's stock RouterTopK is + plain topk with no bias slot, so without this the bias is silently + dropped and ~all tokens route to wrong experts. + """ + from neuronx_distributed.modules.moe.routing import RouterTopK + + if getattr(RouterTopK, "_mimo_v2_noaux_tc_patched", False): + return + + original_init = RouterTopK.__init__ + + def _patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + # CRITICAL: dtype + init value both matter for XLA tracing. + # + # 1) dtype=torch.bfloat16: the NxDI checkpoint loader casts router + # bias from FP32 -> BF16 ("Found torch.float32 weights in + # checkpoint ... Will convert to torch.bfloat16"). If the traced + # NEFF expects FP32 but the checkpoint supplies BF16, the + # LayoutTransformation silently drops the weight and keeps the + # trace-time init values — so the bias at runtime is whatever + # we init here, not the checkpoint values. + # + # 2) init=arange, NOT zeros: if every entry is identical (all + # zeros), the `+ bias` op does not change the relative ordering + # of topk, so XLA's constant-folding passes can prove the add + # is a no-op and eliminate it entirely — dropping the bias + # parameter from the HLO. At that point checkpoint loading has + # nothing to bind to and the real bias is silently discarded. + # Using arange guarantees distinct per-expert values, forcing + # the compiler to keep the add as a runtime op with a live + # parameter. Source: Jim Burtoft's MiniMax-M2 fix notes + # (jimburtoft/neuronx-distributed-inference@49f8e164). + self.e_score_correction_bias = nn.Parameter( + torch.arange(self.num_experts, dtype=torch.bfloat16), + requires_grad=False, + ) + + def _patched_forward(self, hidden_states): + router_logits = self.get_router_logits(hidden_states) + expert_affinities = self.apply_activation_fn(router_logits) + + # MiMo (and MiniMax-M2) uses topk_method='noaux_tc': the bias is + # added ONLY for top-k selection, but the unbiased sigmoid scores + # remain as the expert-affinity weights passed to the experts. + scores_for_choice = ( + expert_affinities.float() + self.e_score_correction_bias.unsqueeze(0) + ) + _, expert_index = torch.topk(scores_for_choice, self.top_k, dim=-1) + + expert_affinities = expert_affinities.to(dtype=hidden_states.dtype) + expert_index = expert_index.detach().to(dtype=torch.long) + return router_logits, expert_affinities, expert_index + + RouterTopK.__init__ = _patched_init + RouterTopK.forward = _patched_forward + RouterTopK._mimo_v2_noaux_tc_patched = True + + def _install_fp8_patches(self): + """Install all FP8-specific runtime patches. No-op for BF16.""" + if not getattr(self.neuron_config, "quantized", False): + return + self._apply_ep_scale_fix() + self._apply_blockwise_scale_stride_fix() + self._apply_2d_per_channel_fix() + self._apply_router_noaux_tc_fix() + + def compile(self, *args, **kwargs): + # save_sharded_checkpoint=True serializes shards during compile() and + # that code path reads scale.partition_stride — patches must be live. + self._install_fp8_patches() + return super().compile(*args, **kwargs) + + def load(self, *args, **kwargs): + self._install_fp8_patches() + return super().load(*args, **kwargs) + + @classmethod + def save_quantized_state_dict(cls, model_path, config): + """Flash ships pre-quantized FP8 safetensors via our preprocess script. + The base implementation calls AutoModelForCausalLM.from_pretrained to + re-quantize, which requires a CUDA GPU (finegrained_fp8 gate) and + materializes an ~600 GB BF16 copy. Skip if the checkpoint directory + already contains a Neuron-FP8 index produced by preprocess.""" + import os as _os + qpath = ( + getattr(config.neuron_config, "quantized_checkpoints_path", None) + or model_path + ) + if qpath and _os.path.isdir(qpath): + index = _os.path.join(qpath, "model.safetensors.index.json") + if _os.path.isfile(index): + return + return super().save_quantized_state_dict(model_path, config) + + def get_compiler_args(self) -> str: + """Get compiler arguments optimized for MiMo-V2.5-Pro.""" + if self.compile_tag == CONTEXT_ENCODING_MODEL_TAG: + optimization_level = "-O1" + elif self.compile_tag == TOKEN_GENERATION_MODEL_TAG: + optimization_level = "-O3" if self.neuron_config.moe_ep_degree > 1 else "-O1" + else: + optimization_level = "-O1" + + compiler_args = ( + f"--enable-saturate-infinity " + f"--enable-mixed-precision-accumulation " + f"--model-type transformer " + f"{optimization_level}" + ) + + # Add CC overlap optimization + compiler_args += ( + " --tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2'" + ) + + compiler_args += " --auto-cast=none" + + # Enable vector-offset DGE + compiler_args += " --internal-enable-dge-levels vector_dynamic_offsets" + compiler_args += " --internal-hlo2tensorizer-options='--verify-hlo=true'" + + if self.neuron_config.scratchpad_page_size: + compiler_args += f" --hbm-scratchpad-page-size={self.neuron_config.scratchpad_page_size}" + + return compiler_args diff --git a/contrib/models/MiMo-V2.5-Pro/test/__init__.py b/contrib/models/MiMo-V2.5-Pro/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiMo-V2.5-Pro/test/integration/__init__.py b/contrib/models/MiMo-V2.5-Pro/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiMo-V2.5-Pro/test/integration/test_model.py b/contrib/models/MiMo-V2.5-Pro/test/integration/test_model.py new file mode 100644 index 00000000..bcbc368e --- /dev/null +++ b/contrib/models/MiMo-V2.5-Pro/test/integration/test_model.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +"""Integration tests for MiMo-V2-Flash NeuronX implementation.""" + +import pytest +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) + + +def test_config_import(): + """Test that config class can be imported.""" + from modeling_mimo_v2 import MiMoV2InferenceConfig, NeuronMiMoV2ForCausalLM + assert MiMoV2InferenceConfig is not None + assert NeuronMiMoV2ForCausalLM is not None + print("PASS: Config and model classes imported successfully") + + +def test_required_attributes(): + """Test that required attributes are defined.""" + from modeling_mimo_v2 import MiMoV2InferenceConfig + # Check get_required_attributes without instantiation (requires many params) + required = MiMoV2InferenceConfig.get_required_attributes(MiMoV2InferenceConfig) + assert "hidden_size" in required + assert "n_routed_experts" in required + assert "num_experts_per_tok" in required + assert "hybrid_layer_pattern" in required + assert "v_head_dim" in required + assert "swa_head_dim" in required + print(f"PASS: {len(required)} required attributes defined") + + +def test_neuron_config_cls(): + """Test that MoENeuronConfig is returned.""" + from modeling_mimo_v2 import MiMoV2InferenceConfig + from neuronx_distributed_inference.models.config import MoENeuronConfig + assert MiMoV2InferenceConfig.get_neuron_config_cls() == MoENeuronConfig + print("PASS: MoENeuronConfig returned") + + +def test_state_dict_converter(): + """Test that state dict converter function exists.""" + from modeling_mimo_v2 import NeuronMiMoV2ForCausalLM + assert hasattr(NeuronMiMoV2ForCausalLM, "convert_hf_to_neuron_state_dict") + print("PASS: State dict converter exists") + + +if __name__ == "__main__": + test_config_import() + test_required_attributes() + test_neuron_config_cls() + test_state_dict_converter() + print("\nAll tests passed!") diff --git a/contrib/models/MiMo-V2.5-Pro/test/unit/__init__.py b/contrib/models/MiMo-V2.5-Pro/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b From e4a13c14fefe505edfa6c748a6a8abacff27e583 Mon Sep 17 00:00:00 2001 From: whn09 Date: Tue, 28 Apr 2026 06:26:18 +0800 Subject: [PATCH 02/24] [contrib] MiMo-V2.5-Pro: apply review findings from HF reference diff MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit P0 fix: - MiMoV2InferenceConfig now stashes `dense_intermediate_size` from HF's `intermediate_size` BEFORE overwriting `self.intermediate_size` with the MoE value. MiMoV2MLP reads this explicit field instead of the brittle `config.intermediate_size * 8` fallback (which happened to equal 16384 for V2.5-Pro by coincidence). P3 — stale V2-Pro / Flash comments updated: - attention_value_scale comments: "(0.707 for Flash)" → "(0.612 for V2.5-Pro)" - convert_mimo_v2_hf_to_neuron_state_dict kv heads comments: V2.5-Pro has num_key_value_heads=8 (same as SWA), not 4 as in V2-Pro. - smoke_compile docstring reworded to drop "Flash BS=1 recipe" wording. - smoke_compile default recipe changed to moe_tp=1/moe_ep=64/BS=48 (per user request: first V2.5-Pro test uses this recipe because it compiles fastest; bug surface on V2-Pro under this recipe was FP8 precision loss in expert MLP weights, which may not reproduce on V2.5). - preprocess router bias comment: noted measured mean=70.906 std=2.4e-4 (identical pathology to V2-Pro, mean-subtract still required). No behavioral change to FP8 monkey-patches or qkv interleaved-group split logic — HF reference diff confirmed V2.5-Pro ships the same interleaved `[16Q|1K|1V]*8` FP8 qkv layout and the same noaux_tc routing (n_group=1, topk_group=1 degenerate to plain noaux_tc). Co-Authored-By: Claude Opus 4.7 --- .../perf_test/smoke_compile_mimo_v2.py | 30 +++++++++---------- .../perf_test/smoke_generate_mimo_v2.py | 6 ++-- .../preprocess_mimo_v2_fp8.py | 4 +-- .../MiMo-V2.5-Pro/src/modeling_mimo_v2.py | 22 ++++++++++---- 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py index b67a95f8..ef5d0049 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py @@ -2,11 +2,11 @@ """Minimal compile+load smoke test for MiMo-V2.5-Pro FP8 on Trn2. Bypasses vLLM entirely so we can iterate on the preprocessed Neuron-FP8 -checkpoint without paying vllm-neuron's startup cost. Builds the Flash BS=1 -recipe (TP=64, EP=1, blockwise FP8 for routed experts), compiles to a temp -dir, then loads. EP=1 lets the TKG path enter forward_selective_loading -legally so BS=1 compiles — with EP>1 NxDI raises NotImplementedError and -forces BS>=num_experts/top_k = 32. +checkpoint without paying vllm-neuron's startup cost. Compiles the model +(TP=64, configurable MoE TP/EP, blockwise FP8 for routed experts) to a +temp dir, then loads. For `moe_ep_degree > 1` the TKG path raises +`NotImplementedError: Selective Loading with Expert parallelism` unless +`batch_size * top_k / num_experts >= 1.0` → `batch_size >= 384 / 8 = 48`. STAGE controls how far we go: instantiate | compile | load | all (default: all) @@ -31,7 +31,7 @@ ) COMPILED_PATH = os.environ.get( "MIMO_V25_PRO_COMPILED_PATH", - "/opt/dlami/nvme/compiled/mimo_v25_pro_moetp16_ep4_bs48/", + "/opt/dlami/nvme/compiled/mimo_v25_pro_moetp1_ep64_bs48/", ) TP_DEGREE = int(os.environ.get("TP_DEGREE", "64")) @@ -41,13 +41,14 @@ # path raises `NotImplementedError: Selective Loading with Expert parallelism`. BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "48")) CTX_BATCH_SIZE = int(os.environ.get("CTX_BATCH_SIZE", "1")) -# moe_tp=16 / moe_ep=4: after moetp1/ep64 (prefill broken, E_local=6) and -# moetp32/ep2 (HBM OOM by 28MB at load, E_local=192), try the middle ground. -# E_local=384/4=96. Also intermediate/moe_tp = 2048/16 = 128 which matches -# the FP8 scale block size exactly — avoids the stride_fix workaround path. -# world_size = 16*4 = 64 OK. -MOE_TP = int(os.environ.get("MOE_TP", "16")) -MOE_EP = int(os.environ.get("MOE_EP", "4")) +# moe_tp=1 / moe_ep=64: first recipe to try on V2.5-Pro. Lowest compile time +# (no intra-expert TP split) and output quality should be comparable to +# Flash, which uses the same recipe. On V2-Pro this produced garbage +# prefill ("0.0.0.0:8080"), but we're re-testing on V2.5-Pro because the +# V2-Pro root cause ended up being FP8 expert-MLP precision loss, which +# V2.5 may or may not inherit. +MOE_TP = int(os.environ.get("MOE_TP", "1")) +MOE_EP = int(os.environ.get("MOE_EP", "64")) STAGE = os.environ.get("STAGE", "all").lower() @@ -96,8 +97,7 @@ def main(): # (e.g. tp=64 + ep=4 -> 256 ranks -> 4x the sharded checkpoint size, # and at runtime the model doesn't fit on the device). For MoE-only # EP we want ep_degree=1 at the outer level and the per-MoE split - # controlled solely by moe_ep_degree (which Pro's working benches - # also do). Keep ep_degree=1 unconditionally. + # controlled solely by moe_ep_degree. Keep ep_degree=1 unconditionally. neuron_config = MoENeuronConfig( tp_degree=TP_DEGREE, ep_degree=1, diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py index f357aee5..b3318714 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py @@ -21,7 +21,7 @@ ) COMPILED_PATH = os.environ.get( "MIMO_V25_PRO_COMPILED_PATH", - "/opt/dlami/nvme/compiled/mimo_v25_pro_tp64_moetp1_ep64_fp8/", + "/opt/dlami/nvme/compiled/mimo_v25_pro_moetp1_ep64_bs48/", ) # Must match smoke_compile_mimo_v2.py exactly, else load() sees a @@ -30,8 +30,8 @@ SEQ_LEN = int(os.environ.get("SEQ_LEN", "1024")) BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "48")) # must match smoke_compile CTX_BATCH_SIZE = int(os.environ.get("CTX_BATCH_SIZE", "1")) -MOE_TP = int(os.environ.get("MOE_TP", "16")) -MOE_EP = int(os.environ.get("MOE_EP", "4")) +MOE_TP = int(os.environ.get("MOE_TP", "1")) +MOE_EP = int(os.environ.get("MOE_EP", "64")) PROMPT = os.environ.get( "MIMO_V25_PRO_PROMPT", diff --git a/contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py b/contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py index 56eabeaa..b9882fbb 100644 --- a/contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py +++ b/contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py @@ -10,7 +10,7 @@ MiMo-V2.5-Pro checkpoint layout: - q_proj, k_proj, v_proj are FUSED into a single `qkv_proj` tensor per - layer (num_kv_heads interleaved groups, Pro-specific). We split into + layer (num_kv_heads interleaved groups, MiMo-V2.5-Pro-specific). We split into three per-row-quantized projections via `split_qkv_fused()`. - o_proj is BF16 (listed in quantization_config.ignored_layers); kept as BF16 on the Neuron side (RowParallelLinear, not QuantizedRowParallel). @@ -409,7 +409,7 @@ def process_layer( out[f"{out_prefix}mlp.router.linear_router.weight"] = router_w.detach().clone() router_bias = lazy.get(f"{prefix}mlp.gate.e_score_correction_bias") if router_bias is not None: - # Pro-specific: HF bias has mean ~71 with per-expert std ~3e-4. NxDI + # V2.5-Pro: HF bias has mean ~71 (same pathology as V2-Pro; measured mean=70.906, std=2.4e-4) with per-expert std ~3e-4. NxDI # casts router parameters to bf16 at load time, and bf16 step size at # magnitude 71 is ~0.5 — which completely wipes out the per-expert # std=3e-4 variation, collapsing all 384 experts to a single bias diff --git a/contrib/models/MiMo-V2.5-Pro/src/modeling_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/src/modeling_mimo_v2.py index 85014672..230f627c 100644 --- a/contrib/models/MiMo-V2.5-Pro/src/modeling_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/src/modeling_mimo_v2.py @@ -177,6 +177,14 @@ def __init__(self, *args, **kwargs): self.num_local_experts = self.n_routed_experts self.n_shared_experts = 0 # MiMo-V2.5-Pro has no shared experts + # Stash the HF config's `intermediate_size` (used by the dense MLP + # in layer 0) BEFORE we overwrite `self.intermediate_size` with the + # MoE value. `MiMoV2MLP` reads `dense_intermediate_size` and falls + # back to `config.intermediate_size * 8` if absent, which happens + # to equal 16384 for V2.5-Pro (2048 * 8) but is brittle if Xiaomi + # ever tweaks the ratio. + self.dense_intermediate_size = self.intermediate_size + # Set intermediate_size for MoE layers self.intermediate_size = self.moe_intermediate_size @@ -340,7 +348,7 @@ def __init__( # Scaling factor self.scaling = self.attn_head_dim ** -0.5 # HF MiMoV2Attention (modeling_mimo_v2.py) multiplies value_states - # by config.attention_value_scale (0.707 for Flash) right after the V + # by config.attention_value_scale (0.612 for MiMo-V2.5-Pro) right after the V # projection, before attention softmax*V. Matching that here — applied # to value_states in forward() rather than to attn_output. self.value_scale = float(getattr(config, "attention_value_scale", 1.0)) @@ -562,7 +570,7 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # HF MiMoV2Attention scales V by attention_value_scale (0.707 for Flash) + # HF MiMoV2Attention scales V by attention_value_scale (0.612 for MiMo-V2.5-Pro) # right after v_proj, before the attention softmax*V. Earlier revisions # of this file applied it post-attention or not at all; both produce # gibberish for prompts longer than ~20 tokens. @@ -824,8 +832,10 @@ class MiMoV2MLP(nn.Module): def __init__(self, config: MiMoV2InferenceConfig): super().__init__() self.hidden_size = config.hidden_size - # Use the dense intermediate size for non-MoE layers - self.intermediate_size = getattr(config, 'dense_intermediate_size', config.intermediate_size * 8) + # Use the dense intermediate size for non-MoE layers. + # `dense_intermediate_size` is stashed in MiMoV2InferenceConfig.__init__ + # before `self.intermediate_size` is overwritten with the MoE value. + self.intermediate_size = config.dense_intermediate_size dtype = config.neuron_config.torch_dtype @@ -1081,8 +1091,8 @@ def convert_mimo_v2_hf_to_neuron_state_dict( num_attention_heads = config.num_attention_heads # MiMo-V2.5-Pro has different KV heads for full and sliding window attention - full_num_kv_heads = config.num_key_value_heads # 4 - swa_num_kv_heads = config.swa_num_key_value_heads # 8 + full_num_kv_heads = config.num_key_value_heads # V2.5-Pro: 8 + swa_num_kv_heads = config.swa_num_key_value_heads # V2.5-Pro: 8 # Check if we need to replicate K/V weights full_use_convert_to_mha = tp_degree > full_num_kv_heads From 78b8e0a97f1b592b84d06c842ce97fe371d2396e Mon Sep 17 00:00:00 2001 From: whn09 Date: Tue, 28 Apr 2026 10:50:17 +0800 Subject: [PATCH 03/24] [contrib] MiMo-V2.5-Pro: document current WIP status in README MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Record what works and what doesn't on 2026-04-28: - Compile + load succeed on Trn2 (moe_tp=1/ep=64/BS=48 recipe). - Prefill produces coherent English but off-topic output ("100% of the time..." loop for a "explain transformer" prompt). Same signature as V2-Pro's earlier FP8 failures — per-expert weight distribution too narrow for FP8 e4m3 precision. - Note observed token IDs 15/16/4/315/279/882 look suspiciously small but are just " of/ the/ time" etc. — top-BPE English subwords. Greedy decode is correct, the logit distribution itself is wrong. - List recipes still to try (moe_tp=16/ep=4, moe_tp=32/ep=2 etc.) and NxDI constraints that rule out BS=1 when moe_ep>1. Points future debuggers at Jim Burtoft's Flash FP8 observation and his Kimi PR #131 SDK 2.28 recommendation. No code changes. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2.5-Pro/README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/contrib/models/MiMo-V2.5-Pro/README.md b/contrib/models/MiMo-V2.5-Pro/README.md index 38623380..6e44fd82 100644 --- a/contrib/models/MiMo-V2.5-Pro/README.md +++ b/contrib/models/MiMo-V2.5-Pro/README.md @@ -35,6 +35,22 @@ Key features: - **Sigmoid Router + noaux_tc**: `sigmoid(logits) + e_score_correction_bias` is used to pick top-8 experts; unbiased `sigmoid(logits)` becomes the affinity weights. `n_group=1, topk_group=1` degenerates group-limited routing to plain noaux_tc. - **attention_value_scale = 0.612**: HF reference multiplies `value_states` by this before `softmax(QK^T) × V` (NOT applied post-attention); the NxDI port matches. +## Status (work-in-progress) + +**This port compiles cleanly and loads on Trn2 but does not yet produce coherent output.** Current symptoms under the default recipe (`tp_degree=64, moe_tp_degree=1, moe_ep_degree=64, batch_size=48, seq_len=1024`) on 2026-04-28: + +- Prefill drifts from token 1: "Explain in one sentence what a transformer neural network is." → `"100% of the time, 100% of the time, ..."` (greedy decode, temperature=0). Decode speed 0.72 tok/s. Note the output is syntactically valid English ("100% of the time" = BPE tokens `15/16/4/315/279/882` = `"1"/"0"/"%"/" of"/" the"/" time"`, all high-frequency) — not a sampling bug: greedy argmax is correctly picking the model's top token, but the logit distribution itself is wrong (output unrelated to the prompt). Same signature as Jim Burtoft's "Flash FP8 → `erotici` repeat" symptom. +- Same failure pattern was observed on the MiMo-V2-Pro port under the same recipe (`"0.0.0.0:8080"` etc.). Root cause identified there was **FP8 expert-MLP precision loss**: Pro's expert weight std ≈ 0.0018 (10× smaller than Flash's ≈ 0.019), landing right at FP8 e4m3's subnormal threshold. V2.5-Pro inherits the same per-expert scale pathology (router bias mean ≈ 70.906 std ≈ 2.4e-4, verified via preprocess; mean-subtract workaround applied). +- Reference: Jim Burtoft observed similar prompt-dependent FP8 degradation on Flash (see PR notes) and recommends selective BF16 retention for precision-sensitive layers. + +Other recipes to try (none verified yet on V2.5-Pro): +- `moe_tp_degree=16, moe_ep_degree=4, BS=48` — balances E_local=96 vs HBM. +- `moe_tp_degree=32, moe_ep_degree=2, BS=1` — mirrors Jim's Kimi-K2 PR, but V2-Pro OOM'd by 28MB on load at BS=48; BS=1 hits `NotImplementedError: Selective Loading with Expert parallelism`. + +Known NxDI limits that constrain recipe choice: +- `BS * top_k / num_experts >= 1.0` required when `moe_ep_degree > 1` at decode (else NotImplementedError). With `num_experts=384, top_k=8` this forces `BS >= 48`. +- `n_routed_experts=384 = 2^7 * 3` → `384 / ep_degree` is never a power of 2 (6, 12, 24, 48, 96, 192, 384). Kimi PR #131 says NKI `_bwmm_shard_on_block_nki_call` on SDK 2.29 has "depressed logits with EP=2" and recommends SDK 2.28. SDK 2.28 venv is not currently installed on the target DLAMI. + ## Prerequisites - **Instance**: trn2.48xlarge (32 NeuronCores, logical_nc_config=2 → 64 logical cores) From e990f76004ca914aff30e326f21b2ef20540142e Mon Sep 17 00:00:00 2001 From: whn09 Date: Tue, 28 Apr 2026 11:08:08 +0800 Subject: [PATCH 04/24] [contrib] MiMo-V2.5-Pro: correct FP8 root-cause framing in Status MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Earlier wording said "Pro's expert weight std is too small for FP8 precision" in absolute terms. That's misleading — sglang on H100/H200 runs the exact same OCP FP8 checkpoint and produces correct output, because GPU cutlass/sglang paths dequantize FP8 to BF16 before the matmul. The actual issue appears to be Neuron's NKI blockwise FP8 compute kernel (_bwmm_shard_on_block_nki_call) running FP8 compute directly on subnormal-leaning tensors. Jim Burtoft's Kimi PR #131 names the Neuron SDK 2.29 blockwise kernel as producing "depressed logits with EP=2" and recommends SDK 2.28. Also noted: V2.5-Pro MoE expert weights are byte-identical to V2-Pro (measured layer 1 expert 0 gate_proj stats match to 6 decimals), so all V2-Pro FP8 workarounds remain required — not a new bug. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2.5-Pro/README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/contrib/models/MiMo-V2.5-Pro/README.md b/contrib/models/MiMo-V2.5-Pro/README.md index 6e44fd82..552354d9 100644 --- a/contrib/models/MiMo-V2.5-Pro/README.md +++ b/contrib/models/MiMo-V2.5-Pro/README.md @@ -40,8 +40,9 @@ Key features: **This port compiles cleanly and loads on Trn2 but does not yet produce coherent output.** Current symptoms under the default recipe (`tp_degree=64, moe_tp_degree=1, moe_ep_degree=64, batch_size=48, seq_len=1024`) on 2026-04-28: - Prefill drifts from token 1: "Explain in one sentence what a transformer neural network is." → `"100% of the time, 100% of the time, ..."` (greedy decode, temperature=0). Decode speed 0.72 tok/s. Note the output is syntactically valid English ("100% of the time" = BPE tokens `15/16/4/315/279/882` = `"1"/"0"/"%"/" of"/" the"/" time"`, all high-frequency) — not a sampling bug: greedy argmax is correctly picking the model's top token, but the logit distribution itself is wrong (output unrelated to the prompt). Same signature as Jim Burtoft's "Flash FP8 → `erotici` repeat" symptom. -- Same failure pattern was observed on the MiMo-V2-Pro port under the same recipe (`"0.0.0.0:8080"` etc.). Root cause identified there was **FP8 expert-MLP precision loss**: Pro's expert weight std ≈ 0.0018 (10× smaller than Flash's ≈ 0.019), landing right at FP8 e4m3's subnormal threshold. V2.5-Pro inherits the same per-expert scale pathology (router bias mean ≈ 70.906 std ≈ 2.4e-4, verified via preprocess; mean-subtract workaround applied). -- Reference: Jim Burtoft observed similar prompt-dependent FP8 degradation on Flash (see PR notes) and recommends selective BF16 retention for precision-sensitive layers. +- Same failure pattern was observed on the MiMo-V2-Pro port under the same recipe (`"0.0.0.0:8080"` etc.). Root cause identified there appears to be **Neuron's NKI blockwise FP8 compute kernel handling Pro's tight expert-weight distribution** (std ≈ 0.0018, ~10× smaller than Flash's ≈ 0.019). This is NOT an FP8-format problem per se — sglang on H100/H200 runs the exact same OCP FP8 checkpoint and produces correct output, because GPU paths dequantize FP8→BF16 before the matmul. Neuron NKI does FP8 compute directly and seems to lose precision on subnormal-leaning tensors. +- V2.5-Pro's MoE expert weights are byte-identical to V2-Pro (verified layer 1 expert 0 dequant stats match to 6 decimal places on 2026-04-28), so all V2-Pro workarounds remain required (router bias mean-subtract, qkv interleaved split, `_apply_2d_per_channel_fix`, `_apply_blockwise_scale_stride_fix`). +- Reference: Jim Burtoft observed similar prompt-dependent FP8 degradation on Flash and his Kimi PR #131 names "blockwise kernel padding produces depressed logits with EP=2 on SDK 2.29; SDK 2.28 recommended". Other recipes to try (none verified yet on V2.5-Pro): - `moe_tp_degree=16, moe_ep_degree=4, BS=48` — balances E_local=96 vs HBM. From f0d9c0b5f036c1e9a8d4442ec8f5585e9dca075b Mon Sep 17 00:00:00 2001 From: whn09 Date: Tue, 28 Apr 2026 11:59:29 +0800 Subject: [PATCH 05/24] [contrib] MiMo-V2.5-Pro: parallel preprocess + NVMe mount docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Parallel preprocess wrapper: - preprocess_mimo_v2_parallel.py: multiprocessing Pool wrapper around preprocess_mimo_v2_fp8.process_layer. Each worker opens its own LazyWeightMap and processes one layer at a time. N_WORKERS default raised to 12 (user request: "越多越好"); 70 layers * ~25 GB peak/layer stays under ~300 GB RAM on a 2 TB trn2.48xl. - run_preprocess_parallel.sh: thin shell wrapper exposing HF_MODEL_PATH, SAVE_PATH, TP_DEGREE, N_WORKERS env vars. Defaults to the 2_9_nxd_inference venv (same one used by the serial preprocess). Wall-clock ~30 min serial → ~5-6 min at 12 workers on fresh cache. README: - Added "NVMe mount" subsection under Prerequisites. trn2.48xl DLAMI assembles four NVMe into RAID0 at /opt/dlami/nvme but does NOT remount automatically after a reboot. Document mdadm --assemble + mount /dev/md0 /opt/dlami/nvme before any path in the recipes resolves. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2.5-Pro/README.md | 19 ++ .../preprocess_mimo_v2_parallel.py | 179 ++++++++++++++++++ .../run_preprocess_parallel.sh | 34 ++++ 3 files changed, 232 insertions(+) create mode 100644 contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_parallel.py create mode 100755 contrib/models/MiMo-V2.5-Pro/src/conversion_script/run_preprocess_parallel.sh diff --git a/contrib/models/MiMo-V2.5-Pro/README.md b/contrib/models/MiMo-V2.5-Pro/README.md index 552354d9..1b35c567 100644 --- a/contrib/models/MiMo-V2.5-Pro/README.md +++ b/contrib/models/MiMo-V2.5-Pro/README.md @@ -59,6 +59,25 @@ Known NxDI limits that constrain recipe choice: - **Venvs**: `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference` (for preprocess + NxDI direct smoke), `/opt/aws_neuronx_venv_pytorch_inference_vllm_0_16` (for vLLM serving). Both ship with the DLAMI. - **Disk**: ~3 TB free under `/opt/dlami/nvme` (the HF FP8 checkpoint is ~962 GB, the Neuron-FP8 preprocessed output is ~1 TB, and `save_sharded_checkpoint=true` writes another ~300-1000 GB per compiled config (varies with recipe)). +### NVMe mount + +The Trn2 DLAMI ships with four local NVMe SSDs that are assembled into a +RAID0 array at `/opt/dlami/nvme`. After a reboot the mount is **NOT** +reassembled automatically — you must re-mount manually before the paths +below resolve: + +```bash +lsblk # confirm you see nvme0n1..nvme3n1 devices +sudo mdadm --assemble /dev/md0 /dev/nvme[0-3]n1 2>/dev/null || true +sudo mount /dev/md0 /opt/dlami/nvme +df -h /opt/dlami/nvme # should show ~6.9 TB total +``` + +If `mdadm --assemble` says the array is already assembled, the mount +step alone is enough. If `/dev/md0` doesn't exist, the array was never +created on this instance — run `/opt/dlami/setup-nvme.sh` (or the +DLAMI's built-in helper; consult `ls /opt/dlami/*.sh`) before mounting. + ## Quick Start (FP8 on Trn2) End-to-end recipe to go from a fresh trn2.48xlarge to a working vLLM OpenAI server serving MiMo-V2.5-Pro FP8. First-time compile takes ~45-60 minutes; subsequent runs hit the neuronx-cc cache and start in a few minutes. diff --git a/contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_parallel.py b/contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_parallel.py new file mode 100644 index 00000000..23f05bcc --- /dev/null +++ b/contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_parallel.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +"""Parallel wrapper around preprocess_mimo_v2_fp8.process_layer. + +Each layer is independent: `process_layer(L, lazy, config, ...)` reads only +keys under `model.layers.{L}.*` from the HF shards and returns the Neuron +layer-shard dict. With 70 layers and per-MoE-layer cost ~60s serial, 4-8 +workers cuts wallclock from ~70 min to ~15-20 min (I/O + CPU FP8 math). + +Each worker opens its own LazyWeightMap so there's no shared safetensors +handle. Output dir is a CLI arg so it can write to a clean path without +touching the serial run's output. +""" +import argparse +import gc +import json +import multiprocessing as mp +import os +import shutil +import sys +import time + +# Resolve the sibling single-layer preprocess module. This file lives at +# .../MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_parallel.py, +# so the importable parent is two levels up. +_SRC_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if _SRC_DIR not in sys.path: + sys.path.insert(0, _SRC_DIR) +from conversion_script.preprocess_mimo_v2_fp8 import ( # noqa: E402 + LazyWeightMap, + process_layer, + save_shard, +) +from safetensors.torch import save_file # noqa: E402 + + +def _worker(task): + layer_idx, hf_model_path, save_path, config = task + hybrid = config.get( + "hybrid_layer_pattern", [0] * config["num_hidden_layers"] + ) + moe_freq = config.get("moe_layer_freq", [1] * config["num_hidden_layers"]) + is_dense = moe_freq[layer_idx] == 0 + is_swa = hybrid[layer_idx] == 1 + + with open( + os.path.join(hf_model_path, "model.safetensors.index.json") + ) as fh: + weight_map_in = json.load(fh)["weight_map"] + lazy = LazyWeightMap(hf_model_path, weight_map_in) + try: + t0 = time.time() + layer_sd = process_layer( + layer_idx, lazy, config, is_dense=is_dense, is_swa=is_swa + ) + filename = f"model_layer{layer_idx}.safetensors" + path = os.path.join(save_path, filename) + materialized = {} + total_bytes = 0 + for k, v in layer_sd.items(): + if not v.is_contiguous(): + v = v.contiguous() + v = v.detach().clone() + materialized[k] = v + total_bytes += v.numel() * v.element_size() + save_file(materialized, path) + keys = list(materialized.keys()) + del materialized, layer_sd + gc.collect() + elapsed = time.time() - t0 + finally: + lazy.close() + return layer_idx, is_dense, is_swa, keys, filename, total_bytes, elapsed + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--hf_model_path", required=True) + p.add_argument("--save_path", required=True) + p.add_argument("--tp_degree", type=int, default=64) + p.add_argument( + "--workers", + type=int, + default=int(os.environ.get("N_WORKERS", "12")), + ) + args = p.parse_args() + + os.makedirs(args.save_path, exist_ok=True) + with open(os.path.join(args.hf_model_path, "config.json")) as fh: + config = json.load(fh) + num_layers = config["num_hidden_layers"] + + print( + f"[par] {num_layers} layers x {args.workers} workers -> {args.save_path}", + flush=True, + ) + tasks = [ + (L, args.hf_model_path, args.save_path, config) for L in range(num_layers) + ] + weight_map_out = {} + t_start = time.time() + + ctx = mp.get_context("spawn") + with ctx.Pool(args.workers) as pool: + done = 0 + for li, is_dense, is_swa, keys, filename, total_bytes, elapsed in pool.imap_unordered( + _worker, tasks + ): + done += 1 + for k in keys: + weight_map_out[k] = filename + tag = "dense" if is_dense else "moe" + attn = "swa" if is_swa else "full" + print( + f" [{done:2d}/{num_layers}] layer {li:2d} [{tag:5s} {attn:4s}] " + f"{total_bytes/1e9:6.2f} GB in {elapsed:5.1f}s " + f"(wall {time.time()-t_start:5.1f}s)", + flush=True, + ) + + print( + f"[par] all {num_layers} layers done in {time.time()-t_start:.1f}s", + flush=True, + ) + + print("[par] processing embed_tokens / norm / lm_head ...", flush=True) + with open( + os.path.join(args.hf_model_path, "model.safetensors.index.json") + ) as fh: + weight_map_in = json.load(fh)["weight_map"] + lazy = LazyWeightMap(args.hf_model_path, weight_map_in) + extras = {} + try: + for src, dst in ( + ("model.embed_tokens.weight", "embed_tokens.weight"), + ("model.norm.weight", "norm.weight"), + ("lm_head.weight", "lm_head.weight"), + ): + t = lazy.get(src) + if t is not None: + extras[dst] = t.detach().clone() + else: + print(f" WARN: missing {src}", flush=True) + if "lm_head.weight" not in extras and "embed_tokens.weight" in extras: + extras["lm_head.weight"] = extras["embed_tokens.weight"].detach().clone() + finally: + lazy.close() + save_shard(extras, args.save_path, "model_extras.safetensors", weight_map_out) + del extras + + total_size = 0 + for f in set(weight_map_out.values()): + total_size += os.path.getsize(os.path.join(args.save_path, f)) + index = { + "metadata": {"total_size": total_size}, + "weight_map": weight_map_out, + } + with open( + os.path.join(args.save_path, "model.safetensors.index.json"), "w" + ) as fh: + json.dump(index, fh, indent=2) + + for name in sorted(os.listdir(args.hf_model_path)): + if name.endswith(".safetensors"): + continue + if name == "model.safetensors.index.json": + continue + src = os.path.join(args.hf_model_path, name) + if os.path.isfile(src): + shutil.copy(src, os.path.join(args.save_path, name)) + + print( + f"\n[par] DONE. total_size={total_size/1e9:.2f} GB " + f"tensors={len(weight_map_out)} -> {args.save_path}", + flush=True, + ) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/MiMo-V2.5-Pro/src/conversion_script/run_preprocess_parallel.sh b/contrib/models/MiMo-V2.5-Pro/src/conversion_script/run_preprocess_parallel.sh new file mode 100755 index 00000000..5001ecd6 --- /dev/null +++ b/contrib/models/MiMo-V2.5-Pro/src/conversion_script/run_preprocess_parallel.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Parallel wrapper around preprocess_mimo_v2_parallel.py. +# +# Each worker dequants one MoE layer at a time (peak ~25 GB per layer on +# V2.5-Pro's 6144 hidden / 384 experts / 2048 intermediate shape). 12 +# workers stay under ~300 GB CPU RAM on a 2 TB box while keeping the +# 192-core CPU busy. On a trn2.48xl that brings total wall time from +# ~30 min (serial) to ~5-6 min. +# +# Env: +# HF_MODEL_PATH raw HF checkpoint (default: /opt/dlami/nvme/models/MiMo-V2.5-Pro) +# SAVE_PATH output Neuron checkpoint (default: /opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8) +# TP_DEGREE tensor-parallel degree used at compile time (default: 64) +# N_WORKERS concurrent layer workers (default: 12) +# VENV venv with torch + safetensors + contrib pkg on sys.path +set -e + +VENV=${VENV:-/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference} +source "$VENV/bin/activate" + +HF_MODEL_PATH=${HF_MODEL_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro} +SAVE_PATH=${SAVE_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8} +TP_DEGREE=${TP_DEGREE:-64} +N_WORKERS=${N_WORKERS:-12} + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SRC_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" +export PYTHONPATH="$SRC_DIR:$PYTHONPATH" + +exec python3 "$SCRIPT_DIR/preprocess_mimo_v2_parallel.py" \ + --hf_model_path "$HF_MODEL_PATH" \ + --save_path "$SAVE_PATH" \ + --tp_degree "$TP_DEGREE" \ + --workers "$N_WORKERS" From b455a1ef850cf3a91a76fef24ae07d23181244cf Mon Sep 17 00:00:00 2001 From: whn09 Date: Tue, 28 Apr 2026 12:16:04 +0800 Subject: [PATCH 06/24] [contrib] MiMo-V2.5-Pro: set AWS Llama-405B FP8 env vars in smoke scripts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AWS trn2-llama3.1-405b-speculative FP8 tutorial ("Scenario 2, Step 2") requires XLA_HANDLE_SPECIAL_SCALAR=1 and UNSAFE_FP8FNCAST=1 for OCP-sourced FP8 checkpoints on Neuron. Setting them in both smoke_compile_mimo_v2.py and smoke_generate_mimo_v2.py via os.environ.setdefault (user-level env overrides still win). Note: our preprocess output has 0 bytes in the IEEE-NaN-adjacent range (byte exp=0b1111), verified on layers.1 attn q/k/v and MoE gate_up/down in /opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8. So these flags are theoretically optional for our pipeline, but they match the exact surface of AWS's reference FP8 tutorial — cheap safety. Also corrected stale docstrings: smoke_compile now says the NxDI venv (pytorch_2_9_nxd_inference) is the target, not the vllm venv. Co-Authored-By: Claude Opus 4.7 --- .../perf_test/smoke_compile_mimo_v2.py | 16 ++++++++++++++-- .../perf_test/smoke_generate_mimo_v2.py | 6 +++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py index ef5d0049..c92946e4 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py @@ -16,8 +16,9 @@ forward pass that allocates the shared scratchpad — useful when HBM is tight. -Run under /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16 (same venv used -by the bench script). +Run under /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference (NxDI direct). +The `/opt/aws_neuronx_venv_pytorch_inference_vllm_0_16` venv is only +needed for vllm serving (bench_mimo_v2.sh). """ import os @@ -25,6 +26,17 @@ import time import traceback +# AWS Llama-3.1-405B FP8 tutorial requires these two env vars to correctly +# handle OCP-derived FP8 checkpoints on Neuron: XLA_HANDLE_SPECIAL_SCALAR=1 +# opts in to XLA emitting the bit-reinterpretation path for fp8_e4m3fn scalars, +# and UNSAFE_FP8FNCAST=1 mirrors it for torch-side casts. Our preprocess output +# has 0 bytes in the IEEE-NaN range (verified 2026-04-28), so these flags are +# theoretically unnecessary, but setting them matches the AWS tutorial +# surface exactly. Source: trn2-llama3.1-405b-speculative-tutorial.html +# "Scenario 2, Step 2". +os.environ.setdefault("XLA_HANDLE_SPECIAL_SCALAR", "1") +os.environ.setdefault("UNSAFE_FP8FNCAST", "1") + MODEL_PATH = os.environ.get( "MIMO_V25_PRO_MODEL_PATH", "/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8", diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py index b3318714..2eb7ab13 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py @@ -7,7 +7,7 @@ single prompt via HuggingFaceGenerationAdapter. Purpose: sanity-check that the FP8 MoE + preprocessed scales actually produce coherent tokens. -Run under /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16. +Run under /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference. """ import os @@ -15,6 +15,10 @@ import time import traceback +# Same env vars as smoke_compile — match AWS Llama-405B FP8 tutorial. +os.environ.setdefault("XLA_HANDLE_SPECIAL_SCALAR", "1") +os.environ.setdefault("UNSAFE_FP8FNCAST", "1") + MODEL_PATH = os.environ.get( "MIMO_V25_PRO_MODEL_PATH", "/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8", From 9d29bbf27864fa228ba67f4deb5e9e5ef26654e8 Mon Sep 17 00:00:00 2001 From: whn09 Date: Tue, 28 Apr 2026 14:06:32 +0800 Subject: [PATCH 07/24] [contrib] MiMo-V2.5-Pro: fix tokenizer padding_side='left' for decoder-only LM HuggingFace tokenizer defaults to padding_side='right', which silently corrupts batched prefill on a causal LM: the last token of each slot becomes a pad token, and the logit used for generating the next token is predicting "what comes after the pad", not "what comes after the real prompt". Observed when running a 6-prompt probe at BS=48: prompts that nearly fill the 267-token batch dimension produced garbage output like "all spaces" (token 220) or random short-id BPE noise. Fix: explicitly set padding_side='left' after tokenizer load. Single-prompt smoke (all slots == same prompt, so no padding triggered) was not affected by this bug, but was producing wrong output for a different reason (the underlying FP8 expert-MLP precision issue). Co-Authored-By: Claude Opus 4.7 --- .../MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py index 2eb7ab13..9f678a09 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py @@ -15,7 +15,9 @@ import time import traceback -# Same env vars as smoke_compile — match AWS Llama-405B FP8 tutorial. +# AWS Llama-3.1-405B FP8 tutorial env vars — these are RUNTIME flags that +# affect XLA's fp8 special-scalar handling. Safe to set at generate time; +# do NOT set at compile time (it changes HLO and busts the neuronx-cc cache). os.environ.setdefault("XLA_HANDLE_SPECIAL_SCALAR", "1") os.environ.setdefault("UNSAFE_FP8FNCAST", "1") @@ -134,6 +136,10 @@ def main(): print(f"[gen] Loaded in {time.time() - t0:.1f}s") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + # Decoder-only LM requires left-padding so the last token of each batch + # slot is the real prompt ending, not a pad token. Default HF tokenizer + # padding_side is 'right' which silently corrupts batched prefill. + tokenizer.padding_side = "left" adapter = HuggingFaceGenerationAdapter(model) # When CHAT_TEMPLATE=1, wrap the raw prompt in the checkpoint's chat From 2a4c9ff11ad2e248e93e40a977003dc8b80ae0a9 Mon Sep 17 00:00:00 2001 From: whn09 Date: Tue, 28 Apr 2026 17:05:48 +0800 Subject: [PATCH 08/24] [contrib] MiMo-V2.5-Pro: skip FP8 quant on q/k/v_proj (attention path) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit V2.5-Pro's attention Q/K/V weights have dequantized abs_mean ~0.001-0.005, roughly 4x smaller than V2.5 (which works). Preprocess has been patched to rewrite the q/k/v_proj tensors in the preprocessed checkpoint as BF16 (matching how o_proj is already handled). Add q_proj/k_proj/v_proj to modules_to_not_convert so NxDI does not try to swap them to QuantizedColumnParallel at convert() time — they remain plain ColumnParallelLinear with BF16 weights. MoE expert weights (gate_up_proj, down_proj) stay FP8 blockwise; their weights saturate the full FP8 ±240 range so quantization is lossless there. Only the attention path goes BF16. Co-Authored-By: Claude Opus 4.7 --- .../perf_test/smoke_compile_mimo_v2.py | 19 ++++++++----------- .../perf_test/smoke_generate_mimo_v2.py | 2 +- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py index c92946e4..6b5fc2e7 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py @@ -26,16 +26,13 @@ import time import traceback -# AWS Llama-3.1-405B FP8 tutorial requires these two env vars to correctly -# handle OCP-derived FP8 checkpoints on Neuron: XLA_HANDLE_SPECIAL_SCALAR=1 -# opts in to XLA emitting the bit-reinterpretation path for fp8_e4m3fn scalars, -# and UNSAFE_FP8FNCAST=1 mirrors it for torch-side casts. Our preprocess output -# has 0 bytes in the IEEE-NaN range (verified 2026-04-28), so these flags are -# theoretically unnecessary, but setting them matches the AWS tutorial -# surface exactly. Source: trn2-llama3.1-405b-speculative-tutorial.html -# "Scenario 2, Step 2". -os.environ.setdefault("XLA_HANDLE_SPECIAL_SCALAR", "1") -os.environ.setdefault("UNSAFE_FP8FNCAST", "1") +# NOTE: AWS Llama-3.1-405B FP8 tutorial recommends XLA_HANDLE_SPECIAL_SCALAR=1 +# and UNSAFE_FP8FNCAST=1 for OCP-derived FP8 checkpoints. Setting them at +# compile time, however, appears to change the HLO that gets emitted (likely +# because XLA lowering of fp8_e4m3fn special scalars switches paths), which +# busts the neuronx-cc cache and forces a full recompile (~90 min). If you +# need these flags, set them ONLY at generate time and rely on the compiled +# NEFF's built-in `--experimental-unsafe-fp8e4m3fn-as-fp8e4m3` handling. MODEL_PATH = os.environ.get( "MIMO_V25_PRO_MODEL_PATH", @@ -150,7 +147,7 @@ def main(): "lm_head", "norm", "router", - "o_proj", + "o_proj", "q_proj", "k_proj", "v_proj", ], ) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py index 9f678a09..762fb0d8 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py @@ -114,7 +114,7 @@ def main(): "lm_head", "norm", "router", - "o_proj", + "o_proj", "q_proj", "k_proj", "v_proj", ], ) From 2b8b5771c53e227b0e3bc7dcbeb3af3862fa7608 Mon Sep 17 00:00:00 2001 From: whn09 Date: Tue, 28 Apr 2026 19:54:56 +0800 Subject: [PATCH 09/24] [contrib] MiMo-V2.5-Pro: try use_torch_block_wise + restore FP8 q/k/v Previous attempt (2a4c9ff) rewrote q/k/v_proj as BF16 to work around Pro's attention weight precision (q_proj abs_mean ~0.00124, 4x smaller than V2.5). Compile succeeded, but load failed with HBM OOM: the BF16 attention weights added ~2 GB per rank, pushing Tensors to 20.93 GB on a 24 GB Neuron HBM and leaving no room for collective DMA rings. Back off on the BF16-attn approach and try a different hypothesis: the NKI blockwise matmul kernel has accumulator precision issues on Pro's MoE expert weights (scale_mean ~5e-5 vs 2.5e-4 on V2.5). Switch blockwise_matmul_config from use_shard_on_block_dynamic_while to use_torch_block_wise=True, which uses a PyTorch fallback that dequantizes each block to BF16 before matmul. Slower but more precise in the accumulator. q/k/v_proj return to FP8 (back out of modules_to_not_convert) so the attention weights don't blow HBM. Co-Authored-By: Claude Opus 4.7 --- .../perf_test/smoke_compile_mimo_v2.py | 17 +++++++++++------ .../perf_test/smoke_generate_mimo_v2.py | 7 ++++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py index 6b5fc2e7..be7f72e7 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py @@ -124,12 +124,17 @@ def main(): moe_tp_degree=MOE_TP, context_encoding_buckets=[SEQ_LEN], router_config={"act_fn": "sigmoid", "dtype": "float32"}, - # SDK 2.29 ships only bwmm_shard_on_block / bwmm_shard_on_intermediate; - # default routes to _call_shard_hidden_kernel which is missing, so we - # take the shard-on-block path via this flag. Matches Flash + Kimi. + # V2.5-Pro experiment: use PyTorch-fallback blockwise matmul instead + # of the NKI _call_shard_hidden_kernel path. Rationale: Pro's MoE + # expert weight scales run ~3-7x smaller than V2.5 (scale_mean + # ~5e-5 vs 2.5e-4); the NKI blockwise kernel's lower accumulator + # precision compounds across 70 layers into decode-time gibberish. + # PyTorch fallback dequantizes each block to BF16 before matmul, + # trading throughput for accumulator precision. blockwise_matmul_config={ - "use_shard_on_block_dynamic_while": True, - "block_sharding_strategy": "PING_PONG", + "use_torch_block_wise": True, + "use_shard_on_intermediate_dynamic_while": True, + "skip_dma_token": True, }, # Persist sharded FP8 weights to disk so subsequent load()s skip the # ~10-minute shard_checkpoint step (writes weights/tp{0..63}_*.safetensors @@ -147,7 +152,7 @@ def main(): "lm_head", "norm", "router", - "o_proj", "q_proj", "k_proj", "v_proj", + "o_proj", ], ) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py index 762fb0d8..c3388e89 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py @@ -99,8 +99,9 @@ def main(): context_encoding_buckets=[SEQ_LEN], router_config={"act_fn": "sigmoid", "dtype": "float32"}, blockwise_matmul_config={ - "use_shard_on_block_dynamic_while": True, - "block_sharding_strategy": "PING_PONG", + "use_torch_block_wise": True, + "use_shard_on_intermediate_dynamic_while": True, + "skip_dma_token": True, }, save_sharded_checkpoint=True, quantized=True, @@ -114,7 +115,7 @@ def main(): "lm_head", "norm", "router", - "o_proj", "q_proj", "k_proj", "v_proj", + "o_proj", ], ) From b3a8487f08437bed3858ad781e434d0cb9c57da0 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 29 Apr 2026 10:30:07 +0800 Subject: [PATCH 10/24] [contrib] MiMo-V2.5-Pro: wire up vLLM serving and record FP8 perf Pro is now serveable via vllm-neuron 0.5.0 on Trn2 (TP=64, moe_ep=64, BS=48). Output quality under the FP8 recipe is still prompt-dependent (drift on most prompts, coherent on self-intro style), consistent with Pro's 4-7x smaller MoE FP8 scales compared to V2.5 and the V2-Pro symptom. Changes: - Revert blockwise_matmul_config back to use_shard_on_block_dynamic_while + PING_PONG (Flash/Kimi recipe). The use_torch_block_wise + BF16-attn experiments both OOM on load. - Fix bench_mimo_v2.sh / smoke configs from BS=32 (Flash) to BS=48 (Pro: 384/8=48), plus all accompanying text in the README. - vLLM patch now registers both MiMoV2FlashForCausalLM and MiMoV2ProForCausalLM in vLLM's ModelRegistry, overriding the built-in GPU stubs; patch works against vllm-neuron release-0.5.0. - Point sanity_check.sh, run_bench_single.sh, 0_setup.sh defaults at the Neuron-FP8 checkpoint (not BF16). - Record measured vLLM serving throughput at c=1/16/48 in the README Performance section (replaces stale BF16 numbers). - Rewrite the Status section: document the drift pattern with prompt examples, the recipes that were tried and failed (BF16-attn, torch blockwise), and the two-node BF16 experiment queued next. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2.5-Pro/README.md | 85 +++++++++---------- .../models/MiMo-V2.5-Pro/perf_test/0_setup.sh | 8 +- .../MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh | 19 +++-- .../perf_test/run_bench_single.sh | 6 +- .../MiMo-V2.5-Pro/perf_test/sanity_check.sh | 2 +- .../perf_test/smoke_compile_mimo_v2.py | 15 ++-- .../perf_test/smoke_generate_mimo_v2.py | 5 +- .../perf_test/vllm-neuron-patch.patch | 20 +++-- 8 files changed, 79 insertions(+), 81 deletions(-) diff --git a/contrib/models/MiMo-V2.5-Pro/README.md b/contrib/models/MiMo-V2.5-Pro/README.md index 1b35c567..7078e480 100644 --- a/contrib/models/MiMo-V2.5-Pro/README.md +++ b/contrib/models/MiMo-V2.5-Pro/README.md @@ -37,16 +37,21 @@ Key features: ## Status (work-in-progress) -**This port compiles cleanly and loads on Trn2 but does not yet produce coherent output.** Current symptoms under the default recipe (`tp_degree=64, moe_tp_degree=1, moe_ep_degree=64, batch_size=48, seq_len=1024`) on 2026-04-28: +**This port compiles cleanly and serves via vLLM on Trn2, but output quality is not production-ready under the default FP8 recipe.** Sanity checks on 2026-04-29 under `tp_degree=64, moe_tp_degree=1, moe_ep_degree=64, batch_size=48, seq_len=1024`: -- Prefill drifts from token 1: "Explain in one sentence what a transformer neural network is." → `"100% of the time, 100% of the time, ..."` (greedy decode, temperature=0). Decode speed 0.72 tok/s. Note the output is syntactically valid English ("100% of the time" = BPE tokens `15/16/4/315/279/882` = `"1"/"0"/"%"/" of"/" the"/" time"`, all high-frequency) — not a sampling bug: greedy argmax is correctly picking the model's top token, but the logit distribution itself is wrong (output unrelated to the prompt). Same signature as Jim Burtoft's "Flash FP8 → `erotici` repeat" symptom. -- Same failure pattern was observed on the MiMo-V2-Pro port under the same recipe (`"0.0.0.0:8080"` etc.). Root cause identified there appears to be **Neuron's NKI blockwise FP8 compute kernel handling Pro's tight expert-weight distribution** (std ≈ 0.0018, ~10× smaller than Flash's ≈ 0.019). This is NOT an FP8-format problem per se — sglang on H100/H200 runs the exact same OCP FP8 checkpoint and produces correct output, because GPU paths dequantize FP8→BF16 before the matmul. Neuron NKI does FP8 compute directly and seems to lose precision on subnormal-leaning tensors. -- V2.5-Pro's MoE expert weights are byte-identical to V2-Pro (verified layer 1 expert 0 dequant stats match to 6 decimal places on 2026-04-28), so all V2-Pro workarounds remain required (router bias mean-subtract, qkv interleaved split, `_apply_2d_per_channel_fix`, `_apply_blockwise_scale_stride_fix`). +- Prompt-dependent drift. Self-intro style prompts ("introduce yourself in one sentence") return coherent Chinese output; most other prompts collapse to repetition or unrelated text within a few tokens (e.g. `"The capital of France is\n# 1000000000000000"`, `"Once upon a time in a small village there lived\n# 0000000000..."`). +- Same failure pattern was observed on the MiMo-V2-Pro port under the same recipe (`"0.0.0.0:8080"`, etc.). Root cause appears to be **Neuron's NKI blockwise FP8 compute kernel handling Pro's tight expert-weight distribution**: Pro's per-expert weight `abs_mean ≈ 0.00124` and blockwise `scale_mean ≈ 2.3e-5` are 4-7× smaller than V2.5 (256 experts) on the same recipe, and the NKI accumulator loses precision compounded across 70 layers. V2.5-Pro's MoE expert weights are byte-identical to V2-Pro (verified layer 1 expert 0 dequant stats match to 6 decimal places on 2026-04-28), so all V2-Pro workarounds remain required (router bias mean-subtract, qkv interleaved split, `_apply_2d_per_channel_fix`, `_apply_blockwise_scale_stride_fix`). +- GPU stacks (sglang on H100/H200) run the exact same OCP FP8 checkpoint correctly, because they dequantize FP8→BF16 before the matmul. Neuron NKI does FP8 compute directly and drifts on subnormal-leaning tensors. - Reference: Jim Burtoft observed similar prompt-dependent FP8 degradation on Flash and his Kimi PR #131 names "blockwise kernel padding produces depressed logits with EP=2 on SDK 2.29; SDK 2.28 recommended". -Other recipes to try (none verified yet on V2.5-Pro): -- `moe_tp_degree=16, moe_ep_degree=4, BS=48` — balances E_local=96 vs HBM. -- `moe_tp_degree=32, moe_ep_degree=2, BS=1` — mirrors Jim's Kimi-K2 PR, but V2-Pro OOM'd by 28MB on load at BS=48; BS=1 hits `NotImplementedError: Selective Loading with Expert parallelism`. +Recipes that were tried and did not resolve the drift (all on 2026-04-28/29): +- **q/k/v_proj dequant to BF16** (compile-time `modules_to_not_convert` includes `q_proj, k_proj, v_proj`): compiled but HBM OOM on load — BF16 attention weights pushed per-rank tensors to 20.93/24 GB, leaving no room for collective DMA rings. +- **`use_torch_block_wise=True`** (PyTorch-fallback blockwise matmul for higher accumulator precision): compile+shard succeeded after ~2 h, but `model.load()` crashed with `status=4 Allocation Failure` — the fallback path raises HBM demand. + +Next experiments queued (none verified yet): +- BF16 weights across two Trn2 instances with cross-node TP/PP (single-instance HBM cannot hold BF16 Pro). +- SDK 2.28 venv test once installed, per Kimi PR #131. +- Selective BF16 only on `gate_up_proj` (smallest MoE scales) while keeping `down_proj` FP8, if HBM allows. Known NxDI limits that constrain recipe choice: - `BS * top_k / num_experts >= 1.0` required when `moe_ep_degree > 1` at decode (else NotImplementedError). With `num_experts=384, top_k=8` this forces `BS >= 48`. @@ -114,8 +119,8 @@ bash contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh bash contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh ``` -The bench script runs two configurations (BS=32 and BS=128, both -`moe_tp_degree=X / moe_ep_degree=Y (see bench script)`) and logs results under +The bench script runs two configurations (BS=48 and BS=128, both +`moe_tp_degree=1 / moe_ep_degree=64`) and logs results under `/tmp/bench_results/mimo_v25_pro/`. For a quick `curl` sanity check while the server is up: @@ -182,10 +187,10 @@ neuron_config = MoENeuronConfig( ep_degree=1, # keep outer EP = 1; only MoE-internal EP varies moe_tp_degree=1, moe_ep_degree=64, - batch_size=32, # must be >= num_experts / top_k = 256 / 8 = 32 - max_batch_size=32, + batch_size=48, # must be >= num_experts / top_k = 384 / 8 = 48 + max_batch_size=48, ctx_batch_size=1, - tkg_batch_size=32, + tkg_batch_size=48, seq_len=1024, n_active_tokens=128, torch_dtype=torch.bfloat16, @@ -245,15 +250,15 @@ Both default to the recommended FP8 recipe (`moe_tp=1`, `moe_ep=64`). ### moe_tp_degree = 1, moe_ep_degree = 64 -**Why**: at `moe_tp_degree=64` each rank owns 1/64 of the intermediate dim, which for Flash (MoE intermediate = 2048) is 32 rows — **below the 128-row blockwise scale block**. NxDI's `_setup_for_scale` detects `weight_shape[axis] < block_size` and collapses the per-rank scale dim to 1, losing per-channel FP8 scale granularity. The resulting drift compounds across Flash's 47 MoE layers and manifests as output collapse ("helpful helpful helpful ...") after roughly 30 decode tokens. +**Why**: at `moe_tp_degree=64` each rank owns 1/64 of the intermediate dim, which for MiMo-V2.5-Pro (MoE intermediate = 2048) is 32 rows — **below the 128-row blockwise scale block**. NxDI's `_setup_for_scale` detects `weight_shape[axis] < block_size` and collapses the per-rank scale dim to 1, losing per-channel FP8 scale granularity. The resulting drift compounds across Pro's 69 MoE layers and manifests as output collapse ("helpful helpful helpful ...") after roughly 30 decode tokens. -`moe_tp_degree=1, moe_ep_degree=64` keeps each expert's weights and blockwise scales intact on a single rank (4 experts per rank), which preserves per-channel scale and produces correct output even on long multi-turn prompts. +`moe_tp_degree=1, moe_ep_degree=64` keeps each expert's weights and blockwise scales intact on a single rank (6 experts per rank for Pro's 384 experts), which preserves per-channel scale. On V2.5 (256 experts) this recipe yields coherent output; on V2.5-Pro it still exhibits prompt-dependent drift (see Status). -Intermediate ratios (`moe_tp=32/ep=2` or `moe_tp=16/ep=4`) have been empirically tested and still produce gibberish, so this is the only currently-supported moe_tp/ep combination for MiMo-V2.5-Pro FP8. +Intermediate ratios (`moe_tp=32/ep=2`, `moe_tp=16/ep=4`) have been empirically tested and still produce gibberish, so `moe_tp=1/moe_ep=64` is the only currently-usable moe_tp/ep combination. -### batch_size >= 32 +### batch_size >= 48 -NxDI's TKG (token generation) path refuses Expert Parallelism when `batch_size < num_experts / top_k`. For Flash that is 256 / 8 = 32, so the smallest working BS on the FP8 path is 32. BS=1 latency demos are not currently possible on FP8; use the BF16 checkpoint with `moe_tp=64, moe_ep=1, batch_size=1` for single-stream latency measurements. +NxDI's TKG (token generation) path refuses Expert Parallelism when `batch_size < num_experts / top_k`. For Pro that is 384 / 8 = 48, so the smallest working BS on the FP8 path is 48. BS=1 latency demos are not possible on FP8; use the BF16 checkpoint with `moe_tp=64, moe_ep=1, batch_size=1` for single-stream latency measurements. ### outer ep_degree = 1 @@ -272,7 +277,7 @@ MiMo-V2.5-Pro can be served via [vllm-neuron](https://github.com/aws-neuron/vllm bash contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh ``` -The patch (`perf_test/vllm-neuron-patch.patch`) is 40 lines and only touches `vllm_neuron/__init__.py`. It adds a `_register_contrib_models()` hook that, when `NXDI_CONTRIB_MIMO_V2_FLASH_SRC` is set, registers `NeuronMiMoV2ForCausalLM` into NxDI's `MODEL_TYPES` under the key `mimo_v2` **and** registers the `MiMoV2ForCausalLM` architecture into vLLM's `ModelRegistry`. No upstream vLLM or NxDI source is modified. +The patch (`perf_test/vllm-neuron-patch.patch`) touches `vllm_neuron/worker/neuronx_distributed_model_loader.py`. It adds a `_register_contrib_models()` hook that, when `NXDI_CONTRIB_MIMO_V2_FLASH_SRC` is set, registers `NeuronMiMoV2ForCausalLM` into NxDI's `MODEL_TYPES` under keys `mimov2flash` **and** `mimov2pro`, **and** overrides vLLM's built-in `MiMoV2FlashForCausalLM` / `MiMoV2ProForCausalLM` (GPU-only stubs) in `ModelRegistry` with the Neuron wrapper so ModelConfig validation accepts either architecture. No upstream vLLM or NxDI source is modified. The checkpoint's `config.json` must set `architectures` to `["MiMoV2ProForCausalLM"]` (or `MiMoV2FlashForCausalLM` for V2.5); the preprocess script takes care of this. ### Serving (FP8, recommended) @@ -289,7 +294,7 @@ python3 -m vllm.entrypoints.openai.api_server \ --model "$MIMO_V2_FLASH_PATH" \ --tensor-parallel-size 64 \ --max-model-len 1024 \ - --max-num-seqs 32 \ + --max-num-seqs 48 \ --no-enable-chunked-prefill \ --no-enable-prefix-caching \ --trust_remote_code \ @@ -313,9 +318,9 @@ python3 -m vllm.entrypoints.openai.api_server \ "blockwise_matmul_config": {"use_shard_on_block_dynamic_while": true, "block_sharding_strategy": "PING_PONG"}, "moe_tp_degree": 1, "moe_ep_degree": 64, - "batch_size": 32, + "batch_size": 48, "ctx_batch_size": 1, - "tkg_batch_size": 32, + "tkg_batch_size": 48, "max_context_length": 1024, "seq_len": 1024, "is_continuous_batching": true, @@ -330,39 +335,33 @@ python3 -m vllm.entrypoints.openai.api_server \ }' ``` -See `perf_test/bench_mimo_v2.sh` for the full benchmark recipe at BS=32 and BS=128. +See `perf_test/bench_mimo_v2.sh` for the full benchmark recipe at BS=48 and BS=128. ### vllm-neuron patch summary The patch is applied to vllm-neuron 0.5.0 and: -- Maps the `MiMoV2ForCausalLM` architecture to Flash's model loader (reusing the Qwen2-family loader path, which Flash's tokenizer inherits from). -- Passes `hf_config` from vLLM into `load_pretrained_config` so NxDI does not re-load the config without `trust_remote_code=True`. -- Replaces vllm-neuron's internal `AutoModelForCausalLM.from_pretrained` call with `huggingface_hub.snapshot_download`, which is the only path that works for `trust_remote_code=True` models when no GPU is available for HF's CUDA-gated FP8 quantizer. +- Patches `AutoConfig.from_pretrained` to default `trust_remote_code=True` so NxDI's `hf_adapter.load_config` can load the `MiMoV2Config` custom code that ships with the checkpoint. +- Registers `NeuronMiMoV2ForCausalLM` into NxDI's `MODEL_TYPES` under `mimov2flash` and `mimov2pro` so the NxDI loader resolves either model_type to the contrib Neuron wrapper. +- Overrides vLLM's built-in `MiMoV2FlashForCausalLM` and `MiMoV2ProForCausalLM` GPU stubs in `ModelRegistry`, since vLLM's ModelConfig validator rejects any architecture not in its registry and the Neuron path never instantiates vLLM's stub class anyway. ## Performance -> These numbers are from the earlier BF16 recipe (pre-FP8 rollout). FP8 numbers will be added once a stable bench run completes on the new recipe; preliminary single-stream qualitative tests show fluent multi-sentence output on long Chinese chat prompts with `moe_tp=1, moe_ep=64, batch_size=32`. - -### Standalone NxDI (trn2.48xlarge, BF16, TP=64, EP=64) +> The throughput numbers below are from a working vLLM server run on 2026-04-29 under the recommended FP8 recipe. Output quality under this recipe is **not production-usable** (see Status); the numbers show that the serving infrastructure runs end-to-end, not that the model answers correctly. -| Batch Size | Throughput (tok/s) | -|------------|-------------------| -| 1 | 29.92 | -| 8 | 215.94 | -| 32 | 649.14 | +### vLLM Serving (trn2.48xlarge, FP8, BS=48, TP=64, moe_tp=1/moe_ep=64, CB + bucketing) -### vLLM Serving (trn2.48xlarge, BF16, BS=32, TP=64/EP=64, CB) +Input/output: 900/90 tokens (`vllm bench serve --dataset-name random`), `on_device_sampling_config={do_sample:true, temperature:0.6, top_k:20, top_p:0.95}`. -Input/output: 900/90 tokens (random dataset) +| Concurrency | Total tok/s | Output tok/s | TTFT median (ms) | TTFT P99 (ms) | TPOT median (ms) | +|-------------|-------------|--------------|------------------|---------------|------------------| +| 1 | 47 | 4.3 | 1,392 | 1,393 | 220 | +| 16 | 391 | 35.6 | 2,361 | 17,394 | 422 | +| 48 | 606 | 55 | 7,322 | 54,413 | 752 | -| Concurrency | Throughput (tok/s) | TPOT (ms) | TTFT (ms) | -|-------------|-------------------|-----------|-----------| -| 1 | 27.98 | 33.65 | 222 | -| 16 | 224.57 | 64.95 | 570 | -| 32 | 302.61 | 90.23 | 1351 | +Per-stream ITL median holds at ~220 ms across all concurrency levels; TPOT/TTFT growth at higher concurrency comes from continuous-batching queue pressure, not per-step compute. -> **Compile time:** the first Flash compile on SDK 2.29 is ~30-60 minutes for the TKG NEFF and similar for the CTE NEFF. Subsequent runs with the same `override_neuron_config` hit the neuronx-cc cache and start in ~1-2 minutes. `save_sharded_checkpoint=true` additionally persists per-rank FP8 shards under `/weights/`, letting future `load()` calls skip the ~10-minute shard_checkpoint pass. +> **Compile time:** the first Pro compile on SDK 2.29 is ~60 minutes for the TKG NEFF and ~15 minutes for the CTE NEFF; subsequent runs with the same `override_neuron_config` hit the neuronx-cc cache and start in ~1-2 minutes. `save_sharded_checkpoint=true` additionally persists per-rank FP8 shards under `/weights/`, letting future `load()` calls skip the ~10-minute shard_checkpoint pass. First full server launch (compile + shard + warmup) is ~2 hours wall-clock. ## Compatibility Matrix @@ -385,7 +384,7 @@ pytest contrib/models/MiMo-V2.5-Pro/test/integration/test_model.py -v 3. **Attention Sink Bias**: Learnable per-head bias added as an extra "sink" column to attention scores in sliding window layers (not added in full-attention layers). Per-rank slicing of the bias happens inside `forward()` based on `parallel_state.get_tensor_model_parallel_rank()`. 4. **FP8 Path Caveats**: - Must use `moe_tp_degree=1, moe_ep_degree=64` (see "FP8 Configuration Notes" above). - - Must use `batch_size >= 32` (NxDI EP>1 requirement). + - Must use `batch_size >= 48` (NxDI EP>1 requirement, `384 / 8 = 48`). - Must keep outer `ep_degree=1` (only `moe_ep_degree` should vary). - Several runtime monkey-patches (router bias, blockwise scale stride, 2D per-channel, EP scale handling) are installed automatically in `NeuronMiMoV2ForCausalLM.__init__` when `quantized=True`; the BF16 path is untouched. @@ -397,4 +396,4 @@ pytest contrib/models/MiMo-V2.5-Pro/test/integration/test_model.py -v Henan Wan (whn09) -**Last Updated:** 2026-04-25 +**Last Updated:** 2026-04-29 diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh index 5d43c529..bdcbe11c 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh @@ -39,15 +39,15 @@ pip install s5cmd python3 -c "import vllm_neuron; print('vllm-neuron installed:', vllm_neuron.__file__)" echo "" -echo "[2/2] Downloading MiMo-V2.5-Pro BF16 weights..." +echo "[2/2] Downloading MiMo-V2.5-Pro Neuron-FP8 weights..." -MIMO_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-BF16}" +MIMO_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8}" if [ -d "$MIMO_PATH" ] && [ "$(ls "$MIMO_PATH"/*.safetensors 2>/dev/null | wc -l)" -gt 0 ]; then echo " MiMo weights already exist at $MIMO_PATH, skipping download" else - echo " Downloading BF16 weights from your S3 bucket (edit the URI if needed)..." + echo " Downloading Neuron-FP8 weights from your S3 bucket (edit the URI if needed)..." mkdir -p "$MIMO_PATH" - s5cmd cp "s3://datalab/xiaomi/models/MiMo-V2.5-Pro-BF16/**" "$MIMO_PATH/" + s5cmd cp "s3://datalab/xiaomi/models/MiMo-V2.5-Pro-Neuron-FP8/**" "$MIMO_PATH/" echo " Download complete: $(du -sh $MIMO_PATH | cut -f1)" fi diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh index 69c7d417..45c38be1 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh @@ -15,7 +15,7 @@ set -e # (4 experts per rank), which preserves the blockwise scale intact. # # NxDI's TKG path refuses Expert Parallelism with BS < num_experts/top_k -# (256 / 8 = 32 for Flash), so the smallest working batch size here is 32. +# (384 / 8 = 48 for V2.5-Pro), so the smallest working batch size here is 48. # If you want BS=1 behaviour, the FP8 path is not currently supported on # this model on Trn2 — use the BF16 checkpoint with the old bench recipe # (`moe_tp_degree=64, moe_ep_degree=1, batch_size=1`). @@ -132,18 +132,19 @@ echo "Results: $RESULTS_DIR" echo "" ############################################################################### -# Config 1: BS=32, TP=64 + moe_tp=1/moe_ep=64, CB + bucketing (smallest BS -# that satisfies NxDI's Expert-Parallel BS >= num_experts/top_k requirement). +# Config 1: BS=48, TP=64 + moe_tp=1/moe_ep=64, CB + bucketing (smallest BS +# that satisfies NxDI's Expert-Parallel BS >= num_experts/top_k requirement: +# 384 / 8 = 48). ############################################################################### -CONFIG_NAME="bs32_tp64_moetp1_ep64" -echo "--- Config 1: BS=32, moe_tp=1/moe_ep=64, CB + bucketing ---" +CONFIG_NAME="bs48_tp64_moetp1_ep64" +echo "--- Config 1: BS=48, moe_tp=1/moe_ep=64, CB + bucketing ---" python3 -m vllm.entrypoints.openai.api_server \ --model "$MODEL_PATH" \ --tokenizer "$MODEL_PATH" \ --tensor-parallel-size 64 \ --max-model-len 1024 \ - --max-num-seqs 32 \ + --max-num-seqs 48 \ --no-enable-chunked-prefill \ --no-enable-prefix-caching \ --port $PORT \ @@ -153,9 +154,9 @@ python3 -m vllm.entrypoints.openai.api_server \ '"$COMMON_MIMO_CONFIG"', "moe_tp_degree": 1, "moe_ep_degree": 64, - "batch_size": 32, + "batch_size": 48, "ctx_batch_size": 1, - "tkg_batch_size": 32, + "tkg_batch_size": 48, "max_context_length": 1024, "seq_len": 1024, "is_continuous_batching": true, @@ -173,7 +174,7 @@ wait_for_server sanity_check run_bench "$CONFIG_NAME" 1 16 run_bench "$CONFIG_NAME" 16 128 -run_bench "$CONFIG_NAME" 32 128 +run_bench "$CONFIG_NAME" 48 192 stop_server ############################################################################### diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh index 45729cd2..81f978bb 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh @@ -13,8 +13,8 @@ # # Environment knobs: # PORT vLLM server port (default 8000) -# MIMO_V2_FLASH_PATH Path to the BF16 checkpoint (default -# /opt/dlami/nvme/models/MiMo-V2.5-Pro-BF16) +# MIMO_V2_FLASH_PATH Path to the Neuron-FP8 checkpoint (default +# /opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8) # CONCURRENCY --max-concurrency (default 1) # NUM_PROMPTS --num-prompts (default 16) # INPUT_LEN --random-input-len (default 900) @@ -27,7 +27,7 @@ set -e source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate -MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-BF16}" +MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8}" PORT="${PORT:-8000}" CONCURRENCY="${CONCURRENCY:-1}" NUM_PROMPTS="${NUM_PROMPTS:-16}" diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh index 9cef80ab..684211f4 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh @@ -11,7 +11,7 @@ set -e -MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-BF16}" +MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8}" PORT="${PORT:-8000}" PROMPT="${PROMPT:-What is 1+1? Answer briefly.}" MAX_TOKENS="${MAX_TOKENS:-64}" diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py index be7f72e7..a0933f7f 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py @@ -124,17 +124,12 @@ def main(): moe_tp_degree=MOE_TP, context_encoding_buckets=[SEQ_LEN], router_config={"act_fn": "sigmoid", "dtype": "float32"}, - # V2.5-Pro experiment: use PyTorch-fallback blockwise matmul instead - # of the NKI _call_shard_hidden_kernel path. Rationale: Pro's MoE - # expert weight scales run ~3-7x smaller than V2.5 (scale_mean - # ~5e-5 vs 2.5e-4); the NKI blockwise kernel's lower accumulator - # precision compounds across 70 layers into decode-time gibberish. - # PyTorch fallback dequantizes each block to BF16 before matmul, - # trading throughput for accumulator precision. + # SDK 2.29 ships only bwmm_shard_on_block / bwmm_shard_on_intermediate; + # default routes to _call_shard_hidden_kernel which is missing, so we + # take the shard-on-block path via this flag. Matches Flash + Kimi. blockwise_matmul_config={ - "use_torch_block_wise": True, - "use_shard_on_intermediate_dynamic_while": True, - "skip_dma_token": True, + "use_shard_on_block_dynamic_while": True, + "block_sharding_strategy": "PING_PONG", }, # Persist sharded FP8 weights to disk so subsequent load()s skip the # ~10-minute shard_checkpoint step (writes weights/tp{0..63}_*.safetensors diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py index c3388e89..9f678a09 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py @@ -99,9 +99,8 @@ def main(): context_encoding_buckets=[SEQ_LEN], router_config={"act_fn": "sigmoid", "dtype": "float32"}, blockwise_matmul_config={ - "use_torch_block_wise": True, - "use_shard_on_intermediate_dynamic_while": True, - "skip_dma_token": True, + "use_shard_on_block_dynamic_while": True, + "block_sharding_strategy": "PING_PONG", }, save_sharded_checkpoint=True, quantized=True, diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/vllm-neuron-patch.patch b/contrib/models/MiMo-V2.5-Pro/perf_test/vllm-neuron-patch.patch index d8a85b89..a29814d4 100644 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/vllm-neuron-patch.patch +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/vllm-neuron-patch.patch @@ -61,20 +61,24 @@ index d2099eb..0c162e4 100644 + _sys.path.insert(0, mimo_src) + try: + from modeling_mimo_v2 import NeuronMiMoV2ForCausalLM -+ MODEL_TYPES.setdefault( -+ "mimov2flash", {"causal-lm": NeuronMiMoV2ForCausalLM} -+ ) ++ # Register under both Flash and Pro model_type keys so the same ++ # NxDI wrapper serves MiMo-V2-Flash, MiMo-V2.5-Pro, and any sibling ++ # that inherits the same config. ++ for _mt in ("mimov2flash", "mimov2pro"): ++ MODEL_TYPES.setdefault(_mt, {"causal-lm": NeuronMiMoV2ForCausalLM}) + try: + from vllm.model_executor.models.registry import ModelRegistry -+ if "MiMoV2ForCausalLM" not in ModelRegistry.get_supported_archs(): -+ ModelRegistry.register_model( -+ "MiMoV2ForCausalLM", NeuronMiMoV2ForCausalLM -+ ) ++ # Override vLLM's GPU-only MiMoV2* stubs with our Neuron ++ # wrapper so ModelConfig validation accepts the architecture ++ # regardless of whether the checkpoint calls itself Flash or ++ # Pro. ++ for _arch in ("MiMoV2FlashForCausalLM", "MiMoV2ProForCausalLM"): ++ ModelRegistry.register_model(_arch, NeuronMiMoV2ForCausalLM) + except ImportError: + pass + except Exception as e: + _w.warn( -+ f"Failed to register MiMo-V2.5-Pro contrib model: {e}", ++ f"Failed to register MiMo-V2 contrib model: {e}", + category=UserWarning, + ) + From bbb1e1fd187ed6c0a855c1f2e7d053e8750f7bc4 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 29 Apr 2026 10:42:43 +0800 Subject: [PATCH 11/24] [contrib] MiMo-V2.5-Pro: fix physical NC count (128, not 32) trn2.48xlarge has 16 Trainium2 chips x 8 cores = 128 physical NeuronCores. logical_nc_config=2 halves that to 64 logical cores, which matches tp_degree=64. Previous Prerequisites line said "32 NeuronCores" which is wrong. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2.5-Pro/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contrib/models/MiMo-V2.5-Pro/README.md b/contrib/models/MiMo-V2.5-Pro/README.md index 7078e480..dcf2c30e 100644 --- a/contrib/models/MiMo-V2.5-Pro/README.md +++ b/contrib/models/MiMo-V2.5-Pro/README.md @@ -59,7 +59,7 @@ Known NxDI limits that constrain recipe choice: ## Prerequisites -- **Instance**: trn2.48xlarge (32 NeuronCores, logical_nc_config=2 → 64 logical cores) +- **Instance**: trn2.48xlarge (128 physical NeuronCores, logical_nc_config=2 → 64 logical cores) - **Neuron SDK**: 2.29 (Python 3.12, PyTorch 2.9) - **Venvs**: `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference` (for preprocess + NxDI direct smoke), `/opt/aws_neuronx_venv_pytorch_inference_vllm_0_16` (for vLLM serving). Both ship with the DLAMI. - **Disk**: ~3 TB free under `/opt/dlami/nvme` (the HF FP8 checkpoint is ~962 GB, the Neuron-FP8 preprocessed output is ~1 TB, and `save_sharded_checkpoint=true` writes another ~300-1000 GB per compiled config (varies with recipe)). From d0eb413f5da79b5d3cadd78a43f977bb68432442 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 29 Apr 2026 11:42:15 +0800 Subject: [PATCH 12/24] [contrib] MiMo-V2.5-Pro: split vLLM launcher into start/bench/sanity trio Mirror the V2.5 structure so Pro has: - start_vllm_server.sh (new): foreground launcher baking in the full override_neuron_config, persistent NEURON_COMPILED_ARTIFACTS path, and all env-var plumbing. Stays up for ad-hoc curl/sanity. - bench_mimo_v2.sh: rewritten as a one-shot composer (start_vllm_server in background + wait + sanity + run_bench_single at c=1/16/48). Replaces the old inline-launch-with-full-JSON version (~110 lines shorter). - run_bench_single.sh: default CONFIG_NAME/RESULTS_DIR brought in line with bench_mimo_v2.sh and the V2.5 port. README: - Add "Keeping a server up for ad-hoc testing" section and an Environment variables table (NXDI_CONTRIB_MIMO_V2_FLASH_SRC, NEURON_COMPILED_ARTIFACTS, BASE_COMPILE_WORK_DIR, etc.). - Replace the ~60-line inline vllm api_server invocation with pointers to start_vllm_server.sh / bench_mimo_v2.sh; the README no longer duplicates the config that lives in the scripts. - Fix "downloads Flash weights" text in the 0_setup.sh blurb (now downloads Pro Neuron-FP8 weights). - Bench results dir default moved to /opt/dlami/nvme/logs/bench_results/mimo_v2_5_pro/ to align with V2.5. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2.5-Pro/README.md | 121 +++++----- .../MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh | 228 ++++-------------- .../perf_test/run_bench_single.sh | 9 +- .../perf_test/start_vllm_server.sh | 100 ++++++++ 4 files changed, 207 insertions(+), 251 deletions(-) create mode 100644 contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh diff --git a/contrib/models/MiMo-V2.5-Pro/README.md b/contrib/models/MiMo-V2.5-Pro/README.md index dcf2c30e..b0f371e4 100644 --- a/contrib/models/MiMo-V2.5-Pro/README.md +++ b/contrib/models/MiMo-V2.5-Pro/README.md @@ -119,9 +119,52 @@ bash contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh bash contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh ``` -The bench script runs two configurations (BS=48 and BS=128, both -`moe_tp_degree=1 / moe_ep_degree=64`) and logs results under -`/tmp/bench_results/mimo_v25_pro/`. +The bench script runs one configuration (BS=48, +`moe_tp_degree=1 / moe_ep_degree=64`) at three concurrency levels (1, 16, 48) +and logs results under `/opt/dlami/nvme/logs/bench_results/mimo_v2_5_pro/`. + +### Keeping a server up for ad-hoc testing + +`bench_mimo_v2.sh` is a one-shot wrapper (launch server → sanity → +3 bench runs → teardown). If you want a long-running server to iterate +against, use the three underlying scripts separately: + +```bash +# Terminal 1: launch the server in the foreground (Ctrl-C to stop). +bash contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh + +# Terminal 2: once "Application startup complete." prints, sanity-check: +bash contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh + +# Run a single bench pass with a chosen concurrency: +CONCURRENCY=16 NUM_PROMPTS=128 \ + bash contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh +``` + +`bench_mimo_v2.sh` composes exactly these three pieces; use whichever +is more convenient. + +### Environment variables + +`0_setup.sh` prints these at the end; setting them explicitly makes the +smoke / bench / manual-launch paths all behave the same. All of them have +sensible defaults in the scripts — export them only if you want to +override or if you plan to launch vLLM outside of `bench_mimo_v2.sh`. + +**Required (at least for manual `vllm api_server` launches):** + +| Variable | Purpose | +|---|---| +| `NXDI_CONTRIB_MIMO_V2_FLASH_SRC` | Path to `contrib/models/MiMo-V2.5-Pro/src/`. `vllm-neuron`'s registration hook reads it to plug `NeuronMiMoV2ForCausalLM` into NxDI's `MODEL_TYPES` table. The `_FLASH_` suffix is kept for backward compatibility with the shared registration hook that also serves V2-Flash and V2.5. | +| `MIMO_V2_FLASH_PATH` | Preprocessed Neuron-FP8 checkpoint dir (the `--save_path` output from preprocess). Same naming rationale as above. | + +**Optional (recommended):** + +| Variable | Default | Purpose | +|---|---|---| +| `NEURON_COMPILED_ARTIFACTS` | `/opt/dlami/nvme/compiled/mimo_v2_5_pro_bs48_moetp1_ep64_fp8_vllm` | Where vLLM writes the NEFF + per-rank sharded weights. Default points at a persistent path under `/opt/dlami/nvme/compiled/` so multiple configs don't collide and runs after the nightly reboot can reuse the sharded weights. vLLM's fallback is `/neuron-compiled-artifacts//` which buries output inside the checkpoint dir. | +| `BASE_COMPILE_WORK_DIR` | `/opt/dlami/nvme/tmp/nxd_model/` | NxDI's HLO / NEFF staging workdir. Default is `/tmp/nxd_model/`, which is wiped by the nightly Trn2 reboot and can silently corrupt parallel compiles that share a basename; the pinned value lives on persistent storage and is unique per config. | +| `VLLM_ENGINE_READY_TIMEOUT_S` | `7200` | First-time compile of Pro's 384-expert MoE is ~60 min TKG + ~15 min CTE + ~30 min shard, well past vLLM's default. | For a quick `curl` sanity check while the server is up: @@ -133,11 +176,9 @@ curl -s http://localhost:8000/v1/chat/completions \ "max_tokens": 64, "temperature": 0.0}' | python3 -m json.tool ``` -If you get fluent sentence-ending output on a 30+ token generation, the -FP8 path is working correctly. If you see repetition collapse -("helpful helpful helpful..."), double-check that `moe_tp_degree=1`, -`moe_ep_degree=64`, `batch_size>=32`, and that you are loading the -preprocessed Neuron-FP8 checkpoint (not the raw HF FP8 directory). +Output quality is currently prompt-dependent under the FP8 recipe (see +Status). A successful sanity check confirms the serving path works; it +does not yet confirm that all prompts produce coherent text. ## Checkpoint Preparation @@ -272,8 +313,8 @@ MiMo-V2.5-Pro can be served via [vllm-neuron](https://github.com/aws-neuron/vllm ```bash # The setup script clones vllm-project/vllm-neuron at release-0.5.0, applies -# the contrib registration patch, installs it editable, and downloads Flash -# weights (BF16 by default; set MIMO_V2_FLASH_PATH to override). +# the contrib registration patch, installs it editable, and downloads +# Pro Neuron-FP8 weights from S3 (set MIMO_V2_FLASH_PATH to override). bash contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh ``` @@ -281,61 +322,17 @@ The patch (`perf_test/vllm-neuron-patch.patch`) touches `vllm_neuron/worker/neur ### Serving (FP8, recommended) +Use `perf_test/start_vllm_server.sh` for a foreground launch (stays up until Ctrl-C), or `perf_test/bench_mimo_v2.sh` for the one-shot launch → sanity → bench → teardown flow. Both scripts bake in the full `override_neuron_config` (TP=64, moe_tp=1, moe_ep=64, BS=48, CB + bucketing, blockwise FP8 MoE with `PING_PONG`, on-device sampling), the required env vars, and the persistent compile-artifact path. See "Keeping a server up for ad-hoc testing" above for the three-terminal workflow. + ```bash -export NXDI_CONTRIB_MIMO_V2_FLASH_SRC=/path/to/neuronx-distributed-inference/contrib/models/MiMo-V2.5-Pro/src -export MIMO_V2_FLASH_PATH=/path/to/MiMo-V2.5-Pro-Neuron-FP8 -# First-time compile of Flash's 256-expert MoE takes 30-60 minutes. -export VLLM_ENGINE_READY_TIMEOUT_S=7200 -# Optional: isolate compile cache per config so parallel Flash/Pro/etc. compiles -# don't race on the default /var/tmp/neuron-compile-cache lock files. -export NEURON_COMPILED_ARTIFACTS=/path/to/compiled/mimo_v25_pro_bs32_moetp1_ep64_fp8 - -python3 -m vllm.entrypoints.openai.api_server \ - --model "$MIMO_V2_FLASH_PATH" \ - --tensor-parallel-size 64 \ - --max-model-len 1024 \ - --max-num-seqs 48 \ - --no-enable-chunked-prefill \ - --no-enable-prefix-caching \ - --trust_remote_code \ - --additional-config '{ - "override_neuron_config": { - "tp_degree": 64, - "logical_nc_config": 2, - "fused_qkv": false, - "sequence_parallel_enabled": false, - "glu_mlp": true, - "normalize_top_k_affinities": true, - "save_sharded_checkpoint": true, - "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, - "quantized": true, - "quantized_checkpoints_path": "/path/to/MiMo-V2.5-Pro-Neuron-FP8", - "quantization_dtype": "f8e4m3", - "quantization_type": "blockwise_symmetric", - "quantization_block_axis": [1, 2], - "quantization_block_size": [128, 128], - "modules_to_not_convert": ["embed_tokens", "lm_head", "norm", "router", "o_proj"], - "blockwise_matmul_config": {"use_shard_on_block_dynamic_while": true, "block_sharding_strategy": "PING_PONG"}, - "moe_tp_degree": 1, - "moe_ep_degree": 64, - "batch_size": 48, - "ctx_batch_size": 1, - "tkg_batch_size": 48, - "max_context_length": 1024, - "seq_len": 1024, - "is_continuous_batching": true, - "enable_bucketing": true, - "context_encoding_buckets": [1024], - "token_generation_buckets": [1024], - "async_mode": true, - "on_device_sampling_config": { - "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 - } - } - }' +# One-shot launch + bench + teardown (~2 h on cold cache, ~5 min on warm cache). +bash contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh + +# Or keep the server up for interactive work: +bash contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh ``` -See `perf_test/bench_mimo_v2.sh` for the full benchmark recipe at BS=48 and BS=128. +See "Environment variables" above for all the knobs (`NEURON_COMPILED_ARTIFACTS`, `BASE_COMPILE_WORK_DIR`, etc.) and their defaults. ### vllm-neuron patch summary diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh index 45c38be1..23c3f59d 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/bench_mimo_v2.sh @@ -1,77 +1,42 @@ #!/bin/bash set -e -# MiMo-V2.5-Pro FP8 vLLM benchmark on Trn2. +# MiMo-V2.5-Pro FP8 vLLM benchmark on Trn2. One-shot wrapper: +# launch server -> sanity check -> bench at c=1,16,48 -> stop server. # -# Requires a Neuron-FP8 preprocessed checkpoint (see -# `src/conversion_script/preprocess_mimo_v2_fp8.py`). The configs below -# all use moe_tp_degree=1 / moe_ep_degree=64 (experts sharded by expert -# parallelism only, no intra-expert TP split) because moe_tp_degree=64 collapses -# the per-rank FP8 blockwise scale to a singleton — per-rank expert -# intermediate is 32 rows, below the 128-row blockwise block, so -# NxDI's `_setup_for_scale` drops per-channel scale granularity. The resulting -# drift compounds across 47 MoE layers and gives repetition / output collapse. -# Using moe_ep_degree=64 keeps all of each expert's weight + scale on one rank -# (4 experts per rank), which preserves the blockwise scale intact. +# This script composes three building blocks in perf_test/: +# start_vllm_server.sh - server launch + env-var setup (backgrounded here) +# sanity_check.sh - one-shot curl against the running server +# run_bench_single.sh - one concurrency level of `vllm bench serve` # -# NxDI's TKG path refuses Expert Parallelism with BS < num_experts/top_k -# (384 / 8 = 48 for V2.5-Pro), so the smallest working batch size here is 48. -# If you want BS=1 behaviour, the FP8 path is not currently supported on -# this model on Trn2 — use the BF16 checkpoint with the old bench recipe -# (`moe_tp_degree=64, moe_ep_degree=1, batch_size=1`). - -source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate - -MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8}" -# The NxDI contrib MiMo-V2.5-Pro modeling code is registered into vLLM / -# NxDI lookup tables by vllm-neuron's register() hook using this env var. -# Default to this contrib package's own src/ relative to the script. -: "${NXDI_CONTRIB_MIMO_V2_FLASH_SRC:=$(cd "$(dirname "$0")/.." && pwd)/src}" -export NXDI_CONTRIB_MIMO_V2_FLASH_SRC - -# First-time Flash FP8 compile takes 30-60 minutes; extend vLLM's ready -# timeout and the compiler's environment variables for FP8 numerics. -export VLLM_ENGINE_READY_TIMEOUT_S=7200 +# Use those directly if you want to keep a long-running server and iterate +# on bench parameters from another shell. +# +# Server recipe: TP=64, moe_tp=1/moe_ep=64, BS=48, continuous batching. +# BS=48 is the smallest working batch size on the FP8 path (NxDI's TKG +# path refuses Expert Parallelism with BS < num_experts/top_k = 384/8 = 48). +# BS=1 single-stream latency demos are not currently supported on Pro FP8. + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PORT="${PORT:-8000}" +RESULTS_DIR="${RESULTS_DIR:-/opt/dlami/nvme/logs/bench_results/mimo_v2_5_pro}" +CONFIG_NAME="bs48_tp64_moetp1_ep64" -PORT=8000 -RESULTS_DIR="/tmp/bench_results/mimo_v25_pro" mkdir -p "$RESULTS_DIR" -# Common neuron config shared across all MiMo-V2.5-Pro FP8 configs. -# save_sharded_checkpoint=true persists per-rank sharded weights to -# /weights/tp{N}_sharded_checkpoint.safetensors during compile; -# load() then reads those directly (~30s) instead of re-sharding the entire -# checkpoint on every vllm-neuron startup (~10+ min). -COMMON_MIMO_CONFIG='"tp_degree": 64, - "logical_nc_config": 2, - "fused_qkv": false, - "sequence_parallel_enabled": false, - "glu_mlp": true, - "normalize_top_k_affinities": true, - "save_sharded_checkpoint": true, - "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, - "quantized": true, - "quantized_checkpoints_path": "'"$MODEL_PATH"'", - "quantization_dtype": "f8e4m3", - "quantization_type": "blockwise_symmetric", - "quantization_block_axis": [1, 2], - "quantization_block_size": [128, 128], - "modules_to_not_convert": ["embed_tokens", "lm_head", "norm", "router", "o_proj"], - "blockwise_matmul_config": {"use_shard_on_block_dynamic_while": true, "block_sharding_strategy": "PING_PONG"}' - -# Helper: wait for vLLM server to be ready. First-time compilation of a -# 256-expert MoE model takes 30-90 minutes, so we poll for up to 2 hours. +# Wait for vLLM server to be ready. First-time compile of the 384-expert +# MoE model takes ~90 min and can stretch past 2 h under contention, so +# poll for up to 2 h. wait_for_server() { - echo " Waiting for vLLM server to be ready (up to 2h for first compile)..." + echo " Waiting for vLLM server on port $PORT (up to 2 h for first compile)..." local interval=10 - local max_attempts=720 # 720 * 10s = 7200s = 2h + local max_attempts=720 local start=$SECONDS for i in $(seq 1 $max_attempts); do - if curl -s http://localhost:$PORT/health > /dev/null 2>&1; then - echo " Server ready! (waited $((SECONDS - start))s)" + if curl -s "http://localhost:$PORT/health" > /dev/null 2>&1; then + echo " Server ready after $((SECONDS - start))s." return 0 fi - # Show a progress blip every minute so the user knows we're alive if [ $((i % 6)) -eq 0 ]; then echo " ...still waiting ($((SECONDS - start))s elapsed)" fi @@ -81,149 +46,42 @@ wait_for_server() { return 1 } -# Helper: run benchmark -run_bench() { - local config_name=$1 - local concurrency=$2 - local num_prompts=$3 - - echo " Benchmark: concurrency=$concurrency, prompts=$num_prompts" - vllm bench serve \ - --backend vllm \ - --model "$MODEL_PATH" \ - --tokenizer "$MODEL_PATH" \ - --endpoint /v1/completions \ - --dataset-name random \ - --num-prompts "$num_prompts" \ - --random-input-len 900 \ - --random-output-len 90 \ - --random-range-ratio 0.03 \ - --max-concurrency "$concurrency" \ - 2>&1 | tee "$RESULTS_DIR/${config_name}_c${concurrency}.txt" - echo "" -} - -# Helper: stop server stop_server() { echo " Stopping vLLM server..." pkill -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true sleep 5 } -# Helper: quick sanity check -sanity_check() { - echo " Running sanity check..." - curl -s http://localhost:$PORT/v1/chat/completions \ - -H 'Content-Type: application/json' \ - -d '{ - "messages": [{"role": "user", "content": "What is 1+1? Answer briefly."}], - "model": "'"$MODEL_PATH"'", - "max_tokens": 64, - "temperature": 0.0, - "stream": false - }' | python3 -c "import sys,json; r=json.load(sys.stdin); print(' Sanity:', r['choices'][0]['message']['content'][:100])" 2>/dev/null || echo " Sanity check: could not parse response" -} - echo "==========================================" echo "MiMo-V2.5-Pro FP8 Performance Benchmark" echo "==========================================" -echo "Model: $MODEL_PATH" +echo "Port: $PORT" echo "Results: $RESULTS_DIR" echo "" -############################################################################### -# Config 1: BS=48, TP=64 + moe_tp=1/moe_ep=64, CB + bucketing (smallest BS -# that satisfies NxDI's Expert-Parallel BS >= num_experts/top_k requirement: -# 384 / 8 = 48). -############################################################################### -CONFIG_NAME="bs48_tp64_moetp1_ep64" -echo "--- Config 1: BS=48, moe_tp=1/moe_ep=64, CB + bucketing ---" - -python3 -m vllm.entrypoints.openai.api_server \ - --model "$MODEL_PATH" \ - --tokenizer "$MODEL_PATH" \ - --tensor-parallel-size 64 \ - --max-model-len 1024 \ - --max-num-seqs 48 \ - --no-enable-chunked-prefill \ - --no-enable-prefix-caching \ - --port $PORT \ - --trust_remote_code \ - --additional-config '{ - "override_neuron_config": { - '"$COMMON_MIMO_CONFIG"', - "moe_tp_degree": 1, - "moe_ep_degree": 64, - "batch_size": 48, - "ctx_batch_size": 1, - "tkg_batch_size": 48, - "max_context_length": 1024, - "seq_len": 1024, - "is_continuous_batching": true, - "enable_bucketing": true, - "context_encoding_buckets": [1024], - "token_generation_buckets": [1024], - "async_mode": true, - "on_device_sampling_config": { - "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 - } - } - }' & +# Start the server in the background. start_vllm_server.sh handles all the +# env vars (MODEL_PATH, NEURON_COMPILED_ARTIFACTS, BASE_COMPILE_WORK_DIR, +# contrib src registration, etc.) and execs `python3 -m vllm...`. +bash "$SCRIPT_DIR/start_vllm_server.sh" & +SERVER_PID=$! +trap stop_server EXIT wait_for_server -sanity_check -run_bench "$CONFIG_NAME" 1 16 -run_bench "$CONFIG_NAME" 16 128 -run_bench "$CONFIG_NAME" 48 192 -stop_server - -############################################################################### -# Config 2: BS=128, TP=64 + moe_tp=1/moe_ep=64, CB + bucketing (throughput). -############################################################################### -CONFIG_NAME="bs128_tp64_moetp1_ep64" -echo "--- Config 2: BS=128, moe_tp=1/moe_ep=64, CB + bucketing ---" -python3 -m vllm.entrypoints.openai.api_server \ - --model "$MODEL_PATH" \ - --tokenizer "$MODEL_PATH" \ - --tensor-parallel-size 64 \ - --max-model-len 1024 \ - --max-num-seqs 128 \ - --no-enable-chunked-prefill \ - --no-enable-prefix-caching \ - --port $PORT \ - --trust_remote_code \ - --additional-config '{ - "override_neuron_config": { - '"$COMMON_MIMO_CONFIG"', - "moe_tp_degree": 1, - "moe_ep_degree": 64, - "batch_size": 128, - "ctx_batch_size": 1, - "tkg_batch_size": 128, - "max_context_length": 1024, - "seq_len": 1024, - "is_continuous_batching": true, - "enable_bucketing": true, - "context_encoding_buckets": [1024], - "token_generation_buckets": [1024], - "async_mode": true, - "on_device_sampling_config": { - "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 - } - } - }' & +# One-shot sanity check (curl the chat endpoint). +PORT="$PORT" bash "$SCRIPT_DIR/sanity_check.sh" || true -wait_for_server -sanity_check -run_bench "$CONFIG_NAME" 1 16 -run_bench "$CONFIG_NAME" 16 128 -run_bench "$CONFIG_NAME" 32 128 -run_bench "$CONFIG_NAME" 128 512 -stop_server +# Three concurrency levels. run_bench_single.sh reads knobs from the +# environment; see its header for all the options. +PORT="$PORT" RESULTS_DIR="$RESULTS_DIR" CONFIG_NAME="$CONFIG_NAME" \ + CONCURRENCY=1 NUM_PROMPTS=16 bash "$SCRIPT_DIR/run_bench_single.sh" +PORT="$PORT" RESULTS_DIR="$RESULTS_DIR" CONFIG_NAME="$CONFIG_NAME" \ + CONCURRENCY=16 NUM_PROMPTS=128 bash "$SCRIPT_DIR/run_bench_single.sh" +PORT="$PORT" RESULTS_DIR="$RESULTS_DIR" CONFIG_NAME="$CONFIG_NAME" \ + CONCURRENCY=48 NUM_PROMPTS=192 bash "$SCRIPT_DIR/run_bench_single.sh" echo "==========================================" -echo "MiMo-V2.5-Pro FP8 benchmarks complete!" +echo "MiMo-V2.5-Pro FP8 benchmark complete!" echo "Results saved to: $RESULTS_DIR" echo "==========================================" ls -la "$RESULTS_DIR" diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh index 81f978bb..2e4e4a3e 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh @@ -20,8 +20,9 @@ # INPUT_LEN --random-input-len (default 900) # OUTPUT_LEN --random-output-len (default 90) # RANGE_RATIO --random-range-ratio (default 0.03) -# CONFIG_NAME Used in the output filename (default bs1_tp64_ep1) -# RESULTS_DIR Where to dump per-run log (default /tmp/bench_results/mimo_v25_pro) +# CONFIG_NAME Used in the output filename (default bs48_tp64_moetp1_ep64) +# RESULTS_DIR Where to dump per-run log +# (default /opt/dlami/nvme/logs/bench_results/mimo_v2_5_pro) set -e @@ -34,8 +35,8 @@ NUM_PROMPTS="${NUM_PROMPTS:-16}" INPUT_LEN="${INPUT_LEN:-900}" OUTPUT_LEN="${OUTPUT_LEN:-90}" RANGE_RATIO="${RANGE_RATIO:-0.03}" -CONFIG_NAME="${CONFIG_NAME:-bs1_tp64_ep1}" -RESULTS_DIR="${RESULTS_DIR:-/tmp/bench_results/mimo_v25_pro}" +CONFIG_NAME="${CONFIG_NAME:-bs48_tp64_moetp1_ep64}" +RESULTS_DIR="${RESULTS_DIR:-/opt/dlami/nvme/logs/bench_results/mimo_v2_5_pro}" mkdir -p "$RESULTS_DIR" diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh new file mode 100644 index 00000000..6848b604 --- /dev/null +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh @@ -0,0 +1,100 @@ +#!/bin/bash +# Start the MiMo-V2.5-Pro FP8 vLLM OpenAI-compatible server in the foreground. +# +# The server stays up until you Ctrl-C it. Use sanity_check.sh and +# run_bench_single.sh in a separate shell to exercise / benchmark it. +# bench_mimo_v2.sh calls this script under the hood for its one-shot +# launch + bench + teardown flow. +# +# Recipe: TP=64, moe_tp=1/moe_ep=64, BS=48, continuous batching + bucketing. +# moe_tp=1/moe_ep=64 keeps each expert's weights and blockwise FP8 scales +# intact on a single rank (6 experts/rank for Pro's 384 experts), avoiding +# the per-rank scale collapse that comes from moe_tp=64 when intermediate=2048 +# is TP-sharded below the 128-row scale block boundary. +# +# NxDI's TKG path refuses Expert Parallelism with BS < num_experts/top_k +# (384 / 8 = 48), so BS=48 is the smallest working batch size on the FP8 +# path. BS=1 single-stream latency is not currently supported on Pro FP8. + +set -e + +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate + +MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8}" +PORT="${PORT:-8000}" + +# Contrib package src. vllm-neuron's registration hook reads this env var +# to plug NeuronMiMoV2ForCausalLM into NxDI's MODEL_TYPES table. +: "${NXDI_CONTRIB_MIMO_V2_FLASH_SRC:=$(cd "$(dirname "$0")/.." && pwd)/src}" +export NXDI_CONTRIB_MIMO_V2_FLASH_SRC + +# Persistent compile-artifact location (NEFF + per-rank sharded weights). +# Setting this overrides vLLM's fallback of +# /neuron-compiled-artifacts//. +: "${NEURON_COMPILED_ARTIFACTS:=/opt/dlami/nvme/compiled/mimo_v2_5_pro_bs48_moetp1_ep64_fp8_vllm}" +export NEURON_COMPILED_ARTIFACTS +# NxDI HLO/NEFF staging directory, pinned to persistent storage so it +# survives the nightly Trn2 reboot and a unique per-config subdir. +: "${BASE_COMPILE_WORK_DIR:=/opt/dlami/nvme/tmp/nxd_model/$(basename "$NEURON_COMPILED_ARTIFACTS")}" +export BASE_COMPILE_WORK_DIR +mkdir -p "$BASE_COMPILE_WORK_DIR" + +# First-time compile of Pro's 384-expert MoE takes ~60 min TKG + ~15 min +# CTE + ~30 min shard; plan for 2 h. +export VLLM_ENGINE_READY_TIMEOUT_S="${VLLM_ENGINE_READY_TIMEOUT_S:-7200}" + +echo "==========================================" +echo "Starting MiMo-V2.5-Pro FP8 vLLM server" +echo "==========================================" +echo " Model path: $MODEL_PATH" +echo " Port: $PORT" +echo " Compiled artifacts: $NEURON_COMPILED_ARTIFACTS" +echo " Compile work dir: $BASE_COMPILE_WORK_DIR" +echo " NXDI_CONTRIB_MIMO_V2_FLASH_SRC: $NXDI_CONTRIB_MIMO_V2_FLASH_SRC" +echo "" + +exec python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 48 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port "$PORT" \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + "tp_degree": 64, + "logical_nc_config": 2, + "fused_qkv": false, + "sequence_parallel_enabled": false, + "glu_mlp": true, + "normalize_top_k_affinities": true, + "save_sharded_checkpoint": true, + "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, + "quantized": true, + "quantized_checkpoints_path": "'"$MODEL_PATH"'", + "quantization_dtype": "f8e4m3", + "quantization_type": "blockwise_symmetric", + "quantization_block_axis": [1, 2], + "quantization_block_size": [128, 128], + "modules_to_not_convert": ["embed_tokens", "lm_head", "norm", "router", "o_proj"], + "blockwise_matmul_config": {"use_shard_on_block_dynamic_while": true, "block_sharding_strategy": "PING_PONG"}, + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 48, + "ctx_batch_size": 1, + "tkg_batch_size": 48, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "enable_bucketing": true, + "context_encoding_buckets": [1024], + "token_generation_buckets": [1024], + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' From 2a181c1e466182b5a8a7c69e3e38db4ae649180c Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 29 Apr 2026 11:53:25 +0800 Subject: [PATCH 13/24] [contrib] MiMo-V2.5-Pro: swap sanity_check default prompt to self-intro "What is 1+1?" drifts to unrelated text under the current FP8 recipe. "Introduce yourself in one sentence." is a high-signal self-identifying prompt that still answers coherently (e.g. "I'm MiMo, developed by Xiaomi LLM Core Team.") and gives a sensible first-run demo. Also drop the explicit `temperature: 0.0` from the request body: vllm-neuron honours the compile-time on_device_sampling_config, not the request-side temperature, so sanity output is always sampled at T=0.6. Note this in a comment. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh index 684211f4..ed2f6b76 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh @@ -13,7 +13,10 @@ set -e MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8}" PORT="${PORT:-8000}" -PROMPT="${PROMPT:-What is 1+1? Answer briefly.}" +# "Introduce yourself" is a high-signal self-identification prompt that the +# FP8 path answers coherently even under current MoE drift (see README +# Status). Swap PROMPT=... if you want to probe other prompts. +PROMPT="${PROMPT:-Hello! Please introduce yourself in one sentence.}" MAX_TOKENS="${MAX_TOKENS:-64}" echo "Sanity check: POST /v1/chat/completions on port $PORT" @@ -29,6 +32,10 @@ if ! curl -sf "http://localhost:$PORT/health" > /dev/null; then exit 1 fi +# NOTE: request-side `temperature` is ignored by vllm-neuron on this model: +# on-device sampling_config (set at compile time in start_vllm_server.sh as +# do_sample=true, T=0.6, top_k=20, top_p=0.95) is baked into the NEFF and +# request params don't override it. Output will be stochastic. RESPONSE=$(curl -s "http://localhost:$PORT/v1/chat/completions" \ -H 'Content-Type: application/json' \ -d "$(cat < Date: Wed, 29 Apr 2026 13:58:15 +0800 Subject: [PATCH 14/24] [contrib] MiMo-V2.5-Pro: BF16-attn recipe restores coherent output Root cause of the FP8 drift is narrowed to the attention path, not the MoE experts. Pro's q/k/v weights have abs_mean ~0.00124, 4x smaller than V2.5 (256 experts), and the NKI blockwise FP8 accumulator loses enough precision at this magnitude to drift the logits across 70 layers. Dequantizing q/k/v to BF16 while keeping MoE experts FP8 restores coherent output on smoke_generate, e.g.: Okay, the user is asking for a simple self-introduction in one sentence, with no deeper or hidden needs apparent. As MiMo, based on Xiaomi's self-developed large model, I need to respond in a friendly, positive, and helpful way that aligns with providing assistance ... Changes: - Add src/conversion_script/repatch_qkv_bf16.py (promoted from /opt/dlami/nvme/scripts/), now argparse-driven. Reads HF fused qkv_proj + weight_scale_inv, dequants per kv-head group, writes BF16 q/k/v into the preprocessed Neuron-FP8 checkpoint in place, drops scale entries from the safetensors index. ~22 min runtime. - smoke_compile_mimo_v2.py / smoke_generate_mimo_v2.py: add q_proj/k_proj/v_proj to modules_to_not_convert, drop seq_len from 1024 to 256 (BF16 q/k/v adds ~2 GB per rank; seq_len=1024 OOMed on load last time), switch default COMPILED_PATH to the new BF16-attn directory name to avoid clobbering earlier artifacts. - README: rewrite Status to separate the all-FP8 result (drifted) from the BF16-attn result (coherent); document the required repatch step, the HBM / seq_len trade-off, and a warning that listing q/k/v in modules_to_not_convert without running repatch first produces nonsense (NxDI casts fp8 bytes to bf16 without applying the scale). Update Quick Start to include the repatch step. Flag that vLLM scripts still use the all-FP8 recipe and the bench numbers haven't been re-measured on BF16-attn. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2.5-Pro/README.md | 91 +++++++-- .../perf_test/smoke_compile_mimo_v2.py | 21 +- .../perf_test/smoke_generate_mimo_v2.py | 7 +- .../src/conversion_script/repatch_qkv_bf16.py | 190 ++++++++++++++++++ 4 files changed, 285 insertions(+), 24 deletions(-) create mode 100644 contrib/models/MiMo-V2.5-Pro/src/conversion_script/repatch_qkv_bf16.py diff --git a/contrib/models/MiMo-V2.5-Pro/README.md b/contrib/models/MiMo-V2.5-Pro/README.md index b0f371e4..d7f0c223 100644 --- a/contrib/models/MiMo-V2.5-Pro/README.md +++ b/contrib/models/MiMo-V2.5-Pro/README.md @@ -37,25 +37,47 @@ Key features: ## Status (work-in-progress) -**This port compiles cleanly and serves via vLLM on Trn2, but output quality is not production-ready under the default FP8 recipe.** Sanity checks on 2026-04-29 under `tp_degree=64, moe_tp_degree=1, moe_ep_degree=64, batch_size=48, seq_len=1024`: - -- Prompt-dependent drift. Self-intro style prompts ("introduce yourself in one sentence") return coherent Chinese output; most other prompts collapse to repetition or unrelated text within a few tokens (e.g. `"The capital of France is\n# 1000000000000000"`, `"Once upon a time in a small village there lived\n# 0000000000..."`). -- Same failure pattern was observed on the MiMo-V2-Pro port under the same recipe (`"0.0.0.0:8080"`, etc.). Root cause appears to be **Neuron's NKI blockwise FP8 compute kernel handling Pro's tight expert-weight distribution**: Pro's per-expert weight `abs_mean ≈ 0.00124` and blockwise `scale_mean ≈ 2.3e-5` are 4-7× smaller than V2.5 (256 experts) on the same recipe, and the NKI accumulator loses precision compounded across 70 layers. V2.5-Pro's MoE expert weights are byte-identical to V2-Pro (verified layer 1 expert 0 dequant stats match to 6 decimal places on 2026-04-28), so all V2-Pro workarounds remain required (router bias mean-subtract, qkv interleaved split, `_apply_2d_per_channel_fix`, `_apply_blockwise_scale_stride_fix`). -- GPU stacks (sglang on H100/H200) run the exact same OCP FP8 checkpoint correctly, because they dequantize FP8→BF16 before the matmul. Neuron NKI does FP8 compute directly and drifts on subnormal-leaning tensors. +**This port compiles cleanly and serves via vLLM on Trn2. Under the default all-FP8 recipe the output drifts on most prompts; a BF16-attn recipe (keep MoE FP8, dequant q/k/v_proj to BF16, compile at seq_len=256) restores coherent output and isolates the root cause.** Findings as of 2026-04-29: + +### All-FP8 recipe (`tp_degree=64, moe_tp=1, moe_ep=64, BS=48, seq_len=1024`) + +- **Prompt-dependent drift.** Self-intro prompts sometimes return coherent answers (sampling gets lucky on a strong self-identifying logit); most other prompts collapse to repetition or unrelated text within a few tokens (e.g. `"The capital of France is\n# 1000000000000000"`, `"Once upon a time in a small village there lived\n# 0000000000..."`, chat continuations that drift into RLHF-style self-reflection with Chinese/Thai). +- **Not a sampling artifact.** `temperature` in the request is ignored — vllm-neuron's on-device sampling config (`do_sample=true, T=0.6, top_k=20, top_p=0.95`) is baked into the NEFF at compile time. Output is always stochastic, but the underlying logits are already drifted. +- Same failure pattern was observed on MiMo-V2-Pro under the same recipe (`"0.0.0.0:8080"`, etc.). V2.5-Pro's MoE expert weights are byte-identical to V2-Pro (verified layer 1 expert 0 dequant stats match to 6 decimal places on 2026-04-28). + +### BF16-attn recipe (`q_proj/k_proj/v_proj` dequant to BF16, MoE kept FP8, `seq_len=256`) + +- **Output is coherent.** On the same self-intro prompt, smoke_generate with a minimal chat template produces a well-formed reasoning trace that correctly identifies the model: + ``` + Okay, the user is asking for a simple self-introduction in one sentence, + with no deeper or hidden needs apparent. As MiMo, based on Xiaomi's self-developed + large model, I need to respond in a friendly, positive, and helpful way that aligns + with providing assistance ... + ``` +- **This narrows the root cause to attention-path FP8 precision**, not the MoE experts. Pro's attention weights have `abs_mean ≈ 0.00124`, roughly 4× smaller than V2.5 (256 experts). The NKI blockwise FP8 accumulator on attention q/k/v at this magnitude loses enough precision to drift the logits across 70 layers; dequantizing q/k/v to BF16 before the matmul restores correct output. MoE experts (scales `≈ 2.3e-5`, also small) can stay FP8 under this recipe. +- **Cost: HBM headroom and seq_len.** BF16 q/k/v adds ~2 GB per rank. At `seq_len=1024` this OOMs on load (previous attempt failed allocating 41 MB for rdh/alltoall rings). `seq_len=256` frees enough full-attention softmax scratch to fit; longer context needs a different HBM plan (cross-instance TP/PP, or larger instance). +- GPU stacks (sglang on H100/H200) run the exact same OCP FP8 checkpoint correctly because they always dequantize FP8 → BF16 before the matmul. The issue is specific to Neuron's direct-FP8 compute path on subnormal-leaning tensors. - Reference: Jim Burtoft observed similar prompt-dependent FP8 degradation on Flash and his Kimi PR #131 names "blockwise kernel padding produces depressed logits with EP=2 on SDK 2.29; SDK 2.28 recommended". -Recipes that were tried and did not resolve the drift (all on 2026-04-28/29): -- **q/k/v_proj dequant to BF16** (compile-time `modules_to_not_convert` includes `q_proj, k_proj, v_proj`): compiled but HBM OOM on load — BF16 attention weights pushed per-rank tensors to 20.93/24 GB, leaving no room for collective DMA rings. -- **`use_torch_block_wise=True`** (PyTorch-fallback blockwise matmul for higher accumulator precision): compile+shard succeeded after ~2 h, but `model.load()` crashed with `status=4 Allocation Failure` — the fallback path raises HBM demand. +### Preprocess step required for BF16 attn + +`src/conversion_script/repatch_qkv_bf16.py` reads the original HF fused `qkv_proj` weight + scale, dequants per-group (`num_q_heads_per_kv_group` × `head_dim` rows per group) to BF16, and overwrites `q_proj/k_proj/v_proj` in the preprocessed Neuron-FP8 checkpoint in place. Runs in ~22 min. Simply adding `q_proj/k_proj/v_proj` to `modules_to_not_convert` without this preprocess is **not** equivalent — NxDI casts the raw fp8_e4m3fn bytes to bfloat16 without applying the scale, which produces nonsense weights. + +### Recipes that were tried and did not resolve the drift (all on 2026-04-28/29) + +- **`use_torch_block_wise=True`** (PyTorch-fallback blockwise matmul for higher accumulator precision): compile+shard succeeded after ~2 h, but `model.load()` crashed with `status=4 Allocation Failure` — the fallback path raises HBM demand even with MoE-only scope. -Next experiments queued (none verified yet): -- BF16 weights across two Trn2 instances with cross-node TP/PP (single-instance HBM cannot hold BF16 Pro). -- SDK 2.28 venv test once installed, per Kimi PR #131. -- Selective BF16 only on `gate_up_proj` (smallest MoE scales) while keeping `down_proj` FP8, if HBM allows. +### Next experiments queued -Known NxDI limits that constrain recipe choice: -- `BS * top_k / num_experts >= 1.0` required when `moe_ep_degree > 1` at decode (else NotImplementedError). With `num_experts=384, top_k=8` this forces `BS >= 48`. -- `n_routed_experts=384 = 2^7 * 3` → `384 / ep_degree` is never a power of 2 (6, 12, 24, 48, 96, 192, 384). Kimi PR #131 says NKI `_bwmm_shard_on_block_nki_call` on SDK 2.29 has "depressed logits with EP=2" and recommends SDK 2.28. SDK 2.28 venv is not currently installed on the target DLAMI. +- **BF16-attn + larger `seq_len`.** `seq_len=256` is tight; Pro's chat template with the default system prompt is already 260 tokens. Either shrink the system prompt, or try BF16-attn at `seq_len=512` on a less-full HBM plan (e.g. MoE-EP with fewer experts per rank, or drop batch size below 48 via different recipe trade-offs). +- **Cross-instance BF16** via pipeline/tensor parallelism on 2× Trn2 (single-instance HBM cannot hold full BF16 Pro). +- **Selective BF16 only on MoE `gate_up_proj`** (smallest expert scales) while keeping `down_proj` FP8 — another axis to probe if attn drift returns at longer contexts. +- **SDK 2.28 venv** test once installed, per Kimi PR #131. + +### Known NxDI limits that constrain recipe choice + +- `BS * top_k / num_experts >= 1.0` required when `moe_ep_degree > 1` at decode (else `NotImplementedError`). With `num_experts=384, top_k=8` this forces `BS >= 48`. +- `n_routed_experts=384 = 2^7 × 3` → `384 / ep_degree` is never a power of 2 (6, 12, 24, 48, 96, 192, 384). Kimi PR #131 says NKI `_bwmm_shard_on_block_nki_call` on SDK 2.29 has "depressed logits with EP=2" and recommends SDK 2.28. ## Prerequisites @@ -106,8 +128,14 @@ python contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8 --save_path /opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8 \ --tp_degree 64 +# 3b. Dequant attention q/k/v to BF16 in place (~22 min, required for +# coherent output under the current FP8 recipe; see Status). +python contrib/models/MiMo-V2.5-Pro/src/conversion_script/repatch_qkv_bf16.py \ + --hf_model_path /opt/dlami/nvme/models/MiMo-V2.5-Pro \ + --neuron_model_path /opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8 + # 4. (Optional) sanity-check the Neuron-FP8 checkpoint without vLLM -# ~45 min first compile; subsequent runs ~30s to load the pre-sharded NEFF. +# ~90 min first compile; subsequent runs ~60s to load the pre-sharded NEFF. source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate python contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py # compile python contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py # 20-token generate @@ -197,6 +225,20 @@ python contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8 Peak RAM during preprocessing is ~24 GB; total runtime ~20 minutes on a trn2.48xlarge instance. +### Required follow-up: dequant q/k/v to BF16 (`repatch_qkv_bf16.py`) + +Under the default all-FP8 recipe, Pro's attention weights drift enough to produce gibberish output (see Status). The fix is to keep MoE experts FP8 but dequant q/k/v to BF16 before compile. `src/conversion_script/repatch_qkv_bf16.py` reads the HF fused `qkv_proj` + `weight_scale_inv`, dequants per-group (each group is one kv-head: `hpg` Q rows, `1×head_dim` K rows, `1×v_head_dim` V rows), and overwrites the q_proj/k_proj/v_proj entries in the preprocessed Neuron-FP8 checkpoint in place. The scale entries for q/k/v are dropped from the safetensors index. + +```bash +python contrib/models/MiMo-V2.5-Pro/src/conversion_script/repatch_qkv_bf16.py \ + --hf_model_path /path/to/MiMo-V2.5-Pro \ + --neuron_model_path /path/to/MiMo-V2.5-Pro-Neuron-FP8 +``` + +Runtime is ~22 minutes on a trn2.48xlarge, peak RAM ~20 GB. After running this, the compile-time `modules_to_not_convert` list in `smoke_compile_mimo_v2.py` / `start_vllm_server.sh` must include `q_proj`, `k_proj`, `v_proj` so NxDI keeps them as BF16 and routes them through a plain `ColumnParallelLinear` rather than the FP8 `QuantizedColumnParallel` path. + +**Do not skip this step.** Adding q_proj/k_proj/v_proj to `modules_to_not_convert` without running repatch first leaves the fp8 bytes in the checkpoint; NxDI then casts them to bf16 **without applying the blockwise scale**, producing nonsense weights and broken output. + ### Fallback: FP8 → BF16 `src/conversion_script/preprocess_mimo_v2_fp8.py` dequantizes the entire checkpoint to BF16. Output is ~290 GB; BF16 is numerically equivalent to the published HF FP8 weights and is useful as a known-good reference. Throughput is ~2× worse than the FP8 path because every attention/MLP matmul operates on full BF16 weights. @@ -220,9 +262,13 @@ from modeling_mimo_v2 import NeuronMiMoV2ForCausalLM, MiMoV2InferenceConfig model_path = "/path/to/MiMo-V2.5-Pro-Neuron-FP8/" compiled_path = "/path/to/compiled/" -# Recommended FP8 recipe: +# Recommended recipe: BF16 attn + FP8 MoE. # moe_tp_degree = 1, moe_ep_degree = 64 -# See "FP8 Configuration Notes" below for why other moe_tp/ep ratios collapse. +# q_proj/k_proj/v_proj in modules_to_not_convert (BF16; requires +# repatch_qkv_bf16.py to have been run on the checkpoint first) +# seq_len = 256 (HBM-tight with BF16 attn; see Status) +# See "FP8 Configuration Notes" below for why other moe_tp/ep ratios +# collapse. neuron_config = MoENeuronConfig( tp_degree=64, ep_degree=1, # keep outer EP = 1; only MoE-internal EP varies @@ -232,7 +278,7 @@ neuron_config = MoENeuronConfig( max_batch_size=48, ctx_batch_size=1, tkg_batch_size=48, - seq_len=1024, + seq_len=256, # HBM is tight with BF16 attn; seq_len=1024 OOMs n_active_tokens=128, torch_dtype=torch.bfloat16, logical_nc_config=2, @@ -253,6 +299,7 @@ neuron_config = MoENeuronConfig( quantization_block_size=[128, 128], modules_to_not_convert=[ "embed_tokens", "lm_head", "norm", "router", "o_proj", + "q_proj", "k_proj", "v_proj", # BF16 attn — requires repatch ], on_device_sampling_config=OnDeviceSamplingConfig( do_sample=True, temperature=0.6, top_k=20, top_p=0.95, @@ -334,6 +381,8 @@ bash contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh See "Environment variables" above for all the knobs (`NEURON_COMPILED_ARTIFACTS`, `BASE_COMPILE_WORK_DIR`, etc.) and their defaults. +> **Note on output quality:** the shipped vLLM scripts use the **all-FP8 recipe** (`seq_len=1024`, attn+MoE both FP8). This currently produces prompt-dependent drift — see Status. The **BF16-attn recipe** that restores coherent output has so far only been validated end-to-end via `smoke_generate_mimo_v2.py` (direct NxDI, `seq_len=256`); porting it into `start_vllm_server.sh` requires also adding `q_proj/k_proj/v_proj` to the `modules_to_not_convert` list and dropping `seq_len` / `max_model_len` to 256, and the perf numbers below have not been re-measured in that configuration. + ### vllm-neuron patch summary The patch is applied to vllm-neuron 0.5.0 and: @@ -346,7 +395,7 @@ The patch is applied to vllm-neuron 0.5.0 and: > The throughput numbers below are from a working vLLM server run on 2026-04-29 under the recommended FP8 recipe. Output quality under this recipe is **not production-usable** (see Status); the numbers show that the serving infrastructure runs end-to-end, not that the model answers correctly. -### vLLM Serving (trn2.48xlarge, FP8, BS=48, TP=64, moe_tp=1/moe_ep=64, CB + bucketing) +### vLLM Serving (trn2.48xlarge, all-FP8 recipe, BS=48, TP=64, moe_tp=1/moe_ep=64, CB + bucketing, `seq_len=1024`) Input/output: 900/90 tokens (`vllm bench serve --dataset-name random`), `on_device_sampling_config={do_sample:true, temperature:0.6, top_k:20, top_p:0.95}`. @@ -358,6 +407,8 @@ Input/output: 900/90 tokens (`vllm bench serve --dataset-name random`), `on_devi Per-stream ITL median holds at ~220 ms across all concurrency levels; TPOT/TTFT growth at higher concurrency comes from continuous-batching queue pressure, not per-step compute. +> **Numbers are from the all-FP8 recipe, which produces drifted output (see Status).** The BF16-attn recipe that restores coherent output has not yet been re-benchmarked on vLLM; throughput should be comparable (only q/k/v go BF16, MoE stays FP8) but at `seq_len=256` instead of 1024, so TTFT/latency characteristics will differ. + > **Compile time:** the first Pro compile on SDK 2.29 is ~60 minutes for the TKG NEFF and ~15 minutes for the CTE NEFF; subsequent runs with the same `override_neuron_config` hit the neuronx-cc cache and start in ~1-2 minutes. `save_sharded_checkpoint=true` additionally persists per-rank FP8 shards under `/weights/`, letting future `load()` calls skip the ~10-minute shard_checkpoint pass. First full server launch (compile + shard + warmup) is ~2 hours wall-clock. ## Compatibility Matrix diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py index a0933f7f..6475eea2 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py @@ -40,11 +40,15 @@ ) COMPILED_PATH = os.environ.get( "MIMO_V25_PRO_COMPILED_PATH", - "/opt/dlami/nvme/compiled/mimo_v25_pro_moetp1_ep64_bs48/", + "/opt/dlami/nvme/compiled/mimo_v2_5_pro_bs48_moetp1_ep64_fp8moe_bf16attn_seq256/", ) TP_DEGREE = int(os.environ.get("TP_DEGREE", "64")) -SEQ_LEN = int(os.environ.get("SEQ_LEN", "1024")) +# Drop seq_len to 256 to free ~200 MB of full-attention softmax scratch per +# rank. The previous BF16-attn attempt at seq_len=1024 OOM'd by 40 MB on load +# (failed to allocate 41943040 bytes for rdh/alltoall); seq_len=256 reclaims +# enough HBM to fit the extra BF16 q/k/v weights. +SEQ_LEN = int(os.environ.get("SEQ_LEN", "256")) # BS=48 is the minimum that avoids forward_selective_loading on decode: # `BS * top_k / num_experts >= 1.0` → BS >= 384/8 = 48. At BS=1 the TKG # path raises `NotImplementedError: Selective Loading with Expert parallelism`. @@ -142,12 +146,25 @@ def main(): quantization_type="blockwise_symmetric", quantization_block_axis=[1, 2], quantization_block_size=[128, 128], + # BF16 attention: keep q/k/v_proj in BF16 (not FP8). Pro's q/k/v + # abs_mean ~0.00124 is 4x smaller than V2.5 and the NKI blockwise + # FP8 accumulator drifts across 70 layers, producing gibberish + # output under the all-FP8 recipe. Dequantizing q/k/v to BF16 + # restores coherent output while keeping MoE experts FP8. + # Prerequisite: run src/conversion_script/repatch_qkv_bf16.py on + # the preprocessed Neuron-FP8 checkpoint first; simply listing + # q/k/v here without the repatch step leaves the fp8 bytes in the + # checkpoint and NxDI silently casts them to bf16 without applying + # the scale, which produces nonsense weights. modules_to_not_convert=[ "embed_tokens", "lm_head", "norm", "router", "o_proj", + "q_proj", + "k_proj", + "v_proj", ], ) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py index 9f678a09..4c1a3bd2 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py @@ -27,13 +27,13 @@ ) COMPILED_PATH = os.environ.get( "MIMO_V25_PRO_COMPILED_PATH", - "/opt/dlami/nvme/compiled/mimo_v25_pro_moetp1_ep64_bs48/", + "/opt/dlami/nvme/compiled/mimo_v2_5_pro_bs48_moetp1_ep64_fp8moe_bf16attn_seq256/", ) # Must match smoke_compile_mimo_v2.py exactly, else load() sees a # mismatched NEFF. TP_DEGREE = int(os.environ.get("TP_DEGREE", "64")) -SEQ_LEN = int(os.environ.get("SEQ_LEN", "1024")) +SEQ_LEN = int(os.environ.get("SEQ_LEN", "256")) BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "48")) # must match smoke_compile CTX_BATCH_SIZE = int(os.environ.get("CTX_BATCH_SIZE", "1")) MOE_TP = int(os.environ.get("MOE_TP", "1")) @@ -115,6 +115,9 @@ def main(): "norm", "router", "o_proj", + "q_proj", + "k_proj", + "v_proj", ], ) diff --git a/contrib/models/MiMo-V2.5-Pro/src/conversion_script/repatch_qkv_bf16.py b/contrib/models/MiMo-V2.5-Pro/src/conversion_script/repatch_qkv_bf16.py new file mode 100644 index 00000000..6ac7fb3a --- /dev/null +++ b/contrib/models/MiMo-V2.5-Pro/src/conversion_script/repatch_qkv_bf16.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +"""In-place patch: replace q/k/v FP8+scale with BF16 in a MiMo-V2.5-Pro +preprocessed Neuron checkpoint, for every decoder layer. Leaves MoE experts, +norms, embed, lm_head, o_proj untouched. + +Rationale: Pro's attention q/k/v weights have abs_mean ~0.001-0.005, roughly +4x smaller than V2.5. The NKI blockwise FP8 accumulator on the attention +path loses enough precision at this magnitude to drift the logits across +70 layers; dequantizing q/k/v to BF16 before the matmul restores coherent +output. MoE experts (also small-scale) can stay FP8. + +Note: simply adding q_proj/k_proj/v_proj to NxDI's `modules_to_not_convert` +at compile time is NOT equivalent — NxDI casts the raw fp8_e4m3fn bytes to +bfloat16 without applying the blockwise scale, which produces nonsense +weights. This script reads the HF fused qkv + scale, dequants per-group, +and writes BF16 weights back into the preprocessed Neuron checkpoint in +place. + +Run under /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16 (or any venv +with torch and safetensors). Takes ~22 min on a trn2.48xlarge for 70 +layers. +""" +import argparse +import glob +import json +import math +import os +import sys +import time + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + + +def main(): + parser = argparse.ArgumentParser( + description="In-place dequant q/k/v from FP8+scale to BF16 in the " + "preprocessed Neuron checkpoint.", + ) + parser.add_argument( + "--hf_model_path", + required=True, + help="Path to the original HuggingFace MiMo-V2.5-Pro checkpoint " + "(fused qkv_proj + qkv_proj.weight_scale_inv).", + ) + parser.add_argument( + "--neuron_model_path", + required=True, + help="Path to the preprocessed Neuron-FP8 checkpoint. q/k/v entries " + "in its model_layer{N}.safetensors shards will be overwritten in " + "place with BF16 values; the scale entries are dropped from the " + "index.", + ) + args = parser.parse_args() + + # Import split_qkv_fused from the neighbouring preprocess script so the + # group layout math (hpg, qg_rows, kg_rows, vg_rows) stays in one place. + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from preprocess_mimo_v2_fp8 import LazyWeightMap # noqa: F401 + + hf_src = args.hf_model_path + neuron = args.neuron_model_path + cfg = json.load(open(os.path.join(hf_src, "config.json"))) + hp = cfg["hybrid_layer_pattern"] + num_hidden_layers = cfg["num_hidden_layers"] + + with open(os.path.join(hf_src, "model.safetensors.index.json")) as f: + hf_wm = json.load(f)["weight_map"] + lazy = LazyWeightMap(hf_src, hf_wm) + + print( + f"Patching q/k/v -> BF16 in {neuron}/model_layer{{0..{num_hidden_layers - 1}}}.safetensors", + flush=True, + ) + t0 = time.time() + + for li in range(num_hidden_layers): + layer_file = os.path.join(neuron, f"model_layer{li}.safetensors") + if not os.path.exists(layer_file): + print(f" layer {li}: FILE MISSING, skip", flush=True) + continue + with safe_open(layer_file, framework="pt") as fp: + layer_sd = {k: fp.get_tensor(k) for k in fp.keys()} + + is_swa = hp[li] == 1 + num_q = cfg["swa_num_attention_heads" if is_swa else "num_attention_heads"] + num_kv = cfg["swa_num_key_value_heads" if is_swa else "num_key_value_heads"] + hd = cfg["swa_head_dim" if is_swa else "head_dim"] + vhd = cfg["swa_v_head_dim" if is_swa else "v_head_dim"] + prefix = f"model.layers.{li}.self_attn" + qkv_w = lazy.get(f"{prefix}.qkv_proj.weight") + qkv_s = lazy.get(f"{prefix}.qkv_proj.weight_scale_inv") + + BLOCK = 128 + hpg = num_q // num_kv + qg_rows = hpg * hd + kg_rows = 1 * hd + vg_rows = 1 * vhd + R = qg_rows + kg_rows + vg_rows + in_features = qkv_w.shape[1] + q_blk = qg_rows // BLOCK + k_blk = (kg_rows + BLOCK - 1) // BLOCK + v_blk = (vg_rows + BLOCK - 1) // BLOCK + per = q_blk + k_blk + v_blk + padded = per * BLOCK + + w = qkv_w.to(torch.float32).view(num_kv, R, in_features) + w_padded = torch.zeros(num_kv, padded, in_features, dtype=torch.float32) + w_padded[:, :R, :] = w + s = qkv_s.to(torch.float32).view(num_kv, per, (in_features + BLOCK - 1) // BLOCK) + s_exp = s.repeat_interleave(BLOCK, dim=1).repeat_interleave(BLOCK, dim=2) + s_exp = s_exp[:, :padded, :in_features] + deq_padded = w_padded * s_exp + deq = deq_padded[:, :R, :] + + q_bf16 = ( + deq[:, :qg_rows, :] + .reshape(num_kv * qg_rows, in_features) + .contiguous() + .to(torch.bfloat16) + ) + k_bf16 = ( + deq[:, qg_rows : qg_rows + kg_rows, :] + .reshape(num_kv * kg_rows, in_features) + .contiguous() + .to(torch.bfloat16) + ) + v_bf16 = ( + deq[:, qg_rows + kg_rows :, :] + .reshape(num_kv * vg_rows, in_features) + .contiguous() + .to(torch.bfloat16) + ) + + for key in list(layer_sd): + if any( + key.endswith(f".{p}.weight") or key.endswith(f".{p}.scale") + for p in ("q_proj", "k_proj", "v_proj") + ): + del layer_sd[key] + + out_prefix = f"layers.{li}.self_attn" + layer_sd[f"{out_prefix}.q_proj.weight"] = q_bf16 + layer_sd[f"{out_prefix}.k_proj.weight"] = k_bf16 + layer_sd[f"{out_prefix}.v_proj.weight"] = v_bf16 + + save_file(layer_sd, layer_file) + dt = time.time() - t0 + if li % 5 == 0 or li == num_hidden_layers - 1: + print( + f" layer {li:2d} [{'swa' if is_swa else 'full'}]: " + f"q{list(q_bf16.shape)} k{list(k_bf16.shape)} v{list(v_bf16.shape)} " + f"elapsed={dt:.1f}s", + flush=True, + ) + + print("\nRewrite weight_map to reflect dtype change.", flush=True) + idx_path = os.path.join(neuron, "model.safetensors.index.json") + with open(idx_path) as f: + idx = json.load(f) + keys_to_drop = [ + k + for k in idx["weight_map"] + if any(k.endswith(f".{p}.scale") for p in ("q_proj", "k_proj", "v_proj")) + ] + for k in keys_to_drop: + idx["weight_map"].pop(k, None) + + total = 0 + for f_path in sorted(glob.glob(os.path.join(neuron, "*.safetensors"))): + with safe_open(f_path, framework="pt") as fp: + for k in fp.keys(): + t = fp.get_slice(k) + shape = list(t.get_shape()) + dt_bytes = {"F32": 4, "F16": 2, "BF16": 2, "F8_E4M3": 1}.get( + t.get_dtype(), 2 + ) + total += dt_bytes * max(1, int(math.prod(shape))) + idx["metadata"] = idx.get("metadata", {}) + idx["metadata"]["total_size"] = total + with open(idx_path, "w") as f: + json.dump(idx, f, indent=2) + print(f" dropped {len(keys_to_drop)} scale entries from index", flush=True) + print(f" total_size now {total / 1e9:.2f} GB", flush=True) + print(f"\nDone in {time.time() - t0:.1f}s", flush=True) + + +if __name__ == "__main__": + main() From cf60b9186f0ff681afc4ac4c9035a4acdd35f1dd Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 29 Apr 2026 14:10:42 +0800 Subject: [PATCH 15/24] [contrib] MiMo-V2.5-Pro: fold BF16 attn into preprocess; drop repatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The separate repatch_qkv_bf16.py was a diagnostic workaround: preprocess FP8 first, discover drift, then retro-fit BF16. Now that BF16 attn is the confirmed recipe, fold the per-group dequant directly into the preprocess so there is one script, one output, no "forgot to run repatch" trap. Changes: - preprocess_mimo_v2_fp8.py::split_qkv_fused now returns BF16 per-proj tensors directly (Dict[str, Tensor] instead of Dict[str, Tuple[...]]). The FP8+blockwise path still unwinds the phantom-row padding, then dequants to BF16 in one go. BF16-source path collapses to the same reshape without requant. - Add _dequant_attn_to_bf16() for the Flash-style non-fused q/k/v fallback path; process_layer calls it so those projections also come out BF16. - No compile-time flag or branch for "all-FP8 attn" — that recipe is known broken for Pro (produces gibberish), preserving the branch only invites re-discovering the same trap. - Delete src/conversion_script/repatch_qkv_bf16.py. - README: drop the "Required follow-up: repatch" subsection, simplify the Status writeup (one recipe, one outcome), remove step 3b from Quick Start, clarify in "Preprocess emits BF16 q/k/v" that modules_to_not_convert still needs q/k/v so NxDI routes them through the non-quantized ColumnParallelLinear. - smoke_compile_mimo_v2.py: tighten the inline comment on q/k/v in modules_to_not_convert (no more "Prerequisite: run repatch"). Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2.5-Pro/README.md | 34 +--- .../perf_test/smoke_compile_mimo_v2.py | 12 +- .../preprocess_mimo_v2_fp8.py | 143 ++++++++----- .../src/conversion_script/repatch_qkv_bf16.py | 190 ------------------ 4 files changed, 106 insertions(+), 273 deletions(-) delete mode 100644 contrib/models/MiMo-V2.5-Pro/src/conversion_script/repatch_qkv_bf16.py diff --git a/contrib/models/MiMo-V2.5-Pro/README.md b/contrib/models/MiMo-V2.5-Pro/README.md index d7f0c223..4788275c 100644 --- a/contrib/models/MiMo-V2.5-Pro/README.md +++ b/contrib/models/MiMo-V2.5-Pro/README.md @@ -59,9 +59,9 @@ Key features: - GPU stacks (sglang on H100/H200) run the exact same OCP FP8 checkpoint correctly because they always dequantize FP8 → BF16 before the matmul. The issue is specific to Neuron's direct-FP8 compute path on subnormal-leaning tensors. - Reference: Jim Burtoft observed similar prompt-dependent FP8 degradation on Flash and his Kimi PR #131 names "blockwise kernel padding produces depressed logits with EP=2 on SDK 2.29; SDK 2.28 recommended". -### Preprocess step required for BF16 attn +### Preprocess emits BF16 q/k/v -`src/conversion_script/repatch_qkv_bf16.py` reads the original HF fused `qkv_proj` weight + scale, dequants per-group (`num_q_heads_per_kv_group` × `head_dim` rows per group) to BF16, and overwrites `q_proj/k_proj/v_proj` in the preprocessed Neuron-FP8 checkpoint in place. Runs in ~22 min. Simply adding `q_proj/k_proj/v_proj` to `modules_to_not_convert` without this preprocess is **not** equivalent — NxDI casts the raw fp8_e4m3fn bytes to bfloat16 without applying the scale, which produces nonsense weights. +`src/conversion_script/preprocess_mimo_v2_fp8.py` now dequants q/k/v to BF16 directly (`split_qkv_fused` for the fused Pro layout, `_dequant_attn_to_bf16` for the Flash-style split layout). Output checkpoint has no `q_proj.scale` / `k_proj.scale` / `v_proj.scale` entries. Compile-time `modules_to_not_convert` must include `q_proj`, `k_proj`, `v_proj` so NxDI routes them through a plain `ColumnParallelLinear`; `smoke_compile_mimo_v2.py` and `start_vllm_server.sh` already do. ### Recipes that were tried and did not resolve the drift (all on 2026-04-28/29) @@ -121,19 +121,15 @@ git checkout contrib/MiMo-V2.5-Pro # the branch this README lives on huggingface-cli download XiaomiMiMo/MiMo-V2.5-Pro \ --local-dir /opt/dlami/nvme/models/MiMo-V2.5-Pro -# 3. Preprocess HF FP8 -> Neuron FP8 (~20 min, ~24 GB peak RAM) +# 3. Preprocess HF FP8 -> Neuron-FP8 (BF16 attn, FP8 MoE). ~20 min, ~24 GB +# peak RAM. The preprocess dequants q/k/v to BF16 in one pass — see +# "Checkpoint Preparation" below for why BF16 attn is the only recipe. source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate python contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py \ --hf_model_path /opt/dlami/nvme/models/MiMo-V2.5-Pro \ --save_path /opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8 \ --tp_degree 64 -# 3b. Dequant attention q/k/v to BF16 in place (~22 min, required for -# coherent output under the current FP8 recipe; see Status). -python contrib/models/MiMo-V2.5-Pro/src/conversion_script/repatch_qkv_bf16.py \ - --hf_model_path /opt/dlami/nvme/models/MiMo-V2.5-Pro \ - --neuron_model_path /opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8 - # 4. (Optional) sanity-check the Neuron-FP8 checkpoint without vLLM # ~90 min first compile; subsequent runs ~60s to load the pre-sharded NEFF. source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate @@ -225,19 +221,11 @@ python contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8 Peak RAM during preprocessing is ~24 GB; total runtime ~20 minutes on a trn2.48xlarge instance. -### Required follow-up: dequant q/k/v to BF16 (`repatch_qkv_bf16.py`) - -Under the default all-FP8 recipe, Pro's attention weights drift enough to produce gibberish output (see Status). The fix is to keep MoE experts FP8 but dequant q/k/v to BF16 before compile. `src/conversion_script/repatch_qkv_bf16.py` reads the HF fused `qkv_proj` + `weight_scale_inv`, dequants per-group (each group is one kv-head: `hpg` Q rows, `1×head_dim` K rows, `1×v_head_dim` V rows), and overwrites the q_proj/k_proj/v_proj entries in the preprocessed Neuron-FP8 checkpoint in place. The scale entries for q/k/v are dropped from the safetensors index. - -```bash -python contrib/models/MiMo-V2.5-Pro/src/conversion_script/repatch_qkv_bf16.py \ - --hf_model_path /path/to/MiMo-V2.5-Pro \ - --neuron_model_path /path/to/MiMo-V2.5-Pro-Neuron-FP8 -``` +### Why q/k/v are BF16 in the preprocessed output -Runtime is ~22 minutes on a trn2.48xlarge, peak RAM ~20 GB. After running this, the compile-time `modules_to_not_convert` list in `smoke_compile_mimo_v2.py` / `start_vllm_server.sh` must include `q_proj`, `k_proj`, `v_proj` so NxDI keeps them as BF16 and routes them through a plain `ColumnParallelLinear` rather than the FP8 `QuantizedColumnParallel` path. +Pro's attention weights have `abs_mean ≈ 0.00124`, roughly 4× smaller than V2.5 (256 experts). The NKI blockwise FP8 accumulator at this magnitude drifts the logits across 70 layers and produces gibberish output — `"The capital of France is\n# 1000000000000000"`, `"Once upon a time in a small village there lived\n# 0000000000..."`, etc. Dequantizing q/k/v to BF16 while keeping MoE experts FP8 restores coherent output (verified on 2026-04-29 via `smoke_generate_mimo_v2.py`). -**Do not skip this step.** Adding q_proj/k_proj/v_proj to `modules_to_not_convert` without running repatch first leaves the fp8 bytes in the checkpoint; NxDI then casts them to bf16 **without applying the blockwise scale**, producing nonsense weights and broken output. +The preprocess handles this in a single pass: `split_qkv_fused()` unfuses Pro's `qkv_proj` into per-proj BF16 tensors directly, and the Flash-style per-proj fallback path dequants via `_dequant_attn_to_bf16()`. The checkpoint emitted by preprocess has no `q_proj.scale` / `k_proj.scale` / `v_proj.scale` entries. Compile-time `modules_to_not_convert` must therefore include `q_proj`, `k_proj`, `v_proj` so NxDI routes them through a plain `ColumnParallelLinear` rather than the FP8 `QuantizedColumnParallel` path — `smoke_compile_mimo_v2.py` already does this. ### Fallback: FP8 → BF16 @@ -264,8 +252,8 @@ compiled_path = "/path/to/compiled/" # Recommended recipe: BF16 attn + FP8 MoE. # moe_tp_degree = 1, moe_ep_degree = 64 -# q_proj/k_proj/v_proj in modules_to_not_convert (BF16; requires -# repatch_qkv_bf16.py to have been run on the checkpoint first) +# q_proj/k_proj/v_proj in modules_to_not_convert (BF16; preprocess +# emits BF16 for q/k/v, no separate step needed) # seq_len = 256 (HBM-tight with BF16 attn; see Status) # See "FP8 Configuration Notes" below for why other moe_tp/ep ratios # collapse. @@ -299,7 +287,7 @@ neuron_config = MoENeuronConfig( quantization_block_size=[128, 128], modules_to_not_convert=[ "embed_tokens", "lm_head", "norm", "router", "o_proj", - "q_proj", "k_proj", "v_proj", # BF16 attn — requires repatch + "q_proj", "k_proj", "v_proj", # BF16 attn — preprocess emits BF16 ], on_device_sampling_config=OnDeviceSamplingConfig( do_sample=True, temperature=0.6, top_k=20, top_p=0.95, diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py index 6475eea2..73b8abbe 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py @@ -148,14 +148,10 @@ def main(): quantization_block_size=[128, 128], # BF16 attention: keep q/k/v_proj in BF16 (not FP8). Pro's q/k/v # abs_mean ~0.00124 is 4x smaller than V2.5 and the NKI blockwise - # FP8 accumulator drifts across 70 layers, producing gibberish - # output under the all-FP8 recipe. Dequantizing q/k/v to BF16 - # restores coherent output while keeping MoE experts FP8. - # Prerequisite: run src/conversion_script/repatch_qkv_bf16.py on - # the preprocessed Neuron-FP8 checkpoint first; simply listing - # q/k/v here without the repatch step leaves the fp8 bytes in the - # checkpoint and NxDI silently casts them to bf16 without applying - # the scale, which produces nonsense weights. + # FP8 accumulator drifts across 70 layers. Preprocess already + # emits BF16 for q/k/v (see preprocess_mimo_v2_fp8.py docstring); + # this list just tells NxDI to skip FP8 quantization and route + # through ColumnParallelLinear instead of QuantizedColumnParallel. modules_to_not_convert=[ "embed_tokens", "lm_head", diff --git a/contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py b/contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py index b9882fbb..e63b5913 100644 --- a/contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py +++ b/contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py @@ -10,8 +10,11 @@ MiMo-V2.5-Pro checkpoint layout: - q_proj, k_proj, v_proj are FUSED into a single `qkv_proj` tensor per - layer (num_kv_heads interleaved groups, MiMo-V2.5-Pro-specific). We split into - three per-row-quantized projections via `split_qkv_fused()`. + layer (num_kv_heads interleaved groups, MiMo-V2.5-Pro-specific). We + split into three per-proj BF16 tensors via `split_qkv_fused()`. BF16 + (not FP8) is required: Pro's attention weights are small-magnitude + and the NKI blockwise FP8 accumulator drifts over 70 layers, producing + gibberish output. MoE experts can stay FP8. - o_proj is BF16 (listed in quantization_config.ignored_layers); kept as BF16 on the Neuron side (RowParallelLinear, not QuantizedRowParallel). - Layer 0 is a dense MLP (moe_layer_freq[0] == 0) with intermediate_size @@ -172,6 +175,33 @@ def _requantize_per_row(dequant: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens return quantized, scales.to(torch.float32) +def _dequant_attn_to_bf16( + weight: torch.Tensor, scale: Optional[torch.Tensor] +) -> torch.Tensor: + """Dequantize an FP8 blockwise attention weight to BF16. + + Used by the Flash-style path where q/k/v ship as separate per-proj + tensors (not fused). The fused-qkv path handles dequant inside + split_qkv_fused because it also has to unwind the phantom-row padding. + """ + if weight.dtype != torch.float8_e4m3fn or scale is None: + return weight.to(torch.bfloat16) + + out_features, in_features = weight.shape + scale_h, scale_w = scale.shape + block_h = (out_features + scale_h - 1) // scale_h + block_w = (in_features + scale_w - 1) // scale_w + + wf = weight.float() + dequant = torch.zeros(out_features, in_features, dtype=torch.float32) + for i in range(scale_h): + for j in range(scale_w): + h0, h1 = i * block_h, min((i + 1) * block_h, out_features) + w0, w1 = j * block_w, min((j + 1) * block_w, in_features) + dequant[h0:h1, w0:w1] = wf[h0:h1, w0:w1] * scale[i, j].item() + return dequant.to(torch.bfloat16) + + def split_qkv_fused( qkv_weight: torch.Tensor, qkv_scale: Optional[torch.Tensor], @@ -179,10 +209,10 @@ def split_qkv_fused( num_kv_heads: int, head_dim: int, v_head_dim: int, -) -> Dict[str, Tuple[torch.Tensor, Optional[torch.Tensor]]]: - """Split Pro's pre-fused qkv_proj into q/k/v. MiMo-V2.5-Pro specific. +) -> Dict[str, torch.Tensor]: + """Split Pro's pre-fused qkv_proj into q/k/v (BF16 output). - HF layout — cross-validated against sglang on H200: + MiMo-V2.5-Pro specific. HF layout — cross-validated against sglang on H200: `qkv_proj.weight` is NOT `[all_Q | all_K | all_V]`. It is num_kv_heads interleaved groups, each holding (heads_per_group Q heads, 1 K head, 1 V head) packed contiguously: @@ -209,6 +239,13 @@ def split_qkv_fused( V is immediately followed by group (g+1)'s Q. We recover the correct dequant by padding each group up to 3456 rows before applying the scale, then stripping the phantom rows. + + Output dtype is always BF16 (no scale). Pro's q/k/v weights are + small-magnitude (abs_mean ~0.00124, 4x smaller than V2.5); the NKI + blockwise FP8 accumulator drifts at this scale and produces gibberish + output. Keeping q/k/v as BF16 while MoE experts stay FP8 is the only + configuration verified to produce coherent output, so this is the + single supported attention recipe. """ in_features = qkv_weight.shape[1] hpg = num_q_heads // num_kv_heads @@ -234,50 +271,53 @@ def split_qkv_fused( ) if qkv_weight.dtype != torch.float8_e4m3fn or qkv_scale is None: - # BF16 path + # BF16 source path (rare — most Pro checkpoints ship as FP8+scale). w = qkv_weight.view(num_kv_heads, real_rows_per_group, in_features) - q_w = w[:, :qg_rows, :].reshape(num_kv_heads * qg_rows, in_features).contiguous() - k_w = w[:, qg_rows:qg_rows + kg_rows, :].reshape(num_kv_heads * kg_rows, in_features).contiguous() - v_w = w[:, qg_rows + kg_rows:, :].reshape(num_kv_heads * vg_rows, in_features).contiguous() - q_w2, q_s2 = convert_bf16_to_fp8_per_row(q_w) - k_w2, k_s2 = convert_bf16_to_fp8_per_row(k_w) - v_w2, v_s2 = convert_bf16_to_fp8_per_row(v_w) - return {"q_proj": (q_w2, q_s2), "k_proj": (k_w2, k_s2), "v_proj": (v_w2, v_s2)} - - # FP8 + blockwise scale path. - expected_scale_rows = num_kv_heads * scale_rows_per_group - expected_scale_cols = (in_features + BLOCK - 1) // BLOCK - assert qkv_scale.shape == (expected_scale_rows, expected_scale_cols), ( - f"qkv scale shape {tuple(qkv_scale.shape)} != expected " - f"({expected_scale_rows}, {expected_scale_cols})" - ) - - w = qkv_weight.to(torch.float32).view( - num_kv_heads, real_rows_per_group, in_features - ) - w_padded = torch.zeros( - num_kv_heads, padded_rows_per_group, in_features, dtype=torch.float32 - ) - w_padded[:, :real_rows_per_group, :] = w + else: + # FP8 + blockwise scale path: dequant with phantom-row padding. + expected_scale_rows = num_kv_heads * scale_rows_per_group + expected_scale_cols = (in_features + BLOCK - 1) // BLOCK + assert qkv_scale.shape == (expected_scale_rows, expected_scale_cols), ( + f"qkv scale shape {tuple(qkv_scale.shape)} != expected " + f"({expected_scale_rows}, {expected_scale_cols})" + ) - s = qkv_scale.to(torch.float32).view( - num_kv_heads, scale_rows_per_group, expected_scale_cols - ) - s_exp = s.repeat_interleave(BLOCK, dim=1).repeat_interleave(BLOCK, dim=2) - s_exp = s_exp[:, :padded_rows_per_group, :in_features] + wf = qkv_weight.to(torch.float32).view( + num_kv_heads, real_rows_per_group, in_features + ) + w_padded = torch.zeros( + num_kv_heads, padded_rows_per_group, in_features, dtype=torch.float32 + ) + w_padded[:, :real_rows_per_group, :] = wf - deq_padded = w_padded * s_exp - deq = deq_padded[:, :real_rows_per_group, :] + s = qkv_scale.to(torch.float32).view( + num_kv_heads, scale_rows_per_group, expected_scale_cols + ) + s_exp = s.repeat_interleave(BLOCK, dim=1).repeat_interleave(BLOCK, dim=2) + s_exp = s_exp[:, :padded_rows_per_group, :in_features] - q_deq = deq[:, :qg_rows, :].reshape(num_kv_heads * qg_rows, in_features).contiguous() - k_deq = deq[:, qg_rows:qg_rows + kg_rows, :].reshape(num_kv_heads * kg_rows, in_features).contiguous() - v_deq = deq[:, qg_rows + kg_rows:, :].reshape(num_kv_heads * vg_rows, in_features).contiguous() + w = (w_padded * s_exp)[:, :real_rows_per_group, :] - q_w2, q_s2 = _requantize_per_row(q_deq) - k_w2, k_s2 = _requantize_per_row(k_deq) - v_w2, v_s2 = _requantize_per_row(v_deq) + q_bf16 = ( + w[:, :qg_rows, :] + .reshape(num_kv_heads * qg_rows, in_features) + .contiguous() + .to(torch.bfloat16) + ) + k_bf16 = ( + w[:, qg_rows:qg_rows + kg_rows, :] + .reshape(num_kv_heads * kg_rows, in_features) + .contiguous() + .to(torch.bfloat16) + ) + v_bf16 = ( + w[:, qg_rows + kg_rows:, :] + .reshape(num_kv_heads * vg_rows, in_features) + .contiguous() + .to(torch.bfloat16) + ) - return {"q_proj": (q_w2, q_s2), "k_proj": (k_w2, k_s2), "v_proj": (v_w2, v_s2)} + return {"q_proj": q_bf16, "k_proj": k_bf16, "v_proj": v_bf16} def _maybe_fp8_to_neuron_per_row( @@ -330,22 +370,21 @@ def process_layer( num_kv = config["num_key_value_heads"] hd = config.get("head_dim") vhd = config.get("v_head_dim", hd) + # split_qkv_fused returns BF16 weights only (no .scale); see its + # docstring for the rationale on why attn stays BF16 while MoE is FP8. split = split_qkv_fused(qkv_w, qkv_s, num_q, num_kv, hd, vhd) - for proj, (w2, s2) in split.items(): - out[f"{out_prefix}self_attn.{proj}.weight"] = w2 - if s2 is not None: - out[f"{out_prefix}self_attn.{proj}.scale"] = s2 + for proj, w_bf16 in split.items(): + out[f"{out_prefix}self_attn.{proj}.weight"] = w_bf16 else: - # Flash-style: q/k/v stored separately. + # Flash-style: q/k/v stored separately. Dequant to BF16 for the same + # reason as the fused path. for proj in ("q_proj", "k_proj", "v_proj"): w = lazy.get(f"{prefix}self_attn.{proj}.weight") if w is None: continue s = lazy.get(f"{prefix}self_attn.{proj}.weight_scale_inv") - w2, s2 = _maybe_fp8_to_neuron_per_row(w, s) - out[f"{out_prefix}self_attn.{proj}.weight"] = w2 - if s2 is not None: - out[f"{out_prefix}self_attn.{proj}.scale"] = s2 + w_bf16 = _dequant_attn_to_bf16(w, s) + out[f"{out_prefix}self_attn.{proj}.weight"] = w_bf16 # o_proj is listed in HF quantization_config.ignored_layers and ships as # BF16; on Neuron it binds to a plain RowParallelLinear (see diff --git a/contrib/models/MiMo-V2.5-Pro/src/conversion_script/repatch_qkv_bf16.py b/contrib/models/MiMo-V2.5-Pro/src/conversion_script/repatch_qkv_bf16.py deleted file mode 100644 index 6ac7fb3a..00000000 --- a/contrib/models/MiMo-V2.5-Pro/src/conversion_script/repatch_qkv_bf16.py +++ /dev/null @@ -1,190 +0,0 @@ -#!/usr/bin/env python3 -"""In-place patch: replace q/k/v FP8+scale with BF16 in a MiMo-V2.5-Pro -preprocessed Neuron checkpoint, for every decoder layer. Leaves MoE experts, -norms, embed, lm_head, o_proj untouched. - -Rationale: Pro's attention q/k/v weights have abs_mean ~0.001-0.005, roughly -4x smaller than V2.5. The NKI blockwise FP8 accumulator on the attention -path loses enough precision at this magnitude to drift the logits across -70 layers; dequantizing q/k/v to BF16 before the matmul restores coherent -output. MoE experts (also small-scale) can stay FP8. - -Note: simply adding q_proj/k_proj/v_proj to NxDI's `modules_to_not_convert` -at compile time is NOT equivalent — NxDI casts the raw fp8_e4m3fn bytes to -bfloat16 without applying the blockwise scale, which produces nonsense -weights. This script reads the HF fused qkv + scale, dequants per-group, -and writes BF16 weights back into the preprocessed Neuron checkpoint in -place. - -Run under /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16 (or any venv -with torch and safetensors). Takes ~22 min on a trn2.48xlarge for 70 -layers. -""" -import argparse -import glob -import json -import math -import os -import sys -import time - -import torch -from safetensors import safe_open -from safetensors.torch import save_file - - -def main(): - parser = argparse.ArgumentParser( - description="In-place dequant q/k/v from FP8+scale to BF16 in the " - "preprocessed Neuron checkpoint.", - ) - parser.add_argument( - "--hf_model_path", - required=True, - help="Path to the original HuggingFace MiMo-V2.5-Pro checkpoint " - "(fused qkv_proj + qkv_proj.weight_scale_inv).", - ) - parser.add_argument( - "--neuron_model_path", - required=True, - help="Path to the preprocessed Neuron-FP8 checkpoint. q/k/v entries " - "in its model_layer{N}.safetensors shards will be overwritten in " - "place with BF16 values; the scale entries are dropped from the " - "index.", - ) - args = parser.parse_args() - - # Import split_qkv_fused from the neighbouring preprocess script so the - # group layout math (hpg, qg_rows, kg_rows, vg_rows) stays in one place. - sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - from preprocess_mimo_v2_fp8 import LazyWeightMap # noqa: F401 - - hf_src = args.hf_model_path - neuron = args.neuron_model_path - cfg = json.load(open(os.path.join(hf_src, "config.json"))) - hp = cfg["hybrid_layer_pattern"] - num_hidden_layers = cfg["num_hidden_layers"] - - with open(os.path.join(hf_src, "model.safetensors.index.json")) as f: - hf_wm = json.load(f)["weight_map"] - lazy = LazyWeightMap(hf_src, hf_wm) - - print( - f"Patching q/k/v -> BF16 in {neuron}/model_layer{{0..{num_hidden_layers - 1}}}.safetensors", - flush=True, - ) - t0 = time.time() - - for li in range(num_hidden_layers): - layer_file = os.path.join(neuron, f"model_layer{li}.safetensors") - if not os.path.exists(layer_file): - print(f" layer {li}: FILE MISSING, skip", flush=True) - continue - with safe_open(layer_file, framework="pt") as fp: - layer_sd = {k: fp.get_tensor(k) for k in fp.keys()} - - is_swa = hp[li] == 1 - num_q = cfg["swa_num_attention_heads" if is_swa else "num_attention_heads"] - num_kv = cfg["swa_num_key_value_heads" if is_swa else "num_key_value_heads"] - hd = cfg["swa_head_dim" if is_swa else "head_dim"] - vhd = cfg["swa_v_head_dim" if is_swa else "v_head_dim"] - prefix = f"model.layers.{li}.self_attn" - qkv_w = lazy.get(f"{prefix}.qkv_proj.weight") - qkv_s = lazy.get(f"{prefix}.qkv_proj.weight_scale_inv") - - BLOCK = 128 - hpg = num_q // num_kv - qg_rows = hpg * hd - kg_rows = 1 * hd - vg_rows = 1 * vhd - R = qg_rows + kg_rows + vg_rows - in_features = qkv_w.shape[1] - q_blk = qg_rows // BLOCK - k_blk = (kg_rows + BLOCK - 1) // BLOCK - v_blk = (vg_rows + BLOCK - 1) // BLOCK - per = q_blk + k_blk + v_blk - padded = per * BLOCK - - w = qkv_w.to(torch.float32).view(num_kv, R, in_features) - w_padded = torch.zeros(num_kv, padded, in_features, dtype=torch.float32) - w_padded[:, :R, :] = w - s = qkv_s.to(torch.float32).view(num_kv, per, (in_features + BLOCK - 1) // BLOCK) - s_exp = s.repeat_interleave(BLOCK, dim=1).repeat_interleave(BLOCK, dim=2) - s_exp = s_exp[:, :padded, :in_features] - deq_padded = w_padded * s_exp - deq = deq_padded[:, :R, :] - - q_bf16 = ( - deq[:, :qg_rows, :] - .reshape(num_kv * qg_rows, in_features) - .contiguous() - .to(torch.bfloat16) - ) - k_bf16 = ( - deq[:, qg_rows : qg_rows + kg_rows, :] - .reshape(num_kv * kg_rows, in_features) - .contiguous() - .to(torch.bfloat16) - ) - v_bf16 = ( - deq[:, qg_rows + kg_rows :, :] - .reshape(num_kv * vg_rows, in_features) - .contiguous() - .to(torch.bfloat16) - ) - - for key in list(layer_sd): - if any( - key.endswith(f".{p}.weight") or key.endswith(f".{p}.scale") - for p in ("q_proj", "k_proj", "v_proj") - ): - del layer_sd[key] - - out_prefix = f"layers.{li}.self_attn" - layer_sd[f"{out_prefix}.q_proj.weight"] = q_bf16 - layer_sd[f"{out_prefix}.k_proj.weight"] = k_bf16 - layer_sd[f"{out_prefix}.v_proj.weight"] = v_bf16 - - save_file(layer_sd, layer_file) - dt = time.time() - t0 - if li % 5 == 0 or li == num_hidden_layers - 1: - print( - f" layer {li:2d} [{'swa' if is_swa else 'full'}]: " - f"q{list(q_bf16.shape)} k{list(k_bf16.shape)} v{list(v_bf16.shape)} " - f"elapsed={dt:.1f}s", - flush=True, - ) - - print("\nRewrite weight_map to reflect dtype change.", flush=True) - idx_path = os.path.join(neuron, "model.safetensors.index.json") - with open(idx_path) as f: - idx = json.load(f) - keys_to_drop = [ - k - for k in idx["weight_map"] - if any(k.endswith(f".{p}.scale") for p in ("q_proj", "k_proj", "v_proj")) - ] - for k in keys_to_drop: - idx["weight_map"].pop(k, None) - - total = 0 - for f_path in sorted(glob.glob(os.path.join(neuron, "*.safetensors"))): - with safe_open(f_path, framework="pt") as fp: - for k in fp.keys(): - t = fp.get_slice(k) - shape = list(t.get_shape()) - dt_bytes = {"F32": 4, "F16": 2, "BF16": 2, "F8_E4M3": 1}.get( - t.get_dtype(), 2 - ) - total += dt_bytes * max(1, int(math.prod(shape))) - idx["metadata"] = idx.get("metadata", {}) - idx["metadata"]["total_size"] = total - with open(idx_path, "w") as f: - json.dump(idx, f, indent=2) - print(f" dropped {len(keys_to_drop)} scale entries from index", flush=True) - print(f" total_size now {total / 1e9:.2f} GB", flush=True) - print(f"\nDone in {time.time() - t0:.1f}s", flush=True) - - -if __name__ == "__main__": - main() From 62fea531bd500a3624aa4df02fa651a85b23cded Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 29 Apr 2026 14:14:44 +0800 Subject: [PATCH 16/24] [contrib] MiMo-V2.5-Pro: standardize on pytorch_inference_vllm_0_16 venv Previously preprocess / smoke / vLLM-serving used three different venvs depending on which stage of the port we were in; both 2_9_nxd and inference_vllm_0_16 happen to have working NxDI + torch installs, so everything ran but the split was noise. Pick one and stick with it. pytorch_inference_vllm_0_16 is the right choice because: - 0_setup.sh installs vllm-neuron (editable) there, so vllm serving has no alternative. - NxDI direct calls from smoke_compile / smoke_generate also work there (nxdi is preinstalled by the DLAMI in both venvs). - Keeping one venv means no confusion about which python to invoke. Files updated: 0_setup.sh, run_bench_single.sh, smoke_compile_mimo_v2.py and smoke_generate_mimo_v2.py docstrings, run_preprocess_parallel.sh, README Prerequisites. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2.5-Pro/README.md | 4 ++-- contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh | 2 +- contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh | 2 +- .../models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py | 5 ++--- .../models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py | 2 +- .../src/conversion_script/run_preprocess_parallel.sh | 2 +- 6 files changed, 8 insertions(+), 9 deletions(-) diff --git a/contrib/models/MiMo-V2.5-Pro/README.md b/contrib/models/MiMo-V2.5-Pro/README.md index 4788275c..55a846f3 100644 --- a/contrib/models/MiMo-V2.5-Pro/README.md +++ b/contrib/models/MiMo-V2.5-Pro/README.md @@ -83,7 +83,7 @@ Key features: - **Instance**: trn2.48xlarge (128 physical NeuronCores, logical_nc_config=2 → 64 logical cores) - **Neuron SDK**: 2.29 (Python 3.12, PyTorch 2.9) -- **Venvs**: `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference` (for preprocess + NxDI direct smoke), `/opt/aws_neuronx_venv_pytorch_inference_vllm_0_16` (for vLLM serving). Both ship with the DLAMI. +- **Venv**: `/opt/aws_neuronx_venv_pytorch_inference_vllm_0_16` (used by preprocess, smoke, and vLLM serving alike; ships with the DLAMI and is where `0_setup.sh` installs the patched `vllm-neuron`). - **Disk**: ~3 TB free under `/opt/dlami/nvme` (the HF FP8 checkpoint is ~962 GB, the Neuron-FP8 preprocessed output is ~1 TB, and `save_sharded_checkpoint=true` writes another ~300-1000 GB per compiled config (varies with recipe)). ### NVMe mount @@ -124,7 +124,7 @@ huggingface-cli download XiaomiMiMo/MiMo-V2.5-Pro \ # 3. Preprocess HF FP8 -> Neuron-FP8 (BF16 attn, FP8 MoE). ~20 min, ~24 GB # peak RAM. The preprocess dequants q/k/v to BF16 in one pass — see # "Checkpoint Preparation" below for why BF16 attn is the only recipe. -source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate python contrib/models/MiMo-V2.5-Pro/src/conversion_script/preprocess_mimo_v2_fp8.py \ --hf_model_path /opt/dlami/nvme/models/MiMo-V2.5-Pro \ --save_path /opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8 \ diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh index bdcbe11c..eb0ee005 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/0_setup.sh @@ -12,7 +12,7 @@ echo "==========================================" echo "Setup: vllm-neuron + MiMo-V2.5-Pro weights" echo "==========================================" -source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate PATCH_FILE="$(cd "$(dirname "$0")" && pwd)/vllm-neuron-patch.patch" diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh index 2e4e4a3e..318ef477 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh @@ -26,7 +26,7 @@ set -e -source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8}" PORT="${PORT:-8000}" diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py index 73b8abbe..0f992dab 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py @@ -16,9 +16,8 @@ forward pass that allocates the shared scratchpad — useful when HBM is tight. -Run under /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference (NxDI direct). -The `/opt/aws_neuronx_venv_pytorch_inference_vllm_0_16` venv is only -needed for vllm serving (bench_mimo_v2.sh). +Run under /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16 (same venv +as vllm serving; both NxDI direct and vllm-neuron are installed there). """ import os diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py index 4c1a3bd2..1b76f4a4 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py @@ -7,7 +7,7 @@ single prompt via HuggingFaceGenerationAdapter. Purpose: sanity-check that the FP8 MoE + preprocessed scales actually produce coherent tokens. -Run under /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference. +Run under /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16. """ import os diff --git a/contrib/models/MiMo-V2.5-Pro/src/conversion_script/run_preprocess_parallel.sh b/contrib/models/MiMo-V2.5-Pro/src/conversion_script/run_preprocess_parallel.sh index 5001ecd6..623946c0 100755 --- a/contrib/models/MiMo-V2.5-Pro/src/conversion_script/run_preprocess_parallel.sh +++ b/contrib/models/MiMo-V2.5-Pro/src/conversion_script/run_preprocess_parallel.sh @@ -15,7 +15,7 @@ # VENV venv with torch + safetensors + contrib pkg on sys.path set -e -VENV=${VENV:-/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference} +VENV=${VENV:-/opt/aws_neuronx_venv_pytorch_inference_vllm_0_16} source "$VENV/bin/activate" HF_MODEL_PATH=${HF_MODEL_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro} From e9ae0949a5c9de8305ab67fc07b8687ea8b77136 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 29 Apr 2026 14:34:16 +0800 Subject: [PATCH 17/24] [contrib] MiMo-V2.5-Pro: trim Status, reframe perf numbers as historical Status had a 3-way split (all-FP8 vs BF16-attn vs preprocess emits BF16) that made sense during the diagnosis but doesn't once BF16-attn is the only shipping recipe. Collapse it into four focused subsections: * Why BF16 attn + FP8 MoE * Cost and constraints (HBM, seq_len=256, BS>=48, EP constraints) * Recipes tried that did not work (all-FP8, use_torch_block_wise) * Next experiments queued Performance: reframe the vLLM throughput table as a historical all-FP8 capture kept for infra validation and order-of-magnitude reference. The shipping recipe (BF16 attn + seq_len=256) hasn't been re-benchmarked yet; note the expected delta (only q/k/v change, MoE unchanged) so readers can project. vLLM Serving note: since the shipped start_vllm_server.sh still has seq_len=1024 and doesn't list q/k/v in modules_to_not_convert, spell out exactly what to change if the BF16-attn checkpoint OOMs on load. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2.5-Pro/README.md | 49 ++++++++++---------------- 1 file changed, 18 insertions(+), 31 deletions(-) diff --git a/contrib/models/MiMo-V2.5-Pro/README.md b/contrib/models/MiMo-V2.5-Pro/README.md index 55a846f3..9f77f799 100644 --- a/contrib/models/MiMo-V2.5-Pro/README.md +++ b/contrib/models/MiMo-V2.5-Pro/README.md @@ -37,48 +37,35 @@ Key features: ## Status (work-in-progress) -**This port compiles cleanly and serves via vLLM on Trn2. Under the default all-FP8 recipe the output drifts on most prompts; a BF16-attn recipe (keep MoE FP8, dequant q/k/v_proj to BF16, compile at seq_len=256) restores coherent output and isolates the root cause.** Findings as of 2026-04-29: +**This port compiles cleanly and serves via vLLM on Trn2. The shipping recipe is BF16 attention + FP8 MoE at `seq_len=256` — verified to produce coherent output end-to-end via `smoke_generate_mimo_v2.py`.** Last updated 2026-04-29. -### All-FP8 recipe (`tp_degree=64, moe_tp=1, moe_ep=64, BS=48, seq_len=1024`) +### Why BF16 attn + FP8 MoE -- **Prompt-dependent drift.** Self-intro prompts sometimes return coherent answers (sampling gets lucky on a strong self-identifying logit); most other prompts collapse to repetition or unrelated text within a few tokens (e.g. `"The capital of France is\n# 1000000000000000"`, `"Once upon a time in a small village there lived\n# 0000000000..."`, chat continuations that drift into RLHF-style self-reflection with Chinese/Thai). -- **Not a sampling artifact.** `temperature` in the request is ignored — vllm-neuron's on-device sampling config (`do_sample=true, T=0.6, top_k=20, top_p=0.95`) is baked into the NEFF at compile time. Output is always stochastic, but the underlying logits are already drifted. -- Same failure pattern was observed on MiMo-V2-Pro under the same recipe (`"0.0.0.0:8080"`, etc.). V2.5-Pro's MoE expert weights are byte-identical to V2-Pro (verified layer 1 expert 0 dequant stats match to 6 decimal places on 2026-04-28). +Pro's attention weights have `abs_mean ≈ 0.00124`, roughly 4× smaller than V2.5 (256 experts). Under an all-FP8 recipe, the NKI blockwise FP8 accumulator on attention q/k/v at this magnitude drifts the logits across 70 layers and produces prompt-dependent gibberish (`"The capital of France is\n# 1000000000000000"`, `"Once upon a time in a small village there lived\n# 0000000000..."`, etc.). Dequantizing q/k/v to BF16 before the matmul restores coherent output. MoE experts (scales `≈ 2.3e-5`, similarly small) can stay FP8. -### BF16-attn recipe (`q_proj/k_proj/v_proj` dequant to BF16, MoE kept FP8, `seq_len=256`) +Verified end-to-end: `smoke_generate_mimo_v2.py` with a minimal chat template returns a well-formed reasoning trace that correctly identifies the model ("As MiMo, based on Xiaomi's self-developed large model..."). `preprocess_mimo_v2_fp8.py` emits BF16 q/k/v directly so no separate step is required. -- **Output is coherent.** On the same self-intro prompt, smoke_generate with a minimal chat template produces a well-formed reasoning trace that correctly identifies the model: - ``` - Okay, the user is asking for a simple self-introduction in one sentence, - with no deeper or hidden needs apparent. As MiMo, based on Xiaomi's self-developed - large model, I need to respond in a friendly, positive, and helpful way that aligns - with providing assistance ... - ``` -- **This narrows the root cause to attention-path FP8 precision**, not the MoE experts. Pro's attention weights have `abs_mean ≈ 0.00124`, roughly 4× smaller than V2.5 (256 experts). The NKI blockwise FP8 accumulator on attention q/k/v at this magnitude loses enough precision to drift the logits across 70 layers; dequantizing q/k/v to BF16 before the matmul restores correct output. MoE experts (scales `≈ 2.3e-5`, also small) can stay FP8 under this recipe. -- **Cost: HBM headroom and seq_len.** BF16 q/k/v adds ~2 GB per rank. At `seq_len=1024` this OOMs on load (previous attempt failed allocating 41 MB for rdh/alltoall rings). `seq_len=256` frees enough full-attention softmax scratch to fit; longer context needs a different HBM plan (cross-instance TP/PP, or larger instance). -- GPU stacks (sglang on H100/H200) run the exact same OCP FP8 checkpoint correctly because they always dequantize FP8 → BF16 before the matmul. The issue is specific to Neuron's direct-FP8 compute path on subnormal-leaning tensors. -- Reference: Jim Burtoft observed similar prompt-dependent FP8 degradation on Flash and his Kimi PR #131 names "blockwise kernel padding produces depressed logits with EP=2 on SDK 2.29; SDK 2.28 recommended". +GPU stacks (sglang on H100/H200) run the same OCP FP8 checkpoint correctly because they always dequantize FP8 → BF16 before the matmul. The issue is specific to Neuron's direct-FP8 compute path on small-magnitude tensors. Kimi PR #131 observes similar FP8 degradation on Flash and recommends SDK 2.28. -### Preprocess emits BF16 q/k/v +### Cost and constraints -`src/conversion_script/preprocess_mimo_v2_fp8.py` now dequants q/k/v to BF16 directly (`split_qkv_fused` for the fused Pro layout, `_dequant_attn_to_bf16` for the Flash-style split layout). Output checkpoint has no `q_proj.scale` / `k_proj.scale` / `v_proj.scale` entries. Compile-time `modules_to_not_convert` must include `q_proj`, `k_proj`, `v_proj` so NxDI routes them through a plain `ColumnParallelLinear`; `smoke_compile_mimo_v2.py` and `start_vllm_server.sh` already do. +- **HBM headroom.** BF16 q/k/v adds ~2 GB per rank. `seq_len=1024` OOMs on load (the previous attempt failed allocating ~40 MB for rdh/alltoall rings after per-rank tensors already reached 20.9/24 GB). `seq_len=256` frees enough full-attention softmax scratch to fit. +- **Short context.** At `seq_len=256`, Pro's own chat template with the default system prompt is already 260 tokens. Longer context needs a different HBM plan (cross-instance TP/PP, or a larger instance). +- `BS * top_k / num_experts >= 1.0` required when `moe_ep_degree > 1` at decode (else `NotImplementedError`). With `num_experts=384, top_k=8` this forces `BS >= 48`. +- `n_routed_experts=384 = 2^7 × 3` → `384 / ep_degree` is never a power of 2 (6, 12, 24, 48, 96, 192, 384). Kimi PR #131 says NKI `_bwmm_shard_on_block_nki_call` on SDK 2.29 has "depressed logits with EP=2" and recommends SDK 2.28. -### Recipes that were tried and did not resolve the drift (all on 2026-04-28/29) +### Recipes tried that did not work -- **`use_torch_block_wise=True`** (PyTorch-fallback blockwise matmul for higher accumulator precision): compile+shard succeeded after ~2 h, but `model.load()` crashed with `status=4 Allocation Failure` — the fallback path raises HBM demand even with MoE-only scope. +- **All-FP8 attention (`modules_to_not_convert` without q/k/v).** Drifts as described above. Known broken; `preprocess_mimo_v2_fp8.py` no longer emits it. +- **`use_torch_block_wise=True`** (PyTorch-fallback blockwise matmul for higher accumulator precision): compile+shard succeeded after ~2 h, but `model.load()` crashed with `status=4 Allocation Failure` — the fallback path raises HBM demand even when scoped to MoE. ### Next experiments queued -- **BF16-attn + larger `seq_len`.** `seq_len=256` is tight; Pro's chat template with the default system prompt is already 260 tokens. Either shrink the system prompt, or try BF16-attn at `seq_len=512` on a less-full HBM plan (e.g. MoE-EP with fewer experts per rank, or drop batch size below 48 via different recipe trade-offs). +- **BF16-attn at `seq_len=512`** (needs a tighter HBM plan — smaller batch, different EP ratio, or shrinking the default system prompt). - **Cross-instance BF16** via pipeline/tensor parallelism on 2× Trn2 (single-instance HBM cannot hold full BF16 Pro). - **Selective BF16 only on MoE `gate_up_proj`** (smallest expert scales) while keeping `down_proj` FP8 — another axis to probe if attn drift returns at longer contexts. - **SDK 2.28 venv** test once installed, per Kimi PR #131. -### Known NxDI limits that constrain recipe choice - -- `BS * top_k / num_experts >= 1.0` required when `moe_ep_degree > 1` at decode (else `NotImplementedError`). With `num_experts=384, top_k=8` this forces `BS >= 48`. -- `n_routed_experts=384 = 2^7 × 3` → `384 / ep_degree` is never a power of 2 (6, 12, 24, 48, 96, 192, 384). Kimi PR #131 says NKI `_bwmm_shard_on_block_nki_call` on SDK 2.29 has "depressed logits with EP=2" and recommends SDK 2.28. - ## Prerequisites - **Instance**: trn2.48xlarge (128 physical NeuronCores, logical_nc_config=2 → 64 logical cores) @@ -369,7 +356,7 @@ bash contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh See "Environment variables" above for all the knobs (`NEURON_COMPILED_ARTIFACTS`, `BASE_COMPILE_WORK_DIR`, etc.) and their defaults. -> **Note on output quality:** the shipped vLLM scripts use the **all-FP8 recipe** (`seq_len=1024`, attn+MoE both FP8). This currently produces prompt-dependent drift — see Status. The **BF16-attn recipe** that restores coherent output has so far only been validated end-to-end via `smoke_generate_mimo_v2.py` (direct NxDI, `seq_len=256`); porting it into `start_vllm_server.sh` requires also adding `q_proj/k_proj/v_proj` to the `modules_to_not_convert` list and dropping `seq_len` / `max_model_len` to 256, and the perf numbers below have not been re-measured in that configuration. +> **Note on the shipped vLLM scripts:** the current `start_vllm_server.sh` still uses `seq_len=1024` and does not list `q_proj/k_proj/v_proj` in `modules_to_not_convert`. Coupled with a BF16-attn preprocessed checkpoint this runs correctly (NxDI just sees BF16 tensors where it expected FP8 and casts them as-is) but at a longer context than the BF16-attn recipe has been HBM-validated for. If you hit a `status=4 Allocation Failure` on load, drop `seq_len` / `max_model_len` / `context_encoding_buckets` / `token_generation_buckets` to 256 and add `q_proj/k_proj/v_proj` to `modules_to_not_convert` to match the smoke-verified configuration. The bench numbers below were taken on the older all-FP8 checkpoint and have not been re-measured since the preprocess switched to BF16 attn. ### vllm-neuron patch summary @@ -381,9 +368,9 @@ The patch is applied to vllm-neuron 0.5.0 and: ## Performance -> The throughput numbers below are from a working vLLM server run on 2026-04-29 under the recommended FP8 recipe. Output quality under this recipe is **not production-usable** (see Status); the numbers show that the serving infrastructure runs end-to-end, not that the model answers correctly. +> The throughput numbers below were captured on 2026-04-29 against a pre-BF16-attn checkpoint (all-FP8, `seq_len=1024`). They are historical — the shipping recipe is BF16 attn + FP8 MoE at `seq_len=256` and has not yet been re-benchmarked. The numbers are kept here for infra validation (continuous batching + bucketing + on-device sampling + vllm-neuron plugin all wire up end-to-end) and order-of-magnitude reference. -### vLLM Serving (trn2.48xlarge, all-FP8 recipe, BS=48, TP=64, moe_tp=1/moe_ep=64, CB + bucketing, `seq_len=1024`) +### vLLM Serving (trn2.48xlarge, historical all-FP8 run, BS=48, TP=64, moe_tp=1/moe_ep=64, CB + bucketing, `seq_len=1024`) Input/output: 900/90 tokens (`vllm bench serve --dataset-name random`), `on_device_sampling_config={do_sample:true, temperature:0.6, top_k:20, top_p:0.95}`. @@ -395,7 +382,7 @@ Input/output: 900/90 tokens (`vllm bench serve --dataset-name random`), `on_devi Per-stream ITL median holds at ~220 ms across all concurrency levels; TPOT/TTFT growth at higher concurrency comes from continuous-batching queue pressure, not per-step compute. -> **Numbers are from the all-FP8 recipe, which produces drifted output (see Status).** The BF16-attn recipe that restores coherent output has not yet been re-benchmarked on vLLM; throughput should be comparable (only q/k/v go BF16, MoE stays FP8) but at `seq_len=256` instead of 1024, so TTFT/latency characteristics will differ. +> Expected BF16-attn delta: only q/k/v go from FP8 to BF16 (MoE is unchanged), so steady-state throughput should be within a few percent. TTFT should drop proportionally with `seq_len` (256 vs 1024 prefill tokens). > **Compile time:** the first Pro compile on SDK 2.29 is ~60 minutes for the TKG NEFF and ~15 minutes for the CTE NEFF; subsequent runs with the same `override_neuron_config` hit the neuronx-cc cache and start in ~1-2 minutes. `save_sharded_checkpoint=true` additionally persists per-rank FP8 shards under `/weights/`, letting future `load()` calls skip the ~10-minute shard_checkpoint pass. First full server launch (compile + shard + warmup) is ~2 hours wall-clock. From d1ea946867385e6e4037a02892e16d7576707b35 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 29 Apr 2026 14:36:31 +0800 Subject: [PATCH 18/24] [contrib] MiMo-V2.5-Pro: fix maintainer name typo --- contrib/models/MiMo-V2.5-Pro/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contrib/models/MiMo-V2.5-Pro/README.md b/contrib/models/MiMo-V2.5-Pro/README.md index 9f77f799..c28aa4cb 100644 --- a/contrib/models/MiMo-V2.5-Pro/README.md +++ b/contrib/models/MiMo-V2.5-Pro/README.md @@ -417,6 +417,6 @@ pytest contrib/models/MiMo-V2.5-Pro/test/integration/test_model.py -v ## Maintainer -Henan Wan (whn09) +Henan Wang (whn09) **Last Updated:** 2026-04-29 From ad81f3f3c7ea3ac2211dca0a833525b16a3ed1a7 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 29 Apr 2026 15:26:06 +0800 Subject: [PATCH 19/24] [contrib] MiMo-V2.5-Pro: align start_vllm_server.sh with BF16-attn ckpt The preprocess now emits BF16 q/k/v (no .scale entries), so vllm-neuron must route attention through the non-quantized ColumnParallelLinear. Three required changes: - Add q_proj/k_proj/v_proj to modules_to_not_convert. Without this, NxDI tries to load q_proj.scale and bails with "Cannot find layers.0.self_attn.q_proj.scale in state_dict". - Drop seq_len / max_model_len / context_encoding_buckets / token_generation_buckets from 1024 to 256. BF16 q/k/v adds ~2 GB per rank and seq_len=1024 OOMs on load; seq_len=256 is the smoke-verified upper bound. - Move NEURON_COMPILED_ARTIFACTS default to a new path (mimo_v2_5_pro_bs48_moetp1_ep64_bf16attn_seq256_vllm) so it doesn't collide with the old all-FP8 compile dir that's been S3-backed up. Note for longer context: seq_len is the single biggest HBM constraint on this recipe; raising it will require either a smaller batch, a different EP ratio, or cross-instance sharding (see README "Next experiments queued"). Co-Authored-By: Claude Opus 4.7 --- .../MiMo-V2.5-Pro/perf_test/start_vllm_server.sh | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh index 6848b604..04c1c5a3 100644 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh @@ -31,7 +31,7 @@ export NXDI_CONTRIB_MIMO_V2_FLASH_SRC # Persistent compile-artifact location (NEFF + per-rank sharded weights). # Setting this overrides vLLM's fallback of # /neuron-compiled-artifacts//. -: "${NEURON_COMPILED_ARTIFACTS:=/opt/dlami/nvme/compiled/mimo_v2_5_pro_bs48_moetp1_ep64_fp8_vllm}" +: "${NEURON_COMPILED_ARTIFACTS:=/opt/dlami/nvme/compiled/mimo_v2_5_pro_bs48_moetp1_ep64_bf16attn_seq256_vllm}" export NEURON_COMPILED_ARTIFACTS # NxDI HLO/NEFF staging directory, pinned to persistent storage so it # survives the nightly Trn2 reboot and a unique per-config subdir. @@ -57,7 +57,7 @@ exec python3 -m vllm.entrypoints.openai.api_server \ --model "$MODEL_PATH" \ --tokenizer "$MODEL_PATH" \ --tensor-parallel-size 64 \ - --max-model-len 1024 \ + --max-model-len 256 \ --max-num-seqs 48 \ --no-enable-chunked-prefill \ --no-enable-prefix-caching \ @@ -79,19 +79,19 @@ exec python3 -m vllm.entrypoints.openai.api_server \ "quantization_type": "blockwise_symmetric", "quantization_block_axis": [1, 2], "quantization_block_size": [128, 128], - "modules_to_not_convert": ["embed_tokens", "lm_head", "norm", "router", "o_proj"], + "modules_to_not_convert": ["embed_tokens", "lm_head", "norm", "router", "o_proj", "q_proj", "k_proj", "v_proj"], "blockwise_matmul_config": {"use_shard_on_block_dynamic_while": true, "block_sharding_strategy": "PING_PONG"}, "moe_tp_degree": 1, "moe_ep_degree": 64, "batch_size": 48, "ctx_batch_size": 1, "tkg_batch_size": 48, - "max_context_length": 1024, - "seq_len": 1024, + "max_context_length": 256, + "seq_len": 256, "is_continuous_batching": true, "enable_bucketing": true, - "context_encoding_buckets": [1024], - "token_generation_buckets": [1024], + "context_encoding_buckets": [256], + "token_generation_buckets": [256], "async_mode": true, "on_device_sampling_config": { "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 From 5a297cba03fb930539d6ef628e79feecb1877a88 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 29 Apr 2026 17:47:53 +0800 Subject: [PATCH 20/24] [contrib] MiMo-V2.5-Pro: fit sanity + bench within seq_len=256 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The start_vllm_server.sh now compiles with seq_len=256 (BF16-attn HBM constraint). Pro's default chat template prepends a ~240-token system prompt that by itself busts the bucket, and the old bench default (input 900, output 90) is also way over. sanity_check.sh: - Switch from /v1/chat/completions to /v1/completions with a hand-rolled <|im_start|>user... <|im_end|><|im_start|>assistant frame that tokenises to ~17 tokens. - Do the HTTP POST from python (bash heredoc mangles the \n inside the chat template, which used to make the model emit a garbage first token — UTF-8 replacement char "?" at the start of every reply). - Note in-comment that request-side temperature / top_k / top_p are ignored; the NEFF's on_device_sampling_config wins. run_bench_single.sh: - Default INPUT_LEN 900 -> 180, OUTPUT_LEN 90 -> 60 (180+60 = 240, fits under seq_len=256 with a small margin for random-range-ratio). - Comment explains the seq_len=256 constraint. bench_mimo_v2.sh is unchanged; it delegates length knobs to run_bench_single.sh. Co-Authored-By: Claude Opus 4.7 --- .../perf_test/run_bench_single.sh | 10 +- .../MiMo-V2.5-Pro/perf_test/sanity_check.sh | 94 +++++++++++-------- 2 files changed, 63 insertions(+), 41 deletions(-) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh index 318ef477..0ec5e90e 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/run_bench_single.sh @@ -17,8 +17,8 @@ # /opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8) # CONCURRENCY --max-concurrency (default 1) # NUM_PROMPTS --num-prompts (default 16) -# INPUT_LEN --random-input-len (default 900) -# OUTPUT_LEN --random-output-len (default 90) +# INPUT_LEN --random-input-len (default 180; matches seq_len=256) +# OUTPUT_LEN --random-output-len (default 60; matches seq_len=256) # RANGE_RATIO --random-range-ratio (default 0.03) # CONFIG_NAME Used in the output filename (default bs48_tp64_moetp1_ep64) # RESULTS_DIR Where to dump per-run log @@ -32,9 +32,11 @@ MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP PORT="${PORT:-8000}" CONCURRENCY="${CONCURRENCY:-1}" NUM_PROMPTS="${NUM_PROMPTS:-16}" -INPUT_LEN="${INPUT_LEN:-900}" -OUTPUT_LEN="${OUTPUT_LEN:-90}" +INPUT_LEN="${INPUT_LEN:-180}" +OUTPUT_LEN="${OUTPUT_LEN:-60}" RANGE_RATIO="${RANGE_RATIO:-0.03}" +# seq_len=256 on the compiled server, so input+output must stay under 256. +# Default 180+60=240 leaves a small margin for random-range-ratio expansion. CONFIG_NAME="${CONFIG_NAME:-bs48_tp64_moetp1_ep64}" RESULTS_DIR="${RESULTS_DIR:-/opt/dlami/nvme/logs/bench_results/mimo_v2_5_pro}" diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh index ed2f6b76..a0ed893c 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh @@ -1,25 +1,28 @@ #!/bin/bash # Quick sanity check against an already-running vLLM server. # -# Assumes vLLM is already listening on $PORT (default 8000) with MiMo-V2.5-Pro -# loaded. Sends a single chat completion and prints the model's reply. +# Posts a minimally-templated chat request to /v1/completions and prints the +# model's reply. We go through /v1/completions (not /v1/chat/completions) +# because Pro's default chat template prepends a ~240-token system prompt +# that by itself overflows the seq_len=256 compile-time bucket; building the +# im_start/im_end/assistant frame by hand keeps the prompt under ~30 tokens +# and fits cleanly. # # Usage: # bash sanity_check.sh # uses defaults # PORT=8001 bash sanity_check.sh # custom port -# PROMPT="..." bash sanity_check.sh # custom prompt +# PROMPT="..." bash sanity_check.sh # custom user content set -e MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8}" PORT="${PORT:-8000}" -# "Introduce yourself" is a high-signal self-identification prompt that the -# FP8 path answers coherently even under current MoE drift (see README -# Status). Swap PROMPT=... if you want to probe other prompts. +# "Introduce yourself" is the self-identification prompt that consistently +# lands in the model's MiMo-aware region. Swap PROMPT=... to probe others. PROMPT="${PROMPT:-Hello! Please introduce yourself in one sentence.}" -MAX_TOKENS="${MAX_TOKENS:-64}" +MAX_TOKENS="${MAX_TOKENS:-80}" -echo "Sanity check: POST /v1/chat/completions on port $PORT" +echo "Sanity check: POST /v1/completions on port $PORT" echo " Model: $MODEL_PATH" echo " Prompt: $PROMPT" echo " Max tokens: $MAX_TOKENS" @@ -28,38 +31,55 @@ echo "" # Health check first — fail fast if server isn't up. if ! curl -sf "http://localhost:$PORT/health" > /dev/null; then echo "ERROR: vLLM server is not responding on http://localhost:$PORT" - echo "Start it with 'bash bench_mimo_v2.sh' or your own launcher first." + echo "Start it with 'bash start_vllm_server.sh' (or bench_mimo_v2.sh)" + echo "first and wait for 'Application startup complete.'" exit 1 fi -# NOTE: request-side `temperature` is ignored by vllm-neuron on this model: -# on-device sampling_config (set at compile time in start_vllm_server.sh as -# do_sample=true, T=0.6, top_k=20, top_p=0.95) is baked into the NEFF and -# request params don't override it. Output will be stochastic. -RESPONSE=$(curl -s "http://localhost:$PORT/v1/chat/completions" \ - -H 'Content-Type: application/json' \ - -d "$(cat </dev/null || echo "$RESPONSE" -echo "" +# NOTE: request-side `temperature` / `top_k` / `top_p` are ignored by +# vllm-neuron on this model: the on_device_sampling_config baked into the +# NEFF at compile time wins. Output is always stochastic; re-run to see +# variance. +# +# Build the chat framing in python so newlines and special tokens survive +# JSON encoding without shell escape pitfalls, then POST to /v1/completions. +python3 <user\n" + + user + + "<|im_end|>\n<|im_start|>assistant\n" +) +body = json.dumps({ + "model": model, + "prompt": prompt, + "max_tokens": int("$MAX_TOKENS"), + "stream": False, +}).encode() +req = urllib.request.Request( + "http://localhost:$PORT/v1/completions", + data=body, + headers={"Content-Type": "application/json"}, +) try: - r = json.load(sys.stdin) - print(r['choices'][0]['message']['content'].strip()) -except Exception as e: - print(f'(could not parse reply: {e})') -" 2>/dev/null) + with urllib.request.urlopen(req, timeout=120) as r: + resp = json.load(r) +except urllib.error.HTTPError as e: + print("HTTP error:", e.code, e.read().decode(errors="replace")) + sys.exit(1) + +if "error" in resp: + print("Error from server:", json.dumps(resp["error"], indent=2)) + sys.exit(1) -echo "Model reply: $REPLY" +text = resp["choices"][0]["text"] +print("Response:") +print(text) +print() +print("Usage:", resp.get("usage", {})) +PYEOF From 965a947839bf626d22a17aef3bb2f23c23c91de6 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 29 Apr 2026 17:53:26 +0800 Subject: [PATCH 21/24] =?UTF-8?q?[contrib]=20MiMo-V2.5-Pro:=20sanity=5Fche?= =?UTF-8?q?ck.sh=20=E2=80=94=20short=20system=20prompt,=20/v1/chat?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Switch back to /v1/chat/completions with an explicit short system message ("You are MiMo, a helpful assistant..."). apply_chat_template then uses our system turn instead of Pro's ~240-token default, and the prompt comes out to ~25 tokens — well under seq_len=256. This is simpler than the /v1/completions + manually-framed-chat route (no shell-escape \n landmines, native OpenAI API shape) and composes cleanly with other chat clients that assume /v1/chat. Override via SYSTEM=... / PROMPT=... / MAX_TOKENS=... env vars. Co-Authored-By: Claude Opus 4.7 --- .../MiMo-V2.5-Pro/perf_test/sanity_check.sh | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh index a0ed893c..a80e85c1 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/sanity_check.sh @@ -1,29 +1,34 @@ #!/bin/bash # Quick sanity check against an already-running vLLM server. # -# Posts a minimally-templated chat request to /v1/completions and prints the -# model's reply. We go through /v1/completions (not /v1/chat/completions) -# because Pro's default chat template prepends a ~240-token system prompt -# that by itself overflows the seq_len=256 compile-time bucket; building the -# im_start/im_end/assistant frame by hand keeps the prompt under ~30 tokens -# and fits cleanly. +# Posts a chat request to /v1/completions and prints the reply. +# +# Pro's default chat template prepends a ~240-token system prompt that by +# itself overflows the seq_len=256 compile-time bucket, so we send an +# explicit short system message — apply_chat_template then uses ours +# instead of the default and the whole prompt fits in ~25 tokens. # # Usage: # bash sanity_check.sh # uses defaults # PORT=8001 bash sanity_check.sh # custom port # PROMPT="..." bash sanity_check.sh # custom user content +# SYSTEM="..." bash sanity_check.sh # custom system message set -e MODEL_PATH="${MIMO_V2_FLASH_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Pro-Neuron-FP8}" PORT="${PORT:-8000}" +# Short system message (keeps total prompt ~25 tokens) — the checkpoint's +# default system prompt is ~240 tokens and would overflow seq_len=256. +SYSTEM="${SYSTEM:-You are MiMo, a helpful assistant developed by Xiaomi.}" # "Introduce yourself" is the self-identification prompt that consistently # lands in the model's MiMo-aware region. Swap PROMPT=... to probe others. PROMPT="${PROMPT:-Hello! Please introduce yourself in one sentence.}" MAX_TOKENS="${MAX_TOKENS:-80}" -echo "Sanity check: POST /v1/completions on port $PORT" +echo "Sanity check: POST /v1/chat/completions on port $PORT" echo " Model: $MODEL_PATH" +echo " System: $SYSTEM" echo " Prompt: $PROMPT" echo " Max tokens: $MAX_TOKENS" echo "" @@ -39,30 +44,28 @@ fi # NOTE: request-side `temperature` / `top_k` / `top_p` are ignored by # vllm-neuron on this model: the on_device_sampling_config baked into the # NEFF at compile time wins. Output is always stochastic; re-run to see -# variance. -# -# Build the chat framing in python so newlines and special tokens survive -# JSON encoding without shell escape pitfalls, then POST to /v1/completions. +# variance, or restart the server with `do_sample=false` in +# start_vllm_server.sh to force deterministic greedy decoding. python3 <user\n" - + user - + "<|im_end|>\n<|im_start|>assistant\n" -) body = json.dumps({ "model": model, - "prompt": prompt, + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], "max_tokens": int("$MAX_TOKENS"), "stream": False, }).encode() req = urllib.request.Request( - "http://localhost:$PORT/v1/completions", + "http://localhost:$PORT/v1/chat/completions", data=body, headers={"Content-Type": "application/json"}, ) @@ -77,7 +80,7 @@ if "error" in resp: print("Error from server:", json.dumps(resp["error"], indent=2)) sys.exit(1) -text = resp["choices"][0]["text"] +text = resp["choices"][0]["message"]["content"] print("Response:") print(text) print() From 935510aca026c9d58706d79d60cf5b08a50f92c1 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 29 Apr 2026 18:39:32 +0800 Subject: [PATCH 22/24] [contrib] MiMo-V2.5-Pro: default vLLM to smoke NEFF (workaround) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit vllm-neuron's own compile path — -O3, --enable-internal-neff-wrapper, on_device_sampling baked into the NEFF, continuous batching — produces garbled first-decode output on Pro: every reply starts with a UTF-8 replacement char and then coherent but completely off-topic text. V2.5 under the same vllm-neuron compile path works fine, so the trigger is Pro-specific (likely SWA + attention sink bias interacting with one of the compile / runtime options above, root cause not isolated). The NxDI-smoke compile path (-O1, no on-device sampler, static batch, produced by perf_test/smoke_compile_mimo_v2.py) does not hit the problem. vllm-neuron can load that NEFF at runtime and serves coherent chat completions with proper `` traces. As a workaround, default NEURON_COMPILED_ARTIFACTS to the smoke compile dir. Users can still override the env var to point at a vllm-neuron-compiled NEFF for testing. Co-Authored-By: Claude Opus 4.7 --- .../MiMo-V2.5-Pro/perf_test/start_vllm_server.sh | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh index 04c1c5a3..95bd8484 100644 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh @@ -31,7 +31,15 @@ export NXDI_CONTRIB_MIMO_V2_FLASH_SRC # Persistent compile-artifact location (NEFF + per-rank sharded weights). # Setting this overrides vLLM's fallback of # /neuron-compiled-artifacts//. -: "${NEURON_COMPILED_ARTIFACTS:=/opt/dlami/nvme/compiled/mimo_v2_5_pro_bs48_moetp1_ep64_bf16attn_seq256_vllm}" +# Default points at the NxDI-smoke compile artifacts (produced by +# perf_test/smoke_compile_mimo_v2.py) instead of a vllm-neuron-compiled +# NEFF. Empirically, vllm-neuron's own compile path (-O3 + +# --enable-internal-neff-wrapper + on_device_sampling + CB-baked NEFF) +# produces garbled first-decode output on Pro (per-request `?` UTF-8 +# replacement chars followed by off-topic coherent text). The smoke +# compile (-O1, no on-device sampler, static batch) works correctly +# when vLLM loads it at runtime. Root cause TBD; not yet isolated. +: "${NEURON_COMPILED_ARTIFACTS:=/opt/dlami/nvme/compiled/mimo_v2_5_pro_bs48_moetp1_ep64_fp8moe_bf16attn_seq256}" export NEURON_COMPILED_ARTIFACTS # NxDI HLO/NEFF staging directory, pinned to persistent storage so it # survives the nightly Trn2 reboot and a unique per-config subdir. From 6da7188c5ffb161b4dc4b23585b02b671bce8e33 Mon Sep 17 00:00:00 2001 From: whn09 Date: Wed, 29 Apr 2026 18:40:20 +0800 Subject: [PATCH 23/24] Revert "[contrib] MiMo-V2.5-Pro: default vLLM to smoke NEFF (workaround)" This reverts commit 935510aca026c9d58706d79d60cf5b08a50f92c1. --- .../MiMo-V2.5-Pro/perf_test/start_vllm_server.sh | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh b/contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh index 95bd8484..04c1c5a3 100644 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh @@ -31,15 +31,7 @@ export NXDI_CONTRIB_MIMO_V2_FLASH_SRC # Persistent compile-artifact location (NEFF + per-rank sharded weights). # Setting this overrides vLLM's fallback of # /neuron-compiled-artifacts//. -# Default points at the NxDI-smoke compile artifacts (produced by -# perf_test/smoke_compile_mimo_v2.py) instead of a vllm-neuron-compiled -# NEFF. Empirically, vllm-neuron's own compile path (-O3 + -# --enable-internal-neff-wrapper + on_device_sampling + CB-baked NEFF) -# produces garbled first-decode output on Pro (per-request `?` UTF-8 -# replacement chars followed by off-topic coherent text). The smoke -# compile (-O1, no on-device sampler, static batch) works correctly -# when vLLM loads it at runtime. Root cause TBD; not yet isolated. -: "${NEURON_COMPILED_ARTIFACTS:=/opt/dlami/nvme/compiled/mimo_v2_5_pro_bs48_moetp1_ep64_fp8moe_bf16attn_seq256}" +: "${NEURON_COMPILED_ARTIFACTS:=/opt/dlami/nvme/compiled/mimo_v2_5_pro_bs48_moetp1_ep64_bf16attn_seq256_vllm}" export NEURON_COMPILED_ARTIFACTS # NxDI HLO/NEFF staging directory, pinned to persistent storage so it # survives the nightly Trn2 reboot and a unique per-config subdir. From af27106b9a49e68a5a0026802e7e559721bb4a07 Mon Sep 17 00:00:00 2001 From: whn09 Date: Thu, 30 Apr 2026 10:39:37 +0800 Subject: [PATCH 24/24] [contrib] MiMo-V2.5-Pro: bump default seq_len 256 -> 512; document vLLM bug seq_len=512 under the BF16-attn recipe was verified end-to-end (compile + shard + load + 5x deterministic greedy generate) via smoke. HBM fits; seq_len=1024 still OOMs. Also documents the vllm-neuron "first request coherent, subsequent requests garbled" bug (tracked upstream at vllm-project/vllm-neuron#31). Every configuration knob we tried (all-FP8 attn, BF16 attn at 256 or 512, CB on/off, on-device sampling on/off, -O3 -> -O1) reproduced the same symptom on Pro but not on V2.5; the same NEFF serves 5 successive greedy generates byte-identically under smoke_generate_mimo_v2.py, so the bug is in vllm-neuron's runtime, not the NEFF. README changes: - Status opener now says the smoke path is verified and the vLLM serving path is blocked on issue #31. - Bump seq_len=256 references to seq_len=512 in HBM/constraints, Usage example, and the MoENeuronConfig code block. - Rewrite the vLLM "Note" callout to point at issue #31 as the single source of truth for the broken state, drop the obsolete "drop to 256" recovery hints. Script changes: - smoke_compile_mimo_v2.py: SEQ_LEN default 256 -> 512; COMPILED_PATH suffix seq256 -> seq512. Comment rewritten. - smoke_generate_mimo_v2.py: matching SEQ_LEN and COMPILED_PATH default changes so a bare `python smoke_generate_mimo_v2.py` picks up the seq_len=512 NEFF. Co-Authored-By: Claude Opus 4.7 --- contrib/models/MiMo-V2.5-Pro/README.md | 21 +++++++++++-------- .../perf_test/smoke_compile_mimo_v2.py | 11 +++++----- .../perf_test/smoke_generate_mimo_v2.py | 4 ++-- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/contrib/models/MiMo-V2.5-Pro/README.md b/contrib/models/MiMo-V2.5-Pro/README.md index c28aa4cb..1bdf8766 100644 --- a/contrib/models/MiMo-V2.5-Pro/README.md +++ b/contrib/models/MiMo-V2.5-Pro/README.md @@ -37,7 +37,9 @@ Key features: ## Status (work-in-progress) -**This port compiles cleanly and serves via vLLM on Trn2. The shipping recipe is BF16 attention + FP8 MoE at `seq_len=256` — verified to produce coherent output end-to-end via `smoke_generate_mimo_v2.py`.** Last updated 2026-04-29. +**This port compiles cleanly and is verified to produce coherent output end-to-end via the NxDI direct smoke path (`smoke_generate_mimo_v2.py`). The shipping recipe is BF16 attention + FP8 MoE at `seq_len=512` (largest seq_len that fits HBM).** + +**Known issue — vLLM serving is broken.** The first `/v1/chat/completions` request against `vllm-neuron` returns coherent output; every subsequent request returns garbled text. Same compiled NEFF serves 5 successive greedy generations byte-identically via the smoke path, so the bug is specifically in vllm-neuron's runtime / request-state handling. Tracking upstream at https://github.com/vllm-project/vllm-neuron/issues/31. Last updated 2026-04-30. ### Why BF16 attn + FP8 MoE @@ -49,8 +51,8 @@ GPU stacks (sglang on H100/H200) run the same OCP FP8 checkpoint correctly becau ### Cost and constraints -- **HBM headroom.** BF16 q/k/v adds ~2 GB per rank. `seq_len=1024` OOMs on load (the previous attempt failed allocating ~40 MB for rdh/alltoall rings after per-rank tensors already reached 20.9/24 GB). `seq_len=256` frees enough full-attention softmax scratch to fit. -- **Short context.** At `seq_len=256`, Pro's own chat template with the default system prompt is already 260 tokens. Longer context needs a different HBM plan (cross-instance TP/PP, or a larger instance). +- **HBM headroom.** BF16 q/k/v adds ~2 GB per rank. `seq_len=1024` OOMs on load (the previous attempt failed allocating ~40 MB for rdh/alltoall rings after per-rank tensors already reached 20.9/24 GB). `seq_len=512` is the largest value empirically verified to fit HBM at BS=48. +- **Short context.** Even at `seq_len=512`, Pro's full chat template with the default system prompt is ~260 tokens; that leaves ~250 tokens for user input + generation. Longer context needs a different HBM plan (cross-instance TP/PP, or a larger instance). - `BS * top_k / num_experts >= 1.0` required when `moe_ep_degree > 1` at decode (else `NotImplementedError`). With `num_experts=384, top_k=8` this forces `BS >= 48`. - `n_routed_experts=384 = 2^7 × 3` → `384 / ep_degree` is never a power of 2 (6, 12, 24, 48, 96, 192, 384). Kimi PR #131 says NKI `_bwmm_shard_on_block_nki_call` on SDK 2.29 has "depressed logits with EP=2" and recommends SDK 2.28. @@ -61,7 +63,8 @@ GPU stacks (sglang on H100/H200) run the same OCP FP8 checkpoint correctly becau ### Next experiments queued -- **BF16-attn at `seq_len=512`** (needs a tighter HBM plan — smaller batch, different EP ratio, or shrinking the default system prompt). +- **Even longer `seq_len`** (> 512): needs a tighter HBM plan — smaller batch, different EP ratio, or cross-instance sharding. +- **Upstream vllm-neuron fix** for the "first-request-only" serving bug (issue #31); patch branch at `whn09/vllm-neuron#fix/hybrid-attn-swa-spec` is a placeholder that did not resolve the symptom. - **Cross-instance BF16** via pipeline/tensor parallelism on 2× Trn2 (single-instance HBM cannot hold full BF16 Pro). - **Selective BF16 only on MoE `gate_up_proj`** (smallest expert scales) while keeping `down_proj` FP8 — another axis to probe if attn drift returns at longer contexts. - **SDK 2.28 venv** test once installed, per Kimi PR #131. @@ -241,7 +244,7 @@ compiled_path = "/path/to/compiled/" # moe_tp_degree = 1, moe_ep_degree = 64 # q_proj/k_proj/v_proj in modules_to_not_convert (BF16; preprocess # emits BF16 for q/k/v, no separate step needed) -# seq_len = 256 (HBM-tight with BF16 attn; see Status) +# seq_len = 512 (largest empirically verified; see Status) # See "FP8 Configuration Notes" below for why other moe_tp/ep ratios # collapse. neuron_config = MoENeuronConfig( @@ -253,7 +256,7 @@ neuron_config = MoENeuronConfig( max_batch_size=48, ctx_batch_size=1, tkg_batch_size=48, - seq_len=256, # HBM is tight with BF16 attn; seq_len=1024 OOMs + seq_len=512, # largest empirically verified; seq_len=1024 OOMs n_active_tokens=128, torch_dtype=torch.bfloat16, logical_nc_config=2, @@ -356,7 +359,7 @@ bash contrib/models/MiMo-V2.5-Pro/perf_test/start_vllm_server.sh See "Environment variables" above for all the knobs (`NEURON_COMPILED_ARTIFACTS`, `BASE_COMPILE_WORK_DIR`, etc.) and their defaults. -> **Note on the shipped vLLM scripts:** the current `start_vllm_server.sh` still uses `seq_len=1024` and does not list `q_proj/k_proj/v_proj` in `modules_to_not_convert`. Coupled with a BF16-attn preprocessed checkpoint this runs correctly (NxDI just sees BF16 tensors where it expected FP8 and casts them as-is) but at a longer context than the BF16-attn recipe has been HBM-validated for. If you hit a `status=4 Allocation Failure` on load, drop `seq_len` / `max_model_len` / `context_encoding_buckets` / `token_generation_buckets` to 256 and add `q_proj/k_proj/v_proj` to `modules_to_not_convert` to match the smoke-verified configuration. The bench numbers below were taken on the older all-FP8 checkpoint and have not been re-measured since the preprocess switched to BF16 attn. +> **vLLM serving is currently broken.** With the BF16-attn checkpoint, every `vllm-neuron` configuration we tried (all-FP8-attn, BF16-attn with `seq_len=256` or `512`, CB on/off, on-device sampling on/off, `-O3` or `-O1` TKG compile) reproduces the same pattern: the first chat request returns coherent output, every subsequent request returns UTF-8-replacement-char + off-topic text. The same compiled NEFF serves 5 successive greedy `adapter.generate()` calls byte-identically under `smoke_generate_mimo_v2.py` — the bug is in vllm-neuron's runtime, not in the model or the NEFF. Tracking at https://github.com/vllm-project/vllm-neuron/issues/31. Until that is fixed, use `smoke_generate_mimo_v2.py` for direct NxDI inference; the bench numbers below are historical infra-validation data from the pre-BF16-attn all-FP8 checkpoint. ### vllm-neuron patch summary @@ -368,7 +371,7 @@ The patch is applied to vllm-neuron 0.5.0 and: ## Performance -> The throughput numbers below were captured on 2026-04-29 against a pre-BF16-attn checkpoint (all-FP8, `seq_len=1024`). They are historical — the shipping recipe is BF16 attn + FP8 MoE at `seq_len=256` and has not yet been re-benchmarked. The numbers are kept here for infra validation (continuous batching + bucketing + on-device sampling + vllm-neuron plugin all wire up end-to-end) and order-of-magnitude reference. +> The throughput numbers below were captured on 2026-04-29 against a pre-BF16-attn checkpoint (all-FP8, `seq_len=1024`) before we discovered the vllm-neuron first-request bug. They are historical — the shipping recipe is BF16 attn + FP8 MoE at `seq_len=512` via the smoke path, and vLLM serving is currently blocked on issue #31. The numbers are kept here for order-of-magnitude reference. ### vLLM Serving (trn2.48xlarge, historical all-FP8 run, BS=48, TP=64, moe_tp=1/moe_ep=64, CB + bucketing, `seq_len=1024`) @@ -419,4 +422,4 @@ pytest contrib/models/MiMo-V2.5-Pro/test/integration/test_model.py -v Henan Wang (whn09) -**Last Updated:** 2026-04-29 +**Last Updated:** 2026-04-30 diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py index 0f992dab..d52f5846 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_compile_mimo_v2.py @@ -39,15 +39,14 @@ ) COMPILED_PATH = os.environ.get( "MIMO_V25_PRO_COMPILED_PATH", - "/opt/dlami/nvme/compiled/mimo_v2_5_pro_bs48_moetp1_ep64_fp8moe_bf16attn_seq256/", + "/opt/dlami/nvme/compiled/mimo_v2_5_pro_bs48_moetp1_ep64_fp8moe_bf16attn_seq512/", ) TP_DEGREE = int(os.environ.get("TP_DEGREE", "64")) -# Drop seq_len to 256 to free ~200 MB of full-attention softmax scratch per -# rank. The previous BF16-attn attempt at seq_len=1024 OOM'd by 40 MB on load -# (failed to allocate 41943040 bytes for rdh/alltoall); seq_len=256 reclaims -# enough HBM to fit the extra BF16 q/k/v weights. -SEQ_LEN = int(os.environ.get("SEQ_LEN", "256")) +# seq_len=512 is the largest value verified to fit HBM under the BF16-attn +# recipe. seq_len=1024 OOMs on load (previous attempt failed allocating +# ~40 MB for rdh/alltoall rings after per-rank tensors reached 20.9/24 GB). +SEQ_LEN = int(os.environ.get("SEQ_LEN", "512")) # BS=48 is the minimum that avoids forward_selective_loading on decode: # `BS * top_k / num_experts >= 1.0` → BS >= 384/8 = 48. At BS=1 the TKG # path raises `NotImplementedError: Selective Loading with Expert parallelism`. diff --git a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py index 1b76f4a4..6a11a565 100755 --- a/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py +++ b/contrib/models/MiMo-V2.5-Pro/perf_test/smoke_generate_mimo_v2.py @@ -27,13 +27,13 @@ ) COMPILED_PATH = os.environ.get( "MIMO_V25_PRO_COMPILED_PATH", - "/opt/dlami/nvme/compiled/mimo_v2_5_pro_bs48_moetp1_ep64_fp8moe_bf16attn_seq256/", + "/opt/dlami/nvme/compiled/mimo_v2_5_pro_bs48_moetp1_ep64_fp8moe_bf16attn_seq512/", ) # Must match smoke_compile_mimo_v2.py exactly, else load() sees a # mismatched NEFF. TP_DEGREE = int(os.environ.get("TP_DEGREE", "64")) -SEQ_LEN = int(os.environ.get("SEQ_LEN", "256")) +SEQ_LEN = int(os.environ.get("SEQ_LEN", "512")) BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "48")) # must match smoke_compile CTX_BATCH_SIZE = int(os.environ.get("CTX_BATCH_SIZE", "1")) MOE_TP = int(os.environ.get("MOE_TP", "1"))