diff --git a/contrib/models/MiniMax-M2/README.md b/contrib/models/MiniMax-M2/README.md new file mode 100644 index 00000000..c77c2ab5 --- /dev/null +++ b/contrib/models/MiniMax-M2/README.md @@ -0,0 +1,240 @@ +# Contrib Model: MiniMax-M2 / M2.7 + +NeuronX Distributed Inference implementation of the MiniMax-M2 family on Trn2. + +- **Reference checkpoint used for validation:** [MiniMaxAI/MiniMax-M2.7](https://huggingface.co/MiniMaxAI/MiniMax-M2.7) +- Works with any `MiniMaxM2ForCausalLM` variant (M2 / M2.7 / any minor version) — the config schema is stable across M2 / M2.7. + +## Model Information + +- **HuggingFace ID:** `MiniMaxAI/MiniMax-M2.7` (and compatible M2 siblings) +- **Model Type:** Decoder-only MoE transformer with uniform GQA attention +- **Architecture:** Custom MoE with sigmoid routing, `e_score_correction_bias` (noaux_tc), per-layer QK RMSNorm +- **License:** Check HuggingFace model card + +## Architecture Details + +| Parameter | Value | +|-----------|-------| +| Hidden Size | 3072 | +| Layers | 62 | +| Attention Heads | 48 Q / 8 KV (GQA) | +| Head Dim | 128 (Q=K=V; uniform, no asymmetry) | +| Experts | 256 (top-8 routing) | +| Expert Intermediate | 1536 | +| Vocab Size | 200,064 | +| RoPE | Partial (rotary_dim=64 of head_dim=128), theta=5M | +| Max Position | 204,800 | + +Key features: +- **Uniform GQA** (no hybrid attention / sliding window / sink bias — M2 is structurally simpler than Flash). +- **QK RMSNorm**: Per-layer RMSNorm applied on Q and K after projection, before RoPE (uses Neuron-native `RmsNorm.apply` for CE/TKG consistency). +- **Sigmoid router + noaux_tc**: `e_score_correction_bias` added to the sigmoid scores before top-k selection; unbiased scores become the expert-affinity weights. +- **FP8-native**: Routed experts ship in blockwise FP8 (128×128 blocks). Per-row FP8 for attention Q/K/V/O after preprocess, which converts the HF OCP FP8 to Neuron FP8. + +## Prerequisites + +- **Instance**: trn2.48xlarge (32 NeuronCores, logical_nc_config=2 → 64 logical cores) +- **Neuron SDK**: 2.29 (Python 3.12, PyTorch 2.9) +- **Venvs**: `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference` (for preprocess + direct NxDI smoke), `/opt/aws_neuronx_venv_pytorch_inference_vllm_0_16` (for vLLM serving). Both ship with the DLAMI. +- **Disk**: ~500 GB free under `/opt/dlami/nvme` (HF FP8 checkpoint ~215 GB, Neuron-FP8 preprocessed output ~230 GB, plus `save_sharded_checkpoint` writes another ~140 GB per compiled config). + +## Quick Start (FP8 on Trn2) + +End-to-end recipe. First-time compile is ~25 minutes; subsequent runs hit the neuronx-cc cache and start in a few minutes. + +```bash +# 1. Clone this repo on the Trn2 instance +cd $HOME +git clone /neuronx-distributed-inference.git +cd neuronx-distributed-inference +git checkout contrib/MiniMax-M2 + +# 2. Download the HuggingFace FP8 checkpoint (~215 GB) +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +huggingface-cli download MiniMaxAI/MiniMax-M2.7 \ + --local-dir /opt/dlami/nvme/models/MiniMax-M2.7 + +# 3. Preprocess HF FP8 -> Neuron FP8 (~13 min, ~15 GB peak RAM) +python contrib/models/MiniMax-M2/src/conversion_script/preprocess_minimax_m2_fp8.py \ + --hf_model_path /opt/dlami/nvme/models/MiniMax-M2.7 \ + --save_path /opt/dlami/nvme/models/MiniMax-M2.7-Neuron-FP8 \ + --tp_degree 64 + +# 4. (Optional) sanity-check without vLLM (~25 min first compile, then ~20s to load) +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate +python contrib/models/MiniMax-M2/perf_test/smoke_compile_minimax_m2.py +python contrib/models/MiniMax-M2/perf_test/smoke_generate_minimax_m2.py + +# 5. Install vllm-neuron with the contrib registration patch +bash contrib/models/MiniMax-M2/perf_test/0_setup.sh + +# 6. Start vLLM + bench (BS=32/moe_ep=64, BS=128/moe_ep=64) +bash contrib/models/MiniMax-M2/perf_test/bench_minimax_m2.sh +``` + +The bench script runs two configurations (BS=32 and BS=128, both `moe_tp_degree=1 / moe_ep_degree=64`) and logs results under `/tmp/bench_results/minimax_m2/`. + +Quick `curl` sanity check once the server is up: + +```bash +curl -s http://localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{"model": "/opt/dlami/nvme/models/MiniMax-M2.7-Neuron-FP8", + "messages": [{"role": "user", "content": "Hello! Introduce yourself in one sentence."}], + "max_tokens": 64, "temperature": 0.7}' | python3 -m json.tool +``` + +If you see fluent sentence output on a 50+ token generation, the FP8 path is working correctly. If you see repetition collapse (single-token loops like "helpful helpful helpful..."), double-check that `moe_tp_degree=1`, `moe_ep_degree=64`, `batch_size>=32`, and that you're loading the preprocessed Neuron-FP8 checkpoint (not the raw HF FP8 directory). + +## Checkpoint Preparation + +The HuggingFace checkpoint ships as block-wise OCP FP8 (E4M3, ±448 range), which is not directly compatible with Neuron FP8 (IEEE-754 E4M3, ±240 range). The preprocess script in `src/conversion_script/preprocess_minimax_m2_fp8.py` rescales it: + +- **Attention q/k/v**: OCP FP8 blockwise → Neuron FP8 per-row. Per-row scales are used because at TP=64 each rank's output dim is <128, which would collapse a blockwise scale to a singleton. A `_apply_2d_per_channel_fix` monkey-patch installed at compile time routes the 2D weights through PER_CHANNEL_SYMMETRIC to match. +- **Attention o_proj**: OCP FP8 blockwise → **BF16 (dequantized)**. The NxDI modeling code binds `self_attn.o_proj` to a plain `RowParallelLinear` rather than the auto-swapped `QuantizedRowParallel`, so the loader does not expect `.scale` or FP8 bytes for o_proj and would drop them as "redundant". Preprocess dequantizes to BF16, and the smoke/bench configs list `o_proj` in `modules_to_not_convert` to keep NxDI from re-swapping it at `convert()` time. +- **MoE experts**: w1/w3 fused into packed `gate_up_proj [num_experts, hidden, 2*IM]`, w2 stacked into `down_proj [num_experts, IM, hidden]`. Scales stay blockwise. +- **Router gate + `e_score_correction_bias`**: renamed into the NxDI router namespace (`block_sparse_moe.router.linear_router.weight` and `...router.e_score_correction_bias`). +- **Norms + embed_tokens + lm_head**: passed through BF16. + +Output layout: +``` +save_path/ + config.json, tokenizer.*, chat_template.jinja + configuration_minimax_m2.py, modeling_minimax_m2.py (trust_remote_code) + model.safetensors.index.json + model_extras.safetensors (embed/norm/lm_head) + model_layer{N}.safetensors (one per decoder layer, N=0..61) +``` + +Runtime characteristics: ~15 GB peak RAM, ~13 minutes total on trn2.48xlarge. + +## FP8 Configuration Notes + +Three non-obvious constraints on Trn2, identical to the Flash FP8 path and for the same underlying reasons: + +1. **`moe_tp_degree=1, moe_ep_degree=64` is the only working FP8 ratio.** At `moe_tp=64` each rank's intermediate slice is 24 rows (<128 blockwise block), and NxDI's `_setup_for_scale` collapses the per-rank scale to a singleton — losing per-channel FP8 scale granularity. The resulting drift compounds across M2's 62 MoE layers and manifests as output collapse after ~30 decode tokens. `moe_tp=1, moe_ep=64` keeps each expert's weight + blockwise scale intact on a single rank and produces correct output. + +2. **`batch_size >= 32` on the FP8 path.** NxDI's TKG path refuses Expert Parallelism when `batch_size < num_experts / top_k = 256 / 8 = 32`. BS=1 single-stream latency demos on FP8 are not possible. + +3. **Keep outer `ep_degree=1`.** `MoENeuronConfig.ep_degree` is the full-model expert-parallel factor and multiplies `world_size` to `tp_degree * ep_degree`. At `world_size > 64` on a 64-NC Trn2, sharded-checkpoint size grows linearly, ranks beyond 63 have no backing hardware, and load fails. MoE EP is controlled exclusively via `moe_ep_degree`. + +The bench and smoke scripts have all three pinned correctly; the items above matter only if you're hand-crafting a `MoENeuronConfig`. + +## vLLM Integration + +MiniMax-M2 can be served via [vllm-neuron](https://github.com/aws-neuron/vllm-neuron). A contrib registration patch (`perf_test/vllm-neuron-patch.patch`) is required to plug the NxDI modeling code into vllm-neuron's lookup tables. + +### Setup + +```bash +# The setup script clones vllm-project/vllm-neuron at release-0.5.0, applies +# the contrib registration patch, installs it editable, and fetches the HF +# FP8 checkpoint (or skips if already present). It also prints the +# preprocess command if the Neuron-FP8 output dir is empty. +bash contrib/models/MiniMax-M2/perf_test/0_setup.sh +``` + +### Serving (FP8, recommended) + +The bench script already starts a vLLM server at port 8000 with the right config; to start one manually: + +```bash +export NXDI_CONTRIB_MINIMAX_M2_SRC=/path/to/neuronx-distributed-inference/contrib/models/MiniMax-M2/src +export MINIMAX_M2_PATH=/path/to/MiniMax-M2.7-Neuron-FP8 +export VLLM_ENGINE_READY_TIMEOUT_S=7200 +# Optional: isolate compile cache per config so parallel M2/Flash/Pro compiles +# don't race on the default /var/tmp/neuron-compile-cache lock files. +export NEURON_COMPILED_ARTIFACTS=/path/to/compiled/minimax_m2_bs32_moetp1_ep64_fp8 + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MINIMAX_M2_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 32 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + "tp_degree": 64, + "logical_nc_config": 2, + "fused_qkv": false, + "sequence_parallel_enabled": false, + "glu_mlp": true, + "moe_mask_padded_tokens": true, + "disable_numeric_cc_token": true, + "save_sharded_checkpoint": true, + "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, + "quantized": true, + "quantized_checkpoints_path": "/path/to/MiniMax-M2.7-Neuron-FP8", + "quantization_dtype": "f8e4m3", + "quantization_type": "blockwise_symmetric", + "quantization_block_axis": [1, 2], + "quantization_block_size": [128, 128], + "modules_to_not_convert": ["embed_tokens", "lm_head", "norm", "router", "o_proj"], + "blockwise_matmul_config": {"use_shard_on_block_dynamic_while": true, "block_sharding_strategy": "PING_PONG"}, + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 32, + "ctx_batch_size": 1, + "tkg_batch_size": 32, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "enable_bucketing": true, + "context_encoding_buckets": [1024], + "token_generation_buckets": [1024], + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' +``` + +### vllm-neuron patch summary + +The patch is applied to vllm-neuron 0.5.0 and: + +- Registers `NeuronMiniMaxM2ForCausalLM` into NxDI's `MODEL_TYPES` under `minimax_m2` when `NXDI_CONTRIB_MINIMAX_M2_SRC` points at this contrib package's `src/`. +- Passes `hf_config` from vLLM into `load_pretrained_config` so NxDI does not re-load the config without `trust_remote_code=True`. +- Replaces vllm-neuron's internal `AutoModelForCausalLM.from_pretrained` with `huggingface_hub.snapshot_download`, which is the only path that works for `trust_remote_code=True` models when no GPU is available for HF's CUDA-gated FP8 quantizer. + +## Testing + +```bash +pytest contrib/models/MiniMax-M2/test/integration/test_model.py -v +``` + +## Key Implementation Notes + +1. **QK Norm**: `MiniMaxM2QKNorm` uses Neuron-native `RmsNorm.apply` (not hand-rolled pow/mean/rsqrt). Hand-rolled PyTorch RMSNorm compiles into different HLO in CE vs TG and produces incorrect TG results. +2. **Router Bias**: `RouterTopKWithBias` stores `e_score_correction_bias` as an `nn.Parameter` initialised to `torch.arange(num_experts, dtype=torch.bfloat16)`. Two non-obvious reasons: + - `register_buffer` (zeros) gets constant-folded by XLA and the checkpoint bias never binds at inference time. + - `dtype=float32` triggers a silent dtype mismatch in the NxDI loader's `LayoutTransformation`, which then drops the weight. +3. **CONVERT_TO_MHA**: When `tp_degree > num_kv_heads` (64 > 8), K/V are replicated to `num_attention_heads` (48) during state-dict conversion; on the FP8 path this applies to the per-row `.scale` tensors in lockstep with the weights. +4. **FP8 Runtime Patches** (installed in `NeuronMiniMaxM2ForCausalLM.__init__` when `quantized=True`, idempotent): + - `_apply_ep_scale_fix` — don't EP-shard `[1,1,W]` singleton scales. + - `_apply_blockwise_scale_stride_fix` — force `partition_stride=1` for `BLOCKWISE_SYMMETRIC` to avoid strided-split failures when per-rank weight is smaller than a 128-wide scale block. + - `_apply_2d_per_channel_fix` — flip q_config from `BLOCKWISE_SYMMETRIC` to `PER_CHANNEL_SYMMETRIC` for 2D attention weights at layer-swap time. +5. **`save_quantized_state_dict` override**: short-circuits the HF re-quantize path (which requires CUDA and materialises a ~600 GB BF16 copy) when the preprocess-produced Neuron-FP8 index is already on disk. + +## Compatibility Matrix + +| Instance | Neuron SDK 2.29+ (PyTorch 2.9) | 2.21 and earlier | +|----------|--------------------------------|------------------| +| Trn2 (trn2.48xlarge) | Tested | Not tested | +| Trn1 | Not supported (requires 64 logical cores via logical_nc_config=2) | Not supported | +| Inf2 | Not supported | Not supported | + +## Example Checkpoints + +* [MiniMaxAI/MiniMax-M2.7](https://huggingface.co/MiniMaxAI/MiniMax-M2.7) +* [MiniMaxAI/MiniMax-M2](https://huggingface.co/MiniMaxAI/MiniMax-M2) (same config schema, compatible preprocess) + +## Maintainer + +Henan Wan (whn09) + +**Last Updated:** 2026-04-25 diff --git a/contrib/models/MiniMax-M2/perf_test/0_setup.sh b/contrib/models/MiniMax-M2/perf_test/0_setup.sh new file mode 100755 index 00000000..e6cf64ce --- /dev/null +++ b/contrib/models/MiniMax-M2/perf_test/0_setup.sh @@ -0,0 +1,73 @@ +#!/bin/bash +# Setup for MiniMax-M2 vLLM benchmarking on Trn2. +# +# Clones upstream vllm-project/vllm-neuron at release-0.5.0 and applies +# vllm-neuron-patch.patch, which adds a runtime registration hook so the +# contrib NeuronMiniMaxM2ForCausalLM is plugged into NxDI's MODEL_TYPES +# at vllm-neuron plugin init time. vLLM's ModelRegistry already recognizes +# MiniMaxM2ForCausalLM so no vLLM-side registration is needed. +set -e + +echo "==========================================" +echo "Setup: vllm-neuron + MiniMax-M2 weights" +echo "==========================================" + +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +PATCH_FILE="$(cd "$(dirname "$0")" && pwd)/vllm-neuron-patch.patch" + +echo "" +echo "[1/2] Installing vllm-neuron (release-0.5.0) with the contrib registration patch..." + +if [ ! -d $HOME/vllm-neuron ]; then + git clone --branch release-0.5.0 https://github.com/vllm-project/vllm-neuron.git $HOME/vllm-neuron +fi + +cd $HOME/vllm-neuron + +if git apply --check "$PATCH_FILE" 2>/dev/null; then + git apply "$PATCH_FILE" + echo " Applied $PATCH_FILE" +else + echo " Patch already applied or conflicts; continuing." +fi + +pip install --extra-index-url=https://pip.repos.neuron.amazonaws.com -e . +pip install s5cmd + +python3 -c "import vllm_neuron; print('vllm-neuron installed:', vllm_neuron.__file__)" + +echo "" +echo "[2/2] Fetching MiniMax-M2.7 FP8 weights (HuggingFace)..." + +# Source HF checkpoint (FP8 OCP, ~215 GB). Preprocessing this into +# Neuron-FP8 via src/conversion_script/preprocess_minimax_m2_fp8.py is +# a separate step (~13 min); the bench script expects the preprocessed +# output at $MINIMAX_M2_PATH. +HF_PATH="${MINIMAX_M2_HF_PATH:-/opt/dlami/nvme/models/MiniMax-M2.7}" +if [ -d "$HF_PATH" ] && [ "$(ls "$HF_PATH"/*.safetensors 2>/dev/null | wc -l)" -gt 0 ]; then + echo " HF weights already at $HF_PATH, skipping download" +else + echo " Downloading HF FP8 weights (this takes ~5 min at S3 speeds)..." + huggingface-cli download MiniMaxAI/MiniMax-M2.7 --local-dir "$HF_PATH" + echo " Download complete: $(du -sh $HF_PATH | cut -f1)" +fi + +MINIMAX_PATH="${MINIMAX_M2_PATH:-/opt/dlami/nvme/models/MiniMax-M2.7-Neuron-FP8}" +if [ -d "$MINIMAX_PATH" ] && [ "$(ls "$MINIMAX_PATH"/*.safetensors 2>/dev/null | wc -l)" -gt 0 ]; then + echo " Neuron-FP8 checkpoint already exists at $MINIMAX_PATH" +else + echo "" + echo "Next step (not run automatically): preprocess HF -> Neuron-FP8" + echo " python contrib/models/MiniMax-M2/src/conversion_script/preprocess_minimax_m2_fp8.py \\" + echo " --hf_model_path $HF_PATH \\" + echo " --save_path $MINIMAX_PATH \\" + echo " --tp_degree 64" +fi + +CONTRIB_SRC="$(cd "$(dirname "$0")/.." && pwd)/src" + +echo "" +echo "Setup complete. Before running the benchmark, export:" +echo " export MINIMAX_M2_PATH=$MINIMAX_PATH" +echo " export NXDI_CONTRIB_MINIMAX_M2_SRC=$CONTRIB_SRC" diff --git a/contrib/models/MiniMax-M2/perf_test/bench_minimax_m2.sh b/contrib/models/MiniMax-M2/perf_test/bench_minimax_m2.sh new file mode 100755 index 00000000..ce49edef --- /dev/null +++ b/contrib/models/MiniMax-M2/perf_test/bench_minimax_m2.sh @@ -0,0 +1,232 @@ +#!/bin/bash +set -e + +# MiniMax-M2 / M2.7 FP8 vLLM benchmark on Trn2. +# +# Requires a Neuron-FP8 preprocessed checkpoint (see +# `src/conversion_script/preprocess_minimax_m2_fp8.py`). The configs below +# all use moe_tp_degree=1 / moe_ep_degree=64 (experts sharded by expert +# parallelism only, no intra-expert TP split) because moe_tp_degree=64 +# collapses the per-rank FP8 blockwise scale to a singleton — M2's MoE +# intermediate is only 1536, so moe_tp=64 gives per-rank intermediate=24 +# rows, well below the 128-row blockwise scale block. +# Using moe_ep_degree=64 keeps all of each expert's weight + scale on one +# rank (4 experts per rank), which preserves per-channel scale intact. +# +# NxDI's TKG path refuses Expert Parallelism with BS < num_experts/top_k +# (256 / 8 = 32 for M2), so the smallest working batch size here is 32. + +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +MODEL_PATH="${MINIMAX_M2_PATH:-/opt/dlami/nvme/models/MiniMax-M2.7-Neuron-FP8}" +# The NxDI contrib MiniMax-M2 modeling code is registered into NxDI's +# MODEL_TYPES by vllm-neuron's register() hook using this env var. +# Default to this contrib package's own src/ relative to the script. +: "${NXDI_CONTRIB_MINIMAX_M2_SRC:=$(cd "$(dirname "$0")/.." && pwd)/src}" +export NXDI_CONTRIB_MINIMAX_M2_SRC + +# First-time M2 FP8 compile takes ~25 minutes; extend vLLM's ready timeout. +export VLLM_ENGINE_READY_TIMEOUT_S=7200 + +PORT=8000 +RESULTS_DIR="/tmp/bench_results/minimax_m2" +mkdir -p "$RESULTS_DIR" + +# Common neuron config shared across all MiniMax-M2 FP8 configs. +# save_sharded_checkpoint=true persists per-rank sharded weights to +# /weights/tp{N}_sharded_checkpoint.safetensors during +# compile; load() then reads those directly instead of re-sharding. +# modules_to_not_convert: HF's quantization_config only skips +# {gate, e_score_correction_bias, lm_head}, and NxDI-side we additionally +# keep embed_tokens / norm in BF16. We also list o_proj here — M2's +# modeling code binds self_attn.o_proj to a plain RowParallelLinear +# (not auto-swapped to QuantizedRowParallel), so preprocess dequantizes +# the FP8 o_proj weights to BF16 and this list prevents NxDI from +# trying to re-swap it at convert() time. Without this, the loader +# silently drops the o_proj weights as "redundant keys" and attention +# output is garbage. +COMMON_MINIMAX_CONFIG='"tp_degree": 64, + "logical_nc_config": 2, + "fused_qkv": false, + "sequence_parallel_enabled": false, + "glu_mlp": true, + "moe_mask_padded_tokens": true, + "disable_numeric_cc_token": true, + "save_sharded_checkpoint": true, + "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, + "quantized": true, + "quantized_checkpoints_path": "'"$MODEL_PATH"'", + "quantization_dtype": "f8e4m3", + "quantization_type": "blockwise_symmetric", + "quantization_block_axis": [1, 2], + "quantization_block_size": [128, 128], + "modules_to_not_convert": ["embed_tokens", "lm_head", "norm", "router", "o_proj"], + "blockwise_matmul_config": {"use_shard_on_block_dynamic_while": true, "block_sharding_strategy": "PING_PONG"}' + +# Helper: wait for vLLM server to be ready. First-time compilation of a +# 256-expert MoE model takes 25-60 minutes, so we poll for up to 2 hours. +wait_for_server() { + echo " Waiting for vLLM server to be ready (up to 2h for first compile)..." + local interval=10 + local max_attempts=720 # 720 * 10s = 7200s = 2h + local start=$SECONDS + for i in $(seq 1 $max_attempts); do + if curl -s http://localhost:$PORT/health > /dev/null 2>&1; then + echo " Server ready! (waited $((SECONDS - start))s)" + return 0 + fi + # Show a progress blip every minute so the user knows we're alive + if [ $((i % 6)) -eq 0 ]; then + echo " ...still waiting ($((SECONDS - start))s elapsed)" + fi + sleep $interval + done + echo " ERROR: Server did not start within $((max_attempts * interval))s" + return 1 +} + +# Helper: run benchmark +run_bench() { + local config_name=$1 + local concurrency=$2 + local num_prompts=$3 + + echo " Benchmark: concurrency=$concurrency, prompts=$num_prompts" + vllm bench serve \ + --backend vllm \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --endpoint /v1/completions \ + --dataset-name random \ + --num-prompts "$num_prompts" \ + --random-input-len 900 \ + --random-output-len 90 \ + --random-range-ratio 0.03 \ + --max-concurrency "$concurrency" \ + 2>&1 | tee "$RESULTS_DIR/${config_name}_c${concurrency}.txt" + echo "" +} + +# Helper: stop server +stop_server() { + echo " Stopping vLLM server..." + pkill -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true + sleep 5 +} + +# Helper: quick sanity check +sanity_check() { + echo " Running sanity check..." + curl -s http://localhost:$PORT/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [{"role": "user", "content": "What is 1+1? Answer briefly."}], + "model": "'"$MODEL_PATH"'", + "max_tokens": 64, + "temperature": 0.0, + "stream": false + }' | python3 -c "import sys,json; r=json.load(sys.stdin); print(' Sanity:', r['choices'][0]['message']['content'][:100])" 2>/dev/null || echo " Sanity check: could not parse response" +} + +echo "==========================================" +echo "MiniMax-M2 FP8 Performance Benchmark" +echo "==========================================" +echo "Model: $MODEL_PATH" +echo "Results: $RESULTS_DIR" +echo "" + +############################################################################### +# Config 1: BS=32, TP=64 + moe_tp=1/moe_ep=64, CB + bucketing (smallest BS +# that satisfies NxDI's Expert-Parallel BS >= num_experts/top_k requirement). +############################################################################### +CONFIG_NAME="bs32_tp64_moetp1_ep64" +echo "--- Config 1: BS=32, moe_tp=1/moe_ep=64, CB + bucketing ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 32 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + '"$COMMON_MINIMAX_CONFIG"', + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 32, + "ctx_batch_size": 1, + "tkg_batch_size": 32, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "enable_bucketing": true, + "context_encoding_buckets": [1024], + "token_generation_buckets": [1024], + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +run_bench "$CONFIG_NAME" 16 128 +run_bench "$CONFIG_NAME" 32 128 +stop_server + +############################################################################### +# Config 2: BS=128, TP=64 + moe_tp=1/moe_ep=64, CB + bucketing (throughput). +############################################################################### +CONFIG_NAME="bs128_tp64_moetp1_ep64" +echo "--- Config 2: BS=128, moe_tp=1/moe_ep=64, CB + bucketing ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 128 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + '"$COMMON_MINIMAX_CONFIG"', + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 128, + "ctx_batch_size": 1, + "tkg_batch_size": 128, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "enable_bucketing": true, + "context_encoding_buckets": [1024], + "token_generation_buckets": [1024], + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +run_bench "$CONFIG_NAME" 16 128 +run_bench "$CONFIG_NAME" 32 128 +run_bench "$CONFIG_NAME" 128 512 +stop_server + +echo "==========================================" +echo "MiniMax-M2 FP8 benchmarks complete!" +echo "Results saved to: $RESULTS_DIR" +echo "==========================================" +ls -la "$RESULTS_DIR" diff --git a/contrib/models/MiniMax-M2/perf_test/run_bench_single.sh b/contrib/models/MiniMax-M2/perf_test/run_bench_single.sh new file mode 100755 index 00000000..5d58f1ff --- /dev/null +++ b/contrib/models/MiniMax-M2/perf_test/run_bench_single.sh @@ -0,0 +1,75 @@ +#!/bin/bash +# Run a single vllm-bench-serve pass against an already-running vLLM server. +# +# Unlike bench_minimax_m2.sh this script does NOT launch or kill the vLLM +# server — you bring your own. That makes it convenient when the bench driver +# in bench_minimax_m2.sh times out during first-time compilation: the server +# keeps running, and once it's ready you can collect numbers with this. +# +# Usage: +# bash run_bench_single.sh # defaults: c=1, 16 prompts +# CONCURRENCY=16 NUM_PROMPTS=128 bash run_bench_single.sh +# CONFIG_NAME=bs256_tp1_ep64_opt CONCURRENCY=32 NUM_PROMPTS=128 bash run_bench_single.sh +# +# Environment knobs: +# PORT vLLM server port (default 8000) +# MINIMAX_M2_PATH Path to the BF16 checkpoint (default +# /opt/dlami/nvme/models/MiniMax-M2-BF16) +# CONCURRENCY --max-concurrency (default 1) +# NUM_PROMPTS --num-prompts (default 16) +# INPUT_LEN --random-input-len (default 900) +# OUTPUT_LEN --random-output-len (default 90) +# RANGE_RATIO --random-range-ratio (default 0.03) +# CONFIG_NAME Used in the output filename (default bs1_tp64_ep1) +# RESULTS_DIR Where to dump per-run log (default /tmp/bench_results/minimax_m2) + +set -e + +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +MODEL_PATH="${MINIMAX_M2_PATH:-/opt/dlami/nvme/models/MiniMax-M2-BF16}" +PORT="${PORT:-8000}" +CONCURRENCY="${CONCURRENCY:-1}" +NUM_PROMPTS="${NUM_PROMPTS:-16}" +INPUT_LEN="${INPUT_LEN:-900}" +OUTPUT_LEN="${OUTPUT_LEN:-90}" +RANGE_RATIO="${RANGE_RATIO:-0.03}" +CONFIG_NAME="${CONFIG_NAME:-bs1_tp64_ep1}" +RESULTS_DIR="${RESULTS_DIR:-/tmp/bench_results/minimax_m2}" + +mkdir -p "$RESULTS_DIR" + +echo "==========================================" +echo "MiniMax-M2 single-run benchmark" +echo "==========================================" +echo " Model: $MODEL_PATH" +echo " Port: $PORT" +echo " Config: $CONFIG_NAME" +echo " Concurrency: $CONCURRENCY" +echo " Prompts: $NUM_PROMPTS" +echo " Input len: $INPUT_LEN Output len: $OUTPUT_LEN" +echo " Results: $RESULTS_DIR/${CONFIG_NAME}_c${CONCURRENCY}.txt" +echo "" + +if ! curl -sf "http://localhost:$PORT/health" > /dev/null; then + echo "ERROR: vLLM server is not responding on http://localhost:$PORT" + echo "Start it first (e.g., bench_minimax_m2.sh) and wait until" + echo "'Application startup complete.' is printed." + exit 1 +fi + +vllm bench serve \ + --backend vllm \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --endpoint /v1/completions \ + --dataset-name random \ + --num-prompts "$NUM_PROMPTS" \ + --random-input-len "$INPUT_LEN" \ + --random-output-len "$OUTPUT_LEN" \ + --random-range-ratio "$RANGE_RATIO" \ + --max-concurrency "$CONCURRENCY" \ + 2>&1 | tee "$RESULTS_DIR/${CONFIG_NAME}_c${CONCURRENCY}.txt" + +echo "" +echo "Saved to: $RESULTS_DIR/${CONFIG_NAME}_c${CONCURRENCY}.txt" diff --git a/contrib/models/MiniMax-M2/perf_test/sanity_check.sh b/contrib/models/MiniMax-M2/perf_test/sanity_check.sh new file mode 100755 index 00000000..7449159f --- /dev/null +++ b/contrib/models/MiniMax-M2/perf_test/sanity_check.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Quick sanity check against an already-running vLLM server. +# +# Assumes vLLM is already listening on $PORT (default 8000) with MiniMax-M2 +# loaded. Sends a single chat completion and prints the model's reply. +# +# Usage: +# bash sanity_check.sh # uses defaults +# PORT=8001 bash sanity_check.sh # custom port +# PROMPT="..." bash sanity_check.sh # custom prompt + +set -e + +MODEL_PATH="${MINIMAX_M2_PATH:-/opt/dlami/nvme/models/MiniMax-M2-BF16}" +PORT="${PORT:-8000}" +PROMPT="${PROMPT:-What is 1+1? Answer briefly.}" +MAX_TOKENS="${MAX_TOKENS:-64}" + +echo "Sanity check: POST /v1/chat/completions on port $PORT" +echo " Model: $MODEL_PATH" +echo " Prompt: $PROMPT" +echo " Max tokens: $MAX_TOKENS" +echo "" + +if ! curl -sf "http://localhost:$PORT/health" > /dev/null; then + echo "ERROR: vLLM server is not responding on http://localhost:$PORT" + echo "Start it with 'bash bench_minimax_m2.sh' or your own launcher first." + exit 1 +fi + +RESPONSE=$(curl -s "http://localhost:$PORT/v1/chat/completions" \ + -H 'Content-Type: application/json' \ + -d "$(cat </dev/null || echo "$RESPONSE" +echo "" + +REPLY=$(echo "$RESPONSE" | python3 -c " +import json, sys +try: + r = json.load(sys.stdin) + print(r['choices'][0]['message']['content'].strip()) +except Exception as e: + print(f'(could not parse reply: {e})') +" 2>/dev/null) + +echo "Model reply: $REPLY" diff --git a/contrib/models/MiniMax-M2/perf_test/smoke_compile_minimax_m2.py b/contrib/models/MiniMax-M2/perf_test/smoke_compile_minimax_m2.py new file mode 100644 index 00000000..11bf2a05 --- /dev/null +++ b/contrib/models/MiniMax-M2/perf_test/smoke_compile_minimax_m2.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 +"""Minimal compile+load smoke test for MiniMax-M2.7 FP8 on Trn2. + +Bypasses vLLM entirely so we can iterate on the preprocessed Neuron-FP8 +checkpoint without paying vllm-neuron's startup cost. Builds the Flash BS=1 +recipe (TP=64, EP=1, blockwise FP8 for routed experts), compiles to a temp +dir, then loads. EP=1 lets the TKG path enter forward_selective_loading +legally so BS=1 compiles — with EP>1 NxDI raises NotImplementedError and +forces BS>=num_experts/top_k = 32. + +STAGE controls how far we go: + instantiate | compile | load | all (default: all) + +DRY_RUN=1 does HLO-only compile (no torch.jit.save + shard). Fastest sanity +check for the preprocessed checkpoint. SKIP_WARMUP=1 on load() skips the +forward pass that allocates the shared scratchpad — useful when HBM is +tight. + +Run under /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16 (same venv used +by the bench script). +""" + +import os +import sys +import time +import traceback + +MODEL_PATH = os.environ.get( + "MINIMAX_M2_MODEL_PATH", + "/opt/dlami/nvme/models/MiniMax-M2.7-Neuron-FP8", +) +COMPILED_PATH = os.environ.get( + "MINIMAX_M2_COMPILED_PATH", + "/opt/dlami/nvme/compiled/minimax_m2_tp64_moetp1_ep64_fp8/", +) + +TP_DEGREE = int(os.environ.get("TP_DEGREE", "64")) +SEQ_LEN = int(os.environ.get("SEQ_LEN", "1024")) +BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1")) +CTX_BATCH_SIZE = int(os.environ.get("CTX_BATCH_SIZE", "1")) +# Default to moe_tp=1 / moe_ep=64. Under FP8 + moe_tp=64 (our old default) +# each rank's MoE expert intermediate slice is 32 rows (<128, the scale +# block size), which collapses the per-rank scale to a singleton in +# NxDI's `_setup_for_scale` — losing per-channel FP8 scale granularity +# and producing a BF16-accumulator drift that compounds into output +# collapse after ~30 decode tokens. moe_tp=1/moe_ep=64 keeps every expert +# on a single rank (4 full experts per rank), so each expert's scale +# survives intact. Override via MOE_TP / MOE_EP env vars for other recipes. +MOE_TP = int(os.environ.get("MOE_TP", "1")) +MOE_EP = int(os.environ.get("MOE_EP", "64")) + +STAGE = os.environ.get("STAGE", "all").lower() + +os.makedirs(COMPILED_PATH, exist_ok=True) + +# NxDI's model builder uses a per-process temp workdir for HLO/NEFF staging +# (BASE_COMPILE_WORK_DIR, default "/tmp/nxd_model/"). If two compiles run in +# parallel with the same default, they silently overwrite each other's +# .hlo_module.pb files and one or both compilations crash with +# "neuronx-cc returned non-zero exit status 70". Pin the workdir to a +# unique per-COMPILED_PATH subdir to stay safe under any parallel invocation. +os.environ.setdefault( + "BASE_COMPILE_WORK_DIR", + os.path.join("/tmp/nxd_model", os.path.basename(COMPILED_PATH.rstrip("/"))), +) + + +def main(): + from neuronx_distributed_inference.models.config import MoENeuronConfig + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + + # Import the contrib wrapper (sibling src dir). + contrib_src = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "src", + ) + sys.path.insert(0, os.path.abspath(contrib_src)) + + from modeling_minimax_m2 import ( + MiniMaxM2InferenceConfig, + NeuronMiniMaxM2ForCausalLM, + ) + + print(f"[smoke] MODEL_PATH={MODEL_PATH}") + print(f"[smoke] COMPILED_PATH={COMPILED_PATH}") + print(f"[smoke] TP_DEGREE={TP_DEGREE}, SEQ_LEN={SEQ_LEN}, BS={BATCH_SIZE}") + print(f"[smoke] MOE_TP={MOE_TP}, MOE_EP={MOE_EP}") + print(f"[smoke] STAGE={STAGE}") + + print("[smoke] Building MoENeuronConfig (quantized FP8 MoE, blockwise_symmetric)...") + # NOTE: ep_degree at the top level controls the OUTER (full model) + # expert-parallel factor, which multiplies world_size to + # tp_degree * ep_degree and duplicates non-MoE weights per replica. + # At world_size > 64 on a 64-NC Trn2, sharded weights grow accordingly + # (e.g. tp=64 + ep=4 -> 256 ranks -> 4x the sharded checkpoint size, + # and at runtime the model doesn't fit on the device). For MoE-only + # EP we want ep_degree=1 at the outer level and the per-MoE split + # controlled solely by moe_ep_degree (which Pro's working benches + # also do). Keep ep_degree=1 unconditionally. + neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + ep_degree=1, + logical_nc_config=2, + batch_size=BATCH_SIZE, + max_batch_size=BATCH_SIZE, + ctx_batch_size=CTX_BATCH_SIZE, + tkg_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + n_active_tokens=128, + torch_dtype="bfloat16", + capacity_factor=1.0, + glu_mlp=True, + moe_ep_degree=MOE_EP, + moe_tp_degree=MOE_TP, + context_encoding_buckets=[SEQ_LEN], + router_config={"act_fn": "sigmoid", "dtype": "float32"}, + # SDK 2.29 ships only bwmm_shard_on_block / bwmm_shard_on_intermediate; + # default routes to _call_shard_hidden_kernel which is missing, so we + # take the shard-on-block path via this flag. + blockwise_matmul_config={ + "use_shard_on_block_dynamic_while": True, + "block_sharding_strategy": "PING_PONG", + }, + # Persist sharded FP8 weights to disk so subsequent load()s skip the + # ~10-minute shard_checkpoint step (writes weights/tp{0..63}_*.safetensors + # on NVMe; NxDI load() reads these directly when present). + save_sharded_checkpoint=True, + # FP8 blockwise for routed experts (Kimi-K2 recipe). + quantized=True, + quantized_checkpoints_path=MODEL_PATH, + quantization_dtype="f8e4m3", + quantization_type="blockwise_symmetric", + quantization_block_axis=[1, 2], + quantization_block_size=[128, 128], + modules_to_not_convert=[ + # M2.7 HF quantization_config.modules_to_not_convert lists + # {gate, e_score_correction_bias, lm_head}. Plus the usual + # embed_tokens / norm kept as BF16 on the NxDI side. + "embed_tokens", + "lm_head", + "norm", + "router", + # o_proj on the NxDI side binds to a plain RowParallelLinear + # (not the auto-swapped QuantizedRowParallel), so we dequantize + # it to BF16 during preprocess AND list it here so NxDI does + # not try to swap it to the quantized class at convert() time. + # Without this the loader drops o_proj.weight and .scale as + # "redundant" during checkpoint sharding and attention output + # becomes garbage. + "o_proj", + ], + ) + + print("[smoke] Building MiniMaxM2InferenceConfig...") + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + config = MiniMaxM2InferenceConfig( + neuron_config, load_config=load_pretrained_config(hf_config=hf_config) + ) + print(f"[smoke] config.hidden_size={config.hidden_size}") + print(f"[smoke] config.num_hidden_layers={config.num_hidden_layers}") + print(f"[smoke] config.num_local_experts={config.num_local_experts}") + print(f"[smoke] config.num_experts_per_tok={config.num_experts_per_tok}") + print(f"[smoke] config.num_key_value_heads={config.num_key_value_heads}") + print(f"[smoke] config.attn_type_list[:5]={getattr(config, 'attn_type_list', [None]*5)[:5]}") + + print("[smoke] Instantiating NeuronMiniMaxM2ForCausalLM (build model-on-cpu)...") + t0 = time.time() + model = NeuronMiniMaxM2ForCausalLM(MODEL_PATH, config) + print(f"[smoke] Instantiated in {time.time() - t0:.1f}s") + + if STAGE == "instantiate": + print("[smoke] STAGE=instantiate only, skipping compile/load.") + return + + DRY_RUN = os.environ.get("DRY_RUN", "0") == "1" + if STAGE in ("compile", "all"): + label = "Dry-run compile (HLO only)" if DRY_RUN else "Full compile" + print(f"[smoke] {label} -> {COMPILED_PATH}") + t0 = time.time() + try: + model.compile(COMPILED_PATH, dry_run=DRY_RUN) + print(f"[smoke] {label} OK in {time.time() - t0:.1f}s") + except Exception: + print(f"[smoke] {label} FAILED:") + traceback.print_exc() + raise + + if STAGE in ("load", "all") and not DRY_RUN: + SKIP_WARMUP = os.environ.get("SKIP_WARMUP", "1") == "1" + print(f"[smoke] Loading compiled model from {COMPILED_PATH} (skip_warmup={SKIP_WARMUP})") + t0 = time.time() + model.load(COMPILED_PATH, skip_warmup=SKIP_WARMUP) + print(f"[smoke] Loaded in {time.time() - t0:.1f}s") + + print("[smoke] Done.") + + +if __name__ == "__main__": + try: + main() + except Exception: + traceback.print_exc() + sys.exit(1) diff --git a/contrib/models/MiniMax-M2/perf_test/smoke_generate_minimax_m2.py b/contrib/models/MiniMax-M2/perf_test/smoke_generate_minimax_m2.py new file mode 100644 index 00000000..146a4287 --- /dev/null +++ b/contrib/models/MiniMax-M2/perf_test/smoke_generate_minimax_m2.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +"""Minimal generate smoke test for MiniMax-M2.7 FP8 on Trn2. + +Assumes the compiled NEFF already exists at MINIMAX_M2_COMPILED_PATH +(from smoke_compile_minimax_m2.py). Rebuilds the same MoENeuronConfig / +Flash wrapper, loads with skip_warmup=False, and generates 20 tokens for a +single prompt via HuggingFaceGenerationAdapter. Purpose: sanity-check that +the FP8 MoE + preprocessed scales actually produce coherent tokens. + +Run under /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16. +""" + +import os +import sys +import time +import traceback + +MODEL_PATH = os.environ.get( + "MINIMAX_M2_MODEL_PATH", + "/opt/dlami/nvme/models/MiniMax-M2.7-Neuron-FP8", +) +COMPILED_PATH = os.environ.get( + "MINIMAX_M2_COMPILED_PATH", + "/opt/dlami/nvme/compiled/minimax_m2_tp64_moetp1_ep64_fp8/", +) + +# Must match smoke_compile_minimax_m2.py exactly, else load() sees a +# mismatched NEFF. +TP_DEGREE = int(os.environ.get("TP_DEGREE", "64")) +SEQ_LEN = int(os.environ.get("SEQ_LEN", "1024")) +BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1")) +CTX_BATCH_SIZE = int(os.environ.get("CTX_BATCH_SIZE", "1")) +MOE_TP = int(os.environ.get("MOE_TP", "1")) +MOE_EP = int(os.environ.get("MOE_EP", "64")) + +PROMPT = os.environ.get( + "MINIMAX_M2_PROMPT", + "Hello! Please introduce yourself in one sentence.", +) +MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "20")) + +# Keep the per-compile BASE_COMPILE_WORK_DIR in sync with +# smoke_compile_minimax_m2.py so load() under the same COMPILED_PATH +# doesn't collide with a concurrent compile or reuse a stale workdir. +os.environ.setdefault( + "BASE_COMPILE_WORK_DIR", + os.path.join("/tmp/nxd_model", os.path.basename(COMPILED_PATH.rstrip("/"))), +) + + +def main(): + from transformers import AutoConfig, AutoTokenizer, GenerationConfig + + from neuronx_distributed_inference.models.config import MoENeuronConfig + from neuronx_distributed_inference.utils.hf_adapter import ( + HuggingFaceGenerationAdapter, + load_pretrained_config, + ) + + contrib_src = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "src", + ) + sys.path.insert(0, os.path.abspath(contrib_src)) + + from modeling_minimax_m2 import ( + MiniMaxM2InferenceConfig, + NeuronMiniMaxM2ForCausalLM, + ) + + print(f"[gen] MODEL_PATH={MODEL_PATH}") + print(f"[gen] COMPILED_PATH={COMPILED_PATH}") + print(f"[gen] TP={TP_DEGREE}, SEQ={SEQ_LEN}, BS={BATCH_SIZE}") + + # Outer ep_degree must match the compile-time value (kept at 1 so + # world_size = tp_degree; see smoke_compile_minimax_m2.py comment). + neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + ep_degree=1, + logical_nc_config=2, + batch_size=BATCH_SIZE, + max_batch_size=BATCH_SIZE, + ctx_batch_size=CTX_BATCH_SIZE, + tkg_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + n_active_tokens=128, + torch_dtype="bfloat16", + capacity_factor=1.0, + glu_mlp=True, + moe_ep_degree=MOE_EP, + moe_tp_degree=MOE_TP, + context_encoding_buckets=[SEQ_LEN], + router_config={"act_fn": "sigmoid", "dtype": "float32"}, + blockwise_matmul_config={ + "use_shard_on_block_dynamic_while": True, + "block_sharding_strategy": "PING_PONG", + }, + save_sharded_checkpoint=True, + quantized=True, + quantized_checkpoints_path=MODEL_PATH, + quantization_dtype="f8e4m3", + quantization_type="blockwise_symmetric", + quantization_block_axis=[1, 2], + quantization_block_size=[128, 128], + modules_to_not_convert=[ + "embed_tokens", + "lm_head", + "norm", + "router", + "o_proj", # binds to plain RowParallelLinear; preprocess dequants to BF16. + ], + ) + + hf_config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + config = MiniMaxM2InferenceConfig( + neuron_config, load_config=load_pretrained_config(hf_config=hf_config) + ) + + print("[gen] Instantiating model...") + t0 = time.time() + model = NeuronMiniMaxM2ForCausalLM(MODEL_PATH, config) + print(f"[gen] Instantiated in {time.time() - t0:.1f}s") + + # skip_warmup=False so generate() hits a primed graph (the warmup forward + # allocates the shared scratchpad the generation path needs). + print(f"[gen] Loading from {COMPILED_PATH} (skip_warmup=False)") + t0 = time.time() + model.load(COMPILED_PATH, skip_warmup=False) + print(f"[gen] Loaded in {time.time() - t0:.1f}s") + + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + # M2 tokenizer doesn't ship a pad_token; fall back to eos so batched + # `padding=True` works when BATCH_SIZE > 1. + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + adapter = HuggingFaceGenerationAdapter(model) + + inputs = tokenizer([PROMPT] * BATCH_SIZE, return_tensors="pt", padding=True) + gen_config = GenerationConfig( + max_new_tokens=MAX_NEW_TOKENS, + min_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + pad_token_id=getattr(tokenizer, "pad_token_id", None) or tokenizer.eos_token_id, + ) + + print(f"[gen] prompt: {PROMPT!r}") + print(f"[gen] input_ids.shape={tuple(inputs['input_ids'].shape)}") + t0 = time.time() + output_ids = adapter.generate( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + generation_config=gen_config, + ) + dt = time.time() - t0 + + prompt_len = inputs["input_ids"].shape[1] + new_tokens = output_ids[0, prompt_len:] + decoded = tokenizer.decode(new_tokens, skip_special_tokens=True) + full = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + print(f"[gen] generated {new_tokens.numel()} tokens in {dt:.2f}s " + f"({new_tokens.numel() / dt:.2f} tok/s)") + print(f"[gen] new token ids: {new_tokens.tolist()}") + print(f"[gen] new text : {decoded!r}") + print(f"[gen] full text : {full!r}") + print("[gen] Done.") + + +if __name__ == "__main__": + try: + main() + except Exception: + traceback.print_exc() + sys.exit(1) diff --git a/contrib/models/MiniMax-M2/perf_test/vllm-neuron-patch.patch b/contrib/models/MiniMax-M2/perf_test/vllm-neuron-patch.patch new file mode 100644 index 00000000..4a84a558 --- /dev/null +++ b/contrib/models/MiniMax-M2/perf_test/vllm-neuron-patch.patch @@ -0,0 +1,107 @@ +diff --git a/vllm_neuron/worker/neuronx_distributed_model_loader.py b/vllm_neuron/worker/neuronx_distributed_model_loader.py +index d2099eb..0c162e4 100644 +--- a/vllm_neuron/worker/neuronx_distributed_model_loader.py ++++ b/vllm_neuron/worker/neuronx_distributed_model_loader.py +@@ -922,6 +922,94 @@ def _camel_to_kebab(name: str) -> str: + return re.sub("([a-z0-9])([A-Z])", r"\1-\2", s1).lower() + + ++ ++def _patch_autoconfig_trust_remote_code(): ++ """Monkey-patch ``AutoConfig.from_pretrained`` to default ``trust_remote_code=True``. ++ ++ NxDI's ``hf_adapter.load_config`` calls ``AutoConfig.from_pretrained(path)`` ++ without ``trust_remote_code``. Contrib models like MiMo-V2-Flash that ++ ship a ``configuration_*.py`` with the checkpoint require custom code ++ execution, so the default behaviour crashes with ``ValueError: The ++ repository ... contains custom code which must be executed``. ++ ++ vLLM's top-level ``--trust-remote-code`` flag only affects vLLM's own ++ config load, not NxDI's. Patching here is cheap and idempotent. ++ """ ++ try: ++ from transformers import AutoConfig ++ except ImportError: ++ return ++ if getattr(AutoConfig, "_nxdi_contrib_patched", False): ++ return ++ _orig = AutoConfig.from_pretrained ++ ++ def _patched(*args, **kwargs): ++ kwargs.setdefault("trust_remote_code", True) ++ return _orig(*args, **kwargs) ++ ++ AutoConfig.from_pretrained = _patched ++ AutoConfig._nxdi_contrib_patched = True ++ ++ ++def _register_contrib_models(): ++ """Lazy-register NxDI contrib models on each process that calls the loader. ++ ++ Driven by env vars: ++ NXDI_CONTRIB_MIMO_V2_FLASH_SRC -> path to contrib MiMo-V2-Flash src/ ++ NXDI_CONTRIB_MINIMAX_M2_SRC -> path to contrib MiniMax-M2 src/ ++ ++ Registers the contrib model class into NxDI's MODEL_TYPES and, where ++ vLLM does not already know the architecture, registers it into vLLM's ++ ModelRegistry. Runs every time _get_neuron_model_cls is called so that ++ vLLM's spawn'd EngineCore workers (which don't inherit the parent's ++ module-level state) pick up the registration too. Registration is ++ idempotent. ++ """ ++ import os as _os ++ import sys as _sys ++ import warnings as _w ++ ++ _patch_autoconfig_trust_remote_code() ++ ++ mimo_src = _os.environ.get("NXDI_CONTRIB_MIMO_V2_FLASH_SRC") ++ if mimo_src and _os.path.isdir(mimo_src) and "mimov2flash" not in MODEL_TYPES: ++ if mimo_src not in _sys.path: ++ _sys.path.insert(0, mimo_src) ++ try: ++ from modeling_mimo_v2 import NeuronMiMoV2ForCausalLM ++ MODEL_TYPES.setdefault( ++ "mimov2flash", {"causal-lm": NeuronMiMoV2ForCausalLM} ++ ) ++ try: ++ from vllm.model_executor.models.registry import ModelRegistry ++ if "MiMoV2FlashForCausalLM" not in ModelRegistry.get_supported_archs(): ++ ModelRegistry.register_model( ++ "MiMoV2FlashForCausalLM", NeuronMiMoV2ForCausalLM ++ ) ++ except ImportError: ++ pass ++ except Exception as e: ++ _w.warn( ++ f"Failed to register MiMo-V2-Flash contrib model: {e}", ++ category=UserWarning, ++ ) ++ ++ minimax_src = _os.environ.get("NXDI_CONTRIB_MINIMAX_M2_SRC") ++ if minimax_src and _os.path.isdir(minimax_src) and "minimaxm2" not in MODEL_TYPES: ++ if minimax_src not in _sys.path: ++ _sys.path.insert(0, minimax_src) ++ try: ++ from modeling_minimax_m2 import NeuronMiniMaxM2ForCausalLM ++ MODEL_TYPES.setdefault( ++ "minimaxm2", {"causal-lm": NeuronMiniMaxM2ForCausalLM} ++ ) ++ except Exception as e: ++ _w.warn( ++ f"Failed to register MiniMax-M2 contrib model: {e}", ++ category=UserWarning, ++ ) ++ ++ + def _get_neuron_model_cls(architecture: str): + """ + Get Neuron model class from architecture string. +@@ -941,6 +1029,7 @@ def _get_neuron_model_cls(architecture: str): + _get_neuron_model_cls("NeuronLlamaForCausalLM") + + """ ++ _register_contrib_models() + # Handle Neuron class name (starts with "Neuron") - strip prefix + if architecture.startswith("Neuron") and "For" in architecture: + original_architecture = architecture diff --git a/contrib/models/MiniMax-M2/src/__init__.py b/contrib/models/MiniMax-M2/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiniMax-M2/src/config.json b/contrib/models/MiniMax-M2/src/config.json new file mode 100644 index 00000000..237efe37 --- /dev/null +++ b/contrib/models/MiniMax-M2/src/config.json @@ -0,0 +1,112 @@ +{ + "architectures": [ + "MiniMaxM2ForCausalLM" + ], + "attention_dropout": 0.0, + "attn_type_list": [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1 + ], + "auto_map": { + "AutoConfig": "configuration_minimax_m2.MiniMaxM2Config", + "AutoModelForCausalLM": "modeling_minimax_m2.MiniMaxM2ForCausalLM" + }, + "bos_token_id": null, + "eos_token_id": null, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 3072, + "initializer_range": 0.02, + "intermediate_size": 1536, + "layernorm_full_attention_beta": 1.0, + "layernorm_linear_attention_beta": 1.0, + "layernorm_mlp_beta": 1.0, + "max_position_embeddings": 196608, + "mlp_intermediate_size": 8192, + "model_type": "minimax_m2", + "mtp_transformer_layers": 1, + "num_attention_heads": 48, + "num_experts_per_tok": 8, + "num_hidden_layers": 62, + "num_key_value_heads": 8, + "num_local_experts": 256, + "num_mtp_modules": 3, + "output_router_logits": false, + "qk_norm_type": "per_layer", + "rms_norm_eps": 1e-06, + "rope_theta": 5000000, + "rotary_dim": 64, + "router_aux_loss_coef": 0.001, + "router_jitter_noise": 0.0, + "scoring_func": "sigmoid", + "shared_intermediate_size": 0, + "shared_moe_mode": "sigmoid", + "sliding_window": null, + "tie_word_embeddings": false, + "transformers_version": "4.57.1", + "use_cache": true, + "use_mtp": true, + "use_qk_norm": true, + "use_routing_bias": true, + "vocab_size": 200064 +} diff --git a/contrib/models/MiniMax-M2/src/configuration_minimax_m2.py b/contrib/models/MiniMax-M2/src/configuration_minimax_m2.py new file mode 100644 index 00000000..76ea7dcc --- /dev/null +++ b/contrib/models/MiniMax-M2/src/configuration_minimax_m2.py @@ -0,0 +1,201 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/minimax_m2/modular_minimax_m2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_minimax_m2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from transformers.configuration_utils import PretrainedConfig + + + +class MiniMaxM2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MiniMaxM2Model`]. It is used to instantiate an + MiniMaxM2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the MiniMaxM2-7B-v0.1 or MiniMaxM2-7B-Instruct-v0.1. + + [minimax_m2ai/MiniMaxM2-8x7B](https://huggingface.co/minimax_m2ai/MiniMaxM2-8x7B) + [minimax_m2ai/MiniMaxM2-7B-Instruct-v0.1](https://huggingface.co/minimax_m2ai/MiniMaxM2-7B-Instruct-v0.1) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MiniMaxM2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MiniMaxM2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. MiniMaxM2's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + router_jitter_noise (`float`, *optional*, defaults to 0.0): + Amount of noise to add to the router. + + ```python + >>> from transformers import MiniMaxM2Model, MiniMaxM2Config + + >>> # Initializing a MiniMaxM2 7B style configuration + >>> configuration = MiniMaxM2Config() + + >>> # Initializing a model from the MiniMaxM2 7B style configuration + >>> model = MiniMaxM2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "minimax_m2" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts + "layers.*.block_sparse_moe.experts.*.w1": "colwise", + "layers.*.block_sparse_moe.experts.*.w2": "rowwise", + "layers.*.block_sparse_moe.experts.*.w3": "colwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + head_dim=None, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.head_dim = head_dim + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + + self.use_qk_norm = kwargs.pop("use_qk_norm", False) + self.rotary_dim = kwargs.pop("rotary_dim", self.head_dim) + self.partial_rotary_factor = kwargs.pop("partial_rotary_factor", 1) + if self.head_dim is not None: + self.partial_rotary_factor = self.rotary_dim / self.head_dim + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["MiniMaxM2Config"] \ No newline at end of file diff --git a/contrib/models/MiniMax-M2/src/conversion_script/preprocess_minimax_m2_fp8.py b/contrib/models/MiniMax-M2/src/conversion_script/preprocess_minimax_m2_fp8.py new file mode 100644 index 00000000..d741a993 --- /dev/null +++ b/contrib/models/MiniMax-M2/src/conversion_script/preprocess_minimax_m2_fp8.py @@ -0,0 +1,476 @@ +""" +Preprocess MiniMax-M2 / M2.7 FP8 checkpoint for Neuron inference. + +Streaming (per-layer) rewrite that mirrors the MiMo-V2-Flash preprocess +flow. The HF checkpoint ships 130 sharded safetensors files with weights +of one layer scattered across several shards; this script keeps one +`safe_open` handle live at a time (via `LazyWeightMap`) and writes one +output file per decoder layer (`model_layer{N}.safetensors`), plus +`model_extras.safetensors` for embed / norm / lm_head. Peak RAM is +~15 GB and total runtime is ~20 minutes on trn2.48xlarge. + +MiniMax-M2 checkpoint layout: + - q/k/v/o_proj and expert w1/w2/w3 are stored FP8 blockwise (128x128) + with separate `.weight_scale_inv` fp32 tensors. The NxDI-side layers + expect: + Quantized{Column,Row}Parallel for attention q/k/v/o and + QuantizedExpertFused{Column,Row}Parallel for MoE experts + so we rescale OCP FP8 (±448) to Neuron FP8 (±240) and emit `.scale`. + - q/k/v and o_proj are 2D with out-dim (= num_heads * head_dim) that at + TP=64 goes below the 128-row scale block. The NxDI modeling code + runs a `_apply_2d_per_channel_fix` monkey-patch at compile time to + swap these layers' q_config to PER_CHANNEL_SYMMETRIC, which expects + per-row scales of shape [out, 1]. So for these 2D tensors we emit + per-row Neuron-FP8 scales (one scalar per output row). + - Expert w1/w2/w3 stay block-quantized; we fuse gate+up along the last + dim and stack experts to match ExpertFusedRowParallelLinear's packed + layout: + gate_up_proj.weight [num_experts, hidden, 2*IM] + gate_up_proj.scale [num_experts, H_blocks, 2*IM_blocks] + down_proj.weight [num_experts, IM, hidden] + down_proj.scale [num_experts, IM_blocks, H_blocks] + - `block_sparse_moe.gate.weight` and `e_score_correction_bias` are + renamed into the NxDI router namespace: + block_sparse_moe.router.linear_router.weight + block_sparse_moe.router.e_score_correction_bias + - embed_tokens / norm / lm_head / per-layer {input,post_attention}_ + layernorm / q_norm / k_norm are passed through BF16 unchanged. + +Output layout: + save_path/ + config.json, tokenizer.*, chat_template.jinja if present + configuration_minimax_m2.py, modeling_minimax_m2.py (trust_remote_code) + model.safetensors.index.json (regenerated) + model_extras.safetensors (embed_tokens, norm, lm_head) + model_layer{N}.safetensors (one per decoder layer, N=0..61) + +Usage: + python preprocess_minimax_m2_fp8.py \\ + --hf_model_path /opt/dlami/nvme/models/MiniMax-M2.7 \\ + --save_path /opt/dlami/nvme/models/MiniMax-M2.7-Neuron-FP8 \\ + --tp_degree 64 +""" + +import argparse +import gc +import json +import os +import shutil +import time +from typing import Dict, List, Optional, Tuple + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + + +FP8_SCALING_FACTOR = 448.0 / 240.0 +NEURON_FP8_MAX = 240.0 + + +# --------------------------------------------------------------------------- +# Quantization primitives +# --------------------------------------------------------------------------- + +def convert_bf16_to_fp8_per_row( + weight: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """BF16 [out, in] -> Neuron FP8 per-row (scales shape [out, 1]).""" + weight_float = weight.float() + row_max_abs = weight_float.abs().max(dim=1, keepdim=True)[0] + scales = torch.clamp(row_max_abs / NEURON_FP8_MAX, min=1e-10) + quantized = (weight_float / scales).to(torch.float8_e4m3fn) + return quantized, scales.to(torch.float32) + + +def rescale_fp8_to_per_row( + weight: torch.Tensor, scale: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Block-wise FP8 + blockwise scale -> Neuron per-row FP8. + + Dequantize to float32 using block broadcast, then per-row requantize. + """ + out_features, in_features = weight.shape + scale_h, scale_w = scale.shape + + block_h = (out_features + scale_h - 1) // scale_h + block_w = (in_features + scale_w - 1) // scale_w + + weight_float = weight.float() + dequantized = torch.zeros(out_features, in_features, dtype=torch.float32) + for i in range(scale_h): + for j in range(scale_w): + h0, h1 = i * block_h, min((i + 1) * block_h, out_features) + w0, w1 = j * block_w, min((j + 1) * block_w, in_features) + dequantized[h0:h1, w0:w1] = ( + weight_float[h0:h1, w0:w1] * scale[i, j].item() + ) + + row_max_abs = dequantized.abs().max(dim=1, keepdim=True)[0] + scales = torch.clamp(row_max_abs / NEURON_FP8_MAX, min=1e-10) + quantized = (dequantized / scales).to(torch.float8_e4m3fn) + return quantized, scales.to(torch.float32) + + +def rescale_fp8_weight_blockwise( + weight: torch.Tensor, scale: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Keep blockwise scales, just rescale into Neuron FP8 range. + + MoE expert weights stay block-quantized; only the dtype range changes. + """ + weight_bf16 = weight.bfloat16() + rescaled = (weight_bf16 / FP8_SCALING_FACTOR).to(torch.float8_e4m3fn) + neuron_scale = scale.float() * FP8_SCALING_FACTOR + return rescaled, neuron_scale.to(torch.float32) + + +# --------------------------------------------------------------------------- +# Streaming weight access (one open safetensors handle at a time) +# --------------------------------------------------------------------------- + +class LazyWeightMap: + """Lazily fetch tensors from sharded safetensors, keeping one handle live.""" + + def __init__(self, model_dir: str, weight_map: Dict[str, str]): + self.model_dir = model_dir + self.weight_map = weight_map + self._cur_filename: Optional[str] = None + self._cur_handle = None + + def _open(self, filename: str): + if self._cur_filename == filename: + return self._cur_handle + if self._cur_handle is not None: + self._cur_handle.__exit__(None, None, None) + self._cur_handle = None + path = os.path.join(self.model_dir, filename) + self._cur_handle = safe_open(path, framework="pt", device="cpu") + self._cur_handle.__enter__() + self._cur_filename = filename + return self._cur_handle + + def get(self, key: str) -> Optional[torch.Tensor]: + filename = self.weight_map.get(key) + if filename is None: + return None + return self._open(filename).get_tensor(key) + + def has(self, key: str) -> bool: + return key in self.weight_map + + def close(self): + if self._cur_handle is not None: + self._cur_handle.__exit__(None, None, None) + self._cur_handle = None + self._cur_filename = None + + +# --------------------------------------------------------------------------- +# Per-tensor helper +# --------------------------------------------------------------------------- + +def _maybe_fp8_to_neuron_per_row( + weight: torch.Tensor, scale: Optional[torch.Tensor] +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """FP8 blockwise -> per-row, or BF16 -> FP8 per-row. Pass-through otherwise.""" + if weight.dtype == torch.float8_e4m3fn and scale is not None: + return rescale_fp8_to_per_row(weight, scale) + if weight.dtype == torch.bfloat16: + return convert_bf16_to_fp8_per_row(weight) + return weight, scale + + +# --------------------------------------------------------------------------- +# Per-layer processing +# --------------------------------------------------------------------------- + +def process_layer( + layer_idx: int, + lazy: LazyWeightMap, + config: dict, +) -> Dict[str, torch.Tensor]: + out: Dict[str, torch.Tensor] = {} + prefix = f"model.layers.{layer_idx}." + out_prefix = f"layers.{layer_idx}." + + # --- Layer norms (BF16, untouched) --- + for name in ("input_layernorm", "post_attention_layernorm"): + t = lazy.get(f"{prefix}{name}.weight") + if t is not None: + out[f"{out_prefix}{name}.weight"] = t.detach().clone() + + # --- Attention: q/k/v/o, all FP8 -> Neuron FP8 per-row --- + # q/k/v -> Neuron FP8 per-row (go through QuantizedColumnParallel). + for proj in ("q_proj", "k_proj", "v_proj"): + w = lazy.get(f"{prefix}self_attn.{proj}.weight") + if w is None: + continue + s = lazy.get(f"{prefix}self_attn.{proj}.weight_scale_inv") + w2, s2 = _maybe_fp8_to_neuron_per_row(w, s) + out[f"{out_prefix}self_attn.{proj}.weight"] = w2 + if s2 is not None: + out[f"{out_prefix}self_attn.{proj}.scale"] = s2 + + # o_proj -> BF16 (dequantized). On the Neuron side the modeling code + # binds self_attn.o_proj to a plain RowParallelLinear, NOT the auto- + # swapped QuantizedRowParallel — so the NxDI loader does not expect + # .scale or FP8 weight bytes for o_proj and would drop them as + # "redundant" during checkpoint sharding, leaving the projection + # zero-initialised and producing garbage outputs. Dequantizing here + # and emitting only a BF16 .weight matches what the loader expects. + # The bench / smoke config must also add "o_proj" to + # modules_to_not_convert to keep NxDI from trying to re-swap this + # layer to QuantizedRowParallel during convert(). + o_w = lazy.get(f"{prefix}self_attn.o_proj.weight") + o_s = lazy.get(f"{prefix}self_attn.o_proj.weight_scale_inv") + if o_w is not None: + if o_w.dtype == torch.float8_e4m3fn: + assert o_s is not None, "FP8 o_proj requires weight_scale_inv" + out_features, in_features = o_w.shape + scale_h, scale_w = o_s.shape + block_h = (out_features + scale_h - 1) // scale_h + block_w = (in_features + scale_w - 1) // scale_w + wf = o_w.float() + tmp = torch.zeros(out_features, in_features, dtype=torch.float32) + for i in range(scale_h): + for j in range(scale_w): + h0, h1 = i * block_h, min((i + 1) * block_h, out_features) + w0, w1 = j * block_w, min((j + 1) * block_w, in_features) + tmp[h0:h1, w0:w1] = wf[h0:h1, w0:w1] * o_s[i, j].item() + o_bf16 = tmp.to(torch.bfloat16) + else: + o_bf16 = o_w.to(torch.bfloat16) + out[f"{out_prefix}self_attn.o_proj.weight"] = o_bf16.detach().clone() + + # --- QK norm (BF16) --- + for name in ("q_norm", "k_norm"): + t = lazy.get(f"{prefix}self_attn.{name}.weight") + if t is not None: + out[f"{out_prefix}self_attn.{name}.weight"] = t.detach().clone() + + # --- MoE router --- + # HF: block_sparse_moe.gate.weight + # NxDI: block_sparse_moe.router.linear_router.weight + router_w = lazy.get(f"{prefix}block_sparse_moe.gate.weight") + if router_w is not None: + out[f"{out_prefix}block_sparse_moe.router.linear_router.weight"] = ( + router_w.detach().clone() + ) + router_bias = lazy.get(f"{prefix}block_sparse_moe.e_score_correction_bias") + if router_bias is not None: + out[f"{out_prefix}block_sparse_moe.router.e_score_correction_bias"] = ( + router_bias.detach().clone() + ) + + # --- MoE experts: fuse gate+up, stack across experts --- + num_experts = config["num_local_experts"] + + # Peek expert 0 to know shapes/dtypes. + e0_w1 = lazy.get(f"{prefix}block_sparse_moe.experts.0.w1.weight") + if e0_w1 is None: + return out + e0_w1_s = lazy.get(f"{prefix}block_sparse_moe.experts.0.w1.weight_scale_inv") + + if e0_w1.dtype == torch.float8_e4m3fn and e0_w1_s is not None: + sample_w, sample_s = rescale_fp8_weight_blockwise(e0_w1, e0_w1_s) + elif e0_w1.dtype == torch.bfloat16: + raise NotImplementedError( + f"Layer {layer_idx} expert 0 w1 is BF16; MiniMax-M2 expects FP8." + ) + else: + sample_w, sample_s = e0_w1, e0_w1_s + + intermediate_size, hidden_size = sample_w.shape # [IM, H] + # Packed transpose layout: [num_experts, H, 2*IM] for gate_up. + gate_up_proj = torch.empty( + num_experts, hidden_size, 2 * intermediate_size, dtype=sample_w.dtype + ) + i_blocks, h_blocks = sample_s.shape # [IM_blocks, H_blocks] + gate_up_scale = torch.empty( + num_experts, h_blocks, 2 * i_blocks, dtype=sample_s.dtype + ) + + e0_w2 = lazy.get(f"{prefix}block_sparse_moe.experts.0.w2.weight") + e0_w2_s = lazy.get(f"{prefix}block_sparse_moe.experts.0.w2.weight_scale_inv") + if e0_w2.dtype == torch.float8_e4m3fn and e0_w2_s is not None: + sample_dw, sample_ds = rescale_fp8_weight_blockwise(e0_w2, e0_w2_s) + else: + raise NotImplementedError( + f"Layer {layer_idx} expert 0 w2 dtype {e0_w2.dtype} not handled." + ) + d_h_blocks, d_i_blocks = sample_ds.shape # [H_blocks, IM_blocks] + down_proj = torch.empty( + num_experts, intermediate_size, hidden_size, dtype=sample_dw.dtype + ) + down_scale = torch.empty( + num_experts, d_i_blocks, d_h_blocks, dtype=sample_ds.dtype + ) + + # Slot expert 0 (already rescaled above). + gate_up_proj[0, :, :intermediate_size] = sample_w.T + gate_up_scale[0, :, :i_blocks] = sample_s.T + e0_w3 = lazy.get(f"{prefix}block_sparse_moe.experts.0.w3.weight") + e0_w3_s = lazy.get(f"{prefix}block_sparse_moe.experts.0.w3.weight_scale_inv") + up_w0, up_s0 = rescale_fp8_weight_blockwise(e0_w3, e0_w3_s) + gate_up_proj[0, :, intermediate_size:] = up_w0.T + gate_up_scale[0, :, i_blocks:] = up_s0.T + down_proj[0] = sample_dw.T + down_scale[0] = sample_ds.T + del e0_w1, e0_w1_s, e0_w3, e0_w3_s, e0_w2, e0_w2_s + del sample_w, sample_s, sample_dw, sample_ds, up_w0, up_s0 + + for e in range(1, num_experts): + w1 = lazy.get(f"{prefix}block_sparse_moe.experts.{e}.w1.weight") + w1_s = lazy.get(f"{prefix}block_sparse_moe.experts.{e}.w1.weight_scale_inv") + w3 = lazy.get(f"{prefix}block_sparse_moe.experts.{e}.w3.weight") + w3_s = lazy.get(f"{prefix}block_sparse_moe.experts.{e}.w3.weight_scale_inv") + w2 = lazy.get(f"{prefix}block_sparse_moe.experts.{e}.w2.weight") + w2_s = lazy.get(f"{prefix}block_sparse_moe.experts.{e}.w2.weight_scale_inv") + g_w, g_s = rescale_fp8_weight_blockwise(w1, w1_s) + u_w, u_s = rescale_fp8_weight_blockwise(w3, w3_s) + d_w, d_s = rescale_fp8_weight_blockwise(w2, w2_s) + gate_up_proj[e, :, :intermediate_size] = g_w.T + gate_up_proj[e, :, intermediate_size:] = u_w.T + gate_up_scale[e, :, :i_blocks] = g_s.T + gate_up_scale[e, :, i_blocks:] = u_s.T + down_proj[e] = d_w.T + down_scale[e] = d_s.T + del w1, w1_s, w3, w3_s, w2, w2_s, g_w, g_s, u_w, u_s, d_w, d_s + + out[f"{out_prefix}block_sparse_moe.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj + out[f"{out_prefix}block_sparse_moe.expert_mlps.mlp_op.gate_up_proj.scale"] = gate_up_scale + out[f"{out_prefix}block_sparse_moe.expert_mlps.mlp_op.down_proj.weight"] = down_proj + out[f"{out_prefix}block_sparse_moe.expert_mlps.mlp_op.down_proj.scale"] = down_scale + return out + + +# --------------------------------------------------------------------------- +# Shard saving / index +# --------------------------------------------------------------------------- + +def save_shard( + tensors: Dict[str, torch.Tensor], + save_path: str, + filename: str, + weight_map: Dict[str, str], +) -> int: + """Save a sub-state-dict; clone tensors so safetensors doesn't complain + about views of mmapped storage. Returns bytes written.""" + path = os.path.join(save_path, filename) + materialized: Dict[str, torch.Tensor] = {} + total_bytes = 0 + for k, v in tensors.items(): + if not v.is_contiguous(): + v = v.contiguous() + v = v.detach().clone() + materialized[k] = v + total_bytes += v.numel() * v.element_size() + save_file(materialized, path) + for k in materialized.keys(): + weight_map[k] = filename + del materialized + return total_bytes + + +# --------------------------------------------------------------------------- +# Main driver +# --------------------------------------------------------------------------- + +def process_minimax_m2_checkpoint(hf_model_path: str, save_path: str, tp_degree: int): + os.makedirs(save_path, exist_ok=True) + + with open(os.path.join(hf_model_path, "model.safetensors.index.json")) as f: + weight_map_in = json.load(f)["weight_map"] + + with open(os.path.join(hf_model_path, "config.json")) as f: + config = json.load(f) + + num_layers = config["num_hidden_layers"] + + print( + f"Processing {num_layers} decoder layers " + f"(hidden={config['hidden_size']}, moe_IM={config['intermediate_size']}, " + f"experts={config['num_local_experts']})", + flush=True, + ) + + lazy = LazyWeightMap(hf_model_path, weight_map_in) + weight_map_out: Dict[str, str] = {} + + try: + for li in range(num_layers): + t0 = time.time() + layer_sd = process_layer(li, lazy, config) + filename = f"model_layer{li}.safetensors" + size = save_shard(layer_sd, save_path, filename, weight_map_out) + del layer_sd + gc.collect() + print( + f" layer {li:2d} {size/1e9:6.2f} GB in {time.time()-t0:5.1f}s", + flush=True, + ) + + print("Processing embed_tokens, norm, lm_head ...", flush=True) + extras: Dict[str, torch.Tensor] = {} + for src, dst in ( + ("model.embed_tokens.weight", "embed_tokens.weight"), + ("model.norm.weight", "norm.weight"), + ("lm_head.weight", "lm_head.weight"), + ): + t = lazy.get(src) + if t is not None: + extras[dst] = t.detach().clone() + else: + print(f" WARNING: missing {src}", flush=True) + if "lm_head.weight" not in extras and "embed_tokens.weight" in extras: + extras["lm_head.weight"] = extras["embed_tokens.weight"].detach().clone() + save_shard(extras, save_path, "model_extras.safetensors", weight_map_out) + del extras + finally: + lazy.close() + + # --- Index file --- + total_size = 0 + for f in set(weight_map_out.values()): + total_size += os.path.getsize(os.path.join(save_path, f)) + index = { + "metadata": {"total_size": total_size}, + "weight_map": weight_map_out, + } + with open(os.path.join(save_path, "model.safetensors.index.json"), "w") as f: + json.dump(index, f, indent=2) + + # --- Copy auxiliary files (config.json, tokenizer, chat template, and the + # trust_remote_code modules the HF config references). + for name in sorted(os.listdir(hf_model_path)): + if name.endswith(".safetensors"): + continue + if name == "model.safetensors.index.json": + continue + src = os.path.join(hf_model_path, name) + if os.path.isfile(src): + shutil.copy(src, os.path.join(save_path, name)) + + print(f"\nPreprocess complete. total_size={total_size/1e9:.2f} GB", flush=True) + print(f" tensors written: {len(weight_map_out)}", flush=True) + print(f" output dir: {save_path}", flush=True) + + +def main(): + parser = argparse.ArgumentParser( + description="Preprocess MiniMax-M2 FP8 checkpoint for Neuron inference" + ) + parser.add_argument("--hf_model_path", required=True) + parser.add_argument("--save_path", required=True) + parser.add_argument( + "--tp_degree", type=int, default=64, + help="Tensor parallelism (currently informational only; " + "the framework does the TP sharding at load time).", + ) + args = parser.parse_args() + process_minimax_m2_checkpoint(args.hf_model_path, args.save_path, args.tp_degree) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/MiniMax-M2/src/modeling_minimax_m2.py b/contrib/models/MiniMax-M2/src/modeling_minimax_m2.py new file mode 100644 index 00000000..89eb013d --- /dev/null +++ b/contrib/models/MiniMax-M2/src/modeling_minimax_m2.py @@ -0,0 +1,1676 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +MiniMax-M2 model for NeuronX Distributed Inference. + +Architecture: 229B total, ~10B active. 62 decoder layers, 256 MoE experts (top-8), +sigmoid routing with e_score_correction_bias, partial RoPE (64/128 head dim), +QK normalization (RMSNorm before reshape), GQA 48Q/8KV heads, SwiGLU experts. + +Based on Henan's (whn09) implementation with SDK 2.28 improvements: +- Fused MoE NKI kernels (router_topk, moe_cte, moe_tkg) +- ModuleMarker wrappers for compiler optimization +- Fused QKV support +- Shard-on-intermediate padding for blockwise matmul +- RouterTopKWithBias preserving e_score_correction_bias for accuracy +""" + +import gc +import math +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from neuronx_distributed.modules.moe.routing import RouterTopK +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, +) +from neuronx_distributed.utils import cpu_mode + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + MOE_TKG_MK_INTERMEDIATE_PER_TP, + MoENeuronConfig, + SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP, +) +from neuronx_distributed_inference.models.layer_boundary_marker import ( + ModuleMarkerEndWrapper, + ModuleMarkerStartWrapper, +) +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) + +# nki-library attention block kernel (partial RoPE support) +try: + from nkilib.experimental.transformer.attention_block_tkg import attention_block_tkg + from nkilib.core.utils.common_types import ( + QuantizationType as NkilibQuantizationType, + ) + + _HAS_NKILIB_ATTN_BLOCK = True +except ImportError: + _HAS_NKILIB_ATTN_BLOCK = False +from neuronx_distributed_inference.modules.attention.gqa import GQA +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.moe_v2 import ( + initialize_moe_process_group, +) + +GQA_SHARDING_STRATEGY = GQA.REPLICATE_TO_TP_DEGREE + + +# --------------------------------------------------------------------------- +# Utility helpers +# --------------------------------------------------------------------------- + + +def get_rmsnorm_cls(): + """Return the appropriate RMSNorm class for the execution environment.""" + if cpu_mode(): + + class SimpleRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + return self.weight * hidden_states.to(input_dtype) + + return SimpleRMSNorm + return CustomRMSNorm + + +def get_modules_to_not_convert(neuron_config: MoENeuronConfig): + return getattr(neuron_config, "modules_to_not_convert", None) + + +# --------------------------------------------------------------------------- +# Fused QKV helpers +# --------------------------------------------------------------------------- + + +def _helper_concat_and_delete_qkv( + state_dict: Dict[str, Any], layer_num: int, attr: str +): + """Concatenate Q/K/V into fused Wqkv for a single attribute (weight or scale). + + The fused key uses the ``qkv_proj.Wqkv`` path because the NxDI model nests + the Wqkv linear layer under ``self_attn.qkv_proj`` (a GroupQueryAttention_QKV module). + """ + state_dict[f"layers.{layer_num}.self_attn.qkv_proj.Wqkv.{attr}"] = torch.cat( + [ + state_dict[f"layers.{layer_num}.self_attn.q_proj.{attr}"], + state_dict[f"layers.{layer_num}.self_attn.k_proj.{attr}"], + state_dict[f"layers.{layer_num}.self_attn.v_proj.{attr}"], + ], + ) + del state_dict[f"layers.{layer_num}.self_attn.q_proj.{attr}"] + del state_dict[f"layers.{layer_num}.self_attn.k_proj.{attr}"] + del state_dict[f"layers.{layer_num}.self_attn.v_proj.{attr}"] + + +def convert_state_dict_to_fused_qkv(state_dict: Dict[str, Any], cfg: InferenceConfig): + """Fuse separate Q/K/V weights into a single Wqkv tensor per layer.""" + mods_to_not_conv = get_modules_to_not_convert(cfg.neuron_config) or [] + for layer_idx in range(cfg.num_hidden_layers): + _helper_concat_and_delete_qkv(state_dict, layer_idx, "weight") + if ( + cfg.neuron_config.quantized_mlp_kernel_enabled + or cfg.neuron_config.quantized + ) and f"layers.{layer_idx}.self_attn" not in mods_to_not_conv: + _helper_concat_and_delete_qkv(state_dict, layer_idx, "scale") + gc.collect() + return state_dict + + +def maybe_dequantize_layer(neuron_state_dict: dict, config): + """Dequantize FP8 layers (weight_scale_inv) to the configured torch dtype.""" + scale_layers = [] + for layer_key in list(neuron_state_dict.keys()): + if "_scale_inv" in layer_key: + scales = neuron_state_dict[layer_key] + scale_layers.append(layer_key) + fp8_layer_name = layer_key.replace("_scale_inv", "") + fp8_layer = neuron_state_dict[fp8_layer_name] + block_size = config.quantization_config["weight_block_size"] + scales_expanded = scales.repeat_interleave( + block_size[0], dim=0 + ).repeat_interleave(block_size[1], dim=1) + scaled_layer = fp8_layer.to(torch.float32) * scales_expanded.to( + torch.float32 + ) + neuron_state_dict[fp8_layer_name] = scaled_layer.to( + config.neuron_config.torch_dtype + ) + for key in scale_layers: + del neuron_state_dict[key] + + +# --------------------------------------------------------------------------- +# MiniMax-M2 specific modules +# --------------------------------------------------------------------------- + + +class MiniMaxM2QKNorm(nn.Module): + """ + QK normalization for MiniMax-M2 using Neuron's fused RmsNorm custom call. + + MiniMax-M2 applies RMSNorm on the Q/K projection output before reshape. + This implementation uses the Neuron-native AwsNeuronRmsNorm custom call + (via RmsNorm.apply) which is validated for both context encoding and token + generation NEFFs. Hand-rolled PyTorch RMSNorm (pow/mean/rsqrt) compiles + into different HLO in CE vs TG and produces incorrect TG results. + + Normalization is computed per-rank (no all-reduce) on the flat projection + output [B, S, per_rank_dim]. The per-element weight is selected dynamically + by SPMD rank from a padded weight tensor. + + Args: + hidden_size: Per-rank hidden dimension (num_heads_per_rank * head_dim) + eps: Epsilon for numerical stability + tp_degree: Tensor parallelism degree + padded_hidden_size: Total weight storage size (tp_degree * per_rank_size) + """ + + def __init__( + self, + hidden_size, + eps=1e-6, + tp_degree=1, + padded_hidden_size=None, + ): + super().__init__() + self.hidden_size = hidden_size + self.variance_epsilon = eps + self.tp_degree = tp_degree + self.padded_hidden_size = ( + padded_hidden_size + if padded_hidden_size is not None + else (hidden_size * tp_degree) + ) + # Weight stored at full padded size for SPMD rank-based selection + self.weight = nn.Parameter(torch.ones(self.padded_hidden_size)) + + def forward(self, hidden_states, rank_util=None): + """ + Apply Neuron-native RMSNorm on flat Q or K tensor (no all-reduce). + + Args: + hidden_states: [B, S, per_rank_dim] — flat projection output + rank_util: SPMDRank for dynamic weight slice selection + """ + from neuronx_distributed_inference.modules.custom_calls import RmsNorm + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + # Dynamically select weight slice by SPMD rank (XLA-compatible) + if rank_util is not None and self.tp_degree > 1: + weight_reshaped = self.weight.view(self.tp_degree, self.hidden_size) + rank_index = rank_util.rank[:1] + local_weight = torch.index_select(weight_reshaped, 0, rank_index).squeeze(0) + else: + local_weight = self.weight[: self.hidden_size] + + # Use Neuron-native fused RmsNorm (AwsNeuronRmsNorm custom call) + dim = len(hidden_states.shape) - 1 + result = RmsNorm.apply(hidden_states, local_weight, self.variance_epsilon, dim) + + return result.to(input_dtype) + + +class RouterTopKWithBias(RouterTopK): + """ + RouterTopK with e_score_correction_bias for MiniMax-M2 sigmoid routing. + + MiniMax-M2 applies sigmoid to router logits to obtain expert affinities, then + adds a learned per-expert bias before top-K selection. The bias influences which + experts are chosen but does NOT affect the affinity weights passed to experts. + + The bias MUST be an nn.Parameter (not a buffer) because: + - XLA tracing bakes register_buffer values as constants in the NEFF + - shard_children only processes nn.Parameter in supported modules + - replace_weights only loads tensors present in the traced model's separated weights + Using nn.Parameter ensures the bias is separated during tracing and loaded from + the checkpoint at inference time. + + Dropping the bias (as v3 does for XLA simplicity) causes ~75% wrong expert selection + because bias values (~8.0-9.5) dominate sigmoid scores (0-1). + """ + + def __init__(self, num_experts: int, *args, **kwargs): + super().__init__(num_experts=num_experts, *args, **kwargs) + # nn.Parameter so it gets separated from NEFF and loaded from checkpoint. + # requires_grad=False since this is inference-only. + # CRITICAL: Initialize with non-uniform values to prevent XLA graph optimization + # from eliminating the add-bias operation. Uniform values (zeros, ones) don't + # change relative ordering in topk, so XLA can prove the add is a no-op and + # eliminate it — removing the bias parameter from the HLO entirely and making it + # impossible to load the real bias values at inference time. + # Using arange produces distinct per-expert values that genuinely affect topk + # ordering, forcing the compiler to keep the bias as a runtime parameter. + # IMPORTANT: Initialize as bfloat16 to match the dtype that _cast_helper + # will produce from the checkpoint (FP32 → BF16). If the NEFF expects FP32 + # but the checkpoint provides BF16, the LayoutTransformation silently + # ignores the weight and leaves the trace-time values in place. + self.e_score_correction_bias = nn.Parameter( + torch.arange(num_experts, dtype=torch.bfloat16), + requires_grad=False, + ) + + def forward(self, hidden_states): + router_logits = self.get_router_logits(hidden_states) + expert_affinities = self.apply_activation_fn(router_logits) + + # Add bias for expert selection only (MiniMax-M2 specific). + # sigmoid(logits) + bias determines WHICH experts are selected, + # but the un-biased sigmoid scores are used as affinity weights. + scores_for_choice = ( + expert_affinities.float() + self.e_score_correction_bias.unsqueeze(0) + ) + _, expert_index = torch.topk(scores_for_choice, self.top_k, dim=-1) + + expert_affinities = expert_affinities.to(dtype=hidden_states.dtype) + expert_index = expert_index.detach().to(dtype=torch.long) + + return router_logits, expert_affinities, expert_index + + +# --------------------------------------------------------------------------- +# MoE initialization +# --------------------------------------------------------------------------- + + +def initialize_minimax_m2_moe_module( + config: InferenceConfig, rmsnorm=None, init_tkg_module=False +): + """ + Create the MoE module for MiniMax-M2 with e_score_correction_bias. + + Instead of wrapping the standard MoE, we inject a RouterTopKWithBias directly + as the router. This ensures the bias is an nn.Parameter that gets: + 1. Separated from the NEFF during XLA tracing (not baked as a constant) + 2. Loaded from the checkpoint via replace_weights at inference time + + The bias values (~8.0-9.5) dominate sigmoid scores (0-1) and are critical + for correct expert selection. Without them, ~75% of experts are wrong. + """ + from neuronx_distributed.modules.moe.expert_mlps_v2 import ExpertMLPsV2 + from neuronx_distributed.modules.moe.model import MoE + from neuronx_distributed.modules.moe.moe_configs import RoutedExpertsMLPOpsConfig + from neuronx_distributed.parallel_layers import parallel_state + from neuronx_distributed.parallel_layers.parallel_state import ( + get_expert_model_parallel_size, + get_tensor_model_parallel_group, + get_world_group, + ) + + from neuronx_distributed_inference.modules.moe_v2 import ( + initialize_moe_process_group, + ) + + enabled_hybrid_sharding = config.neuron_config.hybrid_sharding_config is not None + ( + moe_tkg_tensor_model_parallel_group, + moe_tkg_expert_model_parallel_group, + moe_cte_tensor_model_parallel_group, + moe_cte_expert_model_parallel_group, + ) = initialize_moe_process_group(config, enabled_hybrid_sharding) + + # Use RouterTopKWithBias instead of standard RouterTopK + router = RouterTopKWithBias( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + dtype=config.neuron_config.router_config.dtype, + act_fn=config.neuron_config.router_config.act_fn, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + sequence_dimension=1, + bias=False, # no linear bias; we use e_score_correction_bias instead + apply_act_fn_over_topk=False, + store_transposed_weights=init_tkg_module, + ) + + hidden_size_actual = getattr(config, "original_hidden_size", None) + intermediate_size_actual = getattr(config, "original_intermediate_size", None) + + expert_mlps = ExpertMLPsV2( + routed_experts_mlp_config=RoutedExpertsMLPOpsConfig( + num_experts=config.num_local_experts, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_size_actual=hidden_size_actual, + intermediate_size_actual=intermediate_size_actual, + is_hidden_dim_shuffled=config.neuron_config.is_hidden_dim_shuffled, + is_intermediate_dim_shuffled=config.neuron_config.is_intermediate_dim_shuffled, + top_k=config.num_experts_per_tok, + hidden_act=config.hidden_act, + bias=False, + glu_mlp=config.neuron_config.glu_mlp, + glu_type=config.neuron_config.glu_type, + hidden_act_scaling_factor=config.neuron_config.hidden_act_scaling_factor, + hidden_act_bias=config.neuron_config.hidden_act_bias, + use_index_calc_kernel=config.neuron_config.use_index_calc_kernel, + gate_clamp_upper_limit=config.neuron_config.gate_clamp_upper_limit, + gate_clamp_lower_limit=config.neuron_config.gate_clamp_lower_limit, + up_clamp_upper_limit=config.neuron_config.up_clamp_upper_limit, + up_clamp_lower_limit=config.neuron_config.up_clamp_lower_limit, + early_expert_affinity_modulation=config.neuron_config.early_expert_affinity_modulation, + normalize_top_k_affinities=config.neuron_config.normalize_top_k_affinities, + enable_spmd_rank=config.neuron_config.blockwise_matmul_config.parallelize_token_to_block_mapping, + ), + blockwise_matmul_config=config.neuron_config.blockwise_matmul_config, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + dtype=config.neuron_config.torch_dtype, + is_prefill=config.neuron_config.is_prefill_stage, + enabled_hybrid_sharding=enabled_hybrid_sharding, + tensor_model_parallel_group=parallel_state.get_tensor_model_parallel_group(), + expert_model_parallel_group=parallel_state.get_expert_model_parallel_group(), + cte_tensor_model_parallel_group=moe_cte_tensor_model_parallel_group, + cte_expert_model_parallel_group=moe_cte_expert_model_parallel_group, + tkg_tensor_model_parallel_group=moe_tkg_tensor_model_parallel_group, + tkg_expert_model_parallel_group=moe_tkg_expert_model_parallel_group, + ) + + if init_tkg_module: + from neuronx_distributed.modules.moe.model import MoEFusedTKGConfig + + tkg_config = MoEFusedTKGConfig( + quantized=config.neuron_config.quantized, + moe_fused_kernel_enabled=config.neuron_config.moe_fused_nki_kernel_enabled, + router_topk_kernel_enabled=config.neuron_config.router_topk_nki_kernel_enabled, + expert_mlp_kernel_enabled=config.neuron_config.expert_mlp_nki_kernel_enabled, + shared_mlp_kernel_enabled=config.neuron_config.shared_mlp_nki_kernel_enabled, + norm_topk_prob=config.neuron_config.normalize_top_k_affinities, + is_mxfp4_compute=config.neuron_config.is_mxfp4_compute, + router_mm_dtype=config.neuron_config.router_config.dtype, + ) + else: + tkg_config = None + + moe = MoE( + router=router, + expert_mlps=expert_mlps, + shared_experts=None, # MiniMax-M2 has no shared experts + rmsnorm=rmsnorm, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + return_expert_index=config.neuron_config.return_expert_index, + return_router_logits=config.neuron_config.return_router_logits, + sequence_dimension=1, + init_tkg_module=init_tkg_module, + tkg_config=tkg_config, + ) + + moe.eval() + return moe + + +# --------------------------------------------------------------------------- +# Weight conversion +# --------------------------------------------------------------------------- + + +def convert_minimax_m2_hf_to_neuron_state_dict( + neuron_state_dict: Dict[str, Any], + config: "MiniMaxM2InferenceConfig", +) -> Dict[str, Any]: + """ + Convert a HuggingFace MiniMax-M2 checkpoint to the NxDI-compatible format. + + Key transformations: + 1. Stack per-expert w1/w3 into gate_up_proj, w2 into down_proj + 2. Rename router gate -> router.linear_router (or router.e_score_correction_bias) + 3. Pad QK norm weights to match TP sharding (interleaved for Q, replicated for K) + 4. Optionally pad intermediate_size for shard-on-I blockwise matmul + 5. Optionally fuse QKV into Wqkv + """ + from neuronx_distributed_inference.modules.attention.gqa import ( + GQA, + _maybe_pad_interleaved, + get_shardable_head_counts, + ) + + assert config.neuron_config.glu_mlp is True, ( + "MiniMax-M2 requires glu_mlp=True (SwiGLU)" + ) + + # Dequantize FP8 weights to BF16 ONLY if we are NOT running the native + # FP8 inference path. When neuron_config.quantized=True and the source + # checkpoint was produced by preprocess_minimax_m2_fp8.py, the FP8 bytes + # and .scale tensors must be preserved for NxDI's quantized layers to + # load them directly; dequantizing here would re-inflate weights to BF16 + # and lose the FP8 path's ~2x throughput advantage. + if not getattr(config.neuron_config, "quantized", False): + maybe_dequantize_layer(neuron_state_dict, config) + + with torch.no_grad(): + tp_degree = config.neuron_config.tp_degree + head_dim = config.head_dim + has_qk_norm = getattr(config, "use_qk_norm", True) + + # Rank utility tensor for SPMD operations (int32 for NKI compatibility) + rank_tensor = torch.arange(0, tp_degree, dtype=torch.int32) + neuron_state_dict["rank_util.rank"] = rank_tensor + + # Pre-compute sharded head counts for QK norm padding + sharding_strategy = GQA.REPLICATE_TO_TP_DEGREE + padded_num_attention_heads, padded_num_kv_heads = get_shardable_head_counts( + tp_degree, + config.num_attention_heads, + config.num_key_value_heads, + sharding_strategy, + ) + + gc_interval = 64 # GC every N experts to control memory + + for layer_idx in range(config.num_hidden_layers): + # Per-layer rank tensor for attention SPMD + neuron_state_dict[f"layers.{layer_idx}.self_attn.rank_util.rank"] = ( + rank_tensor.clone() + ) + + # --- QK norm weight padding --- + if has_qk_norm: + # Q norm: interleaved padding (48 -> padded heads) + q_norm_key = f"layers.{layer_idx}.self_attn.q_norm.weight" + if q_norm_key in neuron_state_dict: + q_norm_full = neuron_state_dict[q_norm_key] + source_group_size = ( + config.num_attention_heads // config.num_key_value_heads + ) + q_norm_padded = _maybe_pad_interleaved( + q_norm_full.unsqueeze(0), + pad_dim=1, + source_heads=config.num_attention_heads, + target_heads=padded_num_attention_heads, + source_group_size=source_group_size, + ).squeeze(0) + neuron_state_dict[q_norm_key] = q_norm_padded + + # K norm: replicate from original KV heads to padded KV heads + k_norm_key = f"layers.{layer_idx}.self_attn.k_norm.weight" + if k_norm_key in neuron_state_dict: + k_norm_full = neuron_state_dict[k_norm_key] + k_norm_reshaped = k_norm_full.reshape( + config.num_key_value_heads, head_dim + ) + repeats = padded_num_kv_heads // config.num_key_value_heads + k_norm_replicated = k_norm_reshaped.repeat_interleave( + repeats, dim=0 + ) + neuron_state_dict[k_norm_key] = k_norm_replicated.reshape(-1) + + # --- Router weights --- + # Only rename if the HF-format keys are still present. The + # preprocess_minimax_m2_fp8.py streaming preprocess already emits + # under NxDI names (block_sparse_moe.router.linear_router.weight + # and block_sparse_moe.router.e_score_correction_bias), so for + # the FP8 path the pop() below would KeyError without this guard. + gate_key = f"layers.{layer_idx}.block_sparse_moe.gate.weight" + router_key = ( + f"layers.{layer_idx}.block_sparse_moe.router.linear_router.weight" + ) + if gate_key in neuron_state_dict: + neuron_state_dict[router_key] = neuron_state_dict.pop(gate_key) + + # e_score_correction_bias: map to RouterTopKWithBias.e_score_correction_bias + # This is an nn.Parameter in the router, so it will be separated from the + # NEFF during tracing and loaded via replace_weights at inference time. + bias_src_key = ( + f"layers.{layer_idx}.block_sparse_moe.e_score_correction_bias" + ) + bias_dst_key = ( + f"layers.{layer_idx}.block_sparse_moe.router.e_score_correction_bias" + ) + if bias_src_key in neuron_state_dict: + neuron_state_dict[bias_dst_key] = neuron_state_dict.pop(bias_src_key) + + # --- Expert weight stacking --- + # Skip entirely when the preprocessed checkpoint already has the + # fused layout (preprocess_minimax_m2_fp8.py emits + # block_sparse_moe.expert_mlps.mlp_op.gate_up_proj.weight directly). + w1_key = f"layers.{layer_idx}.block_sparse_moe.experts.0.w1.weight" + if w1_key not in neuron_state_dict: + continue + intermediate_size, hidden_size = neuron_state_dict[w1_key].shape + device = neuron_state_dict[w1_key].device + dtype = neuron_state_dict[w1_key].dtype + + # Stack gate (w1) + up (w3) into gate_up_proj: [E, H, 2*I] + gate_up_proj = torch.empty( + config.num_local_experts, + hidden_size, + 2 * intermediate_size, + dtype=dtype, + device=device, + ) + for expert_idx in range(config.num_local_experts): + ew1 = f"layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w1.weight" + ew3 = f"layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w3.weight" + + gate_up_slice = torch.narrow(gate_up_proj, 0, expert_idx, 1) + torch.narrow(gate_up_slice, 2, 0, intermediate_size).copy_( + neuron_state_dict[ew1].T + ) + torch.narrow( + gate_up_slice, 2, intermediate_size, intermediate_size + ).copy_(neuron_state_dict[ew3].T) + del neuron_state_dict[ew1], neuron_state_dict[ew3] + if (expert_idx + 1) % gc_interval == 0: + gc.collect() + + # Pad gate_up_proj intermediate dimension if needed for shard-on-I + pad_size = getattr(config, "moe_intermediate_pad_size", 0) + if pad_size > 0: + gate_up_proj = gate_up_proj.reshape( + config.num_local_experts, hidden_size, 2, -1 + ) + gate_up_proj = torch.nn.functional.pad(gate_up_proj, (0, pad_size)) + gate_up_proj = gate_up_proj.reshape( + config.num_local_experts, hidden_size, -1 + ) + + neuron_state_dict[ + f"layers.{layer_idx}.block_sparse_moe.expert_mlps.mlp_op.gate_up_proj.weight" + ] = gate_up_proj + + # Stack down (w2) into down_proj: [E, I, H] + down_proj = torch.empty( + config.num_local_experts, + intermediate_size, + hidden_size, + dtype=dtype, + device=device, + ) + for expert_idx in range(config.num_local_experts): + ew2 = f"layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w2.weight" + torch.narrow(down_proj, 0, expert_idx, 1).copy_( + neuron_state_dict[ew2].T + ) + del neuron_state_dict[ew2] + if (expert_idx + 1) % gc_interval == 0: + gc.collect() + + if pad_size > 0: + down_proj = torch.nn.functional.pad(down_proj, (0, 0, 0, pad_size)) + + neuron_state_dict[ + f"layers.{layer_idx}.block_sparse_moe.expert_mlps.mlp_op.down_proj.weight" + ] = down_proj + + gc.collect() + + # Fuse QKV if configured (must run BEFORE the rename below, since + # convert_state_dict_to_fused_qkv expects layers.X.self_attn.q_proj.weight) + if config.neuron_config.fused_qkv: + neuron_state_dict = convert_state_dict_to_fused_qkv( + neuron_state_dict, config + ) + + # --- Attention projection key renaming --- + # The NxDI traced model uses nested module names for attention projections: + # self_attn.qkv_proj.q_proj.weight (not self_attn.q_proj.weight) + # self_attn.qkv_proj.k_proj.weight (not self_attn.k_proj.weight) + # self_attn.qkv_proj.v_proj.weight (not self_attn.v_proj.weight) + # self_attn.o_proj.o_proj.weight (not self_attn.o_proj.weight) + # The preshard hook in RowParallelLinear handles the o_proj rename + # (o_proj.weight -> o_proj.o_proj.weight), so we only rename Q/K/V here. + # When fused_qkv=True, Q/K/V are already merged into Wqkv above. + for layer_idx in range(config.num_hidden_layers): + prefix = f"layers.{layer_idx}.self_attn" + # Q/K/V projections -> nested under qkv_proj + for proj in ("q_proj", "k_proj", "v_proj"): + old_key = f"{prefix}.{proj}.weight" + new_key = f"{prefix}.qkv_proj.{proj}.weight" + if old_key in neuron_state_dict: + neuron_state_dict[new_key] = neuron_state_dict.pop(old_key) + # Same rename for the FP8 .scale tensor (only present on the + # quantized path; BF16 path has no .scale). + old_scale_key = f"{prefix}.{proj}.scale" + new_scale_key = f"{prefix}.qkv_proj.{proj}.scale" + if old_scale_key in neuron_state_dict: + neuron_state_dict[new_scale_key] = neuron_state_dict.pop(old_scale_key) + + # --- Expand MoE blockwise scales along the TP-partitioned dim (FP8 only). --- + # NxDI's shard_checkpoint splits the scale on its partition dim into + # `per_partition_size = dim_size / moe_tp_degree`. When the per-rank + # MoE intermediate is smaller than the 128-wide blockwise scale block + # (moe_tp=64 on MiniMax-M2 gives per-rank IM=24, well below 128), several + # ranks share one scale block — we need to replicate scale entries along + # that dim. Adjacent ranks whose weight falls inside the same 128-wide + # block genuinely share that block's scale. No-op when the .scale keys + # are absent (BF16 path) or moe_tp is large enough (e.g. moe_tp=1). + if getattr(config.neuron_config, "quantized", False): + moe_tp = ( + getattr(config.neuron_config, "moe_tp_degree", None) + or config.neuron_config.tp_degree + ) + for layer_idx in range(config.num_hidden_layers): + # down_proj (RowParallel on intermediate dim). Scale: + # [E, I_blocks, H_blocks] + dp_key = ( + f"layers.{layer_idx}.block_sparse_moe.expert_mlps.mlp_op.down_proj.scale" + ) + if dp_key in neuron_state_dict: + s = neuron_state_dict[dp_key] + i_blocks = s.shape[1] + h_blocks = s.shape[2] + intermediate = i_blocks * 128 + i_per_rank = intermediate // moe_tp + if i_per_rank < 128: + ranks_per_block = 128 // i_per_rank + s_exp = s.unsqueeze(2).expand(-1, -1, ranks_per_block, -1) + s_exp = s_exp.reshape(s.shape[0], i_blocks * ranks_per_block, h_blocks) + assert s_exp.shape[1] == moe_tp, ( + f"down_proj.scale expansion produced {s_exp.shape[1]} rows, " + f"expected moe_tp={moe_tp}" + ) + neuron_state_dict[dp_key] = s_exp.contiguous() + + # gate_up_proj (ColumnParallel on 2*intermediate dim, gate|up fused + # along last axis). Scale: [E, H_blocks, 2*I_blocks] stored as + # [gate_half | up_half]. Module parameter has per-rank last-dim=1 + # (via _apply_blockwise_scale_stride_fix), so the full scale must + # have last-dim=moe_tp with gate entries 0..moe_tp/2 and up + # entries moe_tp/2..moe_tp. Expand each half independently. + gu_key = ( + f"layers.{layer_idx}.block_sparse_moe.expert_mlps.mlp_op.gate_up_proj.scale" + ) + if gu_key in neuron_state_dict: + s = neuron_state_dict[gu_key] + h_blocks = s.shape[1] + two_i_blocks = s.shape[2] + assert two_i_blocks % 2 == 0, ( + f"gate_up_proj.scale last dim must be 2*i_blocks, got {two_i_blocks}" + ) + i_blocks = two_i_blocks // 2 + intermediate = i_blocks * 128 + out_per_rank = (2 * intermediate) // moe_tp + if out_per_rank < 128: + assert moe_tp % 2 == 0, ( + f"moe_tp={moe_tp} must be even for gate/up scale split" + ) + ranks_per_half = moe_tp // 2 + assert ranks_per_half % i_blocks == 0, ( + f"ranks_per_half={ranks_per_half} must be divisible by " + f"i_blocks={i_blocks}" + ) + ranks_per_block = ranks_per_half // i_blocks + gate_half = s[..., :i_blocks] + up_half = s[..., i_blocks:] + gate_exp = ( + gate_half.unsqueeze(-1) + .expand(-1, -1, -1, ranks_per_block) + .reshape(s.shape[0], h_blocks, ranks_per_half) + ) + up_exp = ( + up_half.unsqueeze(-1) + .expand(-1, -1, -1, ranks_per_block) + .reshape(s.shape[0], h_blocks, ranks_per_half) + ) + s_exp = torch.cat([gate_exp, up_exp], dim=-1) + assert s_exp.shape[-1] == moe_tp, ( + f"gate_up_proj.scale expansion produced {s_exp.shape[-1]} " + f"entries, expected moe_tp={moe_tp}" + ) + neuron_state_dict[gu_key] = s_exp.contiguous() + + return neuron_state_dict + + +# --------------------------------------------------------------------------- +# Inference config +# --------------------------------------------------------------------------- + + +class MiniMaxM2InferenceConfig(InferenceConfig): + """ + Inference configuration for MiniMax-M2. + + Extends InferenceConfig with MoE-specific setup: + - Sigmoid routing with FP32 router precision + - Intermediate-size padding for shard-on-I blockwise matmul + - Fused MoE NKI kernel enablement + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # MiniMax-M2 has no shared experts + self.n_shared_experts = 0 + + # Store MoE intermediate size before any padding + self.moe_intermediate_size = self.intermediate_size + + # Pad intermediate for shard-on-I compatibility + self.moe_intermediate_pad_size = 0 + self._maybe_pad_intermediate() + + # Enable fused MoE NKI kernels where dimensions allow + self._enable_moe_fused_nki_kernel() + + # Router config: MiniMax-M2 uses sigmoid routing with FP32 precision + self.neuron_config.router_config.dtype = torch.float32 + self.neuron_config.router_config.act_fn = "sigmoid" + + # MiniMax-M2 normalizes top-K affinities + self.neuron_config.normalize_top_k_affinities = True + + # Disable numeric CC token for MoE stability + self.neuron_config.disable_numeric_cc_token = True + + def _maybe_pad_intermediate(self): + """Pad intermediate_size so shard-on-I blockwise matmul kernels tile correctly.""" + moe_tp_degree = self.neuron_config.moe_tp_degree + i_tp = self.intermediate_size // moe_tp_degree + if getattr( + self.neuron_config.blockwise_matmul_config, + "use_shard_on_intermediate_dynamic_while", + False, + ): + if i_tp % SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP != 0: + padded = ( + math.ceil(i_tp / SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP) + * SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP + * moe_tp_degree + ) + self.moe_intermediate_pad_size = max(padded - self.intermediate_size, 0) + self.intermediate_size = padded + + def _enable_moe_fused_nki_kernel(self): + """Enable fused MoE NKI kernel if the per-TP intermediate dimension is aligned.""" + i_tp = self.intermediate_size // self.neuron_config.moe_tp_degree + if getattr(self.neuron_config, "moe_fused_nki_kernel_enabled", False): + if i_tp % MOE_TKG_MK_INTERMEDIATE_PER_TP == 0: + self.moe_fused_nki_kernel_enabled = True + + def get_required_attributes(self) -> List[str]: + return [ + "head_dim", + "hidden_act", + "hidden_size", + "intermediate_size", + "max_position_embeddings", + "num_attention_heads", + "num_experts_per_tok", + "num_hidden_layers", + "num_key_value_heads", + "num_local_experts", + "rms_norm_eps", + "rope_theta", + "tie_word_embeddings", + "vocab_size", + "use_qk_norm", + "rotary_dim", + ] + + @classmethod + def get_neuron_config_cls(cls): + return MoENeuronConfig + + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- + + +class NeuronMiniMaxM2Attention(NeuronAttentionBase): + """ + MiniMax-M2 attention with two non-standard features: + + 1. QK normalization applied BEFORE reshape to per-head layout (on the full + Q/K projection output). Uses MiniMaxM2QKNorm with distributed all-reduce. + 2. Partial RoPE: rotary embeddings applied to only the first ``rotary_dim`` + dimensions of each head (64 out of 128). + """ + + def __init__(self, config: MiniMaxM2InferenceConfig): + self.rotary_dim = getattr(config, "rotary_dim", config.head_dim) + + # RotaryEmbedding sized to rotary_dim (64), not head_dim (128) + rotary_emb = RotaryEmbedding( + self.rotary_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + rotary_emb=rotary_emb, + rms_norm_eps=config.rms_norm_eps, + use_qk_norm=False, # handled by MiniMaxM2QKNorm below + ) + + # --- QK normalization (local per-rank, no all-reduce) --- + self.use_minimax_qk_norm = getattr(config, "use_qk_norm", True) + tp_degree = config.neuron_config.tp_degree + + if self.use_minimax_qk_norm: + q_per_rank = self.num_heads * self.head_dim + k_per_rank = self.num_key_value_heads * self.head_dim + + # Weight storage: padded to tp_degree * per_rank for SPMD selection + padded_q = self.num_heads * tp_degree * config.head_dim + padded_kv = self.num_key_value_heads * tp_degree + padded_k = padded_kv * config.head_dim + + self.q_norm = MiniMaxM2QKNorm( + q_per_rank, + eps=config.rms_norm_eps, + tp_degree=tp_degree, + padded_hidden_size=padded_q, + ) + self.k_norm = MiniMaxM2QKNorm( + k_per_rank, + eps=config.rms_norm_eps, + tp_degree=tp_degree, + padded_hidden_size=padded_k, + ) + + if not parallel_state.model_parallel_is_initialized(): + raise ValueError( + "NeuronMiniMaxM2Attention requires an initialized distributed environment. " + "Use neuronx_distributed to initialize." + ) + + def prep_qkv_tensors( + self, + position_ids, + hidden_states, + past_key_value, + adapter_ids=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + skip_rope=False, + residual=None, + use_polar_compatible_rope=False, + ): + """Apply local QK norm on flat projection, reshape to heads, then partial RoPE.""" + Q, K, V, residual = self.get_qkv_proj()( + hidden_states=hidden_states, + rmsnorm=rmsnorm, + adapter_ids=adapter_ids, + residual=residual, + ) + + # QK norm on flat per-rank projection output BEFORE reshape (no all-reduce) + if self.use_minimax_qk_norm: + Q = self.q_norm(Q, self.rank_util) + K = self.k_norm(K, self.rank_util) + + bsz, q_len, _ = hidden_states.size() + if self.sequence_parallel_enabled: + q_len *= self.tensor_model_parallel_group.size() + + # Reshape to [B, S, num_heads, head_dim] then transpose to [B, H, S, D] + Q = ( + Q.view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + K = ( + K.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + V = ( + V.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + if not skip_rope: + Q, K, cos_cache, sin_cache = self.apply_rotary_embedding( + Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope + ) + + return Q, K, V, cos_cache, sin_cache, residual + + def apply_rotary_embedding( + self, Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope + ): + """Apply partial rotary embeddings (first rotary_dim dimensions only).""" + from neuronx_distributed_inference.modules.attention.utils import ( + apply_rotary_pos_emb, + ) + + if not use_polar_compatible_rope and self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + + if self.rotary_dim < self.head_dim: + Q_rot, Q_pass = Q[..., : self.rotary_dim], Q[..., self.rotary_dim :] + K_rot, K_pass = K[..., : self.rotary_dim], K[..., self.rotary_dim :] + Q_rot, K_rot = apply_rotary_pos_emb(Q_rot, K_rot, cos_cache, sin_cache) + Q = torch.cat([Q_rot, Q_pass], dim=-1) + K = torch.cat([K_rot, K_pass], dim=-1) + else: + Q, K = apply_rotary_pos_emb(Q, K, cos_cache, sin_cache) + + return Q, K, cos_cache, sin_cache + + def attention_block_tokengen_nki_kernel( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + active_mask=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + rotary_position_ids=None, + update_kv_per_layer=True, + active_block_table=None, + use_polar_compatible_rope=False, + ): + """ + Override base class to use nki-library attention_block_tkg kernel with + partial RoPE support (rotary_dim < head_dim). + + Uses the nki-library kernel instead of the compiler's private kernel. + QK norm is fused into the kernel via the flat QK RMSNorm feature, which + normalizes across all Q (or K) heads concatenated before head splitting. + """ + assert _HAS_NKILIB_ATTN_BLOCK, ( + "nki-library attention_block_tkg not available. " + "Install the nki-library fork with partial RoPE support." + ) + + from neuronx_distributed.parallel_layers.mappings import ( + gather_from_sequence_parallel_region, + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, + gather_from_tensor_model_parallel_region_with_dim, + reduce_scatter_to_tensor_model_parallel_region_with_dim, + ) + from neuronx_distributed_inference.modules.attention.attention_base import ( + EPDispatchOption, + get_data_parallel_attention_dp_group, + ) + # NKI 0.3.0: use kernel[lnc_int] instead of kernel[(nc(lnc),)] + + if ( + self.sequence_parallel_enabled + and self.tensor_model_parallel_group is not None + ): + hidden_states = gather_from_sequence_parallel_region( + hidden_states, + self.sequence_dimension, + process_group=self.tensor_model_parallel_group, + ) + + # Get shapes + bsz, s_tkg, h = hidden_states.shape + h_out = h // 2 if self.is_eagle3_draft else h + num_q_heads = self.num_heads + + # Prepare rmsnorm params + rmsnorm_enabled = rmsnorm is not None + W_gamma = rmsnorm.weight.data.unsqueeze(0) if rmsnorm is not None else None + + # Prepare RoPE params + rope_contiguous_layout = not use_polar_compatible_rope + + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb( + hidden_states, rotary_position_ids + ) + # Take first half and reshape to [dim//2, batch_size, seq_len] + cos_cache = cos_cache[..., : cos_cache.shape[-1] // 2].permute(2, 0, 1) + sin_cache = sin_cache[..., : sin_cache.shape[-1] // 2].permute(2, 0, 1) + elif use_polar_compatible_rope: + from neuronx_distributed.modules.attention.utils import precompute_freqs_cis + + rotary_freqs = precompute_freqs_cis( + self.head_dim, + self.neuron_config.max_context_length * 2, + self.rope_theta, + self.use_scaled_rope, + device=hidden_states.device, + ) + rotary_freqs = rotary_freqs[position_ids] + cos_cache = rotary_freqs.cos().permute(2, 0, 1) + sin_cache = rotary_freqs.sin().permute(2, 0, 1) + else: + cos_cache = None + sin_cache = None + + # Prepare attention mask: merge active_mask and transpose for kernel layout + attention_mask = attention_mask.expand(-1, num_q_heads, -1, -1) + expected_active_mask_shape = (bsz, 1, s_tkg, s_tkg) + if s_tkg == 1: + active_mask = torch.ones( + expected_active_mask_shape, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + else: + assert active_mask.shape == expected_active_mask_shape, ( + f"{active_mask.shape} != {expected_active_mask_shape}" + ) + active_mask = active_mask.expand(-1, num_q_heads, -1, -1) + attention_mask[:, :, :, -s_tkg:] = active_mask + # Transpose to [S_ctx, B, q_heads, S_tkg] for nki-library kernel + attention_mask = attention_mask.permute(3, 0, 1, 2) + + # Prepare KV cache + K_prior, V_prior = past_key_value[:2] + K_prior = K_prior.data + V_prior = V_prior.data + update_cache_in_kernel = ( + update_kv_per_layer and self.attn_block_tkg_nki_kernel_cache_update + ) + sink = ( + self.get_learned_sinks().data.unsqueeze(-1) + if self.learned_sinks_size is not None + else None + ) + kv_cache_update_idx = position_ids[:, :1].to(torch.int32) + + # Prepare output projection + W_out = self.get_o_proj().o_proj.weight.data + if self.o_bias: + W_out_bias = ( + self.get_o_proj().o_proj.bias.data / self.tp_degree + ).unsqueeze(0) + else: + W_out_bias = None + + # Prepare QKV projection + W_qkv = self.get_qkv_proj().Wqkv.weight.data + bias_qkv = ( + self.get_qkv_proj().Wqkv.bias.data.unsqueeze(0) if self.qkv_bias else None + ) + + grid = self.logical_nc_config + + # Prepare flat QK norm weights (per-rank slice via SPMD rank selection) + # The kernel expects [1, per_rank_width] weights for each of Q and K. + flat_qk_norm_enabled = self.use_minimax_qk_norm + flat_qk_W_Q = None + flat_qk_W_K = None + if flat_qk_norm_enabled: + # Q norm: select per-rank slice from padded weight + q_norm_weight = self.q_norm.weight.data # [padded_q_hidden_size] + q_per_rank = self.q_norm.hidden_size + if self.q_norm.tp_degree > 1: + q_w_reshaped = q_norm_weight.view(self.q_norm.tp_degree, q_per_rank) + rank_index = self.rank_util.rank[:1] + flat_qk_W_Q = torch.index_select( + q_w_reshaped, 0, rank_index + ) # [1, q_per_rank] + else: + flat_qk_W_Q = q_norm_weight[:q_per_rank].unsqueeze(0) # [1, q_per_rank] + + # K norm: select per-rank slice from padded weight + k_norm_weight = self.k_norm.weight.data # [padded_k_hidden_size] + k_per_rank = self.k_norm.hidden_size + if self.k_norm.tp_degree > 1: + k_w_reshaped = k_norm_weight.view(self.k_norm.tp_degree, k_per_rank) + rank_index = self.rank_util.rank[:1] + flat_qk_W_K = torch.index_select( + k_w_reshaped, 0, rank_index + ) # [1, k_per_rank] + else: + flat_qk_W_K = k_norm_weight[:k_per_rank].unsqueeze(0) # [1, k_per_rank] + + attn_output, K, V = attention_block_tkg[grid]( + # -- input + X=hidden_states, + X_hidden_dim_actual=getattr(self.config, "original_hidden_size", None), + # -- rmsnorm X + rmsnorm_X_enabled=rmsnorm_enabled, + rmsnorm_X_eps=self.rms_norm_eps, + rmsnorm_X_gamma=W_gamma, + # -- qkv projections + W_qkv=W_qkv, + bias_qkv=bias_qkv, + quantization_type_qkv=NkilibQuantizationType.NONE, + weight_dequant_scale_qkv=None, + input_dequant_scale_qkv=None, + # -- Q/K processing: flat QK RMSNorm (before head split) + rmsnorm_QK_flat_enabled=flat_qk_norm_enabled, + rmsnorm_QK_flat_eps=self.rms_norm_eps if flat_qk_norm_enabled else 0.0, + rmsnorm_QK_flat_W_Q=flat_qk_W_Q, + rmsnorm_QK_flat_W_K=flat_qk_W_K, + # -- Q/K processing: per-head pre-RoPE RMSNorm (disabled) + rmsnorm_QK_pre_rope_enabled=False, + rmsnorm_QK_pre_rope_eps=0.0, + rmsnorm_QK_pre_rope_W_Q=None, + rmsnorm_QK_pre_rope_W_K=None, + # -- Q/K processing: RoPE with partial rotary_dim + cos=cos_cache, + sin=sin_cache, + rope_contiguous_layout=rope_contiguous_layout, + rotary_dim=self.rotary_dim, + # -- Q/K processing: post-RoPE RMSNorm (disabled) + rmsnorm_QK_post_rope_enabled=False, + rmsnorm_QK_post_rope_eps=0.0, + rmsnorm_QK_post_rope_W_Q=None, + rmsnorm_QK_post_rope_W_K=None, + # -- attention + K_cache_transposed=self.k_cache_transposed, + active_blocks_table=( + active_block_table.to(torch.uint32) + if active_block_table is not None + else None + ), + K_cache=K_prior, + V_cache=V_prior, + attention_mask=attention_mask, + sink=sink, + softmax_scale=None, + # -- KV cache update + update_cache=update_cache_in_kernel, + kv_cache_update_idx=kv_cache_update_idx, + # -- output projection + W_out=W_out, + bias_out=W_out_bias, + quantization_type_out=NkilibQuantizationType.NONE, + weight_dequant_scale_out=None, + input_dequant_scale_out=None, + transposed_out=False, + # -- output + out_in_sb=False, + ) + + # Reshape and reduce output + attn_output = attn_output.reshape((bsz, s_tkg, h_out)) + if self.sequence_parallel_enabled: + attn_output = reduce_scatter_to_sequence_parallel_region( + attn_output, 1, process_group=self.tensor_model_parallel_group + ) + else: + if self.ep_dispatch_cc_option == EPDispatchOption.AR_AG: + attn_output = reduce_from_tensor_model_parallel_region( + attn_output, process_group=self.tensor_model_parallel_group + ) + elif self.ep_dispatch_cc_option == EPDispatchOption.RS_AG: + attn_output = reduce_scatter_to_tensor_model_parallel_region_with_dim( + attn_output, + partition_dim=0, + process_group=self.tensor_model_parallel_group, + ) + elif self.ep_dispatch_cc_option == EPDispatchOption.AG_AR: + attn_output = gather_from_tensor_model_parallel_region_with_dim( + attn_output, + gather_dim=0, + process_group=get_data_parallel_attention_dp_group(), + ) + else: + raise ValueError( + f"Unknown EPDispatchOption: {self.ep_dispatch_cc_option}" + ) + + # KV cache handling + if update_cache_in_kernel: + KV = past_key_value + else: + # Reshape K/V from kernel output layout to the rank-4 [B, N, S, D] + # layout expected by kv_cache_manager.update_kv_by_layer_id. + # K from kernel: [head_dim, bsz, q_len] (dBS) + # V from kernel: [bsz, q_len, head_dim] (BSd) + # Target: [B, 1, S, D] (BNSd) or [B, 1, D, S] (BNdS) for transposed K + K = K.permute(1, 0, 2) if self.k_cache_transposed else K.permute(1, 2, 0) + K = K.unsqueeze(1) + V = V.unsqueeze(1) + KV = (K, V) + + return attn_output, KV, cos_cache, sin_cache + + +class NeuronMiniMaxM2DecoderLayer(nn.Module): + """MiniMax-M2 decoder layer: attention + MoE with ModuleMarker wrappers.""" + + def __init__(self, config: MiniMaxM2InferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = NeuronMiniMaxM2Attention(config=config) + self.moe_fused_nki_kernel_enabled = getattr( + config, "moe_fused_nki_kernel_enabled", False + ) + + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + # Fused MoE kernel absorbs post-attention layernorm + if self.moe_fused_nki_kernel_enabled: + self.block_sparse_moe = initialize_minimax_m2_moe_module( + config=config, + rmsnorm=self.post_attention_layernorm, + init_tkg_module=True, + ) + else: + self.block_sparse_moe = initialize_minimax_m2_moe_module(config=config) + + self.qkv_kernel_enabled = config.neuron_config.qkv_kernel_enabled + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + self.qkv_kernel_fused_rmsnorm = not self.sequence_parallel_enabled + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, ...]: + residual = hidden_states + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + + qkv_fused_rmsnorm = None + if self.input_layernorm: + if self.qkv_kernel_enabled and self.qkv_kernel_fused_rmsnorm: + qkv_fused_rmsnorm = self.input_layernorm + else: + hidden_states = self.input_layernorm(hidden_states) + + # Self-attention + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + rmsnorm=qkv_fused_rmsnorm, + **kwargs, + ) + hidden_states = residual + hidden_states + + # MoE + residual = hidden_states + if not self.moe_fused_nki_kernel_enabled: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states, padding_mask)[0] + hidden_states = residual + hidden_states + + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + return (hidden_states, present_key_value, cos_cache, sin_cache, None) + + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + + +class NeuronMiniMaxM2Model(NeuronBaseModel): + """Traceable MiniMax-M2 base model.""" + + def setup_attr_for_model(self, config: MiniMaxM2InferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: MiniMaxM2InferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + self.layers = nn.ModuleList( + [ + NeuronMiniMaxM2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + ) + + +# --------------------------------------------------------------------------- +# CausalLM wrapper +# --------------------------------------------------------------------------- + + +class NeuronMiniMaxM2ForCausalLM(NeuronBaseForCausalLM): + """MiniMax-M2 causal language model for NxDI inference.""" + + _model_cls = NeuronMiniMaxM2Model + + def __init__(self, *args, **kwargs): + # Install FP8 monkey-patches BEFORE super().__init__ so the patched + # quantization-layer classes are in effect when NxDI builds the + # decoder. Gated on quantized=True so the BF16 path is untouched. + ncfg = kwargs.get("config") or (args[1] if len(args) > 1 else None) + if ncfg is not None and getattr(getattr(ncfg, "neuron_config", None), "quantized", False): + self._apply_ep_scale_fix() + self._apply_blockwise_scale_stride_fix() + self._apply_2d_per_channel_fix() + super().__init__(*args, **kwargs) + + @staticmethod + def load_hf_model(model_path, **kwargs): + return None + + @classmethod + def get_config_cls(cls): + return MiniMaxM2InferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: MiniMaxM2InferenceConfig + ) -> dict: + return convert_minimax_m2_hf_to_neuron_state_dict(state_dict, config) + + # ------------------------------------------------------------------ + # FP8 quantized-inference monkey-patches (no-op unless quantized=True). + # + # Reconcile the preprocessed Neuron-FP8 checkpoint (blockwise-MoE + + # per-row-attn) with NxDI's global blockwise_symmetric q_config. Ported + # from the MiMo-V2-Flash FP8 enablement work (same MoE block-size math, + # same Quantized{Column,Row}Parallel issues). All three are gated by + # self.neuron_config.quantized so the BF16 path is completely untouched. + # ------------------------------------------------------------------ + + @staticmethod + def _apply_ep_scale_fix(): + """Skip per-channel `scale` params when marking expert-parallel + weights; they have shape [1, 1, W] and cannot be EP-sharded.""" + from neuronx_distributed.modules.moe.moe_parallel_layers import ( + ExpertFusedLinear, + ) + + if getattr(ExpertFusedLinear, "_minimax_m2_ep_scale_patched", False): + return + + def _patched_mark( + self_inner, + iterable=None, + expert_parallel_group_size=None, + is_prefill=True, + expert_distribution=None, + ): + from neuronx_distributed.parallel_layers.parallel_state import ( + get_expert_model_parallel_size, + ) + + if expert_parallel_group_size is None: + expert_parallel_group_size = get_expert_model_parallel_size() + + if expert_parallel_group_size > 1: + if iterable is None: + params_to_mark = [] + for name, p in self_inner.named_parameters(): + if name == "scale" and p.shape[0] == 1: + continue + params_to_mark.append(p) + iterable = params_to_mark + + for p in iterable: + p.expert_model_parallel = True + if is_prefill: + p.is_prefill = True + p.expert_distribution = expert_distribution + + ExpertFusedLinear._mark_expert_parallel_weights = _patched_mark + ExpertFusedLinear._minimax_m2_ep_scale_patched = True + + @staticmethod + def _apply_blockwise_scale_stride_fix(): + """Force scale.partition_stride=1 for BLOCKWISE_SYMMETRIC quantization + — stride>1 causes strided-splitting failures when per-rank weight + size is smaller than a block.""" + from neuronx_distributed.quantization.quantization_config import ( + QuantizationType, + ) + from neuronx_distributed.quantization.quantization_layers import ( + BaseQuantizeParallelLinear, + ) + + if getattr(BaseQuantizeParallelLinear, "_minimax_m2_blockwise_stride_patched", False): + return + + _original_setup = BaseQuantizeParallelLinear._setup_for_scale + + def _patched_setup(self_inner, *args, **kwargs): + _original_setup(self_inner, *args, **kwargs) + if ( + hasattr(self_inner, "quantization_type") + and self_inner.quantization_type == QuantizationType.BLOCKWISE_SYMMETRIC + and hasattr(self_inner, "scale") + and hasattr(self_inner.scale, "partition_stride") + and self_inner.scale.partition_stride > 1 + ): + self_inner.scale.partition_stride = 1 + + BaseQuantizeParallelLinear._setup_for_scale = _patched_setup + BaseQuantizeParallelLinear._minimax_m2_blockwise_stride_patched = True + + @staticmethod + def _apply_2d_per_channel_fix(): + """Route 2D self_attn swaps through per_channel_symmetric. + + MiniMax-M2's preprocess writes: + - MoE experts: 3D weights with (E, out//128, in//128) blockwise scales. + - self_attn q/k/v/o: 2D weights with (out, 1) per-row scales + (at TP=64 each rank's out-dim is <128, so blockwise scale + would collapse to a singleton; per-row avoids that). + + NxDI's q_config is global blockwise_symmetric (for the MoE). Feeding + that into the 2D classes triggers `block axis cannot be < 0 or > 2, + received 2` in _setup_for_scale (block axes [1, 2] exceed rank-2 + weight_shape). This wraps the 2D classes' from_float to override + q_config on the fly: flip quantization_type to per_channel_symmetric, + drop block_axis / block_size, force quantization_per_channel_axis=0. + MoE classes are untouched. + """ + from neuronx_distributed.quantization.quantization_config import ( + QuantizationType, + ) + from neuronx_distributed.quantization.quantization_layers import ( + QuantizedColumnParallel, + QuantizedRowParallel, + ) + + def _wrap(cls): + if getattr(cls, "_minimax_m2_2d_patched", False): + return + original_from_float = cls.from_float + + def _patched_from_float(klass, mod, q_config=None, _orig=original_from_float): + if q_config is not None and q_config.get("quantization_type") == \ + QuantizationType.BLOCKWISE_SYMMETRIC: + q_config = dict(q_config) + q_config["quantization_type"] = QuantizationType.PER_CHANNEL_SYMMETRIC + q_config["quantization_per_channel_axis"] = 0 + q_config.pop("block_axis", None) + q_config.pop("block_size", None) + if q_config is None: + return _orig(mod) + return _orig(mod, q_config) + + cls.from_float = classmethod(_patched_from_float) + cls._minimax_m2_2d_patched = True + + _wrap(QuantizedColumnParallel) + _wrap(QuantizedRowParallel) + + def _install_fp8_patches(self): + """Install all FP8-specific runtime patches. No-op for BF16.""" + if not getattr(self.neuron_config, "quantized", False): + return + self._apply_ep_scale_fix() + self._apply_blockwise_scale_stride_fix() + self._apply_2d_per_channel_fix() + + def compile(self, *args, **kwargs): + # save_sharded_checkpoint=True serializes shards during compile() and + # that code path reads scale.partition_stride — patches must be live. + self._install_fp8_patches() + return super().compile(*args, **kwargs) + + def load(self, *args, **kwargs): + self._install_fp8_patches() + return super().load(*args, **kwargs) + + @classmethod + def save_quantized_state_dict(cls, model_path, config): + """MiniMax-M2 ships pre-quantized FP8 safetensors via our preprocess + script. The base implementation calls AutoModelForCausalLM.from_pretrained + to re-quantize, which requires a CUDA GPU (HF's finegrained_fp8 + quantizer is gated on CUDA) and materializes a ~600 GB BF16 copy. + Skip if the checkpoint directory already contains a Neuron-FP8 + index produced by the preprocess script.""" + import os as _os + qpath = ( + getattr(config.neuron_config, "quantized_checkpoints_path", None) + or model_path + ) + if qpath and _os.path.isdir(qpath): + index = _os.path.join(qpath, "model.safetensors.index.json") + if _os.path.isfile(index): + return + return super().save_quantized_state_dict(model_path, config) + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + def get_compiler_args(self): + """Compiler arguments tuned for MiniMax-M2 MoE. + + Uses -O1 by default. -O2 was tested but provides no scratchpad memory + savings vs -O1 (identical 22 GB tensor allocation at 62 layers TP=32). + """ + if self.compile_tag == TOKEN_GENERATION_MODEL_TAG: + opt_level = "-O1" + else: + opt_level = "-O1" + + args = f"--enable-saturate-infinity --enable-mixed-precision-accumulation --model-type transformer {opt_level}" + args += ( + " --tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2'" + ) + args += " --auto-cast=none" + args += " --internal-enable-dge-levels vector_dynamic_offsets" + args += " --internal-hlo2tensorizer-options='--verify-hlo=true'" + + if self.neuron_config.scratchpad_page_size: + args += ( + f" --hbm-scratchpad-page-size={self.neuron_config.scratchpad_page_size}" + ) + + if self.neuron_config.attn_block_tkg_nki_kernel_enabled: + assert self.neuron_config.attn_block_tkg_nki_kernel_cascaded_attention, ( + "attn_block_tkg_nki_kernel_enabled requires attn_block_tkg_nki_kernel_cascaded_attention" + ) + self.neuron_config.pre_rope_rmsnorm = True + args += " --internal-max-instruction-limit=15000000" + + return args + + @classmethod + def get_state_dict(cls, model_name_or_path: str, config: InferenceConfig) -> dict: + """Load and convert state dict from a HuggingFace safetensors checkpoint.""" + import json + import os + + from safetensors import safe_open + + if os.path.isdir(model_name_or_path): + index_path = os.path.join( + model_name_or_path, "model.safetensors.index.json" + ) + if os.path.exists(index_path): + with open(index_path, "r") as f: + index = json.load(f) + + model_sd: Dict[str, Any] = {} + shard_files = sorted(set(index["weight_map"].values())) + for i, shard_file in enumerate(shard_files): + if i % 20 == 0: + print( + f" Loading shard {i + 1}/{len(shard_files)}: {shard_file}" + ) + shard_path = os.path.join(model_name_or_path, shard_file) + with safe_open(shard_path, framework="pt", device="cpu") as f: + for key in f.keys(): + model_sd[key] = f.get_tensor(key) + + print( + f" Loaded {len(model_sd)} parameters from {len(shard_files)} shards" + ) + + # Strip model. prefix + for param_name in list(model_sd.keys()): + if param_name.startswith(cls._STATE_DICT_MODEL_PREFIX): + new_name = param_name.replace( + cls._STATE_DICT_MODEL_PREFIX, + cls._NEW_STATE_DICT_MODEL_PREFIX, + 1, + ) + model_sd[new_name] = model_sd.pop(param_name) + + model_sd = cls.convert_hf_to_neuron_state_dict(model_sd, config) + + if getattr(config, "tie_word_embeddings", False): + cls.update_state_dict_for_tied_weights(model_sd) + + if cls._FUSED_PREFIX: + for param_name in list(model_sd.keys()): + model_sd[f"{cls._FUSED_PREFIX}.{param_name}"] = model_sd.pop( + param_name + ) + + return model_sd + else: + from neuronx_distributed_inference.modules.checkpoint import ( + load_state_dict, + ) + + return load_state_dict(model_name_or_path) + else: + return super().get_state_dict(model_name_or_path, config) diff --git a/contrib/models/MiniMax-M2/test/__init__.py b/contrib/models/MiniMax-M2/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiniMax-M2/test/integration/__init__.py b/contrib/models/MiniMax-M2/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiniMax-M2/test/integration/test_model.py b/contrib/models/MiniMax-M2/test/integration/test_model.py new file mode 100644 index 00000000..7dc97ad3 --- /dev/null +++ b/contrib/models/MiniMax-M2/test/integration/test_model.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +"""Integration tests for MiniMax M2 NeuronX implementation.""" + +import pytest +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) + + +def test_config_import(): + """Test that config class can be imported.""" + from modeling_minimax_m2 import MiniMaxM2InferenceConfig, NeuronMiniMaxM2ForCausalLM + assert MiniMaxM2InferenceConfig is not None + assert NeuronMiniMaxM2ForCausalLM is not None + print("PASS: Config and model classes imported successfully") + + +def test_required_attributes(): + """Test that required attributes are defined.""" + from modeling_minimax_m2 import MiniMaxM2InferenceConfig + from neuronx_distributed_inference.models.config import MoENeuronConfig + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + from transformers import AutoConfig + import torch + + neuron_config = MoENeuronConfig( + tp_degree=64, + batch_size=1, + seq_len=512, + torch_dtype=torch.bfloat16, + on_cpu=True, + ) + # Use the bundled config.json to provide model-specific attributes + config_path = Path(__file__).resolve().parents[2] / "src" + hf_config = AutoConfig.from_pretrained(str(config_path), trust_remote_code=True) + config = MiniMaxM2InferenceConfig(neuron_config, load_config=load_pretrained_config(hf_config=hf_config)) + required = config.get_required_attributes() + assert "hidden_size" in required + assert "num_local_experts" in required + assert "num_experts_per_tok" in required + print(f"PASS: {len(required)} required attributes defined") + + +def test_neuron_config_cls(): + """Test that MoENeuronConfig is returned.""" + from modeling_minimax_m2 import MiniMaxM2InferenceConfig + from neuronx_distributed_inference.models.config import MoENeuronConfig + assert MiniMaxM2InferenceConfig.get_neuron_config_cls() == MoENeuronConfig + print("PASS: MoENeuronConfig returned") + + +if __name__ == "__main__": + test_config_import() + test_required_attributes() + test_neuron_config_cls() + print("\nAll tests passed!") diff --git a/contrib/models/MiniMax-M2/test/unit/__init__.py b/contrib/models/MiniMax-M2/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b