diff --git a/contrib/models/MiMo-V2.5/README.md b/contrib/models/MiMo-V2.5/README.md new file mode 100644 index 00000000..88281a47 --- /dev/null +++ b/contrib/models/MiMo-V2.5/README.md @@ -0,0 +1,447 @@ +# Contrib Model: MiMo-V2.5 + +NeuronX Distributed Inference implementation of [XiaomiMiMo/MiMo-V2.5](https://huggingface.co/XiaomiMiMo/MiMo-V2.5). MiMo-V2.5 supersedes the earlier MiMo-V2-Flash release with the same decoder-only MoE architecture, an updated tokenizer, and a multimodal (vision + audio) head that the NxDI language path does not use. + +## Model Information + +- **HuggingFace ID:** `XiaomiMiMo/MiMo-V2.5` +- **Model Type:** Decoder-only MoE transformer with hybrid (full + SWA) attention +- **License:** Check HuggingFace model card + +## Architecture Details + +| Parameter | Value | +|-----------|-------| +| Hidden Size | 4096 | +| Layers | 48 (layer 0 dense, layers 1–47 MoE) | +| Q Heads | 64 | +| KV Heads (full attn) | 4 | +| KV Heads (sliding window) | 8 | +| Q/K Head Dim | 192 | +| V Head Dim | 128 | +| Experts | 256 (top-8 routing) | +| Expert Intermediate | 2048 | +| Vocab Size | 152,576 | +| RoPE | Partial (64 of 192 head dims = 33.4%), theta=5M (full) / 10K (SWA) | +| Sliding Window | 128 | +| Max Position | 262,144 | + +Key features: +- **Hybrid Attention**: 9 full attention layers (0, 5, 11, 17, 23, 29, 35, 41, 47) + 39 sliding window layers (positions driven by `hybrid_layer_pattern`). +- **Asymmetric Head Dims**: Q/K use 192, V uses 128. Plus asymmetric `num_kv_heads` between full (4) and SWA (8) layers. +- **Fused QKV on disk, split on Neuron**: the HF checkpoint ships `qkv_proj.weight` fused (`attention_projection_layout="fused_qkv"`); the NxDI modeling code keeps separate `q_proj`/`k_proj`/`v_proj` linears, so the preprocess script slices the fused tensor back into three per-proj tensors (see "Checkpoint Preparation"). +- **Attention Sink Bias**: Learnable per-head bias on sliding window layers only (`add_swa_attention_sink_bias=true`, `add_full_attention_sink_bias=false`). +- **Sigmoid Router + noaux_tc**: `e_score_correction_bias` added to sigmoid scores before top-k selection; unbiased scores become the affinity weights. +- **attention_value_scale = 0.707**: HF MiMo-V2 multiplies `value_states` by this before the attention softmax × V (NOT applied to attn_output); the NxDI model matches. + +## Prerequisites + +- **Instance**: trn2.48xlarge (128 NeuronCores, logical_nc_config=2 → 64 logical cores) +- **Neuron SDK**: 2.29 (Python 3.12, PyTorch 2.9) +- **Venv**: `/opt/aws_neuronx_venv_pytorch_inference_vllm_0_16` (ships with the DLAMI; has NxDI, vllm-neuron, and `huggingface_hub`/`s5cmd`). +- **Disk**: ~900 GB free under `/opt/dlami/nvme` (HF FP8 checkpoint ~295 GB, Neuron-FP8 preprocessed output ~310 GB, and `save_sharded_checkpoint=true` writes another ~300 GB of per-rank sharded weights per compiled config). The DLAMI creates a 6.9 TB RAID0 at `/dev/md0` across the instance-store NVMes but does **not** add it to `/etc/fstab`, so it is not mounted automatically after a reboot. Before running any of the steps below, remount it if needed: + + ```bash + # If /opt/dlami/nvme appears empty after an overnight reboot, the md0 array + # is still intact and just needs to be remounted: + mount | grep -q /opt/dlami/nvme || sudo mount /dev/md0 /opt/dlami/nvme + df -h /opt/dlami/nvme # should show ~6.9 TB + ``` + +## Quick Start (FP8 on Trn2) + +End-to-end recipe to go from a fresh trn2.48xlarge to a working vLLM OpenAI server serving MiMo-V2.5 FP8. First-time compile takes ~30 minutes; subsequent runs hit the neuronx-cc cache and start in a few minutes. + +```bash +# 1. Clone this repo on the Trn2 instance +cd $HOME +git clone /neuronx-distributed-inference.git +cd neuronx-distributed-inference +git checkout contrib/MiMo-V2.5 # the branch this README lives on + +# 2. Download the HuggingFace FP8 checkpoint (~295 GB). +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate +huggingface-cli download XiaomiMiMo/MiMo-V2.5 \ + --local-dir /opt/dlami/nvme/models/MiMo-V2.5 --max-workers 16 + +# 3. Preprocess HF FP8 -> Neuron FP8 (~16 min, ~15 GB peak RAM) +python contrib/models/MiMo-V2.5/src/conversion_script/preprocess_mimo_v2_5_fp8.py \ + --hf_model_path /opt/dlami/nvme/models/MiMo-V2.5 \ + --save_path /opt/dlami/nvme/models/MiMo-V2.5-Neuron-FP8 \ + --tp_degree 64 + +# 4. (Optional) sanity-check the Neuron-FP8 checkpoint without vLLM +# ~30 min first compile (priority HLO + CE HLO + 27 min shard_checkpoint +# for 64 ranks); subsequent runs ~30s to load the pre-sharded NEFF. +python contrib/models/MiMo-V2.5/perf_test/smoke_compile_mimo_v2_5.py # compile + shard +python contrib/models/MiMo-V2.5/perf_test/smoke_generate_mimo_v2_5.py # 20-token generate + +# 5. Install vllm-neuron with the contrib registration patch +bash contrib/models/MiMo-V2.5/perf_test/0_setup.sh + +# 6. Start vLLM serving MiMo-V2.5 FP8 +bash contrib/models/MiMo-V2.5/perf_test/bench_mimo_v2_5.sh +``` + +The bench script runs one configuration (BS=32, +`moe_tp_degree=1 / moe_ep_degree=64`) at three concurrency levels (1, 16, 32) and logs results under +`/opt/dlami/nvme/logs/bench_results/mimo_v2_5/`. + +### Keeping a server up for ad-hoc testing + +`bench_mimo_v2_5.sh` is a one-shot wrapper (launch server → sanity → +3 bench runs → teardown). If you want a long-running server to iterate +against, use the three underlying scripts separately: + +```bash +# Terminal 1: launch the server in the foreground (Ctrl-C to stop). +bash contrib/models/MiMo-V2.5/perf_test/start_vllm_server.sh + +# Terminal 2: once "Application startup complete." prints, sanity-check: +bash contrib/models/MiMo-V2.5/perf_test/sanity_check.sh + +# Run a single bench pass with a chosen concurrency: +CONCURRENCY=16 NUM_PROMPTS=128 \ + bash contrib/models/MiMo-V2.5/perf_test/run_bench_single.sh +``` + +`bench_mimo_v2_5.sh` composes exactly these three pieces; use whichever +is more convenient. + +### Environment variables + +`0_setup.sh` prints these at the end; setting them explicitly makes the +smoke / bench / manual-launch paths all behave the same. All of them have +sensible defaults in the scripts — export them only if you want to +override or if you plan to launch vLLM outside of `bench_mimo_v2_5.sh`. + +**Required (at least for manual `vllm api_server` launches):** + +| Variable | Purpose | +|---|---| +| `NXDI_CONTRIB_MIMO_V2_5_SRC` | Path to `contrib/models/MiMo-V2.5/src/`. `vllm-neuron`'s registration hook reads it to plug `NeuronMiMoV2ForCausalLM` into NxDI's `MODEL_TYPES` table. | +| `NXDI_CONTRIB_MIMO_V2_FLASH_SRC` | Alias of `NXDI_CONTRIB_MIMO_V2_5_SRC` — same value. vLLM's builtin arch validator only knows `MiMoV2FlashForCausalLM`, so preprocess rewrites the checkpoint's `architectures` to that name and we re-use the Flash registration key (`mimov2flash`) in vllm-neuron's lookup table. | +| `MIMO_V2_5_PATH` | Preprocessed Neuron-FP8 checkpoint dir (the `--save_path` output from preprocess). | + +**Optional (recommended):** + +| Variable | Default | Purpose | +|---|---|---| +| `NEURON_COMPILED_ARTIFACTS` | `/opt/dlami/nvme/compiled/mimo_v2_5_bs32_moetp1_ep64_fp8_vllm` | Where vLLM writes the NEFF + per-rank sharded weights. Default points at a persistent path under `/opt/dlami/nvme/compiled/` so multiple configs don't collide and runs after the nightly reboot can reuse the sharded weights. vLLM's fallback is `/neuron-compiled-artifacts//` which buries output inside the checkpoint dir. | +| `BASE_COMPILE_WORK_DIR` | `/opt/dlami/nvme/tmp/nxd_model/` | NxDI's HLO / NEFF staging workdir. Default is `/tmp/nxd_model/`, which is wiped by the nightly Trn2 reboot and can silently corrupt parallel compiles that share a basename; the pinned value lives on persistent storage and is unique per config. | +| `VLLM_ENGINE_READY_TIMEOUT_S` | `7200` | First-time compile of V2.5's 256-expert MoE is ~30 min dominated by `shard_checkpoint`, well past vLLM's default. | + +For a quick `curl` sanity check while the server is up: + +```bash +curl -s http://localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{"model": "/opt/dlami/nvme/models/MiMo-V2.5-Neuron-FP8", + "messages": [{"role": "user", "content": "Hello! Introduce yourself in one sentence."}], + "max_tokens": 64, "temperature": 0.0}' | python3 -m json.tool +``` + +If you get fluent sentence-ending output on a 30+ token generation, the +FP8 path is working correctly. If you see repetition collapse +("helpful helpful helpful..."), double-check that `moe_tp_degree=1`, +`moe_ep_degree=64`, `batch_size>=32`, and that you are loading the +preprocessed Neuron-FP8 checkpoint (not the raw HF FP8 directory). + +## Checkpoint Preparation + +The HuggingFace checkpoint ships as block-wise OCP FP8 (E4M3, ±448 range), which is not directly compatible with Neuron FP8 (IEEE-754 E4M3, ±240 range). `src/conversion_script/preprocess_mimo_v2_5_fp8.py` performs a per-layer streaming rescale: per-row scales for attention Q/K/V (after fused-qkv split) and the layer-0 dense MLP; blockwise 128×128 scales for MoE experts. `o_proj` is listed in HF's `quantization_config.ignored_layers` and is kept BF16 on the Neuron side (it binds to a plain `RowParallelLinear`, not `QuantizedRowParallel`). Output is ~310 GB across 48 per-layer safetensors shards. + +```bash +python contrib/models/MiMo-V2.5/src/conversion_script/preprocess_mimo_v2_5_fp8.py \ + --hf_model_path /path/to/MiMo-V2.5 \ + --save_path /path/to/MiMo-V2.5-Neuron-FP8 \ + --tp_degree 64 +``` + +Peak RAM during preprocessing is ~15 GB; total runtime ~16 minutes on a trn2.48xlarge instance. + +### V2.5-specific: fused qkv_proj split into 4 interleaved groups + +The HF checkpoint advertises `q_proj.weight` / `k_proj.weight` / `v_proj.weight` in its safetensors index, but the actual LFS objects on the Hub only carry a single fused `self_attn.qkv_proj.weight` tensor. NxDI's MiMoV2Attention hard-codes separate Q/K/V `ColumnParallelLinear` modules, so the preprocess script splits the fused tensor back into three per-proj tensors. + +The fused layout is **not** `[all_Q | all_K | all_V]`. It is **4 interleaved groups** (the group count equals the full-attention `num_key_value_heads = 4`), each packing `hpg` Q heads, `kpg` K heads, and `kpg` V heads contiguously: + + group g (g = 0..3): + rows [g*R : g*R + qg] = Q heads [g*hpg : (g+1)*hpg] + rows [g*R + qg : g*R + qg + kg] = K heads [g*kpg : (g+1)*kpg] + rows [g*R + qg + kg : g*R + R] = V heads [g*kpg : (g+1)*kpg] + + where hpg = num_q_heads / 4, kpg = num_kv_heads / 4, + qg = hpg * 192, kg = kpg * 192, vg = kpg * 128, + R = qg + kg + vg + +For **full-attention layers** this gives `hpg=16, kpg=1, R=3392, total=13568` rows with 108 scale blocks (includes 2 phantom rows from `ceil(192/128)=2`). For **SWA layers** (`num_kv_heads=8`), `hpg=16, kpg=2, R=3712, total=14848` rows with 116 scale blocks (no phantom, since `kg=384` is 128-aligned). Layer 0 (dense) is still attention-FP8 and follows the full-layer layout. + +Any preprocess approach that treats the fused tensor as a plain `[Q|K|V]` concatenation produces garbled outputs — Q/K/V rows land in the wrong per-head slots after the split. + +## Usage + +```python +import sys +from pathlib import Path + +# Make this contrib package's src/ importable (flat, per upstream contrib convention). +sys.path.insert(0, str(Path("contrib/models/MiMo-V2.5/src").resolve())) + +import torch +from transformers import AutoConfig, AutoTokenizer +from neuronx_distributed_inference.models.config import MoENeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config, HuggingFaceGenerationAdapter + +from modeling_mimo_v2 import NeuronMiMoV2ForCausalLM, MiMoV2InferenceConfig + +model_path = "/path/to/MiMo-V2.5-Neuron-FP8/" +compiled_path = "/path/to/compiled/" + +# Recommended FP8 recipe: +# moe_tp_degree = 1, moe_ep_degree = 64 +# See "FP8 Configuration Notes" below for why other moe_tp/ep ratios collapse. +neuron_config = MoENeuronConfig( + tp_degree=64, + ep_degree=1, # keep outer EP = 1; only MoE-internal EP varies + moe_tp_degree=1, + moe_ep_degree=64, + batch_size=32, # must be >= num_experts / top_k = 256 / 8 = 32 + max_batch_size=32, + ctx_batch_size=1, + tkg_batch_size=32, + seq_len=1024, + n_active_tokens=128, + torch_dtype=torch.bfloat16, + logical_nc_config=2, + capacity_factor=1.0, + glu_mlp=True, + fused_qkv=False, # required: asymmetric Q/K (192) vs V (128) head dims + router_config={"act_fn": "sigmoid", "dtype": "float32"}, + blockwise_matmul_config={ + "use_shard_on_block_dynamic_while": True, + "block_sharding_strategy": "PING_PONG", + }, + save_sharded_checkpoint=True, + quantized=True, + quantized_checkpoints_path=model_path, + quantization_dtype="f8e4m3", + quantization_type="blockwise_symmetric", + quantization_block_axis=[1, 2], + quantization_block_size=[128, 128], + modules_to_not_convert=[ + "embed_tokens", "lm_head", "norm", "router", "o_proj", + ], + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=True, temperature=0.6, top_k=20, top_p=0.95, + ), +) + +# trust_remote_code is required by MiMo-V2's HF config; pre-load via AutoConfig +# and pass to NxDI so load_pretrained_config does not re-load without the flag. +hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) +config = MiMoV2InferenceConfig( + neuron_config, load_config=load_pretrained_config(hf_config=hf_config), +) + +model = NeuronMiMoV2ForCausalLM(model_path, config) +model.compile(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +adapter = HuggingFaceGenerationAdapter(model) +inputs = tokenizer(["Hello, how are you?"] * 32, return_tensors="pt", padding=True) +output = adapter.generate( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + max_new_tokens=128, +) +``` + +For a minimal end-to-end smoke test that bypasses vLLM, see: + +- `perf_test/smoke_compile_mimo_v2_5.py` — compile + load (STAGE=instantiate|compile|load|all, DRY_RUN, SKIP_WARMUP) +- `perf_test/smoke_generate_mimo_v2_5.py` — 20-token generation via HuggingFaceGenerationAdapter + +Both default to the recommended FP8 recipe (`moe_tp=1`, `moe_ep=64`). + +## FP8 Configuration Notes + +### moe_tp_degree = 1, moe_ep_degree = 64 + +**Why**: at `moe_tp_degree=64` each rank owns 1/64 of the intermediate dim, which for MiMo-V2.5 (MoE intermediate = 2048) is 32 rows — **below the 128-row blockwise scale block**. NxDI's `_setup_for_scale` detects `weight_shape[axis] < block_size` and collapses the per-rank scale dim to 1, losing per-channel FP8 scale granularity. The resulting drift compounds across MiMo-V2.5's 47 MoE layers and manifests as output collapse ("helpful helpful helpful ...") after roughly 30 decode tokens. + +`moe_tp_degree=1, moe_ep_degree=64` keeps each expert's weights and blockwise scales intact on a single rank (4 experts per rank), which preserves per-channel scale and produces correct output even on long multi-turn prompts. + +Intermediate ratios (`moe_tp=32/ep=2` or `moe_tp=16/ep=4`) have been empirically tested and still produce gibberish, so this is the only currently-supported moe_tp/ep combination for MiMo-V2.5 FP8. + +### batch_size >= 32 + +NxDI's TKG (token generation) path refuses Expert Parallelism when `batch_size < num_experts / top_k`. For MiMo-V2.5 that is 256 / 8 = 32, so the smallest working BS on the FP8 path is 32. BS=1 latency demos are not currently possible on FP8; use the BF16 checkpoint with `moe_tp=64, moe_ep=1, batch_size=1` for single-stream latency measurements. + +### outer ep_degree = 1 + +`MoENeuronConfig.ep_degree` is the **full-model** expert-parallel factor. Setting it to anything > 1 multiplies `world_size` to `tp_degree * ep_degree`, which on a 64-NC Trn2 overflows the device (ranks beyond 63 have no backing hardware, sharded-checkpoint size grows linearly, and load fails). The MoE-internal expert parallelism is controlled exclusively by `moe_ep_degree` — keep `ep_degree=1` at the outer level. + +## vLLM Integration + +MiMo-V2.5 can be served via [vllm-neuron](https://github.com/aws-neuron/vllm-neuron). A contrib registration patch is required to plug the NxDI modeling code into vllm-neuron's lookup tables. + +### Setup + +```bash +# The setup script clones vllm-project/vllm-neuron at release-0.5.0, applies +# the contrib registration patch, installs it editable, and downloads +# MiMo-V2.5 FP8 weights from HuggingFace (~295 GB; skipped if already present). +bash contrib/models/MiMo-V2.5/perf_test/0_setup.sh +``` + +`perf_test/vllm-neuron-patch.patch` adds a `_register_contrib_models()` hook to `vllm_neuron/worker/neuronx_distributed_model_loader.py`. When `NXDI_CONTRIB_MIMO_V2_5_SRC` is set, it registers `NeuronMiMoV2ForCausalLM` into NxDI's `MODEL_TYPES` under the key `mimov2` **and** registers the `MiMoV2ForCausalLM` architecture into vLLM's `ModelRegistry`. The hook also patches `AutoConfig.from_pretrained` to default `trust_remote_code=True` so NxDI's `load_pretrained_config` can read the V2.5 config. No upstream vLLM or NxDI source is modified. + +### Serving (FP8, recommended) + +```bash +export NXDI_CONTRIB_MIMO_V2_5_SRC=/path/to/neuronx-distributed-inference/contrib/models/MiMo-V2.5/src +export MIMO_V2_5_PATH=/path/to/MiMo-V2.5-Neuron-FP8 +# First-time compile of MiMo-V2.5's 256-expert MoE takes 30-60 minutes. +export VLLM_ENGINE_READY_TIMEOUT_S=7200 +# Optional: isolate compile cache per config so parallel MiMo-V2.5/Pro/etc. compiles +# don't race on the default /var/tmp/neuron-compile-cache lock files. +export NEURON_COMPILED_ARTIFACTS=/path/to/compiled/mimo_v2_5_bs32_moetp1_ep64_fp8 + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MIMO_V2_5_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 32 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + "tp_degree": 64, + "logical_nc_config": 2, + "fused_qkv": false, + "sequence_parallel_enabled": false, + "glu_mlp": true, + "normalize_top_k_affinities": true, + "save_sharded_checkpoint": true, + "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, + "quantized": true, + "quantized_checkpoints_path": "/path/to/MiMo-V2.5-Neuron-FP8", + "quantization_dtype": "f8e4m3", + "quantization_type": "blockwise_symmetric", + "quantization_block_axis": [1, 2], + "quantization_block_size": [128, 128], + "modules_to_not_convert": ["embed_tokens", "lm_head", "norm", "router", "o_proj"], + "blockwise_matmul_config": {"use_shard_on_block_dynamic_while": true, "block_sharding_strategy": "PING_PONG"}, + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 32, + "ctx_batch_size": 1, + "tkg_batch_size": 32, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "enable_bucketing": true, + "context_encoding_buckets": [1024], + "token_generation_buckets": [1024], + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' +``` + +See `perf_test/bench_mimo_v2_5.sh` for the full benchmark recipe at BS=32. + +### Testing the vLLM server + +Once `/v1/models` returns 200 (first-compile takes ~30 min; subsequent starts ~3 min), hit `/v1/chat/completions`. MiMo-V2.5's chat template expects the `<|im_start|>...<|im_end|>` ChatML format — vLLM applies it automatically when you use the chat endpoint, so just send a standard messages array: + +```bash +MODEL=/opt/dlami/nvme/models/MiMo-V2.5-Neuron-FP8 + +# 1. Short sanity — should return a one-line MiMo self-introduction. +curl -s http://localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d "{\"model\":\"$MODEL\", + \"messages\":[{\"role\":\"user\",\"content\":\"Hello! Introduce yourself in one sentence.\"}], + \"max_tokens\":64}" | python3 -m json.tool + +# 2. Long output — check for repetition collapse / gibberish on 500+ tokens. +curl -s http://localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d "{\"model\":\"$MODEL\", + \"messages\":[{\"role\":\"user\",\"content\":\"Explain the B-tree data structure in detail, including how insertions and deletions preserve balance.\"}], + \"max_tokens\":800}" | python3 -c "import sys,json; r=json.load(sys.stdin); print(r['choices'][0]['message']['content'])" +``` + +If you see a coherent MiMo introduction and a multi-paragraph technical explanation, the FP8 path is working end-to-end. Output collapse ("helpful helpful helpful ...") on either prompt indicates a broken FP8 recipe — re-check that `moe_tp_degree=1`, `moe_ep_degree=64`, `batch_size>=32`, and that the server is pointed at the Neuron-FP8 preprocessed directory (not the raw HF one). + +**Note on sampling determinism**: `on_device_sampling_config.do_sample=true` is the recommended setting; request-level `temperature` is ignored (sampling params are baked into the NEFF at compile time). + +### vllm-neuron patch summary + +The patch is applied to vllm-neuron 0.5.0 and: + +- Maps the `MiMoV2ForCausalLM` architecture to MiMo-V2.5's model loader (reusing the Qwen2-family loader path, which MiMo-V2.5's tokenizer inherits from). +- Passes `hf_config` from vLLM into `load_pretrained_config` so NxDI does not re-load the config without `trust_remote_code=True`. +- Replaces vllm-neuron's internal `AutoModelForCausalLM.from_pretrained` call with `huggingface_hub.snapshot_download`, which is the only path that works for `trust_remote_code=True` models when no GPU is available for HF's CUDA-gated FP8 quantizer. + +## Performance + +### vLLM Serving (trn2.48xlarge, FP8, BS=32, TP=64 / moe_ep=64, CB + bucketing) + +Input/output: 900 / 90 tokens (random dataset). Recipe is the one `bench_mimo_v2_5.sh` drives; 16 prompts at c=1 and 128 prompts at c=16/c=32. + +| Concurrency | Output throughput (tok/s) | Total throughput (tok/s) | TPOT median (ms) | TTFT median (ms) | TTFT P99 (ms) | +|---|---|---|---|---|---| +| 1 | 15.88 | 174.16 | 58.28 | 485 | 485 | +| 16 | 113.92 | 1251.14 | 130.85 | 863 | 6371 | +| 32 | 147.39 | 1618.81 | 190.48 | 1798 | 13281 | + +Observations: +- **Median ITL stays at ~58 ms across all three concurrency levels** — that's the cost of one BS=32 TKG NEFF forward, which runs at fixed shape regardless of how many slots are actually occupied. +- **Peak output throughput at c=32 is 576 tok/s**, close to the theoretical `32 / 0.058 ≈ 552` ceiling. +- **TPOT and TTFT grow with concurrency** because `enable_chunked_prefill=false`: each new request's context-encoding pass (900 tokens) preempts TKG for a few hundred ms, and the higher the concurrency the more frequently that happens. + +> **Compile time:** the first MiMo-V2.5 compile on SDK 2.29 is ~30 minutes (TKG + CE HLO compilation, weight layout optimization, then `shard_checkpoint` for 64 ranks which dominates at ~27 minutes). Subsequent runs with the same `override_neuron_config` hit the neuronx-cc cache and the NEFF loads in ~1 minute. `save_sharded_checkpoint=true` persists per-rank FP8 shards under `/weights/`, letting future `load()` calls skip the `shard_checkpoint` pass entirely. + +## Compatibility Matrix + +| Instance | Neuron SDK 2.29+ (PyTorch 2.9) | 2.21 and earlier | +|----------|--------------------------------|------------------| +| Trn2 (trn2.48xlarge) | Tested | Not tested | +| Trn1 | Not supported (requires 64 logical cores via logical_nc_config=2) | Not supported | +| Inf2 | Not supported | Not supported | + +## Testing + +```bash +pytest contrib/models/MiMo-V2.5/test/integration/test_model.py -v +``` + +## Key Implementation Notes + +1. **Hybrid Attention**: `hybrid_layer_pattern` list determines full vs sliding window per layer; the modeling code constructs one `NeuronMiMoV2Attention` per layer with the correct `is_sliding_window` flag and rope_theta. +2. **CONVERT_TO_MHA**: When `tp_degree > num_kv_heads` (64 > 4 full / 64 > 8 SWA), K/V are replicated to `num_attention_heads` (64) during state-dict conversion; this applies to both `.weight` and the per-row `.scale` on the FP8 path. +3. **Attention Sink Bias**: Learnable per-head bias added as an extra "sink" column to attention scores in sliding window layers (not added in full-attention layers). Per-rank slicing of the bias happens inside `forward()` based on `parallel_state.get_tensor_model_parallel_rank()`. +4. **Fused qkv split in preprocess**: V2.5's HF checkpoint stores `self_attn.qkv_proj.weight` as 4 interleaved Q/K/V groups (see "Checkpoint Preparation" above). The preprocess script must slice these groups — naïve `[Q|K|V]` concat slicing produces garbage outputs. +5. **weight_map rebuild**: V2.5's `model.safetensors.index.json` references legacy `model_N-00001-of-00002.safetensors` filenames that do not match the actual `model_pp0_epN_shardM.safetensors` objects on disk. `LazyWeightMap` scans the on-disk shards at startup and rebuilds `weight_map` directly from each file's manifest; the inconsistent index is ignored. +6. **FP8 Path Caveats**: + - Must use `moe_tp_degree=1, moe_ep_degree=64` (see "FP8 Configuration Notes" above). + - Must use `batch_size >= 32` (NxDI EP>1 requirement). + - Must keep outer `ep_degree=1` (only `moe_ep_degree` should vary). + - Several runtime monkey-patches (router bias, blockwise scale stride, 2D per-channel, EP scale handling) are installed automatically in `NeuronMiMoV2ForCausalLM.__init__` when `quantized=True`; the BF16 path is untouched. + +## Example Checkpoints + +* [XiaomiMiMo/MiMo-V2.5](https://huggingface.co/XiaomiMiMo/MiMo-V2.5) — HF FP8 source checkpoint + +## Maintainer + +Henan Wang (whn09) + +**Last Updated:** 2026-04-28 diff --git a/contrib/models/MiMo-V2.5/perf_test/0_setup.sh b/contrib/models/MiMo-V2.5/perf_test/0_setup.sh new file mode 100755 index 00000000..7c88c1af --- /dev/null +++ b/contrib/models/MiMo-V2.5/perf_test/0_setup.sh @@ -0,0 +1,113 @@ +#!/bin/bash +# Setup for MiMo-V2.5 vLLM benchmarking on Trn2. +# +# Clones upstream vllm-project/vllm-neuron at release-0.5.0 and applies +# vllm-neuron-patch.patch, which adds a runtime registration hook so the +# contrib NeuronMiMoV2ForCausalLM is plugged into both NxDI's MODEL_TYPES +# (under the key "mimov2") and vLLM's ModelRegistry (as +# MiMoV2ForCausalLM) at vllm-neuron plugin init time. +# +# Then downloads XiaomiMiMo/MiMo-V2.5 from HuggingFace (FP8 blockwise, ~320 GB). +set -e + +echo "==========================================" +echo "Setup: vllm-neuron + MiMo-V2.5 weights" +echo "==========================================" + +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate + +# Resolve repo-relative paths up front — we cd into $HOME/vllm-neuron below, +# after which $0's relative form would no longer resolve. +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PATCH_FILE="$SCRIPT_DIR/vllm-neuron-patch.patch" +CONTRIB_SRC="$(cd "$SCRIPT_DIR/.." && pwd)/src" + +echo "" +echo "[1/2] Installing vllm-neuron (release-0.5.0) with the contrib registration patch..." + +if [ ! -d $HOME/vllm-neuron ]; then + git clone --branch release-0.5.0 https://github.com/vllm-project/vllm-neuron.git $HOME/vllm-neuron +fi + +cd $HOME/vllm-neuron + +# Apply patch (idempotent via `git apply --check` first). +if git apply --check "$PATCH_FILE" 2>/dev/null; then + git apply "$PATCH_FILE" + echo " Applied $PATCH_FILE" +else + echo " Patch already applied or conflicts; continuing." +fi + +pip install --extra-index-url=https://pip.repos.neuron.amazonaws.com -e . + +python3 -c "import vllm_neuron; print('vllm-neuron installed:', vllm_neuron.__file__)" + +echo "" +echo "[2/2] Downloading MiMo-V2.5 FP8 weights from HuggingFace..." + +MIMO_PATH="${MIMO_V2_5_PATH:-/opt/dlami/nvme/models/MiMo-V2.5}" +if [ -d "$MIMO_PATH" ] && [ "$(ls "$MIMO_PATH"/*.safetensors 2>/dev/null | wc -l)" -gt 0 ]; then + echo " MiMo-V2.5 weights already exist at $MIMO_PATH, skipping download" +else + mkdir -p "$MIMO_PATH" + huggingface-cli download XiaomiMiMo/MiMo-V2.5 --local-dir "$MIMO_PATH" --max-workers 16 + echo " Download complete: $(du -sh $MIMO_PATH | cut -f1)" +fi + +NEURON_FP8_PATH="${MIMO_PATH}-Neuron-FP8" +COMPILED_PATH="/opt/dlami/nvme/compiled/mimo_v2_5_bs32_moetp1_ep64_fp8_vllm" + +echo "" +echo "========================================================================" +echo "Next steps" +echo "========================================================================" +echo "" +echo "1. Preprocess the FP8 checkpoint for Neuron (~16 min, ~15 GB peak RAM):" +echo "" +echo " python $CONTRIB_SRC/conversion_script/preprocess_mimo_v2_5_fp8.py \\" +echo " --hf_model_path $MIMO_PATH \\" +echo " --save_path $NEURON_FP8_PATH \\" +echo " --tp_degree 64" +echo "" +echo "2. Export the environment variables used by the smoke / bench scripts:" +echo "" +echo " # --- Required ---" +echo " # Contrib package src (registers NeuronMiMoV2ForCausalLM with vllm-neuron)." +echo " export NXDI_CONTRIB_MIMO_V2_5_SRC=$CONTRIB_SRC" +echo " # vLLM's builtin arch validator only knows MiMoV2FlashForCausalLM, so the" +echo " # preprocess rewrites the checkpoint's config.json architectures to that" +echo " # name. Alias V2.5 src to the Flash env var so vllm-neuron's contrib hook" +echo " # registers mimov2flash -> our V2.5 NeuronMiMoV2ForCausalLM class." +echo " export NXDI_CONTRIB_MIMO_V2_FLASH_SRC=\"\$NXDI_CONTRIB_MIMO_V2_5_SRC\"" +echo " # Preprocessed Neuron-FP8 checkpoint." +echo " export MIMO_V2_5_PATH=$NEURON_FP8_PATH" +echo "" +echo " # --- Optional (recommended) ---" +echo " # vLLM compiles into /neuron-compiled-artifacts// by" +echo " # default. Pin it to a persistent shared location so multiple configs" +echo " # don't collide and you can reuse the NEFF / sharded weights across runs." +echo " export NEURON_COMPILED_ARTIFACTS=$COMPILED_PATH" +echo " # NxDI's HLO/NEFF staging workdir (.hlo_module.pb etc). Default is" +echo " # /tmp/nxd_model//; on Trn2 /tmp is wiped by the nightly" +echo " # reboot, and parallel compiles sharing the same basename silently" +echo " # overwrite each other's staged HLOs. Pin to a unique per-config" +echo " # directory under persistent storage." +echo " export BASE_COMPILE_WORK_DIR=/opt/dlami/nvme/tmp/nxd_model/\$(basename $COMPILED_PATH)" +echo " # First-time compile of V2.5's 256-expert MoE takes ~30 min (NEFF HLO +" +echo " # shard_checkpoint for 64 ranks). Extend vLLM's ready timeout." +echo " export VLLM_ENGINE_READY_TIMEOUT_S=7200" +echo "" +echo "3a. Run the one-shot benchmark (launches + benches + tears down):" +echo "" +echo " bash $SCRIPT_DIR/bench_mimo_v2_5.sh" +echo "" +echo "3b. ...OR keep a server up and probe it manually:" +echo "" +echo " # shell 1: server in foreground (Ctrl-C to stop)" +echo " bash $SCRIPT_DIR/start_vllm_server.sh" +echo "" +echo " # shell 2: once 'Application startup complete.' prints," +echo " bash $SCRIPT_DIR/sanity_check.sh" +echo " CONCURRENCY=16 NUM_PROMPTS=128 bash $SCRIPT_DIR/run_bench_single.sh" +echo "" diff --git a/contrib/models/MiMo-V2.5/perf_test/bench_mimo_v2_5.sh b/contrib/models/MiMo-V2.5/perf_test/bench_mimo_v2_5.sh new file mode 100755 index 00000000..7f82cfcd --- /dev/null +++ b/contrib/models/MiMo-V2.5/perf_test/bench_mimo_v2_5.sh @@ -0,0 +1,87 @@ +#!/bin/bash +set -e + +# MiMo-V2.5 FP8 vLLM benchmark on Trn2. One-shot wrapper: +# launch server -> sanity check -> bench at c=1,16,32 -> stop server. +# +# This script composes three building blocks in perf_test/: +# start_vllm_server.sh - server launch + env-var setup (backgrounded here) +# sanity_check.sh - one-shot curl against the running server +# run_bench_single.sh - one concurrency level of `vllm bench serve` +# +# Use those directly if you want to keep a long-running server and iterate +# on bench parameters from another shell. +# +# Server recipe: TP=64, moe_tp=1/moe_ep=64, BS=32, continuous batching. +# BS=32 is the smallest working batch size on the FP8 path (NxDI's TKG +# path refuses Expert Parallelism with BS < num_experts/top_k = 256/8 = 32). +# BS=1 single-stream latency demos are not currently supported on V2.5 FP8. + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PORT="${PORT:-8000}" +RESULTS_DIR="${RESULTS_DIR:-/opt/dlami/nvme/logs/bench_results/mimo_v2_5}" +CONFIG_NAME="bs32_tp64_moetp1_ep64" + +mkdir -p "$RESULTS_DIR" + +# Wait for vLLM server to be ready. First-time compile of the 256-expert +# MoE model takes ~30 min and can stretch past 2 h under contention, so +# poll for up to 2 h. +wait_for_server() { + echo " Waiting for vLLM server on port $PORT (up to 2 h for first compile)..." + local interval=10 + local max_attempts=720 + local start=$SECONDS + for i in $(seq 1 $max_attempts); do + if curl -s "http://localhost:$PORT/health" > /dev/null 2>&1; then + echo " Server ready after $((SECONDS - start))s." + return 0 + fi + if [ $((i % 6)) -eq 0 ]; then + echo " ...still waiting ($((SECONDS - start))s elapsed)" + fi + sleep $interval + done + echo " ERROR: Server did not start within $((max_attempts * interval))s" + return 1 +} + +stop_server() { + echo " Stopping vLLM server..." + pkill -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true + sleep 5 +} + +echo "==========================================" +echo "MiMo-V2.5 FP8 Performance Benchmark" +echo "==========================================" +echo "Port: $PORT" +echo "Results: $RESULTS_DIR" +echo "" + +# Start the server in the background. start_vllm_server.sh handles all the +# env vars (MODEL_PATH, NEURON_COMPILED_ARTIFACTS, BASE_COMPILE_WORK_DIR, +# contrib src registration, etc.) and execs `python3 -m vllm...`. +bash "$SCRIPT_DIR/start_vllm_server.sh" & +SERVER_PID=$! +trap stop_server EXIT + +wait_for_server + +# One-shot sanity check (curl the chat endpoint). +PORT="$PORT" bash "$SCRIPT_DIR/sanity_check.sh" || true + +# Three concurrency levels. run_bench_single.sh reads knobs from the +# environment; see its header for all the options. +PORT="$PORT" RESULTS_DIR="$RESULTS_DIR" CONFIG_NAME="$CONFIG_NAME" \ + CONCURRENCY=1 NUM_PROMPTS=16 bash "$SCRIPT_DIR/run_bench_single.sh" +PORT="$PORT" RESULTS_DIR="$RESULTS_DIR" CONFIG_NAME="$CONFIG_NAME" \ + CONCURRENCY=16 NUM_PROMPTS=128 bash "$SCRIPT_DIR/run_bench_single.sh" +PORT="$PORT" RESULTS_DIR="$RESULTS_DIR" CONFIG_NAME="$CONFIG_NAME" \ + CONCURRENCY=32 NUM_PROMPTS=128 bash "$SCRIPT_DIR/run_bench_single.sh" + +echo "==========================================" +echo "MiMo-V2.5 FP8 benchmark complete!" +echo "Results saved to: $RESULTS_DIR" +echo "==========================================" +ls -la "$RESULTS_DIR" diff --git a/contrib/models/MiMo-V2.5/perf_test/run_bench_single.sh b/contrib/models/MiMo-V2.5/perf_test/run_bench_single.sh new file mode 100755 index 00000000..5206ba57 --- /dev/null +++ b/contrib/models/MiMo-V2.5/perf_test/run_bench_single.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# Run a single vllm-bench-serve pass against an already-running vLLM server. +# +# Unlike bench_mimo_v2_5.sh this script does NOT launch or kill the vLLM +# server — you bring your own. That makes it convenient when the bench driver +# in bench_mimo_v2_5.sh times out during first-time compilation: the server +# keeps running, and once it's ready you can collect numbers with this. +# +# Usage: +# bash run_bench_single.sh # defaults: c=1, 16 prompts +# CONCURRENCY=16 NUM_PROMPTS=128 bash run_bench_single.sh +# CONFIG_NAME=bs32_tp1_ep64_opt CONCURRENCY=16 NUM_PROMPTS=128 bash run_bench_single.sh +# +# Environment knobs: +# PORT vLLM server port (default 8000) +# MIMO_V2_5_PATH Path to the Neuron-FP8 checkpoint (default +# /opt/dlami/nvme/models/MiMo-V2.5-Neuron-FP8) +# CONCURRENCY --max-concurrency (default 1) +# NUM_PROMPTS --num-prompts (default 16) +# INPUT_LEN --random-input-len (default 900) +# OUTPUT_LEN --random-output-len (default 90) +# RANGE_RATIO --random-range-ratio (default 0.03) +# CONFIG_NAME Used in the output filename (default bs1_tp64_ep1) +# RESULTS_DIR Where to dump per-run log (default /opt/dlami/nvme/logs/bench_results/mimo_v2_5) + +set -e + +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate + +MODEL_PATH="${MIMO_V2_5_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Neuron-FP8}" +PORT="${PORT:-8000}" +CONCURRENCY="${CONCURRENCY:-1}" +NUM_PROMPTS="${NUM_PROMPTS:-16}" +INPUT_LEN="${INPUT_LEN:-900}" +OUTPUT_LEN="${OUTPUT_LEN:-90}" +RANGE_RATIO="${RANGE_RATIO:-0.03}" +CONFIG_NAME="${CONFIG_NAME:-bs1_tp64_ep1}" +RESULTS_DIR="${RESULTS_DIR:-/opt/dlami/nvme/logs/bench_results/mimo_v2_5}" + +mkdir -p "$RESULTS_DIR" + +echo "==========================================" +echo "MiMo-V2.5 single-run benchmark" +echo "==========================================" +echo " Model: $MODEL_PATH" +echo " Port: $PORT" +echo " Config: $CONFIG_NAME" +echo " Concurrency: $CONCURRENCY" +echo " Prompts: $NUM_PROMPTS" +echo " Input len: $INPUT_LEN Output len: $OUTPUT_LEN" +echo " Results: $RESULTS_DIR/${CONFIG_NAME}_c${CONCURRENCY}.txt" +echo "" + +# Quick health check +if ! curl -sf "http://localhost:$PORT/health" > /dev/null; then + echo "ERROR: vLLM server is not responding on http://localhost:$PORT" + echo "Start it first (e.g., bench_mimo_v2_5.sh) and wait until" + echo "'Application startup complete.' is printed." + exit 1 +fi + +vllm bench serve \ + --backend vllm \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --endpoint /v1/completions \ + --dataset-name random \ + --num-prompts "$NUM_PROMPTS" \ + --random-input-len "$INPUT_LEN" \ + --random-output-len "$OUTPUT_LEN" \ + --random-range-ratio "$RANGE_RATIO" \ + --max-concurrency "$CONCURRENCY" \ + 2>&1 | tee "$RESULTS_DIR/${CONFIG_NAME}_c${CONCURRENCY}.txt" + +echo "" +echo "Saved to: $RESULTS_DIR/${CONFIG_NAME}_c${CONCURRENCY}.txt" diff --git a/contrib/models/MiMo-V2.5/perf_test/sanity_check.sh b/contrib/models/MiMo-V2.5/perf_test/sanity_check.sh new file mode 100755 index 00000000..3e43a279 --- /dev/null +++ b/contrib/models/MiMo-V2.5/perf_test/sanity_check.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# Quick sanity check against an already-running vLLM server. +# +# Assumes vLLM is already listening on $PORT (default 8000) with MiMo-V2.5 +# loaded. Sends a single chat completion and prints the model's reply. +# +# Usage: +# bash sanity_check.sh # uses defaults +# PORT=8001 bash sanity_check.sh # custom port +# PROMPT="..." bash sanity_check.sh # custom prompt + +set -e + +MODEL_PATH="${MIMO_V2_5_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Neuron-FP8}" +PORT="${PORT:-8000}" +PROMPT="${PROMPT:-What is 1+1? Answer briefly.}" +MAX_TOKENS="${MAX_TOKENS:-64}" + +echo "Sanity check: POST /v1/chat/completions on port $PORT" +echo " Model: $MODEL_PATH" +echo " Prompt: $PROMPT" +echo " Max tokens: $MAX_TOKENS" +echo "" + +# Health check first — fail fast if server isn't up. +if ! curl -sf "http://localhost:$PORT/health" > /dev/null; then + echo "ERROR: vLLM server is not responding on http://localhost:$PORT" + echo "Start it with 'bash bench_mimo_v2_5.sh' or your own launcher first." + exit 1 +fi + +RESPONSE=$(curl -s "http://localhost:$PORT/v1/chat/completions" \ + -H 'Content-Type: application/json' \ + -d "$(cat </dev/null || echo "$RESPONSE" +echo "" + +# Extract the model's reply for a human-friendly one-liner summary. +REPLY=$(echo "$RESPONSE" | python3 -c " +import json, sys +try: + r = json.load(sys.stdin) + print(r['choices'][0]['message']['content'].strip()) +except Exception as e: + print(f'(could not parse reply: {e})') +" 2>/dev/null) + +echo "Model reply: $REPLY" diff --git a/contrib/models/MiMo-V2.5/perf_test/smoke_compile_mimo_v2_5.py b/contrib/models/MiMo-V2.5/perf_test/smoke_compile_mimo_v2_5.py new file mode 100755 index 00000000..136dba6a --- /dev/null +++ b/contrib/models/MiMo-V2.5/perf_test/smoke_compile_mimo_v2_5.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +"""Minimal compile+load smoke test for MiMo-V2.5 FP8 on Trn2. + +Bypasses vLLM entirely so we can iterate on the preprocessed Neuron-FP8 +checkpoint without paying vllm-neuron's startup cost. Builds the MiMo-V2.5 BS=1 +recipe (TP=64, EP=1, blockwise FP8 for routed experts), compiles to a temp +dir, then loads. EP=1 lets the TKG path enter forward_selective_loading +legally so BS=1 compiles — with EP>1 NxDI raises NotImplementedError and +forces BS>=num_experts/top_k = 32. + +STAGE controls how far we go: + instantiate | compile | load | all (default: all) + +DRY_RUN=1 does HLO-only compile (no torch.jit.save + shard). Fastest sanity +check for the preprocessed checkpoint. SKIP_WARMUP=1 on load() skips the +forward pass that allocates the shared scratchpad — useful when HBM is +tight. + +Run under /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16 (same venv used +by the bench script). +""" + +import os +import sys +import time +import traceback + +MODEL_PATH = os.environ.get( + "MIMO_V2_5_MODEL_PATH", + "/opt/dlami/nvme/models/MiMo-V2.5-Neuron-FP8", +) +COMPILED_PATH = os.environ.get( + "MIMO_V2_5_COMPILED_PATH", + "/opt/dlami/nvme/compiled/mimo_v2_5_tp64_moetp1_ep64_fp8/", +) + +TP_DEGREE = int(os.environ.get("TP_DEGREE", "64")) +SEQ_LEN = int(os.environ.get("SEQ_LEN", "1024")) +BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1")) +CTX_BATCH_SIZE = int(os.environ.get("CTX_BATCH_SIZE", "1")) +# Default to moe_tp=1 / moe_ep=64. Under FP8 + moe_tp=64 (our old default) +# each rank's MoE expert intermediate slice is 32 rows (<128, the scale +# block size), which collapses the per-rank scale to a singleton in +# NxDI's `_setup_for_scale` — losing per-channel FP8 scale granularity +# and producing a BF16-accumulator drift that compounds into output +# collapse after ~30 decode tokens. moe_tp=1/moe_ep=64 keeps every expert +# on a single rank (4 full experts per rank), so each expert's scale +# survives intact. Override via MOE_TP / MOE_EP env vars for other recipes. +MOE_TP = int(os.environ.get("MOE_TP", "1")) +MOE_EP = int(os.environ.get("MOE_EP", "64")) + +STAGE = os.environ.get("STAGE", "all").lower() + +os.makedirs(COMPILED_PATH, exist_ok=True) + +# NxDI's model builder uses a per-process temp workdir for HLO/NEFF staging +# (BASE_COMPILE_WORK_DIR, default "/tmp/nxd_model/"). If two compiles run in +# parallel with the same default, they silently overwrite each other's +# .hlo_module.pb files and one or both compilations crash with +# "neuronx-cc returned non-zero exit status 70". Pin the workdir to a +# unique per-COMPILED_PATH subdir to stay safe under any parallel invocation. +# Default under /opt/dlami/nvme rather than /tmp so the HLO/NEFF artifacts +# survive the nightly Trn2 reboot. +os.environ.setdefault( + "BASE_COMPILE_WORK_DIR", + os.path.join( + "/opt/dlami/nvme/tmp/nxd_model", + os.path.basename(COMPILED_PATH.rstrip("/")), + ), +) + + +def main(): + from neuronx_distributed_inference.models.config import MoENeuronConfig + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + + # Import the contrib wrapper (sibling src dir). + contrib_src = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "src", + ) + sys.path.insert(0, os.path.abspath(contrib_src)) + + from modeling_mimo_v2 import ( + MiMoV2InferenceConfig, + NeuronMiMoV2ForCausalLM, + ) + + print(f"[smoke] MODEL_PATH={MODEL_PATH}") + print(f"[smoke] COMPILED_PATH={COMPILED_PATH}") + print(f"[smoke] TP_DEGREE={TP_DEGREE}, SEQ_LEN={SEQ_LEN}, BS={BATCH_SIZE}") + print(f"[smoke] MOE_TP={MOE_TP}, MOE_EP={MOE_EP}") + print(f"[smoke] STAGE={STAGE}") + + print("[smoke] Building MoENeuronConfig (quantized FP8 MoE, blockwise_symmetric)...") + # NOTE: ep_degree at the top level controls the OUTER (full model) + # expert-parallel factor, which multiplies world_size to + # tp_degree * ep_degree and duplicates non-MoE weights per replica. + # At world_size > 64 on a 64-NC Trn2, sharded weights grow accordingly + # (e.g. tp=64 + ep=4 -> 256 ranks -> 4x the sharded checkpoint size, + # and at runtime the model doesn't fit on the device). For MoE-only + # EP we want ep_degree=1 at the outer level and the per-MoE split + # controlled solely by moe_ep_degree (which Pro's working benches + # also do). Keep ep_degree=1 unconditionally. + neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + ep_degree=1, + logical_nc_config=2, + batch_size=BATCH_SIZE, + max_batch_size=BATCH_SIZE, + ctx_batch_size=CTX_BATCH_SIZE, + tkg_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + n_active_tokens=128, + torch_dtype="bfloat16", + capacity_factor=1.0, + glu_mlp=True, + moe_ep_degree=MOE_EP, + moe_tp_degree=MOE_TP, + context_encoding_buckets=[SEQ_LEN], + router_config={"act_fn": "sigmoid", "dtype": "float32"}, + # SDK 2.29 ships only bwmm_shard_on_block / bwmm_shard_on_intermediate; + # default routes to _call_shard_hidden_kernel which is missing, so we + # take the shard-on-block path via this flag. + blockwise_matmul_config={ + "use_shard_on_block_dynamic_while": True, + "block_sharding_strategy": "PING_PONG", + }, + # Persist sharded FP8 weights to disk so subsequent load()s skip the + # ~10-minute shard_checkpoint step (writes weights/tp{0..63}_*.safetensors + # on NVMe; NxDI load() reads these directly when present). + save_sharded_checkpoint=True, + # FP8 blockwise for routed experts (Kimi-K2 recipe). + quantized=True, + quantized_checkpoints_path=MODEL_PATH, + quantization_dtype="f8e4m3", + quantization_type="blockwise_symmetric", + quantization_block_axis=[1, 2], + quantization_block_size=[128, 128], + modules_to_not_convert=[ + "embed_tokens", + "lm_head", + "norm", + "router", + "o_proj", + ], + ) + + print("[smoke] Building MiMoV2InferenceConfig...") + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + config = MiMoV2InferenceConfig( + neuron_config, load_config=load_pretrained_config(hf_config=hf_config) + ) + print(f"[smoke] config.hidden_size={config.hidden_size}") + print(f"[smoke] config.num_hidden_layers={config.num_hidden_layers}") + print(f"[smoke] config.n_routed_experts={config.n_routed_experts}") + print(f"[smoke] config.num_experts_per_tok={config.num_experts_per_tok}") + print(f"[smoke] config.layer_uses_moe[:5]={config.layer_uses_moe[:5]}") + print(f"[smoke] config.layer_attention_types[:5]={config.layer_attention_types[:5]}") + + print("[smoke] Instantiating NeuronMiMoV2ForCausalLM (build model-on-cpu)...") + t0 = time.time() + model = NeuronMiMoV2ForCausalLM(MODEL_PATH, config) + print(f"[smoke] Instantiated in {time.time() - t0:.1f}s") + + if STAGE == "instantiate": + print("[smoke] STAGE=instantiate only, skipping compile/load.") + return + + DRY_RUN = os.environ.get("DRY_RUN", "0") == "1" + if STAGE in ("compile", "all"): + label = "Dry-run compile (HLO only)" if DRY_RUN else "Full compile" + print(f"[smoke] {label} -> {COMPILED_PATH}") + t0 = time.time() + try: + model.compile(COMPILED_PATH, dry_run=DRY_RUN) + print(f"[smoke] {label} OK in {time.time() - t0:.1f}s") + except Exception: + print(f"[smoke] {label} FAILED:") + traceback.print_exc() + raise + + if STAGE in ("load", "all") and not DRY_RUN: + SKIP_WARMUP = os.environ.get("SKIP_WARMUP", "1") == "1" + print(f"[smoke] Loading compiled model from {COMPILED_PATH} (skip_warmup={SKIP_WARMUP})") + t0 = time.time() + model.load(COMPILED_PATH, skip_warmup=SKIP_WARMUP) + print(f"[smoke] Loaded in {time.time() - t0:.1f}s") + + print("[smoke] Done.") + + +if __name__ == "__main__": + try: + main() + except Exception: + traceback.print_exc() + sys.exit(1) diff --git a/contrib/models/MiMo-V2.5/perf_test/smoke_generate_mimo_v2_5.py b/contrib/models/MiMo-V2.5/perf_test/smoke_generate_mimo_v2_5.py new file mode 100755 index 00000000..45f4bce0 --- /dev/null +++ b/contrib/models/MiMo-V2.5/perf_test/smoke_generate_mimo_v2_5.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +"""Minimal generate smoke test for MiMo-V2.5 FP8 on Trn2. + +Assumes the compiled NEFF already exists at MIMO_V2_5_COMPILED_PATH +(from smoke_compile_mimo_v2_5.py). Rebuilds the same MoENeuronConfig / +MiMo-V2.5 wrapper, loads with skip_warmup=False, and generates 20 tokens for a +single prompt via HuggingFaceGenerationAdapter. Purpose: sanity-check that +the FP8 MoE + preprocessed scales actually produce coherent tokens. + +Run under /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16. +""" + +import os +import sys +import time +import traceback + +MODEL_PATH = os.environ.get( + "MIMO_V2_5_MODEL_PATH", + "/opt/dlami/nvme/models/MiMo-V2.5-Neuron-FP8", +) +COMPILED_PATH = os.environ.get( + "MIMO_V2_5_COMPILED_PATH", + "/opt/dlami/nvme/compiled/mimo_v2_5_tp64_moetp1_ep64_fp8/", +) + +# Must match smoke_compile_mimo_v2_5.py exactly, else load() sees a +# mismatched NEFF. +TP_DEGREE = int(os.environ.get("TP_DEGREE", "64")) +SEQ_LEN = int(os.environ.get("SEQ_LEN", "1024")) +BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1")) +CTX_BATCH_SIZE = int(os.environ.get("CTX_BATCH_SIZE", "1")) +MOE_TP = int(os.environ.get("MOE_TP", "1")) +MOE_EP = int(os.environ.get("MOE_EP", "64")) + +PROMPT = os.environ.get( + "MIMO_V2_5_PROMPT", + "Hello! Please introduce yourself in one sentence.", +) +MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "20")) + +# Keep the per-compile BASE_COMPILE_WORK_DIR in sync with +# smoke_compile_mimo_v2_5.py so load() under the same COMPILED_PATH +# doesn't collide with a concurrent compile or reuse a stale workdir. +# Default under /opt/dlami/nvme so artifacts survive the nightly Trn2 reboot. +os.environ.setdefault( + "BASE_COMPILE_WORK_DIR", + os.path.join( + "/opt/dlami/nvme/tmp/nxd_model", + os.path.basename(COMPILED_PATH.rstrip("/")), + ), +) + + +def main(): + from transformers import AutoConfig, AutoTokenizer, GenerationConfig + + from neuronx_distributed_inference.models.config import MoENeuronConfig + from neuronx_distributed_inference.utils.hf_adapter import ( + HuggingFaceGenerationAdapter, + load_pretrained_config, + ) + + contrib_src = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "src", + ) + sys.path.insert(0, os.path.abspath(contrib_src)) + + from modeling_mimo_v2 import ( + MiMoV2InferenceConfig, + NeuronMiMoV2ForCausalLM, + ) + + print(f"[gen] MODEL_PATH={MODEL_PATH}") + print(f"[gen] COMPILED_PATH={COMPILED_PATH}") + print(f"[gen] TP={TP_DEGREE}, SEQ={SEQ_LEN}, BS={BATCH_SIZE}") + + # Outer ep_degree must match the compile-time value (kept at 1 so + # world_size = tp_degree; see smoke_compile_mimo_v2_5.py comment). + neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + ep_degree=1, + logical_nc_config=2, + batch_size=BATCH_SIZE, + max_batch_size=BATCH_SIZE, + ctx_batch_size=CTX_BATCH_SIZE, + tkg_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + n_active_tokens=128, + torch_dtype="bfloat16", + capacity_factor=1.0, + glu_mlp=True, + moe_ep_degree=MOE_EP, + moe_tp_degree=MOE_TP, + context_encoding_buckets=[SEQ_LEN], + router_config={"act_fn": "sigmoid", "dtype": "float32"}, + blockwise_matmul_config={ + "use_shard_on_block_dynamic_while": True, + "block_sharding_strategy": "PING_PONG", + }, + save_sharded_checkpoint=True, + quantized=True, + quantized_checkpoints_path=MODEL_PATH, + quantization_dtype="f8e4m3", + quantization_type="blockwise_symmetric", + quantization_block_axis=[1, 2], + quantization_block_size=[128, 128], + modules_to_not_convert=[ + "embed_tokens", + "lm_head", + "norm", + "router", + "o_proj", + ], + ) + + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + config = MiMoV2InferenceConfig( + neuron_config, load_config=load_pretrained_config(hf_config=hf_config) + ) + + print("[gen] Instantiating model...") + t0 = time.time() + model = NeuronMiMoV2ForCausalLM(MODEL_PATH, config) + print(f"[gen] Instantiated in {time.time() - t0:.1f}s") + + # skip_warmup=False so generate() hits a primed graph (the warmup forward + # allocates the shared scratchpad the generation path needs). + print(f"[gen] Loading from {COMPILED_PATH} (skip_warmup=False)") + t0 = time.time() + model.load(COMPILED_PATH, skip_warmup=False) + print(f"[gen] Loaded in {time.time() - t0:.1f}s") + + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + adapter = HuggingFaceGenerationAdapter(model) + + inputs = tokenizer([PROMPT] * BATCH_SIZE, return_tensors="pt", padding=True) + gen_config = GenerationConfig( + max_new_tokens=MAX_NEW_TOKENS, + min_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + pad_token_id=getattr(tokenizer, "pad_token_id", None) or tokenizer.eos_token_id, + ) + + print(f"[gen] prompt: {PROMPT!r}") + print(f"[gen] input_ids.shape={tuple(inputs['input_ids'].shape)}") + t0 = time.time() + output_ids = adapter.generate( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + generation_config=gen_config, + ) + dt = time.time() - t0 + + prompt_len = inputs["input_ids"].shape[1] + new_tokens = output_ids[0, prompt_len:] + decoded = tokenizer.decode(new_tokens, skip_special_tokens=True) + full = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + print(f"[gen] generated {new_tokens.numel()} tokens in {dt:.2f}s " + f"({new_tokens.numel() / dt:.2f} tok/s)") + print(f"[gen] new token ids: {new_tokens.tolist()}") + print(f"[gen] new text : {decoded!r}") + print(f"[gen] full text : {full!r}") + print("[gen] Done.") + + +if __name__ == "__main__": + try: + main() + except Exception: + traceback.print_exc() + sys.exit(1) diff --git a/contrib/models/MiMo-V2.5/perf_test/start_vllm_server.sh b/contrib/models/MiMo-V2.5/perf_test/start_vllm_server.sh new file mode 100755 index 00000000..dc515ba0 --- /dev/null +++ b/contrib/models/MiMo-V2.5/perf_test/start_vllm_server.sh @@ -0,0 +1,104 @@ +#!/bin/bash +# Start the MiMo-V2.5 FP8 vLLM OpenAI-compatible server in the foreground. +# +# The server stays up until you Ctrl-C it. Use sanity_check.sh and +# run_bench_single.sh in a separate shell to exercise / benchmark it. +# bench_mimo_v2_5.sh calls this script under the hood for its one-shot +# launch + bench + teardown flow. +# +# Recipe: TP=64, moe_tp=1/moe_ep=64, BS=32, continuous batching + bucketing. +# moe_tp=1/moe_ep=64 keeps each expert's weights and blockwise FP8 scales +# intact on a single rank (4 experts/rank), avoiding the per-rank scale +# collapse that comes from moe_tp=64 when intermediate=2048 is TP-sharded +# below the 128-row scale block boundary. +# +# NxDI's TKG path refuses Expert Parallelism with BS < num_experts/top_k +# (256 / 8 = 32), so BS=32 is the smallest working batch size on the FP8 +# path. BS=1 single-stream latency is not currently supported on V2.5 FP8. + +set -e + +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate + +MODEL_PATH="${MIMO_V2_5_PATH:-/opt/dlami/nvme/models/MiMo-V2.5-Neuron-FP8}" +PORT="${PORT:-8000}" + +# Contrib package src. vllm-neuron's registration hook reads these env vars +# to plug NeuronMiMoV2ForCausalLM into NxDI's MODEL_TYPES table. +: "${NXDI_CONTRIB_MIMO_V2_5_SRC:=$(cd "$(dirname "$0")/.." && pwd)/src}" +export NXDI_CONTRIB_MIMO_V2_5_SRC +# vLLM 0.16's builtin arch validator knows MiMoV2FlashForCausalLM but not +# MiMoV2ForCausalLM. Preprocess rewrites the checkpoint's config.json +# architectures to the Flash name, and we reuse the Flash registration +# key in vllm-neuron (MODEL_TYPES['mimov2flash']). The modeling module +# (modeling_mimo_v2) and class (NeuronMiMoV2ForCausalLM) are shared. +export NXDI_CONTRIB_MIMO_V2_FLASH_SRC="$NXDI_CONTRIB_MIMO_V2_5_SRC" + +# Persistent compile-artifact location (NEFF + per-rank sharded weights). +# Setting this overrides vLLM's fallback of /neuron-compiled-artifacts//. +: "${NEURON_COMPILED_ARTIFACTS:=/opt/dlami/nvme/compiled/mimo_v2_5_bs32_moetp1_ep64_fp8_vllm}" +export NEURON_COMPILED_ARTIFACTS +# NxDI HLO/NEFF staging directory, pinned to persistent storage so it +# survives the nightly Trn2 reboot and a unique per-config subdir. +: "${BASE_COMPILE_WORK_DIR:=/opt/dlami/nvme/tmp/nxd_model/$(basename "$NEURON_COMPILED_ARTIFACTS")}" +export BASE_COMPILE_WORK_DIR +mkdir -p "$BASE_COMPILE_WORK_DIR" + +# First-time compile of V2.5's 256-expert MoE takes ~30 min (HLO + shard). +export VLLM_ENGINE_READY_TIMEOUT_S="${VLLM_ENGINE_READY_TIMEOUT_S:-7200}" + +echo "==========================================" +echo "Starting MiMo-V2.5 FP8 vLLM server" +echo "==========================================" +echo " Model path: $MODEL_PATH" +echo " Port: $PORT" +echo " Compiled artifacts: $NEURON_COMPILED_ARTIFACTS" +echo " Compile work dir: $BASE_COMPILE_WORK_DIR" +echo " NXDI_CONTRIB_MIMO_V2_5_SRC: $NXDI_CONTRIB_MIMO_V2_5_SRC" +echo "" + +exec python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 32 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port "$PORT" \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + "tp_degree": 64, + "logical_nc_config": 2, + "fused_qkv": false, + "sequence_parallel_enabled": false, + "glu_mlp": true, + "normalize_top_k_affinities": true, + "save_sharded_checkpoint": true, + "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, + "quantized": true, + "quantized_checkpoints_path": "'"$MODEL_PATH"'", + "quantization_dtype": "f8e4m3", + "quantization_type": "blockwise_symmetric", + "quantization_block_axis": [1, 2], + "quantization_block_size": [128, 128], + "modules_to_not_convert": ["embed_tokens", "lm_head", "norm", "router", "o_proj"], + "blockwise_matmul_config": {"use_shard_on_block_dynamic_while": true, "block_sharding_strategy": "PING_PONG"}, + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 32, + "ctx_batch_size": 1, + "tkg_batch_size": 32, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "enable_bucketing": true, + "context_encoding_buckets": [1024], + "token_generation_buckets": [1024], + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' diff --git a/contrib/models/MiMo-V2.5/perf_test/vllm-neuron-patch.patch b/contrib/models/MiMo-V2.5/perf_test/vllm-neuron-patch.patch new file mode 100644 index 00000000..4a84a558 --- /dev/null +++ b/contrib/models/MiMo-V2.5/perf_test/vllm-neuron-patch.patch @@ -0,0 +1,107 @@ +diff --git a/vllm_neuron/worker/neuronx_distributed_model_loader.py b/vllm_neuron/worker/neuronx_distributed_model_loader.py +index d2099eb..0c162e4 100644 +--- a/vllm_neuron/worker/neuronx_distributed_model_loader.py ++++ b/vllm_neuron/worker/neuronx_distributed_model_loader.py +@@ -922,6 +922,94 @@ def _camel_to_kebab(name: str) -> str: + return re.sub("([a-z0-9])([A-Z])", r"\1-\2", s1).lower() + + ++ ++def _patch_autoconfig_trust_remote_code(): ++ """Monkey-patch ``AutoConfig.from_pretrained`` to default ``trust_remote_code=True``. ++ ++ NxDI's ``hf_adapter.load_config`` calls ``AutoConfig.from_pretrained(path)`` ++ without ``trust_remote_code``. Contrib models like MiMo-V2-Flash that ++ ship a ``configuration_*.py`` with the checkpoint require custom code ++ execution, so the default behaviour crashes with ``ValueError: The ++ repository ... contains custom code which must be executed``. ++ ++ vLLM's top-level ``--trust-remote-code`` flag only affects vLLM's own ++ config load, not NxDI's. Patching here is cheap and idempotent. ++ """ ++ try: ++ from transformers import AutoConfig ++ except ImportError: ++ return ++ if getattr(AutoConfig, "_nxdi_contrib_patched", False): ++ return ++ _orig = AutoConfig.from_pretrained ++ ++ def _patched(*args, **kwargs): ++ kwargs.setdefault("trust_remote_code", True) ++ return _orig(*args, **kwargs) ++ ++ AutoConfig.from_pretrained = _patched ++ AutoConfig._nxdi_contrib_patched = True ++ ++ ++def _register_contrib_models(): ++ """Lazy-register NxDI contrib models on each process that calls the loader. ++ ++ Driven by env vars: ++ NXDI_CONTRIB_MIMO_V2_FLASH_SRC -> path to contrib MiMo-V2-Flash src/ ++ NXDI_CONTRIB_MINIMAX_M2_SRC -> path to contrib MiniMax-M2 src/ ++ ++ Registers the contrib model class into NxDI's MODEL_TYPES and, where ++ vLLM does not already know the architecture, registers it into vLLM's ++ ModelRegistry. Runs every time _get_neuron_model_cls is called so that ++ vLLM's spawn'd EngineCore workers (which don't inherit the parent's ++ module-level state) pick up the registration too. Registration is ++ idempotent. ++ """ ++ import os as _os ++ import sys as _sys ++ import warnings as _w ++ ++ _patch_autoconfig_trust_remote_code() ++ ++ mimo_src = _os.environ.get("NXDI_CONTRIB_MIMO_V2_FLASH_SRC") ++ if mimo_src and _os.path.isdir(mimo_src) and "mimov2flash" not in MODEL_TYPES: ++ if mimo_src not in _sys.path: ++ _sys.path.insert(0, mimo_src) ++ try: ++ from modeling_mimo_v2 import NeuronMiMoV2ForCausalLM ++ MODEL_TYPES.setdefault( ++ "mimov2flash", {"causal-lm": NeuronMiMoV2ForCausalLM} ++ ) ++ try: ++ from vllm.model_executor.models.registry import ModelRegistry ++ if "MiMoV2FlashForCausalLM" not in ModelRegistry.get_supported_archs(): ++ ModelRegistry.register_model( ++ "MiMoV2FlashForCausalLM", NeuronMiMoV2ForCausalLM ++ ) ++ except ImportError: ++ pass ++ except Exception as e: ++ _w.warn( ++ f"Failed to register MiMo-V2-Flash contrib model: {e}", ++ category=UserWarning, ++ ) ++ ++ minimax_src = _os.environ.get("NXDI_CONTRIB_MINIMAX_M2_SRC") ++ if minimax_src and _os.path.isdir(minimax_src) and "minimaxm2" not in MODEL_TYPES: ++ if minimax_src not in _sys.path: ++ _sys.path.insert(0, minimax_src) ++ try: ++ from modeling_minimax_m2 import NeuronMiniMaxM2ForCausalLM ++ MODEL_TYPES.setdefault( ++ "minimaxm2", {"causal-lm": NeuronMiniMaxM2ForCausalLM} ++ ) ++ except Exception as e: ++ _w.warn( ++ f"Failed to register MiniMax-M2 contrib model: {e}", ++ category=UserWarning, ++ ) ++ ++ + def _get_neuron_model_cls(architecture: str): + """ + Get Neuron model class from architecture string. +@@ -941,6 +1029,7 @@ def _get_neuron_model_cls(architecture: str): + _get_neuron_model_cls("NeuronLlamaForCausalLM") + + """ ++ _register_contrib_models() + # Handle Neuron class name (starts with "Neuron") - strip prefix + if architecture.startswith("Neuron") and "For" in architecture: + original_architecture = architecture diff --git a/contrib/models/MiMo-V2.5/src/__init__.py b/contrib/models/MiMo-V2.5/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiMo-V2.5/src/conversion_script/preprocess_mimo_v2_5_fp8.py b/contrib/models/MiMo-V2.5/src/conversion_script/preprocess_mimo_v2_5_fp8.py new file mode 100644 index 00000000..1530e93f --- /dev/null +++ b/contrib/models/MiMo-V2.5/src/conversion_script/preprocess_mimo_v2_5_fp8.py @@ -0,0 +1,710 @@ +""" +Preprocess MiMo-V2.5 FP8 checkpoint for Neuron inference. + +This is a streaming (per-layer) rewrite of preprocess_mimo_v2_fp8.py. The +original preprocess loaded the entire ~290 GB FP8 checkpoint into RAM via +load_state_dict(); that peaks well over 600 GB after dequantize/requantize +copies and is fragile. This version keeps a single safe_open handle live +at a time and emits per-layer safetensors shards, capping peak RAM at +~24 GB and finishing in ~20 minutes. + +MiMo-V2.5 checkpoint layout: + - q_proj, k_proj, v_proj are stored *separately* in the HF checkpoint + (not pre-fused). No split_qkv_fused needed. + - o_proj is BF16 (listed in quantization_config.ignored_layers); kept + as BF16 on the Neuron side (RowParallelLinear, not QuantizedRowParallel). + - Layer 0 is a dense MLP (moe_layer_freq[0] == 0) with intermediate_size + 16384; layers 1..47 are MoE with 256 experts each. + - Hybrid attention: 9 "full" layers (hybrid_layer_pattern[i] == 0) and + 39 "sliding window" layers (== 1). SWA layers carry + attention_sink_bias (add_swa_attention_sink_bias=True in the config; + add_full_attention_sink_bias=False, so full layers do NOT get it). + +Neuron-side rescaling (same as MiMo-V2 siblings): + - OCP FP8 e4m3 (±448) -> Neuron FP8 e4m3 (±240) with FP8_SCALING_FACTOR=448/240. + - Per-row scales for attention/dense-mlp projections (q/k/v/o, gate/up/down + of the dense layer). + - Blockwise (128x128) scales kept for MoE expert weights; per-expert weights + are transposed and fused to match ExpertFusedRowParallelLinear's packed + layout (gate_up_proj: [num_experts, H, 2*IM]; down_proj: [num_experts, IM, H]). + +Output layout: + save_path/ + config.json, tokenizer.*, chat_template.jinja if present + configuration_mimo_v2.py, modeling_mimo_v2.py (trust_remote_code) + model.safetensors.index.json (regenerated) + model_extras.safetensors (embed_tokens, norm, lm_head) + model_layer{N}.safetensors (one per decoder layer, N=0..47) + +Usage: + python preprocess_mimo_v2_5_fp8.py \\ + --hf_model_path /opt/dlami/nvme/models/MiMo-V2.5 \\ + --save_path /opt/dlami/nvme/models/MiMo-V2.5-Neuron-FP8 \\ + --tp_degree 64 +""" + +import argparse +import gc +import json +import os +import shutil +import time +from typing import Dict, List, Optional, Tuple + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + + +FP8_SCALING_FACTOR = 448.0 / 240.0 +NEURON_FP8_MAX = 240.0 + + +# --------------------------------------------------------------------------- +# Quantization primitives +# --------------------------------------------------------------------------- + +def convert_bf16_to_fp8_per_row( + weight: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """BF16 [out, in] -> Neuron FP8 per-row (scales shape [out, 1]).""" + weight_float = weight.float() + row_max_abs = weight_float.abs().max(dim=1, keepdim=True)[0] + scales = torch.clamp(row_max_abs / NEURON_FP8_MAX, min=1e-10) + quantized = (weight_float / scales).to(torch.float8_e4m3fn) + return quantized, scales.to(torch.float32) + + +def rescale_fp8_to_per_row( + weight: torch.Tensor, scale: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Block-wise FP8 + blockwise scale -> Neuron per-row FP8. + + Dequantize to float32 using block broadcast, then per-row requantize. + """ + out_features, in_features = weight.shape + scale_h, scale_w = scale.shape + + block_h = (out_features + scale_h - 1) // scale_h + block_w = (in_features + scale_w - 1) // scale_w + + weight_float = weight.float() + dequantized = torch.zeros(out_features, in_features, dtype=torch.float32) + for i in range(scale_h): + for j in range(scale_w): + h0, h1 = i * block_h, min((i + 1) * block_h, out_features) + w0, w1 = j * block_w, min((j + 1) * block_w, in_features) + dequantized[h0:h1, w0:w1] = ( + weight_float[h0:h1, w0:w1] * scale[i, j].item() + ) + + row_max_abs = dequantized.abs().max(dim=1, keepdim=True)[0] + scales = torch.clamp(row_max_abs / NEURON_FP8_MAX, min=1e-10) + quantized = (dequantized / scales).to(torch.float8_e4m3fn) + return quantized, scales.to(torch.float32) + + +def rescale_fp8_weight_blockwise( + weight: torch.Tensor, scale: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Keep blockwise scales, just rescale into Neuron FP8 range. + + MoE expert weights stay block-quantized; only the dtype range changes. + """ + weight_bf16 = weight.bfloat16() + rescaled = (weight_bf16 / FP8_SCALING_FACTOR).to(torch.float8_e4m3fn) + neuron_scale = scale.float() * FP8_SCALING_FACTOR + return rescaled, neuron_scale.to(torch.float32) + + +def _requantize_per_row(dequant: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """BF16/FP32 -> Neuron FP8 per-row.""" + row_max_abs = dequant.abs().max(dim=1, keepdim=True)[0] + scales = row_max_abs / NEURON_FP8_MAX + scales = torch.clamp(scales, min=1e-10) + quantized = (dequant / scales).to(torch.float8_e4m3fn) + return quantized, scales.to(torch.float32) + + +def split_qkv_fused( + qkv_weight: torch.Tensor, + qkv_scale: Optional[torch.Tensor], + num_q_heads: int, + num_kv_heads_this_layer: int, + num_groups: int, + head_dim: int, + v_head_dim: int, +) -> Dict[str, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Split V2.5's fused qkv_proj into q/k/v. + + Layout — validated empirically via per-group Q/K/V magnitude probes: + + The weight is NOT `[all_Q | all_K | all_V]`. It is ``num_groups`` + interleaved groups, each packing ``hpg`` Q heads, ``kpg`` K heads, and + ``kpg`` V heads contiguously: + + group g (g = 0 .. num_groups-1): + rows [g*R : g*R + qg] = Q heads [g*hpg : (g+1)*hpg] + rows [g*R + qg : g*R + qg + kg] = K heads [g*kpg : (g+1)*kpg] + rows [g*R + qg + kg : g*R + R] = V heads [g*kpg : (g+1)*kpg] + where + hpg = num_q_heads / num_groups + kpg = num_kv_heads_this_layer / num_groups + qg = hpg * head_dim + kg = kpg * head_dim + vg = kpg * v_head_dim + R = qg + kg + vg + + ``num_groups`` is a model-level constant (= the full-attention + ``num_key_value_heads``, 4 for V2.5). It is the same for full and SWA + layers, so a SWA layer with ``num_kv_heads_this_layer=8`` packs + ``kpg=2`` K heads + 2 V heads per group. + + Scale layout: each group holds ``q_blk + k_blk + v_blk`` scale rows + where ``q_blk = qg // 128`` (exact; qg is always 128-aligned), + ``k_blk = ceil(kg / 128)`` (may add 64 rows of "phantom" padding on + V2.5 full layers where kg = 192 and the last half-block is unused), + ``v_blk = ceil(vg / 128)``. When a phantom half-block appears, the + physical weight rows stop before the phantom rows do — we recover the + correct dequant by padding each group's weight out to + ``per_group_scale * 128`` rows, broadcasting the scale, then stripping + the phantom rows. + """ + in_features = qkv_weight.shape[1] + assert num_q_heads % num_groups == 0 and num_kv_heads_this_layer % num_groups == 0, ( + f"num_q_heads={num_q_heads} and num_kv_heads_this_layer=" + f"{num_kv_heads_this_layer} must both be divisible by " + f"num_groups={num_groups}" + ) + hpg = num_q_heads // num_groups + kpg = num_kv_heads_this_layer // num_groups + qg_rows = hpg * head_dim + kg_rows = kpg * head_dim + vg_rows = kpg * v_head_dim + real_rows_per_group = qg_rows + kg_rows + vg_rows + total_real_rows = num_groups * real_rows_per_group + + BLOCK = 128 + q_scale_rows_per_group = qg_rows // BLOCK # exact + k_scale_rows_per_group = (kg_rows + BLOCK - 1) // BLOCK + v_scale_rows_per_group = (vg_rows + BLOCK - 1) // BLOCK + scale_rows_per_group = ( + q_scale_rows_per_group + k_scale_rows_per_group + v_scale_rows_per_group + ) + padded_rows_per_group = scale_rows_per_group * BLOCK + + assert qkv_weight.shape[0] == total_real_rows, ( + f"qkv_proj.weight row count {qkv_weight.shape[0]} != expected " + f"{total_real_rows} (num_groups={num_groups}, hpg={hpg}, kpg={kpg}, " + f"R={real_rows_per_group})" + ) + + if qkv_weight.dtype != torch.float8_e4m3fn or qkv_scale is None: + # BF16 path: no scale to worry about. + w = qkv_weight.view(num_groups, real_rows_per_group, in_features) + q_w = ( + w[:, :qg_rows, :] + .reshape(num_groups * qg_rows, in_features) + .contiguous() + ) + k_w = ( + w[:, qg_rows : qg_rows + kg_rows, :] + .reshape(num_groups * kg_rows, in_features) + .contiguous() + ) + v_w = ( + w[:, qg_rows + kg_rows :, :] + .reshape(num_groups * vg_rows, in_features) + .contiguous() + ) + q_w2, q_s2 = convert_bf16_to_fp8_per_row(q_w) + k_w2, k_s2 = convert_bf16_to_fp8_per_row(k_w) + v_w2, v_s2 = convert_bf16_to_fp8_per_row(v_w) + return {"q_proj": (q_w2, q_s2), "k_proj": (k_w2, k_s2), "v_proj": (v_w2, v_s2)} + + # FP8 + blockwise scale path. + expected_scale_rows = num_groups * scale_rows_per_group + expected_scale_cols = (in_features + BLOCK - 1) // BLOCK + assert qkv_scale.shape == (expected_scale_rows, expected_scale_cols), ( + f"qkv scale shape {tuple(qkv_scale.shape)} != expected " + f"({expected_scale_rows}, {expected_scale_cols}) for " + f"num_groups={num_groups}, per_group={scale_rows_per_group}" + ) + + w = qkv_weight.to(torch.float32).view( + num_groups, real_rows_per_group, in_features + ) + w_padded = torch.zeros( + num_groups, padded_rows_per_group, in_features, dtype=torch.float32 + ) + w_padded[:, :real_rows_per_group, :] = w + + s = qkv_scale.to(torch.float32).view( + num_groups, scale_rows_per_group, expected_scale_cols + ) + s_exp = s.repeat_interleave(BLOCK, dim=1).repeat_interleave(BLOCK, dim=2) + s_exp = s_exp[:, :padded_rows_per_group, :in_features] + + deq_padded = w_padded * s_exp + deq = deq_padded[:, :real_rows_per_group, :] + + q_deq = ( + deq[:, :qg_rows, :] + .reshape(num_groups * qg_rows, in_features) + .contiguous() + ) + k_deq = ( + deq[:, qg_rows : qg_rows + kg_rows, :] + .reshape(num_groups * kg_rows, in_features) + .contiguous() + ) + v_deq = ( + deq[:, qg_rows + kg_rows :, :] + .reshape(num_groups * vg_rows, in_features) + .contiguous() + ) + + q_w2, q_s2 = _requantize_per_row(q_deq) + k_w2, k_s2 = _requantize_per_row(k_deq) + v_w2, v_s2 = _requantize_per_row(v_deq) + + return {"q_proj": (q_w2, q_s2), "k_proj": (k_w2, k_s2), "v_proj": (v_w2, v_s2)} + + +# --------------------------------------------------------------------------- +# Streaming weight access (one open safetensors handle at a time) +# --------------------------------------------------------------------------- + +class LazyWeightMap: + """Lazily fetch tensors from sharded safetensors, keeping one handle live.""" + + def __init__(self, model_dir: str, weight_map: Dict[str, str]): + self.model_dir = model_dir + self._cur_filename: Optional[str] = None + self._cur_handle = None + # V2.5's published model.safetensors.index.json references filenames + # like `model_N-00001-of-00002.safetensors`, but the shards on disk + # are `model_pp0_epN_shardM.safetensors` and the ep/N numbers don't + # line up. Rather than try to reverse-engineer the mapping, scan the + # on-disk shards and rebuild weight_map by reading each file's + # manifest directly. Falls back to the provided weight_map when the + # shard files match the names on disk (pre-V2.5 checkpoints). + actual_files = sorted(f for f in os.listdir(model_dir) if f.endswith(".safetensors")) + names_in_weight_map = set(weight_map.values()) + if actual_files and not (names_in_weight_map & set(actual_files)): + rebuilt: Dict[str, str] = {} + for fname in actual_files: + path = os.path.join(model_dir, fname) + with safe_open(path, framework="pt", device="cpu") as fp: + for k in fp.keys(): + rebuilt[k] = fname + self.weight_map = rebuilt + else: + self.weight_map = weight_map + + def _open(self, filename: str): + if self._cur_filename == filename: + return self._cur_handle + if self._cur_handle is not None: + self._cur_handle.__exit__(None, None, None) + self._cur_handle = None + path = os.path.join(self.model_dir, filename) + self._cur_handle = safe_open(path, framework="pt", device="cpu") + self._cur_handle.__enter__() + self._cur_filename = filename + return self._cur_handle + + def get(self, key: str) -> Optional[torch.Tensor]: + filename = self.weight_map.get(key) + if filename is None: + return None + return self._open(filename).get_tensor(key) + + def has(self, key: str) -> bool: + return key in self.weight_map + + def close(self): + if self._cur_handle is not None: + self._cur_handle.__exit__(None, None, None) + self._cur_handle = None + self._cur_filename = None + + +# --------------------------------------------------------------------------- +# Per-tensor helper +# --------------------------------------------------------------------------- + +def _maybe_fp8_to_neuron_per_row( + weight: torch.Tensor, scale: Optional[torch.Tensor] +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """FP8 blockwise -> per-row, or BF16 -> FP8 per-row. Pass-through otherwise.""" + if weight.dtype == torch.float8_e4m3fn and scale is not None: + return rescale_fp8_to_per_row(weight, scale) + if weight.dtype == torch.bfloat16: + return convert_bf16_to_fp8_per_row(weight) + return weight, scale + + +# --------------------------------------------------------------------------- +# Per-layer processing +# --------------------------------------------------------------------------- + +def process_layer( + layer_idx: int, + lazy: LazyWeightMap, + config: dict, + is_dense: bool, + is_swa: bool, +) -> Dict[str, torch.Tensor]: + out: Dict[str, torch.Tensor] = {} + prefix = f"model.layers.{layer_idx}." + out_prefix = f"layers.{layer_idx}." + + # --- Layer norms (BF16, untouched) --- + for name in ("input_layernorm", "post_attention_layernorm"): + t = lazy.get(f"{prefix}{name}.weight") + if t is not None: + out[f"{out_prefix}{name}.weight"] = t.detach().clone() + + # --- Attention: q/k/v --- + # MiMo-V2.5 ships QKV *fused* into a single self_attn.qkv_proj.weight, + # with an interleaved-group layout (see split_qkv_fused for details). + # The NxDI modeling code expects separate q_proj / k_proj / v_proj, so + # split the fused tensor back out. Falls back to per-proj tensors if + # the checkpoint is already split (pre-V2.5). + qkv_w = lazy.get(f"{prefix}self_attn.qkv_proj.weight") + qkv_s = lazy.get(f"{prefix}self_attn.qkv_proj.weight_scale_inv") + if qkv_w is not None: + num_heads = config["swa_num_attention_heads" if is_swa else "num_attention_heads"] + num_kv_heads_this = config[ + "swa_num_key_value_heads" if is_swa else "num_key_value_heads" + ] + qk_head_dim = config["swa_head_dim" if is_swa else "head_dim"] + v_hd = config["swa_v_head_dim" if is_swa else "v_head_dim"] + # num_groups is a model-level constant = full-attention num_kv_heads. + # SWA layers with num_kv_heads=8 still use 4 groups (2 K heads per group). + num_groups = config["num_key_value_heads"] + split = split_qkv_fused( + qkv_w, + qkv_s, + num_q_heads=num_heads, + num_kv_heads_this_layer=num_kv_heads_this, + num_groups=num_groups, + head_dim=qk_head_dim, + v_head_dim=v_hd, + ) + for proj, (w2, s2) in split.items(): + out[f"{out_prefix}self_attn.{proj}.weight"] = w2 + if s2 is not None: + out[f"{out_prefix}self_attn.{proj}.scale"] = s2 + else: + # Fallback path for checkpoints that ship split q/k/v (pre-V2.5). + for proj in ("q_proj", "k_proj", "v_proj"): + w = lazy.get(f"{prefix}self_attn.{proj}.weight") + if w is None: + continue + s = lazy.get(f"{prefix}self_attn.{proj}.weight_scale_inv") + w2, s2 = _maybe_fp8_to_neuron_per_row(w, s) + out[f"{out_prefix}self_attn.{proj}.weight"] = w2 + if s2 is not None: + out[f"{out_prefix}self_attn.{proj}.scale"] = s2 + + # o_proj is listed in HF quantization_config.ignored_layers and ships as + # BF16; on Neuron it binds to a plain RowParallelLinear (see + # modeling_mimo_v2.py: self.o_proj = RowParallelLinear(...)), NOT a + # QuantizedRowParallel. Writing FP8 + .scale here would silently be + # reinterpreted as BF16 bytes at load time and produce garbage outputs. + # Keep BF16, never emit .scale. + o_w = lazy.get(f"{prefix}self_attn.o_proj.weight") + o_s = lazy.get(f"{prefix}self_attn.o_proj.weight_scale_inv") + if o_w is not None: + if o_w.dtype == torch.float8_e4m3fn: + # Defensive: if a future checkpoint FP8-quantizes o_proj, dequant + # blockwise back to BF16 (no per-row requant; RowParallelLinear has + # no .scale parameter). + assert o_s is not None, "FP8 o_proj requires weight_scale_inv" + out_features, in_features = o_w.shape + scale_h, scale_w = o_s.shape + block_h = (out_features + scale_h - 1) // scale_h + block_w = (in_features + scale_w - 1) // scale_w + wf = o_w.float() + tmp = torch.zeros(out_features, in_features, dtype=torch.float32) + for i in range(scale_h): + for j in range(scale_w): + h0, h1 = i * block_h, min((i + 1) * block_h, out_features) + w0, w1 = j * block_w, min((j + 1) * block_w, in_features) + tmp[h0:h1, w0:w1] = wf[h0:h1, w0:w1] * o_s[i, j].item() + o_bf16 = tmp.to(torch.bfloat16) + else: + o_bf16 = o_w.to(torch.bfloat16) + out[f"{out_prefix}self_attn.o_proj.weight"] = o_bf16.detach().clone() + + # --- attention_sink_bias: present only on SWA layers in MiMo-V2.5. + # config.add_swa_attention_sink_bias=True, add_full_attention_sink_bias=False. + if is_swa and config.get("add_swa_attention_sink_bias", False): + sink = lazy.get(f"{prefix}self_attn.attention_sink_bias") + if sink is not None: + out[f"{out_prefix}self_attn.attention_sink_bias"] = sink.detach().clone() + elif not is_swa and config.get("add_full_attention_sink_bias", False): + sink = lazy.get(f"{prefix}self_attn.attention_sink_bias") + if sink is not None: + out[f"{out_prefix}self_attn.attention_sink_bias"] = sink.detach().clone() + + # --- MLP: dense vs MoE --- + if is_dense: + # Dense MLP: gate_proj, up_proj, down_proj (FP8 blockwise in MiMo-V2.5 layer 0). + for proj in ("gate_proj", "up_proj", "down_proj"): + w = lazy.get(f"{prefix}mlp.{proj}.weight") + if w is None: + continue + s = lazy.get(f"{prefix}mlp.{proj}.weight_scale_inv") + w2, s2 = _maybe_fp8_to_neuron_per_row(w, s) + out[f"{out_prefix}mlp.{proj}.weight"] = w2 + if s2 is not None: + out[f"{out_prefix}mlp.{proj}.scale"] = s2 + return out + + # --- MoE --- + # Router: mlp.gate -> mlp.router.linear_router + router_w = lazy.get(f"{prefix}mlp.gate.weight") + if router_w is not None: + out[f"{out_prefix}mlp.router.linear_router.weight"] = router_w.detach().clone() + router_bias = lazy.get(f"{prefix}mlp.gate.e_score_correction_bias") + if router_bias is not None: + out[f"{out_prefix}mlp.router.e_score_correction_bias"] = router_bias.detach().clone() + + num_experts = config["n_routed_experts"] + + # Peek expert 0 to learn shapes/dtypes. + e0_gw = lazy.get(f"{prefix}mlp.experts.0.gate_proj.weight") + if e0_gw is None: + return out # no experts (shouldn't happen for MoE layers, but be safe) + e0_gs = lazy.get(f"{prefix}mlp.experts.0.gate_proj.weight_scale_inv") + + if e0_gw.dtype == torch.float8_e4m3fn and e0_gs is not None: + sample_w, sample_s = rescale_fp8_weight_blockwise(e0_gw, e0_gs) + elif e0_gw.dtype == torch.bfloat16: + # Should not happen for MiMo-V2.5 (experts ship in FP8); flag loudly. + raise NotImplementedError( + f"Layer {layer_idx} expert 0 gate_proj is BF16; MiMo-V2.5 expects FP8." + ) + else: + sample_w, sample_s = e0_gw, e0_gs + + intermediate_size, hidden_size = sample_w.shape # [IM, H] + # Packed transpose layout: [num_experts, H, 2*IM] for gate_up. + gate_up_proj = torch.empty( + num_experts, hidden_size, 2 * intermediate_size, dtype=sample_w.dtype + ) + i_blocks, h_blocks = sample_s.shape # [IM_blocks, H_blocks] + gate_up_scale = torch.empty( + num_experts, h_blocks, 2 * i_blocks, dtype=sample_s.dtype + ) + + e0_dw = lazy.get(f"{prefix}mlp.experts.0.down_proj.weight") + e0_ds = lazy.get(f"{prefix}mlp.experts.0.down_proj.weight_scale_inv") + if e0_dw.dtype == torch.float8_e4m3fn and e0_ds is not None: + sample_dw, sample_ds = rescale_fp8_weight_blockwise(e0_dw, e0_ds) + else: + raise NotImplementedError( + f"Layer {layer_idx} expert 0 down_proj dtype {e0_dw.dtype} not handled." + ) + d_h_blocks, d_i_blocks = sample_ds.shape # [H_blocks, IM_blocks] + down_proj = torch.empty( + num_experts, intermediate_size, hidden_size, dtype=sample_dw.dtype + ) + down_scale = torch.empty( + num_experts, d_i_blocks, d_h_blocks, dtype=sample_ds.dtype + ) + + # Slot expert 0 (already rescaled above). + gate_up_proj[0, :, :intermediate_size] = sample_w.T + gate_up_scale[0, :, :i_blocks] = sample_s.T + e0_uw = lazy.get(f"{prefix}mlp.experts.0.up_proj.weight") + e0_us = lazy.get(f"{prefix}mlp.experts.0.up_proj.weight_scale_inv") + up_w0, up_s0 = rescale_fp8_weight_blockwise(e0_uw, e0_us) + gate_up_proj[0, :, intermediate_size:] = up_w0.T + gate_up_scale[0, :, i_blocks:] = up_s0.T + down_proj[0] = sample_dw.T + down_scale[0] = sample_ds.T + del e0_gw, e0_gs, e0_uw, e0_us, e0_dw, e0_ds + del sample_w, sample_s, sample_dw, sample_ds, up_w0, up_s0 + + for e in range(1, num_experts): + gw = lazy.get(f"{prefix}mlp.experts.{e}.gate_proj.weight") + gs = lazy.get(f"{prefix}mlp.experts.{e}.gate_proj.weight_scale_inv") + uw = lazy.get(f"{prefix}mlp.experts.{e}.up_proj.weight") + us = lazy.get(f"{prefix}mlp.experts.{e}.up_proj.weight_scale_inv") + dw = lazy.get(f"{prefix}mlp.experts.{e}.down_proj.weight") + ds = lazy.get(f"{prefix}mlp.experts.{e}.down_proj.weight_scale_inv") + g_w, g_s = rescale_fp8_weight_blockwise(gw, gs) + u_w, u_s = rescale_fp8_weight_blockwise(uw, us) + d_w, d_s = rescale_fp8_weight_blockwise(dw, ds) + gate_up_proj[e, :, :intermediate_size] = g_w.T + gate_up_proj[e, :, intermediate_size:] = u_w.T + gate_up_scale[e, :, :i_blocks] = g_s.T + gate_up_scale[e, :, i_blocks:] = u_s.T + down_proj[e] = d_w.T + down_scale[e] = d_s.T + del gw, gs, uw, us, dw, ds, g_w, g_s, u_w, u_s, d_w, d_s + + out[f"{out_prefix}mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj + out[f"{out_prefix}mlp.expert_mlps.mlp_op.gate_up_proj.scale"] = gate_up_scale + out[f"{out_prefix}mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj + out[f"{out_prefix}mlp.expert_mlps.mlp_op.down_proj.scale"] = down_scale + return out + + +# --------------------------------------------------------------------------- +# Shard saving / index +# --------------------------------------------------------------------------- + +def save_shard( + tensors: Dict[str, torch.Tensor], + save_path: str, + filename: str, + weight_map: Dict[str, str], +) -> int: + """Save a sub-state-dict; clone tensors so safetensors doesn't complain + about views of mmapped storage. Returns bytes written.""" + path = os.path.join(save_path, filename) + materialized: Dict[str, torch.Tensor] = {} + total_bytes = 0 + for k, v in tensors.items(): + if not v.is_contiguous(): + v = v.contiguous() + v = v.detach().clone() + materialized[k] = v + total_bytes += v.numel() * v.element_size() + save_file(materialized, path) + for k in materialized.keys(): + weight_map[k] = filename + del materialized + return total_bytes + + +# --------------------------------------------------------------------------- +# Main driver +# --------------------------------------------------------------------------- + +def process_flash_checkpoint(hf_model_path: str, save_path: str, tp_degree: int): + os.makedirs(save_path, exist_ok=True) + + with open(os.path.join(hf_model_path, "model.safetensors.index.json")) as f: + weight_map_in = json.load(f)["weight_map"] + + with open(os.path.join(hf_model_path, "config.json")) as f: + config = json.load(f) + + num_layers = config["num_hidden_layers"] + hybrid = config.get("hybrid_layer_pattern", [0] * num_layers) + moe_freq = config.get("moe_layer_freq", [1] * num_layers) + + print( + f"Processing {num_layers} decoder layers" + f" (full={sum(1 for v in hybrid if v == 0)}," + f" swa={sum(1 for v in hybrid if v == 1)}," + f" dense={sum(1 for v in moe_freq if v == 0)}," + f" moe={sum(1 for v in moe_freq if v == 1)})", + flush=True, + ) + + lazy = LazyWeightMap(hf_model_path, weight_map_in) + weight_map_out: Dict[str, str] = {} + + try: + for li in range(num_layers): + t0 = time.time() + is_dense = moe_freq[li] == 0 + is_swa = hybrid[li] == 1 + layer_sd = process_layer(li, lazy, config, is_dense=is_dense, is_swa=is_swa) + filename = f"model_layer{li}.safetensors" + size = save_shard(layer_sd, save_path, filename, weight_map_out) + del layer_sd + gc.collect() + tag = "dense" if is_dense else "moe " + attn = "swa " if is_swa else "full" + print( + f" layer {li:2d} [{tag} {attn}] {size/1e9:6.2f} GB in {time.time()-t0:5.1f}s", + flush=True, + ) + + print("Processing embed_tokens, norm, lm_head ...", flush=True) + extras: Dict[str, torch.Tensor] = {} + for src, dst in ( + ("model.embed_tokens.weight", "embed_tokens.weight"), + ("model.norm.weight", "norm.weight"), + ("lm_head.weight", "lm_head.weight"), + ): + t = lazy.get(src) + if t is not None: + extras[dst] = t.detach().clone() + else: + print(f" WARNING: missing {src}", flush=True) + if "lm_head.weight" not in extras and "embed_tokens.weight" in extras: + # Tied embeddings + extras["lm_head.weight"] = extras["embed_tokens.weight"].detach().clone() + save_shard(extras, save_path, "model_extras.safetensors", weight_map_out) + del extras + finally: + lazy.close() + + # --- Index file --- + total_size = 0 + for f in set(weight_map_out.values()): + total_size += os.path.getsize(os.path.join(save_path, f)) + index = { + "metadata": {"total_size": total_size}, + "weight_map": weight_map_out, + } + with open(os.path.join(save_path, "model.safetensors.index.json"), "w") as f: + json.dump(index, f, indent=2) + + # --- Copy auxiliary files (config.json, tokenizer, chat template, + # and crucially the trust_remote_code modules the HF config references). + for name in sorted(os.listdir(hf_model_path)): + if name.endswith(".safetensors"): + continue + if name == "model.safetensors.index.json": + continue + src = os.path.join(hf_model_path, name) + if os.path.isfile(src): + shutil.copy(src, os.path.join(save_path, name)) + + # --- Rewrite architectures in the copied config.json so vLLM's + # pydantic ModelConfig validator accepts the checkpoint. vLLM's + # builtin supported-archs list (Xiaomi's upstream PR) contains + # MiMoV2FlashForCausalLM but not the V2.5 arch `MiMoV2ForCausalLM`, + # and the vllm-neuron plugin registers contribs too late to patch + # that list. The NxDI side loads via auto_map + trust_remote_code, + # so the arch name only has to survive the vLLM pydantic check. + # auto_map still points at the V2.5 modeling/configuration modules. + cfg_path = os.path.join(save_path, "config.json") + if os.path.isfile(cfg_path): + with open(cfg_path) as _f: + _cfg = json.load(_f) + if _cfg.get("architectures") == ["MiMoV2ForCausalLM"]: + _cfg["architectures"] = ["MiMoV2FlashForCausalLM"] + with open(cfg_path, "w") as _f: + json.dump(_cfg, _f, indent=2) + + print(f"\nPreprocess complete. total_size={total_size/1e9:.2f} GB", flush=True) + print(f" tensors written: {len(weight_map_out)}", flush=True) + print(f" output dir: {save_path}", flush=True) + + +def main(): + parser = argparse.ArgumentParser( + description="Preprocess MiMo-V2.5 FP8 checkpoint for Neuron inference" + ) + parser.add_argument("--hf_model_path", required=True) + parser.add_argument("--save_path", required=True) + parser.add_argument("--tp_degree", type=int, default=64, + help="Tensor parallelism (currently informational only; " + "the framework does the TP sharding at load time).") + args = parser.parse_args() + process_flash_checkpoint(args.hf_model_path, args.save_path, args.tp_degree) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/MiMo-V2.5/src/modeling_mimo_v2.py b/contrib/models/MiMo-V2.5/src/modeling_mimo_v2.py new file mode 100644 index 00000000..947dfe2c --- /dev/null +++ b/contrib/models/MiMo-V2.5/src/modeling_mimo_v2.py @@ -0,0 +1,1677 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# This implementation is based on the MiMo-V2.5 model from Xiaomi. +# Reference: https://huggingface.co/XiaomiMiMo/MiMo-V2.5 + +"""MiMo-V2.5 model for NXD inference.""" + +import gc +import math +import warnings +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers.mappings import ( + gather_from_sequence_parallel_region, + gather_from_tensor_model_parallel_region_with_dim, +) +from neuronx_distributed.utils import cpu_mode + +from neuronx_distributed_inference.utils.distributed import ( + split_along_dim, + get_cp_rank, +) +from neuronx_distributed_inference.modules.attention.attention_process_groups import ( + get_context_parallel_attention_cp_group, +) + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + MoENeuronConfig, +) +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.moe_v2 import initialize_moe_module + +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel + +from torch_neuronx.xla_impl.ops import nki_jit + +_flash_fwd_call = nki_jit()(attention_isa_kernel) + + +def get_rmsnorm_cls(): + """Get appropriate RMSNorm class based on execution environment.""" + return MiMoV2RMSNorm if cpu_mode() else CustomRMSNorm + + +class MiMoV2RMSNorm(nn.Module): + """RMSNorm implementation for CPU mode.""" + + def __init__(self, hidden_size: int, eps: float = 1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class MiMoV2RotaryEmbedding(nn.Module): + """Rotary Position Embedding for MiMo-V2.5. + + Supports partial rotary embedding where only a fraction of dimensions + use rotary position encoding. + """ + + def __init__( + self, + dim: int, + max_position_embeddings: int = 262144, + base: float = 5000000.0, + partial_rotary_factor: float = 1.0, + ): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.partial_rotary_factor = partial_rotary_factor + + # Calculate the actual dimension used for rotary embedding + self.rope_dim = int(dim * partial_rotary_factor) + # Ensure rope_dim is even + self.rope_dim = self.rope_dim - (self.rope_dim % 2) + + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.rope_dim, 2, dtype=torch.float32) / self.rope_dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute rotary embeddings. + + Args: + x: Input tensor of shape (batch_size, seq_len, hidden_size) + position_ids: Position indices of shape (batch_size, seq_len) + + Returns: + Tuple of (cos, sin) tensors for rotary embedding + """ + inv_freq_expanded = self.inv_freq[None, :, None].float().expand( + position_ids.shape[0], -1, 1 + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotate half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + unsqueeze_dim: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply rotary position embedding to query and key tensors.""" + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MiMoV2InferenceConfig(InferenceConfig): + """Configuration class for MiMo-V2.5 inference on Neuron.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # MoE configuration + self.num_local_experts = self.n_routed_experts + self.n_shared_experts = 0 # MiMo-V2.5 has no shared experts + + # Set intermediate_size for MoE layers + self.intermediate_size = self.moe_intermediate_size + + # Check and pad intermediate size if needed + self.maybe_pad_intermediate() + + # Router configuration + self.neuron_config.router_config.dtype = torch.float32 + self.neuron_config.router_config.act_fn = "sigmoid" # MiMo uses sigmoid + + # Disable numeric CC token as workaround + self.neuron_config.disable_numeric_cc_token = True + + # MiMo normalizes top-k affinities + self.neuron_config.normalize_top_k_affinities = True + + # Parse hybrid layer pattern + self._parse_hybrid_pattern() + + def _parse_hybrid_pattern(self): + """Parse hybrid layer pattern to determine attention types.""" + if hasattr(self, 'hybrid_layer_pattern') and self.hybrid_layer_pattern: + self.layer_attention_types = [ + "sliding_window" if p == 1 else "full" + for p in self.hybrid_layer_pattern + ] + else: + self.layer_attention_types = ["full"] * self.num_hidden_layers + + # Parse MoE layer frequency + if hasattr(self, 'moe_layer_freq') and self.moe_layer_freq: + self.layer_uses_moe = [bool(f) for f in self.moe_layer_freq] + else: + self.layer_uses_moe = [True] * self.num_hidden_layers + + def maybe_pad_intermediate(self): + """Pad intermediate size if required for efficient computation.""" + from neuronx_distributed_inference.models.config import ( + SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP, + ) + + moe_tp_degree = self.neuron_config.moe_tp_degree + I_TP = self.moe_intermediate_size // moe_tp_degree + + if getattr( + self.neuron_config.blockwise_matmul_config, + "use_shard_on_intermediate_dynamic_while", + False, + ): + if I_TP % SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP != 0: + padded_size = ( + math.ceil(I_TP / SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP) + * SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP + * moe_tp_degree + ) + self.moe_intermediate_pad_size = max( + padded_size - self.moe_intermediate_size, 0 + ) + self.moe_intermediate_size = padded_size + + def get_required_attributes(self) -> List[str]: + return [ + "attention_bias", + "head_dim", + "hidden_act", + "hidden_size", + "hybrid_layer_pattern", + "layernorm_epsilon", + "max_position_embeddings", + "moe_intermediate_size", + "moe_layer_freq", + "n_routed_experts", + "norm_topk_prob", + "num_attention_heads", + "num_experts_per_tok", + "num_hidden_layers", + "num_key_value_heads", + "partial_rotary_factor", + "rope_theta", + "scoring_func", + "sliding_window", + "swa_head_dim", + "swa_num_attention_heads", + "swa_num_key_value_heads", + "swa_rope_theta", + "swa_v_head_dim", + "tie_word_embeddings", + "v_head_dim", + "vocab_size", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[MoENeuronConfig]: + return MoENeuronConfig + + +class NeuronMiMoV2Attention(NeuronAttentionBase): + """MiMo-V2.5 Attention implementation supporting hybrid attention patterns. + + Supports both full attention and sliding window attention with different + head dimensions for Q/K vs V. + """ + + def __init__( + self, + config: MiMoV2InferenceConfig, + layer_idx: int, + is_sliding_window: bool = False, + ): + self.layer_idx = layer_idx + self.is_sliding_window = is_sliding_window + + # Select parameters based on attention type + if is_sliding_window: + self.attn_head_dim = config.swa_head_dim + self.attn_v_head_dim = config.swa_v_head_dim + self.attn_num_heads = config.swa_num_attention_heads + self.attn_num_kv_heads = config.swa_num_key_value_heads + rope_theta = getattr(config, 'swa_rope_theta', 10000.0) + self.sliding_window_size = config.sliding_window + else: + self.attn_head_dim = config.head_dim + self.attn_v_head_dim = config.v_head_dim + self.attn_num_heads = config.num_attention_heads + self.attn_num_kv_heads = config.num_key_value_heads + rope_theta = config.rope_theta + self.sliding_window_size = None + + # Calculate partial rotary dimensions + self.partial_rotary_factor = config.partial_rotary_factor + self.rope_dim = int(self.attn_head_dim * self.partial_rotary_factor) + self.rope_dim = self.rope_dim - (self.rope_dim % 2) # Ensure even + self.nope_dim = self.attn_head_dim - self.rope_dim + + # Create rotary embedding + rotary_emb = MiMoV2RotaryEmbedding( + dim=self.attn_head_dim, + max_position_embeddings=config.max_position_embeddings, + base=rope_theta, + partial_rotary_factor=self.partial_rotary_factor, + ) + + # Initialize base attention + # NOTE: We pass v_head_dim to base class, but MiMo uses asymmetric Q/K (192) vs V (128). + # We override init_gqa_properties() to prevent the base class from creating + # incompatible projection layers (which cause crashes when CP > 1). + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=self.attn_num_heads, + num_key_value_heads=self.attn_num_kv_heads, + head_dim=self.attn_v_head_dim, # Use v_head_dim for base class + rotary_emb=rotary_emb, + rms_norm_eps=config.layernorm_epsilon, + use_qk_norm=False, + ) + + # Initialize MiMo-specific projections with correct dimensions + self._init_projections(config) + + # Scaling factor + self.scaling = self.attn_head_dim ** -0.5 + # HF MiMoV2Attention (modeling_mimo_v2.py) multiplies value_states + # by config.attention_value_scale (0.707 for MiMo-V2) right after the V + # projection, before attention softmax*V. Matching that here — applied + # to value_states in forward() rather than to attn_output. + self.value_scale = float(getattr(config, "attention_value_scale", 1.0)) + + # Store cache KV heads for cache compatibility + # With CONVERT_TO_MHA, all layers have num_attention_heads KV heads + # Otherwise, use max of full and sliding window kv heads + tp_degree = config.neuron_config.tp_degree + if self.use_gqa_convert_to_mha: + # CONVERT_TO_MHA: cache stores num_attention_heads (same as Q heads) + self.cache_num_kv_heads = self.attn_num_heads + self.local_cache_kv_heads = self.local_num_heads + else: + # Standard GQA: cache uses max of full and sliding window kv heads + self.cache_num_kv_heads = max( + config.num_key_value_heads, + getattr(config, 'swa_num_key_value_heads', config.num_key_value_heads) + ) + self.local_cache_kv_heads = max(1, self.cache_num_kv_heads // tp_degree) + + def init_gqa_properties(self): + """Override base class to prevent creating incompatible QKV projections. + + MiMo-V2.5 has asymmetric Q/K head_dim (192) vs V head_dim (128), + which is incompatible with the base class's GroupQueryAttention_QKV. + MiMo uses its own custom projections via _init_projections() instead. + + When CP > 1, the base class would create cte_qkv_proj/tkg_qkv_proj with + wrong head_dim=128, causing compilation crashes. This no-op prevents that. + """ + pass + + def _init_projections(self, config: MiMoV2InferenceConfig): + """Initialize projection layers with correct dimensions. + + When CONVERT_TO_MHA is needed (tp_degree > num_kv_heads), K/V projections + are sized for num_attention_heads (not original num_kv_heads). The checkpoint + weights are replicated in preshard_hook before loading. + """ + dtype = config.neuron_config.torch_dtype + tp_degree = config.neuron_config.tp_degree + + # Check if we need GQA CONVERT_TO_MHA (when tp_degree > num_kv_heads) + self.use_gqa_convert_to_mha = tp_degree > self.attn_num_kv_heads + + # Store source heads for preshard_hook + self._src_num_kv_heads = self.attn_num_kv_heads + self._kv_replication_factor = self.attn_num_heads // self.attn_num_kv_heads if self.use_gqa_convert_to_mha else 1 + + if self.use_gqa_convert_to_mha: + # CONVERT_TO_MHA: K and V use num_attention_heads for proper TP splitting + k_num_heads = self.attn_num_heads + v_num_heads = self.attn_num_heads + else: + k_num_heads = self.attn_num_kv_heads + v_num_heads = self.attn_num_kv_heads + + # Q/K use head_dim, V uses v_head_dim + q_hidden_size = self.attn_num_heads * self.attn_head_dim + k_hidden_size = k_num_heads * self.attn_head_dim + v_hidden_size = v_num_heads * self.attn_v_head_dim + o_hidden_size = self.attn_num_heads * self.attn_v_head_dim + + if parallel_state.model_parallel_is_initialized(): + tp_group = parallel_state.get_tensor_model_parallel_group() + + # Q projection + self.q_proj = ColumnParallelLinear( + config.hidden_size, + q_hidden_size, + bias=config.attention_bias, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + + # K projection + self.k_proj = ColumnParallelLinear( + config.hidden_size, + k_hidden_size, + bias=config.attention_bias, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + + # V projection + self.v_proj = ColumnParallelLinear( + config.hidden_size, + v_hidden_size, + bias=config.attention_bias, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + + # Output projection - with sequence parallel to scatter output + self.o_proj = RowParallelLinear( + o_hidden_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + dtype=dtype, + tensor_model_parallel_group=tp_group, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=1 if self.sequence_parallel_enabled else None, + ) + + # Calculate local dimensions after TP split + self.local_num_heads = self.attn_num_heads // tp_degree + if self.use_gqa_convert_to_mha: + # With CONVERT_TO_MHA, local KV heads = local Q heads + self.local_num_kv_heads = self.local_num_heads + else: + self.local_num_kv_heads = max(1, self.attn_num_kv_heads // tp_degree) + else: + self.q_proj = nn.Linear(config.hidden_size, q_hidden_size, bias=config.attention_bias) + self.k_proj = nn.Linear(config.hidden_size, k_hidden_size, bias=config.attention_bias) + self.v_proj = nn.Linear(config.hidden_size, v_hidden_size, bias=config.attention_bias) + self.o_proj = nn.Linear(o_hidden_size, config.hidden_size, bias=False) + + self.local_num_heads = self.attn_num_heads + self.local_num_kv_heads = k_num_heads + + # Override base class attributes that were computed with wrong head_dim + # The base class init_gqa_properties() uses head_dim=v_head_dim which is wrong for Q/K + # We need to override these to ensure correct computation + self.num_heads = self.local_num_heads + self.num_key_value_heads = self.local_num_kv_heads + self.num_key_value_groups = self.local_num_heads // self.local_num_kv_heads + self.head_dim = self.attn_head_dim # Override to use actual Q/K head_dim (192) + + # Remove qkv_proj from base class if exists (we use separate q_proj, k_proj, v_proj) + if hasattr(self, 'qkv_proj'): + self.qkv_proj = None + + # Attention sink bias for attention layers (following HF implementation) + # This is a learnable parameter that allows attention to "sink" to an extra position + add_full_attention_sink_bias = getattr(config, 'add_full_attention_sink_bias', False) + add_swa_attention_sink_bias = getattr(config, 'add_swa_attention_sink_bias', True) + + # Determine if this layer uses sink bias based on config + self._use_sink_bias = (add_full_attention_sink_bias and not self.is_sliding_window) or \ + (add_swa_attention_sink_bias and self.is_sliding_window) + + if self._use_sink_bias: + # Shape: [num_attention_heads] - will be split across TP ranks + # The weight is loaded from checkpoint with shape [num_attention_heads] + # and will be sliced to [local_num_heads] during forward + self.attention_sink_bias = nn.Parameter( + torch.zeros(self.attn_num_heads, dtype=dtype), requires_grad=False + ) + else: + self.attention_sink_bias = None + + def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: + """Pre-shard hook to replicate K/V weights for CONVERT_TO_MHA. + + NOTE: This method is NOT currently called because NeuronMiMoV2Attention + is not a BaseGroupQueryAttention subclass. K/V weight replication is + instead done in convert_mimo_v2_hf_to_neuron_state_dict(). + + This method is kept for reference and potential future use. + """ + # This hook is not called - see note above + return False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[torch.Tensor] = None, + cos_cache: Optional[torch.Tensor] = None, + sin_cache: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """Forward pass for MiMo-V2.5 attention with Context Parallelism support.""" + + # Context Parallelism: only active during context encoding (no past_key_value) + is_context_parallel = past_key_value is None and self.cp_degree > 1 + cp_rank = None + + if is_context_parallel: + cp_rank = get_cp_rank( + self.rank_util.get_rank(), self.tp_degree, + self.cp_degree, self.neuron_config.switch_cc, + ) + # Split attention_mask (dim=2 = Q rows) and position_ids (dim=1 = seq) + attention_mask = split_along_dim( + attention_mask, dim=2, rank=cp_rank, num_partitions=self.cp_degree + ) + # Keep full position_ids for RoPE computation on full-length K/V + local_position_ids = split_along_dim( + position_ids, dim=1, rank=cp_rank, num_partitions=self.cp_degree + ) + + # Handle sequence parallel + if self.sequence_parallel_enabled and parallel_state.model_parallel_is_initialized(): + hidden_states = gather_from_sequence_parallel_region( + hidden_states, + self.sequence_dimension, + process_group=parallel_state.get_tensor_model_parallel_group(), + ) + + # Context Parallelism without sequence parallel: split hidden_states + if is_context_parallel and not self.sequence_parallel_enabled: + hidden_states = split_along_dim( + hidden_states, dim=1, rank=cp_rank, num_partitions=self.cp_degree + ) + + bsz, q_len, _ = hidden_states.size() + + # Determine if this is token generation (past_key_value is not None) + is_token_gen = past_key_value is not None + + # Project Q, K, V + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # HF MiMoV2Attention scales V by attention_value_scale (0.707 for MiMo-V2) + # right after v_proj, before the attention softmax*V. Earlier revisions + # of this file applied it post-attention or not at all; both produce + # gibberish for prompts longer than ~20 tokens. + if self.value_scale != 1.0: + value_states = value_states * self.value_scale + + # Reshape for multi-head attention: [bsz, num_heads, seq_len, head_dim] + query_states = query_states.view(bsz, q_len, self.local_num_heads, self.attn_head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.local_num_kv_heads, self.attn_head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.local_num_kv_heads, self.attn_v_head_dim).transpose(1, 2) + + # Split into rope and non-rope parts + query_rope = query_states[..., :self.rope_dim] + query_nope = query_states[..., self.rope_dim:] + key_rope = key_states[..., :self.rope_dim] + key_nope = key_states[..., self.rope_dim:] + + # Compute rotary embeddings + # IMPORTANT: Always compute for this layer because different layer types + # (full vs sliding window) use different rope_theta values. + # Full attention: rope_theta = 5000000 + # Sliding window: rope_theta = 10000 + # We cannot reuse cached cos/sin from other layers! + # + # For CP with sequence_parallel: Q/K/V have full S, use full position_ids for RoPE. + # For CP without sequence_parallel: Q/K/V have S/CP, use local_position_ids for RoPE + # (local_position_ids contain the correct global positions for this CP rank). + if is_context_parallel and not self.sequence_parallel_enabled: + rope_position_ids = local_position_ids + else: + rope_position_ids = position_ids + cos_cache, sin_cache = self.rotary_emb(value_states, rope_position_ids) + + # Apply rotary position embedding to rope parts only + query_rope, key_rope = apply_rotary_pos_emb( + query_rope, key_rope, cos_cache, sin_cache, rope_position_ids + ) + + # Concatenate rope and non-rope parts + query_states = torch.cat([query_rope, query_nope], dim=-1) + key_states = torch.cat([key_rope, key_nope], dim=-1) + + # Context Parallelism: split Q and save local KV for cache + if is_context_parallel: + if self.sequence_parallel_enabled: + # Q/K/V have full S. Split Q to local portion, save local KV for cache. + # Use split_along_dim (torch.index_select) instead of Python slicing + # because XLA tracing doesn't support dynamic tensor indices in slice notation. + query_states = split_along_dim(query_states, dim=2, rank=cp_rank, num_partitions=self.cp_degree) + key_states_for_cache = split_along_dim(key_states, dim=2, rank=cp_rank, num_partitions=self.cp_degree) + value_states_for_cache = split_along_dim(value_states, dim=2, rank=cp_rank, num_partitions=self.cp_degree) + q_len = q_len // self.cp_degree + # K/V stay at full S for attention computation + else: + # Q/K/V have S/CP. Save local KV for cache, then all-gather K/V. + key_states_for_cache = key_states + value_states_for_cache = value_states + key_states = gather_from_tensor_model_parallel_region_with_dim( + key_states, gather_dim=2, + process_group=get_context_parallel_attention_cp_group(), + ) + value_states = gather_from_tensor_model_parallel_region_with_dim( + value_states, gather_dim=2, + process_group=get_context_parallel_attention_cp_group(), + ) + # Q stays at S/CP + else: + # Store key/value states BEFORE GQA repeat for KV cache + key_states_for_cache = key_states + value_states_for_cache = value_states + + # WORKAROUND 1: Pad V from v_head_dim (128) to head_dim (192) for KV cache compatibility + if self.attn_v_head_dim < self.attn_head_dim: + pad_size = self.attn_head_dim - self.attn_v_head_dim + value_states_for_cache = F.pad(value_states_for_cache, (0, pad_size), value=0.0) + + # WORKAROUND 2: Pad KV heads if layer has fewer than cache expects + # Only needed when NOT using CONVERT_TO_MHA (standard GQA mode) + if not self.use_gqa_convert_to_mha and self.local_num_kv_heads < self.local_cache_kv_heads: + # Pad KV heads by repeating + repeat_factor = self.local_cache_kv_heads // self.local_num_kv_heads + key_states_for_cache = key_states_for_cache.repeat(1, repeat_factor, 1, 1) + value_states_for_cache = value_states_for_cache.repeat(1, repeat_factor, 1, 1) + + # Repeat KV heads for GQA (only needed without CONVERT_TO_MHA) + # With CONVERT_TO_MHA, K/V already have num_attention_heads + num_key_value_groups = self.local_num_heads // self.local_num_kv_heads + if num_key_value_groups > 1: + key_states = key_states.repeat_interleave(num_key_value_groups, dim=1) + value_states = value_states.repeat_interleave(num_key_value_groups, dim=1) + + if is_token_gen: + # Token generation: use decomposed attention with prior (cached) and active (current) KV + # past_key_value[0] = cached K, shape [bsz, cache_kv_heads, kv_seq_len, head_dim] + # past_key_value[1] = cached V, shape [bsz, cache_kv_heads, kv_seq_len, head_dim] (padded) + K_prior = past_key_value[0] + V_prior = past_key_value[1] + + # WORKAROUND 1: Slice KV heads if cache has more than layer needs + # Only needed when NOT using CONVERT_TO_MHA (standard GQA mode) + # With CONVERT_TO_MHA, cache and layer have same num_kv_heads + if not self.use_gqa_convert_to_mha and self.local_num_kv_heads < self.local_cache_kv_heads: + # Cache has repeated heads, just take the first local_num_kv_heads + K_prior = K_prior[:, :self.local_num_kv_heads, :, :] + V_prior = V_prior[:, :self.local_num_kv_heads, :, :] + + # WORKAROUND 2: Slice V_prior back to v_head_dim (128) from head_dim (192) + if self.attn_v_head_dim < self.attn_head_dim: + V_prior = V_prior[..., :self.attn_v_head_dim] + + # Repeat cached KV for GQA (only needed without CONVERT_TO_MHA) + # With CONVERT_TO_MHA, cached K/V already have num_attention_heads + if num_key_value_groups > 1: + K_prior = K_prior.repeat_interleave(num_key_value_groups, dim=1) + V_prior = V_prior.repeat_interleave(num_key_value_groups, dim=1) + + # Compute attention on prior (cached) KV + # K_prior shape: [bsz, num_heads, kv_seq_len, head_dim] + prior_scores = torch.matmul(query_states, K_prior.transpose(-2, -1)) * self.scaling + + # Apply attention mask to prior scores + if attention_mask is not None: + # Convert boolean mask to additive mask if needed + if attention_mask.dtype == torch.bool: + prior_scores = prior_scores.masked_fill(~attention_mask, float('-inf')) + else: + prior_scores = prior_scores + attention_mask + + # Apply sliding window mask for SWA layers + if self.is_sliding_window and self.sliding_window_size is not None and position_ids is not None: + kv_seq_len = prior_scores.size(-1) + current_pos = position_ids[0, 0] + pos_indices = torch.arange(kv_seq_len, device=prior_scores.device) + sliding_mask = pos_indices >= (current_pos - self.sliding_window_size + 1) + sliding_mask = sliding_mask[None, None, None, :] + prior_scores = prior_scores.masked_fill(~sliding_mask, float('-inf')) + + prior_scores = prior_scores.to(torch.float32) + + # Compute attention on active (current) KV + active_scores = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling + active_scores = active_scores.to(torch.float32) + + # Combined softmax over prior and active scores + all_scores = torch.cat([prior_scores, active_scores], dim=-1) + + # Add attention sink bias (following HF implementation) + # This must be applied to token generation as well! + use_sink = self._use_sink_bias and self.attention_sink_bias is not None + if use_sink: + tp_rank = parallel_state.get_tensor_model_parallel_rank() if parallel_state.model_parallel_is_initialized() else 0 + local_sink = self.attention_sink_bias[tp_rank * self.local_num_heads:(tp_rank + 1) * self.local_num_heads] + sink_bias = local_sink.reshape(1, -1, 1, 1).expand(bsz, -1, q_len, 1) + all_scores = torch.cat([all_scores, sink_bias], dim=-1) + + # Numerical stability: subtract max before softmax + all_scores = all_scores - all_scores.max(dim=-1, keepdim=True).values + attn_weights = F.softmax(all_scores, dim=-1, dtype=torch.float32) + + # Drop the sink column after softmax + if use_sink: + attn_weights = attn_weights[..., :-1] + + # Split attention weights back + prior_weights = attn_weights[..., :-q_len].to(V_prior.dtype) + active_weights = attn_weights[..., -q_len:].to(value_states.dtype) + + # Compute attention outputs + attn_prior = torch.matmul(prior_weights, V_prior) + attn_active = torch.matmul(active_weights, value_states) + attn_output = attn_prior + attn_active + else: + # Context encoding: standard attention + # With CP: Q is local [B, H, S/CP, D], K/V are full [B, H, S, D] + # Without CP: Q/K/V all have same seq_len + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling + + # Apply attention mask (additive mask: 0 = attend, -inf = mask out) + # The framework creates boolean masks, so we need to convert them + # With CP: attention_mask is already split to [B, 1, S/CP, S] (local Q rows, full K cols) + if attention_mask is not None: + # Convert boolean mask to additive mask if needed + if attention_mask.dtype == torch.bool: + # True = attend (0), False = mask (-inf) + additive_mask = torch.zeros_like(attn_weights) + additive_mask = additive_mask.masked_fill(~attention_mask, float('-inf')) + attn_weights = attn_weights + additive_mask + else: + # Already additive mask + attn_weights = attn_weights + attention_mask + + # Apply sliding window mask for SWA layers + if self.is_sliding_window and self.sliding_window_size is not None: + kv_seq_len = attn_weights.size(-1) + if is_context_parallel: + # With CP: Q has local seq len, K has full seq len. + # Use local_position_ids for correct global Q positions. + row_idx = local_position_ids[0].unsqueeze(1).to(attn_weights.device) + else: + row_idx = torch.arange(kv_seq_len, device=attn_weights.device).unsqueeze(1) + col_idx = torch.arange(kv_seq_len, device=attn_weights.device).unsqueeze(0) + # Causal: col <= row, and within window: col >= row - window_size + 1 + sliding_mask = (col_idx <= row_idx) & (col_idx >= row_idx - self.sliding_window_size + 1) + sliding_mask = sliding_mask[None, None, :, :] + # Convert to additive mask + attn_weights = attn_weights.masked_fill(~sliding_mask, float('-inf')) + + # Add attention sink bias (following HF implementation) + # This adds an extra "sink" column to attention weights + use_sink = self._use_sink_bias and self.attention_sink_bias is not None + if use_sink: + # Get local portion of sink bias for this TP rank + tp_rank = parallel_state.get_tensor_model_parallel_rank() if parallel_state.model_parallel_is_initialized() else 0 + local_sink = self.attention_sink_bias[tp_rank * self.local_num_heads:(tp_rank + 1) * self.local_num_heads] + # Reshape and expand: [local_num_heads] -> [bsz, local_num_heads, q_len, 1] + sink_bias = local_sink.reshape(1, -1, 1, 1).expand(bsz, -1, q_len, 1) + attn_weights = torch.cat([attn_weights, sink_bias], dim=-1) + + # Numerical stability: subtract max before softmax (like HF implementation) + attn_weights = attn_weights - attn_weights.max(dim=-1, keepdim=True).values + + # Softmax + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32) + + # Drop the sink column after softmax + if use_sink: + attn_weights = attn_weights[..., :-1] + + attn_weights = attn_weights.to(value_states.dtype) + + # Apply attention to values + attn_output = torch.matmul(attn_weights, value_states) + + # Reshape and project output + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.local_num_heads * self.attn_v_head_dim) + + # Context Parallelism: gather output across CP ranks BEFORE o_proj. + # With SP enabled, o_proj scatters along seq dim. The input must have full S + # (not S/CP), otherwise the SP-scattered output won't match the residual. + # Without SP, gather after o_proj to restore full seq_len for residual. + if is_context_parallel: + attn_output = gather_from_tensor_model_parallel_region_with_dim( + attn_output, gather_dim=1, + process_group=get_context_parallel_attention_cp_group(), + ) + + attn_output = self.o_proj(attn_output) + + # Prepare KV cache output - return as tuple for KV cache manager + # Return LOCAL key/value states for cache (each CP rank stores its portion) + new_key_value = (key_states_for_cache, value_states_for_cache) + + return attn_output, new_key_value, cos_cache, sin_cache + + +class MiMoV2MLP(nn.Module): + """Standard MLP for non-MoE layers in MiMo-V2.5.""" + + def __init__(self, config: MiMoV2InferenceConfig): + super().__init__() + self.hidden_size = config.hidden_size + # Use the dense intermediate size for non-MoE layers + self.intermediate_size = getattr(config, 'dense_intermediate_size', config.intermediate_size * 8) + + dtype = config.neuron_config.torch_dtype + + if parallel_state.model_parallel_is_initialized(): + tp_group = parallel_state.get_tensor_model_parallel_group() + + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + input_is_parallel=True, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + else: + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + self.act_fn = F.silu + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class NeuronMiMoV2DecoderLayer(nn.Module): + """MiMo-V2.5 Decoder Layer with hybrid attention and conditional MoE.""" + + def __init__(self, config: MiMoV2InferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + # Determine attention type for this layer + is_sliding_window = config.layer_attention_types[layer_idx] == "sliding_window" + self.attention_type = "sliding_window" if is_sliding_window else "full" + + # Create attention module + self.self_attn = NeuronMiMoV2Attention( + config=config, + layer_idx=layer_idx, + is_sliding_window=is_sliding_window, + ) + + # Determine if this layer uses MoE + self.uses_moe = config.layer_uses_moe[layer_idx] + + # Create MLP/MoE module + if self.uses_moe: + self.mlp = initialize_moe_module(config=config) + else: + self.mlp = MiMoV2MLP(config) + + # Layer norms + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.layernorm_epsilon, + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.layernorm_epsilon, + ) + + # Config flags + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + cos_cache: Optional[torch.Tensor] = None, + sin_cache: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, ...]: + """Forward pass for decoder layer.""" + + # Self attention with residual + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cos_cache=cos_cache, + sin_cache=sin_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # MLP/MoE with residual + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if self.uses_moe: + hidden_states = self.mlp(hidden_states, padding_mask)[0] + else: + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value, cos_cache, sin_cache, None) + return outputs + + +class NeuronMiMoV2Model(NeuronBaseModel): + """MiMo-V2.5 Model for NXD inference.""" + + def setup_attr_for_model(self, config: MiMoV2InferenceConfig): + self.on_device_sampling = config.neuron_config.on_device_sampling_config is not None + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + + # Check if we need GQA CONVERT_TO_MHA mode + # When tp_degree > num_kv_heads, we replicate K/V to match num_attention_heads + min_kv_heads = min( + config.num_key_value_heads, + getattr(config, 'swa_num_key_value_heads', config.num_key_value_heads) + ) + self.use_gqa_convert_to_mha = self.tp_degree > min_kv_heads + + if self.use_gqa_convert_to_mha: + # With CONVERT_TO_MHA, KV cache stores num_attention_heads (same as Q) + self.num_key_value_heads = config.num_attention_heads + else: + # Standard GQA: use the maximum num_kv_heads for KV cache + # (handles hybrid full/sliding window attention) + self.num_key_value_heads = max( + config.num_key_value_heads, + getattr(config, 'swa_num_key_value_heads', config.num_key_value_heads) + ) + + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + # MiMo has hybrid attention (full + sliding window) + # NOTE: Do NOT set self.sliding_window here because it affects KV cache size globally. + # MiMo handles sliding window per-layer in the attention module itself. + # Setting has_mixed_attn = True enables proper mask creation without affecting cache size. + self.has_mixed_attn = True + + def init_model(self, config: MiMoV2InferenceConfig): + self.padding_idx = getattr(config, 'pad_token_id', None) + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + + self.layers = nn.ModuleList([ + NeuronMiMoV2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + + self.norm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.layernorm_epsilon, + ) + + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + ) + + +def _replicate_kv_weights_for_convert_to_mha( + tensor: torch.Tensor, + source_heads: int, + target_heads: int, + head_dim: int, +) -> torch.Tensor: + """Replicate K/V weights from source_heads to target_heads for CONVERT_TO_MHA. + + Args: + tensor: Weight tensor of shape [source_heads * head_dim, hidden_size] + source_heads: Number of source KV heads + target_heads: Number of target heads (num_attention_heads) + head_dim: Head dimension + + Returns: + Replicated tensor of shape [target_heads * head_dim, hidden_size] + """ + if tensor is None or source_heads >= target_heads: + return tensor + + repeats = target_heads // source_heads + + # Reshape to [source_heads, head_dim, hidden_size] + original_shape = tensor.shape + tensor = tensor.view(source_heads, head_dim, -1) + + # Repeat along head dimension + tensor = tensor.repeat_interleave(repeats, dim=0) + + # Reshape back to [num_heads * head_dim, hidden_size] + tensor = tensor.view(-1, original_shape[-1]) + + return tensor + + +def convert_mimo_v2_hf_to_neuron_state_dict( + neuron_state_dict: Dict[str, Any], + config: MiMoV2InferenceConfig, +) -> Dict[str, Any]: + """Convert HuggingFace MiMo-V2.5 weights to Neuron format. + + This handles: + 1. Router weight renaming + 2. Expert weight concatenation and transposition + 3. FP8 dequantization if needed + 4. K/V weight replication for CONVERT_TO_MHA mode + """ + + assert config.neuron_config.glu_mlp is True, "Only GLU MLP is supported" + + # Dequantize layers if needed + _maybe_dequantize_layer(neuron_state_dict, config) + + # Add rank utility tensors + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, config.neuron_config.tp_degree, dtype=torch.int32 + ) + + # Determine if CONVERT_TO_MHA is needed + tp_degree = config.neuron_config.tp_degree + num_attention_heads = config.num_attention_heads + + # MiMo-V2.5 has different KV heads for full and sliding window attention + full_num_kv_heads = config.num_key_value_heads # 4 + swa_num_kv_heads = config.swa_num_key_value_heads # 8 + + # Check if we need to replicate K/V weights + full_use_convert_to_mha = tp_degree > full_num_kv_heads + swa_use_convert_to_mha = tp_degree > swa_num_kv_heads + + for layer_idx in range(config.num_hidden_layers): + # Add rank utility for attention + neuron_state_dict[f"layers.{layer_idx}.self_attn.rank_util.rank"] = torch.arange( + 0, config.neuron_config.tp_degree, dtype=torch.int32 + ) + + # Determine attention type for this layer + is_sliding_window = config.layer_attention_types[layer_idx] == "sliding_window" + + if is_sliding_window: + src_num_kv_heads = swa_num_kv_heads + use_convert_to_mha = swa_use_convert_to_mha + head_dim = config.swa_head_dim # 192 + v_head_dim = config.swa_v_head_dim # 128 + else: + src_num_kv_heads = full_num_kv_heads + use_convert_to_mha = full_use_convert_to_mha + head_dim = config.head_dim # 192 + v_head_dim = config.v_head_dim # 128 + + # Replicate K/V weights if CONVERT_TO_MHA is needed + if use_convert_to_mha: + k_proj_key = f"layers.{layer_idx}.self_attn.k_proj.weight" + v_proj_key = f"layers.{layer_idx}.self_attn.v_proj.weight" + + if k_proj_key in neuron_state_dict: + neuron_state_dict[k_proj_key] = _replicate_kv_weights_for_convert_to_mha( + neuron_state_dict[k_proj_key], + src_num_kv_heads, + num_attention_heads, + head_dim, + ) + + if v_proj_key in neuron_state_dict: + neuron_state_dict[v_proj_key] = _replicate_kv_weights_for_convert_to_mha( + neuron_state_dict[v_proj_key], + src_num_kv_heads, + num_attention_heads, + v_head_dim, + ) + + # FP8 path: replicate per-row scales ([src_heads*head_dim, 1]) in + # lockstep with the weights. Without this the shard_weights step + # rejects the scale shape mismatch (e.g. [12,1] vs expected [192,1]). + # BF16 has no .scale key, so this loop is a no-op there. + for proj, hd in (("k_proj", head_dim), ("v_proj", v_head_dim)): + scale_key = f"layers.{layer_idx}.self_attn.{proj}.scale" + if scale_key in neuron_state_dict: + neuron_state_dict[scale_key] = _replicate_kv_weights_for_convert_to_mha( + neuron_state_dict[scale_key], + src_num_kv_heads, + num_attention_heads, + hd, + ) + + # Only convert MoE layers + if not config.layer_uses_moe[layer_idx]: + continue + + # Check if this layer has MoE weights + gate_key = f"layers.{layer_idx}.mlp.gate.weight" + if gate_key not in neuron_state_dict: + continue + + # Rename router weights + neuron_state_dict[f"layers.{layer_idx}.mlp.router.linear_router.weight"] = ( + neuron_state_dict[gate_key].detach().clone() + ) + del neuron_state_dict[gate_key] + + # Get dimensions from first expert + expert_0_gate = f"layers.{layer_idx}.mlp.experts.0.gate_proj.weight" + if expert_0_gate not in neuron_state_dict: + continue + + intermediate_size, hidden_size = neuron_state_dict[expert_0_gate].shape + device = neuron_state_dict[expert_0_gate].device + dtype = neuron_state_dict[expert_0_gate].dtype + + num_experts = config.n_routed_experts + + # Concatenate gate and up projections + gate_up_proj = torch.empty( + num_experts, + hidden_size, + 2 * intermediate_size, + dtype=dtype, + device=device, + ) + + for e in range(num_experts): + gate_proj_weights = neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight" + ].T.detach().clone() + up_proj_weights = neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight" + ].T.detach().clone() + + gate_up_proj[e, :, :intermediate_size] = gate_proj_weights + gate_up_proj[e, :, intermediate_size:] = up_proj_weights + + del neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"] + del neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"] + + # Pad if needed + pad_size = getattr(config, "moe_intermediate_pad_size", 0) + if pad_size > 0: + gate_up_proj = gate_up_proj.reshape(num_experts, hidden_size, 2, -1) + gate_up_proj = F.pad(gate_up_proj, (0, pad_size)) + gate_up_proj = gate_up_proj.reshape(num_experts, hidden_size, -1) + + neuron_state_dict[f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj + + # Convert down projections + down_proj = torch.empty( + num_experts, + intermediate_size, + hidden_size, + dtype=dtype, + device=device, + ) + + for e in range(num_experts): + down_proj_weights = neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" + ].T.detach().clone() + down_proj[e] = down_proj_weights + del neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"] + + # Pad if needed + if pad_size > 0: + down_proj = F.pad(down_proj, (0, 0, 0, pad_size)) + + neuron_state_dict[f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj + + gc.collect() + + # --- Expand MoE blockwise scales along the TP-partitioned dim (FP8 only). --- + # NxDI's shard_checkpoint splits the scale on its partition dim into + # `per_partition_size = dim_size / tp_degree`. At TP=64 both projections + # have per-rank "intermediate" smaller than the 128-wide scale block, so + # several ranks share one scale block — we need to replicate scale entries + # along that dim. Adjacent ranks whose weight falls inside the same + # 128-wide block genuinely share that block's scale. No-op when the + # .scale keys are absent (BF16 path). + if getattr(config.neuron_config, "quantized", False): + # IMPORTANT: MoE expert weights are sharded by moe_tp_degree (not the + # top-level tp_degree — attention uses tp_degree, MoE can use a + # different split). At moe_tp=64 the per-rank intermediate is 32 (<128) + # so we had to expand the scale to make the shard layout match; at + # moe_tp=16 per-rank intermediate is 128 (>=128) and no expansion is + # needed. + moe_tp = getattr(config.neuron_config, "moe_tp_degree", None) or config.neuron_config.tp_degree + for layer_idx in range(config.num_hidden_layers): + if not config.layer_uses_moe[layer_idx]: + continue + + # down_proj (RowParallel on intermediate dim). Scale: [E, I_blocks, H_blocks] + dp_key = f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.down_proj.scale" + if dp_key in neuron_state_dict: + s = neuron_state_dict[dp_key] + i_blocks = s.shape[1] + h_blocks = s.shape[2] + intermediate = i_blocks * 128 + i_per_rank = intermediate // moe_tp + if i_per_rank < 128: + ranks_per_block = 128 // i_per_rank + s_exp = s.unsqueeze(2).expand(-1, -1, ranks_per_block, -1) + s_exp = s_exp.reshape(s.shape[0], i_blocks * ranks_per_block, h_blocks) + assert s_exp.shape[1] == moe_tp, ( + f"down_proj.scale expansion produced {s_exp.shape[1]} rows, " + f"expected moe_tp={moe_tp}" + ) + neuron_state_dict[dp_key] = s_exp.contiguous() + + # gate_up_proj (ColumnParallel on 2*intermediate dim, gate|up fused + # along last axis). Scale: [E, H_blocks, 2*I_blocks] stored as + # [gate_half | up_half]. Module parameter has per-rank last-dim=1 + # (via _apply_blockwise_scale_stride_fix patch forcing + # partition_stride=1), so the full scale must have last-dim=moe_tp + # with gate entries 0..moe_tp/2 and up entries moe_tp/2..moe_tp. + # Expand each half independently to preserve the gate/up boundary + # when NxD does `split(per_partition=2*I/moe_tp, dim=-1)`. + gu_key = f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.gate_up_proj.scale" + if gu_key in neuron_state_dict: + s = neuron_state_dict[gu_key] + h_blocks = s.shape[1] + two_i_blocks = s.shape[2] + assert two_i_blocks % 2 == 0, ( + f"gate_up_proj.scale last dim must be 2*i_blocks, got {two_i_blocks}" + ) + i_blocks = two_i_blocks // 2 + intermediate = i_blocks * 128 + out_per_rank = (2 * intermediate) // moe_tp + if out_per_rank < 128: + assert moe_tp % 2 == 0, f"moe_tp={moe_tp} must be even for gate/up scale split" + ranks_per_half = moe_tp // 2 + assert ranks_per_half % i_blocks == 0, ( + f"ranks_per_half={ranks_per_half} must be divisible by " + f"i_blocks={i_blocks}" + ) + ranks_per_block = ranks_per_half // i_blocks + gate_half = s[..., :i_blocks] # [E, H_blocks, i_blocks] + up_half = s[..., i_blocks:] + gate_exp = ( + gate_half.unsqueeze(-1) + .expand(-1, -1, -1, ranks_per_block) + .reshape(s.shape[0], h_blocks, ranks_per_half) + ) + up_exp = ( + up_half.unsqueeze(-1) + .expand(-1, -1, -1, ranks_per_block) + .reshape(s.shape[0], h_blocks, ranks_per_half) + ) + s_exp = torch.cat([gate_exp, up_exp], dim=-1) + assert s_exp.shape[-1] == moe_tp, ( + f"gate_up_proj.scale expansion produced {s_exp.shape[-1]} " + f"entries, expected moe_tp={moe_tp}" + ) + neuron_state_dict[gu_key] = s_exp.contiguous() + + return neuron_state_dict + + +def _maybe_dequantize_layer( + neuron_state_dict: Dict[str, Any], + config: MiMoV2InferenceConfig, +): + """Dequantize FP8 layers if present.""" + scale_layers = [] + + for layer_key in list(neuron_state_dict.keys()): + if "_scale_inv" in layer_key: + scales = neuron_state_dict[layer_key] + scale_layers.append(layer_key) + + fp8_layer_name = layer_key.replace("_scale_inv", "") + if fp8_layer_name not in neuron_state_dict: + continue + + fp8_layer = neuron_state_dict[fp8_layer_name] + + # Get block size from config if available + if hasattr(config, 'quantization_config') and config.quantization_config: + block_size = config.quantization_config.get("weight_block_size", [128, 128]) + else: + block_size = [128, 128] + + # Expand scales and dequantize + scales_expanded = scales.repeat_interleave(block_size[0], dim=0) + scales_expanded = scales_expanded.repeat_interleave(block_size[1], dim=1) + + # Ensure shapes match + if scales_expanded.shape != fp8_layer.shape: + scales_expanded = scales_expanded[:fp8_layer.shape[0], :fp8_layer.shape[1]] + + scaled_layer = fp8_layer.to(torch.float32) * scales_expanded.to(torch.float32) + neuron_state_dict[fp8_layer_name] = scaled_layer.to(config.neuron_config.torch_dtype) + + # Remove scale layers + for scale_layer in scale_layers: + del neuron_state_dict[scale_layer] + + +class NeuronMiMoV2ForCausalLM(NeuronBaseForCausalLM): + """MiMo-V2.5 for Causal Language Modeling on Neuron.""" + + _model_cls = NeuronMiMoV2Model + + def __init__(self, *args, **kwargs): + # Install FP8 monkey-patches BEFORE super().__init__ so the patched + # RouterTopK.__init__ and quantization layer classes are in effect + # when NxDI builds the decoder (and instantiates routers). Harnesses + # that drive us via model.compile()/model.load() (e.g. vllm-neuron) + # call those methods AFTER construction, so patching from inside + # compile()/load() is too late — RouterTopK instances would already + # lack our e_score_correction_bias parameter, silently routing tokens + # to wrong experts and producing gibberish output. + # + # _install_fp8_patches() reads self.neuron_config, which needs to + # exist; grab it from the args or the config arg the same way the + # base class does. + ncfg = kwargs.get("config") or (args[1] if len(args) > 1 else None) + if ncfg is not None and getattr(getattr(ncfg, "neuron_config", None), "quantized", False): + self._apply_ep_scale_fix() + self._apply_blockwise_scale_stride_fix() + self._apply_2d_per_channel_fix() + self._apply_router_noaux_tc_fix() + super().__init__(*args, **kwargs) + + @staticmethod + def load_hf_model(model_path: str, **kwargs): + """Load HuggingFace model. + + Note: MiMo-V2.5 uses custom code, so we need trust_remote_code=True + """ + from transformers import AutoModelForCausalLM + return AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + **kwargs, + ) + + @classmethod + def get_config_cls(cls) -> Type[MiMoV2InferenceConfig]: + return MiMoV2InferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: Dict[str, Any], + config: MiMoV2InferenceConfig, + ) -> Dict[str, Any]: + return convert_mimo_v2_hf_to_neuron_state_dict(state_dict, config) + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + # ------------------------------------------------------------------ + # FP8 quantized-inference monkey-patches (no-op unless quantized=True). + # + # Reconcile the preprocessed Neuron-FP8 checkpoint (blockwise-MoE + + # per-row-attn) with NxDI's global blockwise_symmetric q_config. All + # four are gated by self.neuron_config.quantized so the BF16 path is + # completely untouched. + # ------------------------------------------------------------------ + + @staticmethod + def _apply_ep_scale_fix(): + """Skip per-channel `scale` params when marking expert-parallel + weights; they have shape [1, 1, W] and cannot be EP-sharded.""" + from neuronx_distributed.modules.moe.moe_parallel_layers import ( + ExpertFusedLinear, + ) + + if getattr(ExpertFusedLinear, "_mimo_v2_ep_scale_patched", False): + return + + def _patched_mark( + self_inner, + iterable=None, + expert_parallel_group_size=None, + is_prefill=True, + expert_distribution=None, + ): + from neuronx_distributed.parallel_layers.parallel_state import ( + get_expert_model_parallel_size, + ) + + if expert_parallel_group_size is None: + expert_parallel_group_size = get_expert_model_parallel_size() + + if expert_parallel_group_size > 1: + if iterable is None: + params_to_mark = [] + for name, p in self_inner.named_parameters(): + if name == "scale" and p.shape[0] == 1: + continue + params_to_mark.append(p) + iterable = params_to_mark + + for p in iterable: + p.expert_model_parallel = True + if is_prefill: + p.is_prefill = True + p.expert_distribution = expert_distribution + + ExpertFusedLinear._mark_expert_parallel_weights = _patched_mark + ExpertFusedLinear._mimo_v2_ep_scale_patched = True + + @staticmethod + def _apply_blockwise_scale_stride_fix(): + """Force scale.partition_stride=1 for BLOCKWISE_SYMMETRIC quantization + — stride>1 causes strided-splitting failures when per-rank weight size + is smaller than a block.""" + from neuronx_distributed.quantization.quantization_config import ( + QuantizationType, + ) + from neuronx_distributed.quantization.quantization_layers import ( + BaseQuantizeParallelLinear, + ) + + if getattr(BaseQuantizeParallelLinear, "_mimo_v2_blockwise_stride_patched", False): + return + + _original_setup = BaseQuantizeParallelLinear._setup_for_scale + + def _patched_setup(self_inner, *args, **kwargs): + _original_setup(self_inner, *args, **kwargs) + if ( + hasattr(self_inner, "quantization_type") + and self_inner.quantization_type == QuantizationType.BLOCKWISE_SYMMETRIC + and hasattr(self_inner, "scale") + and hasattr(self_inner.scale, "partition_stride") + and self_inner.scale.partition_stride > 1 + ): + self_inner.scale.partition_stride = 1 + + BaseQuantizeParallelLinear._setup_for_scale = _patched_setup + BaseQuantizeParallelLinear._mimo_v2_blockwise_stride_patched = True + + @staticmethod + def _apply_2d_per_channel_fix(): + """Route 2D self_attn + layer-0 dense-MLP swaps through per_channel_symmetric. + + MiMo-V2.5's preprocess writes: + - MoE experts: 3D weights with (E, out//128, in//128) blockwise scales. + - self_attn q/k/v + layer-0 mlp gate/up/down: 2D weights with + (out, 1) per-row scales. + + NxDI's q_config is global blockwise_symmetric (to satisfy the MoE). + Feeding that into the 2D classes triggers + `block axis cannot be < 0 or > 2, received 2` in _setup_for_scale + (block axes [1, 2] exceed rank-2 weight_shape). This wraps the 2D + classes' from_float to override q_config on the fly. + """ + from neuronx_distributed.quantization.quantization_config import ( + QuantizationType, + ) + from neuronx_distributed.quantization.quantization_layers import ( + QuantizedColumnParallel, + QuantizedRowParallel, + ) + + def _wrap(cls): + if getattr(cls, "_mimo_v2_2d_patched", False): + return + original_from_float = cls.from_float + + def _patched_from_float(klass, mod, q_config=None, _orig=original_from_float): + if q_config is not None and q_config.get("quantization_type") == \ + QuantizationType.BLOCKWISE_SYMMETRIC: + q_config = dict(q_config) + q_config["quantization_type"] = QuantizationType.PER_CHANNEL_SYMMETRIC + q_config["quantization_per_channel_axis"] = 0 + q_config.pop("block_axis", None) + q_config.pop("block_size", None) + if q_config is None: + return _orig(mod) + return _orig(mod, q_config) + + cls.from_float = classmethod(_patched_from_float) + cls._mimo_v2_2d_patched = True + + _wrap(QuantizedColumnParallel) + _wrap(QuantizedRowParallel) + + @staticmethod + def _apply_router_noaux_tc_fix(): + """Register e_score_correction_bias on NxD RouterTopK and fold it into + top-k selection so MiMo-V2's noaux_tc routing matches HF reference. + + MiMo-V2's HF config uses topk_method='noaux_tc': each expert score is + `sigmoid(logits) + e_score_correction_bias`, top-k indices are chosen + from THAT biased score; the returned expert weights (affinities) + come from the UNBIASED sigmoid(logits). NxD's stock RouterTopK is + plain topk with no bias slot, so without this the bias is silently + dropped and ~all tokens route to wrong experts. + """ + from neuronx_distributed.modules.moe.routing import RouterTopK + + if getattr(RouterTopK, "_mimo_v2_noaux_tc_patched", False): + return + + original_init = RouterTopK.__init__ + + def _patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + # CRITICAL: dtype + init value both matter for XLA tracing. + # + # 1) dtype=torch.bfloat16: the NxDI checkpoint loader casts router + # bias from FP32 -> BF16 ("Found torch.float32 weights in + # checkpoint ... Will convert to torch.bfloat16"). If the traced + # NEFF expects FP32 but the checkpoint supplies BF16, the + # LayoutTransformation silently drops the weight and keeps the + # trace-time init values — so the bias at runtime is whatever + # we init here, not the checkpoint values. + # + # 2) init=arange, NOT zeros: if every entry is identical (all + # zeros), the `+ bias` op does not change the relative ordering + # of topk, so XLA's constant-folding passes can prove the add + # is a no-op and eliminate it entirely — dropping the bias + # parameter from the HLO. At that point checkpoint loading has + # nothing to bind to and the real bias is silently discarded. + # Using arange guarantees distinct per-expert values, forcing + # the compiler to keep the add as a runtime op with a live + # parameter. Source: Jim Burtoft's MiniMax-M2 fix notes + # (jimburtoft/neuronx-distributed-inference@49f8e164). + self.e_score_correction_bias = nn.Parameter( + torch.arange(self.num_experts, dtype=torch.bfloat16), + requires_grad=False, + ) + + def _patched_forward(self, hidden_states): + router_logits = self.get_router_logits(hidden_states) + expert_affinities = self.apply_activation_fn(router_logits) + + # MiMo (and MiniMax-M2) uses topk_method='noaux_tc': the bias is + # added ONLY for top-k selection, but the unbiased sigmoid scores + # remain as the expert-affinity weights passed to the experts. + scores_for_choice = ( + expert_affinities.float() + self.e_score_correction_bias.unsqueeze(0) + ) + _, expert_index = torch.topk(scores_for_choice, self.top_k, dim=-1) + + expert_affinities = expert_affinities.to(dtype=hidden_states.dtype) + expert_index = expert_index.detach().to(dtype=torch.long) + return router_logits, expert_affinities, expert_index + + RouterTopK.__init__ = _patched_init + RouterTopK.forward = _patched_forward + RouterTopK._mimo_v2_noaux_tc_patched = True + + def _install_fp8_patches(self): + """Install all FP8-specific runtime patches. No-op for BF16.""" + if not getattr(self.neuron_config, "quantized", False): + return + self._apply_ep_scale_fix() + self._apply_blockwise_scale_stride_fix() + self._apply_2d_per_channel_fix() + self._apply_router_noaux_tc_fix() + + def compile(self, *args, **kwargs): + # save_sharded_checkpoint=True serializes shards during compile() and + # that code path reads scale.partition_stride — patches must be live. + self._install_fp8_patches() + return super().compile(*args, **kwargs) + + def load(self, *args, **kwargs): + self._install_fp8_patches() + return super().load(*args, **kwargs) + + @classmethod + def save_quantized_state_dict(cls, model_path, config): + """MiMo-V2.5 ships pre-quantized FP8 safetensors via our preprocess script. + The base implementation calls AutoModelForCausalLM.from_pretrained to + re-quantize, which requires a CUDA GPU (finegrained_fp8 gate) and + materializes an ~600 GB BF16 copy. Skip if the checkpoint directory + already contains a Neuron-FP8 index produced by preprocess.""" + import os as _os + qpath = ( + getattr(config.neuron_config, "quantized_checkpoints_path", None) + or model_path + ) + if qpath and _os.path.isdir(qpath): + index = _os.path.join(qpath, "model.safetensors.index.json") + if _os.path.isfile(index): + return + return super().save_quantized_state_dict(model_path, config) + + def get_compiler_args(self) -> str: + """Get compiler arguments optimized for MiMo-V2.5.""" + if self.compile_tag == CONTEXT_ENCODING_MODEL_TAG: + optimization_level = "-O1" + elif self.compile_tag == TOKEN_GENERATION_MODEL_TAG: + optimization_level = "-O3" if self.neuron_config.moe_ep_degree > 1 else "-O1" + else: + optimization_level = "-O1" + + compiler_args = ( + f"--enable-saturate-infinity " + f"--enable-mixed-precision-accumulation " + f"--model-type transformer " + f"{optimization_level}" + ) + + # Add CC overlap optimization + compiler_args += ( + " --tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2'" + ) + + compiler_args += " --auto-cast=none" + + # Enable vector-offset DGE + compiler_args += " --internal-enable-dge-levels vector_dynamic_offsets" + compiler_args += " --internal-hlo2tensorizer-options='--verify-hlo=true'" + + if self.neuron_config.scratchpad_page_size: + compiler_args += f" --hbm-scratchpad-page-size={self.neuron_config.scratchpad_page_size}" + + return compiler_args diff --git a/contrib/models/MiMo-V2.5/test/__init__.py b/contrib/models/MiMo-V2.5/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiMo-V2.5/test/integration/__init__.py b/contrib/models/MiMo-V2.5/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiMo-V2.5/test/integration/test_model.py b/contrib/models/MiMo-V2.5/test/integration/test_model.py new file mode 100644 index 00000000..d52fabb6 --- /dev/null +++ b/contrib/models/MiMo-V2.5/test/integration/test_model.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +"""Integration tests for MiMo-V2.5 NeuronX implementation.""" + +import pytest +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) + + +def test_config_import(): + """Test that config class can be imported.""" + from modeling_mimo_v2 import MiMoV2InferenceConfig, NeuronMiMoV2ForCausalLM + assert MiMoV2InferenceConfig is not None + assert NeuronMiMoV2ForCausalLM is not None + print("PASS: Config and model classes imported successfully") + + +def test_required_attributes(): + """Test that required attributes are defined.""" + from modeling_mimo_v2 import MiMoV2InferenceConfig + # Check get_required_attributes without instantiation (requires many params) + required = MiMoV2InferenceConfig.get_required_attributes(MiMoV2InferenceConfig) + assert "hidden_size" in required + assert "n_routed_experts" in required + assert "num_experts_per_tok" in required + assert "hybrid_layer_pattern" in required + assert "v_head_dim" in required + assert "swa_head_dim" in required + print(f"PASS: {len(required)} required attributes defined") + + +def test_neuron_config_cls(): + """Test that MoENeuronConfig is returned.""" + from modeling_mimo_v2 import MiMoV2InferenceConfig + from neuronx_distributed_inference.models.config import MoENeuronConfig + assert MiMoV2InferenceConfig.get_neuron_config_cls() == MoENeuronConfig + print("PASS: MoENeuronConfig returned") + + +def test_state_dict_converter(): + """Test that state dict converter function exists.""" + from modeling_mimo_v2 import NeuronMiMoV2ForCausalLM + assert hasattr(NeuronMiMoV2ForCausalLM, "convert_hf_to_neuron_state_dict") + print("PASS: State dict converter exists") + + +if __name__ == "__main__": + test_config_import() + test_required_attributes() + test_neuron_config_cls() + test_state_dict_converter() + print("\nAll tests passed!") diff --git a/contrib/models/MiMo-V2.5/test/unit/__init__.py b/contrib/models/MiMo-V2.5/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b