diff --git a/contrib/models/Ministral-3-14B-Instruct-2512/README.md b/contrib/models/Ministral-3-14B-Instruct-2512/README.md new file mode 100644 index 00000000..7bdfe700 --- /dev/null +++ b/contrib/models/Ministral-3-14B-Instruct-2512/README.md @@ -0,0 +1,307 @@ +# Contrib Model: Ministral-3-14B-Instruct-2512 (Leanstral) + +NeuronX Distributed Inference contrib for Ministral-3-14B-Instruct-2512 on AWS Trainium 2. +This model uses Mistral's 14B dense GQA text decoder with 8 KV heads, served via the +LlamaForCausalLM code path in NxDI with custom NKI kernels for multi-KV-head attention. + +## Model Information + +- **HuggingFace ID:** `mistralai/Ministral-3-14B-Instruct-2512` +- **Architecture:** Dense GQA (runs as LlamaForCausalLM via hf-overrides) +- **Parameters:** 14B (40 layers, hidden=5120, 32 Q / 8 KV heads, d_head=128) +- **Vocab:** 32768 (text-only extraction from VL checkpoint) +- **License:** Check HuggingFace model card (gated access) + +## Architecture Details + +- 40 layers, hidden\_size=5120 (mapped to 3584 for text extraction), intermediate\_size=16384 +- num\_attention\_heads=32, num\_kv\_heads=8, head\_dim=128, rope\_theta=1e9 +- At TP=4: q\_heads\_per\_rank=8, kv\_heads\_per\_rank=2 +- Original checkpoint is FP8 E4M3 — dequantized to BF16 via `extract_text_model.py` + +### Key Adaptations for SDK 2.29 + +1. **LlamaForCausalLM code path**: vLLM 0.16 auto-promotes MistralForCausalLM to Pixtral. + We use `--hf-overrides '{"architectures": ["LlamaForCausalLM"], "model_type": "llama"}'` + to force the Llama code path, which handles the GQA sharding natively. + +2. **QKV NKI kernel (recommended)**: The standard NxDI QKV NKI kernel provides the best + per-request and aggregate throughput. No additional kernel patches are needed beyond + the QKV fixes in `setup_patches.py`. + +3. **FP8→BF16 text extraction**: The HuggingFace checkpoint is a VL model with FP8 weights. + `extract_text_model.py` strips vision keys, dequantizes FP8→BF16, fixes tokenizer issues, + and writes a clean text-only checkpoint. + +4. **Multi-KV-head TKG kernel**: Modified `attention_block_tkg` kernel supporting + kv\_heads\_per\_rank > 1 via virtual-batch approach. With LNC=2 (default on trn2.3xlarge), + the TKG kernel **matches baseline TPOT at BS=4** and adds only 5-9% overhead at BS=8. + See [TKG Kernel Results](#tkg-kernel-results) for details. + +## Prerequisites + +- **SDK 2.29** (neuronx-cc >= 2.24, neuronx-distributed-inference >= 0.9, vLLM 0.16 + vllm-neuron 0.5) +- **trn2.3xlarge** (TP=4, LNC=2, 96 GB HBM) +- **Model checkpoint**: `mistralai/Ministral-3-14B-Instruct-2512` from HuggingFace (gated) +- **Disk**: ~300 GB EBS for checkpoint + compiled model artifacts + +### Environment Setup + +```bash +# Activate pre-installed vLLM 0.16 environment (SDK 2.29) +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate + +# Install aiohttp for benchmark script +pip install aiohttp +``` + +## Quick Start + +### Step 1: Download Model + +```bash +huggingface-cli download mistralai/Ministral-3-14B-Instruct-2512 \ + --local-dir /home/ubuntu/models/Ministral-3-14B-Instruct-2512 +``` + +### Step 2: Extract Text-Only BF16 Checkpoint + +```bash +python src/extract_text_model.py \ + --input-dir /home/ubuntu/models/Ministral-3-14B-Instruct-2512 \ + --output-dir /home/ubuntu/models/Ministral-3-14B-text-bf16 +``` + +This produces a ~27 GB checkpoint with: +- 6 safetensors shards (BF16, vision keys removed, `language_model.` prefix stripped) +- Fixed `tokenizer_config.json` (removes Pixtral processor references) +- Proper `config.json` for LlamaForCausalLM + +### Step 3: Apply Runtime Patches + +```bash +python src/setup_patches.py +``` + +Applies 6 patches to the installed NxDI/nkilib packages: +1. `rms_norm_eps` pass-through in model base +2. nkilib QKV kernel epsilon guard +3. neuronxcc QKV kernel epsilon guard +4. `convert_state_dict_to_fused_qkv` fix for non-standard head counts +5. Fused RMSNorm config support +6. Multi-KV TKG kernel + NKI 0.3.0 V cache fix + attention adapter + +### Step 4: Launch vLLM Server + +```bash +export NEURON_CC_FLAGS="--auto-cast=matmult" + +python -m vllm.entrypoints.openai.api_server \ + --model /home/ubuntu/models/Ministral-3-14B-text-bf16 \ + --tensor-parallel-size 4 \ + --max-model-len 4096 \ + --max-num-seqs 4 \ + --block-size 8 \ + --no-enable-prefix-caching \ + --port 8000 \ + --hf-overrides '{"architectures": ["LlamaForCausalLM"], "model_type": "llama"}' \ + --additional-config '{"override_neuron_config": {"fused_qkv": true, "qkv_nki_kernel_enabled": true, "qkv_kernel_enabled": true}}' +``` + +First launch compiles the model (~5 minutes). Subsequent launches use the NCC cache. + +For higher batch sizes (BS=8), change `--max-num-seqs 8`. + +### Step 5: Query + +```bash +curl -s http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"model": "leanstral", + "messages": [{"role": "user", "content": "What is the capital of France?"}], + "max_tokens": 256}' +``` + +## Performance Results + +Measured on trn2.3xlarge (TP=4, LNC=2, SDK 2.29) via vLLM 0.16: + +### vLLM Serving — Baseline (QKV NKI Kernel, Recommended) + +**BS=4 (`--max-num-seqs 4`):** + +| Workload | Conc | TTFT P50 (ms) | tok/s P50 | TPOT P50 (ms) | E2E P50 (ms) | +|----------|------|---------------|-----------|---------------|--------------| +| short-short (128/128) | 1 | 100.6 | 63.3 | 15.8 | 2106.9 | +| short-short (128/128) | 4 | 200.0 | 58.5 | 15.9 | 2465.7 | +| short-long (128/512) | 1 | 101.4 | 62.8 | 15.9 | 8234.2 | +| short-long (128/512) | 4 | 200.0 | 62.3 | 15.9 | 8498.9 | +| long-short (2048/128) | 1 | 303.6 | 57.4 | 17.4 | 2514.7 | +| long-short (2048/128) | 4 | 609.2 | 50.9 | 17.3 | 3400.3 | +| long-long (2048/512) | 1 | 304.0 | 57.7 | 17.3 | 9156.5 | +| long-long (2048/512) | 4 | 608.9 | 55.9 | 17.3 | 10053.9 | + +**BS=8 (`--max-num-seqs 8`):** + +| Workload | Conc | TTFT P50 (ms) | tok/s P50 | TPOT P50 (ms) | E2E P50 (ms) | +|----------|------|---------------|-----------|---------------|--------------| +| short-short (128/128) | 1 | 102.7 | 59.7 | 16.7 | 2229.5 | +| short-short (128/128) | 4 | 201.4 | 57.0 | 16.8 | 2519.0 | +| short-short (128/128) | 8 | 348.6 | 53.6 | 16.8 | 2907.6 | +| long-long (2048/512) | 1 | 306.8 | 53.5 | 18.7 | 9864.4 | +| long-long (2048/512) | 4 | 613.1 | 51.5 | 18.8 | 10813.4 | +| long-long (2048/512) | 8 | 1069.3 | 49.5 | 18.7 | 11982.9 | + +**Notes:** +- The tok/s column reports **per-request** throughput. At conc=N, the **aggregate system + throughput** is ~Nx higher (e.g., TPOT=15.9ms → 4/0.0159 = **252 tok/s aggregate** at BS=4). +- TPOT scales gracefully with batch size: 15.8ms (BS=4) → 16.7ms (BS=8) = only 6% increase. + +### Aggregate Throughput Comparison + +| SDK | Config | Aggregate tok/s (BS=4) | Per-request tok/s (BS=1) | +|-----|--------|----------------------|------------------------| +| **2.29** | QKV NKI kernel (baseline) | **~252** (4/TPOT) | 63.3 | +| **2.29** | TKG kernel (LNC=2) | **~255** (4/TPOT) | 63.5 | +| 2.28 | Fused QKV+TKG | 213.7 | 71.0 | +| GPU | H100 FP8 | 140.3 | — | + +At BS=4, both baseline and TKG configs **exceed H100 by 1.8x**. The TKG kernel with LNC=2 +matches baseline throughput while fusing the KV cache update into the attention kernel. + +## TKG Kernel Results + +The multi-KV-head TKG (Token-Key-Generation) kernel fuses attention computation and KV cache +update into a single NKI kernel. It uses a virtual-batch approach +(`B_virt = batch_size * kv_heads_per_rank`) to handle GQA models with kv\_heads\_per\_rank > 1. + +With **LNC=2** (default on trn2.3xlarge, grid=2), the TKG kernel shards the virtual-batch +computation across 2 NeuronCores per physical core, effectively halving the HBM load overhead. +This eliminates the performance regression seen with LNC=1 (grid=1). + +### TKG Launch Command + +To enable TKG, add two flags to `--additional-config`: + +```bash +python -m vllm.entrypoints.openai.api_server \ + --model /home/ubuntu/models/Ministral-3-14B-text-bf16 \ + --tensor-parallel-size 4 \ + --max-model-len 4096 \ + --max-num-seqs 4 \ + --block-size 8 \ + --no-enable-prefix-caching \ + --port 8000 \ + --hf-overrides '{"architectures": ["LlamaForCausalLM"], "model_type": "llama"}' \ + --additional-config '{"override_neuron_config": {"fused_qkv": true, "qkv_nki_kernel_enabled": true, "qkv_kernel_enabled": true, "attn_block_tkg_nki_kernel_enabled": true, "attn_block_tkg_nki_kernel_cache_update": true}}' +``` + +### TKG Performance (LNC=2, SDK 2.29) + +**BS=4 TKG (`--max-num-seqs 4`):** + +| Workload | Conc | TTFT P50 (ms) | tok/s P50 | TPOT P50 (ms) | E2E P50 (ms) | +|----------|------|---------------|-----------|---------------|--------------| +| short-short (128/128) | 1 | 101.9 | 63.5 | 15.7 | 2101.5 | +| short-short (128/128) | 4 | 247.1 | 59.4 | 15.7 | 2380.8 | +| short-long (128/512) | 1 | 102.2 | 63.4 | 15.8 | 8166.4 | +| short-long (128/512) | 4 | 249.7 | 61.0 | 16.1 | 8661.1 | +| long-short (2048/128) | 1 | 304.5 | 56.6 | 17.7 | 2550.1 | +| long-short (2048/128) | 4 | 754.9 | 48.6 | 17.1 | 3362.8 | +| long-long (2048/512) | 1 | 303.3 | 58.8 | 17.0 | 8992.8 | +| long-long (2048/512) | 4 | 753.1 | 56.3 | 16.9 | 9832.8 | + +**BS=8 TKG (`--max-num-seqs 8`):** + +| Workload | Conc | TTFT P50 (ms) | tok/s P50 | TPOT P50 (ms) | E2E P50 (ms) | +|----------|------|---------------|-----------|---------------|--------------| +| short-short (128/128) | 1 | 101.6 | 57.3 | 17.5 | 2317.7 | +| short-short (128/128) | 4 | 246.6 | 53.7 | 17.5 | 2613.6 | +| short-short (128/128) | 8 | 390.3 | 50.8 | 17.5 | 2981.1 | +| long-long (2048/512) | 1 | 301.9 | 49.6 | 20.1 | 10612.0 | +| long-long (2048/512) | 4 | 750.2 | 47.3 | 20.3 | 11560.2 | +| long-long (2048/512) | 8 | 1197.7 | 45.5 | 20.3 | 12705.5 | + +### TKG vs Baseline TPOT Comparison + +With LNC=2, the TKG overhead is minimal at BS=4 and modest at BS=8: + +| Batch Size | Workload | Baseline TPOT | TKG LNC=2 TPOT | Overhead | +|-----------|----------|---------------|-----------------|----------| +| BS=4 | short-short | 15.8ms | 15.7ms | **-0.6%** | +| BS=4 | long-long | 17.3ms | 17.0ms | **-1.7%** | +| BS=8 | short-short | 16.7ms | 17.5ms | **+4.8%** | +| BS=8 | long-long | 18.7ms | 20.1ms | **+7.5%** | + +For comparison, with LNC=1 (grid=1) the overhead was 11% at BS=4 and 27% at BS=8. +LNC=2 sharding distributes the virtual-batch HBM loads across 2 NeuronCores, halving +the per-core overhead. + +### Design Notes + +The TKG kernel uses a "virtual batch" expansion where each KV head group becomes a separate +batch entry (`B_virt = BS * kv_heads_per_rank`). The inner attention kernel's batch loops +(`_compute_qk_matmul`, `_compute_pv_matmul_and_store`) iterate over `B_virt` entries, each +loading a KV cache slice from HBM. With LNC=2, the `grid=2` NKI sharding distributes these +iterations across 2 physical NeuronCores, effectively halving the HBM bandwidth pressure. + +The kernel and adapter are included as a reference implementation for upstream multi-KV-head +TKG support in nkilib. + +## Known Limitations + +1. **TKG scaling at BS=8**: The multi-KV-head TKG kernel matches baseline at BS=4 but adds + 5-9% TPOT overhead at BS=8 due to the virtual-batch approach. For latency-critical BS=8 + workloads, the baseline QKV NKI kernel config may be preferred. + +2. **KVDP not supported**: KV data parallelism is not compatible with the multi-KV-head + kernel path. + +3. **FP8 checkpoint**: The original checkpoint uses FP8 E4M3 weights. These are dequantized + to BF16 during extraction. Runtime FP8 inference is not currently supported. + +4. **Pixtral auto-promotion**: vLLM 0.16 auto-promotes Mistral models to Pixtral even with + tokenizer fixes. The `--hf-overrides` flag is mandatory to force LlamaForCausalLM. + +5. **Text-only**: This contrib extracts and serves only the text decoder. Vision-language + inference requires the full VL model and additional patches not included here. + +## Compatibility Matrix + +| Instance | SDK 2.29 | SDK 2.28 | Earlier | +|----------|----------|----------|---------| +| trn2.3xlarge (TP=4) | **Tested** | Tested (prior version) | Not supported | +| trn2.48xlarge | Not tested | Not tested | Not tested | +| trn1 / inf2 | Not supported | Not supported | Not supported | + +## Source Files + +| File | Description | +|------|-------------| +| `src/setup_patches.py` | SDK 2.29 runtime patch installer (6 patches) | +| `src/extract_text_model.py` | FP8→BF16 text-only checkpoint extraction | +| `src/attention_block_tkg_multi_kv.py` | Multi-KV-head TKG kernel (NKI 0.3.0 compatible) | +| `src/multi_kv_adapter.py` | TKG kernel adapter for attention_base.py | +| `src/fix_nki030.py` | NKI 0.3.0 compatibility fixes | +| `src/modeling_leanstral.py` | Legacy model class (SDK 2.28, reference only) | +| `src/patch_native_multi_kv.py` | Legacy adapter (SDK 2.28, reference only) | +| `bench.py` | Async streaming benchmark script | +| `test/integration/test_model.py` | Integration test | + +## Upstream NxDI Gaps + +This contrib identifies NxDI gaps that would benefit from upstream support: + +1. **Multi-KV-head TKG kernel** — the bundled kernel hardcodes kv\_heads=1. The nki-library + kernel fork adds `n_kv_heads` parameter with virtual-batch dispatch. With LNC=2 sharding + (grid=2), performance matches baseline at BS=4 and has only 5-9% overhead at BS=8. +2. **Fused QKV conversion** — `convert_state_dict_to_fused_qkv` assumes standard Llama head + ratios; non-standard ratios (32Q/8KV at TP=4) need a fix to compute interleave groups. +3. **RMS norm epsilon** — NxDI model base doesn't pass `rms_norm_eps` from config, defaulting + to 1e-5 which differs from Mistral's 1e-5 (same in this case, but other models differ). + +## Maintainer + +Leanstral Project + +**Last Updated:** 2026-04-26 diff --git a/contrib/models/Ministral-3-14B-Instruct-2512/bench.py b/contrib/models/Ministral-3-14B-Instruct-2512/bench.py new file mode 100644 index 00000000..0989439b --- /dev/null +++ b/contrib/models/Ministral-3-14B-Instruct-2512/bench.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +"""Benchmark vLLM server with streaming for TTFT, TPOT, output tok/s measurements. + +Uses the OpenAI-compatible chat/completions API with streaming to measure: +- TTFT (time to first token) +- Output tok/s (decode throughput) +- TPOT (time per output token) +- E2E latency + +Workloads: short-short(128/128), short-long(128/512), long-short(2048/128), long-long(2048/512) +Concurrency: 1, 4 +""" + +import argparse +import asyncio +import json +import time +import statistics +import aiohttp + +CONFIG = { + "base_url": "http://localhost:8000", + "model": "/home/ubuntu/models/Ministral-3-14B-text-bf16", +} + +WORKLOADS = { + "short-short": (128, 128), + "short-long": (128, 512), + "long-short": (2048, 128), + "long-long": (2048, 512), +} + + +def make_prompt(n_tokens): + """Generate a prompt of approximately n_tokens tokens.""" + base = "Explain the following topic in great detail with examples and analysis. " + filler = "The quick brown fox jumps over the lazy dog. " + n_repeats = max(1, int(n_tokens / 13)) + prompt = base + filler * n_repeats + return prompt + + +async def stream_request(session, model, prompt, max_tokens, request_id=0): + """Send a streaming chat completion request and measure timing.""" + payload = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": max_tokens, + "stream": True, + "extra_body": {"ignore_eos": True}, + "temperature": 0.0, + } + + t_start = time.perf_counter() + t_first_token = None + token_times = [] + total_tokens = 0 + + try: + async with session.post( + f"{CONFIG['base_url']}/v1/chat/completions", + json=payload, + timeout=aiohttp.ClientTimeout(total=600), + ) as resp: + if resp.status != 200: + text = await resp.text() + return {"error": f"HTTP {resp.status}: {text[:200]}"} + + async for line in resp.content: + line = line.decode("utf-8").strip() + if not line.startswith("data: "): + continue + data = line[6:] + if data == "[DONE]": + break + try: + chunk = json.loads(data) + delta = chunk["choices"][0].get("delta", {}) + content = delta.get("content", "") + if content: + now = time.perf_counter() + if t_first_token is None: + t_first_token = now + token_times.append(now) + total_tokens += 1 + except (json.JSONDecodeError, KeyError, IndexError): + continue + + except Exception as e: + return {"error": str(e)} + + t_end = time.perf_counter() + + if t_first_token is None: + return {"error": "No tokens received"} + + ttft = (t_first_token - t_start) * 1000 # ms + e2e = (t_end - t_start) * 1000 # ms + + # Calculate inter-token times (TPOT) for decode phase + if len(token_times) > 1: + inter_token = [ + (token_times[i] - token_times[i - 1]) * 1000 + for i in range(1, len(token_times)) + ] + else: + inter_token = [0] + + decode_time = (t_end - t_first_token) if total_tokens > 1 else 0 + output_toks = max(1, total_tokens - 1) + tok_per_sec = output_toks / decode_time if decode_time > 0 else 0 + + return { + "request_id": request_id, + "ttft_ms": ttft, + "e2e_ms": e2e, + "total_tokens": total_tokens, + "output_tok_s": tok_per_sec, + "tpot_ms": statistics.median(inter_token) if inter_token else 0, + "inter_token_times": inter_token, + } + + +async def run_workload( + model, workload_name, input_tokens, output_tokens, concurrency, n_requests=5 +): + """Run a workload at given concurrency.""" + prompt = make_prompt(input_tokens) + + print( + f"\n Workload: {workload_name} (in={input_tokens}, out={output_tokens}), " + f"concurrency={concurrency}, requests={n_requests}" + ) + + # Warmup + print(f" Warming up...") + async with aiohttp.ClientSession() as session: + result = await stream_request(session, model, prompt, output_tokens, 0) + if "error" in result: + print(f" ERROR in warmup: {result['error']}") + return None + + # Benchmark + results = [] + async with aiohttp.ClientSession() as session: + for batch_start in range(0, n_requests, concurrency): + batch_size = min(concurrency, n_requests - batch_start) + tasks = [ + stream_request(session, model, prompt, output_tokens, batch_start + i) + for i in range(batch_size) + ] + batch_results = await asyncio.gather(*tasks) + results.extend(batch_results) + + # Filter errors + errors = [r for r in results if "error" in r] + good = [r for r in results if "error" not in r] + + if not good: + print(f" ALL REQUESTS FAILED: {errors}") + return None + + if errors: + print(f" {len(errors)} errors out of {len(results)} requests") + + # Aggregate stats + ttfts = [r["ttft_ms"] for r in good] + e2es = [r["e2e_ms"] for r in good] + toks = [r["output_tok_s"] for r in good] + all_inter = [] + for r in good: + all_inter.extend(r["inter_token_times"]) + + def pcts(data): + data = sorted(data) + n = len(data) + return { + "median": statistics.median(data), + "p95": data[int(n * 0.95)] if n > 1 else data[0], + "p99": data[int(n * 0.99)] if n > 1 else data[0], + } + + stats = { + "workload": workload_name, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "concurrency": concurrency, + "n_requests": len(good), + "n_errors": len(errors), + "ttft": pcts(ttfts), + "e2e": pcts(e2es), + "output_tok_s": pcts(toks), + "tpot": pcts(all_inter) if all_inter else {"median": 0, "p95": 0, "p99": 0}, + "avg_tokens": statistics.mean([r["total_tokens"] for r in good]), + } + + print( + f" TTFT (ms): median={stats['ttft']['median']:.1f} P95={stats['ttft']['p95']:.1f} P99={stats['ttft']['p99']:.1f}" + ) + print( + f" Output tok/s: median={stats['output_tok_s']['median']:.1f} P95={stats['output_tok_s']['p95']:.1f} P99={stats['output_tok_s']['p99']:.1f}" + ) + print( + f" TPOT (ms): median={stats['tpot']['median']:.1f} P95={stats['tpot']['p95']:.1f} P99={stats['tpot']['p99']:.1f}" + ) + print( + f" E2E (ms): median={stats['e2e']['median']:.1f} P95={stats['e2e']['p95']:.1f} P99={stats['e2e']['p99']:.1f}" + ) + print(f" Avg tokens: {stats['avg_tokens']:.0f}") + + return stats + + +async def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", default=CONFIG["model"]) + parser.add_argument("--base-url", default=CONFIG["base_url"]) + parser.add_argument("--workloads", nargs="+", default=list(WORKLOADS.keys())) + parser.add_argument("--concurrency", nargs="+", type=int, default=[1, 4]) + parser.add_argument("--requests", type=int, default=5, help="Requests per workload") + args = parser.parse_args() + + CONFIG["base_url"] = args.base_url + CONFIG["model"] = args.model + + print(f"Benchmarking {CONFIG['model']}") + print(f"Server: {CONFIG['base_url']}") + print(f"Workloads: {args.workloads}") + print(f"Concurrency: {args.concurrency}") + print(f"Requests per workload: {args.requests}") + + all_results = [] + for wl_name in args.workloads: + if wl_name not in WORKLOADS: + print(f"Unknown workload: {wl_name}") + continue + in_toks, out_toks = WORKLOADS[wl_name] + for conc in args.concurrency: + stats = await run_workload( + CONFIG["model"], wl_name, in_toks, out_toks, conc, args.requests + ) + if stats: + all_results.append(stats) + + # Print summary table + print("\n" + "=" * 100) + print( + f"{'Workload':<15} {'Conc':>4} {'TTFT-P50':>10} {'TTFT-P95':>10} {'tok/s-P50':>10} {'TPOT-P50':>10} {'E2E-P50':>10}" + ) + print("-" * 100) + for r in all_results: + print( + f"{r['workload']:<15} {r['concurrency']:>4} " + f"{r['ttft']['median']:>9.1f} {r['ttft']['p95']:>9.1f} " + f"{r['output_tok_s']['median']:>9.1f} " + f"{r['tpot']['median']:>9.1f} " + f"{r['e2e']['median']:>9.1f}" + ) + print("=" * 100) + + # Save raw results + with open("/home/ubuntu/bench_results.json", "w") as f: + json.dump(all_results, f, indent=2) + print(f"\nRaw results saved to /home/ubuntu/bench_results.json") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contrib/models/Ministral-3-14B-Instruct-2512/src/__init__.py b/contrib/models/Ministral-3-14B-Instruct-2512/src/__init__.py new file mode 100644 index 00000000..bc27f491 --- /dev/null +++ b/contrib/models/Ministral-3-14B-Instruct-2512/src/__init__.py @@ -0,0 +1,15 @@ +from .modeling_leanstral import ( + build_inference_config, + get_model_cls, + apply_shard_over_heads_patch, + apply_multi_kv_tkg_patch, + load_cpu_projector, +) + +__all__ = [ + "build_inference_config", + "get_model_cls", + "apply_shard_over_heads_patch", + "apply_multi_kv_tkg_patch", + "load_cpu_projector", +] diff --git a/contrib/models/Ministral-3-14B-Instruct-2512/src/attention_block_tkg_multi_kv.py b/contrib/models/Ministral-3-14B-Instruct-2512/src/attention_block_tkg_multi_kv.py new file mode 100644 index 00000000..a34034fd --- /dev/null +++ b/contrib/models/Ministral-3-14B-Instruct-2512/src/attention_block_tkg_multi_kv.py @@ -0,0 +1,1969 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Attention Block TKG Kernel — Multi-KV-Head Fork + +Forked from nkilib.experimental.transformer.attention_block_tkg to support +kv_heads_per_rank > 1 (e.g., 8 KV heads at TP=4 = 2 KV heads per rank). + +The original kernel hardcodes kv_heads=1. This version: + - Derives kv_heads from the W_qkv weight shape + - Processes K with n_heads=kv_heads (not 1) + - Loops over KV head groups for attention, calling attention_tkg once per group + - Accumulates partial O-proj results across groups + - Updates all KV cache heads + +The inner attention kernel (attention_tkg) is unchanged — it already operates on +one KV head group at a time. + +Single kernel call per layer → avoids the NCC_ITEN404 compiler ICE that occurs +with the multi-call monkeypatch approach. +""" + +from typing import Any, Dict, Optional, Tuple + +import nki +import nki.isa as nisa +import nki.language as nl +from nki.isa.constants import oob_mode + +from nkilib.core.attention.attention_tkg import AttnTKGConfig, attention_tkg +from nkilib.core.attention.attention_tkg_utils import is_fp8_e4m3 +from nkilib.core.embeddings.rope import RoPE_sbuf +from nkilib.core.output_projection.output_projection_tkg import output_projection_tkg +from nkilib.core.qkv.qkv import qkv_tkg +from nkilib.core.utils.allocator import SbufManager, create_auto_alloc_manager +from nkilib.core.utils.common_types import NormType, QuantizationType +from nkilib.core.utils.kernel_assert import kernel_assert +from nkilib.core.utils.kernel_helpers import ( + get_max_positive_value_for_dtype, + get_verified_program_sharding_info, + is_hbm_buffer, +) +from nkilib.core.utils.logging import Logger +from nkilib.core.utils.tensor_view import TensorView + +# KVDP sharding helpers — removed for multi-KV-head kernel. +# KVDP is NOT supported with kv_heads_per_rank > 1. +# The KVDP parameter is kept in the API for compatibility but asserted == 1. + + +# TODO(NKI-699): Refactor API to use configuration dataclasses for better clarity +# Note: Using keyword-only args (via *) to avoid breaking callers when adding/reordering +# parameters, and to improve readability given the large number of arguments. +@nki.jit +def attention_block_tkg( + # -- input + X: nl.ndarray, + *, + X_hidden_dim_actual: Optional[int], + # -- rmsnorm X + rmsnorm_X_enabled: bool, + rmsnorm_X_eps: Optional[float], + rmsnorm_X_gamma: Optional[nl.ndarray], + # -- qkv projections + W_qkv: nl.ndarray, + bias_qkv: Optional[nl.ndarray], + quantization_type_qkv: QuantizationType, + weight_dequant_scale_qkv: Optional[nl.ndarray], + input_dequant_scale_qkv: Optional[nl.ndarray], + # -- Q/K processing: pre-RoPE RMSNorm + rmsnorm_QK_pre_rope_enabled: bool, + rmsnorm_QK_pre_rope_eps: float, + rmsnorm_QK_pre_rope_W_Q: Optional[nl.ndarray], + rmsnorm_QK_pre_rope_W_K: Optional[nl.ndarray], + # -- Q/K processing: RoPE + cos: Optional[nl.ndarray], + sin: Optional[nl.ndarray], + rope_contiguous_layout: bool, + # -- Q/K processing: post-RoPE RMSNorm + rmsnorm_QK_post_rope_enabled: bool, + rmsnorm_QK_post_rope_eps: float, + rmsnorm_QK_post_rope_W_Q: Optional[nl.ndarray], + rmsnorm_QK_post_rope_W_K: Optional[nl.ndarray], + # -- attention + K_cache_transposed: bool, + active_blocks_table: Optional[nl.ndarray], + K_cache: nl.ndarray, + V_cache: nl.ndarray, + attention_mask: nl.ndarray, + sink: Optional[nl.ndarray], + softmax_scale: Optional[float] = None, + # -- KV cache update + update_cache: bool, + kv_cache_update_idx: Optional[nl.ndarray], + k_scale: Optional[nl.ndarray] = None, + v_scale: Optional[nl.ndarray] = None, + # -- output projection + W_out: Optional[nl.ndarray], + bias_out: Optional[nl.ndarray], + quantization_type_out: QuantizationType, + weight_dequant_scale_out: Optional[nl.ndarray], + input_dequant_scale_out: Optional[nl.ndarray], + transposed_out: bool, + # -- output + out_in_sb: bool, + sbm: Optional[SbufManager] = None, + skip_attention: bool = False, + # -- Multi-KV-head support + n_kv_heads: int = 1, + # -- Number of query heads per rank (passed explicitly to avoid .shape on PlaceholderParameter weights) + n_q_heads: int = 8, + # -- Head dimension (passed explicitly to avoid NKI .shape limitations on KV cache tensors) + head_dim: int = 128, + # -- Max context length for KV cache (passed explicitly for multi-KV-head) + s_max_ctx: int = 256, + # -- Per-group attention mask (for multi-KV-head, avoids NKI reshape-on-slice) + # Shape: [S_ctx, B, q_per_kv_group, S_tkg] — same mask reused for all KV head groups + group_attention_mask: Optional[nl.ndarray] = None, + # -- Pre-allocated HBM buffer for V active tokens (multi-KV-head) + # Shape: [B_virt, 1, S_tkg, d_head] — pre-allocated by adapter to avoid NCC_IBIR440 + # The compiler's DRAM allocator cannot handle kernel-internal HBM ndarray allocations. + v_active_hbm: Optional[nl.ndarray] = None, + # -- Replicated kv_cache_update_idx for multi-KV-head cache update + # Shape: [B_virt, 1] where B_virt = B * kv_heads — same position replicated per head + # Pre-allocated by adapter to avoid kernel-internal HBM allocation (NCC_IBIR440) + kv_cache_update_idx_virt: Optional[nl.ndarray] = None, +): + """ + Fused Attention Block for Token Generation (TKG). + + Performs end-to-end attention block computation optimized for autoregressive + decoding with all stages fused in SBUF to avoid HBM round-trips. Intended for + small batch sizes (B ≤ 16) and short sequence lengths (S_tkg ≤ 8) typical in + token generation workloads. + + Dimensions: + B: Batch size (≤ 16 recommended) + B_attn: Batch size for attention = B/KVDP when KV data parallelism enabled, otherwise B + S_tkg: Number of new tokens to generate (≤ 8 required) + S_ctx: KV cache sequence length in current bucket + S_max_ctx: Maximum KV cache capacity of current bucket + H: Hidden dimension (must be multiple of 128) + d_head: Head dimension (must be even) + q_heads: Number of query heads + kv_heads: Number of KV heads per rank (derived from W_qkv shape). Supports kv_heads >= 1. + num_blocks: Number of blocks in block KV cache + block_len: Block length for block KV cache + + Args: + X (nl.ndarray): Input hidden states + Shape: + [B, S_tkg, H] when in HBM + [H0=pmax, BxS, H1] where H1=lnc x (H//lnc//pmax) when in SBUF + + When in SBUF, the layout is obtained by rearranging HBM data: + HBM: (BxS, lnc, H0, H1//lnc) -> SBUF: (H0, BxS, (lnc, H1//lnc)) + This interleaves H1//lnc values from each lnc chunk along the H dimension, + matching qkv_tkg() kernel's expected SBUF input format. + X_hidden_dim_actual (Optional[int]): Actual hidden dim if X is padded + + rmsnorm_X_enabled (bool): Apply RMSNorm to X before QKV projection + rmsnorm_X_eps (Optional[float]): RMSNorm epsilon (default 1e-3) + rmsnorm_X_gamma (Optional[nl.ndarray]): [1, H] @ HBM, RMSNorm weights + + W_qkv (nl.ndarray): [H, d_head*(q_heads+2)] @ HBM, QKV projection weights + bias_qkv (Optional[nl.ndarray]): [1, d_head*(q_heads+2)] @ HBM, QKV bias + quantization_type_qkv (QuantizationType): Type of quantization for QKV projection (NONE, STATIC). + weight_dequant_scale_qkv (Optional[nl.ndarray]): Weight dequantization scale for QKV projection. + Shape: [PMAX, 1] @ HBM when quantization_type_qkv is STATIC. + input_dequant_scale_qkv (Optional[nl.ndarray]): Input dequantization scale for QKV projection. + Shape: [PMAX, 1] @ HBM when quantization_type_qkv is STATIC. + + rmsnorm_QK_pre_rope_enabled (bool): Apply RMSNorm to Q/K before RoPE + rmsnorm_QK_pre_rope_eps (float): Pre-RoPE RMSNorm epsilon + rmsnorm_QK_pre_rope_W_Q (Optional[nl.ndarray]): [1, d_head] @ HBM, Pre-RoPE Q gamma weights + rmsnorm_QK_pre_rope_W_K (Optional[nl.ndarray]): [1, d_head] @ HBM, Pre-RoPE K gamma weights + cos (Optional[nl.ndarray]): [d_head//2, B, S_tkg] @ HBM, RoPE cosine embeddings (None = skip RoPE) + sin (Optional[nl.ndarray]): [d_head//2, B, S_tkg] @ HBM, RoPE sine embeddings (None = skip RoPE) + rope_contiguous_layout (bool): True for contiguous halves, False for interleaved + rmsnorm_QK_post_rope_enabled (bool): Apply RMSNorm to Q/K after RoPE + rmsnorm_QK_post_rope_eps (float): Post-RoPE RMSNorm epsilon + rmsnorm_QK_post_rope_W_Q (Optional[nl.ndarray]): [1, d_head] @ HBM, Post-RoPE Q weights + rmsnorm_QK_post_rope_W_K (Optional[nl.ndarray]): [1, d_head] @ HBM, Post-RoPE K weights + + K_cache_transposed (bool): Whether K cache is stored transposed in HBM. + If True: K cache is [B, d_head, S_ctx]. If False: K cache is [B, S_ctx, d_head]. + Must be False for block KV cache. + active_blocks_table (Optional[nl.ndarray]): [B, num_blocks] @ HBM, Block indices for block KV cache + K_cache (nl.ndarray): Key cache @ HBM. + Flat KV: [B, d_head, S_max_ctx] if K_cache_transposed else [B, S_max_ctx, d_head]. + Block KV: [num_blocks, block_len, d_head]. + V_cache (nl.ndarray): Value cache @ HBM. + Flat KV: [B, S_max_ctx, d_head]. + Block KV: [num_blocks, block_len, d_head]. + attention_mask (nl.ndarray): [S_ctx, B, q_heads, S_tkg] @ HBM, Attention mask + sink (Optional[nl.ndarray]): [H, 1] @ HBM, Attention sink tokens + softmax_scale (Optional[float]): Scaling factor for attention scores. If None, defaults to 1/sqrt(d_head). + + k_scale (Optional[nl.ndarray]): Scale for K quantization to FP8. Shape (PMAX, 1) or (1, 1) @ HBM. + Must contain a single scalar value (replicated or scalar). When provided with v_scale, + enables FP8 KV cache quantization. Supported dtypes: float32, float16, bfloat16. + v_scale (Optional[nl.ndarray]): Scale for V quantization to FP8. Shape (PMAX, 1) or (1, 1) @ HBM. + Must contain a single scalar value (replicated or scalar). When provided with k_scale, + enables FP8 KV cache quantization. Supported dtypes: float32, float16, bfloat16. + + update_cache (bool): Update KV cache with new tokens + kv_cache_update_idx (Optional[nl.ndarray]): [B, 1], Cache write positions (uint32_max = skip) + + W_out (Optional[nl.ndarray]): [q_heads*d_head, H] @ HBM, Output projection weights + bias_out (Optional[nl.ndarray]): [1, H] @ HBM, Output projection bias + quantization_type_out (QuantizationType): Type of quantization for output projection (NONE, STATIC). + weight_dequant_scale_out (Optional[nl.ndarray]): Weight dequantization scale for output projection. + Shape: [PMAX, 1] @ HBM when quantization_type_out is STATIC. + input_dequant_scale_out (Optional[nl.ndarray]): Input dequantization scale for output projection. + Shape: [PMAX, 1] @ HBM when quantization_type_out is STATIC. + transposed_out (bool): Transpose output layout (requires W_out) + out_in_sb (bool): Return output in SBUF instead of HBM + sbm (Optional[SbufManager]): SBUF memory manager (otherwise auto-allocated) + skip_attention (bool): Skip attention computation (for testing) + + KVDP (int): KV cache data parallelism degree - number of ranks that shard the KV cache + across the batch dimension (1 = disabled). Each rank processes B/KVDP batches. + KVDP_replica_group (Optional[ReplicaGroup]): Replica group for collective ops + + KV Data Parallelism (KVDP > 1): + KV-DP partitions the KV cache across ranks along the batch dimension. Each rank holds + B/KVDP batches of the KV cache. Before attention: all_gather Q heads, slice Q/K/V batch. + After attention: all_gather output batch, slice heads. + + When KV data parallelism is enabled, input/output shapes change: + - B_attn = B / KVDP (batches per rank for attention) + - q_heads_attn = q_heads * KVDP (query heads per rank after gather) + + Input shape changes: + - K_cache, V_cache: [B_attn, ...] instead of [B, ...] + - attention_mask: [S_ctx, B_attn, q_heads_attn, S_tkg] + - kv_cache_update_idx: [B_attn, 1] (caller must slice per rank) + + Output shape changes (when update_cache=False): + - K_out: [d_head, B_attn, S_tkg] + - V_out: [B_attn, 1, S_tkg, d_head] + + Returns: + out (nl.ndarray): Output tensor with shape depending on projection and output location: + - Without projection (W_out=None): + - out_in_sb=False: [B, q_heads, d_head, S_tkg] @ HBM + - out_in_sb=True: [d_head, B*q_heads*S_tkg] @ SBUF + - With projection (W_out provided): + - transposed_out=False, out_in_sb=False: [B*S_tkg, H] @ HBM + - transposed_out=False, out_in_sb=True: [B*S_tkg, H//lnc] @ SBUF + - transposed_out=True, out_in_sb=False: [128, lnc, H//lnc//128, B*S_tkg] @ HBM + - transposed_out=True, out_in_sb=True: [128, H//lnc//128, B*S_tkg] @ SBUF + K_out (nl.ndarray): + - If update_cache=True: Updated K cache (shape matches K_cache input) + - If update_cache=False: New K tokens [d_head, B_attn, S_tkg] @ HBM + V_out (nl.ndarray): + - If update_cache=True: Updated V cache (shape matches V_cache input) + - If update_cache=False: New V tokens [B_attn, 1, S_tkg, d_head] @ HBM + + Notes: + - Requires NeuronCore v3+ + - d_head must be even + - H must be multiple of pmax + - Requires batch * sequence_tkg * q_heads <= pmax (=pmax) + - Supports grouped-query attention (GQA) with single key/value head + - LNC-2 sharding support for KV cache updates + + Pseudocode: + # Stage 1: QKV Projection + if rmsnorm_X_enabled: + X_norm = rms_norm(X, rmsnorm_X_gamma, rmsnorm_X_eps) + QKV = matmul(X_norm, W_qkv) + bias_qkv + + # Stage 2: Q/K Processing + Q, K = split_and_transpose(QKV) + if rmsnorm_QK_pre_rope_enabled: + Q = rms_norm(Q, rmsnorm_QK_pre_rope_W_Q) + K = rms_norm(K, rmsnorm_QK_pre_rope_W_K) + if cos is not None and sin is not None: + Q, K = rope(Q, cos, sin), rope(K, cos, sin) + if rmsnorm_QK_post_rope_enabled: + Q = rms_norm(Q, rmsnorm_QK_post_rope_W_Q) + K = rms_norm(K, rmsnorm_QK_post_rope_W_K) + V = extract_V(QKV) + + # Stage 3: Attention + Q_scaled = Q / sqrt(d_head) + attn_out = attention_tkg(Q_scaled, K, V, K_cache, V_cache, attention_mask) + + # Stage 4: KV Cache Update + if update_cache: + update_kv_cache(K_cache, V_cache, K, V, kv_cache_update_idx) + K_out, V_out = K_cache, V_cache # Return updated caches + else: + K_out, V_out = K, V # Return new tokens + + # Stage 5: Output Projection + if W_out is not None: + output = matmul(attn_out, W_out) + bias_out + else: + output = attn_out + + return output, K_out, V_out + """ + + # ========== Validation and Setup ========== + # Inline config extraction instead of calling _validate_and_extract_config, + # because the NKI tracer cannot handle mixed nl.ndarray + Python int function args. + + kv_heads = n_kv_heads if n_kv_heads is not None else 1 + + # Extract batch/sequence dimensions from attention_mask (always works in NKI tracer) + _, B, _, _ = attention_mask.shape + + if X.buffer == nl.sbuf: + kernel_assert(X.shape[0] == nl.tile_size.pmax, "SBUF input X dim0 must be pmax") + H = X.shape[2] * nl.tile_size.pmax + S_tkg = X.shape[1] // B + else: + B, S_tkg, H = X.shape + + # d_head from explicit parameter (avoids .shape on 4D KV cache which NKI can't trace) + d_head = ( + head_dim if head_dim is not None else V_cache.shape[2] + ) # fallback: 3D cache + # Compute I from explicit q_heads + kv_heads (avoids W_qkv.shape on PlaceholderParameter) + q_heads = n_q_heads + I = d_head * (q_heads + 2 * kv_heads) + half_d = d_head // 2 + + is_KVDP = False # KVDP not supported with multi-KV-head + B_attn = B + q_heads_attn = q_heads + is_block_kv = active_blocks_table is not None + + # Process KV cache + K_cache, V_cache, cache_had_head_dim = __internal_squeeze_head_dim( + K_cache, V_cache, is_block_kv, kv_heads + ) + + if is_block_kv: + blk_len = V_cache.shape[1] + S_ctx = S_max_ctx = active_blocks_table.shape[1] * blk_len + else: + S_ctx = attention_mask.shape[0] + # Always use explicit s_max_ctx to avoid .shape on PlaceholderParameter V_cache + S_max_ctx = s_max_ctx + blk_len = 0 + + do_out_proj = W_out is not None + + # KV Quantization + kv_quant = k_scale is not None and v_scale is not None + q_per_kv_group = q_heads // kv_heads # Q heads per KV head group + + sbm = ( + sbm + if sbm is not None + else create_auto_alloc_manager(logger=Logger("attn-block-tkg")) + ) + sbm.open_scope("attn-blk-tkg-scope") + + # ========== QKV Projection ========== + # Input: X [B, S_tkg, H] @ HBM + # Output: QKV_tkg_sb [B*S_tkg, I] @ SBUF where I = d_head * (q_heads + 2) + rmsnorm_X_eps = 1e-3 if rmsnorm_X_eps is None else rmsnorm_X_eps + QKV_tkg_sb = qkv_tkg( + hidden=X, + qkv_w=W_qkv, + norm_w=rmsnorm_X_gamma, + d_head=d_head, + num_q_heads=q_heads, + num_kv_heads=kv_heads, + eps=rmsnorm_X_eps, + norm_type=NormType.RMS_NORM if rmsnorm_X_enabled else NormType.NO_NORM, + quantization_type=quantization_type_qkv, + qkv_w_scale=weight_dequant_scale_qkv, + qkv_in_scale=input_dequant_scale_qkv, + output_in_sbuf=True, + qkv_bias=bias_qkv, + hidden_actual=X_hidden_dim_actual, + sbm=sbm, + ) + + # ========== Q/K Processing: Transpose + RMSNorm pre + RoPE + RMSNorm post ========== + # Input: QKV_tkg_sb [B*S_tkg, I] @ SBUF + # Output: Q_tkg_sb [d_head, B*q_heads*S_tkg] @ SBUF, K_tkg_sb [d_head, B*S_tkg] @ SBUF + Q_tkg_sb, K_tkg_sb = _QK_processing( + QKV_tkg_sb, + q_heads, + kv_heads, + B, + rmsnorm_QK_pre_rope_enabled, + rmsnorm_QK_pre_rope_eps, + rmsnorm_QK_pre_rope_W_Q, + rmsnorm_QK_pre_rope_W_K, + cos, + sin, + rope_contiguous_layout, + rmsnorm_QK_post_rope_enabled, + rmsnorm_QK_post_rope_eps, + rmsnorm_QK_post_rope_W_Q, + rmsnorm_QK_post_rope_W_K, + sbm, + ) + + # ========== Extract V from QKV ========== + # For multi-KV-head: extract kv_heads * d_head worth of V data + # V lives at offset q_heads*d + kv_heads*d in the QKV buffer + # V_tkg_sb shape: [B*S_tkg, kv_heads*d_head] — ALL KV heads + v_offset = d_head * (q_heads + kv_heads) + v_size = kv_heads * d_head + V_tkg_sb = nl.ndarray( + (B * S_tkg, v_size), + dtype=QKV_tkg_sb.dtype, + buffer=nl.sbuf, + name="attention_blk_V_tkg_sb", + ) + nisa.tensor_copy(V_tkg_sb, QKV_tkg_sb[:, nl.ds(v_offset, v_size)]) + + # Quantize K and V to FP8 for attention when kv_quant=True + if kv_quant: + K_tkg_sb = _quantize_to_fp8(K_tkg_sb, k_scale, sbm) + V_tkg_sb = _quantize_to_fp8(V_tkg_sb, v_scale, sbm) + + # ========== KV Data Parallelism: Input Collectives ========== + # NOTE: KVDP with multi-KV-head not yet supported + kernel_assert( + not is_KVDP or kv_heads == 1, "KVDP with multi-KV-head not yet supported" + ) + + # ========== Attention Computation (Virtual Batch Approach) ========== + # Instead of calling attention_tkg N times (once per KV head group), + # we reshape the data so all groups become "virtual batches" and call + # attention_tkg ONCE. This avoids compiler name collisions (NCC_INLA001) + # that occur when the same sub-kernel is called multiple times. + # + # Virtual batch: B_virt = B_attn * kv_heads + # Each virtual batch has q_per_kv_group Q heads and 1 KV head. + # Q_tkg_sb layout: [d, B*q_heads*S] = [d, B*(kv*qpg)*S] = [d, B_virt*qpg*S] ✓ (compatible) + # K_tkg_sb layout: [d, B*kv_heads*S] = [d, B_virt*S] ✓ (compatible) + # K_cache: [B, kv_heads, d, S_max] → reshape to [B_virt, 1, d, S_max] + # V_cache: [B, kv_heads, S_max, d] → reshape to [B_virt, 1, S_max, d] + # group_mask: [S_ctx, B_virt, qpg, S_tkg] (pre-created by adapter) + + # --- For multi-KV-head: copy V tokens to pre-allocated HBM buffer --- + # V_tkg_sb: [B*S_tkg, kv_heads*d_head] @ SBUF — needs reshaping to virtual batch layout. + # SBUF partition constraints prevent reshape from [1, 256] to [2, 128] directly. + # DMA copy to the pre-allocated v_active_hbm (passed from adapter) to avoid NCC_IBIR440. + if kv_heads > 1: + B_virt = B_attn * kv_heads + kernel_assert( + v_active_hbm is not None, + "v_active_hbm must be provided when n_kv_heads > 1", + ) + # DMA copy V from SBUF to HBM. + # Reshape HBM to match SBUF layout [B*S, kv*d] for the copy. + nisa.dma_copy( + v_active_hbm.reshape((B_attn * S_tkg, kv_heads * d_head)), + V_tkg_sb, + ) + V_active_hbm = v_active_hbm + + if skip_attention: + attn_out = Q_tkg_sb + else: + # Scale Q by softmax_scale (default: 1/sqrt(d_head)) + _softmax_scale = ( + softmax_scale if softmax_scale is not None else d_head ** (-0.5) + ) + nisa.tensor_scalar( + dst=Q_tkg_sb, data=Q_tkg_sb, op0=nl.multiply, operand0=_softmax_scale + ) + + if kv_heads == 1: + B_virt = B_attn + + # Single-head: V_active on HBM + V_active_dtype = nl.float8_e4m3 if kv_quant else V_tkg_sb.dtype + V_active_hbm = nl.ndarray( + (B_attn, 1, S_tkg, d_head), + dtype=V_active_dtype, + buffer=nl.shared_hbm, + name="v_active_virt_hbm", + ) + nisa.dma_copy(V_active_hbm.reshape((B_attn * S_tkg, d_head)), V_tkg_sb) + + # --- Prepare KV cache for attention --- + if is_block_kv: + k_prior_virt = K_cache + v_prior_virt = V_cache + else: + if kv_heads == 1: + # Original single-head path + k_shape = ( + (B_attn, 1, d_head, S_max_ctx) + if K_cache_transposed + else (B_attn, 1, S_max_ctx, d_head) + ) + k_prior_virt = K_cache.reshape(k_shape) + v_prior_virt = V_cache.reshape((B_attn, 1, S_max_ctx, d_head)) + else: + # Multi-KV-head: reshape 4D cache to virtual batch format. + # K_cache: [B, kv_heads, d, S_max] or [B, kv_heads, S_max, d] + # V_cache: [B, kv_heads, S_max, d] + # Target: [B_virt, 1, ...] where B_virt = B * kv_heads + # + # For B=1: K_cache [1, 2, d, S] -> reshape to [2, 1, d, S] — just a reshape! + # The data layout is [head0_data, head1_data] which maps to + # [vbatch0_data, vbatch1_data] with the same stride. + if K_cache_transposed: + k_prior_virt = K_cache.reshape((B_virt, 1, d_head, S_max_ctx)) + else: + k_prior_virt = K_cache.reshape((B_virt, 1, S_max_ctx, d_head)) + v_prior_virt = V_cache.reshape((B_virt, 1, S_max_ctx, d_head)) + + # --- Attention mask --- + if kv_heads > 1: + kernel_assert( + group_attention_mask is not None, + "group_attention_mask required when n_kv_heads > 1", + ) + attn_mask = group_attention_mask # [S_ctx, B_virt, qpg, S_tkg] + else: + attn_mask = attention_mask # [S_ctx, B, q_heads, S_tkg] + + # --- Allocate attention output --- + attn_out = sbm.alloc_stack( + (d_head, B_virt * q_per_kv_group * S_tkg), dtype=X.dtype, buffer=nl.sbuf + ) + + attn_cfg = AttnTKGConfig( + bs=B_virt, + q_head=q_per_kv_group, + s_active=S_tkg, + curr_sprior=S_ctx, + full_sprior=S_max_ctx, + d_head=d_head, + block_len=blk_len if is_block_kv else 0, + tp_k_prior=not K_cache_transposed, + strided_mm1=not is_block_kv, + use_pos_id=False, + fuse_rope=False, + use_gpsimd_sb2sb=True, + qk_in_sb=True, + k_out_in_sb=False, + out_in_sb=True, + ) + + attention_tkg( + q=Q_tkg_sb, # [d, B_virt * qpg * S] @ SBUF + k_active=K_tkg_sb, # [d, B_virt * S] @ SBUF + v_active=V_active_hbm, # [B_virt, 1, S, d] @ HBM + k_prior=k_prior_virt, # [B_virt, 1, d, S_max] or [B_virt, 1, S_max, d] + v_prior=v_prior_virt, # [B_virt, 1, S_max, d] + mask=attn_mask, + out=attn_out, + cfg=attn_cfg, + sbm=sbm, + sink=sink, + active_blocks_table=active_blocks_table, + ) + + # ========== KV Data Parallelism: Output Gather ========== + # NOTE: KVDP not supported with multi-KV-head (asserted above) + + # ========== KV Cache Update ========== + # For multi-KV-head: K_tkg_sb is [d, B*kv_heads*S] and V_tkg_sb is [B*S, kv_heads*d] + # We loop over heads and call _kv_cache_update per head with 3D cache slices + if update_cache: + if kv_heads == 1: + # Original single-head path + _kv_cache_update( + K_cache=K_cache, + V_cache=V_cache, + K_tkg=K_tkg_sb, + V_tkg=V_tkg_sb, + kv_cache_update_idx=kv_cache_update_idx, + B=B_attn, + d_head=d_head, + S_tkg=S_tkg, + S_max_ctx=S_max_ctx, + K_cache_transposed=K_cache_transposed, + is_block_kv=is_block_kv, + ) + K_cache, V_cache = __internal_unsqueeze_head_dim( + K_cache, V_cache, cache_had_head_dim, is_block_kv + ) + else: + # Multi-KV-head: inline cache update using .ap() on 3D cache views. + # + # CRITICAL FIX: The original _update_flat_cache uses .ap() on 3D tensors + # [B, S_max, d] with indirect_dim=1 pointing to the sequence dimension. + # The previous code reshaped to 2D [B_virt*S_max, d], which changes the + # dimension numbering — indirect_dim=1 then points to d_head (columns) + # instead of the sequence position, causing writes to the WRONG location. + # + # Fix: reshape 4D cache [B, kv_heads, S_max, d] to 3D [B_virt, S_max, d] + # so the .ap() semantics match the original exactly. + # + # K_cache: [B, kv_heads, S_max, d] or [B, kv_heads, d, S_max] → 3D + # V_cache: [B, kv_heads, S_max, d] → 3D [B_virt, S_max, d] + # K_tkg_sb: [d, B*kv_heads*S] @ SBUF + # V_tkg_sb: [B*S, kv_heads*d] @ SBUF + + B_virt_cache = B_attn * kv_heads + + _, n_prgs, prg_id = get_verified_program_sharding_info( + "multi_kv_cache_update", (0, 1), 2 + ) + + start_position = nl.ndarray((1, 1), dtype=nl.uint32, buffer=nl.sbuf) + + # V update on lnc=0 + # V_cache: [B, kv_heads, S_max, d] → reshape to 3D [B_virt, S_max, d] + # Then use the SAME .ap() pattern as _update_flat_cache: 3D with indirect_dim=1 + # + # NKI 0.3.0 / neuronx-cc 2.24 fix: + # V_tkg_sb has shape (B*S_tkg, kv_heads*d_head) = (4, 256). + # nc_transpose on the full tensor fails: stationary free dim 256 > gemm_stationary_fmax=128. + # tensor_copy from partition b also fails (cross-partition BIR error). + # + # Solution: Split V per KV head FIRST, then transpose each (B*S_tkg, d_head) = (4, 128) + # chunk separately. nc_transpose on (4, 128) has stationary free = 4 (well under 128). + # This keeps everything in SBUF — no HBM roundtrip overhead. + if n_prgs == 1 or prg_id == 0: + V_cache_3d = V_cache.reshape((B_virt_cache, S_max_ctx, d_head)) + + for g in range(kv_heads): + # Slice V for this KV head: (B*S_tkg, d_head) from V_tkg_sb + V_head_sb = V_tkg_sb[ + :, nl.ds(g * d_head, d_head) + ] # (B*S_tkg, d_head) + # Transpose (B*S_tkg, d_head) → (d_head, B*S_tkg) + # stationary free = B*S_tkg (e.g. 4) << 128 ✓ + V_head_T_psum = nl.ndarray( + (d_head, B_attn * S_tkg), + dtype=V_head_sb.dtype, + buffer=nl.psum, + ) + nisa.nc_transpose(dst=V_head_T_psum, data=V_head_sb) + V_head_T_sb = nl.ndarray( + (d_head, B_attn * S_tkg), + dtype=V_head_sb.dtype, + buffer=nl.sbuf, + ) + nisa.tensor_copy(V_head_T_sb, V_head_T_psum) + + for b in range(B_attn): + vb = b * kv_heads + g + nisa.dma_copy(start_position, kv_cache_update_idx[b]) + # Extract (d_head, S_tkg) column for this batch + V_col_sb = nl.ndarray( + (d_head, S_tkg), + dtype=V_head_T_sb.dtype, + buffer=nl.sbuf, + ) + nisa.tensor_copy( + V_col_sb, + V_head_T_sb[:, nl.ds(b * S_tkg, S_tkg)], + ) + # Transpose (d_head, S_tkg) → (S_tkg, d_head) for cache DMA + V_col_T_psum = nl.ndarray( + (S_tkg, d_head), + dtype=V_col_sb.dtype, + buffer=nl.psum, + ) + nisa.nc_transpose(dst=V_col_T_psum, data=V_col_sb) + V_col_T_sb = nl.ndarray( + (S_tkg, d_head), + dtype=V_col_sb.dtype, + buffer=nl.sbuf, + ) + nisa.tensor_copy(V_col_T_sb, V_col_T_psum) + nisa.dma_copy( + dst=V_cache_3d.ap( + pattern=[[d_head, S_tkg], [1, d_head]], + offset=vb * S_max_ctx * d_head, + scalar_offset=start_position, + indirect_dim=1, + ), + src=V_col_T_sb, + ) + + # K update on lnc=1 + if n_prgs == 1 or prg_id == 1: + for b in range(B_attn): + nisa.dma_copy(start_position, kv_cache_update_idx[b]) + for g in range(kv_heads): + vb = b * kv_heads + g + k_src_col = b * kv_heads * S_tkg + g * S_tkg + if K_cache_transposed: + # K_cache: [B, kv_heads, d, S_max] → 3D [B_virt, d, S_max] + K_cache_3d = K_cache.reshape( + (B_virt_cache, d_head, S_max_ctx) + ) + K_col_sb = nl.ndarray( + (d_head, S_tkg), dtype=K_tkg_sb.dtype, buffer=nl.sbuf + ) + nisa.tensor_copy( + K_col_sb, K_tkg_sb[:, nl.ds(k_src_col, S_tkg)] + ) + nisa.dma_copy( + dst=K_cache_3d.ap( + pattern=[[S_max_ctx, d_head], [1, S_tkg]], + offset=vb * d_head * S_max_ctx, + scalar_offset=start_position, + indirect_dim=2, + ), + src=K_col_sb, + ) + else: + # K_cache: [B, kv_heads, S_max, d] → 3D [B_virt, S_max, d] + K_cache_3d = K_cache.reshape( + (B_virt_cache, S_max_ctx, d_head) + ) + K_col_sb = nl.ndarray( + (d_head, S_tkg), dtype=K_tkg_sb.dtype, buffer=nl.sbuf + ) + nisa.tensor_copy( + K_col_sb, K_tkg_sb[:, nl.ds(k_src_col, S_tkg)] + ) + K_transposed_sb = nl.ndarray( + (S_tkg, d_head), K_tkg_sb.dtype, nl.sbuf + ) + _transpose_sbuf(K_col_sb, K_transposed_sb) + nisa.dma_copy( + dst=K_cache_3d.ap( + pattern=[[d_head, S_tkg], [1, d_head]], + offset=vb * S_max_ctx * d_head, + scalar_offset=start_position, + indirect_dim=1, + ), + src=K_transposed_sb, + ) + # K_cache and V_cache are already 4D, no unsqueeze needed + else: # No cache update: return new K/V tokens + if kv_heads == 1: + K_tkg_hbm = nl.ndarray( + (d_head, B_attn, S_tkg), + dtype=K_tkg_sb.dtype, + buffer=nl.shared_hbm, + name="K_hbm", + ) + nisa.dma_copy(K_tkg_hbm.reshape(K_tkg_sb.shape), K_tkg_sb) + # V_tkg_hbm: copy from SBUF [B*S, d] to HBM [B, 1, S, d] + V_tkg_hbm = nl.ndarray( + (B_attn, 1, S_tkg, d_head), + dtype=V_tkg_sb.dtype, + buffer=nl.shared_hbm, + name="V_hbm", + ) + nisa.dma_copy(V_tkg_hbm.reshape((B_attn * S_tkg, d_head)), V_tkg_sb) + else: + # Return multi-head K tokens: [d, B, kv_heads, S] + K_tkg_hbm = nl.ndarray( + (d_head, B_attn, kv_heads, S_tkg), + dtype=K_tkg_sb.dtype, + buffer=nl.shared_hbm, + name="K_hbm", + ) + nisa.dma_copy(K_tkg_hbm.reshape(K_tkg_sb.shape), K_tkg_sb) + # Return multi-head V tokens: [B, kv_heads, S, d] + # V_tkg_sb is [B*S, kv_heads*d] in SBUF + V_tkg_hbm = nl.ndarray( + (B_attn, kv_heads, S_tkg, d_head), + dtype=V_tkg_sb.dtype, + buffer=nl.shared_hbm, + name="V_hbm", + ) + nisa.dma_copy( + V_tkg_hbm.reshape((B_attn * S_tkg, kv_heads * d_head)), V_tkg_sb + ) + + # ========== Output Projection (Optional) ========== + # Input: attn_out [d_head, B, q_heads, S_tkg] @ SBUF/HBM + # Output: kernel_output layout depends on transposed_out and out_in_sb + if do_out_proj: + kernel_output = output_projection_tkg( + attention=attn_out.reshape((d_head, B, q_heads, S_tkg)), + weight=W_out, + bias=bias_out, + quantization_type=quantization_type_out, + weight_scale=weight_dequant_scale_out, + input_scale=input_dequant_scale_out, + TRANSPOSE_OUT=transposed_out, + OUT_IN_SB=out_in_sb, + ) + else: + kernel_assert( + not transposed_out, + "transposed_out requires output projection (W_out must be provided)", + ) + kernel_output = attn_out + + # Copy output to HBM if caller expects it on HBM but it's on SBUF. This is only used for debug when skipping both attention and output-projection. + if out_in_sb == False and kernel_output.buffer == nl.sbuf: + kernel_output_hbm = nl.ndarray( + kernel_output.shape, + kernel_output.dtype, + nl.shared_hbm, + name="kernel_output_hbm", + ) + nisa.dma_copy(kernel_output_hbm, kernel_output) + kernel_output = kernel_output_hbm + + # ========== Cleanup and Return ========== + sbm.close_scope() + if update_cache: + return kernel_output, K_cache, V_cache + else: + return kernel_output, K_tkg_hbm, V_tkg_hbm + + +############### Internal ############### + + +def _validate_and_extract_config( + X: nl.ndarray, + W_qkv: nl.ndarray, + K_cache: nl.ndarray, + V_cache: nl.ndarray, + attention_mask: nl.ndarray, + cos: Optional[nl.ndarray], + sin: Optional[nl.ndarray], + rmsnorm_X_gamma: Optional[nl.ndarray], + K_cache_transposed: bool, + active_blocks_table: Optional[nl.ndarray], + W_out: Optional[nl.ndarray], + k_scale: Optional[nl.ndarray], + v_scale: Optional[nl.ndarray], + kv_heads: int, + d_head_param: int, + s_max_ctx_param, # int or None +) -> Dict[str, Any]: + """ + Validate inputs and extract configuration parameters for attention block. + + Args: + X (nl.ndarray): Input hidden states + W_qkv (nl.ndarray): QKV projection weights + K_cache (nl.ndarray): Key cache + V_cache (nl.ndarray): Value cache + attention_mask (nl.ndarray): Attention mask + cos (Optional[nl.ndarray]): RoPE cosine embeddings + sin (Optional[nl.ndarray]): RoPE sine embeddings + rmsnorm_X_gamma (Optional[nl.ndarray]): RMSNorm weights + K_cache_transposed (bool): K cache layout flag + active_blocks_table (Optional[nl.ndarray]): Block indices for block KV cache + W_out (Optional[nl.ndarray]): Output projection weights + + Returns: + Dict[str, Any]: Configuration dictionary with keys: B, S_tkg, H, d_head, half_d, + q_heads, S_ctx, S_max_ctx, is_block_kv, blk_len, cache_had_head_dim, + do_out_proj, K_cache, V_cache + + Notes: + - Validates tensor shapes and dimensions + - Extracts batch size, sequence lengths, and head dimensions + - Handles both block and flat KV cache layouts + """ + + kernel_assert( + nisa.get_nc_version() >= nisa.nc_version.gen3, + f"Kernel requires nc-version >= gen3, got {nisa.get_nc_version()}", + ) + + _, B, _, _ = attention_mask.shape + if X.buffer == nl.sbuf: + # X.shape = (pmax, B*S, H // pmax) @ SBUF + kernel_assert(len(X.shape) == 3, "SBUF input X must have 3 dimensions") + kernel_assert( + X.shape[0] == nl.tile_size.pmax, + f"SBUF input X dim0 must be {nl.tile_size.pmax}", + ) + kernel_assert( + X.shape[1] % B == 0, f"SBUF input X dim1 must be divisible by B={B}" + ) + H = X.shape[2] * nl.tile_size.pmax + S_tkg = X.shape[1] // B + else: + # X.shape = (B,S,H) @ HBM + kernel_assert(is_hbm_buffer(X), "Input X must be in HBM or SBUF") + B, S_tkg, H = X.shape + + # kv_heads and d_head_param are passed as positional int args (no NKI tracing issues). + # d_head_param is the head dimension. For multi-KV-head, it must be provided. + # For single-head, derive from V_cache.shape (3D cache, .shape works). + d_head = d_head_param + if d_head == 0 and kv_heads == 1: + # Fallback for single-head: V_cache is 3D (B, S, d) + d_head = V_cache.shape[2] + I = W_qkv.shape[1] + + # KVDP not supported with multi-KV-head kernel + # Hardcode is_KVDP=False, B_attn=B, q_heads_attn=q_heads + is_KVDP = False + kernel_assert( + B * S_tkg <= nl.tile_size.pmax, + f"B * S_tkg must be <= {nl.tile_size.pmax}, got {B * S_tkg}", + ) + kernel_assert(d_head % 2 == 0, f"d_head must be even, got {d_head}") + kernel_assert( + d_head > 0 and I % d_head == 0, + f"QKV weights must be packed as (q_heads + 2*kv_heads) * d_head, got I={I}, d_head={d_head}, kv_heads={kv_heads}", + ) + + q_heads = I // d_head - 2 * kv_heads + half_d = d_head // 2 + + B_attn = B + q_heads_attn = q_heads + + # Process KV cache + is_block_kv = active_blocks_table is not None + K_cache, V_cache, cache_had_head_dim = __internal_squeeze_head_dim( + K_cache, V_cache, is_block_kv, kv_heads + ) + + if is_block_kv: + blk_len = V_cache.shape[1] + S_ctx = S_max_ctx = active_blocks_table.shape[1] * blk_len + kernel_assert( + V_cache.shape == K_cache.shape, + f"Block KV cache shape mismatch: K={K_cache.shape} vs V={V_cache.shape}", + ) + else: + S_ctx = attention_mask.shape[0] + if kv_heads > 1: + # Multi-KV-head: NKI tracing cannot access .shape on 4D KV cache tensors. + # Use explicitly passed s_max_ctx parameter. + kernel_assert( + s_max_ctx_param is not None, + "s_max_ctx must be provided when n_kv_heads > 1", + ) + S_max_ctx = s_max_ctx_param + else: + S_max_ctx = V_cache.shape[1] + blk_len = 0 + # Skip shape validation for multi-KV-head (NKI can't access .shape on 4D cache) + if kv_heads == 1: + kernel_assert( + V_cache.shape[0] == B_attn, + f"V_cache batch mismatch: expected {B_attn}, got {V_cache.shape[0]}", + ) + expected_K_shape = ( + (B_attn, d_head, S_max_ctx) + if K_cache_transposed + else (B_attn, S_max_ctx, d_head) + ) + kernel_assert( + tuple(K_cache.shape) == expected_K_shape, + f"K_cache shape mismatch: expected {expected_K_shape}, got {K_cache.shape}", + ) + + # Validate attention mask + expected_mask_shape = (S_ctx, B_attn, q_heads_attn, S_tkg) + kernel_assert( + tuple(attention_mask.shape) == expected_mask_shape, + f"attention_mask shape mismatch: expected {expected_mask_shape}, got {attention_mask.shape}", + ) + + # Validate RMSNorm weights + if rmsnorm_X_gamma is not None: + kernel_assert( + tuple(rmsnorm_X_gamma.shape) == (1, H), + f"rmsnorm_X_gamma must be (1, {H}), got {rmsnorm_X_gamma.shape}", + ) + + # Validate RoPE embeddings + if cos is not None and sin is not None: + kernel_assert( + tuple(cos.shape) == (half_d, B, S_tkg), + f"cos shape mismatch: expected ({half_d}, {B}, {S_tkg}), got {cos.shape}", + ) + kernel_assert( + tuple(sin.shape) == (half_d, B, S_tkg), + f"sin shape mismatch: expected ({half_d}, {B}, {S_tkg}), got {sin.shape}", + ) + + # KV Quantization + if k_scale is not None and v_scale is not None: + kernel_assert( + is_fp8_e4m3(K_cache.dtype), + f"KV quantization requires float8_e4m3 K_cache, got {K_cache.dtype}", + ) + kernel_assert( + is_fp8_e4m3(V_cache.dtype), + f"KV quantization requires float8_e4m3 V_cache, got {V_cache.dtype}", + ) + kv_quant = True + else: + kv_quant = False + + return { + "B": B, + "S_tkg": S_tkg, + "H": H, + "d_head": d_head, + "half_d": half_d, + "q_heads": q_heads, + "S_ctx": S_ctx, + "S_max_ctx": S_max_ctx, + "is_block_kv": is_block_kv, + "blk_len": blk_len, + "cache_had_head_dim": cache_had_head_dim, + "do_out_proj": W_out is not None, + "K_cache": K_cache, + "V_cache": V_cache, + "kv_quant": kv_quant, + "B": B, + "B_attn": B_attn, + "q_heads_attn": q_heads_attn, + "is_KVDP": is_KVDP, + "kv_heads": kv_heads, + } + + +def __internal_squeeze_head_dim( + K_cache: nl.ndarray, V_cache: nl.ndarray, is_block_kv: bool, kv_heads: int +) -> Tuple[nl.ndarray, nl.ndarray, bool]: + """ + Remove head dimension from 4D cache tensors when kv_heads == 1. + For multi-KV-head (kv_heads > 1), leave the 4D tensors as-is. + + NKI tracing does not support len(tensor.shape), so we use the kv_heads + parameter to determine the cache layout. + + Args: + K_cache (nl.ndarray): Key cache (3D or 4D) + V_cache (nl.ndarray): Value cache (3D or 4D) + is_block_kv (bool): Block KV cache flag + kv_heads (int): Number of KV heads per rank + + Returns: + Tuple[nl.ndarray, nl.ndarray, bool]: (K_processed, V_processed, had_head_dim) + """ + if kv_heads > 1: + # Multi-KV-head: cache is 4D (B, kv_heads, ...), keep as-is + # NKI tracing cannot access .shape on 4D KV cache tensors + return K_cache, V_cache, True + else: + # kv_heads == 1: cache is 3D from NxDI (B, S, d) or (num_blocks, block_len, d) + # NxDI with kv_heads==1 always creates 3D caches, so .shape is accessible. + # Just return as-is (no squeeze needed since it's already 3D). + return K_cache, V_cache, False + + +def __internal_unsqueeze_head_dim( + K_cache: nl.ndarray, + V_cache: nl.ndarray, + cache_had_head_dim: bool, + is_block_kv: bool, +) -> Tuple[nl.ndarray, nl.ndarray]: + """ + Add back head dimension if cache originally had one. + + Args: + K_cache (nl.ndarray): Key cache (3D) + V_cache (nl.ndarray): Value cache (3D) + cache_had_head_dim (bool): Whether cache originally had head dimension + is_block_kv (bool): Block KV cache flag + + Returns: + Tuple[nl.ndarray, nl.ndarray]: (K_cache, V_cache) with head dimension restored + + Notes: + - Inverse operation of __internal_squeeze_head_dim + - Returns original tensors if cache_had_head_dim is False + """ + if not cache_had_head_dim: + return K_cache, V_cache + + head_dim = 2 if is_block_kv else 1 + K_shape = list(K_cache.shape[:head_dim]) + [1] + list(K_cache.shape[head_dim:]) + V_shape = list(V_cache.shape[:head_dim]) + [1] + list(V_cache.shape[head_dim:]) + return K_cache.reshape(tuple(K_shape)), V_cache.reshape(tuple(V_shape)) + + +def _to_sbuf(buf: nl.ndarray, sbm: SbufManager) -> nl.ndarray: + """ + Ensure buffer is in SBUF; copy from HBM if needed. + + Args: + buf (nl.ndarray): Input buffer (HBM or SBUF) + sbm (SbufManager): SBUF memory manager + + Returns: + nl.ndarray: Buffer in SBUF + + Notes: + - Returns original buffer if already in SBUF + - Allocates and copies if buffer is in HBM + """ + if buf.buffer == nl.sbuf: + return buf + else: + sb = sbm.alloc_stack(buf.shape, dtype=buf.dtype, buffer=nl.sbuf) + nisa.dma_copy(sb, buf) + return sb + + +def _process_head_group( + QKV: nl.ndarray, + qkv_offset: int, + n_heads: int, + d: int, + B: int, + S: int, + rmsnorm_pre_enabled: bool, + rmsnorm_pre_eps: float, + rmsnorm_pre_W: Optional[nl.ndarray], + enable_rope: bool, + sb_cos: Optional[nl.ndarray], + sb_sin: Optional[nl.ndarray], + rope_contiguous_layout: bool, + rmsnorm_post_enabled: bool, + rmsnorm_post_eps: float, + rmsnorm_post_W: Optional[nl.ndarray], + sbm: SbufManager, +) -> nl.ndarray: + """ + Process Q or K: extract heads, transpose to [d, B*n_heads*S], apply optional RMSNorm pre/post and RoPE. + For Q: n_heads=q_heads, qkv_offset=0 + For K: n_heads=1, qkv_offset=d*q_heads + """ + # Transpose heads from [B*S, n_heads*d] to [d, B*n_heads*S] + out = sbm.alloc_stack(shape=(d, B * n_heads * S), dtype=QKV.dtype, buffer=nl.sbuf) + for head_idx in range(n_heads): + psum = nl.ndarray((d, B * S), dtype=QKV.dtype, buffer=nl.psum) + nisa.nc_transpose(psum, QKV[:, nl.ds(qkv_offset + head_idx * d, d)]) + nisa.tensor_copy( + out.reshape((d, B, n_heads, S))[:, :, head_idx, :], psum.reshape((d, B, S)) + ) + + # Pre-RoPE RMSNorm + if rmsnorm_pre_enabled: + _rms_norm_inplace(out, rmsnorm_pre_eps, w=rmsnorm_pre_W, sbm=sbm) + + # RoPE + if enable_rope: + out_4d = out.reshape((d, B, n_heads, S)) + out_rope = sbm.alloc_stack(out_4d.shape, dtype=out.dtype, buffer=nl.sbuf) + RoPE_sbuf( + out_4d, + sb_cos, + sb_sin, + out_rope, + convert_from_interleaved=not rope_contiguous_layout, + ) + out = out_rope.reshape((d, B * n_heads * S)) + + # Post-RoPE RMSNorm + if rmsnorm_post_enabled: + _rms_norm_inplace(out, rmsnorm_post_eps, rmsnorm_post_W, sbm) + + return out + + +def _QK_processing( + QKV: nl.ndarray, + q_heads: int, + kv_heads: int, + B: int, + rmsnorm_pre_enabled: bool, + rmsnorm_pre_eps: float, + rmsnorm_pre_W_Q: Optional[nl.ndarray], + rmsnorm_pre_W_K: Optional[nl.ndarray], + cos: Optional[nl.ndarray], + sin: Optional[nl.ndarray], + rope_contiguous_layout: bool, + rmsnorm_post_enabled: bool, + rmsnorm_post_eps: float, + rmsnorm_post_W_Q: Optional[nl.ndarray], + rmsnorm_post_W_K: Optional[nl.ndarray], + sbm: SbufManager, +) -> Tuple[nl.ndarray, nl.ndarray]: + """ + Unified Q/K processing: transpose, optional pre-RoPE RMSNorm, optional RoPE, optional post-RoPE RMSNorm. + + Args: + QKV: [B*S, I] @ SBUF - concatenated Q/K/V projections where I = d*(q_heads+2) + q_heads: number of query heads + B: batch size + rmsnorm_pre_enabled: Apply RMSNorm before RoPE + rmsnorm_pre_eps: Pre-RoPE RMSNorm epsilon + rmsnorm_pre_W_Q: Pre-RoPE Q gamma weights (optional) + rmsnorm_pre_W_K: Pre-RoPE K gamma weights (optional) + cos: RoPE cosine embeddings (None = skip RoPE) + sin: RoPE sine embeddings (None = skip RoPE) + rope_contiguous_layout: True for contiguous halves, False for interleaved + rmsnorm_post_enabled: Apply RMSNorm after RoPE + rmsnorm_post_eps: Post-RoPE RMSNorm epsilon + rmsnorm_post_W_Q: Post-RoPE Q weights (optional) + rmsnorm_post_W_K: Post-RoPE K weights (optional) + sbm: SBUF memory manager + + Returns: + Q: [d, B*q_heads*S] @ SBUF + K: [d, B*S] @ SBUF + """ + kernel_assert(QKV.buffer == nl.sbuf, "QKV must be in SBUF") + kernel_assert(len(QKV.shape) == 2, "Expecting QKV.shape=(BxS, I)") + + BxS, I = QKV.shape + S = BxS // B + d = I // (q_heads + 2 * kv_heads) + + enable_rope = cos is not None and sin is not None + + # Load RoPE embeddings to SBUF if needed + sb_cos, sb_sin = None, None + if enable_rope: + kernel_assert(cos.shape == sin.shape, "cos and sin must match") + sb_cos = _to_sbuf(cos, sbm) + sb_sin = _to_sbuf(sin, sbm) + + Q = _process_head_group( + QKV, + qkv_offset=0, + n_heads=q_heads, + d=d, + B=B, + S=S, + rmsnorm_pre_enabled=rmsnorm_pre_enabled, + rmsnorm_pre_eps=rmsnorm_pre_eps, + rmsnorm_pre_W=rmsnorm_pre_W_Q, + enable_rope=enable_rope, + sb_cos=sb_cos, + sb_sin=sb_sin, + rope_contiguous_layout=rope_contiguous_layout, + rmsnorm_post_enabled=rmsnorm_post_enabled, + rmsnorm_post_eps=rmsnorm_post_eps, + rmsnorm_post_W=rmsnorm_post_W_Q, + sbm=sbm, + ) + K = _process_head_group( + QKV, + qkv_offset=d * q_heads, + n_heads=kv_heads, + d=d, + B=B, + S=S, + rmsnorm_pre_enabled=rmsnorm_pre_enabled, + rmsnorm_pre_eps=rmsnorm_pre_eps, + rmsnorm_pre_W=rmsnorm_pre_W_K, + enable_rope=enable_rope, + sb_cos=sb_cos, + sb_sin=sb_sin, + rope_contiguous_layout=rope_contiguous_layout, + rmsnorm_post_enabled=rmsnorm_post_enabled, + rmsnorm_post_eps=rmsnorm_post_eps, + rmsnorm_post_W=rmsnorm_post_W_K, + sbm=sbm, + ) + + return Q, K + + +def _rms_norm_inplace( + x: nl.ndarray, + eps: float, + w: Optional[nl.ndarray] = None, + sbm: Optional[SbufManager] = None, +) -> None: + """ + RMS normalization in-place: x / sqrt(mean(x^2) + eps), optionally scaled by w. + Computed in fp32, result written back to x in original dtype. + + Args: + x: [d_head, BnS] @ SBUF - input tensor (d_head must be nl.tile_size.pmax), modified in-place + eps: epsilon for numerical stability + w: [d_head, 1] @ HBM - optional scale weights + sbm: SBUF memory manager + """ + d_head, BnS = x.shape + kernel_assert( + d_head == nl.tile_size.pmax, f"d_head must be {nl.tile_size.pmax}, got {d_head}" + ) + + # Setup constants + ones_sb = sbm.alloc_stack((d_head, d_head), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(ones_sb, 1.0) + eps_sb = sbm.alloc_stack((d_head, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(eps_sb, eps) + + # Compute x^2 in fp32 + x_squared = sbm.alloc_stack((d_head, BnS), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(x_squared, x, x, nl.multiply) + + # Compute sum(x^2) via matmul with all-ones matrix + psum_sb = nl.ndarray((d_head, BnS), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(psum_sb, stationary=ones_sb, moving=x_squared) + + # Compute rsqrt(mean(x^2) + eps) + rsqrt_sb = sbm.alloc_stack((d_head, BnS), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=rsqrt_sb, op=nl.rsqrt, data=psum_sb, bias=eps_sb, scale=1.0 / d_head + ) + + # Normalize: x * rsqrt + out_sb = sbm.alloc_stack((d_head, BnS), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(out_sb, x, rsqrt_sb, nl.multiply) + + # Optional scaling by weights + if w is not None: + w_sb = sbm.alloc_stack((d_head, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(w_sb, w.reshape((d_head, 1))) + nisa.tensor_scalar(dst=out_sb, data=out_sb, op0=nl.multiply, operand0=w_sb) + + # Copy result back to x with original dtype + nisa.tensor_copy(dst=x, src=out_sb) + + +############################# KV cache update logic ############################# + + +def _kv_cache_update( + K_cache: nl.ndarray, + V_cache: nl.ndarray, + K_tkg: nl.ndarray, + V_tkg: nl.ndarray, + kv_cache_update_idx: nl.ndarray, + B: int, + d_head: int, + S_tkg: int, + S_max_ctx: int, + K_cache_transposed: bool, + is_block_kv: bool, +) -> Tuple[nl.ndarray, nl.ndarray]: + """ + Update KV cache with new tokens for token generation. + + Args: + K_cache: K cache @ HBM + - Block KV: [num_blocks, block_len, d_head] + - Flat transposed: [B, d_head, S_max_ctx] + - Flat: [B, S_max_ctx, d_head] + V_cache: V cache @ HBM + - Block KV: [num_blocks, block_len, d_head] + - Flat: [B, S_max_ctx, d_head] + K_tkg: [d_head, B*S_tkg] @ SBUF + V_tkg: [B, S_tkg, d_head] @ SBUF + kv_cache_update_idx: [B] slot indices for cache writes + B: batch size + d_head: head dimension + S_tkg: number of new tokens + S_max_ctx: max cache sequence length + K_cache_transposed: K cache layout flag + is_block_kv: block KV cache flag + + Returns: + Updated (K_cache, V_cache) - modified in-place + """ + + # TODO: oob_mode.skip not supported for flat cache. Using oob_mode.skip causes accuracy failures (root cause unknown). + + if is_block_kv: + _update_block_cache( + K_cache, V_cache, K_tkg, V_tkg, kv_cache_update_idx, S_tkg, B + ) + elif S_tkg == 1 and B > 1 and (not K_cache_transposed or B > 16): + # one vector DMA of (B, S_tkg, d_head). Bug for S_tkg > 1. + _update_flat_cache_batched( + K_cache, + V_cache, + K_tkg, + V_tkg, + kv_cache_update_idx, + S_tkg, + S_max_ctx, + B, + d_head, + K_cache_transposed=K_cache_transposed, + ) + else: + # B scalar DMA of (S_tkg, d_head) + _update_flat_cache( + K_cache, + V_cache, + K_tkg, + V_tkg, + K_cache_transposed, + kv_cache_update_idx, + S_tkg, + S_max_ctx, + B, + d_head, + ) + + +def _update_flat_cache_batched( + K_cache: nl.ndarray, + V_cache: nl.ndarray, + K_tkg: nl.ndarray, + V_tkg: nl.ndarray, + kv_cache_update_idx: nl.ndarray, + S_tkg: int, + S_max_ctx: int, + B: int, + d_head: int, + K_cache_transposed: bool = False, +) -> None: + """ + Update flat (non-block) KV cache with new tokens using batched DMA operations. + + This optimized version writes all B batches in a single DMA operation using vector_offset + for indirect addressing. Currently limited to S_tkg=1 due to access pattern bug. + + Args: + K_cache: [B, S_max_ctx, d_head] K cache in HBM + V_cache: [B, S_max_ctx, d_head] V cache in HBM + K_tkg: [d_head, B*S_tkg] new K tokens in SBUF + V_tkg: [B*S_tkg, d_head] new V tokens in SBUF + kv_cache_update_idx: [B, 1] per-batch write positions + S_tkg: number of new tokens per batch + S_max_ctx: maximum cache sequence length + B: batch size + d_head: head dimension + """ + # Validate sharding configuration + _, n_prgs, prg_id = get_verified_program_sharding_info("kv_cache update", (0, 1), 2) + kernel_assert(n_prgs <= 2, f"Expected lnc in [1,2], got {n_prgs}") + kernel_assert( + S_tkg == 1, f"_update_flat_cache_batched() only supports S_tkg=1, got {S_tkg}" + ) + kernel_assert( + B * S_tkg <= nl.tile_size.pmax, + f"B * S_tkg must be <= {nl.tile_size.pmax}, got {B * S_tkg}", + ) + + # Validate tensor shapes + if K_cache_transposed: + kernel_assert( + K_cache.shape == (B, d_head, S_max_ctx), + f"K_cache shape mismatch: expected {(B, d_head, S_max_ctx)}, got {K_cache.shape}", + ) + else: + kernel_assert( + K_cache.shape == (B, S_max_ctx, d_head), + f"K_cache shape mismatch: expected {(B, S_max_ctx, d_head)}, got {K_cache.shape}", + ) + kernel_assert( + V_cache.shape == (B, S_max_ctx, d_head), + f"V_cache shape mismatch: expected {(B, S_max_ctx, d_head)}, got {V_cache.shape}", + ) + kernel_assert( + K_tkg.shape == (d_head, B * S_tkg), + f"K_tkg shape mismatch: expected {(d_head, B * S_tkg)}, got {K_tkg.shape}", + ) + kernel_assert( + V_tkg.shape == (B * S_tkg, d_head), + f"V_tkg shape mismatch: expected {(B * S_tkg, d_head)}, got {V_tkg.shape}", + ) + kernel_assert( + kv_cache_update_idx.shape == (B, 1), + f"kv_cache_update_idx shape mismatch: expected {(B, 1)}, got {kv_cache_update_idx.shape}", + ) + + # Compute absolute token indices for V_cache (and non-transposed K_cache): + # token_indices[b] = kv_cache_update_idx[b] + b * S_max_ctx + token_indices = nl.ndarray((B, 1), dtype=nl.uint32, buffer=nl.sbuf) + nisa.dma_copy(token_indices, kv_cache_update_idx) + + batch_offset = nl.ndarray((B, 1), dtype=nl.uint32, buffer=nl.sbuf) + nisa.iota(batch_offset, [[0, 1]], offset=0, channel_multiplier=S_max_ctx) + nisa.tensor_tensor(token_indices, token_indices, batch_offset, nl.add) + + # Vector DMA with indirect addressing: + # - Reshape cache to (B*S_max_ctx, d_head) so each row has stride d_head + # - vector_offset provides per-batch row indices: token_indices[b] + # - DMA engine scales token_indices[b] by d_head (stride of indirect_dim=0) + + # Update V_cache on lnc=0 + if n_prgs == 1 or prg_id == 0: + nisa.dma_copy( + dst=V_cache.reshape((B * S_max_ctx, d_head)).ap( + pattern=[[1, B], [d_head, S_tkg], [1, d_head]], + offset=0, + vector_offset=token_indices, + indirect_dim=0, + ), + src=V_tkg.ap(pattern=[[S_tkg * d_head, B], [d_head, S_tkg], [1, d_head]]), + ) + + # Update K_cache on lnc=1 + if n_prgs == 1 or prg_id == 1: + K_reshaped_sb = nl.ndarray((B * S_tkg, d_head), K_tkg.dtype, nl.sbuf) + _transpose_sbuf(K_tkg, K_reshaped_sb) + + if K_cache_transposed: + # K_cache [B, d_head, S_max_ctx] — strided scatter via vector DGE + # k_token_indices[b] = b * d_head * S_max_ctx + idx[b] + k_token_indices = nl.ndarray((B, 1), dtype=nl.uint32, buffer=nl.sbuf) + k_batch_offset = nl.ndarray((B, 1), dtype=nl.uint32, buffer=nl.sbuf) + nisa.iota( + k_batch_offset, + [[0, 1]], + offset=0, + channel_multiplier=d_head * S_max_ctx, + ) + nisa.dma_copy(k_token_indices, kv_cache_update_idx) + nisa.tensor_tensor(k_token_indices, k_token_indices, k_batch_offset, nl.add) + + nisa.dma_copy( + dst=K_cache.reshape((B * d_head * S_max_ctx,)).ap( + pattern=[[1, B], [S_max_ctx, d_head], [1, S_tkg]], + offset=0, + vector_offset=k_token_indices, + indirect_dim=0, + ), + src=K_reshaped_sb.ap(pattern=[[d_head, B], [1, d_head], [1, S_tkg]]), + ) + else: + # K_cache [B, S_max_ctx, d_head] — contiguous, same pattern as V + nisa.dma_copy( + dst=K_cache.reshape((B * S_max_ctx, d_head)).ap( + pattern=[[1, B], [d_head, S_tkg], [1, d_head]], + offset=0, + vector_offset=token_indices, + indirect_dim=0, + ), + src=K_reshaped_sb.ap( + pattern=[[S_tkg * d_head, B], [d_head, S_tkg], [1, d_head]] + ), + ) + + +def _update_flat_cache( + K_cache: nl.ndarray, + V_cache: nl.ndarray, + K_tkg: nl.ndarray, + V_tkg: nl.ndarray, + K_cache_transposed: bool, + kv_cache_update_idx: nl.ndarray, + S_tkg: int, + S_max_ctx: int, + B: int, + d_head: int, +) -> None: + """ + Update flat (non-block) KV cache with new tokens using per-batch scalar_offset. + + This version iterates over batches and uses scalar_offset for indirect addressing. + Supports any B*S_tkg <= PMAX and handles both transposed and non-transposed K cache layouts. + + Args: + K_cache: [B, d_head, S_max_ctx] if transposed else [B, S_max_ctx, d_head] @ HBM + V_cache: [B, S_max_ctx, d_head] @ HBM + K_tkg: [d_head, B*S_tkg] @ SBUF + V_tkg: [B*S_tkg, d_head] @ SBUF + K_cache_transposed: K cache layout flag + kv_cache_update_idx: [B, 1] per-batch write positions + S_tkg: number of new tokens per batch + S_max_ctx: maximum cache sequence length + B: batch size + d_head: head dimension + """ + _, n_prgs, prg_id = get_verified_program_sharding_info("kv_cache update", (0, 1), 2) + kernel_assert(n_prgs <= 2, f"Expected lnc in [1,2], got {n_prgs}") + + # Validate tensor shapes + kernel_assert( + kv_cache_update_idx.shape == (B, 1), + f"kv_cache_update_idx shape mismatch: expected {(B, 1)}, got {kv_cache_update_idx.shape}", + ) + kernel_assert( + V_cache.shape == (B, S_max_ctx, d_head), + f"V_cache shape mismatch: expected {(B, S_max_ctx, d_head)}, got {V_cache.shape}", + ) + kernel_assert( + V_tkg.shape == (B * S_tkg, d_head), + f"V_tkg shape mismatch: expected {(B * S_tkg, d_head)}, got {V_tkg.shape}", + ) + kernel_assert( + K_tkg.shape == (d_head, B * S_tkg), + f"K_tkg shape mismatch: expected {(d_head, B * S_tkg)}, got {K_tkg.shape}", + ) + + # Update V_cache on lnc=0 + if n_prgs == 1 or prg_id == 0: + start_position = nl.ndarray((1, 1), dtype=nl.uint32, buffer=nl.sbuf) + for batch_idx in range(B): + nisa.dma_copy(start_position, kv_cache_update_idx[batch_idx]) + nisa.dma_copy( + dst=V_cache.ap( + pattern=[[d_head, S_tkg], [1, d_head]], + offset=batch_idx * S_max_ctx * d_head, + scalar_offset=start_position, + indirect_dim=1, + ), + src=V_tkg[nl.ds(batch_idx * S_tkg, S_tkg), :], + ) + + # Update K_cache on lnc=1 + if n_prgs == 1 or prg_id == 1: + if K_cache_transposed: + kernel_assert( + K_cache.shape == (B, d_head, S_max_ctx), + f"K_cache shape mismatch: expected {(B, d_head, S_max_ctx)}, got {K_cache.shape}", + ) + # K_tkg is already in correct layout [d_head, B*S_tkg] + start_position = nl.ndarray((1, 1), dtype=nl.uint32, buffer=nl.sbuf) + for batch_idx in range(B): + nisa.dma_copy(start_position, kv_cache_update_idx[batch_idx]) + nisa.dma_copy( + dst=K_cache.ap( + pattern=[[S_max_ctx, d_head], [1, S_tkg]], + offset=batch_idx * d_head * S_max_ctx, + scalar_offset=start_position, + indirect_dim=2, + ), + src=K_tkg[:, nl.ds(batch_idx * S_tkg, S_tkg)], + ) + else: + kernel_assert( + K_cache.shape == (B, S_max_ctx, d_head), + f"K_cache shape mismatch: expected {(B, S_max_ctx, d_head)}, got {K_cache.shape}", + ) + kernel_assert( + B * S_tkg <= nl.tile_size.pmax and d_head <= nl.tile_size.pmax, + f"Transpose constraints: B*S_tkg={B * S_tkg}, d_head={d_head} (both must be <= {nl.tile_size.pmax})", + ) + + # Transpose K_tkg from [d_head, B*S_tkg] to [B*S_tkg, d_head] + K_reshaped_sb = nl.ndarray((B * S_tkg, d_head), K_tkg.dtype, nl.sbuf) + _transpose_sbuf(K_tkg, K_reshaped_sb) + + # Write transposed K to cache + start_position = nl.ndarray((1, 1), dtype=nl.uint32, buffer=nl.sbuf) + for batch_idx in range(B): + nisa.dma_copy(start_position, kv_cache_update_idx[batch_idx]) + nisa.dma_copy( + dst=K_cache.ap( + pattern=[[d_head, S_tkg], [1, d_head]], + offset=batch_idx * S_max_ctx * d_head, + scalar_offset=start_position, + indirect_dim=1, + ), + src=K_reshaped_sb[nl.ds(batch_idx * S_tkg, S_tkg), :], + ) + + +def _update_block_cache( + K_cache: nl.ndarray, + V_cache: nl.ndarray, + K_tkg: nl.ndarray, + V_tkg: nl.ndarray, + kv_cache_update_idx: nl.ndarray, + S_tkg: int, + B: int, +) -> None: + """ + Update block KV cache with new tokens. + + Routes to batched or scalar_offset implementation based on S_tkg. + + Args: + K_cache: [num_blocks, block_len, d_head] + V_cache: [num_blocks, block_len, d_head] + K_tkg: [d_head, B*S_tkg] + V_tkg: [B*S_tkg, d_head] + kv_cache_update_idx: [B, 1] slot indices for cache update (uint32 max = skip) + S_tkg: number of new tokens + B: batch size + """ + # TODO: Use batched case for all S_tkg values once vector DMA access pattern supports S_tkg > 1 + if S_tkg == 1: + # one vector DMA, S_tkg = 1 + _update_block_cache_batched( + K_cache, V_cache, K_tkg, V_tkg, kv_cache_update_idx, S_tkg, B + ) + else: + # B scalar DMAs, with any S_tkg + _update_block_cache_scalar( + K_cache, V_cache, K_tkg, V_tkg, kv_cache_update_idx, S_tkg, B + ) + + +def _update_block_cache_batched( + K_cache: nl.ndarray, + V_cache: nl.ndarray, + K_tkg: nl.ndarray, + V_tkg: nl.ndarray, + kv_cache_update_idx: nl.ndarray, + S_tkg: int, + B: int, +) -> None: + """ + Update block KV cache with new tokens using batched DMA operations. + + This optimized version writes all B batches in a single DMA operation using vector_offset. + Currently limited to S_tkg=1 due to access pattern bug. + + Args: + K_cache: [num_blocks, block_len, d_head] + V_cache: [num_blocks, block_len, d_head] + K_tkg: [d_head, B*S_tkg] + V_tkg: [B*S_tkg, d_head] + kv_cache_update_idx: [B, 1] slot indices for cache update (uint32 max = skip) + S_tkg: number of new tokens (must be 1) + B: batch size + """ + _, n_prgs, prg_id = get_verified_program_sharding_info("kv_cache update", (0, 1), 2) + kernel_assert(n_prgs <= 2, f"Expected lnc in [1,2], got {n_prgs}") + kernel_assert( + S_tkg == 1, f"_update_block_cache_batched() only supports S_tkg=1, got {S_tkg}" + ) + kernel_assert( + B * S_tkg <= nl.tile_size.pmax, + f"B * S_tkg must be <= {nl.tile_size.pmax}, got {B * S_tkg}", + ) + + num_blocks, blk_len, d_head = K_cache.shape + + kernel_assert( + kv_cache_update_idx.shape == (B, 1), + f"kv_cache_update_idx shape mismatch: expected {(B, 1)}, got {kv_cache_update_idx.shape}", + ) + kernel_assert( + K_cache.shape == V_cache.shape, + f"K/V cache shape mismatch: K={K_cache.shape} vs V={V_cache.shape}", + ) + kernel_assert( + V_tkg.shape == (B * S_tkg, d_head), + f"V_tkg shape mismatch: expected {(B * S_tkg, d_head)}, got {V_tkg.shape}", + ) + kernel_assert( + K_tkg.shape == (d_head, B * S_tkg), + f"K_tkg shape mismatch: expected {(d_head, B * S_tkg)}, got {K_tkg.shape}", + ) + + # Transpose K_tkg on lnc=1 + if n_prgs == 1 or prg_id == 1: + K_transposed_sb = nl.ndarray((B * S_tkg, d_head), K_tkg.dtype, nl.sbuf) + _transpose_sbuf(K_tkg, K_transposed_sb) + + # Copy cache update indices to SBUF + kv_cache_update_idx_sb = nl.ndarray( + kv_cache_update_idx.shape, kv_cache_update_idx.dtype, nl.sbuf + ) + nisa.dma_copy(kv_cache_update_idx_sb, kv_cache_update_idx) + + # Vector DMA with indirect addressing: + # - Reshape cache to (num_blocks*blk_len, d_head) so each row has stride d_head + # - vector_offset provides per-batch global token indices: kv_cache_update_idx_sb[b] + # - DMA engine scales by d_head + + # Update V_cache on lnc=0 + if n_prgs == 1 or prg_id == 0: + nisa.dma_copy( + dst=V_cache.reshape((num_blocks * blk_len, d_head)).ap( + pattern=[[1, B], [d_head, S_tkg], [1, d_head]], + offset=0, + vector_offset=kv_cache_update_idx_sb, + indirect_dim=0, + ), + src=V_tkg.ap(pattern=[[S_tkg * d_head, B], [d_head, S_tkg], [1, d_head]]), + oob_mode=oob_mode.skip, # skip writes for invalid batch (position_id = uint32_max) + ) + + # Update K_cache on lnc=1 + if n_prgs == 1 or prg_id == 1: + nisa.dma_copy( + dst=K_cache.reshape((num_blocks * blk_len, d_head)).ap( + pattern=[[1, B], [d_head, S_tkg], [1, d_head]], + offset=0, + vector_offset=kv_cache_update_idx_sb, + indirect_dim=0, + ), + src=K_transposed_sb.ap( + pattern=[[S_tkg * d_head, B], [d_head, S_tkg], [1, d_head]] + ), + oob_mode=oob_mode.skip, # skip writes for invalid batch (position_id = uint32_max) + ) + + +def _update_block_cache_scalar( + K_cache: nl.ndarray, + V_cache: nl.ndarray, + K_tkg: nl.ndarray, + V_tkg: nl.ndarray, + kv_cache_update_idx: nl.ndarray, + S_tkg: int, + B: int, +) -> None: + """ + Update block KV cache with new tokens using per-batch scalar_offset. + + This version iterates over batches and uses scalar_offset for indirect addressing. + Supports S_tkg > 1. + + Args: + K_cache: [num_blocks, block_len, d_head] + V_cache: [num_blocks, block_len, d_head] + K_tkg: [d_head, B*S_tkg] + V_tkg: [B*S_tkg, d_head] + kv_cache_update_idx: [B, 1] slot indices for cache update (uint32 max = skip) + S_tkg: number of new tokens + B: batch size + """ + _, n_prgs, prg_id = get_verified_program_sharding_info("kv_cache update", (0, 1), 2) + kernel_assert(n_prgs <= 2, f"Expected lnc in [1,2], got {n_prgs}") + kernel_assert( + B * S_tkg <= nl.tile_size.pmax, + f"B * S_tkg must be <= {nl.tile_size.pmax}, got {B * S_tkg}", + ) + + num_blocks, blk_len, d_head = K_cache.shape + + kernel_assert( + kv_cache_update_idx.shape == (B, 1), + f"kv_cache_update_idx shape mismatch: expected {(B, 1)}, got {kv_cache_update_idx.shape}", + ) + kernel_assert( + K_cache.shape == V_cache.shape, + f"K/V cache shape mismatch: K={K_cache.shape} vs V={V_cache.shape}", + ) + kernel_assert( + V_tkg.shape == (B * S_tkg, d_head), + f"V_tkg shape mismatch: expected {(B * S_tkg, d_head)}, got {V_tkg.shape}", + ) + kernel_assert( + K_tkg.shape == (d_head, B * S_tkg), + f"K_tkg shape mismatch: expected {(d_head, B * S_tkg)}, got {K_tkg.shape}", + ) + + # Transpose K_tkg once on lnc=1 + if n_prgs == 1 or prg_id == 1: + K_transposed_sb = nl.ndarray((B * S_tkg, d_head), K_tkg.dtype, nl.sbuf) + _transpose_sbuf(K_tkg, K_transposed_sb) + + # Update cache per batch element using scalar_offset + for batch_idx in range(B): + start_position = nl.ndarray((1, 1), dtype=nl.uint32, buffer=nl.sbuf) + nisa.dma_copy(start_position, kv_cache_update_idx[batch_idx]) + + # Update V_cache on lnc=0 + if n_prgs == 1 or prg_id == 0: + nisa.dma_copy( + dst=V_cache.ap( + pattern=[[d_head, S_tkg], [1, d_head]], + offset=0, + scalar_offset=start_position, + indirect_dim=1, + ), + src=V_tkg[nl.ds(batch_idx * S_tkg, S_tkg), :], + oob_mode=oob_mode.skip, # skip writes for invalid batch (position_id = uint32_max) + ) + + # Update K_cache on lnc=1 + if n_prgs == 1 or prg_id == 1: + nisa.dma_copy( + dst=K_cache.ap( + pattern=[[d_head, S_tkg], [1, d_head]], + offset=0, + scalar_offset=start_position, + indirect_dim=1, + ), + src=K_transposed_sb[nl.ds(batch_idx * S_tkg, S_tkg), :], + oob_mode=oob_mode.skip, # skip writes for invalid batch (position_id = uint32_max) + ) + + +############################# FP8 Quantization Helpers ############################# + +_FP8_E4M3_MAX = get_max_positive_value_for_dtype(nl.float8_e4m3) +_FP8_E4M3_MIN = -_FP8_E4M3_MAX + + +def _quantize_to_fp8(tensor, scale, sbm): + """ + Quantize a tensor to FP8 E4M3 format using a single scalar scale. + + Computes: output = cast_to_fp8(clip(tensor * scale, [-240, 240])) + + The scale must represent a single scalar value. Two shapes are supported for + compatibility with different APIs: + - (1, 1): scalar, broadcast to partition dim + - (PMAX, 1): assumed to contain identical values, copied directly + + Args: + tensor: Input tensor in SBUF, shape (P, F), dtype bf16 or f32 + scale: Scale tensor in HBM, shape (PMAX, 1) or (1, 1). + Must contain a single scalar value (broadcast or replicated). + Supported dtypes: float32, float16, bfloat16. + sbm: SbufManager for allocations + + Returns: + FP8 E4M3 quantized tensor in SBUF, same shape as input + """ + kernel_assert(tensor.buffer == nl.sbuf, "quantize_to_fp8 requires tensor in SBUF") + kernel_assert( + not is_fp8_e4m3(tensor.dtype), + f"quantize_to_fp8 input already FP8: {tensor.dtype}", + ) + + partition_dim = tensor.shape[0] + + # Copy scale to SBUF + # ndarray avoids anti-dependency with other stack values + scale_sb = nl.ndarray(shape=(partition_dim, 1), dtype=nl.float32, buffer=nl.sbuf) + if scale.shape == (nl.tile_size.pmax, 1): + nisa.dma_copy(dst=scale_sb, src=scale[0:partition_dim, :]) + else: + kernel_assert( + scale.shape == (1, 1), + f"scale must be (pmax, 1) or (1, 1), got {scale.shape}", + ) + nisa.dma_copy( + dst=scale_sb, + src=TensorView(scale).broadcast(dim=0, size=partition_dim).get_view(), + ) + + # Scale: multiply by scale + tensor_scaled = sbm.alloc_stack(tensor.shape, dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar(tensor_scaled, tensor, nl.multiply, scale_sb) + + # Clip to FP8 range and cast + tensor_fp8 = sbm.alloc_stack(tensor.shape, dtype=nl.float8_e4m3, buffer=nl.sbuf) + nisa.tensor_scalar( + tensor_fp8, + tensor_scaled, + nl.minimum, + _FP8_E4M3_MAX, + op1=nl.maximum, + operand1=_FP8_E4M3_MIN, + ) + + return tensor_fp8 + + +def _transpose_sbuf(src, dst): + """ + Transpose tensor from SBUF to SBUF via PSUM. + + For FP8: nc_transpose doesn't support FP8, so we cast to bf16, transpose, cast back. + + Args: + src: Source tensor in SBUF (P, F) + dst: Destination tensor in SBUF (F, P) - must be pre-allocated + """ + if is_fp8_e4m3(src.dtype): + # FP8 workaround: cast to bf16, transpose, cast back + src_bf16 = nl.ndarray(src.shape, dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(src_bf16, src) + psum = nl.ndarray(dst.shape, dtype=nl.bfloat16, buffer=nl.psum) + nisa.nc_transpose(dst=psum, data=src_bf16) + nisa.tensor_copy(dst=dst, src=psum) + else: + kernel_assert( + src.dtype in (nl.bfloat16, nl.float16), + f"_transpose_sbuf only supports bf16, fp16, or fp8, got {src.dtype}", + ) + psum = nl.ndarray(dst.shape, dtype=src.dtype, buffer=nl.psum) + nisa.nc_transpose(dst=psum, data=src) + nisa.tensor_copy(dst=dst, src=psum) diff --git a/contrib/models/Ministral-3-14B-Instruct-2512/src/extract_text_model.py b/contrib/models/Ministral-3-14B-Instruct-2512/src/extract_text_model.py new file mode 100644 index 00000000..e61d3060 --- /dev/null +++ b/contrib/models/Ministral-3-14B-Instruct-2512/src/extract_text_model.py @@ -0,0 +1,400 @@ +#!/usr/bin/env python3 +""" +Ministral 14B: Extract text-only BF16 weights from FP8 multimodal checkpoint. + +The HuggingFace checkpoint `mistralai/Ministral-3-14B-Instruct-2512` is a +`Mistral3ForConditionalGeneration` model (multimodal: Pixtral vision encoder + +text decoder) with FP8 E4M3 quantized linear weights. + +This script: + 1. Reads safetensors weights directly (no model class needed) + 2. Strips `language_model.` prefix from text keys + 3. Dequantizes FP8 E4M3 weights to BF16 using per-tensor weight_scale_inv + 4. Drops vision_tower, multi_modal_projector, activation_scale, weight_scale_inv keys + 5. Creates a LlamaForCausalLM-compatible config.json (avoids Pixtral auto-promotion) + 6. Cleans tokenizer_config.json (removes processor_class references) + 7. Writes sharded BF16 safetensors + +Output: A directory loadable by vLLM as a standard LlamaForCausalLM with --hf-overrides. + +Model details: + - 40 layers, hidden=5120, heads=32/8kv, head_dim=128, intermediate=16384 + - vocab=131072, tie_word_embeddings=false (separate lm_head) + - YaRN RoPE scaling (factor=16, theta=1e9) + - FP8 E4M3 linear weights with per-tensor BF16 scalar scales + - embed_tokens, lm_head, layernorm weights are already BF16 + +Usage: + python extract_text_model.py [--src /path/to/full/model] [--dst /path/to/output] + +Requires only: safetensors, torch (no transformers version constraint) +""" + +import argparse +import json +import os +import shutil +import time + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + + +def dequantize_fp8(weight, scale_inv): + """Dequantize FP8 E4M3 weight to BF16 using per-tensor scale. + + Formula: bf16_weight = fp8_weight.to(bf16) * weight_scale_inv + """ + return weight.to(torch.bfloat16) * scale_inv.to(torch.bfloat16) + + +def extract(src_dir, dst_dir): + print("=" * 60) + print("Ministral 14B: Extract text-only BF16 backbone") + print("=" * 60) + print(f" Source: {src_dir}") + print(f" Destination: {dst_dir}") + + os.makedirs(dst_dir, exist_ok=True) + + # Load source config + with open(os.path.join(src_dir, "config.json")) as f: + full_config = json.load(f) + + text_config = full_config.get("text_config", {}) + if not text_config: + print(" ERROR: no text_config found in source config.json") + return + + # --------------------------------------------------------------- + # Step 1: Read and dequantize weights + # --------------------------------------------------------------- + print(f"\n[1/5] Reading and dequantizing weights from safetensors...") + t0 = time.time() + + # Load safetensors index + index_path = os.path.join(src_dir, "model.safetensors.index.json") + with open(index_path) as f: + idx = json.load(f) + weight_map = idx["weight_map"] + + # Group keys by shard file + file_keys = {} + for key, fname in weight_map.items(): + if fname not in file_keys: + file_keys[fname] = [] + file_keys[fname].append(key) + + text_prefix = "language_model." + text_weights = {} + skipped_vision = 0 + skipped_scales = 0 + dequantized = 0 + + # First pass: collect all scale keys for lookup + scale_map = {} # "language_model.model.layers.0.self_attn.q_proj" -> shard_file + for key in weight_map: + if key.endswith(".weight_scale_inv") and key.startswith(text_prefix): + base = key[: -len(".weight_scale_inv")] + scale_map[base] = weight_map[key] + + for fname in sorted(file_keys.keys()): + keys = file_keys[fname] + fpath = os.path.join(src_dir, fname) + print(f" Processing {fname} ({len(keys)} keys)...") + f = safe_open(fpath, framework="pt") + + for key in keys: + # Skip non-text keys + if not key.startswith(text_prefix): + skipped_vision += 1 + continue + + # Skip activation_scale and weight_scale_inv (not needed in output) + if key.endswith(".activation_scale") or key.endswith(".weight_scale_inv"): + skipped_scales += 1 + continue + + new_key = key[len(text_prefix) :] + tensor = f.get_tensor(key) + + # Dequantize FP8 weights + if tensor.dtype == torch.float8_e4m3fn: + # Find the corresponding scale + base_key = key[: -len(".weight")] # strip ".weight" + scale_key = base_key + ".weight_scale_inv" + if scale_key in weight_map: + # Scale might be in a different shard + scale_shard = weight_map[scale_key] + if scale_shard == fname: + scale = f.get_tensor(scale_key) + else: + sf = safe_open( + os.path.join(src_dir, scale_shard), framework="pt" + ) + scale = sf.get_tensor(scale_key) + tensor = dequantize_fp8(tensor, scale) + dequantized += 1 + else: + print(f" WARNING: no scale for {key}, casting directly to bf16") + tensor = tensor.to(torch.bfloat16) + elif tensor.dtype != torch.bfloat16: + tensor = tensor.to(torch.bfloat16) + + text_weights[new_key] = tensor + + elapsed = time.time() - t0 + print(f" Extracted {len(text_weights)} text weights") + print(f" Dequantized {dequantized} FP8 tensors to BF16") + print(f" Skipped {skipped_vision} vision/projector keys") + print(f" Skipped {skipped_scales} scale keys") + print(f" Time: {elapsed:.1f}s") + + # --------------------------------------------------------------- + # Step 2: Save sharded weights + # --------------------------------------------------------------- + print(f"\n[2/5] Saving text-only BF16 weights...") + t0 = time.time() + + total_bytes = sum(t.numel() * t.element_size() for t in text_weights.values()) + print(f" Total size: {total_bytes / 1e9:.2f} GB") + + MAX_SHARD = 5e9 # 5 GB per shard + shard_idx = 0 + current_shard = {} + current_size = 0 + new_weight_map = {} + + def flush(): + nonlocal shard_idx, current_shard, current_size + if not current_shard: + return + shard_idx += 1 + sname = f"model-{shard_idx:05d}-of-PLACEHOLDER.safetensors" + save_file(current_shard, os.path.join(dst_dir, sname)) + for k in current_shard: + new_weight_map[k] = sname + print( + f" Shard {shard_idx}: {len(current_shard)} tensors, " + f"{current_size / 1e9:.2f} GB" + ) + current_shard = {} + current_size = 0 + + for k in sorted(text_weights.keys()): + t = text_weights[k] + sz = t.numel() * t.element_size() + if current_size + sz > MAX_SHARD and current_shard: + flush() + current_shard[k] = t + current_size += sz + flush() + + # Rename shards with correct total count + total_shards = shard_idx + final_map = {} + for k, sname in new_weight_map.items(): + final = sname.replace("PLACEHOLDER", f"{total_shards:05d}") + final_map[k] = final + + # Rename shard files + for i in range(1, total_shards + 1): + old_name = f"model-{i:05d}-of-PLACEHOLDER.safetensors" + new_name = f"model-{i:05d}-of-{total_shards:05d}.safetensors" + old_path = os.path.join(dst_dir, old_name) + new_path = os.path.join(dst_dir, new_name) + if os.path.exists(old_path): + os.rename(old_path, new_path) + + with open(os.path.join(dst_dir, "model.safetensors.index.json"), "w") as f: + json.dump({"metadata": {}, "weight_map": final_map}, f, indent=2) + + print(f" Saved {total_shards} shards in {time.time() - t0:.1f}s") + + # --------------------------------------------------------------- + # Step 3: Create LlamaForCausalLM-compatible config + # --------------------------------------------------------------- + print(f"\n[3/5] Creating config.json...") + + # IMPORTANT: Use LlamaForCausalLM architecture. + # vLLM 0.16 auto-promotes MistralForCausalLM to PixtralForConditionalGeneration + # based on tokenizer/processor hints. Using LlamaForCausalLM avoids this entirely. + # The Llama code path also correctly handles head_dim != hidden_size/num_heads. + rope_params = text_config.get("rope_parameters", {}) + + config = { + "architectures": ["LlamaForCausalLM"], + "model_type": "llama", + "torch_dtype": "bfloat16", + "hidden_size": text_config["hidden_size"], + "intermediate_size": text_config["intermediate_size"], + "num_hidden_layers": text_config["num_hidden_layers"], + "num_attention_heads": text_config["num_attention_heads"], + "num_key_value_heads": text_config.get( + "num_key_value_heads", text_config["num_attention_heads"] + ), + "head_dim": text_config.get("head_dim", 128), + "vocab_size": text_config["vocab_size"], + "max_position_embeddings": text_config.get("max_position_embeddings", 262144), + "rms_norm_eps": text_config.get("rms_norm_eps", 1e-5), + "hidden_act": text_config.get("hidden_act", "silu"), + "tie_word_embeddings": full_config.get("tie_word_embeddings", False), + "attention_bias": text_config.get("attention_bias", False), + "attention_dropout": text_config.get("attention_dropout", 0.0), + "bos_token_id": text_config.get("bos_token_id", 1), + "eos_token_id": text_config.get("eos_token_id", 2), + "rope_theta": rope_params.get("rope_theta", 1000000000.0), + # YaRN rope scaling -- required for correct position embeddings + "rope_scaling": { + "rope_type": rope_params.get("rope_type", "yarn"), + "type": rope_params.get("type", "yarn"), + "factor": rope_params.get("factor", 16.0), + "beta_fast": rope_params.get("beta_fast", 32.0), + "beta_slow": rope_params.get("beta_slow", 1.0), + "original_max_position_embeddings": rope_params.get( + "original_max_position_embeddings", 16384 + ), + "mscale": rope_params.get("mscale", 1.0), + "mscale_all_dim": rope_params.get("mscale_all_dim", 1.0), + }, + # Do NOT include sliding_window -- causes NxDI tensor shape issues + } + + # Print config summary + computed_head_dim = config["hidden_size"] // config["num_attention_heads"] + print(f" model_type: {config['model_type']}") + print(f" architectures: {config['architectures']}") + print(f" hidden_size: {config['hidden_size']}") + print(f" num_hidden_layers: {config['num_hidden_layers']}") + print(f" num_attention_heads: {config['num_attention_heads']}") + print(f" num_key_value_heads: {config['num_key_value_heads']}") + print( + f" head_dim: {config['head_dim']} (config) vs {computed_head_dim} (computed)" + ) + print(f" intermediate_size: {config['intermediate_size']}") + print(f" vocab_size: {config['vocab_size']}") + print(f" rms_norm_eps: {config['rms_norm_eps']}") + print(f" rope_theta: {config['rope_theta']}") + print( + f" rope_scaling: {config['rope_scaling']['rope_type']}, " + f"factor={config['rope_scaling']['factor']}" + ) + print(f" tie_word_embeddings: {config['tie_word_embeddings']}") + + if config["head_dim"] != computed_head_dim: + print( + f" NOTE: head_dim ({config['head_dim']}) != " + f"hidden_size/num_heads ({computed_head_dim})" + ) + print( + f" This is intentional -- the model uses head_dim=128 with hidden_size=5120" + ) + + with open(os.path.join(dst_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + # --------------------------------------------------------------- + # Step 4: Copy and clean tokenizer files + # --------------------------------------------------------------- + print(f"\n[4/5] Copying tokenizer files...") + for fname in os.listdir(src_dir): + if ( + fname.startswith("tokenizer") + or fname == "special_tokens_map.json" + or fname == "tekken.json" + or fname == "generation_config.json" + or fname == "chat_template.jinja" + ): + src = os.path.join(src_dir, fname) + if os.path.isfile(src): + shutil.copy2(src, os.path.join(dst_dir, fname)) + print(f" Copied {fname}") + + # Clean tokenizer_config.json: + # 1. Remove processor_class to prevent vLLM 0.16 Pixtral auto-promotion + # 2. Fix tokenizer_class: "TokenizersBackend" is Mistral-internal, not in + # HuggingFace transformers. Replace with "PreTrainedTokenizerFast" which + # works with the standard tokenizer.json file. + tok_config_path = os.path.join(dst_dir, "tokenizer_config.json") + if os.path.exists(tok_config_path): + with open(tok_config_path) as f: + tok_config = json.load(f) + changed = False + for field in ["processor_class", "auto_map"]: + if field in tok_config: + del tok_config[field] + changed = True + print(f" Removed '{field}' from tokenizer_config.json") + if tok_config.get("tokenizer_class") == "TokenizersBackend": + tok_config["tokenizer_class"] = "PreTrainedTokenizerFast" + changed = True + print( + f" Fixed tokenizer_class: TokenizersBackend -> PreTrainedTokenizerFast" + ) + # Remove Mistral-specific fields that break HF transformers: + # - extra_special_tokens: list format not compatible (expects dict) + # - backend: Mistral-internal tokenizer backend identifier + for field in ["extra_special_tokens", "backend"]: + if field in tok_config: + del tok_config[field] + changed = True + print(f" Removed '{field}' from tokenizer_config.json") + if changed: + with open(tok_config_path, "w") as f: + json.dump(tok_config, f, indent=2) + + # Also clean processor_config.json if it was copied + proc_config_path = os.path.join(dst_dir, "processor_config.json") + if os.path.exists(proc_config_path): + os.remove(proc_config_path) + print(f" Removed processor_config.json") + + # --------------------------------------------------------------- + # Step 5: Summary + # --------------------------------------------------------------- + print(f"\n[5/5] Summary") + total_size = sum( + os.path.getsize(os.path.join(dst_dir, f)) + for f in os.listdir(dst_dir) + if os.path.isfile(os.path.join(dst_dir, f)) + ) + n_files = len(os.listdir(dst_dir)) + print(f" Output directory: {dst_dir}") + print(f" Total files: {n_files}") + print(f" Total size: {total_size / 1e9:.2f} GB") + print(f"\n To serve with vLLM + NxDI TKG:") + print(f" python -m vllm.entrypoints.openai.api_server \\") + print(f" --model {dst_dir} \\") + print(f" --tensor-parallel-size 4 --max-model-len 4096 --max-num-seqs 4 \\") + print(f" --no-enable-prefix-caching --block-size 8 \\") + print( + f' --hf-overrides \'{{"architectures": ["LlamaForCausalLM"], ' + f'"model_type": "llama"}}\' \\' + ) + print( + f' --additional-config \'{{"override_neuron_config": {{' + f'"fused_qkv": true, "qkv_nki_kernel_enabled": true, ' + f'"qkv_kernel_enabled": true, ' + f'"attn_block_tkg_nki_kernel_enabled": true, ' + f'"attn_block_tkg_nki_kernel_cache_update": true}}}}\'' + ) + print("=" * 60) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Extract text-only BF16 model from Ministral 14B FP8 multimodal checkpoint" + ) + parser.add_argument( + "--src", + default="/home/ubuntu/models/Ministral-3-14B-Instruct-2512", + help="Path to full multimodal FP8 checkpoint", + ) + parser.add_argument( + "--dst", + default="/home/ubuntu/models/Ministral-3-14B-text-bf16", + help="Output directory for text-only BF16 model", + ) + args = parser.parse_args() + extract(args.src, args.dst) diff --git a/contrib/models/Ministral-3-14B-Instruct-2512/src/fix_nki030.py b/contrib/models/Ministral-3-14B-Instruct-2512/src/fix_nki030.py new file mode 100644 index 00000000..d722eb44 --- /dev/null +++ b/contrib/models/Ministral-3-14B-Instruct-2512/src/fix_nki030.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +""" +Fix attention_block_tkg_multi_kv.py for NKI 0.3.0 compatibility. + +NKI 0.3.0 does NOT support keyword-only arguments (after *,). +Fix: remove *, entirely, give all non-defaulted params appropriate defaults. + +For ndarray params: default = None (they MUST be provided at call time) +For bool params: default based on their usage (False for enable flags) +For float/int params: default = 0.0 / 0 +For Optional params: default = None +""" + +import sys +import re + +fpath = ( + sys.argv[1] + if len(sys.argv) > 1 + else "/opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/lib/python3.12/site-packages/nkilib/experimental/transformer/attention_block_tkg_multi_kv.py" +) + +with open(fpath) as f: + content = f.read() + +# Step 1: Remove the *, line +content = content.replace(" *,\n", "") +print(" Removed *,") + +# Step 2: For params without defaults that come after params with defaults, +# add appropriate defaults. We need to find the function signature. +# The first defaulted param (that already had a default) determines the cutoff. + +# Find function signature boundaries +func_start = content.find("def attention_block_tkg(") +paren_depth = 0 +in_func = False +func_end = func_start +for i in range(func_start, len(content)): + if content[i] == "(": + paren_depth += 1 + in_func = True + elif content[i] == ")": + paren_depth -= 1 + if in_func and paren_depth == 0: + func_end = i + break + +sig = content[func_start : func_end + 1] +lines = sig.split("\n") +new_lines = [] + +# Track if we've seen a defaulted param +seen_default = False + +for line in lines: + stripped = line.strip() + + # Skip comments, empty lines, def line, closing paren + if ( + stripped.startswith("#") + or stripped == "" + or stripped.startswith("def ") + or stripped == ")" + or stripped == ")," + ): + new_lines.append(line) + continue + + # Check if line has a default + has_default = "=" in stripped and ":" in stripped + + if has_default: + seen_default = True + new_lines.append(line) + continue + + # No default - need to add one if we've seen a defaulted param + if not seen_default: + new_lines.append(line) + continue + + # Need to add a default. Determine appropriate default based on type. + if "Optional[" in stripped: + default = "None" + elif ": nl.ndarray" in stripped: + default = "None" + elif ": bool" in stripped: + default = "False" + elif ": float" in stripped: + default = "0.0" + elif ": int" in stripped: + default = "0" + elif "QuantizationType" in stripped: + default = "None" + elif "SbufManager" in stripped: + default = "None" + else: + default = "None" + + # Add default before trailing comma + if stripped.endswith(","): + line = line.rstrip().rstrip(",") + f" = {default}," + else: + line = line.rstrip() + f" = {default}" + + new_lines.append(line) + +new_sig = "\n".join(new_lines) +content = content[:func_start] + new_sig + content[func_end + 1 :] + +with open(fpath, "w") as f: + f.write(content) + +# Verify syntax +import py_compile + +try: + py_compile.compile(fpath, doraise=True) + print("Syntax OK") +except py_compile.PyCompileError as e: + print(f"Syntax error: {e}") + sys.exit(1) + +# Show first 30 lines of new signature for verification +with open(fpath) as f: + all_lines = f.readlines() +for i, line in enumerate(all_lines): + if "def attention_block_tkg" in line: + for j in range(i, min(i + 45, len(all_lines))): + print(f"{j + 1}: {all_lines[j]}", end="") + break diff --git a/contrib/models/Ministral-3-14B-Instruct-2512/src/modeling_leanstral.py b/contrib/models/Ministral-3-14B-Instruct-2512/src/modeling_leanstral.py new file mode 100644 index 00000000..5105c211 --- /dev/null +++ b/contrib/models/Ministral-3-14B-Instruct-2512/src/modeling_leanstral.py @@ -0,0 +1,959 @@ +""" +Ministral-3-14B-Instruct-2512 (Leanstral) on AWS Neuron via NxDI. + +A vision-language model combining a Pixtral vision encoder with a Llama-compatible +text decoder and a Mistral3-specific PatchMerger projector. Architecture: 40 layers, +hidden=5120, heads=32/8kv, vocab=131072, intermediate=16384, head_dim=128. + +This contrib model reuses NxDI's Pixtral VL pipeline (NeuronPixtralVisionModel for the +vision encoder, NeuronLlamaModel for the text decoder) with three key adaptations: + +1. CPU projector: The Mistral3 PatchMerger (spatial 2x2 merge via F.unfold + 2-layer MLP) + runs on CPU since it has no NxDI equivalent. +2. SHARD_OVER_HEADS GQA: Avoids replicating KV heads when kv_heads >= tp_degree. With 8 KV + heads at TP=4, each rank gets kv_heads_per_rank=2 instead of stock NxDI's replication to 8. +3. Multi-KV-head TKG kernel: A modified nki-library attention_block_tkg kernel that supports + kv_heads_per_rank > 1 via a virtual-batch approach. + +Requires: +- SDK 2.28 (neuronx-cc >= 2.23, neuronx-distributed-inference >= 0.8) +- trn2.3xlarge (TP=4, LNC=2) +- Model checkpoint: mistralai/Ministral-3-14B-Instruct-2512 (HuggingFace, gated) + +Known limitations: +- TKG kernel uses grid=1 (NCC_IXLV002 workaround), ~4% text throughput cost +- FP8 checkpoint weights are dequantized to bf16 during state_dict conversion +""" + +import copy +import json +import logging +import os +from types import SimpleNamespace + +import torch +import torch.nn as nn +from safetensors import safe_open + +logger = logging.getLogger(__name__) + +# Mistral3 uses image_token_id=10 for [IMG] tokens in the vocabulary +IMAGE_TOKEN_ID = 10 +PATCH_SIZE = 16 +SPATIAL_MERGE_SIZE = 2 + + +# --------------------------------------------------------------------------- +# SHARD_OVER_HEADS GQA patch +# --------------------------------------------------------------------------- +# NxDI 0.8 only supports CONVERT_TO_MHA (replicates KV heads to match Q heads) +# and REPLICATE_TO_TP_DEGREE. For models where kv_heads >= tp_degree and +# kv_heads % tp_degree == 0, we can shard KV heads across ranks instead. +# This avoids inflating KV cache memory by 4x and enables the multi-KV-head +# TKG kernel path. +# +# This patch should be applied BEFORE any NxDI model classes are imported. +# It is a candidate for upstream NxDI inclusion (see fork branch +# feature/shard-over-heads-gqa on github.com/jimburtoft/neuronx-distributed-inference). + + +_shard_over_heads_applied = False + + +def apply_shard_over_heads_patch(): + """Patch NxDI's GQA sharding to support kv_heads >= tp_degree without replication.""" + global _shard_over_heads_applied + if _shard_over_heads_applied: + return + + import neuronx_distributed_inference.modules.attention.gqa as gqa_module + + _orig_determine = gqa_module.determine_sharding_strategy + _orig_get_shardable = gqa_module.get_shardable_head_counts + + def _patched_determine( + tp_degree, source_key_value_heads, desired_sharding_strategy=None + ): + if ( + source_key_value_heads >= tp_degree + and source_key_value_heads % tp_degree == 0 + ): + return gqa_module.GQA.CONVERT_TO_MHA + return _orig_determine( + tp_degree, source_key_value_heads, desired_sharding_strategy + ) + + def _patched_get_shardable( + tp_degree, num_attention_heads, num_key_value_heads, sharding_strategy + ): + if ( + sharding_strategy == gqa_module.GQA.CONVERT_TO_MHA + and num_key_value_heads >= tp_degree + and num_key_value_heads % tp_degree == 0 + ): + from neuronx_distributed_inference.modules.attention.gqa import ( + get_number_of_extra_heads, + ) + + updated = num_attention_heads + get_number_of_extra_heads( + num_attention_heads, tp_degree + ) + return updated, num_key_value_heads + return _orig_get_shardable( + tp_degree, num_attention_heads, num_key_value_heads, sharding_strategy + ) + + # Patch in all modules that import these functions + for module_path in [ + "neuronx_distributed_inference.modules.attention.gqa", + "neuronx_distributed_inference.modules.kvcache.kv_cache_manager", + ]: + try: + import importlib + + mod = importlib.import_module(module_path) + mod.determine_sharding_strategy = _patched_determine + mod.get_shardable_head_counts = _patched_get_shardable + except (ImportError, AttributeError): + pass + + # Also patch gpt_kv_cache_manager if present + try: + import neuronx_distributed_inference.modules.kvcache.gpt_kv_cache_manager as gpt_kv + + gpt_kv.determine_sharding_strategy = _patched_determine + gpt_kv.get_shardable_head_counts = _patched_get_shardable + except (ImportError, AttributeError): + pass + + _shard_over_heads_applied = True + logger.info("SHARD_OVER_HEADS GQA patch applied") + + +# --------------------------------------------------------------------------- +# Multi-KV-head TKG kernel adapter patch +# --------------------------------------------------------------------------- +# The stock NxDI TKG kernel hardcodes kv_heads=1 per rank. This adapter +# replaces the dispatch method to call our modified nki-library kernel +# (attention_block_tkg_multi_kv) which supports kv_heads_per_rank > 1. +# For kv_heads_per_rank == 1, the patch is a no-op passthrough. + +_multi_kv_patch_applied = False + + +def apply_multi_kv_tkg_patch(): + """Patch NxDI's TKG kernel dispatch for multi-KV-head support.""" + global _multi_kv_patch_applied + if _multi_kv_patch_applied: + return + + from . import patch_native_multi_kv + + patch_native_multi_kv.apply_patch() + _multi_kv_patch_applied = True + logger.info("Multi-KV-head TKG kernel adapter patch applied") + + +# --------------------------------------------------------------------------- +# CPU Projector (Mistral3 PatchMerger + MLP) +# --------------------------------------------------------------------------- +# The Mistral3 projector does spatial 2x2 merging of vision patches followed +# by a 2-layer MLP. It runs on CPU because NxDI's Pixtral pipeline does not +# include this specific projector variant. + + +class Mistral3PatchMerger(nn.Module): + """Spatial 2x2 patch merger using F.unfold for correct spatial ordering.""" + + def __init__(self, hidden_size, spatial_merge_size=SPATIAL_MERGE_SIZE): + super().__init__() + self.hidden_size = hidden_size + self.spatial_merge_size = spatial_merge_size + self.merging_layer = nn.Linear( + hidden_size * spatial_merge_size * spatial_merge_size, + hidden_size, + bias=False, + ) + + def forward(self, features, ph, pw): + merge = self.spatial_merge_size + ph_m = (ph // merge) * merge + pw_m = (pw // merge) * merge + feats = features.view(ph, pw, self.hidden_size)[:ph_m, :pw_m, :] + image_grid = feats.permute(2, 0, 1).unsqueeze(0) + grid = torch.nn.functional.unfold(image_grid, kernel_size=merge, stride=merge) + grid = grid.view(self.hidden_size * merge * merge, -1).t() + return self.merging_layer(grid) + + +class Mistral3CPUProjector(nn.Module): + """CPU-side vision-to-text projector: RMSNorm -> PatchMerger -> 2-layer MLP.""" + + def __init__( + self, + vision_hidden_size, + text_hidden_size, + spatial_merge_size=SPATIAL_MERGE_SIZE, + ): + super().__init__() + self.norm = nn.RMSNorm(vision_hidden_size, eps=1e-5) + self.patch_merger = Mistral3PatchMerger(vision_hidden_size, spatial_merge_size) + self.linear_1 = nn.Linear(vision_hidden_size, text_hidden_size, bias=False) + self.act = nn.GELU() + self.linear_2 = nn.Linear(text_hidden_size, text_hidden_size, bias=False) + + def forward(self, vision_features, image_h, image_w): + ph = image_h // PATCH_SIZE + pw = image_w // PATCH_SIZE + feats = vision_features.squeeze(0) + feats = self.norm(feats.float()).to(feats.dtype) + feats = self.patch_merger(feats, ph, pw) + feats = self.linear_1(feats) + feats = self.act(feats) + feats = self.linear_2(feats) + return feats + + +def load_cpu_projector(model_path, vision_hidden_size, text_hidden_size): + """Load Mistral3CPUProjector weights from safetensors checkpoint.""" + projector = Mistral3CPUProjector(vision_hidden_size, text_hidden_size) + weight_mapping = { + "multi_modal_projector.norm.weight": "norm.weight", + "multi_modal_projector.patch_merger.merging_layer.weight": "patch_merger.merging_layer.weight", + "multi_modal_projector.linear_1.weight": "linear_1.weight", + "multi_modal_projector.linear_2.weight": "linear_2.weight", + } + safetensors_files = sorted( + f + for f in os.listdir(model_path) + if f.endswith(".safetensors") and "consolidated" not in f + ) + for fname in safetensors_files: + with safe_open(os.path.join(model_path, fname), framework="pt") as f: + for key in f.keys(): + if key in weight_mapping: + target_key = weight_mapping[key] + parts = target_key.split(".") + module = projector + for part in parts[:-1]: + module = getattr(module, part) + existing = getattr(module, parts[-1]) + with torch.no_grad(): + existing.copy_(f.get_tensor(key).to(existing.dtype)) + projector.eval().to(torch.bfloat16) + return projector + + +# --------------------------------------------------------------------------- +# NxDI imports (deferred to avoid import before patches are applied) +# --------------------------------------------------------------------------- +# These are imported at function/class scope to ensure patches are applied +# before the NxDI module import chain runs. + + +def _get_nxdi_imports(): + """Lazy import of NxDI classes. Call after patches are applied.""" + from neuronx_distributed_inference.models.config import NeuronConfig + from neuronx_distributed_inference.models.llama.modeling_llama import ( + NeuronLlamaModel, + ) + from neuronx_distributed_inference.models.image_to_text_model_base import ( + NeuronBaseForImageToText, + ) + from neuronx_distributed_inference.models.image_to_text_model_wrapper import ( + ImageToTextModelWrapper, + ) + from neuronx_distributed_inference.models.model_wrapper import ( + VISION_ENCODER_MODEL_TAG, + ) + from neuronx_distributed_inference.models.pixtral.modeling_pixtral import ( + PixtralInferenceConfig, + ) + from neuronx_distributed_inference.models.pixtral.modeling_pixtral_vision import ( + NeuronPixtralVisionModel, + PixtralVisionModelWrapper, + ) + from neuronx_distributed_inference.models.llama4.utils.encoder_utils import ( + generate_positions_from_mask, + pad_positions, + pad_vision_embeddings, + scatter_by_index_put, + ) + import neuronx_distributed_inference.modules.autobucketing as autobucketing + + return { + "NeuronConfig": NeuronConfig, + "NeuronLlamaModel": NeuronLlamaModel, + "NeuronBaseForImageToText": NeuronBaseForImageToText, + "ImageToTextModelWrapper": ImageToTextModelWrapper, + "VISION_ENCODER_MODEL_TAG": VISION_ENCODER_MODEL_TAG, + "PixtralInferenceConfig": PixtralInferenceConfig, + "NeuronPixtralVisionModel": NeuronPixtralVisionModel, + "PixtralVisionModelWrapper": PixtralVisionModelWrapper, + "generate_positions_from_mask": generate_positions_from_mask, + "pad_positions": pad_positions, + "pad_vision_embeddings": pad_vision_embeddings, + "scatter_by_index_put": scatter_by_index_put, + "autobucketing": autobucketing, + } + + +# --------------------------------------------------------------------------- +# Inference config builder +# --------------------------------------------------------------------------- + + +def build_inference_config( + model_path, + tp_degree=4, + batch_size=1, + seq_len=2048, + n_positions=4096, + vision_seq_len=4096, + tkg_buckets=None, + enable_tkg_kernel=True, +): + """Build a PixtralInferenceConfig for Ministral-3-14B. + + Args: + model_path: Path to HuggingFace checkpoint directory. + tp_degree: Tensor parallelism degree. Default 4 for trn2.3xlarge. + batch_size: Batch size. Default 1. + seq_len: Maximum text sequence length. Default 2048. + n_positions: Maximum position embeddings. Default 4096. + vision_seq_len: Maximum vision sequence length. Default 4096. + tkg_buckets: Token generation buckets. Default [256, 512, 1024, seq_len]. + enable_tkg_kernel: Enable the fused NKI TKG attention kernel. Default True. + + Returns: + PixtralInferenceConfig instance. + """ + nxdi = _get_nxdi_imports() + NeuronConfig = nxdi["NeuronConfig"] + PixtralInferenceConfig = nxdi["PixtralInferenceConfig"] + + if tkg_buckets is None: + tkg_buckets = [256, 512, 1024, seq_len] + + with open(os.path.join(model_path, "config.json")) as f: + full_config = json.load(f) + text_cfg = full_config["text_config"] + vision_cfg = full_config["vision_config"] + rope_params = text_cfg.get("rope_parameters", {}) + + text_neuron_config = NeuronConfig( + tp_degree=tp_degree, + batch_size=batch_size, + seq_len=seq_len, + n_positions=n_positions, + torch_dtype=torch.bfloat16, + on_device_sampling_config=None, + enable_bucketing=True, + token_generation_buckets=tkg_buckets, + fused_qkv=True, + qkv_nki_kernel_enabled=True, + attn_block_tkg_nki_kernel_enabled=enable_tkg_kernel, + attn_block_tkg_nki_kernel_cache_update=enable_tkg_kernel, + ) + vision_neuron_config = NeuronConfig( + tp_degree=tp_degree, + batch_size=batch_size, + seq_len=vision_seq_len, + torch_dtype=torch.bfloat16, + enable_bucketing=True, + on_device_sampling_config=None, + ) + + def custom_load_config(config_obj): + """Populate PixtralInferenceConfig from Ministral3 config.json. + + Ministral3 is not registered in HuggingFace AutoConfig, so we build + the text_config and vision_config SimpleNamespace objects manually. + """ + rope_scaling_dict = { + "rope_type": rope_params.get("rope_type", "yarn"), + "type": rope_params.get("type", "yarn"), + "factor": rope_params.get("factor", 16.0), + "beta_fast": rope_params.get("beta_fast", 32.0), + "beta_slow": rope_params.get("beta_slow", 1.0), + "original_max_position_embeddings": rope_params.get( + "original_max_position_embeddings", 16384 + ), + "mscale": rope_params.get("mscale", 1.0), + "mscale_all_dim": rope_params.get("mscale_all_dim", 1.0), + } + tc = SimpleNamespace( + hidden_size=text_cfg["hidden_size"], + num_attention_heads=text_cfg["num_attention_heads"], + num_hidden_layers=text_cfg["num_hidden_layers"], + num_key_value_heads=text_cfg["num_key_value_heads"], + vocab_size=text_cfg["vocab_size"], + max_position_embeddings=text_cfg["max_position_embeddings"], + rope_theta=rope_params.get("rope_theta", 1e9), + rope_scaling=rope_scaling_dict, + rms_norm_eps=text_cfg["rms_norm_eps"], + hidden_act=text_cfg["hidden_act"], + intermediate_size=text_cfg["intermediate_size"], + head_dim=text_cfg.get("head_dim", 128), + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=full_config.get("tie_word_embeddings", False), + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ) + vc = SimpleNamespace( + hidden_size=vision_cfg["hidden_size"], + num_attention_heads=vision_cfg["num_attention_heads"], + num_hidden_layers=vision_cfg["num_hidden_layers"], + num_channels=vision_cfg["num_channels"], + patch_size=vision_cfg["patch_size"], + image_size=vision_cfg["image_size"], + rope_theta=vision_cfg.get("rope_parameters", {}).get("rope_theta", 10000.0), + head_dim=vision_cfg.get("head_dim", 64), + intermediate_size=vision_cfg.get("intermediate_size", 4096), + hidden_act=vision_cfg.get("hidden_act", "silu"), + ) + config_obj.text_config = tc + config_obj.vision_config = vc + config_obj.multimodal_projector_bias = False + config_obj.projector_hidden_act = "gelu" + config_obj.vision_feature_layer = -1 + config_obj.spatial_merge_size = full_config.get( + "spatial_merge_size", SPATIAL_MERGE_SIZE + ) + config_obj.image_token_index = IMAGE_TOKEN_ID + config_obj._name_or_path = model_path + config_obj.output_attentions = False + config_obj.output_hidden_states = False + config_obj.return_dict = True + + return PixtralInferenceConfig( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + load_config=custom_load_config, + ) + + +# --------------------------------------------------------------------------- +# Model classes (built lazily via _build_model_classes) +# --------------------------------------------------------------------------- + + +def _build_model_classes(): + """Build the actual model classes with NxDI base classes. + + Must be called after apply_shard_over_heads_patch() and apply_multi_kv_tkg_patch(). + Returns a dict of class objects. + """ + nxdi = _get_nxdi_imports() + + NeuronLlamaModel = nxdi["NeuronLlamaModel"] + NeuronBaseForImageToText = nxdi["NeuronBaseForImageToText"] + ImageToTextModelWrapper = nxdi["ImageToTextModelWrapper"] + VISION_ENCODER_MODEL_TAG = nxdi["VISION_ENCODER_MODEL_TAG"] + PixtralInferenceConfig = nxdi["PixtralInferenceConfig"] + NeuronPixtralVisionModel = nxdi["NeuronPixtralVisionModel"] + PixtralVisionModelWrapper = nxdi["PixtralVisionModelWrapper"] + generate_positions_from_mask = nxdi["generate_positions_from_mask"] + pad_positions = nxdi["pad_positions"] + pad_vision_embeddings = nxdi["pad_vision_embeddings"] + scatter_by_index_put = nxdi["scatter_by_index_put"] + autobucketing = nxdi["autobucketing"] + + class _NeuronLeanstralTextModel(NeuronLlamaModel): + """Llama text model with vision embedding injection.""" + + def encode_vision_to_input(self, inputs_embeds, vision_embeddings, vision_mask): + return scatter_by_index_put(inputs_embeds, vision_embeddings, vision_mask) + + class _NeuronLeanstralVisionModel(NeuronPixtralVisionModel): + """Pixtral ViT without built-in projector.""" + + def __init__(self, config): + super().__init__(config) + if hasattr(self, "multi_modal_projector"): + del self.multi_modal_projector + + def forward(self, patch_embeds, attention_mask, position_ids): + patch_embeds = self.vision_patch_conv_linear(patch_embeds) + patch_embeds = self.vision_ln_pre(patch_embeds) + return self.vision_transformer( + patch_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + class _LeanstralVisionWrapper(PixtralVisionModelWrapper): + """Fix unpad slice for no-projector output.""" + + def pad_inputs(self, patch_embeds, attention_mask, position_ids): + result = super().pad_inputs(patch_embeds, attention_mask, position_ids) + if self.original_patch_embed_slices is not None: + self.original_patch_embed_slices[-1][-1] = ( + self.config.vision_config.hidden_size + ) + return result + + class _NeuronLeanstralForCausalLM(NeuronBaseForImageToText): + """Full VL model: Pixtral vision + Llama text + CPU projector.""" + + text_model_cls = _NeuronLeanstralTextModel + vision_model_cls = _NeuronLeanstralVisionModel + text_model_wrapper = ImageToTextModelWrapper + vision_model_wrapper = _LeanstralVisionWrapper + + def __init__(self, model_path, inference_config, *args, **kwargs): + super().__init__( + self.text_model_cls, + self.vision_model_cls, + self.text_model_wrapper, + self.vision_model_wrapper, + model_path, + inference_config, + *args, + **kwargs, + ) + self.cpu_projector = load_cpu_projector( + model_path, + vision_hidden_size=self.config.vision_config.hidden_size, + text_hidden_size=self.config.text_config.hidden_size, + ) + + @classmethod + def get_config_cls(cls): + return PixtralInferenceConfig + + def _get_model_outputs( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + vision_embeddings, + vision_mask, + deepstack_vision_embeds, + medusa_args, + llava_args, + slot_mapping=None, + block_table=None, + full_context_lens=None, + computed_context_lens=None, + rotary_position_ids=None, + ): + """Override to drop deepstack_vision_embeds (25th arg). + + ImageToTextModelWrapper traces the model with 24 positional args. + The base class passes 25 (including deepstack_vision_embeds), which + causes an arg-count mismatch at runtime. We drop it here. + """ + if rotary_position_ids is None: + rotary_position_ids = torch.empty(0) + + if self._is_prefill(position_ids): + outputs = self.context_encoding_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + *[torch.empty(0) for _ in range(16)], + rotary_position_ids, + vision_embeddings, + vision_mask, + ) + self.kv_cache_populated = True + is_run_on_neuron = self.context_encoding_model.is_neuron() + else: + outputs = self.token_generation_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + *[torch.empty(0) for _ in range(16)], + rotary_position_ids, + torch.empty(0, dtype=self.text_config.neuron_config.torch_dtype), + torch.empty(0, dtype=torch.bool), + ) + is_run_on_neuron = self.token_generation_model.is_neuron() + + return outputs, is_run_on_neuron + + def get_vision_compiler_args(self): + return ( + "--enable-saturate-infinity --auto-cast=none --model-type=transformer " + "--tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2 ' -O1 " + "--internal-hlo2tensorizer-options='--verify-hlo=true'" + ) + + def get_compiler_args(self): + return ( + "--enable-saturate-infinity --auto-cast=none --model-type=transformer " + "--tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2 --vectorize-strided-dma ' -O1 " + "--internal-hlo2tensorizer-options='--verify-hlo=true'" + ) + + def enable_vision_encoder( + self, enable_wlt_optimization=True, **model_init_kwargs + ): + new_config = copy.deepcopy(self.config) + if new_config.vision_config.neuron_config.enable_bucketing: + vc_nc = new_config.vision_config.neuron_config + if vc_nc.buckets == [vc_nc.seq_len] or vc_nc.buckets is None: + if vc_nc.seq_len > 1024: + vc_nc.buckets = autobucketing.generate_buckets( + 1024, vc_nc.seq_len + ) + else: + vc_nc.buckets = [vc_nc.seq_len] + new_config.neuron_config = copy.deepcopy( + new_config.vision_config.neuron_config + ) + self.vision_encoder_model = self.vision_model_wrapper( + config=new_config, + model_cls=self.vision_model_cls, + tag=VISION_ENCODER_MODEL_TAG, + compiler_args=self.get_vision_compiler_args(), + model_init_kwargs=model_init_kwargs, + priority_model_idx=(0 if enable_wlt_optimization else None), + pipeline_execution=True, + return_ranked_to_cpu=True, + ) + self.vision_models.append(self.vision_encoder_model) + + def get_required_kwargs(self): + return ["pixel_values", "vision_mask", "image_sizes"] + + def get_padding_length(self, input_ids): + buckets = self.context_encoding_model.config.neuron_config.buckets + for val in buckets: + if val >= input_ids.shape[1]: + return val + raise ValueError( + f"No bucket found for input_ids length {input_ids.shape[1]}. " + f"Available buckets: {buckets}" + ) + + def forward_atomic_prefill( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + pixel_values, + vision_mask, + image_sizes, + ): + """Run vision encoder + CPU projector + text prefill for one batch item.""" + if image_sizes is None: + img_h, img_w = pixel_values.shape[2], pixel_values.shape[3] + image_sizes = torch.tensor([[img_h, img_w]], dtype=torch.int32) + + if position_ids is None: + position_ids = torch.arange( + input_ids.shape[1], dtype=torch.int32 + ).unsqueeze(0) + + if seq_ids is None: + seq_ids = torch.zeros(input_ids.shape[0], dtype=torch.int32) + + if sampling_params is None: + sampling_params = torch.zeros( + input_ids.shape[0], 3, dtype=torch.float32 + ) + + if vision_mask is None: + vision_mask = (input_ids == IMAGE_TOKEN_ID).unsqueeze(-1).to(torch.bool) + vision_mask = generate_positions_from_mask(vision_mask.squeeze()) + + # 1. Vision encoder (on Neuron) + vision_embeddings = self.vision_encoder_model( + pixel_values.to(self.vision_config.neuron_config.torch_dtype), + image_sizes, + ) + + # 2. CPU projector: RMSNorm -> PatchMerger -> MLP + img_h = image_sizes[0, 0].item() + img_w = image_sizes[0, 1].item() + with torch.no_grad(): + projected = self.cpu_projector(vision_embeddings, img_h, img_w) + vision_embeddings = projected.unsqueeze(0).to( + self.text_config.neuron_config.torch_dtype + ) + + # 3. Pad to text bucket + pad_limit = self.get_padding_length(input_ids) + vision_mask = pad_positions(vision_mask, pad_limit, (pad_limit - 1)) + vision_embeddings = pad_vision_embeddings(vision_embeddings, pad_limit) + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + seq_ids=seq_ids, + sampling_params=sampling_params, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + deepstack_vision_embeds=None, + ) + + def check_empty_pixel_values(self, pixel_values): + if pixel_values is None: + return True + if isinstance(pixel_values, torch.Tensor): + return pixel_values.sum() == 0 + if isinstance(pixel_values, list): + return all(pv.sum() == 0 for pv in pixel_values) + return True + + def get_batch_line_mm_input(self, mm_input, index): + if mm_input is None: + return None + if isinstance(mm_input, list): + return mm_input[index] + if isinstance(mm_input, torch.Tensor): + return mm_input[index].unsqueeze(0) + return None + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + seq_ids=None, + sampling_params=None, + pixel_values=None, + vision_mask=None, + image_sizes=None, + **kwargs, + ): + """Forward pass supporting both VL prefill and text-only decode.""" + if input_ids.shape[-1] > 1 and not self.check_empty_pixel_values( + pixel_values + ): + # VL prefill: process each batch item through vision pipeline + outputs = [] + for index in range(input_ids.shape[0]): + outputs.append( + self.forward_atomic_prefill( + input_ids[index].unsqueeze(0), + attention_mask[index].unsqueeze(0) + if attention_mask is not None + else None, + position_ids[index].unsqueeze(0) + if position_ids is not None + else None, + seq_ids[index].unsqueeze(0) + if seq_ids is not None + else None, + sampling_params[index].unsqueeze(0) + if sampling_params is not None + else None, + self.get_batch_line_mm_input(pixel_values, index), + self.get_batch_line_mm_input(vision_mask, index), + self.get_batch_line_mm_input(image_sizes, index), + ) + ) + from transformers.modeling_outputs import CausalLMOutputWithPast + + logits = ( + torch.cat([o.logits for o in outputs], dim=0) + if outputs[0].logits is not None + else None + ) + tokens_list = [ + o.tokens + for o in outputs + if hasattr(o, "tokens") and o.tokens is not None + ] + tokens = torch.cat(tokens_list, dim=0) if tokens_list else None + out = CausalLMOutputWithPast(logits=logits, hidden_states=[]) + if tokens is not None: + out.tokens = tokens + return out + else: + # Text-only prefill or TKG decode + pad_limit = ( + self.get_padding_length(input_ids) if input_ids.shape[-1] > 1 else 1 + ) + vision_embeddings_dummy, vision_mask_dummy = ( + ImageToTextModelWrapper.get_dummy_vision_inputs( + config=self.text_config, + input_ids=input_ids, + n_active_tokens=pad_limit, + fill_value=(pad_limit - 1), + ) + ) + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + seq_ids=seq_ids, + sampling_params=sampling_params, + vision_embeddings=vision_embeddings_dummy, + vision_mask=vision_mask_dummy, + ) + + @staticmethod + def load_hf_model(model_path, **kwargs): + """Load HuggingFace checkpoint as a state_dict wrapper.""" + state_dict = {} + safetensors_files = sorted( + f + for f in os.listdir(model_path) + if f.endswith(".safetensors") and "consolidated" not in f + ) + for fname in safetensors_files: + with safe_open(os.path.join(model_path, fname), framework="pt") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key) + + class _StateDict: + def __init__(self, sd): + self._sd = sd + + def state_dict(self): + return self._sd + + return _StateDict(state_dict) + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict, inference_config): + """Convert HuggingFace Ministral3 weights to NxDI format. + + Handles: FP8 dequantization, text/vision split, QKV fusion, + attention key remapping, rank utilities, vision key remapping. + """ + # Phase 1: FP8 dequantize + clean = {} + fp8_count = 0 + for key, val in state_dict.items(): + if ".activation_scale" in key or ".weight_scale_inv" in key: + continue + if val.dtype == torch.float8_e4m3fn: + scale_key = key.replace(".weight", ".weight_scale_inv") + scale = ( + state_dict[scale_key].float() + if scale_key in state_dict + else torch.tensor(1.0) + ) + clean[key] = (val.float() * scale).to(torch.bfloat16) + fp8_count += 1 + else: + clean[key] = val + logger.info("Dequantized %d FP8 tensors", fp8_count) + + # Phase 2: Split text vs vision + text_dict = {} + for key, val in clean.items(): + if key.startswith("language_model.model."): + text_dict[key.replace("language_model.model.", "")] = val + elif key.startswith("language_model."): + text_dict[key.replace("language_model.", "")] = val + + # Phase 3: Remap attention keys + if inference_config.text_config.neuron_config.fused_qkv: + num_layers = inference_config.text_config.num_hidden_layers + for i in range(num_layers): + q_key = f"layers.{i}.self_attn.q_proj.weight" + k_key = f"layers.{i}.self_attn.k_proj.weight" + v_key = f"layers.{i}.self_attn.v_proj.weight" + if q_key in text_dict and k_key in text_dict and v_key in text_dict: + fused = torch.cat( + [text_dict[q_key], text_dict[k_key], text_dict[v_key]], + dim=0, + ) + text_dict[f"layers.{i}.self_attn.Wqkv.weight"] = fused + del text_dict[q_key], text_dict[k_key], text_dict[v_key] + else: + remap = { + ".self_attn.q_proj.": ".self_attn.qkv_proj.q_proj.", + ".self_attn.k_proj.": ".self_attn.qkv_proj.k_proj.", + ".self_attn.v_proj.": ".self_attn.qkv_proj.v_proj.", + } + remapped = {} + for key, val in text_dict.items(): + new_key = key + for pat, rep in remap.items(): + if pat in new_key: + new_key = new_key.replace(pat, rep) + break + remapped[new_key] = val + text_dict = remapped + + # Phase 4: Add rank utilities + tp = inference_config.text_config.neuron_config.tp_degree + for i in range(inference_config.text_config.num_hidden_layers): + text_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( + 0, tp, dtype=torch.int32 + ) + text_dict["rank_util.rank"] = torch.arange(0, tp, dtype=torch.int32) + + # Phase 5: Vision keys + vision_dict = {} + for key, val in clean.items(): + if key.startswith("vision_tower."): + new_key = key.replace("vision_tower.", "vision_") + vision_dict[new_key] = val.to( + inference_config.vision_config.neuron_config.torch_dtype + ) + # Reshape patch conv weight: Conv2d -> Linear equivalent + patch_key = "vision_patch_conv.weight" + if patch_key in vision_dict: + vision_dict["vision_patch_conv_linear.weight"] = vision_dict.pop( + patch_key + ).reshape( + -1, + inference_config.vision_config.num_channels + * inference_config.vision_config.patch_size**2, + ) + + merged = {**text_dict, **vision_dict} + logger.info( + "State dict converted: %d text keys, %d vision keys", + len(text_dict), + len(vision_dict), + ) + return merged + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + pass + + return { + "NeuronLeanstralTextModel": _NeuronLeanstralTextModel, + "NeuronLeanstralVisionModel": _NeuronLeanstralVisionModel, + "LeanstralVisionWrapper": _LeanstralVisionWrapper, + "NeuronLeanstralForCausalLM": _NeuronLeanstralForCausalLM, + } + + +# --------------------------------------------------------------------------- +# Public API: NeuronLeanstralForCausalLM +# --------------------------------------------------------------------------- +# The class is built lazily to allow patches to be applied first. + +_model_classes = None + + +def get_model_cls(): + """Get the NeuronLeanstralForCausalLM class, applying patches if needed. + + Returns the model class ready for instantiation: + model = get_model_cls()(model_path, inference_config) + """ + global _model_classes + if _model_classes is None: + # Ensure patches are applied before building classes + apply_shard_over_heads_patch() + apply_multi_kv_tkg_patch() + _model_classes = _build_model_classes() + return _model_classes["NeuronLeanstralForCausalLM"] diff --git a/contrib/models/Ministral-3-14B-Instruct-2512/src/multi_kv_adapter.py b/contrib/models/Ministral-3-14B-Instruct-2512/src/multi_kv_adapter.py new file mode 100644 index 00000000..f2ddff3e --- /dev/null +++ b/contrib/models/Ministral-3-14B-Instruct-2512/src/multi_kv_adapter.py @@ -0,0 +1,318 @@ +# ============================================================ +# MULTI-KV-HEAD TKG KERNEL ADAPTER +# ============================================================ +# Appended to attention_base.py by setup_patches.py +# Provides multi-KV-head support for the TKG fused attention NKI kernel. +# When kv_heads_per_rank == 1, delegates to the stock NxDI TKG method. +# When kv_heads_per_rank > 1, calls the Leanstral-derived forked kernel +# using a "virtual batch" approach (each KV head becomes a virtual batch entry). + +import logging as _mkv_logging + +_mkv_logger = _mkv_logging.getLogger("multi_kv_tkg_adapter") + +_original_attn_block_tkg_nki = NeuronAttentionBase.attention_block_tokengen_nki_kernel + + +def _multi_kv_attention_block_tokengen_nki_kernel( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + active_mask=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + rotary_position_ids=None, + update_kv_per_layer=True, + active_block_table=None, + use_polar_compatible_rope=False, +): + """ + Multi-KV-head TKG kernel adapter. + + When kv_heads_per_rank == 1, delegates to the original NxDI method. + When kv_heads_per_rank > 1, calls the Leanstral forked kernel with + n_kv_heads parameter, handling the interface translation. + """ + import torch + + kv_heads = self.num_key_value_heads + q_heads = self.num_heads + + # Fast path: kv_heads=1 per rank, use original unmodified method + if kv_heads == 1: + return _original_attn_block_tkg_nki( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + active_mask, + cos_cache, + sin_cache, + rmsnorm, + rotary_position_ids, + update_kv_per_layer, + active_block_table, + use_polar_compatible_rope=use_polar_compatible_rope, + ) + + # --- Multi-KV-head path --- + assert q_heads % kv_heads == 0, ( + f"q_heads ({q_heads}) must be divisible by kv_heads ({kv_heads})" + ) + + # Sequence parallel gather + if self.sequence_parallel_enabled and self.tensor_model_parallel_group is not None: + hidden_states = gather_from_sequence_parallel_region( + hidden_states, + self.sequence_dimension, + process_group=self.tensor_model_parallel_group, + ) + + from nkilib.experimental.transformer.attention_block_tkg_multi_kv import ( + attention_block_tkg, + ) + + bsz, q_len, h = hidden_states.size() + h_out = h // 2 if getattr(self, "is_eagle3_draft", False) else h + + # ---- RoPE cos/sin preparation ---- + skip_rope = False + rope_contiguous_layout = not use_polar_compatible_rope + + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(hidden_states, rotary_position_ids) + cos_cache = cos_cache[..., : cos_cache.shape[-1] // 2].permute(2, 0, 1) + sin_cache = sin_cache[..., : sin_cache.shape[-1] // 2].permute(2, 0, 1) + elif use_polar_compatible_rope: + rotary_freqs = precompute_freqs_cis( + self.head_dim, + self.neuron_config.max_context_length * 2, + self.rope_theta, + self.use_scaled_rope, + device=hidden_states.device, + ) + rotary_freqs = rotary_freqs[position_ids] + cos_cache = rotary_freqs.cos().permute(2, 0, 1) + sin_cache = rotary_freqs.sin().permute(2, 0, 1) + else: + expected_shape = (self.head_dim // 2, bsz, q_len) + cos_cache = torch.zeros(expected_shape).to(hidden_states) + sin_cache = torch.zeros(expected_shape).to(hidden_states) + skip_rope = True + + cos_for_kernel = None if skip_rope else cos_cache + sin_for_kernel = None if skip_rope else sin_cache + + # ---- KV Cache ---- + K_prior = past_key_value[0].data + V_prior = past_key_value[1].data + + the_dtype = hidden_states.dtype + the_device = hidden_states.device + + # ---- Mask preparation ---- + s_prior = attention_mask.shape[-1] + attention_mask = attention_mask.expand(-1, q_heads, -1, -1).contiguous() + + expected_active_mask_shape = (bsz, 1, q_len, q_len) + if q_len == 1: + active_mask = torch.ones( + expected_active_mask_shape, dtype=the_dtype, device=the_device + ) + active_mask = active_mask.expand(-1, q_heads, -1, -1).contiguous() + attention_mask[:, :, :, -q_len:] = active_mask + attention_mask_nki = attention_mask.permute(3, 0, 1, 2).contiguous() + + # Per-group mask for virtual batch approach + q_per_kv_group = q_heads // kv_heads + group_attention_mask = attention_mask_nki[:, :, :q_per_kv_group, :].contiguous() + group_attention_mask = group_attention_mask.repeat(1, kv_heads, 1, 1).contiguous() + + # ---- Weights ---- + W_qkv = self.get_qkv_proj().Wqkv.weight.data + W_qkv_bias = ( + self.get_qkv_proj().Wqkv.bias.data.unsqueeze(0) if self.qkv_bias else None + ) + + fused_rmsnorm = rmsnorm is not None + W_gamma = ( + rmsnorm.weight.data.unsqueeze(0) + if fused_rmsnorm + else torch.ones((1, h), device=the_device) + ) + + update_cache_in_kernel = ( + update_kv_per_layer and self.attn_block_tkg_nki_kernel_cache_update + ) + + W_out = self.get_o_proj().o_proj.weight.data + W_out_bias = ( + self.get_o_proj().o_proj.bias.data.unsqueeze(0) if self.o_bias else None + ) + if W_out_bias is not None: + W_out_bias = W_out_bias / self.tp_degree + + # ---- Output buffers ---- + if update_cache_in_kernel: + K = K_prior + V = V_prior + else: + K = torch.zeros( + self.head_dim, bsz, kv_heads, q_len, dtype=the_dtype, device=the_device + ) + V = torch.zeros( + bsz, kv_heads, q_len, self.head_dim, dtype=the_dtype, device=the_device + ) + + # ---- V active HBM buffer (workaround NCC_IBIR440) ---- + B_virt = bsz * kv_heads + v_active_hbm_buf = torch.zeros( + B_virt, 1, q_len, self.head_dim, dtype=the_dtype, device=the_device + ) + + # ---- kv_cache_update_idx ---- + kv_cache_update_idx = position_ids[:, :1].to(torch.int32) + + # ---- Replicated update idx for multi-KV ---- + kv_cache_update_idx_virt = kv_cache_update_idx.repeat_interleave(kv_heads, dim=0) + + # ---- QK norm ---- + has_qk_layernorm = self.q_layernorm is not None and self.k_layernorm is not None + qk_norm_eps = self.rms_norm_eps if self.rms_norm_eps else 1e-6 + is_pre_rope_qk_norm = ( + has_qk_layernorm and self.qk_norm_placement == QKNormPlacement.PRE_ROPE + ) + is_post_rope_qk_norm = ( + has_qk_layernorm and self.qk_norm_placement == QKNormPlacement.POST_ROPE + ) + rmsnorm_QK_pre_rope_W_Q = ( + self.q_layernorm.weight.data.unsqueeze(0) if is_pre_rope_qk_norm else None + ) + rmsnorm_QK_pre_rope_W_K = ( + self.k_layernorm.weight.data.unsqueeze(0) if is_pre_rope_qk_norm else None + ) + rmsnorm_QK_post_rope_W_Q = ( + self.q_layernorm.weight.data.unsqueeze(0) if is_post_rope_qk_norm else None + ) + rmsnorm_QK_post_rope_W_K = ( + self.k_layernorm.weight.data.unsqueeze(0) if is_post_rope_qk_norm else None + ) + + # ---- Grid ---- + lnc = self.logical_nc_config + grid = lnc if isinstance(lnc, int) else int(lnc) + + # ---- Call multi-KV kernel ---- + from nkilib.core.utils.common_types import QuantizationType + + attn_output, K, V = attention_block_tkg[grid]( + X=hidden_states, + X_hidden_dim_actual=getattr(self.config, "original_hidden_size", None), + rmsnorm_X_enabled=fused_rmsnorm, + rmsnorm_X_eps=self.rms_norm_eps, + rmsnorm_X_gamma=W_gamma, + W_qkv=W_qkv, + bias_qkv=W_qkv_bias, + quantization_type_qkv=QuantizationType.NONE, + weight_dequant_scale_qkv=None, + input_dequant_scale_qkv=None, + rmsnorm_QK_pre_rope_enabled=is_pre_rope_qk_norm, + rmsnorm_QK_pre_rope_eps=qk_norm_eps if is_pre_rope_qk_norm else 0.0, + rmsnorm_QK_pre_rope_W_Q=rmsnorm_QK_pre_rope_W_Q, + rmsnorm_QK_pre_rope_W_K=rmsnorm_QK_pre_rope_W_K, + cos=cos_for_kernel, + sin=sin_for_kernel, + rope_contiguous_layout=rope_contiguous_layout, + rmsnorm_QK_post_rope_enabled=is_post_rope_qk_norm, + rmsnorm_QK_post_rope_eps=qk_norm_eps if is_post_rope_qk_norm else 0.0, + rmsnorm_QK_post_rope_W_Q=rmsnorm_QK_post_rope_W_Q, + rmsnorm_QK_post_rope_W_K=rmsnorm_QK_post_rope_W_K, + K_cache_transposed=self.k_cache_transposed, + active_blocks_table=active_block_table, + K_cache=K_prior, + V_cache=V_prior, + attention_mask=attention_mask_nki, + sink=None, + softmax_scale=None if self.softmax_scale is None else (1 / self.softmax_scale), + update_cache=update_cache_in_kernel, + kv_cache_update_idx=kv_cache_update_idx, + W_out=W_out, + bias_out=W_out_bias, + quantization_type_out=QuantizationType.NONE, + weight_dequant_scale_out=None, + input_dequant_scale_out=None, + transposed_out=False, + out_in_sb=False, + # Multi-KV-head parameters + n_kv_heads=kv_heads, + n_q_heads=q_heads, + head_dim=self.head_dim, + s_max_ctx=V_prior.shape[2], + group_attention_mask=group_attention_mask, + v_active_hbm=v_active_hbm_buf, + kv_cache_update_idx_virt=kv_cache_update_idx_virt, + ) + + # ---- Post-processing ---- + attn_output = attn_output.reshape((bsz, q_len, h_out)) + + # All-reduce or reduce-scatter across TP ranks + if self.sequence_parallel_enabled: + attn_output = reduce_scatter_to_sequence_parallel_region( + attn_output, 1, process_group=self.tensor_model_parallel_group + ) + else: + from neuronx_distributed_inference.modules.attention.attention_base import ( + EPDispatchOption, + ) + + if self.ep_dispatch_cc_option == EPDispatchOption.AR_AG: + attn_output = reduce_from_tensor_model_parallel_region( + attn_output, process_group=self.tensor_model_parallel_group + ) + elif self.ep_dispatch_cc_option == EPDispatchOption.RS_AG: + attn_output = reduce_scatter_to_tensor_model_parallel_region_with_dim( + attn_output, + partition_dim=0, + process_group=self.tensor_model_parallel_group, + ) + elif self.ep_dispatch_cc_option == EPDispatchOption.AG_AR: + from neuronx_distributed.parallel_layers.parallel_state import ( + get_data_parallel_attention_dp_group, + ) + + attn_output = gather_from_tensor_model_parallel_region_with_dim( + attn_output, + gather_dim=0, + process_group=get_data_parallel_attention_dp_group(), + ) + else: + raise ValueError(f"Unknown EPDispatchOption: {self.ep_dispatch_cc_option}") + + # ---- KV output ---- + if not update_cache_in_kernel: + if K.dim() == 4: + if self.k_cache_transposed: + K = K.permute(1, 2, 0, 3) + else: + K = K.permute(1, 2, 3, 0) + else: + if self.k_cache_transposed: + K = K.permute(1, 0, 2).unsqueeze(1) + else: + K = K.permute(1, 2, 0).unsqueeze(1) + if V.dim() == 3: + V = V.unsqueeze(1) + + return attn_output, (K, V), cos_cache, sin_cache + + +NeuronAttentionBase.attention_block_tokengen_nki_kernel = ( + _multi_kv_attention_block_tokengen_nki_kernel +) +_mkv_logger.info("Applied multi-KV-head TKG kernel adapter to NeuronAttentionBase") diff --git a/contrib/models/Ministral-3-14B-Instruct-2512/src/patch_native_multi_kv.py b/contrib/models/Ministral-3-14B-Instruct-2512/src/patch_native_multi_kv.py new file mode 100644 index 00000000..ade660c4 --- /dev/null +++ b/contrib/models/Ministral-3-14B-Instruct-2512/src/patch_native_multi_kv.py @@ -0,0 +1,435 @@ +""" +Monkeypatch: Native multi-KV-head NKI kernel adapter for NxDI. + +Replaces NxDI's `NeuronAttentionBase.attention_block_tokengen_nki_kernel` dispatch +method to call our modified nki-library kernel (attention_block_tkg_multi_kv) +instead of the bundled `llama3_nki_attention_block_token_gen_kernel`. + +The bundled kernel hardcodes kv_heads=1 per rank, and the per-group monkeypatch +(patch_attn_block_multi_kv.py) hits a compiler ICE (NCC_ITEN404) at S_ctx >= 512. + +This adapter translates NxDI's calling convention to the nki-library kernel's +interface, handling: + - Mask merging: mask_cache + mask_active -> single attention_mask in NKI layout + - Parameter renaming (W_gamma -> rmsnorm_X_gamma, etc.) + - Passing n_kv_heads for multi-KV-head support + - Default values for nki-library-only parameters (quantization=NONE, etc.) + +For kv_heads_per_rank == 1, the patch is a no-op passthrough to the original method. + +Usage: + from . import patch_native_multi_kv + patch_native_multi_kv.apply_patch() +""" + +import logging +import torch + +logger = logging.getLogger(__name__) + +_original_method = None +_patched = False + + +def _patched_attention_block_tokengen_nki_kernel( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + active_mask, + cos_cache, + sin_cache, + rmsnorm, + rotary_position_ids, + update_kv_per_layer, + active_block_table, + use_polar_compatible_rope=False, +): + """ + Native multi-KV-head dispatch using the nki-library kernel fork. + + When kv_heads_per_rank == 1, delegates to the original NxDI method. + When kv_heads_per_rank > 1, calls our attention_block_tkg kernel with + n_kv_heads parameter, handling the interface translation. + """ + kv_heads = self.num_key_value_heads + q_heads = self.num_heads + + # Fast path: kv_heads=1 per rank, use original unmodified method + if kv_heads == 1: + return _original_method( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + active_mask, + cos_cache, + sin_cache, + rmsnorm, + rotary_position_ids, + update_kv_per_layer, + active_block_table, + use_polar_compatible_rope=use_polar_compatible_rope, + ) + + # --- Multi-KV-head path: use nki-library kernel with n_kv_heads --- + assert q_heads % kv_heads == 0, ( + f"q_heads ({q_heads}) must be divisible by kv_heads ({kv_heads})" + ) + + # Sequence parallel gather (same as original method) + if self.sequence_parallel_enabled and self.tensor_model_parallel_group is not None: + from neuronx_distributed.parallel_layers.mappings import ( + gather_from_sequence_parallel_region, + ) + + hidden_states = gather_from_sequence_parallel_region( + hidden_states, + self.sequence_dimension, + process_group=self.tensor_model_parallel_group, + ) + + from .attention_block_tkg_multi_kv import attention_block_tkg + + bsz, q_len, h = hidden_states.size() + + # ---- RoPE cos/sin preparation (same as original NxDI method) ---- + skip_rope = False + + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(hidden_states, rotary_position_ids) + cos_cache = cos_cache[..., : cos_cache.shape[-1] // 2].permute(2, 0, 1) + sin_cache = sin_cache[..., : sin_cache.shape[-1] // 2].permute(2, 0, 1) + elif use_polar_compatible_rope: + from neuronx_distributed_inference.modules.attention.attention_base import ( + precompute_freqs_cis, + ) + + rotary_freqs = precompute_freqs_cis( + self.head_dim, + self.neuron_config.max_context_length * 2, + self.rope_theta, + self.use_scaled_rope, + device=hidden_states.device, + ) + rotary_freqs = rotary_freqs[position_ids] + cos_cache = rotary_freqs.cos().permute(2, 0, 1) + sin_cache = rotary_freqs.sin().permute(2, 0, 1) + else: + expected_rope_coeff_shape = (self.head_dim // 2, bsz, q_len) + cos_cache = torch.zeros(expected_rope_coeff_shape).to(hidden_states) + sin_cache = torch.zeros(expected_rope_coeff_shape).to(hidden_states) + skip_rope = True + + # ---- KV Cache ---- + # .data unwraps PlaceholderParameter -> plain Tensor for NKI tracer compatibility + K_prior = past_key_value[0].data + V_prior = past_key_value[1].data + + the_dtype = hidden_states.dtype + the_device = hidden_states.device + + # ---- Mask preparation ---- + # NxDI provides: + # attention_mask (mask_cache): [bsz, 1, q_len, s_prior] or [bsz, q_heads, q_len, s_prior] + # active_mask: [bsz, 1, q_len, q_len] + # Our nki-library kernel expects: + # attention_mask: [S_ctx, B, q_heads, S_tkg] (NKI layout) + # where S_ctx = s_prior and S_tkg = q_len + + s_prior = attention_mask.shape[-1] + + # Expand to full q_heads if needed + attention_mask = attention_mask.expand(-1, q_heads, -1, -1).contiguous() + + expected_active_mask_shape = (bsz, 1, q_len, q_len) + if q_len == 1: + active_mask = torch.ones( + expected_active_mask_shape, dtype=the_dtype, device=the_device + ) + active_mask = active_mask.expand(-1, q_heads, -1, -1).contiguous() + + # Merge: overwrite the last q_len positions of s_prior with active_mask + attention_mask[:, :, :, -q_len:] = active_mask + + # Transpose to NKI layout: [bsz, q_heads, q_len, s_prior] -> [s_prior, bsz, q_heads, q_len] + attention_mask_nki = attention_mask.permute(3, 0, 1, 2).contiguous() + + # Create per-group mask for multi-KV-head attention (virtual batch approach) + # For the virtual batch approach, attention_tkg is called ONCE with + # B_virt = bsz * kv_heads, q_head = q_per_kv_group. + # Mask shape: [s_prior, B_virt, q_per_kv_group, q_len] + # All heads share the same causal mask, so expand bsz -> B_virt by repeating. + q_per_kv_group = q_heads // kv_heads + group_attention_mask = attention_mask_nki[:, :, :q_per_kv_group, :].contiguous() + # group_attention_mask: [s_prior, bsz, q_per_kv_group, q_len] + # Expand to [s_prior, bsz * kv_heads, q_per_kv_group, q_len] + group_attention_mask = group_attention_mask.repeat(1, kv_heads, 1, 1).contiguous() + + # ---- Weights ---- + # CRITICAL: Use .data to unwrap Parameter/PlaceholderParameter into plain Tensor. + # The NKI tracer cannot access .shape on Parameter objects — it treats them as "none" + # in the AST. All tensors passed to @nki.jit must be plain Tensor, not Parameter. + W_qkv = self.get_qkv_proj().Wqkv.weight.data + W_qkv_bias = ( + self.get_qkv_proj().Wqkv.bias.data.unsqueeze(0) if self.qkv_bias else None + ) + + fused_rmsnorm = rmsnorm is not None + W_gamma = ( + rmsnorm.weight.data.unsqueeze(0) + if fused_rmsnorm + else torch.ones((1, h), device=the_device) + ) + + update_cache_in_kernel = ( + update_kv_per_layer and self.attn_block_tkg_nki_kernel_cache_update + ) + + W_out = self.get_o_proj().o_proj.weight.data + h_out = h // 2 if getattr(self, "is_eagle3_draft", False) else h + assert W_out.shape == (q_heads * self.head_dim, h_out), ( + f"W_out.shape = {W_out.shape}" + ) + + W_out_bias = ( + self.get_o_proj().o_proj.bias.data.unsqueeze(0) if self.o_bias else None + ) + if W_out_bias is not None: + W_out_bias = W_out_bias / self.tp_degree + + # ---- Output buffers ---- + if update_cache_in_kernel: + K = K_prior + V = V_prior + else: + # Multi-KV-head: K output is [d, B, kv_heads, q_len], V is [B, kv_heads, q_len, d] + K = torch.zeros( + self.head_dim, bsz, kv_heads, q_len, dtype=the_dtype, device=the_device + ) + V = torch.zeros( + bsz, kv_heads, q_len, self.head_dim, dtype=the_dtype, device=the_device + ) + + # ---- V active HBM buffer (multi-KV-head) ---- + # Pre-allocate in PyTorch to avoid NCC_IBIR440 (DRAM allocator failure for + # kernel-internal nl.ndarray(..., buffer=nl.shared_hbm) with multi-KV-head). + # Shape: [B_virt, 1, q_len, head_dim] where B_virt = bsz * kv_heads + if kv_heads > 1: + B_virt = bsz * kv_heads + v_active_hbm_buf = torch.zeros( + B_virt, 1, q_len, self.head_dim, dtype=the_dtype, device=the_device + ) + else: + v_active_hbm_buf = None + + # ---- RoPE parameters ---- + # skip_rope: pass None for cos/sin to skip RoPE in nki-library kernel + cos_for_kernel = None if skip_rope else cos_cache + sin_for_kernel = None if skip_rope else sin_cache + + # rope_contiguous_layout = rope_first_second_half_impl = not use_polar_compatible_rope + rope_contiguous_layout = not use_polar_compatible_rope + + # ---- QK norm (pre-rope or post-rope) ---- + # For Mistral3: use_qk_norm is typically False + # The bundled kernel has a single qk_norm boolean; nki-library splits into pre/post + qk_norm = self.use_qk_norm + pre_rope_rmsnorm = self.neuron_config.pre_rope_rmsnorm + + # If qk_norm is enabled and pre_rope_rmsnorm placement: + rmsnorm_QK_pre_rope_enabled = qk_norm and pre_rope_rmsnorm + rmsnorm_QK_post_rope_enabled = qk_norm and not pre_rope_rmsnorm + + # QK norm weights — the bundled kernel derives these internally from W_qkv + # For now, pass None and let the kernel use unit scaling (identity norm) + # This is safe because Mistral3 doesn't use QK norm + rmsnorm_QK_pre_rope_W_Q = None + rmsnorm_QK_pre_rope_W_K = None + rmsnorm_QK_post_rope_W_Q = None + rmsnorm_QK_post_rope_W_K = None + + # ---- kv_cache_update_idx ---- + # NxDI passes position_ids as [B, q_len] int32 + # nki-library kernel expects kv_cache_update_idx as [B, 1] uint32 + # For TKG (q_len=1), position_ids is already [B, 1] + kv_cache_update_idx = position_ids.to(torch.int32) + if kv_cache_update_idx.dim() == 1: + kv_cache_update_idx = kv_cache_update_idx.unsqueeze(1) + + # ---- kv_cache_update_idx_virt (multi-KV-head) ---- + # Replicate position indices for each KV head within each batch. + # Shape: [B_virt, 1] where B_virt = bsz * kv_heads + # Each batch's position index is repeated kv_heads times. + if kv_heads > 1: + kv_cache_update_idx_virt = kv_cache_update_idx.repeat_interleave( + kv_heads, dim=0 + ) + else: + kv_cache_update_idx_virt = None + + # ---- H_actual (for padded checkpoints) ---- + X_hidden_dim_actual = getattr(self.config, "original_hidden_size", None) + + # ---- Grid ---- + # The bundled kernel uses nc(lnc_config) which returns a VNC object. + # Our nki.jit kernel needs a plain integer grid. + # Extract the integer value from the logical_nc_config. + lnc = self.logical_nc_config + grid = lnc if isinstance(lnc, int) else int(lnc) + + # WORKAROUND: Force grid=1 for multi-KV-head to avoid NCC_IXLV002 barrier + # mismatch when attention_tkg runs with B_virt > 1 on LNC=2. + # With grid=1, the kernel runs on a single program (no LNC split). + # Performance impact: ~5-10% slower cache update (no K/V core parallelism), + # but correctness is preserved. + if kv_heads > 1: + grid = 1 + + # ---- Call our nki-library kernel ---- + from nkilib.core.utils.common_types import QuantizationType + + attn_output, K, V = attention_block_tkg[grid]( + # -- input + X=hidden_states, + X_hidden_dim_actual=X_hidden_dim_actual, + # -- rmsnorm X + rmsnorm_X_enabled=fused_rmsnorm, + rmsnorm_X_eps=self.rms_norm_eps, + rmsnorm_X_gamma=W_gamma, + # -- qkv projections + W_qkv=W_qkv, + bias_qkv=W_qkv_bias, + quantization_type_qkv=QuantizationType.NONE, + weight_dequant_scale_qkv=None, + input_dequant_scale_qkv=None, + # -- Q/K processing: pre-RoPE RMSNorm + rmsnorm_QK_pre_rope_enabled=rmsnorm_QK_pre_rope_enabled, + rmsnorm_QK_pre_rope_eps=self.rms_norm_eps, + rmsnorm_QK_pre_rope_W_Q=rmsnorm_QK_pre_rope_W_Q, + rmsnorm_QK_pre_rope_W_K=rmsnorm_QK_pre_rope_W_K, + # -- Q/K processing: RoPE + cos=cos_for_kernel, + sin=sin_for_kernel, + rope_contiguous_layout=rope_contiguous_layout, + # -- Q/K processing: post-RoPE RMSNorm + rmsnorm_QK_post_rope_enabled=rmsnorm_QK_post_rope_enabled, + rmsnorm_QK_post_rope_eps=self.rms_norm_eps, + rmsnorm_QK_post_rope_W_Q=rmsnorm_QK_post_rope_W_Q, + rmsnorm_QK_post_rope_W_K=rmsnorm_QK_post_rope_W_K, + # -- attention + K_cache_transposed=self.k_cache_transposed, + active_blocks_table=active_block_table, + K_cache=K_prior, + V_cache=V_prior, + attention_mask=attention_mask_nki, + sink=None, + softmax_scale=None, + # -- KV cache update + update_cache=update_cache_in_kernel, + kv_cache_update_idx=kv_cache_update_idx, + k_scale=None, + v_scale=None, + # -- output projection + W_out=W_out, + bias_out=W_out_bias, + quantization_type_out=QuantizationType.NONE, + weight_dequant_scale_out=None, + input_dequant_scale_out=None, + transposed_out=False, + # -- output + out_in_sb=False, + sbm=None, + skip_attention=False, + # -- Multi-KV-head + n_kv_heads=kv_heads, + # -- Number of query heads per rank (explicit, avoids W_qkv.shape on PlaceholderParameter) + n_q_heads=q_heads, + # -- Head dimension (explicit, avoids NKI .shape on 4D cache) + head_dim=self.head_dim, + # -- Max context length (from KV cache shape, accessible in PyTorch) + s_max_ctx=V_prior.shape[2], + # -- Per-group attention mask (avoids NKI reshape-on-slice inside kernel) + group_attention_mask=group_attention_mask, + # -- Pre-allocated HBM buffer for V active tokens (multi-KV-head only) + v_active_hbm=v_active_hbm_buf, + # -- Replicated kv_cache_update_idx for multi-KV-head cache update + kv_cache_update_idx_virt=kv_cache_update_idx_virt, + ) + + # ---- Post-processing: reshape output ---- + # Our kernel returns attn_output with O-proj already applied + # Shape depends on transposed_out and out_in_sb: + # transposed_out=False, out_in_sb=False: [B*S_tkg, H] @ HBM + attn_output = attn_output.reshape((bsz, q_len, h_out)) + + # All-reduce or reduce-scatter across TP ranks + from neuronx_distributed.parallel_layers.mappings import ( + reduce_from_tensor_model_parallel_region, + ) + + if self.sequence_parallel_enabled: + from neuronx_distributed.parallel_layers.mappings import ( + reduce_scatter_to_sequence_parallel_region, + ) + + attn_output = reduce_scatter_to_sequence_parallel_region( + attn_output, 1, process_group=self.tensor_model_parallel_group + ) + else: + attn_output = reduce_from_tensor_model_parallel_region( + attn_output, process_group=self.tensor_model_parallel_group + ) + + # ---- KV output handling ---- + if not update_cache_in_kernel: + # K from kernel: [d, B, kv_heads, q_len] or [d, B, q_len] (if kv_heads=1) + # V from kernel: [B, kv_heads, q_len, d] or [B, 1, q_len, d] (if kv_heads=1) + # NxDI expects: + # K: [B, kv_heads, d, q_len] if k_cache_transposed else [B, kv_heads, q_len, d] + # V: [B, kv_heads, q_len, d] + if K.dim() == 4: + # [d, B, kv_heads, q_len] -> [B, kv_heads, d, q_len] + if self.k_cache_transposed: + K = K.permute(1, 2, 0, 3) + else: + # [d, B, kv_heads, q_len] -> [B, kv_heads, q_len, d] + K = K.permute(1, 2, 3, 0) + else: + # Single head: [d, B, q_len] + if self.k_cache_transposed: + K = K.permute(1, 0, 2).unsqueeze(1) # [B, 1, d, q_len] + else: + K = K.permute(1, 2, 0).unsqueeze(1) # [B, 1, q_len, d] + if V.dim() == 3: + V = V.unsqueeze(1) # [B, q_len, d] -> [B, 1, q_len, d] + # V is already [B, kv_heads, q_len, d] from our kernel + + return attn_output, (K, V), cos_cache, sin_cache + + +def apply_patch(): + """ + Apply the native multi-KV-head kernel adapter to NeuronAttentionBase. + + Must be called after NxDI imports but before model compilation. + """ + global _original_method, _patched + + from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, + ) + + _original_method = NeuronAttentionBase.attention_block_tokengen_nki_kernel + NeuronAttentionBase.attention_block_tokengen_nki_kernel = ( + _patched_attention_block_tokengen_nki_kernel + ) + _patched = True + logger.info( + "Patched NeuronAttentionBase.attention_block_tokengen_nki_kernel " + "with native multi-KV-head nki-library kernel adapter" + ) + return True diff --git a/contrib/models/Ministral-3-14B-Instruct-2512/src/setup_patches.py b/contrib/models/Ministral-3-14B-Instruct-2512/src/setup_patches.py new file mode 100644 index 00000000..3e6b8be0 --- /dev/null +++ b/contrib/models/Ministral-3-14B-Instruct-2512/src/setup_patches.py @@ -0,0 +1,573 @@ +#!/usr/bin/env python3 +""" +Setup patches for Ministral-3-14B-Instruct-2512 (Leanstral) on SDK 2.29. + +Applies to: Ministral 14B (Mistral3ForConditionalGeneration, 32Q/8KV at TP=4). + +Applies all required patches to a fresh DLAMI 20260410 (SDK 2.29) installation: + 1. Mistral rms_norm_eps pass-through (NxDI) + 2. nkilib QKV CTE eps guard + 3. neuronxcc QKV CTE eps guard + 4. QKV weight fusion in convert_hf_to_neuron_state_dict (NxDI) + 5. Fused RMSNorm in Mistral decoder forward (NxDI) + 6. Multi-KV TKG kernel + adapter (nkilib + NxDI) + +Usage: + python setup_patches.py [--venv /path/to/venv] [--dry-run] +""" + +import argparse +import os +import re +import shutil +import sys + +# --------------------------------------------------------------------------- +# Path resolution +# --------------------------------------------------------------------------- + +DEFAULT_VENV = "/opt/aws_neuronx_venv_pytorch_inference_vllm_0_16" + + +def resolve_paths(venv_root): + sp = os.path.join(venv_root, "lib", "python3.12", "site-packages") + return { + "site_packages": sp, + "modeling_mistral": os.path.join( + sp, + "neuronx_distributed_inference", + "models", + "mistral", + "modeling_mistral.py", + ), + "attention_base": os.path.join( + sp, + "neuronx_distributed_inference", + "modules", + "attention", + "attention_base.py", + ), + "nkilib_qkv_cte": os.path.join(sp, "nkilib", "core", "qkv", "qkv_cte.py"), + "neuronxcc_qkv_cte": os.path.join( + sp, "neuronxcc", "nki", "_pre_prod_kernels", "qkv_cte_impl.py" + ), + "nkilib_transformer": os.path.join(sp, "nkilib", "experimental", "transformer"), + } + + +def backup(path): + bak = path + ".bak_contrib" + if not os.path.exists(bak): + shutil.copy2(path, bak) + print(f" Backed up {os.path.basename(path)}") + + +def read(path): + with open(path) as f: + return f.read() + + +def write(path, content): + with open(path, "w") as f: + f.write(content) + + +# --------------------------------------------------------------------------- +# Patch 1: Mistral rms_norm_eps pass-through +# --------------------------------------------------------------------------- + + +def patch_rms_norm_eps(paths, dry_run=False): + """Add rms_norm_eps=config.rms_norm_eps to NeuronMistralAttention super().__init__().""" + fpath = paths["modeling_mistral"] + content = read(fpath) + + if "rms_norm_eps=config.rms_norm_eps" in content: + print(" [1] rms_norm_eps: already patched") + return True + + # Find the super().__init__ call in NeuronMistralAttention and add rms_norm_eps + idx = content.find("class NeuronMistralAttention") + if idx == -1: + print(" [1] rms_norm_eps: ERROR - cannot find NeuronMistralAttention class") + return False + # Find the super().__init__ after this class + super_idx = content.find("super().__init__(", idx) + if super_idx == -1: + print(" [1] rms_norm_eps: ERROR - cannot find super().__init__ call") + return False + # Find the closing paren (handle nested parens like getattr()) + paren_depth = 0 + end_idx = super_idx + len("super().__init__(") + for i in range(end_idx, len(content)): + if content[i] == "(": + paren_depth += 1 + elif content[i] == ")": + if paren_depth == 0: + end_idx = i + break + paren_depth -= 1 + + # Insert rms_norm_eps before the closing paren + call_content = content[super_idx:end_idx] + if "rms_norm_eps" not in call_content: + # Add after the last parameter, with proper formatting + insert = ",\n rms_norm_eps=config.rms_norm_eps" + new_content = ( + content[:end_idx].rstrip() + insert + "\n " + content[end_idx:] + ) + if not dry_run: + backup(fpath) + write(fpath, new_content) + print(" [1] rms_norm_eps: PATCHED") + return True + else: + print(" [1] rms_norm_eps: already present in super().__init__") + return True + + +# --------------------------------------------------------------------------- +# Patch 2: nkilib QKV CTE eps guard +# --------------------------------------------------------------------------- + + +def patch_nkilib_eps(paths, dry_run=False): + """Guard nisa.memset for norm_eps when norm_eps=None.""" + fpath = paths["nkilib_qkv_cte"] + content = read(fpath) + + if "norm_eps if norm_eps is not None else 0" in content: + print(" [2] nkilib eps guard: already patched") + return True + + # Find: nisa.memset(dst=norm_eps_sb, value=norm_eps) + old = "value=norm_eps)" + new = "value=norm_eps if norm_eps is not None else 0)" + + if old in content: + if not dry_run: + backup(fpath) + # Replace only the first occurrence in the relevant context + content = content.replace(old, new, 1) + write(fpath, content) + print(" [2] nkilib eps guard: PATCHED") + return True + + print(" [2] nkilib eps guard: ERROR - target not found") + return False + + +# --------------------------------------------------------------------------- +# Patch 3: neuronxcc QKV CTE eps guard +# --------------------------------------------------------------------------- + + +def patch_neuronxcc_eps(paths, dry_run=False): + """Guard bias_eps[...] = eps when eps=None.""" + fpath = paths["neuronxcc_qkv_cte"] + content = read(fpath) + + if "if eps is not None:" in content and "bias_eps" in content: + print(" [3] neuronxcc eps guard: already patched") + return True + + # Find: bias_eps[...] = eps (without an if guard) + # Note: neuronxcc uses 2-space indentation + # Try both 2-space and 4-space indentation patterns + old_2sp = " bias_eps[...] = eps" + new_2sp = " if eps is not None:\n bias_eps[...] = eps" + old_4sp = " bias_eps[...] = eps" + new_4sp = " if eps is not None:\n bias_eps[...] = eps" + + if old_2sp in content and "if eps is not None:" not in content: + old = old_2sp + new = new_2sp + elif old_4sp in content and "if eps is not None:" not in content: + old = old_4sp + new = new_4sp + else: + old = None + new = None + + if "if eps is not None:" in content: + print(" [3] neuronxcc eps guard: already patched") + return True + + if old is not None: + if not dry_run: + backup(fpath) + content = content.replace(old, new, 1) + write(fpath, content) + print(" [3] neuronxcc eps guard: PATCHED") + return True + + print(" [3] neuronxcc eps guard: ERROR - target not found") + return False + + +# --------------------------------------------------------------------------- +# Patch 4: QKV weight fusion +# --------------------------------------------------------------------------- + + +def patch_fused_qkv(paths, dry_run=False): + """Add QKV weight fusion to convert_hf_to_neuron_state_dict for Mistral.""" + fpath = paths["modeling_mistral"] + content = read(fpath) + + if "Fuse Q/K/V weights into Wqkv" in content: + print(" [4] fused_qkv: already patched") + return True + + # Find the return state_dict in convert_hf_to_neuron_state_dict + # We need to insert the fusion code just before "return state_dict" + # in the convert_hf_to_neuron_state_dict function + func_marker = "def convert_hf_to_neuron_state_dict" + func_idx = content.find(func_marker) + if func_idx == -1: + print(" [4] fused_qkv: ERROR - cannot find convert_hf_to_neuron_state_dict") + return False + + # Find the last "return state_dict" after this function + # (there may be multiple return statements; we want the final one in this function) + return_pattern = " return state_dict" + last_return_idx = content.rfind(return_pattern, func_idx) + if last_return_idx == -1: + print(" [4] fused_qkv: ERROR - cannot find 'return state_dict'") + return False + + fusion_code = """ + # Fuse Q/K/V weights into Wqkv when fused_qkv is enabled + if getattr(neuron_config, "fused_qkv", False): + import torch as _torch_fqkv + for i in range(num_layers): + q_key = f"layers.{i}.self_attn.q_proj.weight" + k_key = f"layers.{i}.self_attn.k_proj.weight" + v_key = f"layers.{i}.self_attn.v_proj.weight" + if q_key in state_dict and k_key in state_dict and v_key in state_dict: + q_w = state_dict.pop(q_key) + k_w = state_dict.pop(k_key) + v_w = state_dict.pop(v_key) + fused_key = f"layers.{i}.self_attn.qkv_proj.Wqkv.weight" + state_dict[fused_key] = _torch_fqkv.cat([q_w, k_w, v_w], dim=0) + # Also handle biases if present + q_bias_key = f"layers.{i}.self_attn.q_proj.bias" + k_bias_key = f"layers.{i}.self_attn.k_proj.bias" + v_bias_key = f"layers.{i}.self_attn.v_proj.bias" + if q_bias_key in state_dict: + fused_bias_key = f"layers.{i}.self_attn.qkv_proj.Wqkv.bias" + state_dict[fused_bias_key] = _torch_fqkv.cat([ + state_dict.pop(q_bias_key), + state_dict.pop(k_bias_key), + state_dict.pop(v_bias_key), + ], dim=0) + +""" + + new_content = content[:last_return_idx] + fusion_code + content[last_return_idx:] + if not dry_run: + backup(fpath) + write(fpath, new_content) + print(" [4] fused_qkv: PATCHED") + return True + + +# --------------------------------------------------------------------------- +# Patch 5: Fused RMSNorm in decoder forward +# --------------------------------------------------------------------------- + + +def patch_fused_rmsnorm(paths, dry_run=False): + """Add rmsnorm=self.input_layernorm to attention call in Mistral decoder forward.""" + fpath = paths["modeling_mistral"] + content = read(fpath) + + if "rmsnorm=self.input_layernorm" in content: + print(" [5] fused_rmsnorm: already patched") + return True + + # Find the attention call in the decoder forward method that does NOT have rmsnorm + # This varies by SDK version. We look for the self.self_attn( call in the forward method. + # The Llama model passes rmsnorm=self.input_layernorm; Mistral does not. + + # Strategy: find "hidden_states = self.self_attn(" in a decoder forward method + # and add rmsnorm parameter + + # Look for the pattern where self_attn is called with hidden_states + attn_call_pattern = "self.self_attn(\n" + idx = content.find(attn_call_pattern) + if idx == -1: + attn_call_pattern = "self.self_attn(" + idx = content.find(attn_call_pattern) + + if idx == -1: + print( + " [5] fused_rmsnorm: WARNING - cannot find self.self_attn call, skipping" + ) + return True # Non-fatal -- fused_rmsnorm is optional + + # Find the closing paren of the self_attn call + paren_depth = 0 + start = idx + len("self.self_attn(") + for i in range(start, len(content)): + if content[i] == "(": + paren_depth += 1 + elif content[i] == ")": + if paren_depth == 0: + # Insert rmsnorm before closing paren + call_body = content[start:i] + if "rmsnorm" not in call_body: + # Find the indentation of the existing params + # Look back from ')' to find the indentation level + # We need to add rmsnorm as a kwarg at the same indent level + # Strategy: find the last newline before the closing ')' to get indent + last_nl = content.rfind("\n", start, i) + if last_nl != -1: + # Get indent of closing paren + close_indent = "" + for c in content[last_nl + 1 : i]: + if c in " \t": + close_indent += c + else: + break + # Param indent is typically close_indent + 4 spaces + param_indent = close_indent + " " + else: + param_indent = " " + close_indent = " " + insert = ( + f"{param_indent}rmsnorm=self.input_layernorm,\n{close_indent}" + ) + # Replace the closing paren and its preceding whitespace + # Find where the whitespace before ')' starts + ws_start = i + while ws_start > start and content[ws_start - 1] in " \t\n": + ws_start -= 1 + # Check if there's already a trailing comma + pre_ws = content[ws_start - 1] if ws_start > start else "" + if pre_ws == ",": + new_content = content[:ws_start] + "\n" + insert + content[i:] + else: + new_content = content[:ws_start] + ",\n" + insert + content[i:] + if not dry_run: + backup(fpath) + write(fpath, new_content) + print(" [5] fused_rmsnorm: PATCHED") + else: + print(" [5] fused_rmsnorm: already present") + return True + paren_depth -= 1 + + print(" [5] fused_rmsnorm: WARNING - could not find closing paren, skipping") + return True + + +# --------------------------------------------------------------------------- +# Patch 6: Multi-KV TKG kernel + adapter +# --------------------------------------------------------------------------- + + +def _fix_nki030_kernel(fpath): + """Fix kernel for NKI 0.3.0: remove *, and add defaults to params after first defaulted.""" + content = read(fpath) + if "*," not in content: + return # Already fixed + + content = content.replace(" *,\n", "") + + # Find function signature + func_start = content.find("def attention_block_tkg(") + if func_start == -1: + return + paren_depth = 0 + in_func = False + func_end = func_start + for i in range(func_start, len(content)): + if content[i] == "(": + paren_depth += 1 + in_func = True + elif content[i] == ")": + paren_depth -= 1 + if in_func and paren_depth == 0: + func_end = i + break + + sig = content[func_start : func_end + 1] + lines = sig.split("\n") + new_lines = [] + seen_default = False + + for line in lines: + stripped = line.strip() + if ( + stripped.startswith("#") + or stripped == "" + or stripped.startswith("def ") + or stripped == ")" + or stripped == ")," + ): + new_lines.append(line) + continue + + has_default = "=" in stripped and ":" in stripped + if has_default: + seen_default = True + new_lines.append(line) + continue + + if not seen_default: + new_lines.append(line) + continue + + # Add default based on type annotation + if "Optional[" in stripped or ": nl.ndarray" in stripped: + default = "None" + elif ": bool" in stripped: + default = "False" + elif ": float" in stripped: + default = "0.0" + elif ": int" in stripped: + default = "0" + else: + default = "None" + + if stripped.endswith(","): + line = line.rstrip().rstrip(",") + f" = {default}," + else: + line = line.rstrip() + f" = {default}" + new_lines.append(line) + + new_sig = "\n".join(new_lines) + content = content[:func_start] + new_sig + content[func_end + 1 :] + write(fpath, content) + + +def patch_multi_kv_tkg(paths, dry_run=False): + """Install Leanstral forked multi-KV TKG kernel and adapter monkeypatch.""" + # Step 6a: Copy the forked kernel to nkilib + kernel_src = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "attention_block_tkg_multi_kv.py" + ) + kernel_dst = os.path.join( + paths["nkilib_transformer"], "attention_block_tkg_multi_kv.py" + ) + + if not os.path.exists(kernel_src): + print(f" [6a] multi-KV kernel: ERROR - source not found at {kernel_src}") + return False + + if not dry_run: + shutil.copy2(kernel_src, kernel_dst) + print(f" [6a] multi-KV kernel: copied to nkilib") + + # Step 6a.1: Apply NKI 0.3.0 compatibility fix to the kernel + # NKI 0.3.0 does not support keyword-only arguments (after *,) + # We remove *, and add defaults to params that need them + if not dry_run: + _fix_nki030_kernel(kernel_dst) + print(f" [6a] multi-KV kernel: NKI 0.3.0 fix applied") + + # Step 6b: Apply adapter monkeypatch to attention_base.py + fpath = paths["attention_base"] + content = read(fpath) + + PATCH_MARKER = "# MULTI_KV_TKG_PATCH_APPLIED" + if PATCH_MARKER in content: + print(" [6b] multi-KV adapter: already patched") + return True + + # Read the adapter code from our local file + adapter_src = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "multi_kv_adapter.py" + ) + + if os.path.exists(adapter_src): + adapter_code = read(adapter_src) + else: + print(f" [6b] multi-KV adapter: ERROR - source not found at {adapter_src}") + return False + + if not dry_run: + backup(fpath) + with open(fpath, "a") as f: + f.write("\n\n" + PATCH_MARKER + "\n") + f.write(adapter_code) + print(" [6b] multi-KV adapter: PATCHED") + return True + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description="Apply Mistral NKI optimization patches for Ministral-3-14B (Leanstral)" + ) + parser.add_argument( + "--venv", + default=DEFAULT_VENV, + help=f"Path to Neuron venv (default: {DEFAULT_VENV})", + ) + parser.add_argument( + "--dry-run", action="store_true", help="Show what would be patched" + ) + parser.add_argument( + "--skip-tkg", + action="store_true", + help="Skip TKG kernel patches (patches 1-5 only, for baseline NKI QKV testing)", + ) + args = parser.parse_args() + + venv = args.venv + sp = os.path.join(venv, "lib", "python3.12", "site-packages") + if not os.path.isdir(sp): + print(f"ERROR: site-packages not found at {sp}") + print("Make sure you're running on a DLAMI 20260410 (SDK 2.29) instance") + sys.exit(1) + + paths = resolve_paths(venv) + print(f"Patching SDK 2.29 at: {venv}") + if args.dry_run: + print("(DRY RUN - no files will be modified)\n") + else: + print() + + results = [] + results.append(("rms_norm_eps", patch_rms_norm_eps(paths, args.dry_run))) + results.append(("nkilib_eps", patch_nkilib_eps(paths, args.dry_run))) + results.append(("neuronxcc_eps", patch_neuronxcc_eps(paths, args.dry_run))) + results.append(("fused_qkv", patch_fused_qkv(paths, args.dry_run))) + results.append(("fused_rmsnorm", patch_fused_rmsnorm(paths, args.dry_run))) + + if not args.skip_tkg: + results.append(("multi_kv_tkg", patch_multi_kv_tkg(paths, args.dry_run))) + + print("\n--- Summary ---") + ok = all(r[1] for r in results) + for name, success in results: + print(f" {name}: {'OK' if success else 'FAILED'}") + + if ok: + print("\nAll patches applied successfully.") + if not args.skip_tkg: + print("\nTo start vLLM with full NKI optimization:") + print(" python -m vllm.entrypoints.openai.api_server \\") + print(" --model /path/to/Ministral-3-14B-Instruct-2512 \\") + print(" --tensor-parallel-size 4 --max-model-len 4096 \\") + print(" --max-num-seqs 1 --no-enable-prefix-caching \\") + print(" --block-size 8 \\") + print(' --additional-config \'{"override_neuron_config": {') + print(' "fused_qkv": true, "qkv_nki_kernel_enabled": true,') + print(' "qkv_kernel_enabled": true,') + print(' "attn_block_tkg_nki_kernel_enabled": true,') + print(' "attn_block_tkg_nki_kernel_cache_update": true') + print(" }}'") + else: + print("\nSome patches failed. Check errors above.") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Ministral-3-14B-Instruct-2512/test/__init__.py b/contrib/models/Ministral-3-14B-Instruct-2512/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Ministral-3-14B-Instruct-2512/test/integration/__init__.py b/contrib/models/Ministral-3-14B-Instruct-2512/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Ministral-3-14B-Instruct-2512/test/integration/test_model.py b/contrib/models/Ministral-3-14B-Instruct-2512/test/integration/test_model.py new file mode 100644 index 00000000..3d8c0c4f --- /dev/null +++ b/contrib/models/Ministral-3-14B-Instruct-2512/test/integration/test_model.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +""" +Integration tests for Ministral-3-14B-Instruct-2512 (Leanstral) on NeuronX. + +Tests require: + - trn2.3xlarge instance with SDK 2.28 + - NEURON_PLATFORM_TARGET_OVERRIDE=trn2 + - Model checkpoint at MODEL_PATH + - Pre-compiled model at COMPILED_MODEL_PATH (or will compile on first run) + +Run: + export NEURON_PLATFORM_TARGET_OVERRIDE=trn2 + export NEURON_COMPILE_CACHE_URL="" + source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate + pytest test/integration/test_model.py -v --capture=tee-sys +""" + +import os +import sys +import time +from pathlib import Path + +import pytest +import torch + +# Add src directory to path +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) +from modeling_leanstral import build_inference_config, get_model_cls + +# ---- Configuration ---- +# Override via environment variables if needed +MODEL_PATH = os.environ.get( + "LEANSTRAL_MODEL_PATH", "/mnt/models/Ministral-3-14B-Instruct-2512" +) +COMPILED_MODEL_PATH = os.environ.get( + "LEANSTRAL_COMPILED_PATH", "/mnt/models/compiled_leanstral_contrib" +) +TP_DEGREE = int(os.environ.get("LEANSTRAL_TP_DEGREE", "4")) +SEQ_LEN = 2048 +N_POSITIONS = 4096 +VISION_SEQ_LEN = 4096 + +TEXT_PROMPT = ( + "The theory of general relativity, proposed by Albert Einstein in 1915, " + "fundamentally changed" +) +NUM_DECODE_STEPS = 10 + + +# ---- Fixtures ---- + + +@pytest.fixture(scope="module") +def compiled_model(): + """Build, compile (if needed), and load the Leanstral model.""" + config = build_inference_config( + model_path=MODEL_PATH, + tp_degree=TP_DEGREE, + batch_size=1, + seq_len=SEQ_LEN, + n_positions=N_POSITIONS, + vision_seq_len=VISION_SEQ_LEN, + enable_tkg_kernel=True, + ) + ModelCls = get_model_cls() + model = ModelCls(MODEL_PATH, config) + + # Compile if not already compiled + model.compile(COMPILED_MODEL_PATH) + model.load(COMPILED_MODEL_PATH) + + # Enable vision encoder + model.enable_vision_encoder() + + # Warmup: run one short generation to populate caches + from neuronx_distributed_inference.utils.hf_adapter import ( + HuggingFaceGenerationAdapter, + ) + + adapter = HuggingFaceGenerationAdapter(model) + warmup_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + _ = adapter.generate( + input_ids=warmup_ids, + attention_mask=torch.ones_like(warmup_ids), + max_new_tokens=2, + do_sample=False, + ) + + return model + + +@pytest.fixture(scope="module") +def tokenizer(): + """Load the tokenizer.""" + from tokenizers import Tokenizer + + return Tokenizer.from_file(os.path.join(MODEL_PATH, "tokenizer.json")) + + +# ---- Helper functions ---- + + +def extract_logits(outputs): + """Extract logits from model output (handles various output formats).""" + if hasattr(outputs, "logits") and outputs.logits is not None: + return outputs.logits + elif isinstance(outputs, torch.Tensor): + return outputs + elif isinstance(outputs, (tuple, list)): + return outputs[0] + else: + return outputs.logits + + +def greedy_decode(model, tokenizer, prompt, num_steps): + """Run prefill + greedy decode for num_steps tokens. + + Returns (logits_list, token_list) where each logits entry is the + last-position logits tensor for that decode step. + """ + encoded = tokenizer.encode(prompt) + input_ids = torch.tensor([encoded.ids], dtype=torch.long) + prompt_len = input_ids.shape[1] + all_logits = [] + all_tokens = [] + + # Prefill + out = model( + input_ids=input_ids, + attention_mask=torch.ones_like(input_ids), + position_ids=torch.arange(prompt_len, dtype=torch.int32).unsqueeze(0), + seq_ids=torch.zeros(1, dtype=torch.int32), + sampling_params=torch.zeros(1, 3, dtype=torch.float32), + ) + logits = extract_logits(out) + step_logits = (logits[:, -1, :] if logits.dim() == 3 else logits).float().cpu() + all_logits.append(step_logits) + next_token = step_logits.argmax(dim=-1).squeeze().item() + all_tokens.append(next_token) + + # Decode + for step in range(num_steps - 1): + total_len = prompt_len + len(all_tokens) + out = model( + input_ids=torch.tensor([[all_tokens[-1]]], dtype=torch.long), + attention_mask=torch.ones(1, total_len, dtype=torch.int32), + position_ids=torch.tensor([[total_len - 1]], dtype=torch.int32), + seq_ids=torch.zeros(1, dtype=torch.int32), + sampling_params=torch.zeros(1, 3, dtype=torch.float32), + ) + logits = extract_logits(out) + step_logits = ( + (logits[:, -1, :] if logits.dim() == 3 else logits[:1]).float().cpu() + ) + all_logits.append(step_logits) + next_token = step_logits.argmax(dim=-1).squeeze().item() + all_tokens.append(next_token) + + return all_logits, all_tokens + + +# ---- Tests ---- + + +def test_smoke(compiled_model): + """Smoke test: model loads and has expected attributes.""" + assert compiled_model is not None + assert hasattr(compiled_model, "config") + assert hasattr(compiled_model, "cpu_projector") + assert hasattr(compiled_model, "vision_encoder_model") + assert compiled_model.config.text_config.num_hidden_layers == 40 + assert compiled_model.config.text_config.hidden_size == 5120 + print("Smoke test passed: model loaded with correct config") + + +def test_text_generation(compiled_model, tokenizer): + """Test text-only generation produces coherent output.""" + logits_list, tokens = greedy_decode( + compiled_model, tokenizer, TEXT_PROMPT, NUM_DECODE_STEPS + ) + text = tokenizer.decode(tokens) + print(f"Generated ({NUM_DECODE_STEPS} tokens): {text}") + + # Basic sanity checks + assert len(tokens) == NUM_DECODE_STEPS, ( + f"Expected {NUM_DECODE_STEPS} tokens, got {len(tokens)}" + ) + assert all(isinstance(t, int) for t in tokens), "All tokens must be integers" + assert all(0 <= t < 131072 for t in tokens), "Token IDs must be in vocab range" + + # Logits shape check + for i, logits in enumerate(logits_list): + assert logits.shape[-1] == 131072, ( + f"Step {i}: logits vocab dim = {logits.shape[-1]}, expected 131072" + ) + + print(f"Text generation test passed: {NUM_DECODE_STEPS} valid tokens generated") + + +def test_output_coherence(compiled_model, tokenizer): + """Test that generated text is not repetitive gibberish.""" + _, tokens = greedy_decode(compiled_model, tokenizer, TEXT_PROMPT, NUM_DECODE_STEPS) + text = tokenizer.decode(tokens) + + # Check for excessive repetition + words = text.split() + if len(words) >= 5: + max_repeat = 5 + for i in range(len(words) - max_repeat): + repeated = all(words[i + j] == words[i] for j in range(max_repeat)) + assert not repeated, ( + f"Excessive word repetition detected at position {i}: " + f"'{words[i]}' repeated {max_repeat}+ times" + ) + + # Check that we're not just producing the same token repeatedly + unique_tokens = set(tokens) + assert len(unique_tokens) >= min(3, NUM_DECODE_STEPS), ( + f"Only {len(unique_tokens)} unique tokens in {NUM_DECODE_STEPS} steps -- " + f"possible degenerate generation" + ) + + print(f"Coherence test passed: {len(unique_tokens)} unique tokens, text: {text}") + + +def test_logit_validity(compiled_model, tokenizer): + """Test that logits are finite and have reasonable distribution.""" + logits_list, _ = greedy_decode( + compiled_model, tokenizer, TEXT_PROMPT, NUM_DECODE_STEPS + ) + + for step, logits in enumerate(logits_list): + # Must be finite + assert torch.isfinite(logits).all(), f"Step {step}: non-finite logits detected" + + # Must not be all zeros + assert logits.abs().sum() > 0, f"Step {step}: all-zero logits" + + # Softmax should produce a valid probability distribution + probs = torch.softmax(logits.squeeze(), dim=-1) + prob_sum = probs.sum().item() + assert abs(prob_sum - 1.0) < 0.01, ( + f"Step {step}: softmax sum = {prob_sum}, expected ~1.0" + ) + + print(f"Logit validity test passed: all {len(logits_list)} steps have valid logits") + + +def test_throughput(compiled_model, tokenizer): + """Measure and report decode throughput.""" + num_tokens = 20 + + # Warmup + greedy_decode(compiled_model, tokenizer, "Hello", 3) + + # Timed run + start = time.perf_counter() + _, tokens = greedy_decode(compiled_model, tokenizer, TEXT_PROMPT, num_tokens) + elapsed = time.perf_counter() - start + + throughput = num_tokens / elapsed + text = tokenizer.decode(tokens) + print(f"Throughput: {throughput:.1f} tok/s ({num_tokens} tokens in {elapsed:.2f}s)") + print(f"Generated: {text}") + + # Minimum throughput sanity check (very conservative) + assert throughput > 5.0, ( + f"Throughput {throughput:.1f} tok/s is below minimum threshold of 5 tok/s" + ) + + +# ---- Main ---- + +if __name__ == "__main__": + print("=" * 70) + print("Ministral-3-14B-Instruct-2512 (Leanstral) Integration Tests") + print("=" * 70) + print(f"Model path: {MODEL_PATH}") + print(f"Compiled path: {COMPILED_MODEL_PATH}") + print(f"TP degree: {TP_DEGREE}") + print() + pytest.main([__file__, "-v", "--capture=tee-sys"]) diff --git a/contrib/models/Ministral-3-14B-Instruct-2512/test/unit/__init__.py b/contrib/models/Ministral-3-14B-Instruct-2512/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b