diff --git a/contrib/models/Qwen3.5-2B/README.md b/contrib/models/Qwen3.5-2B/README.md
new file mode 100644
index 00000000..d75f71c6
--- /dev/null
+++ b/contrib/models/Qwen3.5-2B/README.md
@@ -0,0 +1,256 @@
+# Contrib Model: Qwen3.5-2B
+
+NeuronX Distributed Inference implementation of Qwen3.5-2B, a 2B parameter dense model from Alibaba Cloud with a hybrid DeltaNet + GQA attention architecture.
+
+## Model Information
+
+- **HuggingFace ID:** `Qwen/Qwen3.5-2B`
+- **Model Type:** Decoder-only hybrid DeltaNet/GQA transformer
+- **Parameters:** ~2B (BF16)
+- **Architecture:** 24 layers (18 DeltaNet linear attention + 6 standard GQA), dense SwiGLU MLP, partial RoPE, tied embeddings
+- **License:** Apache 2.0
+
+### Key Architecture Details
+
+| Feature | Value |
+|---------|-------|
+| Layers | 24 (18 DeltaNet + 6 GQA, pattern: [3 DeltaNet + 1 GQA] x 6) |
+| Hidden Size | 2048 |
+| GQA Attention | 8 Q heads, 2 KV heads, head_dim=256 |
+| DeltaNet Attention | 16 value heads, 16 key heads, k_dim=v_dim=128 |
+| MLP | Dense SwiGLU (intermediate_size=6144) |
+| Position Encoding | Partial RoPE (25% of head_dim), mRoPE for VL |
+| Vocabulary | 248,320 |
+| Tied Embeddings | Yes |
+
+The DeltaNet layers use linear recurrent attention (gated delta rule) instead of softmax attention, requiring custom NKI kernels for execution on Neuron. A fused single-kernel chunked forward handles context encoding (CTE), while a per-token recurrent kernel handles token generation (TKG).
+
+## Validation Results
+
+**Validated:** 2026-04-23
+**Instance:** trn2.3xlarge (TP=4, LNC=2)
+**SDK:** Neuron SDK 2.29, PyTorch 2.9, NKI 0.3.0
+
+### Benchmark Results
+
+All benchmarks on trn2.3xlarge, TP=4, LNC=2, BF16. Chat-formatted prompt (~19 input tokens). Throughput is total tokens/sec across all batch items.
+
+#### Batch Size Scaling (seq_len=128)
+
+| Batch Size | TTFT (ms) | Throughput (tok/s) | Per-Request (tok/s) |
+|:----------:|:---------:|:------------------:|:-------------------:|
+| 1 | 157.8 | 114.5 | 114.5 |
+| 2 | 72.0 | 233.1 | 116.5 |
+| 4 | 104.4 | 329.6 | 82.4 |
+| 8 | 185.6 | 409.5 | 51.2 |
+
+#### Sequence Length Scaling (BS=1)
+
+| seq_len | TTFT (ms) | Throughput (tok/s) |
+|:-------:|:---------:|:------------------:|
+| 128 | 157.8 | 114.5 |
+| 512 | 54.3 | 138.1 |
+| 1024 | 102.7 | 125.3 |
+| 2048 | 199.7 | 106.5 |
+| 4096 | 401.7 | 80.3 |
+
+### Accuracy Validation
+
+9/9 integration tests pass. Accuracy is validated through:
+
+1. **First-token logit comparison** against pre-computed CPU BF16 reference logits:
+ - Cosine similarity: 0.9156 (threshold: 0.85) on TP shard 0
+ - Top-1 token agreement: True (both CPU and Neuron predict "Paris")
+ - Top-5 overlap: 4/5 (threshold: 3)
+
+2. **Multi-prompt coherence tests** with chat-formatted prompts:
+ - Factual Q&A: "What is the capital of France?" produces correct answer
+ - Code generation: "Write a Python fibonacci function" produces valid code
+ - Knowledge: "What is the largest ocean on Earth?" produces correct answer
+ - List generation: "List two ingredients for a chocolate cake" produces valid list
+
+**Note on multi-token logit validation:** DeltaNet layers (18 of 24) use NKI linear recurrent kernels that produce higher BF16 numerical divergence than standard GQA. Autoregressive sequences diverge after the first generated token, making multi-token `logit_validation()` inapplicable. The first-token logits are validated where CPU and Neuron process identical input prefixes. Additionally, the model outputs TP-sharded logits (vocab/tp_degree) because `ModelWrapper` does not call `_gather_along_dim`, so comparison uses the TP shard 0 slice.
+
+## Usage
+
+```python
+import json
+import os
+import torch
+from transformers import AutoTokenizer, GenerationConfig
+from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig
+from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter
+
+from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM
+
+model_path = "/path/to/Qwen3.5-2B"
+compiled_path = "/scratch/qwen35_2b_traced/"
+
+neuron_config = NeuronConfig(
+ tp_degree=4,
+ batch_size=1,
+ ctx_batch_size=1,
+ tkg_batch_size=1,
+ seq_len=128,
+ torch_dtype=torch.bfloat16,
+ logical_nc_config=2,
+ enable_bucketing=False,
+ flash_decoding_enabled=False,
+ on_device_sampling_config=OnDeviceSamplingConfig(top_k=1),
+ save_sharded_checkpoint=True,
+)
+
+# Read config.json directly (model_type 'qwen3_5' may not be
+# registered in all transformers versions)
+with open(os.path.join(model_path, "config.json")) as f:
+ hf_config = json.load(f)
+text_config = hf_config.get("text_config", hf_config)
+config_dict = dict(text_config)
+config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044)
+
+config = Qwen35InferenceConfig(
+ neuron_config=neuron_config,
+ **config_dict,
+)
+
+# Compile
+model = NeuronQwen35ForCausalLM(model_path, config)
+model.compile(compiled_path)
+
+# Load
+model = NeuronQwen35ForCausalLM(compiled_path)
+model.load(compiled_path)
+
+# Generate with chat template (recommended)
+tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right")
+gen_config = GenerationConfig(
+ do_sample=True, top_k=1,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+)
+
+messages = [{"role": "user", "content": "What is the capital of France?"}]
+text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+inputs = tokenizer(text, padding=True, return_tensors="pt")
+gen_model = HuggingFaceGenerationAdapter(model)
+outputs = gen_model.generate(
+ inputs.input_ids,
+ generation_config=gen_config,
+ attention_mask=inputs.attention_mask,
+ max_new_tokens=80,
+)
+print(tokenizer.decode(outputs[0], skip_special_tokens=True))
+```
+
+**Note:** Qwen3.5-2B is a chat model. Use `tokenizer.apply_chat_template()` for best results. Raw text prompts may produce echoey output.
+
+**Note on `seq_len`:** The `seq_len` parameter is the total sequence budget (input + generated tokens). Do not pad inputs to `max_length=seq_len`. Use `padding=True` for automatic minimal padding.
+
+## Compatibility Matrix
+
+| Instance | TP | SDK 2.29 | SDK 2.28 |
+|----------|-----|----------|----------|
+| trn2.3xlarge (LNC=2) | 4 | VALIDATED | Not tested |
+
+### Tested Configurations (trn2.3xlarge, TP=4, LNC=2)
+
+| Batch Size | seq_len | Status |
+|:----------:|:-------:|:------:|
+| 1 | 128 | VALIDATED |
+| 2 | 128 | VALIDATED |
+| 4 | 128 | VALIDATED |
+| 8 | 128 | VALIDATED |
+| 1 | 512 | VALIDATED |
+| 1 | 1024 | VALIDATED |
+| 1 | 2048 | VALIDATED |
+| 1 | 4096 | VALIDATED |
+| 2 | 1024 | VALIDATED |
+| 4 | 512 | VALIDATED |
+
+## Example Checkpoints
+
+* [Qwen/Qwen3.5-2B](https://huggingface.co/Qwen/Qwen3.5-2B) (BF16, ~4 GB)
+
+## Testing Instructions
+
+### Unit Tests (CPU only)
+
+```bash
+cd contrib/models/Qwen3.5-2B/
+pytest test/unit/ -v
+```
+
+### Integration Tests (requires trn2 instance)
+
+```bash
+cd contrib/models/Qwen3.5-2B/
+# Activate SDK 2.29 environment
+source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate
+
+QWEN35_MODEL_PATH=/mnt/models/Qwen3.5-2B \
+QWEN35_COMPILED_PATH=/mnt/models/qwen35_2b_traced \
+QWEN35_LOGIT_COMPILED_PATH=/mnt/models/qwen35_2b_traced_logits \
+QWEN35_REF_LOGITS_PATH=/mnt/models/qwen35_2b_cpu_reference_logits_bf16.pt \
+pytest test/integration/test_model.py --capture=tee-sys -v
+```
+
+Environment variables:
+- `QWEN35_MODEL_PATH` — Path to HF model weights (required)
+- `QWEN35_COMPILED_PATH` — Path for compiled artifacts (default: `/tmp/qwen35_2b_traced`)
+- `QWEN35_LOGIT_COMPILED_PATH` — Path to model compiled with `output_logits=True` for logit validation (optional; test skips if not provided)
+- `QWEN35_REF_LOGITS_PATH` — Path to pre-computed CPU BF16 reference logits for logit validation (optional; test skips if not provided)
+- `QWEN35_TP_DEGREE` — Tensor parallelism degree (default: 4)
+- `QWEN35_SEQ_LEN` — Max sequence length (default: 128)
+
+#### Generating CPU Reference Logits
+
+The `qwen3_5` model type requires `transformers>=5.0`. Generate BF16 reference logits in a separate environment:
+
+```bash
+python3 -m venv /tmp/cpu_ref_venv && source /tmp/cpu_ref_venv/bin/activate
+pip install torch transformers accelerate
+python3 -c "
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
+model = AutoModelForCausalLM.from_pretrained('/path/to/Qwen3.5-2B', torch_dtype=torch.bfloat16)
+tokenizer = AutoTokenizer.from_pretrained('/path/to/Qwen3.5-2B')
+inputs = tokenizer('The capital of France is', return_tensors='pt')
+gen_cfg = GenerationConfig(do_sample=False, max_new_tokens=16, min_new_tokens=16,
+ pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id)
+with torch.no_grad():
+ out = model.generate(inputs.input_ids, generation_config=gen_cfg,
+ return_dict_in_generate=True, output_scores=True)
+torch.save({'expected_logits': torch.stack(out.scores)[:16,:,:],
+ 'input_ids': inputs.input_ids, 'prompt': 'The capital of France is'},
+ '/path/to/qwen35_2b_cpu_reference_logits_bf16.pt')
+"
+deactivate
+```
+
+#### Compiling with output_logits for Logit Validation
+
+The logit validation test requires a separate compiled model with `output_logits=True`. After compiling the standard model, compile a second copy:
+
+```python
+neuron_config = NeuronConfig(
+ tp_degree=4, batch_size=1, ctx_batch_size=1, tkg_batch_size=1,
+ seq_len=128, torch_dtype=torch.bfloat16, logical_nc_config=2,
+ enable_bucketing=False, flash_decoding_enabled=False,
+ on_device_sampling_config=OnDeviceSamplingConfig(top_k=1),
+ save_sharded_checkpoint=True, output_logits=True, # <-- enables logit capture
+)
+```
+
+## Known Issues
+
+1. **SDK 2.29+ required:** The NKI DeltaNet kernels require NKI 0.3.0 (SDK 2.29).
+
+2. **PyTorch chunked forward hits compiler ICE on 2B dimensions:** The `_chunk_forward` path creates 5D tensors that trigger neuronx-cc codegen crash (NCC_INLA001). The fused NKI kernel is the default and required CTE path. Controlled via `USE_NKI_FUSED` env var (defaults to enabled).
+
+3. **No mini model test:** DeltaNet layers require NKI kernels that only execute on Neuron devices. All integration tests require a trn2 instance with full model weights.
+
+4. **Chat template required for quality output:** Raw text prompts produce echoey/repetitive output. Always use `tokenizer.apply_chat_template()`.
+
+## Maintainer
+
+Jim Burtoft ([@jimburtoft](https://github.com/jimburtoft))
diff --git a/contrib/models/Qwen3.5-2B/src/__init__.py b/contrib/models/Qwen3.5-2B/src/__init__.py
new file mode 100644
index 00000000..7e79aa03
--- /dev/null
+++ b/contrib/models/Qwen3.5-2B/src/__init__.py
@@ -0,0 +1,41 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from src.modeling_qwen35 import (
+ NeuronGatedDeltaNet,
+ NeuronQwen35Attention,
+ NeuronQwen35DecoderLayer,
+ NeuronQwen35ForCausalLM,
+ NeuronQwen35Model,
+ Qwen35DecoderModelInstance,
+ Qwen35InferenceConfig,
+ Qwen35MLP,
+ Qwen35ModelWrapper,
+)
+from src.modeling_qwen35_vision import (
+ NeuronQwen35VisionForImageEncoding,
+ NeuronQwen35VisionModel,
+)
+from src.modeling_qwen35_vl import (
+ NeuronQwen35VLForCausalLM,
+ Qwen35VLInferenceConfig,
+)
+
+__all__ = [
+ # Text decoder
+ "NeuronGatedDeltaNet",
+ "NeuronQwen35Attention",
+ "NeuronQwen35DecoderLayer",
+ "NeuronQwen35ForCausalLM",
+ "NeuronQwen35Model",
+ "Qwen35DecoderModelInstance",
+ "Qwen35InferenceConfig",
+ "Qwen35MLP",
+ "Qwen35ModelWrapper",
+ # Vision encoder
+ "NeuronQwen35VisionForImageEncoding",
+ "NeuronQwen35VisionModel",
+ # Vision-language
+ "NeuronQwen35VLForCausalLM",
+ "Qwen35VLInferenceConfig",
+]
diff --git a/contrib/models/Qwen3.5-2B/src/modeling_qwen35.py b/contrib/models/Qwen3.5-2B/src/modeling_qwen35.py
new file mode 100644
index 00000000..27feffe2
--- /dev/null
+++ b/contrib/models/Qwen3.5-2B/src/modeling_qwen35.py
@@ -0,0 +1,2517 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+NxDI contrib: Qwen3.5-2B (qwen3_5 -- dense model)
+
+Hybrid DeltaNet + Standard Attention + Dense MLP architecture.
+Same architecture family as the Qwen3.5 dense models with smaller dimensions.
+
+18 of 24 layers use Gated DeltaNet (linear recurrent attention)
+6 of 24 layers use standard GQA with KV cache + output gate
+All 24 layers use a dense SwiGLU MLP (intermediate_size=6144)
+
+Architecture details:
+- DeltaNet layers: separate in_proj_{qkv, z, a, b}, causal conv1d on QKV, gated delta rule
+- Attention layers: q_proj doubled (Q + gate), partial RoPE (25% of head_dim), sigmoid output gate
+- Dense MLP: standard SwiGLU (gate_proj, up_proj, down_proj) -- no MoE, no router, no experts
+- KV cache: NxDI KVCacheManager for attention layers; DeltaNet layers store recurrent+conv
+ state as nn.Parameter buffers and return dummy KV tuples
+"""
+
+import gc
+import math
+import logging
+import os
+import sys
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from neuronx_distributed_inference.models.model_base import (
+ NeuronBaseForCausalLM,
+ NeuronBaseModel,
+)
+from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm
+
+try:
+ from neuronxcc.nki._private_kernels.attention import attention_isa_kernel
+except ImportError:
+ from neuronxcc.nki.kernels.attention import attention_isa_kernel
+
+from neuronx_distributed.parallel_layers import parallel_state
+from neuronx_distributed.parallel_layers.layers import (
+ ColumnParallelLinear,
+ ParallelEmbedding,
+ RowParallelLinear,
+)
+from neuronx_distributed.utils import cpu_mode
+
+try:
+ from nki import jit as nki_jit # NKI 0.3.0+ (SDK 2.29)
+except ImportError:
+ from torch_neuronx.xla_impl.ops import nki_jit # NKI 0.2.x (SDK 2.28)
+from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeRMSNorm
+
+from src.nki_kernels.nki_deltanet import deltanet_recurrent_fwd as _deltanet_nki_kernel
+from src.nki_kernels.nki_deltanet import (
+ deltanet_recurrent_fwd_state as _deltanet_nki_kernel_state,
+)
+from src.nki_kernels.nki_deltanet_chunked import (
+ deltanet_chunk_step as _deltanet_nki_chunk_step,
+)
+from src.nki_kernels.nki_deltanet_fused import (
+ deltanet_fused_chunked_fwd as _deltanet_fused_kernel,
+)
+from src.nki_kernels.nki_deltanet_fused import (
+ _make_lower_mask,
+ _make_lower_mask_diag,
+ _make_identity,
+)
+
+from neuronx_distributed_inference.models.config import (
+ InferenceConfig,
+ NeuronConfig,
+)
+from neuronx_distributed_inference.models.model_wrapper import (
+ CONTEXT_ENCODING_MODEL_TAG,
+ TOKEN_GENERATION_MODEL_TAG,
+ DecoderModelInstance,
+ ModelWrapper,
+)
+from neuronx_distributed_inference.modules.attention.attention_base import (
+ NeuronAttentionBase,
+)
+from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding
+from neuronx_distributed_inference.models.layer_boundary_marker import (
+ ModuleMarkerEndWrapper,
+ ModuleMarkerStartWrapper,
+)
+
+logger = logging.getLogger(__name__)
+
+_flash_fwd_call = nki_jit()(attention_isa_kernel)
+
+# Option B: Direct nkilib flash attention for head_dim > 128
+USE_NKILIB_KERNEL = os.environ.get("USE_NKILIB_KERNEL", "0") == "1"
+
+_nkilib_flash_attn = None
+if USE_NKILIB_KERNEL:
+ try:
+ import neuronxcc.nki as _nki
+ from neuronx_distributed_inference.modules.attention.attention_base import (
+ peel_decorations as _peel_decorations,
+ get_platform_target as _get_platform_target,
+ )
+ from neuronxcc.nki.compiler import (
+ skip_middle_end_transformations as _skip_middle_end,
+ enable_stack_allocator as _enable_stack_allocator,
+ )
+
+ import importlib
+
+ _fork_path = "/home/ubuntu/nki-library-fork/nkilib_src"
+ if os.path.isdir(_fork_path) and _fork_path not in sys.path:
+ sys.path.insert(0, _fork_path)
+ _to_remove = [k for k in sys.modules if k.startswith("nkilib")]
+ for k in _to_remove:
+ del sys.modules[k]
+ import nki.language as _stub_nl
+ import neuronxcc.nki.language as _real_nl
+
+ for _attr in [
+ "NKIObject",
+ "float8_e4m3fn",
+ "float8_e4m3fn_x4",
+ "float8_e5m2_x4",
+ "float4_e2m1fn_x4",
+ ]:
+ if not hasattr(_real_nl, _attr) and hasattr(_stub_nl, _attr):
+ setattr(_real_nl, _attr, getattr(_stub_nl, _attr))
+ from nkilib.core.attention.attention_cte import (
+ attention_cte as _attention_cte_raw,
+ _MAX_HEAD_DIM,
+ )
+
+ assert _MAX_HEAD_DIM == 256, (
+ f"nkilib fork has _MAX_HEAD_DIM={_MAX_HEAD_DIM}, expected 256. "
+ f"System nkilib may have been loaded instead of fork."
+ )
+ logger.info(
+ f"Loaded nkilib attention_cte from fork (_MAX_HEAD_DIM={_MAX_HEAD_DIM})"
+ )
+
+ _raw_fn = _peel_decorations(_attention_cte_raw)
+ _platform = _get_platform_target()
+ _nkilib_flash_attn = _nki.jit(
+ _raw_fn,
+ mode="torchxla",
+ platform_target=_platform,
+ show_compiler_tb=True,
+ debug_kernel=True,
+ )
+ _nkilib_flash_attn = _skip_middle_end(_nkilib_flash_attn)
+ _nkilib_flash_attn = _enable_stack_allocator(
+ _nkilib_flash_attn, log_level=logging.INFO
+ )
+ logger.info("Option B: nkilib flash attention loaded for head_dim > 128")
+ except Exception as e:
+ logger.warning(f"Option B: Failed to load nkilib flash attention: {e}")
+ import traceback as _tb
+
+ _tb.print_exc()
+ _nkilib_flash_attn = None
+
+# Option A: Detect if patch_attn_kernel was imported
+NKILIB_PATCH_ACTIVE = False
+try:
+ from importlib import import_module as _import_module
+
+ _attn_mod = _import_module("neuronxcc.nki._pre_prod_kernels.attn_fwd")
+ if hasattr(_attn_mod, "_original_attention_nki_kernel_adapter"):
+ NKILIB_PATCH_ACTIVE = True
+ logger.info("Option A detected: _pre_prod_kernels patched with nkilib kernel")
+except Exception:
+ pass
+
+
+# ============================================================
+# Newton-Raphson Refined RMSNorm
+# ============================================================
+USE_NEWTON_RMSNORM = os.environ.get("USE_NEWTON_RMSNORM") == "1"
+USE_PYTHON_RMSNORM = os.environ.get("USE_PYTHON_RMSNORM") == "1"
+
+
+class NewtonRMSNorm(nn.Module):
+ """RMSNorm with Newton-Raphson refined rsqrt for improved numerical accuracy."""
+
+ def __init__(self, hidden_size=None, eps=1e-6):
+ super().__init__()
+ self.weight = None
+ if hidden_size is not None:
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.hidden_size = hidden_size
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ original_dtype = hidden_states.dtype
+ x = hidden_states.to(torch.float32)
+ variance = x.pow(2).mean(-1, keepdim=True)
+ y = torch.rsqrt(variance + self.variance_epsilon)
+ y = y * (3.0 - (variance + self.variance_epsilon) * y * y) * 0.5
+ result = x * y
+ if self.weight is not None:
+ result = result * self.weight.float()
+ return result.to(original_dtype)
+
+
+def get_rmsnorm_cls():
+ if cpu_mode() or USE_PYTHON_RMSNORM:
+ return Qwen3MoeRMSNorm
+ return NewtonRMSNorm if USE_NEWTON_RMSNORM else CustomRMSNorm
+
+
+def l2norm(x, dim=-1, eps=1e-6):
+ return F.normalize(x, p=2, dim=dim, eps=eps)
+
+
+# ============================================================
+# Gated DeltaNet Module (Linear Recurrent Attention)
+# ============================================================
+
+
+class NeuronGatedDeltaNet(nn.Module):
+ """
+ Gated DeltaNet linear attention for Neuron.
+
+ Replaces standard attention for 18 of 24 layers in Qwen3.5-2B.
+ Uses a chunk-based linear recurrence instead of KV cache.
+
+ HF weight layout (2B dense):
+ - in_proj_qkv.weight: (key_dim*2 + value_dim, hidden_size) = (6144, 2048)
+ - in_proj_z.weight: (value_dim, hidden_size) = (2048, 2048)
+ - in_proj_a.weight: (num_v_heads, hidden_size) = (16, 2048)
+ - in_proj_b.weight: (num_v_heads, hidden_size) = (16, 2048)
+ - conv1d.weight: (conv_dim, 1, conv_kernel_size) = (6144, 1, 4)
+ - A_log: (num_v_heads,) = (16,)
+ - dt_bias: (num_v_heads,) = (16,)
+ - norm.weight: (head_v_dim,) = (128,)
+ - out_proj.weight: (hidden_size, value_dim) = (5120, 6144)
+ """
+
+ def __init__(self, config, layer_idx: int):
+ super().__init__()
+ tc = config
+
+ self.hidden_size = tc.hidden_size # 5120
+ self.num_v_heads = tc.linear_num_value_heads # 48
+ self.num_k_heads = tc.linear_num_key_heads # 16
+ self.head_k_dim = tc.linear_key_head_dim # 128
+ self.head_v_dim = tc.linear_value_head_dim # 128
+ self.key_dim = self.head_k_dim * self.num_k_heads # 2048
+ self.value_dim = self.head_v_dim * self.num_v_heads # 6144
+ self.conv_kernel_size = tc.linear_conv_kernel_dim # 4
+ self.layer_idx = layer_idx
+ self.rms_norm_eps = tc.rms_norm_eps
+
+ # KV cache dummy shape info
+ self.head_dim = tc.head_dim # 256
+ tp_degree = tc.neuron_config.tp_degree
+ raw_kv_heads = tc.num_key_value_heads
+ if raw_kv_heads < tp_degree:
+ replicated_kv_heads = tp_degree
+ else:
+ replicated_kv_heads = raw_kv_heads
+ self.kv_heads_per_rank = replicated_kv_heads // tp_degree
+
+ # Conv1d on concatenated QKV (NOT Z)
+ self.conv_dim = self.key_dim * 2 + self.value_dim # 10240
+ self.conv1d = nn.Conv1d(
+ in_channels=self.conv_dim,
+ out_channels=self.conv_dim,
+ bias=False,
+ kernel_size=self.conv_kernel_size,
+ groups=self.conv_dim,
+ padding=self.conv_kernel_size - 1,
+ )
+
+ # Input projections (nn.Linear — NOT sharded by NxDI TP, replicated on all ranks)
+ self.in_proj_qkv = nn.Linear(
+ self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False
+ )
+ self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False)
+ self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
+ self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
+
+ # Decay parameters
+ self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
+ self.A_log = nn.Parameter(torch.zeros(self.num_v_heads))
+
+ # Output norm and projection
+ self.norm = Qwen3MoeRMSNorm(self.head_v_dim, eps=self.rms_norm_eps)
+ self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
+
+ # State buffers for CTE -> TKG carry-over
+ alloc_batch_size = getattr(config.neuron_config, "max_batch_size", 1)
+ self._phase_batch_size = getattr(config.neuron_config, "batch_size", 1)
+ self.recurrent_state_buffer = nn.Parameter(
+ torch.zeros(
+ alloc_batch_size,
+ self.num_v_heads,
+ self.head_k_dim,
+ self.head_v_dim,
+ dtype=config.neuron_config.torch_dtype,
+ ),
+ requires_grad=False,
+ )
+ self.conv_state_buffer = nn.Parameter(
+ torch.zeros(
+ alloc_batch_size,
+ self.conv_dim,
+ self.conv_kernel_size - 1,
+ dtype=config.neuron_config.torch_dtype,
+ ),
+ requires_grad=False,
+ )
+
+ def _recurrent_step(self, query, key, value, g, beta, recurrent_state):
+ """Single-step recurrent update for token generation."""
+ query = l2norm(query, dim=-1)
+ key = l2norm(key, dim=-1)
+ scale = 1.0 / (query.shape[-1] ** 0.5)
+ query = query * scale
+
+ q_t = query[:, :, 0]
+ k_t = key[:, :, 0]
+ v_t = value[:, :, 0]
+ g_t = g[:, :, 0].exp().unsqueeze(-1).unsqueeze(-1)
+ beta_t = beta[:, :, 0].unsqueeze(-1)
+
+ new_state = recurrent_state * g_t
+ kv_mem = (new_state * k_t.unsqueeze(-1)).sum(dim=-2)
+ delta = (v_t - kv_mem) * beta_t
+ new_state = new_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
+ output = (new_state * q_t.unsqueeze(-1)).sum(dim=-2)
+
+ return output.unsqueeze(2), new_state
+
+ def _nki_recurrent_forward(self, query, key, value, g, beta):
+ """Full-sequence recurrent forward using NKI kernel for context encoding."""
+ query = l2norm(query, dim=-1)
+ key = l2norm(key, dim=-1)
+ B, H, S, k_dim = query.shape
+ v_dim = value.shape[-1]
+ scale = 1.0 / (k_dim**0.5)
+ query = query * scale
+
+ BH = B * H
+ query_flat = query.reshape(BH, S, k_dim).contiguous()
+ key_flat = key.reshape(BH, S, k_dim).contiguous()
+ value_flat = value.reshape(BH, S, v_dim).contiguous()
+
+ g_flat = g.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous()
+ beta_flat = beta.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous()
+
+ outputs = []
+ states = []
+ for bh in range(BH):
+ out_bh, state_bh = _deltanet_nki_kernel_state(
+ query_flat[bh],
+ key_flat[bh],
+ value_flat[bh],
+ g_flat[bh],
+ beta_flat[bh],
+ )
+ outputs.append(out_bh)
+ states.append(state_bh)
+
+ output = torch.stack(outputs, dim=0)
+ output = output.reshape(B, H, S, v_dim)
+
+ final_state = torch.stack(states, dim=0)
+ final_state = final_state.reshape(B, H, k_dim, v_dim)
+
+ return output, final_state
+
+ def _nki_chunked_forward(
+ self, query, key, value, g, beta, output_final_state=False
+ ):
+ """Chunked NKI kernel forward for context encoding (prefill)."""
+ chunk_size = 128
+
+ query = l2norm(query, dim=-1)
+ key = l2norm(key, dim=-1)
+ B, H, S, k_dim = query.shape
+ v_dim = value.shape[-1]
+ scale = 1.0 / (k_dim**0.5)
+ query = query * scale
+
+ pad_size = (chunk_size - S % chunk_size) % chunk_size
+ if pad_size > 0:
+ query = F.pad(query, (0, 0, 0, pad_size))
+ key = F.pad(key, (0, 0, 0, pad_size))
+ value = F.pad(value, (0, 0, 0, pad_size))
+ beta = F.pad(beta, (0, pad_size))
+ g = F.pad(g, (0, pad_size))
+ total_seq_len = S + pad_size
+
+ num_chunks = total_seq_len // chunk_size
+ g_reshaped = g.reshape(B, H, num_chunks, chunk_size)
+ g_cs = g_reshaped.cumsum(dim=-1)
+ g_last_per_chunk = g_cs[:, :, :, -1:]
+ g_last_expanded = g_last_per_chunk.expand(-1, -1, -1, chunk_size)
+
+ query_chunks = query.reshape(B, H, num_chunks, chunk_size, k_dim)
+ key_chunks = key.reshape(B, H, num_chunks, chunk_size, k_dim)
+ value_chunks = value.reshape(B, H, num_chunks, chunk_size, v_dim)
+
+ beta_chunks = (
+ beta.reshape(B, H, num_chunks, chunk_size)
+ .unsqueeze(-1)
+ .expand(-1, -1, -1, -1, v_dim)
+ )
+ gc_chunks = g_cs.unsqueeze(-1).expand(-1, -1, -1, -1, v_dim)
+ gl_chunks = g_last_expanded.unsqueeze(-1).expand(-1, -1, -1, -1, v_dim)
+
+ BH = B * H
+ query_chunks = query_chunks.reshape(
+ BH, num_chunks, chunk_size, k_dim
+ ).contiguous()
+ key_chunks = key_chunks.reshape(BH, num_chunks, chunk_size, k_dim).contiguous()
+ value_chunks = value_chunks.reshape(
+ BH, num_chunks, chunk_size, v_dim
+ ).contiguous()
+ beta_chunks = beta_chunks.reshape(
+ BH, num_chunks, chunk_size, v_dim
+ ).contiguous()
+ gc_chunks = gc_chunks.reshape(BH, num_chunks, chunk_size, v_dim).contiguous()
+ gl_chunks = gl_chunks.reshape(BH, num_chunks, chunk_size, v_dim).contiguous()
+
+ device = query.device
+ lower_mask = torch.tril(
+ torch.ones(chunk_size, chunk_size, dtype=torch.float32, device=device),
+ diagonal=-1,
+ )
+ identity_mat = torch.eye(chunk_size, dtype=torch.float32, device=device)
+ lower_mask_diag = torch.tril(
+ torch.ones(chunk_size, chunk_size, dtype=torch.float32, device=device),
+ diagonal=0,
+ )
+
+ all_outputs = []
+ all_states = []
+ for bh in range(BH):
+ state = torch.zeros(k_dim, v_dim, dtype=torch.float32, device=device)
+
+ head_chunks = []
+ for c_idx in range(num_chunks):
+ q_chunk = query_chunks[bh, c_idx].contiguous()
+ k_chunk = key_chunks[bh, c_idx].contiguous()
+ v_chunk = value_chunks[bh, c_idx].contiguous()
+ beta_chunk = beta_chunks[bh, c_idx].contiguous()
+ gc_chunk = gc_chunks[bh, c_idx].contiguous()
+ gl_chunk = gl_chunks[bh, c_idx].contiguous()
+
+ out_chunk, state = _deltanet_nki_chunk_step(
+ q_chunk,
+ k_chunk,
+ v_chunk,
+ beta_chunk,
+ gc_chunk,
+ gl_chunk,
+ state,
+ lower_mask,
+ identity_mat,
+ lower_mask_diag,
+ )
+ head_chunks.append(out_chunk)
+
+ head_output = torch.cat(head_chunks, dim=0)
+ all_outputs.append(head_output)
+ all_states.append(state)
+
+ output = torch.stack(all_outputs, dim=0)
+ output = output.reshape(B, H, total_seq_len, v_dim)
+ output = output[:, :, :S]
+
+ if output_final_state:
+ final_state = torch.stack(all_states, dim=0)
+ last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim)
+ else:
+ last_recurrent_state = None
+
+ return output, last_recurrent_state
+
+ def _fused_chunked_forward(
+ self, query, key, value, g, beta, output_final_state=False
+ ):
+ """Fused single-kernel chunked forward for CTE — SSD-style.
+
+ Processes all chunks in a single NKI kernel call per (B,H) pair.
+ State persists in SBUF across chunks (no HBM round-trips).
+ Cumsum of g computed in-kernel via tensor_tensor_scan.
+
+ This is the optimized version of _nki_chunked_forward with:
+ 1. Single kernel call per (B,H) instead of B*H*num_chunks
+ 2. State in SBUF across all chunks (biggest perf win)
+ 3. In-kernel cumsum (avoids PyTorch cumsum overhead)
+ 4. tensor_scalar for broadcasts (no explicit loops)
+ """
+ chunk_size = 128
+
+ query = l2norm(query, dim=-1)
+ key = l2norm(key, dim=-1)
+ B, H, S, k_dim = query.shape
+ v_dim = value.shape[-1]
+ scale = 1.0 / (k_dim**0.5)
+ query = query * scale
+
+ # Pad sequence to multiple of chunk_size
+ pad_size = (chunk_size - S % chunk_size) % chunk_size
+ if pad_size > 0:
+ query = F.pad(query, (0, 0, 0, pad_size))
+ key = F.pad(key, (0, 0, 0, pad_size))
+ value = F.pad(value, (0, 0, 0, pad_size))
+ beta = F.pad(beta, (0, pad_size))
+ g = F.pad(g, (0, pad_size))
+ total_seq_len = S + pad_size
+
+ BH = B * H
+ # Flatten to (BH, S, dim) for per-(b,h) kernel calls
+ query_flat = query.reshape(BH, total_seq_len, k_dim).contiguous()
+ key_flat = key.reshape(BH, total_seq_len, k_dim).contiguous()
+ value_flat = value.reshape(BH, total_seq_len, v_dim).contiguous()
+
+ # g and beta: (BH, S) -> (BH, S, 1) for the kernel's (S, 1) input layout
+ g_flat = g.reshape(BH, total_seq_len).unsqueeze(-1).contiguous()
+ beta_flat = beta.reshape(BH, total_seq_len).unsqueeze(-1).contiguous()
+
+ # Create constant mask tensors (shared across all B*H calls)
+ device = query.device
+ lower_mask = torch.tensor(
+ _make_lower_mask(), dtype=torch.float32, device=device
+ )
+ identity_mat = torch.tensor(
+ _make_identity(), dtype=torch.float32, device=device
+ )
+ lower_mask_diag = torch.tensor(
+ _make_lower_mask_diag(), dtype=torch.float32, device=device
+ )
+
+ all_outputs = []
+ all_states = []
+ for bh in range(BH):
+ out_bh, state_bh = _deltanet_fused_kernel(
+ query_flat[bh], # (S, 128)
+ key_flat[bh], # (S, 128)
+ value_flat[bh], # (S, 128)
+ g_flat[bh], # (S, 1) — RAW g, not cumsum
+ beta_flat[bh], # (S, 1) — sigmoid(b)
+ lower_mask, # (128, 128)
+ identity_mat, # (128, 128)
+ lower_mask_diag, # (128, 128)
+ )
+ all_outputs.append(out_bh)
+ all_states.append(state_bh)
+
+ output = torch.stack(all_outputs, dim=0)
+ output = output.reshape(B, H, total_seq_len, v_dim)
+ output = output[:, :, :S]
+
+ if output_final_state:
+ final_state = torch.stack(all_states, dim=0)
+ last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim)
+ else:
+ last_recurrent_state = None
+
+ return output, last_recurrent_state
+
+ def _sequential_forward(self, query, key, value, g, beta, output_final_state=False):
+ """Sequential full-sequence gated delta rule for CTE.
+
+ Uses the same per-step recurrence as _recurrent_step but loops over the
+ full sequence. Avoids the slice-assignment loop in _chunk_forward that
+ may compile incorrectly on Neuron/XLA.
+ """
+ query = l2norm(query, dim=-1)
+ key = l2norm(key, dim=-1)
+
+ B, H, S, k_dim = query.shape
+ v_dim = value.shape[-1]
+ scale = 1.0 / (k_dim**0.5)
+ query = query * scale
+
+ state = query.new_zeros(B, H, k_dim, v_dim)
+ all_outputs = []
+ for t in range(S):
+ q_t = query[:, :, t] # (B, H, K)
+ k_t = key[:, :, t] # (B, H, K)
+ v_t = value[:, :, t] # (B, H, V)
+ beta_t = beta[:, :, t].unsqueeze(-1) # (B, H, 1)
+ g_t = g[:, :, t].exp().unsqueeze(-1).unsqueeze(-1) # (B, H, 1, 1)
+
+ # Gated delta rule
+ state = state * g_t
+ kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) # (B, H, V)
+ delta = (v_t - kv_mem) * beta_t # (B, H, V)
+ state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) # (B, H, K, V)
+
+ o_t = (state * q_t.unsqueeze(-1)).sum(dim=-2) # (B, H, V)
+ all_outputs.append(o_t.unsqueeze(2))
+
+ output = torch.cat(all_outputs, dim=2) # (B, H, S, V)
+ final_state = state if output_final_state else None
+ return output, final_state
+
+ def _chunk_forward(self, query, key, value, g, beta, output_final_state=False):
+ """Chunk-based forward for context encoding (prefill)."""
+ chunk_size = 64
+
+ query = l2norm(query, dim=-1)
+ key = l2norm(key, dim=-1)
+
+ B, H, S, k_dim = query.shape
+ v_dim = value.shape[-1]
+ scale = 1.0 / (k_dim**0.5)
+ query = query * scale
+
+ pad_size = (chunk_size - S % chunk_size) % chunk_size
+ if pad_size > 0:
+ query = F.pad(query, (0, 0, 0, pad_size))
+ key = F.pad(key, (0, 0, 0, pad_size))
+ value = F.pad(value, (0, 0, 0, pad_size))
+ beta = F.pad(beta, (0, pad_size))
+ g = F.pad(g, (0, pad_size))
+ total_seq_len = S + pad_size
+
+ v_beta = value * beta.unsqueeze(-1)
+ k_beta = key * beta.unsqueeze(-1)
+
+ num_chunks = total_seq_len // chunk_size
+ query = query.reshape(B, H, num_chunks, chunk_size, k_dim)
+ key = key.reshape(B, H, num_chunks, chunk_size, k_dim)
+ value = value.reshape(B, H, num_chunks, chunk_size, v_dim)
+ k_beta = k_beta.reshape(B, H, num_chunks, chunk_size, k_dim)
+ v_beta = v_beta.reshape(B, H, num_chunks, chunk_size, v_dim)
+ g = g.reshape(B, H, num_chunks, chunk_size)
+
+ mask = torch.triu(
+ torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device),
+ diagonal=0,
+ )
+
+ g = g.cumsum(dim=-1)
+ decay_mask = (g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().tril()
+
+ attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
+ for i in range(1, chunk_size):
+ row = attn[..., i, :i].clone()
+ sub = attn[..., :i, :i].clone()
+ attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
+ attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
+
+ value = attn @ v_beta
+ k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
+
+ last_recurrent_state = torch.zeros(
+ B, H, k_dim, v_dim, dtype=query.dtype, device=query.device
+ )
+ core_attn_out = torch.zeros_like(value)
+ mask2 = torch.triu(
+ torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device),
+ diagonal=1,
+ )
+
+ for i in range(num_chunks):
+ q_i = query[:, :, i]
+ k_i = key[:, :, i]
+ v_i = value[:, :, i]
+
+ attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(
+ mask2, 0
+ )
+
+ v_prime = k_cumdecay[:, :, i] @ last_recurrent_state
+ v_new = v_i - v_prime
+
+ attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
+ core_attn_out[:, :, i] = attn_inter + attn_i @ v_new
+
+ last_recurrent_state = (
+ last_recurrent_state * g[:, :, i, -1, None, None].exp()
+ + (
+ k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]
+ ).transpose(-1, -2)
+ @ v_new
+ )
+
+ core_attn_out = core_attn_out.reshape(B, H, -1, v_dim)
+ core_attn_out = core_attn_out[:, :, :S]
+
+ if not output_final_state:
+ last_recurrent_state = None
+
+ return core_attn_out, last_recurrent_state
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask=None,
+ position_ids=None,
+ past_key_value=None,
+ **kwargs,
+ ):
+ """Forward pass compatible with NxDI decoder layer interface."""
+ batch_size, seq_len, _ = hidden_states.shape
+
+ seq_ids = kwargs.get("seq_ids", None)
+ is_decode = past_key_value is not None
+
+ # Padding mask for DeltaNet: [B, S, 1] with 1.0 for real tokens, 0.0 for padding.
+ # Passed from get_model_output where it's computed from input_ids != pad_token_id.
+ # Embeddings are already zeroed for padding tokens; this mask additionally
+ # zeros the decay gate so the recurrent state is preserved unchanged
+ # through padding positions (no spurious decay).
+ valid_mask_1d = kwargs.get("deltanet_padding_mask", None) # [B, S, 1] or None
+
+ # Project inputs
+ deltanet_fp32 = os.environ.get("DELTANET_FP32") == "1"
+ if deltanet_fp32:
+ hs_f32 = hidden_states.float()
+ qkv = F.linear(hs_f32, self.in_proj_qkv.weight.float()).to(
+ hidden_states.dtype
+ )
+ z = F.linear(hs_f32, self.in_proj_z.weight.float()).to(hidden_states.dtype)
+ b = F.linear(hs_f32, self.in_proj_b.weight.float()).to(hidden_states.dtype)
+ a = F.linear(hs_f32, self.in_proj_a.weight.float()).to(hidden_states.dtype)
+ else:
+ qkv = self.in_proj_qkv(hidden_states)
+ z = self.in_proj_z(hidden_states)
+ b = self.in_proj_b(hidden_states)
+ a = self.in_proj_a(hidden_states)
+
+ # Split QKV
+ query = qkv[..., : self.key_dim]
+ key = qkv[..., self.key_dim : self.key_dim * 2]
+ value = qkv[..., self.key_dim * 2 :]
+
+ # Causal Conv1d on QKV
+ mixed = torch.cat([query, key, value], dim=-1)
+ mixed = mixed.transpose(1, 2)
+
+ if is_decode:
+ if seq_ids is not None:
+ conv_state = torch.index_select(self.conv_state_buffer, 0, seq_ids)
+ else:
+ conv_state = self.conv_state_buffer[:batch_size]
+ conv_input = torch.cat([conv_state, mixed], dim=-1)
+
+ w = self.conv1d.weight.squeeze(1)
+ conv_out = torch.zeros_like(mixed)
+ for k in range(4):
+ conv_out = (
+ conv_out
+ + w[:, k].unsqueeze(0).unsqueeze(-1) * conv_input[:, :, k : k + 1]
+ )
+ mixed_post_conv = F.silu(conv_out)
+
+ new_conv_state = torch.cat([conv_state[:, :, 1:], mixed], dim=-1)
+ alloc_bs = self.conv_state_buffer.shape[0]
+ if seq_ids is not None:
+ # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement
+ # Add buffer dependency for input_output_alias
+ new_conv_state = (
+ new_conv_state.to(self.conv_state_buffer.dtype)
+ + self.conv_state_buffer * 0
+ )
+ elif batch_size < alloc_bs:
+ pad_size = alloc_bs - batch_size
+ new_conv_state = torch.cat(
+ [
+ new_conv_state,
+ self.conv_state_buffer[batch_size:] * 0,
+ ],
+ dim=0,
+ )
+ else:
+ new_conv_state = new_conv_state + self.conv_state_buffer * 0
+ else:
+ mixed_post_conv = F.silu(self.conv1d(mixed)[:, :, :seq_len])
+
+ if valid_mask_1d is not None:
+ # valid_mask_1d is [B, S, 1]; count valid tokens per batch
+ num_valid = (
+ valid_mask_1d.squeeze(-1).sum(dim=-1, keepdim=True).long()
+ ) # [B, 1]
+ idx_base = num_valid - 3
+ idx_base = idx_base.clamp(min=0)
+ offsets = torch.arange(3, device=mixed.device).unsqueeze(0)
+ gather_idx = idx_base + offsets # [B, 3]
+ gather_idx = gather_idx.unsqueeze(1).expand(-1, self.conv_dim, -1)
+ new_conv_state = torch.gather(mixed, 2, gather_idx)
+ else:
+ new_conv_state = mixed[:, :, -3:].contiguous()
+
+ alloc_bs = self.conv_state_buffer.shape[0]
+ if seq_ids is not None:
+ # BS=1 optimization: scatter to index 0 = direct replacement
+ new_conv_state = (
+ new_conv_state.to(self.conv_state_buffer.dtype)
+ + self.conv_state_buffer * 0
+ )
+ elif batch_size < alloc_bs:
+ pad_size = alloc_bs - batch_size
+ new_conv_state = torch.cat(
+ [
+ new_conv_state,
+ torch.zeros(
+ pad_size,
+ self.conv_dim,
+ self.conv_kernel_size - 1,
+ dtype=new_conv_state.dtype,
+ device=new_conv_state.device,
+ ),
+ ],
+ dim=0,
+ )
+ new_conv_state = new_conv_state + self.conv_state_buffer * 0
+ else:
+ new_conv_state = new_conv_state + self.conv_state_buffer * 0
+
+ mixed_post_conv = mixed_post_conv.transpose(1, 2)
+
+ # Zero out conv1d output for padding positions.
+ # Conv1d with kernel_size=4 leaks real token info into the first
+ # few padding positions. Zeroing here ensures Q, K, V are exactly
+ # zero for all padding positions so the recurrence is unaffected.
+ if valid_mask_1d is not None:
+ mixed_post_conv = (
+ mixed_post_conv * valid_mask_1d
+ ) # [B, S, conv_dim] * [B, S, 1]
+
+ query = mixed_post_conv[..., : self.key_dim]
+ key = mixed_post_conv[..., self.key_dim : self.key_dim * 2]
+ value = mixed_post_conv[..., self.key_dim * 2 :]
+
+ # Reshape to heads
+ query = query.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim)
+ key = key.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim)
+ value = value.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim)
+
+ # Compute gating
+ beta = b.sigmoid()
+ g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
+
+ if valid_mask_1d is not None:
+ # Zero g for padding → alpha=exp(0)=1 → state preserved through padding
+ # Zero beta for padding → no state update from padding tokens
+ mask_2d = valid_mask_1d.squeeze(-1).float() # [B, S]
+ g = g * mask_2d.unsqueeze(-1)
+ beta = beta * mask_2d.unsqueeze(-1)
+
+ # Expand K heads to match V heads (16 -> 48) using expand+reshape
+ if self.num_v_heads // self.num_k_heads > 1:
+ rep = self.num_v_heads // self.num_k_heads # 3
+ query = (
+ query.unsqueeze(3)
+ .expand(-1, -1, -1, rep, -1)
+ .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim)
+ )
+ key = (
+ key.unsqueeze(3)
+ .expand(-1, -1, -1, rep, -1)
+ .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim)
+ )
+
+ # Transpose to (B, H, S, dim)
+ query = query.transpose(1, 2).contiguous().float()
+ key = key.transpose(1, 2).contiguous().float()
+ value = value.transpose(1, 2).contiguous().float()
+ g = g.transpose(1, 2).contiguous().float()
+ beta = beta.transpose(1, 2).contiguous().float()
+
+ if is_decode:
+ # TKG: single-step recurrent update
+ if seq_ids is not None:
+ recurrent_state = torch.index_select(
+ self.recurrent_state_buffer, 0, seq_ids
+ ).float()
+ else:
+ recurrent_state = self.recurrent_state_buffer[:batch_size].float()
+
+ output, new_state = self._recurrent_step(
+ query, key, value, g, beta, recurrent_state
+ )
+ new_state_bf16 = new_state.to(self.recurrent_state_buffer.dtype)
+ alloc_bs = self.recurrent_state_buffer.shape[0]
+ if seq_ids is not None:
+ # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement
+ # Add buffer dependency for input_output_alias
+ new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0
+ elif batch_size < alloc_bs:
+ new_rec_state = torch.cat(
+ [
+ new_state_bf16,
+ self.recurrent_state_buffer[batch_size:] * 0,
+ ],
+ dim=0,
+ )
+ else:
+ new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0
+ else:
+ # CTE: fused NKI kernel by default (PyTorch _chunk_forward hits
+ # neuronx-cc codegen ICE NCC_INLA001 with 2B dimensions).
+ # Override with env vars for debugging/benchmarking.
+ use_nki_fused = os.environ.get("USE_NKI_FUSED", "1") != "0"
+ use_nki_chunked = os.environ.get("USE_NKI_CHUNKED") == "1"
+ use_nki = os.environ.get("USE_NKI") == "1"
+ use_sequential = os.environ.get("DELTANET_SEQUENTIAL") == "1"
+ use_pytorch_chunk = os.environ.get("USE_PYTORCH_CHUNK") == "1"
+
+ if use_pytorch_chunk:
+ output, final_state = self._chunk_forward(
+ query, key, value, g, beta, output_final_state=True
+ )
+ elif use_nki_chunked:
+ output, final_state = self._nki_chunked_forward(
+ query, key, value, g, beta, output_final_state=True
+ )
+ elif use_nki:
+ output, final_state = self._nki_recurrent_forward(
+ query, key, value, g, beta
+ )
+ elif use_sequential:
+ output, final_state = self._sequential_forward(
+ query, key, value, g, beta, output_final_state=True
+ )
+ elif use_nki_fused:
+ output, final_state = self._fused_chunked_forward(
+ query, key, value, g, beta, output_final_state=True
+ )
+ else:
+ output, final_state = self._fused_chunked_forward(
+ query, key, value, g, beta, output_final_state=True
+ )
+
+ if final_state is not None:
+ final_state_bf16 = final_state.to(self.recurrent_state_buffer.dtype)
+ alloc_bs = self.recurrent_state_buffer.shape[0]
+ if seq_ids is not None:
+ # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement
+ # Add buffer dependency for input_output_alias
+ new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0
+ elif batch_size < alloc_bs:
+ new_rec_state = torch.cat(
+ [
+ final_state_bf16,
+ torch.zeros(
+ alloc_bs - batch_size,
+ self.num_v_heads,
+ self.head_k_dim,
+ self.head_v_dim,
+ dtype=final_state_bf16.dtype,
+ device=final_state_bf16.device,
+ ),
+ ],
+ dim=0,
+ )
+ new_rec_state = new_rec_state + self.recurrent_state_buffer * 0
+ else:
+ new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0
+ else:
+ new_rec_state = self.recurrent_state_buffer * 1
+
+ # Output: norm, gate, project
+ output = output.to(hidden_states.dtype)
+ output = output.transpose(1, 2).contiguous()
+ output = output.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim)
+ output = self.norm(output)
+ z_gate = z.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim)
+ output = output * F.silu(z_gate)
+ output = output.reshape(batch_size, seq_len, self.value_dim)
+ output = self.out_proj(output)
+
+ # Return dummy KV for KVCacheManager
+ dummy_k = torch.zeros(
+ batch_size,
+ self.kv_heads_per_rank,
+ seq_len,
+ self.head_dim,
+ dtype=hidden_states.dtype,
+ device=hidden_states.device,
+ )
+ dummy_v = torch.zeros_like(dummy_k)
+
+ return output, (dummy_k, dummy_v), new_rec_state, new_conv_state
+
+
+# ============================================================
+# InferenceConfig (Dense -- no MoE)
+# ============================================================
+
+
+class Qwen35InferenceConfig(InferenceConfig):
+ """Config for Qwen3.5-2B (dense) with hybrid DeltaNet + Attention."""
+
+ def __init__(self, *args, **kwargs):
+ # Set defaults BEFORE super().__init__() because it calls validate_config()
+ # which checks get_required_attributes(). These can be overridden by
+ # kwargs or load_config.
+
+ # Layer types for hybrid dispatch: [3 DeltaNet + 1 GQA] x 6 = 24 layers
+ if "layer_types" not in kwargs and not any(
+ hasattr(a, "layer_types") for a in args if hasattr(a, "__dict__")
+ ):
+ layer_types = []
+ for _ in range(6):
+ layer_types.extend(
+ [
+ "linear_attention",
+ "linear_attention",
+ "linear_attention",
+ "full_attention",
+ ]
+ )
+ kwargs.setdefault("layer_types", layer_types)
+
+ # DeltaNet-specific config defaults
+ kwargs.setdefault("linear_num_value_heads", 16)
+ kwargs.setdefault("linear_num_key_heads", 16)
+ kwargs.setdefault("linear_key_head_dim", 128)
+ kwargs.setdefault("linear_value_head_dim", 128)
+ kwargs.setdefault("linear_conv_kernel_dim", 4)
+
+ super().__init__(*args, **kwargs)
+
+ # Attention output gate
+ self.attn_output_gate = getattr(self, "attn_output_gate", True)
+
+ # Partial RoPE
+ self.partial_rotary_factor = getattr(self, "partial_rotary_factor", 0.25)
+ self.rope_dim = int(self.head_dim * self.partial_rotary_factor) # 64
+
+ # mRoPE (multimodal RoPE) for VL support
+ rope_params = getattr(self, "rope_parameters", {}) or {}
+ self.mrope_section = rope_params.get("mrope_section", [11, 11, 10])
+ self.mrope_interleaved = rope_params.get("mrope_interleaved", True)
+
+ # Standard HF config attributes expected by NxDI
+ if not hasattr(self, "output_attentions"):
+ self.output_attentions = False
+ if not hasattr(self, "output_hidden_states"):
+ self.output_hidden_states = False
+
+ def get_required_attributes(self) -> List[str]:
+ return [
+ "head_dim",
+ "hidden_act",
+ "hidden_size",
+ "intermediate_size",
+ "max_position_embeddings",
+ "num_attention_heads",
+ "num_hidden_layers",
+ "num_key_value_heads",
+ "rms_norm_eps",
+ "rope_theta",
+ "vocab_size",
+ # DeltaNet-specific
+ "linear_num_value_heads",
+ "linear_num_key_heads",
+ "linear_key_head_dim",
+ "linear_value_head_dim",
+ "linear_conv_kernel_dim",
+ "layer_types",
+ ]
+
+ @classmethod
+ def get_neuron_config_cls(cls):
+ return NeuronConfig
+
+
+# ============================================================
+# Attention (standard GQA for 16 of 64 layers)
+# With output gate: q_proj is 2x sized, split into (query, gate)
+# With partial RoPE: only first rope_dim dimensions get rotary
+# ============================================================
+
+
+class Qwen35MRoPEEmbedding(nn.Module):
+ """Multimodal Rotary Position Embedding (mRoPE) for Qwen3.5.
+
+ Handles 3D position information (temporal, height, width) for VL models.
+ Position IDs have shape (3, batch_size, seq_len) for T/H/W dimensions.
+ For text-only (2D position_ids), broadcasts to 3D with identical positions.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.head_dim = config.head_dim # 256
+ self.rope_dim = config.rope_dim # 64
+ self.mrope_section = config.mrope_section # [11, 11, 10]
+ self.mrope_interleaved = getattr(config, "mrope_interleaved", True)
+ self.rope_theta = config.rope_theta
+
+ # Validate mrope_section sums to rope_dim // 2 = 32
+ assert sum(self.mrope_section) == self.rope_dim // 2, (
+ f"mrope_section {self.mrope_section} sums to {sum(self.mrope_section)}, "
+ f"expected {self.rope_dim // 2}"
+ )
+
+ def forward(self, x, position_ids_3d):
+ """Compute cos/sin from 3D position IDs.
+
+ Args:
+ x: hidden_states (for device/dtype inference)
+ position_ids_3d: (3, batch_size, seq_len) -- T, H, W positions
+
+ Returns:
+ cos: (batch_size, seq_len, rope_dim)
+ sin: (batch_size, seq_len, rope_dim)
+ """
+ device = x.device
+ dtype = torch.float32
+
+ sections = self.mrope_section # [11, 11, 10]
+ cos_parts = []
+ sin_parts = []
+
+ freq_offset = 0
+ for axis_idx, section_size in enumerate(sections):
+ pos = position_ids_3d[axis_idx].float() # (batch, seq_len)
+
+ dim_pairs = section_size # number of (cos, sin) pairs for this axis
+ freqs = 1.0 / (
+ self.rope_theta
+ ** (
+ torch.arange(0, dim_pairs * 2, 2, dtype=dtype, device=device)
+ / (self.rope_dim)
+ )
+ ) # (dim_pairs,)
+
+ # freqs: (dim_pairs,), pos: (B, S) -> angles: (B, S, dim_pairs)
+ angles = pos.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0)
+
+ cos_parts.append(angles.cos())
+ sin_parts.append(angles.sin())
+
+ # Concatenate: (B, S, 32)
+ cos = torch.cat(cos_parts, dim=-1)
+ sin = torch.cat(sin_parts, dim=-1)
+
+ if self.mrope_interleaved:
+ # Interleave to (B, S, 64): [c0, c0, c1, c1, ...] for rotate_half
+ cos = cos.repeat_interleave(2, dim=-1)
+ sin = sin.repeat_interleave(2, dim=-1)
+ else:
+ cos = torch.cat([cos, cos], dim=-1)
+ sin = torch.cat([sin, sin], dim=-1)
+
+ return cos, sin
+
+
+class NeuronQwen35Attention(NeuronAttentionBase):
+ """Standard GQA attention for Qwen3.5 with output gate and partial RoPE.
+
+ 8 Q heads, 2 KV heads (4:1 GQA), head_dim=256 for the 2B dense model.
+ q_proj is doubled (query + gate), split at load time.
+ Only first rope_dim=64 of head_dim=256 gets rotary encoding.
+
+ Uses NeuronAttentionBase infrastructure for QKV projection, KV cache,
+ RoPE, and attention computation. Overrides forward() to insert the
+ sigmoid output gate between attention output and o_proj.
+ """
+
+ def __init__(self, config):
+ # Partial RoPE: create mRoPE embedding with rope_dim (64)
+ self.rope_dim = config.rope_dim # 64 = head_dim * partial_rotary_factor
+
+ # Create QK norm modules (will be passed to base class)
+ rms_norm_eps = config.rms_norm_eps
+ q_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps)
+ k_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps)
+
+ # Partial RoPE: use standard RotaryEmbedding.
+ # For VL with 3D mRoPE positions, cos/sin are pre-computed externally in
+ # get_model_output() using Qwen35MRoPEEmbedding and passed as cos_cache/sin_cache.
+ rotary_emb = RotaryEmbedding(
+ self.rope_dim, # Only 64 dims get rotary embedding
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ )
+ super().__init__(
+ config=config,
+ hidden_size=config.hidden_size,
+ num_attention_heads=config.num_attention_heads,
+ num_key_value_heads=config.num_key_value_heads,
+ head_dim=config.head_dim,
+ rotary_emb=rotary_emb,
+ rms_norm_eps=rms_norm_eps,
+ use_qk_norm=False,
+ q_layernorm=q_ln,
+ k_layernorm=k_ln,
+ )
+
+ # Separate mRoPE module for VL 3D position_ids
+ self.mrope_emb = Qwen35MRoPEEmbedding(config)
+
+ # Output gate projection: hidden_size -> num_heads * head_dim
+ # Populated from the second half of q_proj during state dict conversion.
+ self.output_gate_proj = ColumnParallelLinear(
+ config.hidden_size,
+ config.num_attention_heads * config.head_dim,
+ bias=False,
+ gather_output=False,
+ )
+
+ def apply_rotary_embedding(
+ self, Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope
+ ):
+ """Partial RoPE: only apply rotary embedding to first rope_dim dimensions.
+
+ Q shape: (B, H, S, head_dim) where head_dim=256
+ cos/sin shape: (B, S, rope_dim) where rope_dim=64 (from RotaryEmbedding(dim=64))
+
+ Split Q/K along last dim into:
+ q_rope (first 64 dims) -- apply RoPE
+ q_pass (remaining 192 dims) -- pass through unchanged
+ """
+ from neuronx_distributed_inference.modules.attention.utils import (
+ apply_rotary_pos_emb,
+ )
+
+ if self.rotary_emb is not None:
+ if cos_cache is None or sin_cache is None:
+ cos_cache, sin_cache = self.rotary_emb(V, position_ids)
+
+ # Split into rope and pass-through portions
+ Q_orig_dtype = Q.dtype
+ q_rope = Q[..., : self.rope_dim] # (B, H, S, 64)
+ q_pass = Q[..., self.rope_dim :] # (B, H, S, 192)
+ k_rope = K[..., : self.rope_dim]
+ k_pass = K[..., self.rope_dim :]
+
+ # Apply RoPE only to the rope portion
+ q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos_cache, sin_cache)
+
+ # Concatenate back (ensure bf16 is maintained)
+ Q = torch.cat([q_rope, q_pass], dim=-1).to(Q_orig_dtype)
+ K = torch.cat([k_rope, k_pass], dim=-1).to(Q_orig_dtype)
+
+ return Q, K, cos_cache, sin_cache
+
+ def perform_prefill(self, Q, K, V, q_len, bsz, attention_mask=None):
+ """Prefill path with NKI flash attention for head_dim=256."""
+ head_dim = Q.shape[-1]
+
+ # Option B: nkilib flash attention for head_dim > 128
+ if _nkilib_flash_attn is not None:
+ q_contig = Q.contiguous()
+ k_contig = K.contiguous()
+ v_contig = V.contiguous()
+ scale = 1.0 / math.sqrt(head_dim)
+ result = _nkilib_flash_attn(
+ q_contig, k_contig, v_contig, scale=scale, use_causal_mask=True
+ )
+ return result, None
+
+ # Option A: kernel patched globally
+ if NKILIB_PATCH_ACTIVE:
+ return _flash_fwd_call(Q, K, V, use_causal_mask=True), None
+
+ # Fallback: softmax path (use 3D tensors to avoid compiler ICE with 4D patterns)
+ if head_dim > 128:
+ # GQA: expand K/V heads to match Q heads
+ num_q_heads = Q.shape[1]
+ num_kv_heads = K.shape[1]
+ if num_q_heads != num_kv_heads:
+ kv_rep = num_q_heads // num_kv_heads
+ K = (
+ K.unsqueeze(2)
+ .expand(-1, -1, kv_rep, -1, -1)
+ .reshape(bsz, num_q_heads, q_len, head_dim)
+ )
+ V = (
+ V.unsqueeze(2)
+ .expand(-1, -1, kv_rep, -1, -1)
+ .reshape(bsz, num_q_heads, q_len, head_dim)
+ )
+ # Reshape to 3D (B*H, S, d) to avoid neuronx-cc codegen ICE with 4D
+ # attention weight tensors (NCC_INLA001: Expected 2D tensor but got 4D AP)
+ Q_3d = Q.reshape(bsz * num_q_heads, q_len, head_dim)
+ K_3d = K.reshape(bsz * num_q_heads, q_len, head_dim)
+ V_3d = V.reshape(bsz * num_q_heads, q_len, head_dim)
+ attn_weights = torch.bmm(Q_3d, K_3d.transpose(-1, -2)) / math.sqrt(head_dim)
+ # Build causal mask for 3D: (1, S, S) broadcast over B*H
+ causal_mask = torch.triu(
+ torch.full(
+ (q_len, q_len),
+ -65504.0,
+ dtype=attn_weights.dtype,
+ device=attn_weights.device,
+ ),
+ diagonal=1,
+ ).unsqueeze(0)
+ attn_weights = attn_weights + causal_mask
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
+ Q.dtype
+ )
+ attn_output = torch.bmm(attn_weights, V_3d)
+ # Reshape back to 4D (B, H, S, d)
+ return attn_output.reshape(bsz, num_q_heads, q_len, head_dim), None
+
+ return _flash_fwd_call(Q, K, V, use_causal_mask=True), None
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_ids=None,
+ past_key_value=None,
+ cos_cache=None,
+ sin_cache=None,
+ rmsnorm=None,
+ adapter_ids=None,
+ active_mask=None,
+ **kwargs,
+ ):
+ """Forward with output gate applied BEFORE o_proj.
+
+ Override NeuronAttentionBase.forward() to insert the sigmoid gate
+ between the attention output and o_proj, matching the HF reference:
+ gate = sigmoid(gate_proj(pre_attn_hidden))
+ attn_output = attn_output * gate
+ attn_output = o_proj(attn_output)
+ """
+ bsz, q_len, _ = hidden_states.shape
+
+ # Use standard 2D position_ids for prep_qkv_tensors.
+ rope_pos_ids = position_ids
+
+ # Compute gate from input hidden states (before QKV projection)
+ gate = self.output_gate_proj(hidden_states) # (B, S, num_heads * head_dim)
+
+ # Standard QKV prep (projections, QK norm, RoPE)
+ Q, K, V, cos_cache, sin_cache, _residual = self.prep_qkv_tensors(
+ rope_pos_ids,
+ hidden_states,
+ past_key_value,
+ adapter_ids=adapter_ids,
+ cos_cache=cos_cache,
+ sin_cache=sin_cache,
+ rmsnorm=rmsnorm,
+ )
+
+ if past_key_value is None:
+ # Context encoding (prefill)
+ attn_output, _flash_strategy = self.perform_prefill(
+ Q, K, V, q_len, bsz, attention_mask
+ )
+ else:
+ # Token generation (decode)
+ tkg_mask = attention_mask
+ if tkg_mask is not None and tkg_mask.ndim == 2:
+ tkg_mask = tkg_mask.unsqueeze(1).unsqueeze(2) # (B, S) -> (B, 1, 1, S)
+ attn_output = self.compute_for_token_gen(
+ Q, K, V, position_ids, past_key_value, tkg_mask, active_mask
+ )
+
+ # attn_output is (B, H, S, head_dim) -- transpose to (B, S, H*head_dim)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
+
+ # Apply sigmoid output gate BEFORE o_proj (matching HF reference)
+ attn_output = attn_output * torch.sigmoid(gate)
+
+ # Apply o_proj
+ attn_output = self.get_o_proj()(attn_output, adapter_ids=adapter_ids)
+
+ # Ensure K, V are in model dtype (bf16) for KV cache update
+ # (prevents mixed-precision dynamic-update-slice in neuronx-cc)
+ K = K.to(self.torch_dtype)
+ V = V.to(self.torch_dtype)
+ past_key_value = (K, V)
+ return attn_output, past_key_value, cos_cache, sin_cache
+
+
+# ============================================================
+# Dense MLP (replaces MoE)
+# ============================================================
+
+
+class Qwen35MLP(nn.Module):
+ """Dense SwiGLU MLP for Qwen3.5-2B.
+
+ gate_proj: hidden_size -> intermediate_size (2048 -> 6144)
+ up_proj: hidden_size -> intermediate_size (2048 -> 6144)
+ down_proj: intermediate_size -> hidden_size (6144 -> 2048)
+
+ output = down_proj(silu(gate_proj(x)) * up_proj(x))
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.gate_proj = ColumnParallelLinear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=False,
+ gather_output=False,
+ )
+ self.up_proj = ColumnParallelLinear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=False,
+ gather_output=False,
+ )
+ self.down_proj = RowParallelLinear(
+ config.intermediate_size,
+ config.hidden_size,
+ bias=False,
+ input_is_parallel=True,
+ )
+
+ def forward(self, hidden_states):
+ gate = self.gate_proj(hidden_states)
+ up = self.up_proj(hidden_states)
+ hidden_states = F.silu(gate) * up
+ hidden_states = self.down_proj(hidden_states)
+ return hidden_states
+
+
+# ============================================================
+# Decoder Layer (hybrid dispatch -- DeltaNet or GQA + Dense MLP)
+# ============================================================
+
+
+class NeuronQwen35DecoderLayer(nn.Module):
+ """Hybrid decoder layer: dispatches to DeltaNet or standard attention.
+ Uses dense MLP for all layers (no MoE).
+ """
+
+ def __init__(self, config: Qwen35InferenceConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.layer_type = config.layer_types[layer_idx]
+ self.layer_idx = layer_idx
+ self.config = config
+
+ # Attention (DeltaNet or standard GQA)
+ if self.layer_type == "linear_attention":
+ self.linear_attn = NeuronGatedDeltaNet(config, layer_idx)
+ else:
+ self.self_attn = NeuronQwen35Attention(config=config)
+
+ # Dense MLP (all layers)
+ self.mlp = Qwen35MLP(config)
+
+ self.input_layernorm = get_rmsnorm_cls()(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = get_rmsnorm_cls()(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask=None,
+ position_ids=None,
+ past_key_value=None,
+ padding_mask=None,
+ cos_cache=None,
+ sin_cache=None,
+ **kwargs,
+ ):
+ residual = hidden_states
+
+ hidden_states = ModuleMarkerStartWrapper()(hidden_states)
+ hidden_states = self.input_layernorm(hidden_states)
+
+ if self.layer_type == "linear_attention":
+ # DeltaNet path
+ attn_out, dummy_kv, new_rec_state, new_conv_state = self.linear_attn(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ **kwargs,
+ )
+ hidden_states = residual + attn_out
+ present_key_value = dummy_kv
+ deltanet_states = (new_rec_state, new_conv_state)
+ else:
+ deltanet_states = None
+ # Standard attention path
+ hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ cos_cache=cos_cache,
+ sin_cache=sin_cache,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Dense MLP FFN
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ hidden_states = ModuleMarkerEndWrapper()(hidden_states)
+ outputs = (
+ hidden_states,
+ present_key_value,
+ cos_cache,
+ sin_cache,
+ None,
+ deltanet_states,
+ )
+ return outputs
+
+
+# ============================================================
+# Model
+# ============================================================
+
+
+class NeuronQwen35Model(NeuronBaseModel):
+ def setup_attr_for_model(self, config: Qwen35InferenceConfig):
+ self.on_device_sampling = (
+ config.neuron_config.on_device_sampling_config is not None
+ )
+ self.tp_degree = config.neuron_config.tp_degree
+ self.hidden_size = config.hidden_size
+ self.num_attention_heads = config.num_attention_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.max_batch_size = config.neuron_config.max_batch_size
+ self.buckets = config.neuron_config.buckets
+
+ def init_model(self, config: Qwen35InferenceConfig):
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = ParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ self.padding_idx,
+ dtype=config.neuron_config.torch_dtype,
+ shard_across_embedding=True,
+ )
+ self.layers = nn.ModuleList(
+ [
+ NeuronQwen35DecoderLayer(config, layer_idx)
+ for layer_idx in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps)
+ self.lm_head = ColumnParallelLinear(
+ config.hidden_size,
+ config.vocab_size,
+ gather_output=False if self.on_device_sampling else True,
+ bias=False,
+ )
+
+ # mRoPE embedding for VL
+ self.mrope_emb = Qwen35MRoPEEmbedding(config)
+
+ @property
+ def _deltanet_state_params(self):
+ """Return DeltaNet state nn.Parameters in alias order."""
+ params = []
+ for layer in self.layers:
+ if hasattr(layer, "linear_attn"):
+ params.append(layer.linear_attn.recurrent_state_buffer)
+ params.append(layer.linear_attn.conv_state_buffer)
+ return params
+
+ def encode_vision_to_input(self, inputs_embeds, vision_embeddings, vision_mask):
+ """Scatter vision embeddings into text input embeddings at image token positions."""
+ _, max_positions, embedding_dim = inputs_embeds.shape
+ h_new = inputs_embeds.clone()
+ vision_flat = vision_embeddings.view(-1, embedding_dim)
+ positions_flat = vision_mask.view(-1)
+ h_new.view(-1, embedding_dim).index_put_(
+ (positions_flat,), vision_flat, accumulate=False
+ )
+ return h_new
+
+ def get_model_output(
+ self,
+ input_ids=None,
+ seq_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ past_key_values=None,
+ active_mask=None,
+ inputs_embeds=None,
+ prev_hidden=None,
+ adapter_ids=None,
+ rotary_position_ids=None,
+ update_cache=False,
+ is_for_context_encoding=False,
+ vision_embeddings=None,
+ vision_mask=None,
+ local_attn_mask=None,
+ windowed_context_encoding_window_idx=-1,
+ padding_mask=None,
+ **kwargs,
+ ):
+ """Override to collect DeltaNet state tensors from decoder layers."""
+ batch_size, seq_length = input_ids.shape[:2]
+ if self.config.neuron_config.layer_boundary_markers:
+ input_ids = ModuleMarkerStartWrapper()(input_ids)
+
+ past_key_values_length = 0
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][1].shape[2]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # CRITICAL: Zero out embeddings for padding tokens so DeltaNet recurrence
+ # is not polluted. DeltaNet has no attention mask -- it processes all
+ # sequence positions through a linear recurrence. Padding tokens have
+ # real embedding vectors which corrupt the recurrence state.
+ # The mask is [B, S, 1] float with 1.0 for real tokens, 0.0 for padding.
+ deltanet_padding_mask = (
+ (input_ids != self.padding_idx).unsqueeze(-1).to(inputs_embeds.dtype)
+ )
+ if is_for_context_encoding:
+ inputs_embeds = inputs_embeds * deltanet_padding_mask
+
+ # Vision embedding injection
+ if (vision_embeddings is not None) and (vision_mask is not None):
+ if vision_embeddings.dtype != self.config.neuron_config.torch_dtype:
+ vision_embeddings = vision_embeddings.to(
+ self.config.neuron_config.torch_dtype
+ )
+ if is_for_context_encoding:
+ inputs_embeds = self.encode_vision_to_input(
+ inputs_embeds, vision_embeddings, vision_mask
+ )
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ hidden_states = inputs_embeds
+
+ # Get KV cache for TKG
+ cache_size = self.n_positions
+ if not is_for_context_encoding:
+ if self.kv_mgr is not None:
+ past_key_values = self.kv_mgr.get_cache(
+ seq_ids=seq_ids,
+ seq_len=cache_size,
+ is_for_context_encoding=is_for_context_encoding,
+ windowed_context_encoding_window_idx=windowed_context_encoding_window_idx,
+ **kwargs,
+ )
+
+ # Decoder layers
+ next_decoder_cache = ()
+ deltanet_state_tensors = []
+ cos_cache = None
+ sin_cache = None
+
+ # Convert 2D attention_mask to 4D causal mask for CTE
+ if (
+ attention_mask is not None
+ and attention_mask.ndim == 2
+ and is_for_context_encoding
+ ):
+ causal = torch.ones(
+ (seq_length, seq_length),
+ dtype=torch.bool,
+ device=attention_mask.device,
+ ).tril()
+ padding_4d = attention_mask[:, None, None, :].to(torch.bool)
+ attention_mask = (causal[None, None, :, :] & padding_4d).to(
+ attention_mask.dtype
+ )
+
+ # Pre-compute mRoPE cos/sin
+ if rotary_position_ids is not None and rotary_position_ids.ndim == 3:
+ cos_cache, sin_cache = self.mrope_emb(inputs_embeds, rotary_position_ids)
+
+ for idx, decoder_layer in enumerate(self.layers):
+ past_key_value = (
+ past_key_values[idx] if past_key_values is not None else None
+ )
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ seq_ids=seq_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ active_mask=active_mask,
+ adapter_ids=adapter_ids,
+ cos_cache=cos_cache,
+ sin_cache=sin_cache,
+ rotary_position_ids=rotary_position_ids,
+ kv_mgr=self.kv_mgr,
+ get_kv_per_layer=False,
+ update_kv_per_layer=False,
+ idx=idx,
+ is_for_context_encoding=is_for_context_encoding,
+ seq_len=cache_size,
+ residual=None,
+ local_mask=local_attn_mask,
+ windowed_context_encoding_window_idx=windowed_context_encoding_window_idx,
+ padding_mask=padding_mask,
+ deltanet_padding_mask=deltanet_padding_mask,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+ kv = layer_outputs[1]
+ next_decoder_cache += (kv,)
+ cos_cache, sin_cache = layer_outputs[2:4]
+
+ # Collect DeltaNet state tensors
+ deltanet_states = layer_outputs[5] if len(layer_outputs) > 5 else None
+ if deltanet_states is not None:
+ deltanet_state_tensors.append(deltanet_states[0])
+ deltanet_state_tensors.append(deltanet_states[1])
+
+ # Update KV cache
+ if update_cache:
+ next_decoder_cache = self.kv_mgr.update_cache(
+ is_for_context_encoding=is_for_context_encoding,
+ seq_ids=seq_ids,
+ position_ids=position_ids,
+ new_key_values=next_decoder_cache,
+ seq_len=cache_size,
+ windowed_context_encoding_window_idx=windowed_context_encoding_window_idx,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ self._deltanet_updated_states = deltanet_state_tensors
+
+ return (hidden_states, next_decoder_cache)
+
+ def forward(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ seq_ids,
+ sampling_params,
+ prev_hidden=None,
+ adapter_ids=None,
+ accepted_indices=None,
+ current_length=None,
+ medusa_mask=None,
+ scatter_index=None,
+ slot_mapping=None,
+ active_block_table=None,
+ num_queries=None,
+ computed_context_lens=None,
+ tile_q_indices=None,
+ tile_block_tables=None,
+ tile_masks=None,
+ inputs_embeds=None,
+ kv_cache=None,
+ active_mask=None,
+ rotary_position_id=None,
+ vision_embeddings=None,
+ vision_mask=None,
+ ):
+ """Override base forward to append DeltaNet state tensors to output."""
+ prev_hidden = self.set_none_if_empty(prev_hidden)
+ adapter_ids = self.set_none_if_empty(adapter_ids)
+ accepted_indices = self.set_none_if_empty(accepted_indices)
+ current_length = self.set_none_if_empty(current_length)
+ medusa_mask = self.set_none_if_empty(medusa_mask)
+ scatter_index = self.set_none_if_empty(scatter_index)
+ slot_mapping = self.set_none_if_empty(slot_mapping)
+ active_block_table = self.set_none_if_empty(active_block_table)
+ num_queries = self.set_none_if_empty(num_queries)
+ computed_context_lens = self.set_none_if_empty(computed_context_lens)
+ tile_q_indices = self.set_none_if_empty(tile_q_indices)
+ tile_block_tables = self.set_none_if_empty(tile_block_tables)
+ tile_masks = self.set_none_if_empty(tile_masks)
+ inputs_embeds = self.set_none_if_empty(inputs_embeds)
+ kv_cache = self.set_none_if_empty(kv_cache)
+ active_mask = self.set_none_if_empty(active_mask)
+ rotary_position_id = self.set_none_if_empty(rotary_position_id)
+ vision_embeddings = self.set_none_if_empty(vision_embeddings)
+ vision_mask = self.set_none_if_empty(vision_mask)
+
+ is_for_context_encoding = position_ids.shape[-1] != 1 and not (
+ hasattr(self.neuron_config, "speculation_length")
+ and position_ids.shape[-1] == self.neuron_config.speculation_length
+ )
+
+ seq_ids = seq_ids.to(torch.int32)
+ attn_mask = attention_mask
+
+ hidden_states, updated_kv_cache = self.get_model_output(
+ input_ids=input_ids,
+ seq_ids=seq_ids,
+ attention_mask=attn_mask,
+ position_ids=position_ids,
+ active_mask=active_mask,
+ inputs_embeds=inputs_embeds,
+ adapter_ids=adapter_ids,
+ rotary_position_ids=rotary_position_id,
+ update_cache=True,
+ is_for_context_encoding=is_for_context_encoding,
+ padding_mask=None,
+ active_block_table=active_block_table,
+ scatter_index=slot_mapping
+ if getattr(self, "is_block_kv_layout", False)
+ else scatter_index,
+ vision_embeddings=vision_embeddings,
+ vision_mask=vision_mask,
+ )
+
+ batch_size = input_ids.shape[0]
+ if not getattr(self, "sliced_hidden", False):
+ if not is_for_context_encoding:
+ pass
+ else:
+ index = torch.max(position_ids, dim=1, keepdim=True).indices
+ index = index.unsqueeze(1).expand(batch_size, 1, self.hidden_size)
+ hidden_states = torch.gather(hidden_states, dim=1, index=index)
+
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ if hasattr(self.lm_head, "pad_size"):
+ if self.lm_head.gather_output:
+ rank_id = torch.tensor(0, device=logits.device, dtype=torch.int32)
+ world_size = 1
+ else:
+ from neuronx_distributed.parallel_layers import parallel_state
+
+ rank_id = self.rank_util.get_rank()
+ world_size = torch.distributed.get_world_size(
+ group=self.lm_head.tensor_parallel_group
+ )
+ from neuronx_distributed_inference.models.model_base import (
+ mask_padded_logits,
+ )
+
+ logits = mask_padded_logits(
+ logits, rank_id, world_size, pad_size=self.lm_head.pad_size
+ )
+
+ if self.on_device_sampling:
+ res = self._sample_on_device(
+ logits, sampling_params, False, is_for_context_encoding
+ )
+ else:
+ res = logits
+
+ outputs = [res]
+ if self.neuron_config.output_logits:
+ outputs += [logits]
+ outputs += updated_kv_cache
+
+ # Append DeltaNet state tensors (for input_output_aliases)
+ if hasattr(self, "_deltanet_updated_states"):
+ outputs += self._deltanet_updated_states
+
+ return outputs
+
+
+# ============================================================
+# State Dict Converter (Dense -- no MoE weight handling)
+# ============================================================
+
+
+def convert_qwen35_hf_to_neuron_state_dict(neuron_state_dict, config):
+ """Convert HF Qwen3.5 (dense) weights to NxDI format.
+
+ Weight mappings per layer type:
+
+ DeltaNet layers (linear_attention):
+ HF: layers.X.linear_attn.{in_proj_qkv, in_proj_z, in_proj_a, in_proj_b,
+ conv1d, A_log, dt_bias, norm, out_proj}
+ NxDI: same names (no remapping needed)
+
+ Full attention layers:
+ HF: layers.X.self_attn.q_proj.weight: (num_heads*head_dim*2, hidden) -- doubled for gate
+ NxDI: layers.X.self_attn.Wqkv.weight (fused Q+K+V, gate separated)
+ layers.X.self_attn.output_gate_proj.weight (gate part)
+ HF: layers.X.self_attn.{k_proj, v_proj, o_proj, q_norm, k_norm}
+ NxDI: layers.X.self_attn.{..., q_layernorm, k_layernorm}
+
+ Dense MLP (all layers):
+ HF: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight
+ NxDI: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight (same names)
+ """
+ # Add rank_util
+ neuron_state_dict["rank_util.rank"] = torch.arange(
+ 0,
+ config.neuron_config.tp_degree,
+ dtype=torch.int32,
+ )
+
+ # CRITICAL: Convert (1+weight) RMSNorm weights to standard RMSNorm weights.
+ # Qwen3.5 uses RMSNorm with `output = norm(x) * (1 + weight)` where weight
+ # is initialized to zeros. Standard NxDI RMSNorm uses `output = norm(x) * weight`
+ # where weight is initialized to ones. To convert: new_weight = old_weight + 1.0
+ norm_keys_to_convert = []
+ for l in range(config.num_hidden_layers):
+ norm_keys_to_convert.append(f"layers.{l}.input_layernorm.weight")
+ norm_keys_to_convert.append(f"layers.{l}.post_attention_layernorm.weight")
+ if config.layer_types[l] == "full_attention":
+ norm_keys_to_convert.append(f"layers.{l}.self_attn.q_norm.weight")
+ norm_keys_to_convert.append(f"layers.{l}.self_attn.k_norm.weight")
+ norm_keys_to_convert.append("norm.weight")
+
+ for nk in norm_keys_to_convert:
+ if nk in neuron_state_dict:
+ old_val = neuron_state_dict[nk]
+ neuron_state_dict[nk] = old_val.float() + 1.0
+ if "layers.0." in nk or nk == "norm.weight":
+ logger.debug(
+ f"[NORM FIX] {nk}: mean {old_val.float().mean():.4f} -> {neuron_state_dict[nk].mean():.4f}"
+ )
+ else:
+ if "layers.0." in nk or nk == "norm.weight":
+ logger.warning(f"[NORM FIX] key not found: {nk}")
+
+ for l in range(config.num_hidden_layers):
+ layer_type = config.layer_types[l]
+
+ # === Attention layers ===
+ if layer_type == "full_attention":
+ neuron_state_dict[f"layers.{l}.self_attn.rank_util.rank"] = torch.arange(
+ 0,
+ config.neuron_config.tp_degree,
+ dtype=torch.int32,
+ )
+
+ # QK norms: q_norm -> q_layernorm, k_norm -> k_layernorm
+ q_norm_key = f"layers.{l}.self_attn.q_norm.weight"
+ k_norm_key = f"layers.{l}.self_attn.k_norm.weight"
+ if q_norm_key in neuron_state_dict:
+ neuron_state_dict[f"layers.{l}.self_attn.q_layernorm.weight"] = (
+ neuron_state_dict.pop(q_norm_key).detach().clone()
+ )
+ if k_norm_key in neuron_state_dict:
+ neuron_state_dict[f"layers.{l}.self_attn.k_layernorm.weight"] = (
+ neuron_state_dict.pop(k_norm_key).detach().clone()
+ )
+
+ # q_proj is doubled: (num_heads * head_dim * 2, hidden_size)
+ # INTERLEAVED: [head0_query(head_dim) | head0_gate(head_dim) | head1_query(head_dim) | ...]
+ q_proj_key = f"layers.{l}.self_attn.q_proj.weight"
+ if q_proj_key in neuron_state_dict:
+ q_proj_w = neuron_state_dict.pop(q_proj_key)
+ num_heads = config.num_attention_heads
+ head_dim = config.head_dim
+ q_proj_w = q_proj_w.reshape(num_heads, head_dim * 2, config.hidden_size)
+ query_w = q_proj_w[:, :head_dim, :]
+ gate_w = q_proj_w[:, head_dim:, :]
+ query_w = query_w.reshape(num_heads * head_dim, config.hidden_size)
+ gate_w = gate_w.reshape(num_heads * head_dim, config.hidden_size)
+
+ neuron_state_dict[q_proj_key] = query_w
+ neuron_state_dict[f"layers.{l}.self_attn.output_gate_proj.weight"] = (
+ gate_w
+ )
+
+ # Fuse QKV
+ if config.neuron_config.fused_qkv:
+ q_key = f"layers.{l}.self_attn.q_proj.weight"
+ k_key = f"layers.{l}.self_attn.k_proj.weight"
+ v_key = f"layers.{l}.self_attn.v_proj.weight"
+ if q_key in neuron_state_dict:
+ neuron_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = torch.cat(
+ [
+ neuron_state_dict[q_key],
+ neuron_state_dict[k_key],
+ neuron_state_dict[v_key],
+ ]
+ )
+ del neuron_state_dict[q_key]
+ del neuron_state_dict[k_key]
+ del neuron_state_dict[v_key]
+
+ # Dense MLP: no weight conversion needed -- HF and NxDI use same names
+ # HF: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight
+ # NxDI: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight
+
+ gc.collect()
+
+ return neuron_state_dict
+
+
+# ============================================================
+# Custom ModelWrapper and DecoderModelInstance for DeltaNet state aliasing
+# ============================================================
+
+
+class Qwen35DecoderModelInstance(DecoderModelInstance):
+ """Custom DecoderModelInstance that adds DeltaNet state buffers to input_output_aliases."""
+
+ def get(self, bucket_rank, **kwargs):
+ """Override to add DeltaNet state aliases after KV cache aliases."""
+ module, input_output_aliases = super().get(bucket_rank, **kwargs)
+
+ num_output_from_trace = 1 if not self.neuron_config.output_logits else 2
+
+ if module.kv_mgr is not None:
+ num_kv = len(module.kv_mgr.past_key_values)
+ else:
+ num_kv = 0
+
+ state_start_idx = num_output_from_trace + num_kv
+
+ if hasattr(module, "_deltanet_state_params"):
+ for i, param in enumerate(module._deltanet_state_params):
+ input_output_aliases[param] = state_start_idx + i
+
+ return module, input_output_aliases
+
+
+class Qwen35ModelWrapper(ModelWrapper):
+ """Custom ModelWrapper for VL support with mRoPE and vision inputs."""
+
+ def get_model_instance(self):
+ return Qwen35DecoderModelInstance(
+ model_cls=self.model_cls,
+ config=self.config,
+ **self.model_init_kwargs,
+ )
+
+ def input_generator(self):
+ """Generate inputs including mrope_position_ids, vision_embeddings, and vision_mask."""
+ base_inputs = super().input_generator()
+ extended_inputs = []
+
+ for bucket_inputs in base_inputs:
+ input_ids = bucket_inputs[0]
+ batch_size = input_ids.shape[0]
+ n_active_tokens = input_ids.shape[1]
+
+ is_cte = n_active_tokens > 1
+
+ if is_cte:
+ mrope_position_ids = (
+ torch.arange(0, n_active_tokens, dtype=torch.int32)
+ .unsqueeze(0)
+ .unsqueeze(0)
+ .expand(3, batch_size, -1)
+ .contiguous()
+ )
+
+ vision_embeddings = torch.zeros(
+ (batch_size, n_active_tokens, self.config.hidden_size),
+ dtype=self.config.neuron_config.torch_dtype,
+ )
+ vision_mask = torch.full(
+ (batch_size, n_active_tokens, 1),
+ fill_value=n_active_tokens - 1,
+ dtype=torch.int32,
+ )
+ else:
+ mrope_position_ids = torch.zeros((0,), dtype=torch.int32)
+ vision_embeddings = torch.zeros(
+ (0,), dtype=self.config.neuron_config.torch_dtype
+ )
+ vision_mask = torch.zeros((0,), dtype=torch.int32)
+
+ padded = list(bucket_inputs)
+ while len(padded) < 21:
+ padded.append(torch.zeros((0,), dtype=torch.int32))
+ padded.append(mrope_position_ids) # position 21
+ padded.append(vision_embeddings) # position 22
+ padded.append(vision_mask) # position 23
+
+ extended_inputs.append(tuple(padded))
+
+ return extended_inputs
+
+ def pad_inputs(self, *args, pad_type="first_fit"):
+ """Override to pad mrope_position_ids and vision inputs to bucket size."""
+ orig_mrope = args[21] if len(args) >= 22 else None
+ orig_vis_emb = args[22] if len(args) >= 23 else None
+ orig_vis_mask = args[23] if len(args) >= 24 else None
+
+ padded_args = super().pad_inputs(*args, pad_type=pad_type)
+
+ if len(padded_args) >= 24 and orig_mrope is not None:
+ padded_seq_len = padded_args[0].shape[1]
+ batch_size = padded_args[0].shape[0]
+ is_cte = padded_seq_len > 1
+
+ if is_cte:
+ current_mrope = orig_mrope
+ current_vis_emb = orig_vis_emb
+ current_vis_mask = orig_vis_mask
+
+ if (
+ current_mrope.ndim == 3
+ and current_mrope.shape[-1] != padded_seq_len
+ ):
+ orig_len = current_mrope.shape[-1]
+ pad_size = padded_seq_len - orig_len
+ last_pos = current_mrope[:, :, -1:]
+ pad_offsets = torch.arange(
+ 1, pad_size + 1, dtype=current_mrope.dtype
+ )
+ pad_offsets = (
+ pad_offsets.unsqueeze(0).unsqueeze(0).expand(3, batch_size, -1)
+ )
+ mrope_pad = last_pos + pad_offsets
+ mrope_position_ids = torch.cat([current_mrope, mrope_pad], dim=-1)
+ elif current_mrope.ndim == 3:
+ mrope_position_ids = current_mrope
+ else:
+ mrope_position_ids = (
+ torch.arange(0, padded_seq_len, dtype=torch.int32)
+ .unsqueeze(0)
+ .unsqueeze(0)
+ .expand(3, batch_size, -1)
+ .contiguous()
+ )
+
+ if (
+ current_vis_emb is not None
+ and current_vis_emb.ndim == 3
+ and current_vis_emb.shape[1] < padded_seq_len
+ ):
+ pad_emb = torch.zeros(
+ (
+ batch_size,
+ padded_seq_len - current_vis_emb.shape[1],
+ current_vis_emb.shape[2],
+ ),
+ dtype=current_vis_emb.dtype,
+ )
+ vision_embeddings = torch.cat([current_vis_emb, pad_emb], dim=1)
+ elif current_vis_emb is not None and current_vis_emb.ndim == 3:
+ vision_embeddings = current_vis_emb[:, :padded_seq_len]
+ else:
+ vision_embeddings = torch.zeros(
+ (batch_size, padded_seq_len, self.config.hidden_size),
+ dtype=self.config.neuron_config.torch_dtype,
+ )
+
+ if (
+ current_vis_mask is not None
+ and current_vis_mask.ndim == 3
+ and current_vis_mask.shape[1] < padded_seq_len
+ ):
+ pad_mask = torch.full(
+ (batch_size, padded_seq_len - current_vis_mask.shape[1], 1),
+ fill_value=padded_seq_len - 1,
+ dtype=torch.int32,
+ )
+ vision_mask = torch.cat([current_vis_mask, pad_mask], dim=1)
+ elif current_vis_mask is not None and current_vis_mask.ndim == 3:
+ vision_mask = current_vis_mask[:, :padded_seq_len]
+ else:
+ vision_mask = torch.full(
+ (batch_size, padded_seq_len, 1),
+ fill_value=padded_seq_len - 1,
+ dtype=torch.int32,
+ )
+
+ padded_args = (
+ *padded_args[:21],
+ mrope_position_ids,
+ vision_embeddings,
+ vision_mask,
+ )
+
+ padded_args = list(padded_args)
+ padded_args[23] = padded_args[23].clamp(max=padded_seq_len - 1)
+ padded_args = tuple(padded_args)
+
+ return padded_args
+
+
+# ============================================================
+# Top-Level Model
+# ============================================================
+
+
+class NeuronQwen35ForCausalLM(NeuronBaseForCausalLM):
+ _model_cls = NeuronQwen35Model
+
+ def get_model_wrapper_cls(self):
+ """Return custom ModelWrapper with DeltaNet state aliasing."""
+ return Qwen35ModelWrapper
+
+ @staticmethod
+ def load_hf_model(model_path, **kwargs):
+ """Load HF model weights.
+
+ The model is a VL model (Qwen3_5ForConditionalGeneration) but we
+ only need the text backbone.
+ """
+ from transformers import AutoModelForCausalLM
+
+ kwargs.setdefault("trust_remote_code", True)
+ return AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
+
+ @classmethod
+ def get_config_cls(cls):
+ return Qwen35InferenceConfig
+
+ @staticmethod
+ def update_state_dict_for_tied_weights(state_dict):
+ """Copy embed_tokens weight to lm_head for tied embeddings."""
+ state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone()
+
+ @staticmethod
+ def convert_hf_to_neuron_state_dict(state_dict, config):
+ """Strip VL wrapper prefix and convert to NxDI format."""
+ new_sd = {}
+ for k, v in state_dict.items():
+ if k.startswith("language_model."):
+ new_k = k.replace("language_model.", "", 1)
+ new_sd[new_k] = v
+ elif k.startswith("model.language_model."):
+ new_k = k.replace("model.language_model.", "", 1)
+ new_sd[new_k] = v
+ elif k.startswith("model.visual") or k.startswith("visual"):
+ continue # Skip vision encoder
+ elif k.startswith("model."):
+ new_sd[k.replace("model.", "", 1)] = v
+ elif k.startswith("mtp."):
+ continue # Skip MTP
+ elif k.startswith("lm_head."):
+ new_sd[k] = v
+ else:
+ new_sd[k] = v
+
+ return convert_qwen35_hf_to_neuron_state_dict(new_sd, config)
+
+ def enable_context_encoding(self):
+ self.compile_tag = CONTEXT_ENCODING_MODEL_TAG
+ super().enable_context_encoding()
+
+ def enable_token_generation(self):
+ self.compile_tag = TOKEN_GENERATION_MODEL_TAG
+ super().enable_token_generation()
+
+ def _copy_past_key_values(self, outputs):
+ """Override to also copy DeltaNet state buffers on CPU."""
+ super()._copy_past_key_values(outputs)
+
+ num_output_from_trace = 1
+ if (
+ self.neuron_config.output_logits
+ and self.neuron_config.on_device_sampling_config
+ ):
+ num_output_from_trace = 2
+
+ if (
+ hasattr(self, "token_generation_model")
+ and self.token_generation_model is not None
+ ):
+ tkg_model = self.token_generation_model.model
+ cte_model = self.context_encoding_model.model
+ else:
+ return
+
+ if tkg_model.kv_mgr is not None:
+ num_kv = len(tkg_model.kv_mgr.past_key_values)
+ else:
+ num_kv = 0
+
+ state_start = num_output_from_trace + num_kv
+
+ tkg_params = getattr(tkg_model, "_deltanet_state_params", [])
+ cte_params = getattr(cte_model, "_deltanet_state_params", [])
+
+ if len(tkg_params) > 0 and state_start + len(tkg_params) <= len(outputs):
+ for i, (tkg_param, cte_param) in enumerate(zip(tkg_params, cte_params)):
+ new_state = outputs[state_start + i]
+ tkg_param.data = new_state
+ cte_param.data = new_state
+
+ def get_required_kwargs(self):
+ """Return extra kwargs for HF generation loop."""
+ return ["llava_args"]
+
+ def _get_model_outputs(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ seq_ids,
+ sampling_params,
+ prev_hidden,
+ adapter_ids,
+ medusa_args,
+ llava_args,
+ slot_mapping=None,
+ block_table=None,
+ full_context_lens=None,
+ computed_context_lens=None,
+ tf_args=None,
+ ):
+ """Override to pass all 24 positional args explicitly."""
+ is_prefill = self._is_prefill(position_ids)
+
+ seq_len = input_ids.shape[1]
+ batch_size = input_ids.shape[0]
+
+ if llava_args and len(llava_args) >= 2:
+ vision_embeddings = llava_args[0]
+ vision_mask = llava_args[1]
+ if len(llava_args) >= 3:
+ mrope_position_ids = llava_args[2]
+ else:
+ mrope_position_ids = None
+ elif is_prefill:
+ vision_embeddings = torch.zeros(
+ (batch_size, seq_len, self.config.hidden_size),
+ dtype=self.config.neuron_config.torch_dtype,
+ )
+ vision_mask = torch.full(
+ (batch_size, seq_len, 1),
+ fill_value=seq_len - 1,
+ dtype=torch.int32,
+ )
+ mrope_position_ids = None
+ else:
+ vision_embeddings = torch.zeros((0,), dtype=torch.float32)
+ vision_mask = torch.zeros((0,), dtype=torch.int32)
+ mrope_position_ids = None
+
+ if is_prefill:
+ if mrope_position_ids is None:
+ mrope_position_ids = (
+ torch.arange(0, seq_len, dtype=torch.int32)
+ .unsqueeze(0)
+ .unsqueeze(0)
+ .expand(3, batch_size, -1)
+ .contiguous()
+ )
+ else:
+ mrope_position_ids = torch.zeros((0,), dtype=torch.int32)
+
+ empties = [torch.empty(0) for _ in range(14)]
+
+ if self._is_prefill(position_ids):
+ ctx_bs = self.context_encoding_model.neuron_config.batch_size
+ output_logits = []
+
+ for cb in range(0, batch_size, ctx_bs):
+ cb_end = min(cb + ctx_bs, batch_size)
+ actual_chunk = cb_end - cb
+
+ chunk_input_ids = input_ids[cb:cb_end]
+ chunk_attn_mask = attention_mask[cb:cb_end]
+ chunk_pos_ids = position_ids[cb:cb_end]
+ chunk_seq_ids = seq_ids[cb:cb_end]
+ chunk_sampling = sampling_params[cb:cb_end]
+ chunk_prev_hidden = (
+ prev_hidden[cb:cb_end]
+ if prev_hidden is not None
+ and hasattr(prev_hidden, "ndim")
+ and prev_hidden.ndim > 0
+ and prev_hidden.shape[0] > 0
+ else prev_hidden
+ )
+ chunk_adapter_ids = (
+ adapter_ids[cb:cb_end]
+ if adapter_ids is not None
+ and hasattr(adapter_ids, "ndim")
+ and adapter_ids.ndim > 0
+ and adapter_ids.shape[0] > 0
+ else adapter_ids
+ )
+
+ if mrope_position_ids.ndim == 3:
+ chunk_mrope = mrope_position_ids[:, cb:cb_end, :]
+ else:
+ chunk_mrope = mrope_position_ids
+
+ if vision_embeddings.ndim == 3:
+ chunk_vis_emb = vision_embeddings[cb:cb_end]
+ chunk_vis_mask = vision_mask[cb:cb_end]
+ else:
+ chunk_vis_emb = vision_embeddings
+ chunk_vis_mask = vision_mask
+
+ if actual_chunk < ctx_bs:
+ pad_n = ctx_bs - actual_chunk
+ chunk_input_ids = torch.cat(
+ [chunk_input_ids, chunk_input_ids[:1].expand(pad_n, -1)], dim=0
+ )
+ chunk_attn_mask = torch.cat(
+ [chunk_attn_mask, chunk_attn_mask[:1].expand(pad_n, -1)], dim=0
+ )
+ chunk_pos_ids = torch.cat(
+ [chunk_pos_ids, chunk_pos_ids[:1].expand(pad_n, -1)], dim=0
+ )
+ pad_seq = torch.arange(
+ batch_size, batch_size + pad_n, dtype=chunk_seq_ids.dtype
+ )
+ chunk_seq_ids = torch.cat([chunk_seq_ids, pad_seq], dim=0)
+ chunk_sampling = torch.cat(
+ [chunk_sampling, chunk_sampling[:1].expand(pad_n, -1)], dim=0
+ )
+ if (
+ chunk_prev_hidden is not None
+ and hasattr(chunk_prev_hidden, "ndim")
+ and chunk_prev_hidden.ndim > 0
+ and chunk_prev_hidden.shape[0] > 0
+ ):
+ chunk_prev_hidden = torch.cat(
+ [
+ chunk_prev_hidden,
+ chunk_prev_hidden[:1].expand(pad_n, -1),
+ ],
+ dim=0,
+ )
+ if (
+ chunk_adapter_ids is not None
+ and hasattr(chunk_adapter_ids, "ndim")
+ and chunk_adapter_ids.ndim > 0
+ and chunk_adapter_ids.shape[0] > 0
+ ):
+ chunk_adapter_ids = torch.cat(
+ [
+ chunk_adapter_ids,
+ chunk_adapter_ids[:1].expand(pad_n, -1),
+ ],
+ dim=0,
+ )
+ if chunk_mrope.ndim == 3:
+ chunk_mrope = torch.cat(
+ [chunk_mrope, chunk_mrope[:, :1, :].expand(-1, pad_n, -1)],
+ dim=1,
+ )
+ if chunk_vis_emb.ndim == 3:
+ chunk_vis_emb = torch.cat(
+ [
+ chunk_vis_emb,
+ torch.zeros(
+ (pad_n,) + chunk_vis_emb.shape[1:],
+ dtype=chunk_vis_emb.dtype,
+ ),
+ ],
+ dim=0,
+ )
+ chunk_vis_mask = torch.cat(
+ [
+ chunk_vis_mask,
+ torch.full(
+ (pad_n,) + chunk_vis_mask.shape[1:],
+ fill_value=seq_len - 1,
+ dtype=chunk_vis_mask.dtype,
+ ),
+ ],
+ dim=0,
+ )
+
+ chunk_out = self.context_encoding_model(
+ chunk_input_ids,
+ chunk_attn_mask,
+ chunk_pos_ids,
+ chunk_seq_ids,
+ chunk_sampling,
+ chunk_prev_hidden,
+ chunk_adapter_ids,
+ *empties,
+ chunk_mrope,
+ chunk_vis_emb,
+ chunk_vis_mask,
+ )
+ if actual_chunk < ctx_bs:
+ chunk_out = chunk_out[:actual_chunk]
+ output_logits.append(chunk_out)
+
+ outputs = (
+ torch.cat(output_logits, dim=0)
+ if len(output_logits) > 1
+ else output_logits[0]
+ )
+ self.kv_cache_populated = True
+ is_run_on_neuron = self.context_encoding_model.is_neuron()
+ else:
+ outputs = self.token_generation_model(
+ input_ids,
+ attention_mask,
+ position_ids,
+ seq_ids,
+ sampling_params,
+ prev_hidden,
+ adapter_ids,
+ *empties,
+ mrope_position_ids,
+ vision_embeddings,
+ vision_mask,
+ )
+ is_run_on_neuron = self.token_generation_model.is_neuron()
+
+ return outputs, is_run_on_neuron
+
+ def get_compiler_args(self):
+ if self.compile_tag == CONTEXT_ENCODING_MODEL_TAG:
+ optimization_level = "-O1"
+ else:
+ optimization_level = "-O1"
+
+ compiler_args = (
+ "--enable-saturate-infinity "
+ "--enable-mixed-precision-accumulation "
+ f"--model-type transformer {optimization_level} "
+ "--auto-cast=none "
+ )
+ return compiler_args
diff --git a/contrib/models/Qwen3.5-2B/src/modeling_qwen35_vision.py b/contrib/models/Qwen3.5-2B/src/modeling_qwen35_vision.py
new file mode 100644
index 00000000..59b73b18
--- /dev/null
+++ b/contrib/models/Qwen3.5-2B/src/modeling_qwen35_vision.py
@@ -0,0 +1,823 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Qwen3.5 (Dense) Vision Encoder for NeuronX Distributed Inference.
+
+Ports the Qwen3.5 ViT encoder to run on Neuron. The vision encoder
+architecture is identical across model sizes -- dimensions are read from config:
+- Qwen3.5-2B: depth=24, hidden=1024, out_hidden=2048
+
+Dimensions are read from config so this module works for any Qwen3.5 dense model size.
+
+The vision encoder runs as a separate compiled model from the text decoder,
+compiled and loaded via NeuronBaseForImageToText.
+"""
+
+import logging
+import math
+import os
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# CRITICAL: Use finite negative value instead of -inf for Neuron attention masks.
+# The Neuron compiler's bfloat16 handling of -inf produces NaN that bleeds from
+# padding positions into ALL positions through the transformer layers.
+# -65504.0 is large enough for softmax masking but avoids NaN overflow.
+_MASK_NEG_INF = -65504.0
+
+logger = logging.getLogger(__name__)
+
+# -- NxDI imports (available on Neuron instances) --
+try:
+ from neuronx_distributed_inference.models.application_base import (
+ NeuronApplicationBase,
+ )
+ from neuronx_distributed_inference.models.model_wrapper import ModelWrapper
+ from neuronx_distributed_inference.modules.attention.attention_base import (
+ NeuronAttentionBase,
+ )
+ from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding
+ from neuronx_distributed.parallel_layers import layers as nxd_layers
+except ImportError:
+ logger.warning(
+ "NxDI imports unavailable -- vision module can only be used on Neuron instances"
+ )
+
+# -- HuggingFace imports for patch embed (runs on CPU) --
+try:
+ from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
+ Qwen3_5MoeVisionPatchEmbed,
+ Qwen3_5MoeVisionPatchMerger,
+ Qwen3_5MoeVisionRotaryEmbedding,
+ )
+except ImportError:
+ try:
+ # transformers 4.57+ uses Qwen3VL* class names
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import (
+ Qwen3VLVisionPatchEmbed as Qwen3_5MoeVisionPatchEmbed,
+ Qwen3VLVisionPatchMerger as Qwen3_5MoeVisionPatchMerger,
+ Qwen3VLVisionRotaryEmbedding as Qwen3_5MoeVisionRotaryEmbedding,
+ )
+ except ImportError:
+ try:
+ # Older transformers uses Qwen2VL* class names
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
+ Qwen2VLVisionPatchEmbed as Qwen3_5MoeVisionPatchEmbed,
+ Qwen2VLVisionPatchMerger as Qwen3_5MoeVisionPatchMerger,
+ Qwen2VLVisionRotaryEmbedding as Qwen3_5MoeVisionRotaryEmbedding,
+ )
+ except ImportError:
+ Qwen3_5MoeVisionPatchEmbed = None
+ Qwen3_5MoeVisionPatchMerger = None
+ Qwen3_5MoeVisionRotaryEmbedding = None
+
+
+def apply_rotary_pos_emb_vision(q, k, cos, sin):
+ """Apply rotary position embeddings to vision Q and K tensors.
+
+ Uses rotate_half style (matching HF reference):
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+
+ Args:
+ q: (seq_len, num_heads, head_dim)
+ k: (seq_len, num_heads, head_dim)
+ cos: (seq_len, head_dim)
+ sin: (seq_len, head_dim)
+ """
+ cos = cos.unsqueeze(-2) # (seq_len, 1, head_dim)
+ sin = sin.unsqueeze(-2)
+
+ def rotate_half(x):
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed.to(q.dtype), k_embed.to(k.dtype)
+
+
+class NeuronQwen35VisionAttention(nn.Module):
+ """Vision attention for Qwen3.5 MoE.
+
+ Uses fused QKV linear (no bias in Neuron port for efficiency).
+ Non-causal attention with block-diagonal mask for variable-length images.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.scaling = self.head_dim**-0.5
+
+ # Fused QKV: (hidden_size -> 3 * hidden_size) with bias
+ self.qkv = nxd_layers.ColumnParallelLinear(
+ self.hidden_size,
+ 3 * self.hidden_size,
+ bias=True,
+ gather_output=True,
+ )
+ self.proj = nxd_layers.RowParallelLinear(
+ self.hidden_size,
+ self.hidden_size,
+ bias=True,
+ input_is_parallel=False,
+ )
+
+ def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
+ """
+ Args:
+ hidden_states: (seq_len, hidden_size)
+ attention_mask: (1, 1, seq_len, seq_len) block-diagonal mask
+ position_embeddings: (cos, sin) tuple
+ """
+ seq_len = hidden_states.shape[0]
+
+ # QKV projection
+ qkv = self.qkv(hidden_states) # (seq_len, 3 * hidden_size)
+ qkv = qkv.reshape(seq_len, 3, self.num_heads, self.head_dim)
+ qkv = qkv.permute(1, 0, 2, 3) # (3, seq_len, num_heads, head_dim)
+ q, k, v = qkv.unbind(0) # each (seq_len, num_heads, head_dim)
+
+ # Apply rotary embeddings
+ if position_embeddings is not None:
+ cos, sin = position_embeddings
+ q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
+
+ # Reshape for batched attention: (1, num_heads, seq_len, head_dim)
+ q = q.transpose(0, 1).unsqueeze(0)
+ k = k.transpose(0, 1).unsqueeze(0)
+ v = v.transpose(0, 1).unsqueeze(0)
+
+ # Scaled dot-product attention
+ attn_weights = torch.matmul(q, k.transpose(-1, -2)) * self.scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
+ attn_output = torch.matmul(attn_weights, v)
+
+ # Reshape back: (seq_len, hidden_size)
+ attn_output = attn_output.squeeze(0).transpose(0, 1).reshape(seq_len, -1)
+
+ # Output projection
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class NeuronQwen35VisionMLP(nn.Module):
+ """Vision MLP with GELU activation."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.linear_fc1 = nxd_layers.ColumnParallelLinear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ gather_output=True,
+ )
+ self.linear_fc2 = nxd_layers.RowParallelLinear(
+ config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ input_is_parallel=False,
+ )
+ self.act_fn = nn.GELU()
+
+ def forward(self, hidden_states):
+ return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_states)))
+
+
+class NeuronQwen35VisionBlock(nn.Module):
+ """Single vision transformer block: LayerNorm + Attention + LayerNorm + MLP."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ self.attn = NeuronQwen35VisionAttention(config)
+ self.mlp = NeuronQwen35VisionMLP(config)
+
+ def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ attention_mask=attention_mask,
+ position_embeddings=position_embeddings,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class NeuronQwen35VisionModel(nn.Module):
+ """Qwen3.5 MoE Vision Encoder for Neuron.
+
+ This is the nn.Module that gets compiled and traced onto Neuron.
+ Patch embedding, positional embedding, and rotary embedding are computed
+ on CPU in the ModelWrapper and passed as inputs.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.blocks = nn.ModuleList(
+ [NeuronQwen35VisionBlock(config) for _ in range(config.depth)]
+ )
+ # Merger: spatial_merge_size^2 * hidden_size -> out_hidden_size
+ self.merger_norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ merger_hidden = config.hidden_size * (config.spatial_merge_size**2)
+ self.merger_fc1 = nn.Linear(merger_hidden, merger_hidden)
+ self.merger_act = nn.GELU()
+ self.merger_fc2 = nn.Linear(merger_hidden, config.out_hidden_size)
+
+ def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
+ """
+ Args:
+ hidden_states: (seq_len, hidden_size) -- after patch_embed + pos_embed
+ attention_mask: (1, 1, seq_len, seq_len) block-diagonal mask
+ position_embeddings: (cos, sin) tuple for rotary
+
+ Returns:
+ vision_embeddings: (merged_seq_len, out_hidden_size)
+ """
+ for block in self.blocks:
+ hidden_states = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_embeddings=position_embeddings,
+ )
+
+ # Apply merger: norm -> spatial merge -> fc1 -> gelu -> fc2
+ hidden_states = self.merger_norm(hidden_states)
+ merge_size = self.config.spatial_merge_size
+ merged_hidden = self.config.hidden_size * (merge_size**2)
+ hidden_states = hidden_states.view(-1, merged_hidden)
+ hidden_states = self.merger_fc2(self.merger_act(self.merger_fc1(hidden_states)))
+
+ return hidden_states
+
+
+class CPUVisionModel(nn.Module):
+ """CPU-only vision encoder (pure PyTorch, no Neuron dependencies).
+
+ Used when HBM is insufficient to load the vision encoder on Neuron
+ alongside the text decoder (e.g., when HBM is limited).
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.blocks = nn.ModuleList(
+ [self._make_block(config) for _ in range(config.depth)]
+ )
+ self.merger_norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ merger_hidden = config.hidden_size * (config.spatial_merge_size**2)
+ self.merger_fc1 = nn.Linear(merger_hidden, merger_hidden)
+ self.merger_act = nn.GELU()
+ self.merger_fc2 = nn.Linear(merger_hidden, config.out_hidden_size)
+
+ @staticmethod
+ def _make_block(config):
+ """Build a single vision block with standard nn.Linear (no TP)."""
+ block = nn.Module()
+ block.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ block.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+
+ # Attention
+ attn = nn.Module()
+ attn.hidden_size = config.hidden_size
+ attn.num_heads = config.num_heads
+ attn.head_dim = config.hidden_size // config.num_heads
+ attn.scaling = attn.head_dim**-0.5
+ attn.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=True)
+ attn.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
+ block.attn = attn
+
+ # MLP
+ mlp = nn.Module()
+ mlp.linear_fc1 = nn.Linear(
+ config.hidden_size, config.intermediate_size, bias=True
+ )
+ mlp.linear_fc2 = nn.Linear(
+ config.intermediate_size, config.hidden_size, bias=True
+ )
+ mlp.act_fn = nn.GELU()
+ block.mlp = mlp
+
+ return block
+
+ def _forward_attention(self, attn, hidden_states, attention_mask, cos, sin):
+ seq_len = hidden_states.shape[0]
+ qkv = attn.qkv(hidden_states).reshape(seq_len, 3, attn.num_heads, attn.head_dim)
+ qkv = qkv.permute(1, 0, 2, 3)
+ q, k, v = qkv.unbind(0)
+
+ if cos is not None and sin is not None:
+ cos_u = cos.unsqueeze(-2)
+ sin_u = sin.unsqueeze(-2)
+
+ def rotate_half(x):
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ q = (q * cos_u) + (rotate_half(q) * sin_u)
+ k = (k * cos_u) + (rotate_half(k) * sin_u)
+
+ q = q.transpose(0, 1).unsqueeze(0)
+ k = k.transpose(0, 1).unsqueeze(0)
+ v = v.transpose(0, 1).unsqueeze(0)
+
+ attn_weights = torch.matmul(q, k.transpose(-1, -2)) * attn.scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
+ out = torch.matmul(attn_weights, v)
+ out = out.squeeze(0).transpose(0, 1).reshape(seq_len, -1)
+ return attn.proj(out)
+
+ def forward(self, hidden_states, attention_mask, cos, sin):
+ for block in self.blocks:
+ hidden_states = hidden_states + self._forward_attention(
+ block.attn, block.norm1(hidden_states), attention_mask, cos, sin
+ )
+ hidden_states = hidden_states + block.mlp.linear_fc2(
+ block.mlp.act_fn(block.mlp.linear_fc1(block.norm2(hidden_states)))
+ )
+
+ hidden_states = self.merger_norm(hidden_states)
+ merge_size = self.config.spatial_merge_size
+ merged_hidden = self.config.hidden_size * (merge_size**2)
+ hidden_states = hidden_states.view(-1, merged_hidden)
+ hidden_states = self.merger_fc2(self.merger_act(self.merger_fc1(hidden_states)))
+ return hidden_states
+
+
+class NeuronQwen35VisionModelWrapper(ModelWrapper):
+ """Wraps the vision encoder for NxDI tracing.
+
+ Handles CPU-side operations that cannot be traced:
+ - Patch embedding (Conv3d)
+ - Positional embedding (Embedding + bilinear interpolation)
+ - Rotary position embedding computation
+ - Vision attention mask construction (block-diagonal)
+ - Sequence length bucketing and padding/unpadding
+
+ Supports three modes:
+ 1. NxDI traced model (parallel layers) -- standard NxDI compilation
+ 2. Pre-compiled standalone model -- loaded from torch_neuronx.trace() output
+ 3. CPU-only model -- for when HBM is full for the vision encoder
+ """
+
+ def __init__(self, config, model_cls=None, **kwargs):
+ if model_cls is not None:
+ super().__init__(config, model_cls, **kwargs)
+ else:
+ # Standalone mode: no NxDI model_cls
+ nn.Module.__init__(self)
+ self.vision_config = config
+ self._compiled_model = None # Set by load_compiled() -- single bucket
+ self._compiled_buckets = None # Set by load_compiled() -- multi-bucket dict
+ self._cpu_model = None # Set by load_cpu_model()
+
+ # These HF modules run on CPU, outside the traced graph
+ if Qwen3_5MoeVisionPatchEmbed is not None:
+ self.patch_embed = Qwen3_5MoeVisionPatchEmbed(config)
+ self.pos_embed = nn.Embedding(
+ config.num_position_embeddings, config.hidden_size
+ )
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
+ head_dim = config.hidden_size // config.num_heads
+ self.rotary_pos_emb = Qwen3_5MoeVisionRotaryEmbedding(head_dim // 2)
+ else:
+ logger.warning("HF Qwen3.5 MoE vision classes not available")
+
+ self.vision_seq_len_buckets = kwargs.get(
+ "vision_seq_len_buckets", [1024, 4096, 16384]
+ )
+
+ def load_compiled(self, compiled_model_path):
+ """Load pre-compiled standalone vision encoder(s).
+
+ Supports two modes:
+ 1. Single .pt file: Legacy mode, loads one compiled model for one bucket size.
+ 2. Directory with multiple .pt files: Multi-bucket mode. Files must be named
+ 'vision_encoder_{bucket_size}.pt' (e.g., 'vision_encoder_256.pt').
+ Falls back to single 'vision_encoder.pt' in the directory.
+
+ Args:
+ compiled_model_path: Path to a .pt file or directory containing bucket .pt files.
+ """
+ import glob as glob_module
+
+ logger.info(f"Loading pre-compiled vision encoder from {compiled_model_path}")
+
+ if os.path.isfile(compiled_model_path):
+ # Single file mode (legacy)
+ self._compiled_model = torch.jit.load(compiled_model_path)
+ self._compiled_buckets = None
+ logger.info("Vision encoder loaded successfully (single bucket)")
+ elif os.path.isdir(compiled_model_path):
+ # Directory mode: look for bucket-specific files
+ bucket_files = sorted(
+ glob_module.glob(
+ os.path.join(compiled_model_path, "vision_encoder_*.pt")
+ )
+ )
+ if bucket_files:
+ self._compiled_buckets = {}
+ for bf in bucket_files:
+ # Extract bucket size from filename: vision_encoder_256.pt -> 256
+ basename = os.path.basename(bf)
+ try:
+ bucket_size = int(
+ basename.replace("vision_encoder_", "").replace(".pt", "")
+ )
+ self._compiled_buckets[bucket_size] = torch.jit.load(bf)
+ logger.info(f" Loaded vision bucket {bucket_size} from {bf}")
+ except ValueError:
+ logger.warning(f" Skipping unrecognized file: {bf}")
+ self._compiled_model = None
+ # Update vision_seq_len_buckets to match compiled buckets
+ self.vision_seq_len_buckets = sorted(self._compiled_buckets.keys())
+ logger.info(
+ f"Vision encoder loaded with {len(self._compiled_buckets)} buckets: "
+ f"{self.vision_seq_len_buckets}"
+ )
+ else:
+ # Fall back to single vision_encoder.pt in directory
+ single_path = os.path.join(compiled_model_path, "vision_encoder.pt")
+ if os.path.exists(single_path):
+ self._compiled_model = torch.jit.load(single_path)
+ self._compiled_buckets = None
+ logger.info(
+ "Vision encoder loaded successfully (single file in dir)"
+ )
+ else:
+ raise FileNotFoundError(
+ f"No vision encoder files found in {compiled_model_path}"
+ )
+ else:
+ raise FileNotFoundError(
+ f"Vision encoder path not found: {compiled_model_path}"
+ )
+
+ def load_vision_weights_from_hf(self, model_path):
+ """Load patch_embed and pos_embed weights from HF safetensors.
+
+ Args:
+ model_path: Path to HF model directory
+ """
+ from pathlib import Path
+ from safetensors import safe_open
+
+ st_files = sorted(
+ p
+ for p in Path(model_path).glob("*.safetensors")
+ if p.suffix == ".safetensors"
+ )
+ loaded = 0
+ for sf_path in st_files:
+ with safe_open(str(sf_path), framework="pt") as f:
+ for key in f.keys():
+ if key == "model.visual.patch_embed.proj.weight":
+ self.patch_embed.proj.weight.data.copy_(f.get_tensor(key))
+ loaded += 1
+ elif key == "model.visual.patch_embed.proj.bias":
+ self.patch_embed.proj.bias.data.copy_(f.get_tensor(key))
+ loaded += 1
+ elif key == "model.visual.pos_embed.weight":
+ self.pos_embed.weight.data.copy_(f.get_tensor(key))
+ loaded += 1
+ logger.info(f"Loaded {loaded} CPU-side vision weight tensors from HF")
+
+ def load_cpu_model(self, model_path):
+ """Load a CPU-only vision encoder from HF safetensors.
+
+ Use this when HBM is insufficient for the Neuron-compiled vision encoder
+ (e.g., when the text decoder fills available HBM).
+
+ Args:
+ model_path: Path to HF model directory with safetensors
+ """
+ from pathlib import Path
+ from safetensors import safe_open
+
+ config = self.vision_config
+ cpu_model = CPUVisionModel(config)
+
+ # Build key mapping from HF safetensors to CPU model
+ key_map = {}
+ for i in range(config.depth):
+ hf_pre = f"model.visual.blocks.{i}"
+ loc_pre = f"blocks.{i}"
+ for suffix in [
+ "attn.qkv.weight",
+ "attn.qkv.bias",
+ "attn.proj.weight",
+ "attn.proj.bias",
+ "mlp.linear_fc1.weight",
+ "mlp.linear_fc1.bias",
+ "mlp.linear_fc2.weight",
+ "mlp.linear_fc2.bias",
+ "norm1.weight",
+ "norm1.bias",
+ "norm2.weight",
+ "norm2.bias",
+ ]:
+ key_map[f"{hf_pre}.{suffix}"] = f"{loc_pre}.{suffix}"
+
+ key_map["model.visual.merger.norm.weight"] = "merger_norm.weight"
+ key_map["model.visual.merger.norm.bias"] = "merger_norm.bias"
+ key_map["model.visual.merger.linear_fc1.weight"] = "merger_fc1.weight"
+ key_map["model.visual.merger.linear_fc1.bias"] = "merger_fc1.bias"
+ key_map["model.visual.merger.linear_fc2.weight"] = "merger_fc2.weight"
+ key_map["model.visual.merger.linear_fc2.bias"] = "merger_fc2.bias"
+
+ st_files = sorted(Path(model_path).glob("model*.safetensors"))
+ loaded = 0
+ state_dict = cpu_model.state_dict()
+
+ for sf_path in st_files:
+ with safe_open(str(sf_path), framework="pt") as f:
+ for key in f.keys():
+ if key in key_map:
+ local_key = key_map[key]
+ if local_key in state_dict:
+ state_dict[local_key].copy_(f.get_tensor(key))
+ loaded += 1
+
+ cpu_model.load_state_dict(state_dict)
+ cpu_model = cpu_model.to(torch.bfloat16).eval()
+ self._cpu_model = cpu_model
+ logger.info(
+ f"Loaded CPU vision encoder: {loaded} weights, "
+ f"{sum(p.numel() for p in cpu_model.parameters()) / 1e6:.1f}M params"
+ )
+
+ def _get_vision_bucket(self, seq_len):
+ """Find the smallest bucket that fits the sequence length."""
+ for bucket in sorted(self.vision_seq_len_buckets):
+ if seq_len <= bucket:
+ return bucket
+ return self.vision_seq_len_buckets[-1]
+
+ def rot_pos_emb(self, grid_thw):
+ """Compute rotary positional embeddings for vision tokens.
+
+ Returns: (total_tokens, head_dim) tensor of rotary frequencies.
+ """
+ merge_size = self.vision_config.spatial_merge_size
+ grid_thw_list = grid_thw.tolist()
+
+ max_hw = max(max(h, w) for _, h, w in grid_thw_list)
+ freq_table = self.rotary_pos_emb(max_hw)
+ device = freq_table.device
+
+ total_tokens = sum(t * h * w for t, h, w in grid_thw_list)
+ pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
+
+ offset = 0
+ for num_frames, height, width in grid_thw_list:
+ merged_h, merged_w = height // merge_size, width // merge_size
+
+ block_rows = torch.arange(merged_h, device=device)
+ block_cols = torch.arange(merged_w, device=device)
+ intra_row = torch.arange(merge_size, device=device)
+ intra_col = torch.arange(merge_size, device=device)
+
+ row_idx = (
+ block_rows[:, None, None, None] * merge_size
+ + intra_row[None, None, :, None]
+ )
+ col_idx = (
+ block_cols[None, :, None, None] * merge_size
+ + intra_col[None, None, None, :]
+ )
+
+ row_idx = row_idx.expand(
+ merged_h, merged_w, merge_size, merge_size
+ ).reshape(-1)
+ col_idx = col_idx.expand(
+ merged_h, merged_w, merge_size, merge_size
+ ).reshape(-1)
+
+ coords = torch.stack((row_idx, col_idx), dim=-1)
+ if num_frames > 1:
+ coords = coords.repeat(num_frames, 1)
+
+ num_tokens = coords.shape[0]
+ pos_ids[offset : offset + num_tokens] = coords
+ offset += num_tokens
+
+ embeddings = freq_table[pos_ids]
+ embeddings = embeddings.flatten(1)
+ return embeddings
+
+ def fast_pos_embed_interpolate(self, grid_thw):
+ """Bilinear interpolation of positional embeddings for variable resolution."""
+ grid_thw_list = grid_thw.tolist()
+ grid_ts = [row[0] for row in grid_thw_list]
+ grid_hs = [row[1] for row in grid_thw_list]
+ grid_ws = [row[2] for row in grid_thw_list]
+ device = self.pos_embed.weight.device
+
+ idx_list = [[] for _ in range(4)]
+ weight_list = [[] for _ in range(4)]
+
+ for t, h, w in grid_thw_list:
+ h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
+ w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
+
+ h_idxs_floor = h_idxs.int()
+ w_idxs_floor = w_idxs.int()
+ h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+ w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+
+ dh = h_idxs - h_idxs_floor
+ dw = w_idxs - w_idxs_floor
+
+ base_h = h_idxs_floor * self.num_grid_per_side
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
+
+ indices = [
+ (base_h[None].T + w_idxs_floor[None]).flatten(),
+ (base_h[None].T + w_idxs_ceil[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
+ ]
+ weights = [
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
+ ((1 - dh)[None].T * dw[None]).flatten(),
+ (dh[None].T * (1 - dw)[None]).flatten(),
+ (dh[None].T * dw[None]).flatten(),
+ ]
+
+ for i in range(4):
+ idx_list[i].extend(indices[i].tolist())
+ weight_list[i].extend(weights[i].tolist())
+
+ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
+ weight_tensor = torch.tensor(
+ weight_list, dtype=self.pos_embed.weight.dtype, device=device
+ )
+ pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
+
+ patch_pos_embeds = patch_pos_embeds.split(
+ [h * w for h, w in zip(grid_hs, grid_ws)]
+ )
+
+ merge_size = self.vision_config.spatial_merge_size
+ patch_pos_embeds_permute = []
+ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
+ pos_embed = pos_embed.repeat(t, 1)
+ pos_embed = (
+ pos_embed.view(
+ t, h // merge_size, merge_size, w // merge_size, merge_size, -1
+ )
+ .permute(0, 1, 3, 2, 4, 5)
+ .flatten(0, 4)
+ )
+ patch_pos_embeds_permute.append(pos_embed)
+
+ return torch.cat(patch_pos_embeds_permute)
+
+ def _build_vision_attention_mask(self, grid_thw, seq_len, dtype):
+ """Build block-diagonal attention mask for variable-length images.
+
+ Each image gets its own attention block (no cross-image attention).
+ """
+ cu_seqlens = torch.repeat_interleave(
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
+ ).cumsum(dim=0, dtype=torch.int32)
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ # Build block-diagonal mask
+ mask = torch.full((seq_len, seq_len), _MASK_NEG_INF, dtype=dtype)
+ for i in range(len(cu_seqlens) - 1):
+ start = cu_seqlens[i].item()
+ end = cu_seqlens[i + 1].item()
+ mask[start:end, start:end] = 0.0
+
+ return mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
+
+ def forward(self, pixel_values, image_grid_thw):
+ """Run vision encoding (CPU preprocessing + Neuron traced model).
+
+ Args:
+ pixel_values: Raw pixel values from HF processor
+ image_grid_thw: (num_images, 3) -- temporal, height, width in patches
+
+ Returns:
+ vision_embeddings: (total_merged_tokens, out_hidden_size)
+ """
+ # 1. Patch embedding (CPU, Conv3d)
+ hidden_states = self.patch_embed(pixel_values)
+
+ # 2. Positional embedding (CPU, bilinear interpolation)
+ pos_embeds = self.fast_pos_embed_interpolate(image_grid_thw)
+ hidden_states = hidden_states + pos_embeds
+
+ # 3. Rotary position embeddings (CPU)
+ rotary_pos_emb = self.rot_pos_emb(image_grid_thw)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ # 4. Vision attention mask (block-diagonal)
+ seq_len = hidden_states.shape[0]
+ attention_mask = self._build_vision_attention_mask(
+ image_grid_thw, seq_len, hidden_states.dtype
+ )
+
+ # 5. Bucket and pad for Neuron compilation
+ bucket_len = self._get_vision_bucket(seq_len)
+ cos, sin = position_embeddings
+ if seq_len < bucket_len:
+ pad_len = bucket_len - seq_len
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad_len))
+ cos = F.pad(cos, (0, 0, 0, pad_len))
+ sin = F.pad(sin, (0, 0, 0, pad_len))
+ # Extend mask with _MASK_NEG_INF for padded positions (NOT -inf, which causes NaN on Neuron)
+ mask = torch.full(
+ (1, 1, bucket_len, bucket_len), _MASK_NEG_INF, dtype=hidden_states.dtype
+ )
+ mask[:, :, :seq_len, :seq_len] = attention_mask
+ attention_mask = mask
+
+ # 6. Run vision model (Neuron compiled or CPU fallback)
+ if self._compiled_buckets is not None:
+ # Multi-bucket mode: select the compiled model for this bucket
+ if bucket_len not in self._compiled_buckets:
+ raise RuntimeError(
+ f"No compiled vision encoder for bucket size {bucket_len}. "
+ f"Available buckets: {sorted(self._compiled_buckets.keys())}. "
+ f"Input seq_len={seq_len} requires bucket {bucket_len}."
+ )
+ compiled_model = self._compiled_buckets[bucket_len]
+ vision_output = compiled_model(
+ hidden_states.to(torch.bfloat16),
+ attention_mask.to(torch.bfloat16),
+ cos.to(torch.bfloat16),
+ sin.to(torch.bfloat16),
+ )
+ elif self._compiled_model is not None:
+ # Single compiled model (legacy)
+ vision_output = self._compiled_model(
+ hidden_states.to(torch.bfloat16),
+ attention_mask.to(torch.bfloat16),
+ cos.to(torch.bfloat16),
+ sin.to(torch.bfloat16),
+ )
+ elif self._cpu_model is not None:
+ # CPU-only mode: run vision encoder on CPU (no bucketing/padding needed
+ # but we pad anyway for consistency with the same merger math)
+ with torch.no_grad():
+ vision_output = self._cpu_model(
+ hidden_states.to(torch.bfloat16),
+ attention_mask.to(torch.bfloat16),
+ cos.to(torch.bfloat16),
+ sin.to(torch.bfloat16),
+ )
+ else:
+ # NxDI traced model: takes (hidden_states, attention_mask, position_embeddings)
+ vision_output = self.model(hidden_states, attention_mask, (cos, sin))
+
+ # 7. Unpad: only keep valid merged tokens
+ merge_area = self.vision_config.spatial_merge_size**2
+ total_merged_tokens = sum(
+ t
+ * (h // self.vision_config.spatial_merge_size)
+ * (w // self.vision_config.spatial_merge_size)
+ for t, h, w in image_grid_thw.tolist()
+ )
+ vision_output = vision_output[:total_merged_tokens]
+
+ return vision_output
+
+
+class NeuronQwen35VisionForImageEncoding(NeuronApplicationBase):
+ """Standalone application class for vision encoding (for testing)."""
+
+ model_cls = NeuronQwen35VisionModel
+ model_wrapper_cls = NeuronQwen35VisionModelWrapper
+
+ @staticmethod
+ def prepare_input_args(image_path, processor):
+ """Prepare vision inputs from an image path.
+
+ Args:
+ image_path: Path to image file
+ processor: HF AutoProcessor
+
+ Returns:
+ pixel_values, image_grid_thw
+ """
+ from PIL import Image
+
+ image = Image.open(image_path).convert("RGB")
+ inputs = processor(images=image, return_tensors="pt")
+ return inputs["pixel_values"], inputs["image_grid_thw"]
diff --git a/contrib/models/Qwen3.5-2B/src/modeling_qwen35_vl.py b/contrib/models/Qwen3.5-2B/src/modeling_qwen35_vl.py
new file mode 100644
index 00000000..64cbe71e
--- /dev/null
+++ b/contrib/models/Qwen3.5-2B/src/modeling_qwen35_vl.py
@@ -0,0 +1,665 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Qwen3.5 Vision-Language Model Orchestrator for NeuronX Distributed Inference.
+
+This is the top-level VL model that wires together:
+- The vision encoder (modeling_qwen35_vision.py)
+- The text decoder (modeling_qwen35.py, dense model with vision injection)
+
+It handles:
+- Multimodal RoPE (mRoPE) with interleaved layout
+- Vision embedding injection via scatter_by_index_put
+- Separate compilation and loading of vision and text models
+- The CTE+TKG generation loop with vision inputs
+
+Architecture follows the NxDI NeuronBaseForImageToText pattern established
+by Qwen3-VL in SDK 2.28, adapted for Qwen3.5 dense model's unique features:
+- No deepstack (Qwen3.5 does not use intermediate vision feature injection)
+- DeltaNet linear attention layers in the text decoder
+- Dense SwiGLU MLP layers in the text decoder
+- Interleaved mRoPE (THWTHW... layout) instead of Qwen3-VL's section-based layout
+"""
+
+import logging
+import os
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+
+logger = logging.getLogger(__name__)
+
+# NxDI imports
+try:
+ from neuronx_distributed_inference.models.image_to_text_model_base import (
+ ImageToTextInferenceConfig,
+ NeuronBaseForImageToText,
+ )
+ from neuronx_distributed_inference.models.config import NeuronConfig
+
+ HAS_NXDI_VL = True
+except ImportError:
+ HAS_NXDI_VL = False
+ logger.warning("NxDI VL base classes not available -- VL model requires SDK 2.28+")
+
+# Local imports
+try:
+ from src.modeling_qwen35 import (
+ NeuronQwen35ForCausalLM,
+ NeuronQwen35Model,
+ Qwen35InferenceConfig,
+ Qwen35ModelWrapper,
+ )
+ from src.modeling_qwen35_vision import (
+ NeuronQwen35VisionModel,
+ NeuronQwen35VisionModelWrapper,
+ )
+except ImportError:
+ from modeling_qwen35 import (
+ NeuronQwen35ForCausalLM,
+ NeuronQwen35Model,
+ Qwen35InferenceConfig,
+ Qwen35ModelWrapper,
+ )
+ from modeling_qwen35_vision import (
+ NeuronQwen35VisionModel,
+ NeuronQwen35VisionModelWrapper,
+ )
+
+
+def get_rope_index(
+ input_ids,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ attention_mask=None,
+ image_token_id=248056,
+ video_token_id=248057,
+ vision_start_token_id=248053,
+ spatial_merge_size=2,
+):
+ """Compute 3D multimodal RoPE position IDs for Qwen3.5.
+
+ Returns position_ids of shape (3, batch_size, seq_len) where:
+ - Axis 0: temporal position
+ - Axis 1: height position
+ - Axis 2: width position
+
+ For text tokens, all 3 axes have the same sequential position.
+ For vision tokens, each axis encodes the spatial/temporal grid position.
+
+ Also returns rope_deltas for use during TKG decoding.
+
+ Adapted from HuggingFace Qwen3_5Model.get_rope_index().
+ """
+ if video_grid_thw is not None:
+ video_grid_thw = torch.repeat_interleave(
+ video_grid_thw, video_grid_thw[:, 0], dim=0
+ )
+ video_grid_thw[:, 0] = 1
+
+ image_grid_thw_list = (
+ image_grid_thw.tolist() if image_grid_thw is not None else None
+ )
+ video_grid_thw_list = (
+ video_grid_thw.tolist() if video_grid_thw is not None else None
+ )
+
+ mrope_position_deltas = []
+ total_input_ids = input_ids
+
+ if attention_mask is None:
+ attention_mask = torch.ones_like(total_input_ids)
+
+ position_ids = torch.zeros(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+
+ image_index, video_index = 0, 0
+ attention_mask = attention_mask.to(total_input_ids.device)
+
+ for i, ids in enumerate(total_input_ids):
+ ids = ids[attention_mask[i] == 1]
+ image_nums, video_nums = 0, 0
+
+ vision_start_indices = torch.argwhere(ids == vision_start_token_id).squeeze(1)
+ if len(vision_start_indices) > 0:
+ vision_tokens = ids[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (vision_tokens == video_token_id).sum()
+
+ input_tokens = ids.tolist()
+ llm_pos_ids_list = []
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+
+ if ed_image < ed_video:
+ t, h, w = image_grid_thw_list[image_index]
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+ else:
+ t, h, w = video_grid_thw_list[video_index]
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+
+ llm_grid_t = t
+ llm_grid_h = h // spatial_merge_size
+ llm_grid_w = w // spatial_merge_size
+
+ text_len = ed - st
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
+ )
+
+ t_index = (
+ torch.arange(llm_grid_t)
+ .view(-1, 1)
+ .expand(-1, llm_grid_h * llm_grid_w)
+ .flatten()
+ )
+ h_index = (
+ torch.arange(llm_grid_h)
+ .view(1, -1, 1)
+ .expand(llm_grid_t, -1, llm_grid_w)
+ .flatten()
+ )
+ w_index = (
+ torch.arange(llm_grid_w)
+ .view(1, 1, -1)
+ .expand(llm_grid_t, llm_grid_h, -1)
+ .flatten()
+ )
+ llm_pos_ids_list.append(
+ torch.stack([t_index, h_index, w_index]) + text_len + st_idx
+ )
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
+ )
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
+ position_ids.device
+ )
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
+
+ mrope_position_deltas = torch.tensor(
+ mrope_position_deltas, device=input_ids.device
+ ).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+
+
+class Qwen35VLInferenceConfig:
+ """Configuration for the full VL model (text + vision).
+
+ Wraps the existing Qwen35InferenceConfig for text and adds
+ vision-specific settings.
+ """
+
+ def __init__(
+ self,
+ text_config,
+ vision_config,
+ image_token_id=248056,
+ video_token_id=248057,
+ vision_start_token_id=248053,
+ vision_end_token_id=248054,
+ spatial_merge_size=2,
+ vision_seq_len_buckets=None,
+ **kwargs,
+ ):
+ """
+ Args:
+ text_config: Qwen35InferenceConfig instance for the text decoder
+ vision_config: dict with vision encoder hyperparams (depth, hidden_size, etc.)
+ image_token_id: Token ID for image placeholder tokens
+ video_token_id: Token ID for video placeholder tokens
+ vision_start_token_id: Token ID for <|vision_start|>
+ vision_end_token_id: Token ID for <|vision_end|>
+ spatial_merge_size: How many patches are merged (2 = 2x2 = 4 patches merged)
+ vision_seq_len_buckets: List of vision sequence length buckets for compilation
+ """
+ self.text_config = text_config
+ self.vision_config = vision_config
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+ self.spatial_merge_size = spatial_merge_size
+ self.vision_seq_len_buckets = vision_seq_len_buckets or [1024, 4096, 16384]
+
+
+class NeuronQwen35VLForCausalLM:
+ """Top-level VL model for Qwen3.5-2B on Neuron.
+
+ This class manages:
+ - Separate compilation/loading of vision encoder and text decoder
+ - CPU-side mRoPE computation
+ - Vision embedding injection into text decoder
+ - The CTE+TKG generation loop
+
+ Note: This is NOT an NeuronBaseForImageToText subclass because the
+ text decoder (NeuronQwen35ForCausalLM) has extensive custom overrides
+ (DeltaNet state management, custom forward, custom ModelWrapper) that
+ don't fit the base class pattern. Instead, this class composes the two
+ models and handles the VL orchestration directly.
+ """
+
+ def __init__(self, model_path, text_config, vision_config=None, processor=None):
+ """
+ Args:
+ model_path: Path to HF model directory
+ text_config: Qwen35InferenceConfig for text decoder
+ vision_config: Qwen35VLInferenceConfig (or None for text-only)
+ processor: HF AutoProcessor for image preprocessing
+ """
+ self.model_path = model_path
+ self.text_config = text_config
+ self.vl_config = vision_config
+ self.processor = processor
+
+ # Text decoder (existing implementation)
+ self.text_model = NeuronQwen35ForCausalLM(
+ model_path=model_path, config=text_config
+ )
+
+ # Vision encoder (lazy init -- only built if vl_config provided)
+ self.vision_model_wrapper = None
+ if vision_config is not None:
+ self._init_vision_model(vision_config)
+
+ # mRoPE state
+ self.rope_deltas = None
+
+ def _init_vision_model(self, vl_config):
+ """Initialize the vision encoder wrapper."""
+ from types import SimpleNamespace
+
+ vision_cfg = SimpleNamespace(**vl_config.vision_config)
+ self.vision_model_wrapper = NeuronQwen35VisionModelWrapper(
+ config=vision_cfg,
+ model_cls=None, # Standalone mode (no NxDI parallel layers)
+ vision_seq_len_buckets=vl_config.vision_seq_len_buckets,
+ )
+ self._vl_config = vl_config
+
+ def compile(self, compiled_model_path):
+ """Compile both text and vision models.
+
+ For the vision encoder, use compile_vision_encoder.py separately
+ (standalone torch_neuronx.trace compilation). Then use load() to
+ load the pre-compiled vision encoder.
+ """
+ # Compile text decoder
+ text_path = os.path.join(compiled_model_path, "text_model")
+ os.makedirs(text_path, exist_ok=True)
+ self.text_model.compile(text_path)
+
+ # Vision encoder is compiled separately via compile_vision_encoder.py
+ if self.vision_model_wrapper is not None:
+ logger.info(
+ "Vision encoder must be compiled separately using "
+ "compile_vision_encoder.py. Use load() to load the "
+ "pre-compiled vision encoder."
+ )
+
+ def load(self, compiled_model_path, vision_compiled_path=None):
+ """Load both compiled models.
+
+ Args:
+ compiled_model_path: Path to compiled text model (or parent dir)
+ vision_compiled_path: Path to compiled vision encoder .pt file.
+ If None, looks for 'vision_encoder.pt' in compiled_model_path.
+ """
+ text_path = os.path.join(compiled_model_path, "text_model")
+ if os.path.exists(text_path):
+ self.text_model.load(text_path)
+ else:
+ # Backward compatibility: text model compiled at root
+ self.text_model.load(compiled_model_path)
+
+ # Load vision encoder
+ if self.vision_model_wrapper is not None:
+ if vision_compiled_path is None:
+ vision_compiled_path = os.path.join(
+ compiled_model_path, "vision_encoder.pt"
+ )
+ if os.path.exists(vision_compiled_path):
+ self.vision_model_wrapper.load_compiled(vision_compiled_path)
+ # Also load CPU-side weights (patch_embed, pos_embed)
+ self.vision_model_wrapper.load_vision_weights_from_hf(self.model_path)
+ logger.info("Vision encoder loaded from pre-compiled model")
+ else:
+ logger.warning(
+ f"No compiled vision encoder found at {vision_compiled_path}. "
+ "Vision encoding will not be available."
+ )
+
+ # Qwen3.5 stop token IDs (loaded from config/tokenizer)
+ _DEFAULT_EOS_TOKEN_IDS = {
+ 248044, # <|endoftext|> -- text config eos_token_id
+ 248046, # <|im_end|> -- tokenizer eos_token / end of assistant turn
+ }
+
+ def generate(
+ self,
+ input_ids,
+ attention_mask=None,
+ pixel_values=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ max_new_tokens=32,
+ temperature=0.0,
+ top_p=1.0,
+ top_k=0,
+ eos_token_ids=None,
+ **kwargs,
+ ):
+ """Generate text from text and/or vision inputs.
+
+ Args:
+ input_ids: (batch_size, seq_len) token IDs
+ attention_mask: (batch_size, seq_len) attention mask
+ pixel_values: Vision pixel values from HF processor (or None for text-only)
+ image_grid_thw: (num_images, 3) grid dimensions
+ video_grid_thw: (num_videos, 3) grid dimensions
+ max_new_tokens: Maximum new tokens to generate
+ temperature: Sampling temperature (0.0 = greedy/argmax)
+ top_p: Nucleus sampling threshold (1.0 = disabled)
+ top_k: Top-k sampling (0 = disabled)
+ eos_token_ids: Set of token IDs to stop generation on
+ (default: {248044, 248046})
+
+ Returns:
+ generated_ids: (batch_size, seq_len + max_new_tokens) token IDs
+ """
+ if eos_token_ids is None:
+ eos_token_ids = self._DEFAULT_EOS_TOKEN_IDS
+
+ # Reset text model state for a fresh generation.
+ # This ensures CTE runs (not TKG) even if a prior generate() was called.
+ # DeltaNet recurrent states don't need explicit zeroing because the CTE
+ # NKI kernel always starts from zero state.
+ self.text_model.reset()
+
+ has_vision = pixel_values is not None and pixel_values.numel() > 0
+
+ # Step 1: Compute 3D mRoPE position IDs
+ if has_vision and self._vl_config is not None:
+ position_ids, self.rope_deltas = get_rope_index(
+ input_ids,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ attention_mask=attention_mask,
+ image_token_id=self._vl_config.image_token_id,
+ video_token_id=self._vl_config.video_token_id,
+ vision_start_token_id=self._vl_config.vision_start_token_id,
+ spatial_merge_size=self._vl_config.spatial_merge_size,
+ )
+ else:
+ # Text-only: use standard sequential position IDs
+ seq_len = input_ids.shape[1]
+ position_ids = torch.arange(seq_len).unsqueeze(0)
+ self.rope_deltas = None
+
+ # Step 2: Run vision encoder and prepare injection args
+ llava_args = []
+ batch_size = input_ids.shape[0]
+ if has_vision and self.vision_model_wrapper is not None:
+ # The vision encoder processes both image and video frames identically
+ # (they share the same ViT architecture). The HF processor outputs a
+ # single pixel_values tensor for images, and video frames are treated
+ # as multiple images with temporal grid > 1.
+ vision_embeddings = self.vision_model_wrapper(pixel_values, image_grid_thw)
+ # vision_embeddings: (total_merged_tokens, out_hidden_size)
+
+ # Build vision_mask: boolean mask of ALL vision token positions
+ # (both image_token_id and video_token_id placeholders)
+ image_token_id = self._vl_config.image_token_id
+ video_token_id = self._vl_config.video_token_id
+ vision_bool_mask = (input_ids == image_token_id) | (
+ input_ids == video_token_id
+ ) # (BS, seq_len)
+
+ # For batch_size=1 (primary path): extract positions from batch element 0.
+ # For batch_size>1: each element may have different image token positions;
+ # we'd need per-element scatter. Currently only batch_size=1 is supported
+ # for VL (the compiled model uses batch_size=1 for CTE).
+ if batch_size > 1:
+ logger.warning(
+ "VL generation with batch_size > 1 is not fully supported. "
+ "Using batch element 0 for vision scatter positions."
+ )
+
+ positions = (
+ vision_bool_mask[0].nonzero(as_tuple=False).squeeze(-1)
+ ) # (n_vision_tokens,)
+
+ # Reshape vision_embeddings to (1, n_vision_tokens, hidden_size)
+ n_vis = positions.shape[0]
+ hidden_size = vision_embeddings.shape[-1]
+ vis_emb = vision_embeddings[:n_vis].unsqueeze(0) # (1, n_vis, hidden)
+
+ # Pad to match input sequence length for compiled graph compatibility
+ seq_len = input_ids.shape[1]
+ pad_limit = seq_len # Must match the bucket size
+
+ # Pad vision_embeddings to (1, pad_limit, hidden_size)
+ if n_vis < pad_limit:
+ pad_emb = torch.zeros(
+ (1, pad_limit - n_vis, hidden_size),
+ dtype=vis_emb.dtype,
+ )
+ vis_emb_padded = torch.cat([vis_emb, pad_emb], dim=1)
+ else:
+ vis_emb_padded = vis_emb[:, :pad_limit]
+
+ # Pad positions to (1, pad_limit, 1) with a SAFE fill value.
+ # CRITICAL: fill_value must be a valid index (within [0, pad_limit-1]).
+ # Using pad_limit-1 targets the last position (always a padding slot)
+ # so index_put_ scatters zero embeddings there harmlessly.
+ # NOTE: Do NOT use large sentinel values (e.g., 2**30) as they cause
+ # DGE out-of-bounds crashes in the Neuron runtime.
+ positions_padded = torch.full(
+ (1, pad_limit, 1),
+ fill_value=pad_limit - 1,
+ dtype=torch.int32,
+ )
+ positions_padded[0, :n_vis, 0] = positions[:pad_limit].to(torch.int32)
+
+ llava_args = [vis_emb_padded, positions_padded]
+
+ # Append 3D mRoPE position IDs for the text model.
+ # position_ids shape: (3, batch_size, seq_len) from get_rope_index.
+ # _get_model_outputs receives this at slot 21 and pre-computes
+ # mRoPE cos/sin in get_model_output() for all decoder layers.
+ if position_ids.ndim == 3:
+ mrope_pos = position_ids[:, :, :seq_len].to(torch.int32).contiguous()
+ llava_args.append(mrope_pos)
+ else:
+ vision_embeddings = None
+
+ # Step 3: Context encoding (prefill)
+ generated_ids = input_ids.clone()
+
+ # CRITICAL: Always pass an explicit attention_mask for CTE.
+ # The base class _infer_attention_mask() assumes sequential position_ids
+ # (position_ids[i] >= i). When position_ids come from mRoPE temporal
+ # axis (non-sequential, e.g., all vision tokens share position 4),
+ # the inferred mask incorrectly masks out most of the sequence.
+ # Fix: provide a real all-ones mask for the actual token positions.
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids)
+
+ # For slot 2 (position_ids): use SEQUENTIAL positions regardless of mRoPE.
+ # Slot 2 is only used for: (1) logit position selection via torch.max(),
+ # (2) attention mask inference (which we bypass with explicit mask above).
+ # The actual RoPE computation uses slot 21 (rotary_position_ids) from
+ # _get_model_outputs, NOT slot 2. Using sequential slot 2 ensures
+ # correct logit selection and avoids any position_ids-related issues.
+ seq_len = input_ids.shape[1]
+ cte_position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
+
+ with torch.no_grad():
+ output = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=cte_position_ids,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=False,
+ llava_args=llava_args,
+ )
+
+ logits = output[0] if isinstance(output, tuple) else output.logits
+ next_token = self._sample_token(logits[:, -1, :], temperature, top_p, top_k)
+ generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1)
+
+ # Check EOS after first token
+ if next_token.item() in eos_token_ids:
+ return generated_ids
+
+ # Step 4: Token generation (TKG) loop
+ for _ in range(max_new_tokens - 1):
+ pos_ids = torch.tensor([[generated_ids.shape[1] - 1]])
+ if self.rope_deltas is not None:
+ pos_ids = pos_ids + self.rope_deltas
+
+ last_token = generated_ids[:, -1:]
+ with torch.no_grad():
+ output = self.text_model(
+ input_ids=last_token,
+ position_ids=pos_ids,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=False,
+ )
+ logits = output[0] if isinstance(output, tuple) else output.logits
+ next_token = self._sample_token(logits[:, -1, :], temperature, top_p, top_k)
+ generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1)
+
+ # Stop on EOS
+ if next_token.item() in eos_token_ids:
+ break
+
+ return generated_ids
+
+ @staticmethod
+ def _sample_token(logits, temperature=0.0, top_p=1.0, top_k=0):
+ """Sample a token from logits with optional temperature/top-p/top-k.
+
+ Args:
+ logits: (batch_size, vocab_size) unnormalized logits
+ temperature: Sampling temperature. 0.0 = greedy (argmax).
+ top_p: Nucleus sampling threshold. 1.0 = disabled.
+ top_k: Top-k filtering. 0 = disabled.
+
+ Returns:
+ token_id: (batch_size,) sampled token IDs
+ """
+ if temperature <= 0.0:
+ return torch.argmax(logits, dim=-1)
+
+ # Apply temperature
+ logits = logits / temperature
+
+ # Top-k filtering
+ if top_k > 0:
+ top_k = min(top_k, logits.shape[-1])
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits[indices_to_remove] = float("-inf")
+
+ # Top-p (nucleus) filtering
+ if top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cumulative_probs = torch.cumsum(
+ torch.softmax(sorted_logits, dim=-1), dim=-1
+ )
+ # Remove tokens with cumulative probability above the threshold
+ sorted_indices_to_remove = cumulative_probs > top_p
+ # Shift right so the first token above threshold is kept
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
+ ..., :-1
+ ].clone()
+ sorted_indices_to_remove[..., 0] = False
+ # Scatter back to original indexing
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ -1, sorted_indices, sorted_indices_to_remove
+ )
+ logits[indices_to_remove] = float("-inf")
+
+ # Sample from the filtered distribution
+ probs = torch.softmax(logits, dim=-1)
+ return torch.multinomial(probs, num_samples=1).squeeze(-1)
+
+ @staticmethod
+ def prepare_input_args(text_prompt, image_path, processor, role="user"):
+ """Prepare inputs for vision+text generation.
+
+ Args:
+ text_prompt: Text prompt string
+ image_path: Path to image file (or None for text-only)
+ processor: HF AutoProcessor
+ role: Message role (default "user")
+
+ Returns:
+ input_ids, attention_mask, vision_inputs dict
+ """
+ content = []
+ if image_path is not None:
+ import base64
+ from pathlib import Path
+
+ image_data = Path(image_path).read_bytes()
+ b64 = base64.b64encode(image_data).decode("utf-8")
+ content.append(
+ {
+ "type": "image",
+ "url": f"data:image/jpeg;base64,{b64}",
+ }
+ )
+ content.append({"type": "text", "text": text_prompt})
+
+ messages = [{"role": role, "content": content}]
+ inputs = processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ return_dict=True,
+ )
+
+ input_ids = inputs["input_ids"]
+ attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids))
+
+ vision_inputs = {}
+ if "pixel_values" in inputs:
+ vision_inputs["pixel_values"] = inputs["pixel_values"]
+ if "image_grid_thw" in inputs:
+ vision_inputs["image_grid_thw"] = inputs["image_grid_thw"]
+ if "video_grid_thw" in inputs:
+ vision_inputs["video_grid_thw"] = inputs["video_grid_thw"]
+
+ return input_ids, attention_mask, vision_inputs
diff --git a/contrib/models/Qwen3.5-2B/src/nki_kernels/__init__.py b/contrib/models/Qwen3.5-2B/src/nki_kernels/__init__.py
new file mode 100644
index 00000000..b56dfecb
--- /dev/null
+++ b/contrib/models/Qwen3.5-2B/src/nki_kernels/__init__.py
@@ -0,0 +1,10 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Custom NKI kernels for Qwen3.5-2B DeltaNet layers.
+
+Contains three kernel implementations:
+- nki_deltanet: Per-token recurrent kernel (used for token generation)
+- nki_deltanet_chunked: Per-chunk kernel (legacy, superseded by fused)
+- nki_deltanet_fused: Fused single-kernel chunked forward (used for context encoding)
+"""
diff --git a/contrib/models/Qwen3.5-2B/src/nki_kernels/nki_deltanet.py b/contrib/models/Qwen3.5-2B/src/nki_kernels/nki_deltanet.py
new file mode 100644
index 00000000..e6740aa1
--- /dev/null
+++ b/contrib/models/Qwen3.5-2B/src/nki_kernels/nki_deltanet.py
@@ -0,0 +1,337 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""NKI kernels for DeltaNet gated delta rule recurrent forward.
+
+NKI v3 (SDK 2.29, NKI 0.3.0). Processes a SINGLE (batch, head) pair per kernel call.
+The caller loops over (B, H) in PyTorch and calls this kernel for each pair.
+
+Input layout: All inputs are 2D contiguous tensors (S, 128).
+Each call processes one (batch, head) element's full sequence.
+
+k_dim = v_dim = 128, which matches SBUF tile partition dimension exactly.
+g and beta are scalars per token, expanded to (S, 128) by the caller.
+
+Two kernel variants:
+ deltanet_recurrent_fwd -- returns output only (original)
+ deltanet_recurrent_fwd_state -- returns (output, final_state) for CTE->TKG carry-over
+"""
+
+import nki
+import nki.isa as nisa
+import nki.language as nl
+
+# Partition dimension max (NeuronCore SBUF tile width)
+P_MAX = 128
+
+# Shuffle mask: broadcast partition 0 to all partitions in a 32-wide group
+_BROADCAST_MASK = [0] * 32
+
+
+@nki.jit
+def deltanet_recurrent_fwd(
+ query: nl.ndarray, # (S, 128) float32
+ key: nl.ndarray, # (S, 128) float32
+ value: nl.ndarray, # (S, 128) float32
+ g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128
+ beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128
+) -> nl.ndarray:
+ """NKI kernel for DeltaNet recurrent forward -- single (batch, head).
+
+ Iterates over sequence tokens with sequential_range.
+ State matrix (128 x 128) lives in SBUF.
+
+ Args:
+ query: (S, 128) float32
+ key: (S, 128) float32
+ value: (S, 128) float32
+ g_in: (S, 128) float32
+ beta_in: (S, 128) float32
+
+ Returns:
+ output: (S, 128) float32
+ """
+ seq_len, dim = query.shape
+
+ # Output tensor in HBM
+ output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm)
+
+ # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1
+ seq_stride = dim
+
+ # Initialize recurrent state in SBUF: (128, 128)
+ state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.memset(dst=state, value=0.0)
+
+ # Sequential loop over tokens (state-dependent)
+ for t in nl.sequential_range(seq_len):
+ tok_offset = t * seq_stride
+
+ # ---- Load inputs for token t ----
+ q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(
+ dst=q_t,
+ src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset),
+ )
+
+ k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(
+ dst=k_t,
+ src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset),
+ )
+
+ v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(
+ dst=v_t,
+ src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset),
+ )
+
+ g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(
+ dst=g_t,
+ src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset),
+ )
+
+ beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(
+ dst=beta_t,
+ src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset),
+ )
+
+ # ---- Step 1: Decay state -- state = state * exp(g_t) ----
+ exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0)
+
+ state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=state_decayed,
+ data=state,
+ op0=nl.multiply,
+ operand0=exp_g,
+ engine=nisa.vector_engine,
+ )
+ nisa.tensor_copy(dst=state, src=state_decayed)
+
+ # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ----
+ kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t)
+ kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum)
+
+ # ---- Step 3: delta = (v_t - kv_mem) * beta_t ----
+ v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract)
+
+ delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=delta,
+ data=v_sub,
+ op0=nl.multiply,
+ operand0=beta_t,
+ engine=nisa.vector_engine,
+ )
+
+ # ---- Step 4: state += outer(k_t, delta) ----
+ # Broadcast multiply: outer[i,j] = k_t[i] * delta[j]
+ # 1) Transpose delta (128,1) -> (1,128) in PSUM
+ # 2) Copy PSUM (1,128) -> SBUF (128,128) -- partition broadcast
+ # 3) Multiply by k_t (128,1) which broadcasts across free dim
+ # This avoids the nc_matmul P=1 outer product (wastes 127/128 TE lanes).
+
+ # Transpose delta to get values along free dimension
+ delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=delta_row_psum, data=delta)
+
+ # Copy PSUM (1, 128) -> SBUF (1, 128) first (NKI 0.3.0 requires matching P dims)
+ delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum)
+
+ # Broadcast (1, 128) SBUF -> (128, 128) SBUF via nc_stream_shuffle
+ # Each partition row gets the same delta values
+ delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ for i_shuf in nl.static_range(P_MAX // 32):
+ nisa.nc_stream_shuffle(
+ src=delta_row_sb[0:1, 0:P_MAX],
+ dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX],
+ shuffle_mask=_BROADCAST_MASK,
+ )
+
+ # Element-wise multiply: outer[i,j] = delta_broadcast[i,j] * k_t[i,0]
+ # tensor_scalar broadcasts (P,1) k_t across all F columns
+ outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=outer_prod,
+ data=delta_broadcast,
+ op0=nl.multiply,
+ operand0=k_t,
+ engine=nisa.vector_engine,
+ )
+
+ # Accumulate into state
+ state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add)
+ nisa.tensor_copy(dst=state, src=state_new)
+
+ # ---- Step 5: o_t = state^T @ q_t ----
+ o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t)
+ o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=o_t, src=o_t_psum)
+
+ # ---- Store output for token t ----
+ nisa.dma_copy(
+ dst=output.ap(pattern=[[1, dim]], offset=tok_offset),
+ src=o_t,
+ )
+
+ return output
+
+
+@nki.jit
+def deltanet_recurrent_fwd_state(
+ query: nl.ndarray, # (S, 128) float32
+ key: nl.ndarray, # (S, 128) float32
+ value: nl.ndarray, # (S, 128) float32
+ g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128
+ beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128
+):
+ """NKI kernel for DeltaNet recurrent forward with final state output.
+
+ Same recurrence as deltanet_recurrent_fwd, but ALSO writes the final
+ recurrent state (128, 128) to an output HBM buffer. This enables
+ CTE -> TKG state carry-over.
+
+ Returns:
+ output: (S, 128) float32 -- per-token output
+ final_state: (128, 128) float32 -- recurrent state after last token
+ """
+ seq_len, dim = query.shape
+
+ # Output tensors in HBM
+ output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm)
+ final_state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm)
+
+ # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1
+ seq_stride = dim
+
+ # Initialize recurrent state in SBUF: (128, 128)
+ state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.memset(dst=state, value=0.0)
+
+ # Sequential loop over tokens (state-dependent)
+ for t in nl.sequential_range(seq_len):
+ tok_offset = t * seq_stride
+
+ # ---- Load inputs for token t ----
+ q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(
+ dst=q_t,
+ src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset),
+ )
+
+ k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(
+ dst=k_t,
+ src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset),
+ )
+
+ v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(
+ dst=v_t,
+ src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset),
+ )
+
+ g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(
+ dst=g_t,
+ src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset),
+ )
+
+ beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(
+ dst=beta_t,
+ src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset),
+ )
+
+ # ---- Step 1: Decay state -- state = state * exp(g_t) ----
+ exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0)
+
+ state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=state_decayed,
+ data=state,
+ op0=nl.multiply,
+ operand0=exp_g,
+ engine=nisa.vector_engine,
+ )
+ nisa.tensor_copy(dst=state, src=state_decayed)
+
+ # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ----
+ kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t)
+ kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum)
+
+ # ---- Step 3: delta = (v_t - kv_mem) * beta_t ----
+ v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract)
+
+ delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=delta,
+ data=v_sub,
+ op0=nl.multiply,
+ operand0=beta_t,
+ engine=nisa.vector_engine,
+ )
+
+ # ---- Step 4: state += outer(k_t, delta) ----
+ # Broadcast multiply: outer[i,j] = k_t[i] * delta[j]
+ delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=delta_row_psum, data=delta)
+
+ # Copy PSUM (1, 128) -> SBUF (1, 128) first (NKI 0.3.0 requires matching P dims)
+ delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum)
+
+ # Broadcast (1, 128) SBUF -> (128, 128) SBUF via nc_stream_shuffle
+ delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ for i_shuf in nl.static_range(P_MAX // 32):
+ nisa.nc_stream_shuffle(
+ src=delta_row_sb[0:1, 0:P_MAX],
+ dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX],
+ shuffle_mask=_BROADCAST_MASK,
+ )
+
+ outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=outer_prod,
+ data=delta_broadcast,
+ op0=nl.multiply,
+ operand0=k_t,
+ engine=nisa.vector_engine,
+ )
+
+ state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add)
+ nisa.tensor_copy(dst=state, src=state_new)
+
+ # ---- Step 5: o_t = state^T @ q_t ----
+ o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t)
+ o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=o_t, src=o_t_psum)
+
+ # ---- Store output for token t ----
+ nisa.dma_copy(
+ dst=output.ap(pattern=[[1, dim]], offset=tok_offset),
+ src=o_t,
+ )
+
+ # ---- Write final state to HBM ----
+ # state is (128, 128) in SBUF, copy to final_state in HBM
+ # Use dma_copy with full tile: P_MAX rows, dim cols
+ nisa.dma_copy(dst=final_state, src=state)
+
+ return output, final_state
diff --git a/contrib/models/Qwen3.5-2B/src/nki_kernels/nki_deltanet_chunked.py b/contrib/models/Qwen3.5-2B/src/nki_kernels/nki_deltanet_chunked.py
new file mode 100644
index 00000000..88f0cc1b
--- /dev/null
+++ b/contrib/models/Qwen3.5-2B/src/nki_kernels/nki_deltanet_chunked.py
@@ -0,0 +1,323 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""NKI per-chunk DeltaNet kernel for CTE (context encoding / prefill).
+
+Single-chunk kernel: processes one chunk (128 tokens) with Neumann-series
+power-doubling for intra-chunk correction. The caller loops over chunks
+in PyTorch, passing state between calls.
+
+Each kernel call:
+ - Takes one chunk of data: q, k, v, beta, g_cumsum, g_last (all 128x128)
+ - Takes recurrent state_in (128x128)
+ - Returns chunk output (128x128) and state_out (128x128)
+
+No sequence-indexed DMA inside the kernel -- all inputs/outputs are full tiles.
+This avoids the DMA OOB issue seen with nl.sequential_range + slice indexing
+in the NxDI model compilation context.
+
+NKI v3 (SDK 2.29, NKI 0.3.0). Uses nki.* namespace.
+"""
+
+import nki
+import nki.isa as nisa
+import nki.language as nl
+
+P_MAX = 128
+
+
+@nki.jit
+def deltanet_chunk_step(
+ query, # (128, 128) float32 -- one chunk, l2-normed+scaled
+ key, # (128, 128) float32 -- one chunk, l2-normed
+ value, # (128, 128) float32 -- one chunk
+ beta_broadcast, # (128, 128) float32 -- write gate broadcast to 128
+ g_cumsum, # (128, 128) float32 -- cumsum of g within chunk, broadcast
+ g_last, # (128, 128) float32 -- g_cumsum[-1], constant in chunk, broadcast
+ state_in, # (128, 128) float32 -- recurrent state from previous chunk
+ lower_mask, # (128, 128) float32 -- strict lower triangular
+ identity, # (128, 128) float32 -- identity matrix
+ lower_mask_diag, # (128, 128) float32 -- lower tri with diagonal
+):
+ """Process one chunk of DeltaNet.
+
+ Returns:
+ output: (128, 128) float32 -- chunk output
+ state_out: (128, 128) float32 -- updated recurrent state
+ """
+ C, dim = query.shape # C = 128, dim = 128
+
+ # Output tensors in HBM
+ output = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.shared_hbm)
+ state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm)
+
+ # Load all inputs into SBUF
+ q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(dst=q_c, src=query)
+
+ k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(dst=k_c, src=key)
+
+ v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(dst=v_c, src=value)
+
+ beta_c = nl.ndarray((P_MAX, dim), dtype=beta_broadcast.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(dst=beta_c, src=beta_broadcast)
+
+ gc_c = nl.ndarray((P_MAX, dim), dtype=g_cumsum.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(dst=gc_c, src=g_cumsum)
+
+ gl_c = nl.ndarray((P_MAX, dim), dtype=g_last.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(dst=gl_c, src=g_last)
+
+ state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.dma_copy(dst=state, src=state_in)
+
+ # Load masks
+ eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.dma_copy(dst=eye, src=identity)
+
+ Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.dma_copy(dst=Lmask, src=lower_mask)
+
+ Lmask_d = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.dma_copy(dst=Lmask_d, src=lower_mask_diag)
+
+ # ============================================================
+ # k_beta = K * beta, v_beta = V * beta
+ # ============================================================
+ k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=k_beta, data1=k_c, data2=beta_c, op=nl.multiply)
+
+ v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=v_beta, data1=v_c, data2=beta_c, op=nl.multiply)
+
+ # ============================================================
+ # exp(g_cumsum) and exp(-g_cumsum)
+ # ============================================================
+ exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.activation(dst=exp_gc, op=nl.exp, data=gc_c, bias=None, scale=1.0)
+
+ neg_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=neg_gc,
+ data=gc_c,
+ op0=nl.multiply,
+ operand0=-1.0,
+ engine=nisa.vector_engine,
+ )
+ exp_neg_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.activation(dst=exp_neg_gc, op=nl.exp, data=neg_gc, bias=None, scale=1.0)
+
+ # exp(g_last) for state decay
+ exp_gl = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.activation(dst=exp_gl, op=nl.exp, data=gl_c, bias=None, scale=1.0)
+
+ # ============================================================
+ # Phase 1: Build A matrix (intra-chunk correction)
+ # QK = k_beta @ k^T -- contract over features
+ # ============================================================
+ kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=kb_T_psum, stationary=k_beta, moving=eye)
+ kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=kb_T, src=kb_T_psum)
+
+ k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=k_T_psum, stationary=k_c, moving=eye)
+ k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=k_T, src=k_T_psum)
+
+ QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T)
+ QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=QK, src=QK_psum)
+
+ # ============================================================
+ # Decay mask: QK_decay[i,j] = QK[i,j] * exp(gc[i]) * exp(-gc[j])
+ # ============================================================
+ QK_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=QK_row, data1=QK, data2=exp_gc, op=nl.multiply)
+
+ QK_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=QK_r_T_psum, stationary=QK_row, moving=eye)
+ QK_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=QK_r_T, src=QK_r_T_psum)
+
+ QK_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=QK_r_T_col, data1=QK_r_T, data2=exp_neg_gc, op=nl.multiply)
+
+ QK_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=QK_d_psum, stationary=QK_r_T_col, moving=eye)
+ QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=QK_decay, src=QK_d_psum)
+
+ # A = -QK_decay * lower_mask
+ neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=neg_QK_decay,
+ data=QK_decay,
+ op0=nl.multiply,
+ operand0=-1.0,
+ engine=nisa.vector_engine,
+ )
+ A = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=A, data1=neg_QK_decay, data2=Lmask, op=nl.multiply)
+
+ # ============================================================
+ # Neumann power-doubling: N = (I+A)(I+A^2)...(I+A^{64})
+ # ============================================================
+ P_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=P_acc, data1=eye, data2=A, op=nl.add)
+
+ A_pow = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=A_pow, src=A)
+
+ for _round in nl.sequential_range(6):
+ Ap_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=Ap_T_psum, stationary=A_pow, moving=eye)
+ Ap_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=Ap_T, src=Ap_T_psum)
+
+ Ap_sq_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=Ap_sq_psum, stationary=Ap_T, moving=A_pow)
+ nisa.tensor_copy(dst=A_pow, src=Ap_sq_psum)
+
+ IpA = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=IpA, data1=eye, data2=A_pow, op=nl.add)
+
+ IpA_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=IpA_T_psum, stationary=IpA, moving=eye)
+ IpA_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=IpA_T, src=IpA_T_psum)
+
+ Pacc_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=Pacc_psum, stationary=IpA_T, moving=P_acc)
+ nisa.tensor_copy(dst=P_acc, src=Pacc_psum)
+
+ # ============================================================
+ # Apply N: value_corr = N @ v_beta, k_cumdecay = N @ (k_beta * exp_gc)
+ # ============================================================
+ N_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=N_T_psum, stationary=P_acc, moving=eye)
+ N_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=N_T, src=N_T_psum)
+
+ vc_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=vc_psum, stationary=N_T, moving=v_beta)
+ value_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=value_corr, src=vc_psum)
+
+ kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=kb_exp_gc, data1=k_beta, data2=exp_gc, op=nl.multiply)
+
+ kcd_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=kcd_psum, stationary=N_T, moving=kb_exp_gc)
+ k_cumdecay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=k_cumdecay, src=kcd_psum)
+
+ # ============================================================
+ # Phase 2: Inter-chunk state propagation
+ # attn_intra = (q @ k^T) * decay_mask * lower_mask_diag
+ # ============================================================
+ q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=q_T_psum, stationary=q_c, moving=eye)
+ q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=q_T, src=q_T_psum)
+
+ qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T)
+ qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=qk_raw, src=qk_psum)
+
+ qk_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=qk_row, data1=qk_raw, data2=exp_gc, op=nl.multiply)
+
+ qk_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=qk_r_T_psum, stationary=qk_row, moving=eye)
+ qk_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=qk_r_T, src=qk_r_T_psum)
+
+ qk_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=qk_r_T_col, data1=qk_r_T, data2=exp_neg_gc, op=nl.multiply)
+
+ qk_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=qk_d_psum, stationary=qk_r_T_col, moving=eye)
+ qk_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=qk_decay, src=qk_d_psum)
+
+ attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=attn_intra, data1=qk_decay, data2=Lmask_d, op=nl.multiply)
+
+ # ============================================================
+ # v_prime = k_cumdecay @ state
+ # ============================================================
+ kcd_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=kcd_T_psum, stationary=k_cumdecay, moving=eye)
+ kcd_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=kcd_T, src=kcd_T_psum)
+
+ vp_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=vp_psum, stationary=kcd_T, moving=state)
+ v_prime = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=v_prime, src=vp_psum)
+
+ v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=v_new, data1=value_corr, data2=v_prime, op=nl.subtract)
+
+ # ============================================================
+ # attn_inter = (q * exp(g_cumsum)) @ state
+ # ============================================================
+ q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=q_exp, data1=q_c, data2=exp_gc, op=nl.multiply)
+
+ qe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=qe_T_psum, stationary=q_exp, moving=eye)
+ qe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=qe_T, src=qe_T_psum)
+
+ ai_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state)
+ attn_inter = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=attn_inter, src=ai_psum)
+
+ # ============================================================
+ # attn_intra @ v_new
+ # ============================================================
+ ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=ai_T_psum, stationary=attn_intra, moving=eye)
+ ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=ai_T, src=ai_T_psum)
+
+ intra_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=intra_psum, stationary=ai_T, moving=v_new)
+ intra_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=intra_out, src=intra_psum)
+
+ # ============================================================
+ # chunk_output = attn_inter + intra_out
+ # ============================================================
+ chunk_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add)
+
+ nisa.dma_copy(dst=output, src=chunk_out)
+
+ # ============================================================
+ # State update: state_new = exp(g_last) * (state + k_raw_decay^T @ v_new)
+ # ============================================================
+ k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=k_raw_decay, data1=k_c, data2=exp_neg_gc, op=nl.multiply)
+
+ kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=kv_psum, stationary=k_raw_decay, moving=v_new)
+ kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=kv_outer, src=kv_psum)
+
+ state_plus = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=state_plus, data1=state, data2=kv_outer, op=nl.add)
+
+ state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=state_new, data1=state_plus, data2=exp_gl, op=nl.multiply)
+
+ nisa.dma_copy(dst=state_out, src=state_new)
+
+ return output, state_out
diff --git a/contrib/models/Qwen3.5-2B/src/nki_kernels/nki_deltanet_fused.py b/contrib/models/Qwen3.5-2B/src/nki_kernels/nki_deltanet_fused.py
new file mode 100644
index 00000000..3447a138
--- /dev/null
+++ b/contrib/models/Qwen3.5-2B/src/nki_kernels/nki_deltanet_fused.py
@@ -0,0 +1,577 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Fused single-kernel DeltaNet chunked forward for CTE (context encoding).
+
+SSD-style architecture: processes ALL chunks for one (batch, head) pair in
+a single NKI kernel call. State (128x128) persists in SBUF across chunks —
+no HBM round-trips for inter-chunk state propagation.
+
+Key optimizations over nki_deltanet_chunked.py:
+ 1. Single kernel call per (B,H) instead of B*H*num_chunks calls
+ 2. State in SBUF across all chunks (no HBM state read/write per chunk)
+ 3. In-kernel cumsum via tensor_tensor_scan (no PyTorch cumsum)
+ 4. Masks and constants loaded once, reused across chunks
+ 5. Uses tensor_scalar for partition-broadcast (no explicit broadcast loops)
+ 6. nc_transpose (Vector Engine) for all 128x128 transposes instead of
+ nc_matmul(moving=eye) (Tensor Engine) — frees TE for actual math
+
+NKI 0.3.0 (SDK 2.29). k_dim = v_dim = 128 = P_MAX exactly.
+Chunk size = 128 = P_MAX (one tile per chunk).
+
+Mathematical framework (same as nki_deltanet_chunked.py):
+ Per-chunk Neumann-series power-doubling for intra-chunk correction:
+ A = -QK_decay * lower_mask
+ N = (I+A)(I+A^2)(I+A^4)...(I+A^64) [6 rounds]
+ value_corr = N @ v_beta
+ k_cumdecay = N @ (k_beta * exp(gc))
+
+ Inter-chunk state propagation:
+ v_prime = k_cumdecay @ state
+ v_new = value_corr - v_prime
+ attn_inter = (q * exp(gc)) @ state
+ attn_intra = (q @ k^T) * decay_mask * lower_mask_diag
+ output = attn_inter + attn_intra @ v_new
+ state = exp(g_last) * (state + k_raw_decay^T @ v_new)
+"""
+
+import numpy as np
+
+import nki
+import nki.isa as nisa
+import nki.language as nl
+
+P_MAX = 128 # Partition dim = chunk_size = k_dim = v_dim
+CHUNK_SIZE = 128
+
+# Broadcast partition 0 to all partitions in a 32-wide group
+_BROADCAST_MASK = [0] * 32
+
+
+def _make_lower_mask():
+ """Strict lower triangular (128x128) as numpy constant."""
+ return np.tril(np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=-1)
+
+
+def _make_lower_mask_diag():
+ """Lower triangular with diagonal (128x128) as numpy constant."""
+ return np.tril(np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=0)
+
+
+def _make_identity():
+ """Identity matrix (128x128) as numpy constant."""
+ return np.eye(CHUNK_SIZE, dtype=np.float32)
+
+
+@nki.jit
+def deltanet_fused_chunked_fwd(
+ query: nl.ndarray, # (S, 128) float32 — l2-normed and scaled
+ key: nl.ndarray, # (S, 128) float32 — l2-normed
+ value: nl.ndarray, # (S, 128) float32
+ g_in: nl.ndarray, # (S, 1) float32 — per-token log-decay (NOT cumsum)
+ beta_in: nl.ndarray, # (S, 1) float32 — per-token write gate
+ lower_mask: nl.ndarray, # (128, 128) float32 — strict lower tri
+ identity: nl.ndarray, # (128, 128) float32 — identity
+ lower_mask_diag: nl.ndarray, # (128, 128) float32 — lower tri with diag
+):
+ """Fused chunked DeltaNet forward — single kernel call per (batch, head).
+
+ Processes all chunks sequentially within the kernel, keeping the recurrent
+ state (128x128) in SBUF across chunks. Returns per-token output and
+ final state.
+
+ Input requirements:
+ - S must be divisible by 128 (pad before calling)
+ - query must be l2-normed and scaled by 1/sqrt(k_dim)
+ - key must be l2-normed
+ - g_in is RAW log-decay (cumsum computed in-kernel via tensor_tensor_scan)
+ - beta_in is sigmoid(b) (write gate)
+
+ Returns:
+ output: (S, 128) float32
+ final_state: (128, 128) float32
+ """
+ seq_len = query.shape[0]
+ dim = query.shape[1] # 128
+ num_chunks = seq_len // CHUNK_SIZE
+
+ # Output tensors in HBM
+ output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm)
+ final_state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm)
+
+ # ================================================================
+ # Load constant masks into SBUF once (reused across all chunks)
+ # ================================================================
+ eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.dma_copy(dst=eye, src=identity)
+
+ Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.dma_copy(dst=Lmask, src=lower_mask)
+
+ Lmask_d = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.dma_copy(dst=Lmask_d, src=lower_mask_diag)
+
+ # Ones vector for cumsum scan: (1, CHUNK_SIZE)
+ ones_1xC = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.memset(dst=ones_1xC, value=1.0)
+
+ # Zero initial for cumsum scan
+ zero_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.memset(dst=zero_11, value=0.0)
+
+ # ================================================================
+ # Initialize recurrent state in SBUF — persists across ALL chunks
+ # ================================================================
+ state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.memset(dst=state, value=0.0)
+
+ # ================================================================
+ # Sequential chunk processing
+ # ================================================================
+ for i_chunk in nl.sequential_range(num_chunks):
+ chunk_start = i_chunk * CHUNK_SIZE
+
+ # ---- Load chunk data from HBM ----
+ q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(
+ dst=q_c,
+ src=query[chunk_start : chunk_start + CHUNK_SIZE, 0:dim],
+ )
+
+ k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(
+ dst=k_c,
+ src=key[chunk_start : chunk_start + CHUNK_SIZE, 0:dim],
+ )
+
+ v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf)
+ nisa.dma_copy(
+ dst=v_c,
+ src=value[chunk_start : chunk_start + CHUNK_SIZE, 0:dim],
+ )
+
+ # g: (CHUNK_SIZE, 1) — raw log-decay per token
+ g_chunk_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.dma_copy(
+ dst=g_chunk_p[0:CHUNK_SIZE, 0:1],
+ src=g_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1],
+ )
+
+ # beta: (CHUNK_SIZE, 1) — write gate scalar per token
+ beta_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.dma_copy(
+ dst=beta_p[0:CHUNK_SIZE, 0:1],
+ src=beta_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1],
+ )
+
+ # ---- In-kernel cumsum of g via tensor_tensor_scan ----
+ # Need g as (1, CHUNK_SIZE) for scan along free dim.
+ # Transpose: (CHUNK_SIZE, 1) -> (1, CHUNK_SIZE) via nc_transpose
+ g_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.memset(dst=g_padded, value=0.0)
+ nisa.tensor_copy(
+ dst=g_padded[0:CHUNK_SIZE, 0:1],
+ src=g_chunk_p[0:CHUNK_SIZE, 0:1],
+ )
+
+ g_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=g_tp_psum, data=g_padded)
+
+ g_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(
+ dst=g_row[0:1, 0:CHUNK_SIZE],
+ src=g_tp_psum[0:1, 0:CHUNK_SIZE],
+ )
+
+ # cumsum: gc_row[t] = 1.0 * gc_row[t-1] + g_row[t]
+ gc_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor_scan(
+ dst=gc_row[0:1, 0:CHUNK_SIZE],
+ data0=ones_1xC[0:1, 0:CHUNK_SIZE],
+ data1=g_row[0:1, 0:CHUNK_SIZE],
+ initial=zero_11[0:1, 0:1],
+ op0=nl.multiply,
+ op1=nl.add,
+ )
+
+ # Transpose gc back to (CHUNK_SIZE, 1) partition layout
+ gc_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.memset(dst=gc_padded, value=0.0)
+ nisa.tensor_copy(
+ dst=gc_padded[0:1, 0:CHUNK_SIZE],
+ src=gc_row[0:1, 0:CHUNK_SIZE],
+ )
+
+ gc_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=gc_tp_psum, data=gc_padded)
+
+ # gc_p: (P_MAX, 1) — cumulative sum of g per token in this chunk
+ gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(
+ dst=gc_p[0:CHUNK_SIZE, 0:1],
+ src=gc_tp_psum[0:CHUNK_SIZE, 0:1],
+ )
+
+ # g_last = gc[-1] (scalar) — needed for state decay
+ gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(
+ dst=gl_11[0:1, 0:1],
+ src=gc_row[0:1, CHUNK_SIZE - 1 : CHUNK_SIZE],
+ )
+
+ # ---- Compute exp(gc), exp(-gc), exp(g_last) as (P_MAX, 1) scalars ----
+ # These (P_MAX, 1) tensors are used with tensor_scalar to broadcast
+ # across the free dimension without explicit (P_MAX, dim) copies.
+
+ exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.activation(
+ dst=exp_gc_p[0:P_MAX, 0:1],
+ op=nl.exp,
+ data=gc_p[0:P_MAX, 0:1],
+ bias=None,
+ scale=1.0,
+ )
+
+ neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=neg_gc_p,
+ data=gc_p,
+ op0=nl.multiply,
+ operand0=-1.0,
+ engine=nisa.vector_engine,
+ )
+ exp_neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.activation(
+ dst=exp_neg_gc_p[0:P_MAX, 0:1],
+ op=nl.exp,
+ data=neg_gc_p[0:P_MAX, 0:1],
+ bias=None,
+ scale=1.0,
+ )
+
+ # exp(g_last): scalar, then broadcast to (P_MAX, 1)
+ exp_gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.activation(
+ dst=exp_gl_11,
+ op=nl.exp,
+ data=gl_11,
+ bias=None,
+ scale=1.0,
+ )
+
+ exp_gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf)
+ for i_shuf in nl.static_range(P_MAX // 32):
+ nisa.nc_stream_shuffle(
+ src=exp_gl_11[0:1, 0:1],
+ dst=exp_gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1],
+ shuffle_mask=_BROADCAST_MASK,
+ )
+
+ # ============================================================
+ # k_beta = K * beta, v_beta = V * beta
+ # tensor_scalar broadcasts beta_p (P_MAX, 1) across free dim
+ # ============================================================
+ k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=k_beta,
+ data=k_c,
+ op0=nl.multiply,
+ operand0=beta_p,
+ engine=nisa.vector_engine,
+ )
+
+ v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=v_beta,
+ data=v_c,
+ op0=nl.multiply,
+ operand0=beta_p,
+ engine=nisa.vector_engine,
+ )
+
+ # ============================================================
+ # Phase 1: Build A matrix (intra-chunk correction)
+ # Transpose K and K_beta for matmul
+ # ============================================================
+ kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=kb_T_psum, data=k_beta)
+ kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=kb_T, src=kb_T_psum)
+
+ k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=k_T_psum, data=k_c)
+ k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=k_T, src=k_T_psum)
+
+ # QK = k_beta^T @ k (contract over features)
+ QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T)
+ QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=QK, src=QK_psum)
+
+ # ============================================================
+ # Decay mask: QK_decay[i,j] = QK[i,j] * exp(gc[i]) * exp(-gc[j])
+ #
+ # Row scaling: QK_row[i,:] = QK[i,:] * exp(gc[i])
+ # Then transpose, column scale, transpose back.
+ # Uses tensor_scalar with (P_MAX,1) operand for row scaling.
+ # ============================================================
+ QK_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=QK_row,
+ data=QK,
+ op0=nl.multiply,
+ operand0=exp_gc_p,
+ engine=nisa.vector_engine,
+ )
+
+ # Transpose to scale columns (now rows in transposed view)
+ QK_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=QK_r_T_psum, data=QK_row)
+ QK_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=QK_r_T, src=QK_r_T_psum)
+
+ QK_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=QK_r_T_col,
+ data=QK_r_T,
+ op0=nl.multiply,
+ operand0=exp_neg_gc_p,
+ engine=nisa.vector_engine,
+ )
+
+ # Transpose back
+ QK_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=QK_d_psum, data=QK_r_T_col)
+ QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=QK_decay, src=QK_d_psum)
+
+ # A = -QK_decay * lower_mask
+ neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=neg_QK_decay,
+ data=QK_decay,
+ op0=nl.multiply,
+ operand0=-1.0,
+ engine=nisa.vector_engine,
+ )
+ A_mat = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=A_mat, data1=neg_QK_decay, data2=Lmask, op=nl.multiply)
+
+ # ============================================================
+ # Neumann power-doubling: N = (I+A)(I+A^2)...(I+A^{64})
+ # 6 rounds → resolves rank up to 2^6 = 64 (sufficient for chunk=128)
+ # ============================================================
+ P_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=P_acc, data1=eye, data2=A_mat, op=nl.add)
+
+ A_pow = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=A_pow, src=A_mat)
+
+ for _round in nl.sequential_range(6):
+ # A_pow = A_pow^2: transpose A_pow, then matmul
+ Ap_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=Ap_T_psum, data=A_pow)
+ Ap_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=Ap_T, src=Ap_T_psum)
+
+ Ap_sq_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=Ap_sq_psum, stationary=Ap_T, moving=A_pow)
+ nisa.tensor_copy(dst=A_pow, src=Ap_sq_psum)
+
+ # P_acc = (I + A_pow) @ P_acc: transpose IpA, then matmul
+ IpA = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=IpA, data1=eye, data2=A_pow, op=nl.add)
+
+ IpA_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=IpA_T_psum, data=IpA)
+ IpA_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=IpA_T, src=IpA_T_psum)
+
+ Pacc_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=Pacc_psum, stationary=IpA_T, moving=P_acc)
+ nisa.tensor_copy(dst=P_acc, src=Pacc_psum)
+
+ # ============================================================
+ # Apply N: value_corr = N @ v_beta
+ # k_cumdecay = N @ (k_beta * exp(gc))
+ # ============================================================
+ N_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=N_T_psum, data=P_acc)
+ N_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=N_T, src=N_T_psum)
+
+ vc_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=vc_psum, stationary=N_T, moving=v_beta)
+ value_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=value_corr, src=vc_psum)
+
+ # k_beta * exp(gc): row-scaled
+ kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=kb_exp_gc,
+ data=k_beta,
+ op0=nl.multiply,
+ operand0=exp_gc_p,
+ engine=nisa.vector_engine,
+ )
+
+ kcd_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=kcd_psum, stationary=N_T, moving=kb_exp_gc)
+ k_cumdecay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=k_cumdecay, src=kcd_psum)
+
+ # ============================================================
+ # Phase 2: Inter-chunk state propagation
+ # attn_intra = (q @ k^T) * decay_mask * lower_mask_diag
+ # ============================================================
+ q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=q_T_psum, data=q_c)
+ q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=q_T, src=q_T_psum)
+
+ qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T)
+ qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=qk_raw, src=qk_psum)
+
+ # Row-scale by exp(gc)
+ qk_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=qk_row,
+ data=qk_raw,
+ op0=nl.multiply,
+ operand0=exp_gc_p,
+ engine=nisa.vector_engine,
+ )
+
+ # Transpose, column-scale by exp(-gc), transpose back
+ qk_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=qk_r_T_psum, data=qk_row)
+ qk_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=qk_r_T, src=qk_r_T_psum)
+
+ qk_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=qk_r_T_col,
+ data=qk_r_T,
+ op0=nl.multiply,
+ operand0=exp_neg_gc_p,
+ engine=nisa.vector_engine,
+ )
+
+ qk_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=qk_d_psum, data=qk_r_T_col)
+ qk_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=qk_decay, src=qk_d_psum)
+
+ attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(
+ dst=attn_intra, data1=qk_decay, data2=Lmask_d, op=nl.multiply
+ )
+
+ # ============================================================
+ # v_prime = k_cumdecay @ state (state is in SBUF!)
+ # ============================================================
+ kcd_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=kcd_T_psum, data=k_cumdecay)
+ kcd_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=kcd_T, src=kcd_T_psum)
+
+ vp_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=vp_psum, stationary=kcd_T, moving=state)
+ v_prime = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=v_prime, src=vp_psum)
+
+ v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=v_new, data1=value_corr, data2=v_prime, op=nl.subtract)
+
+ # ============================================================
+ # attn_inter = (q * exp(gc)) @ state (state is in SBUF!)
+ # ============================================================
+ q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=q_exp,
+ data=q_c,
+ op0=nl.multiply,
+ operand0=exp_gc_p,
+ engine=nisa.vector_engine,
+ )
+
+ qe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=qe_T_psum, data=q_exp)
+ qe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=qe_T, src=qe_T_psum)
+
+ ai_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state)
+ attn_inter = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=attn_inter, src=ai_psum)
+
+ # ============================================================
+ # attn_intra @ v_new
+ # ============================================================
+ ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_transpose(dst=ai_T_psum, data=attn_intra)
+ ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=ai_T, src=ai_T_psum)
+
+ intra_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=intra_psum, stationary=ai_T, moving=v_new)
+ intra_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=intra_out, src=intra_psum)
+
+ # ============================================================
+ # chunk_output = attn_inter + intra_out
+ # ============================================================
+ chunk_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add)
+
+ # Store output chunk to HBM
+ nisa.dma_copy(
+ dst=output[chunk_start : chunk_start + CHUNK_SIZE, 0:dim],
+ src=chunk_out,
+ )
+
+ # ============================================================
+ # State update: state = exp(g_last) * (state + k_raw_decay^T @ v_new)
+ # state is updated IN-PLACE in SBUF — no HBM round-trip!
+ # ============================================================
+
+ # k_raw_decay = k * exp(-gc)
+ k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_scalar(
+ dst=k_raw_decay,
+ data=k_c,
+ op0=nl.multiply,
+ operand0=exp_neg_gc_p,
+ engine=nisa.vector_engine,
+ )
+
+ # k_raw_decay^T @ v_new → (dim, dim) outer product sum
+ # nc_matmul: result[M,N] = sum_K stationary[K,M] * moving[K,N]
+ # stationary=k_raw_decay (P_MAX, dim), moving=v_new (P_MAX, dim)
+ # Result: sum over tokens -> (dim, dim)
+ kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum)
+ nisa.nc_matmul(dst=kv_psum, stationary=k_raw_decay, moving=v_new)
+ kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_copy(dst=kv_outer, src=kv_psum)
+
+ # state = state + kv_outer
+ state_plus = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf)
+ nisa.tensor_tensor(dst=state_plus, data1=state, data2=kv_outer, op=nl.add)
+
+ # state = state_plus * exp(g_last)
+ # tensor_scalar broadcasts exp_gl_p (P_MAX, 1) across free dim
+ nisa.tensor_scalar(
+ dst=state,
+ data=state_plus,
+ op0=nl.multiply,
+ operand0=exp_gl_p,
+ engine=nisa.vector_engine,
+ )
+
+ # ---- Write final state to HBM ----
+ nisa.dma_copy(dst=final_state_out, src=state)
+
+ return output, final_state_out
diff --git a/contrib/models/Qwen3.5-2B/test/__init__.py b/contrib/models/Qwen3.5-2B/test/__init__.py
new file mode 100644
index 00000000..04f8b7b7
--- /dev/null
+++ b/contrib/models/Qwen3.5-2B/test/__init__.py
@@ -0,0 +1,2 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
diff --git a/contrib/models/Qwen3.5-2B/test/integration/__init__.py b/contrib/models/Qwen3.5-2B/test/integration/__init__.py
new file mode 100644
index 00000000..04f8b7b7
--- /dev/null
+++ b/contrib/models/Qwen3.5-2B/test/integration/__init__.py
@@ -0,0 +1,2 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
diff --git a/contrib/models/Qwen3.5-2B/test/integration/test_model.py b/contrib/models/Qwen3.5-2B/test/integration/test_model.py
new file mode 100644
index 00000000..d674e8ce
--- /dev/null
+++ b/contrib/models/Qwen3.5-2B/test/integration/test_model.py
@@ -0,0 +1,671 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Integration tests for Qwen3.5-2B on Neuron.
+
+Tests compilation, loading, inference accuracy, and performance using
+the full 2B model with pre-downloaded HuggingFace weights on a trn2 instance.
+
+Note: A mini model option is not provided because DeltaNet layers require NKI
+kernels that only execute on Neuron devices, and the hybrid DeltaNet + GQA
+architecture needs at least TP=4 for compilation.
+
+Environment variables:
+ QWEN35_MODEL_PATH Path to HF model weights (required)
+ QWEN35_COMPILED_PATH Path to compiled artifacts (default: /tmp/qwen35_2b_traced)
+ QWEN35_REF_LOGITS_PATH Path to CPU reference logits .pt file (for logit validation)
+ QWEN35_TP_DEGREE Tensor parallelism degree (default: 4)
+ QWEN35_SEQ_LEN Max sequence length (default: 128)
+ TTFT_THRESHOLD_MS Max TTFT in ms (default: 5000)
+ THROUGHPUT_THRESHOLD Min throughput in tok/s (default: 5.0)
+
+Prerequisites:
+ - trn2.3xlarge or larger with TP >= 4 NeuronCores available
+ - NXDI installed (neuronx_distributed_inference)
+ - HuggingFace weights downloaded to QWEN35_MODEL_PATH
+ - SDK 2.29+ (NKI 0.3.0 required for DeltaNet kernels)
+
+Usage:
+ # Full model (trn2.3xlarge, TP=4):
+ QWEN35_MODEL_PATH=/mnt/models/Qwen3.5-2B \\
+ QWEN35_COMPILED_PATH=/mnt/models/qwen35_2b_traced \\
+ pytest test/integration/test_model.py --capture=tee-sys
+"""
+
+import gc
+import os
+import sys
+import time
+
+import pytest
+import torch
+
+# Ensure the contrib root (Qwen3.5-2B/) is on sys.path
+_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
+if _CONTRIB_ROOT not in sys.path:
+ sys.path.insert(0, _CONTRIB_ROOT)
+
+# ── Configuration from environment ──────────────────────────────────────
+
+MODEL_PATH = os.environ.get("QWEN35_MODEL_PATH", "")
+COMPILED_PATH = os.environ.get("QWEN35_COMPILED_PATH", "/tmp/qwen35_2b_traced")
+CPU_REFERENCE_LOGITS_PATH = os.environ.get("QWEN35_REF_LOGITS_PATH", "")
+LOGIT_COMPILED_PATH = os.environ.get("QWEN35_LOGIT_COMPILED_PATH", "")
+TP_DEGREE = int(os.environ.get("QWEN35_TP_DEGREE", "4"))
+SEQ_LEN = int(os.environ.get("QWEN35_SEQ_LEN", "128"))
+TTFT_THRESHOLD_MS = float(os.environ.get("TTFT_THRESHOLD_MS", "5000"))
+THROUGHPUT_THRESHOLD = float(os.environ.get("THROUGHPUT_THRESHOLD", "5.0"))
+
+requires_model_path = pytest.mark.skipif(
+ not MODEL_PATH,
+ reason=(
+ "QWEN35_MODEL_PATH not set. Integration tests require the full 2B model "
+ "weights. Set QWEN35_MODEL_PATH=/path/to/Qwen3.5-2B to run these tests."
+ ),
+)
+
+
+# ── Fixtures ────────────────────────────────────────────────────────────
+
+
+@pytest.fixture(scope="module")
+def model_path():
+ """Return path to model weights."""
+ return MODEL_PATH
+
+
+@pytest.fixture(scope="module")
+def compiled_model(model_path):
+ """Compile and load the model on Neuron."""
+ import json
+
+ from neuronx_distributed_inference.models.config import (
+ NeuronConfig,
+ OnDeviceSamplingConfig,
+ )
+ from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM
+
+ neuron_config = NeuronConfig(
+ tp_degree=TP_DEGREE,
+ batch_size=1,
+ ctx_batch_size=1,
+ tkg_batch_size=1,
+ seq_len=SEQ_LEN,
+ torch_dtype=torch.bfloat16,
+ on_device_sampling_config=OnDeviceSamplingConfig(top_k=1),
+ enable_bucketing=False,
+ flash_decoding_enabled=False,
+ logical_nc_config=2,
+ save_sharded_checkpoint=True,
+ )
+
+ # Read config.json directly (model_type 'qwen3_5' may not be in
+ # AutoConfig registry for all transformers versions)
+ with open(os.path.join(model_path, "config.json")) as f:
+ full_config = json.load(f)
+ text_config = full_config.get("text_config", full_config)
+
+ config_dict = dict(text_config)
+ config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044)
+ if "rope_parameters" in text_config:
+ config_dict["rope_theta"] = text_config["rope_parameters"].get(
+ "rope_theta", 10000000
+ )
+ config_dict.setdefault("tie_word_embeddings", True)
+
+ inf_config = Qwen35InferenceConfig(
+ neuron_config=neuron_config,
+ **config_dict,
+ )
+
+ # Compile if no existing artifacts
+ compiled_path = COMPILED_PATH
+ neff_path = os.path.join(compiled_path, "model.pt")
+ if not os.path.exists(neff_path):
+ print(f"Compiling to {compiled_path}...")
+ model = NeuronQwen35ForCausalLM(model_path, inf_config)
+ model.compile(compiled_path)
+ del model
+ gc.collect()
+
+ # Load
+ print(f"Loading from {compiled_path}...")
+ model = NeuronQwen35ForCausalLM(compiled_path)
+ model.load(compiled_path)
+ return model
+
+
+@pytest.fixture(scope="module")
+def tokenizer(model_path):
+ """Load tokenizer."""
+ from transformers import AutoTokenizer
+
+ tok = AutoTokenizer.from_pretrained(model_path, padding_side="right")
+ if tok.pad_token is None:
+ tok.pad_token = tok.eos_token
+ return tok
+
+
+@pytest.fixture(scope="module")
+def generation_config(tokenizer):
+ """Create generation config."""
+ from transformers import GenerationConfig
+
+ return GenerationConfig(
+ do_sample=True,
+ top_k=1,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ )
+
+
+def _generate(model, tokenizer, generation_config, prompt, max_new_tokens=20):
+ """Generate text using the NXDI model (raw text prompt)."""
+ from neuronx_distributed_inference.utils.hf_adapter import (
+ HuggingFaceGenerationAdapter,
+ )
+
+ inputs = tokenizer(prompt, padding=True, return_tensors="pt")
+ gen_model = HuggingFaceGenerationAdapter(model)
+ outputs = gen_model.generate(
+ inputs.input_ids,
+ generation_config=generation_config,
+ attention_mask=inputs.attention_mask,
+ max_new_tokens=max_new_tokens,
+ )
+ return outputs[0].tolist(), tokenizer.decode(outputs[0], skip_special_tokens=True)
+
+
+def _chat_generate(
+ model, tokenizer, generation_config, user_message, max_new_tokens=50
+):
+ """Generate text using the NXDI model with chat template formatting.
+
+ Qwen3.5-2B is a chat model that expects <|im_start|>/<|im_end|> formatting.
+ Raw text prompts produce echoey output; chat-formatted prompts work correctly.
+ """
+ from neuronx_distributed_inference.utils.hf_adapter import (
+ HuggingFaceGenerationAdapter,
+ )
+
+ messages = [{"role": "user", "content": user_message}]
+ text = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ inputs = tokenizer(text, padding=True, return_tensors="pt")
+ gen_model = HuggingFaceGenerationAdapter(model)
+ outputs = gen_model.generate(
+ inputs.input_ids,
+ generation_config=generation_config,
+ attention_mask=inputs.attention_mask,
+ max_new_tokens=max_new_tokens,
+ )
+ full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
+ # Extract just the assistant response (after the user message)
+ # The decoded text includes "user\n{msg}\nassistant\n{response}"
+ if "assistant" in full_text:
+ response = full_text.split("assistant", 1)[-1].strip()
+ else:
+ response = full_text
+ return outputs[0].tolist(), response
+
+
+def _is_repetitive(text, max_repeat=5):
+ """Check for excessive word repetition."""
+ words = text.split()
+ if len(words) < max_repeat:
+ return False
+ for i in range(len(words) - max_repeat + 1):
+ if len(set(words[i : i + max_repeat])) == 1:
+ return True
+ return False
+
+
+# ── Smoke Tests ─────────────────────────────────────────────────────────
+
+
+@requires_model_path
+def test_model_loads(compiled_model):
+ """Model compiles and loads successfully."""
+ assert compiled_model is not None
+ assert hasattr(compiled_model, "neuron_config")
+ print(" Model loaded successfully")
+
+
+@requires_model_path
+def test_model_generates(compiled_model, tokenizer, generation_config):
+ """Model generates at least 5 tokens."""
+ tokens, text = _generate(
+ compiled_model,
+ tokenizer,
+ generation_config,
+ "Hello, I am a language model",
+ max_new_tokens=20,
+ )
+ input_len = len(tokenizer.encode("Hello, I am a language model"))
+ new_tokens = len(tokens) - input_len
+ assert new_tokens >= 5, f"Expected >= 5 new tokens, got {new_tokens}"
+ print(f" Generated {new_tokens} tokens: {text[:100]}...")
+
+
+# ── Accuracy Tests ──────────────────────────────────────────────────────
+
+
+@requires_model_path
+def test_output_coherence(compiled_model, tokenizer, generation_config):
+ """Output should contain multiple words and not be excessively repetitive."""
+ _, response = _chat_generate(
+ compiled_model,
+ tokenizer,
+ generation_config,
+ "What is the capital of France?",
+ max_new_tokens=50,
+ )
+ # Strip tags if present
+ clean = response
+ if "" in clean:
+ clean = clean.split("")[-1].strip()
+ words = clean.split()
+ assert len(words) >= 3, f"Expected >= 3 words, got {len(words)}: '{clean}'"
+ assert not _is_repetitive(clean), f"Output is excessively repetitive: '{clean}'"
+ print(f" Output coherent ({len(words)} words): {clean[:80]}...")
+
+
+@requires_model_path
+def test_top_token_valid(compiled_model, tokenizer, generation_config):
+ """First generated token should be a valid decodable token."""
+ tokens, _ = _chat_generate(
+ compiled_model,
+ tokenizer,
+ generation_config,
+ "Hello!",
+ max_new_tokens=1,
+ )
+ # Chat template adds special tokens, so input_len is the chat-formatted length
+ messages = [{"role": "user", "content": "Hello!"}]
+ chat_text = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ input_len = len(tokenizer.encode(chat_text))
+ first_new = tokens[input_len]
+ assert 0 <= first_new < tokenizer.vocab_size, (
+ f"Token {first_new} out of vocab range"
+ )
+ decoded = tokenizer.decode([first_new])
+ assert len(decoded) > 0, f"Token {first_new} decoded to empty string"
+ print(f" First token: {first_new} -> '{decoded}'")
+
+
+@requires_model_path
+def test_capital_of_france(compiled_model, tokenizer, generation_config):
+ """'What is the capital of France?' should produce 'Paris' in response."""
+ _, response = _chat_generate(
+ compiled_model,
+ tokenizer,
+ generation_config,
+ "What is the capital of France?",
+ max_new_tokens=80,
+ )
+ # Strip tags if present
+ clean = response
+ if "" in clean:
+ clean = clean.split("")[-1].strip()
+ assert "paris" in clean.lower(), f"Expected 'Paris' in output, got: '{clean}'"
+ print(f" Capital of France: {clean[:80]}...")
+
+
+# ── Logit Validation ───────────────────────────────────────────────────
+
+# Qwen3.5-2B uses a hybrid DeltaNet + GQA architecture where 18 of 24 layers
+# are DeltaNet layers using NKI linear recurrent kernels in BF16. This produces
+# numerical divergence from CPU that prevents multi-token logit_validation()
+# (sequences diverge after the first token, making subsequent logit comparisons
+# meaningless). Instead, we validate the first generated token's logits which
+# are computed identically (same input prefix) on both CPU and Neuron.
+#
+# The model outputs TP-sharded logits (vocab_size / tp_degree) because
+# ModelWrapper does not call _gather_along_dim (unlike NeuronBaseModel).
+# Comparisons use the first TP shard (contiguous column split of lm_head).
+FIRST_TOKEN_COSINE_THRESHOLD = 0.85
+FIRST_TOKEN_TOP5_OVERLAP_THRESHOLD = 3 # out of 5
+
+
+@requires_model_path
+def test_logit_accuracy(tokenizer):
+ """Validate first-token logits against pre-computed CPU BF16 reference.
+
+ DeltaNet layers (18 of 24) use NKI linear recurrent kernels that produce
+ higher BF16 numerical divergence than standard GQA. Multi-token
+ logit_validation() is not applicable because autoregressive sequences diverge
+ after the first generated token. This test validates the first-token logits
+ where CPU and Neuron process identical input prefixes.
+
+ Metrics:
+ - Cosine similarity of first-token logit distribution (TP shard 0)
+ - Top-1 token agreement within TP shard 0
+ - Top-5 overlap between CPU and Neuron within TP shard 0
+
+ Requires:
+ - Pre-computed CPU BF16 reference logits at QWEN35_REF_LOGITS_PATH
+ - A model compiled with output_logits=True at QWEN35_LOGIT_COMPILED_PATH
+ """
+ if not CPU_REFERENCE_LOGITS_PATH or not os.path.exists(CPU_REFERENCE_LOGITS_PATH):
+ pytest.skip(
+ "CPU reference logits not found. Set QWEN35_REF_LOGITS_PATH to the "
+ "path of pre-computed CPU reference logits (.pt file)."
+ )
+ if not LOGIT_COMPILED_PATH or not os.path.exists(
+ os.path.join(LOGIT_COMPILED_PATH, "model.pt")
+ ):
+ pytest.skip(
+ "Logit-validation compiled model not found. Set QWEN35_LOGIT_COMPILED_PATH "
+ "to a model compiled with output_logits=True."
+ )
+
+ from transformers import GenerationConfig as HFGenConfig
+ from neuronx_distributed_inference.utils.hf_adapter import (
+ HuggingFaceGenerationAdapter,
+ )
+ from src.modeling_qwen35 import NeuronQwen35ForCausalLM
+
+ # Load the model compiled with output_logits=True
+ print(f" Loading logit-validation model from {LOGIT_COMPILED_PATH}...")
+ logit_model = NeuronQwen35ForCausalLM(LOGIT_COMPILED_PATH)
+ logit_model.load(LOGIT_COMPILED_PATH)
+
+ cpu_ref = torch.load(CPU_REFERENCE_LOGITS_PATH, weights_only=True)
+ cpu_logits = cpu_ref["expected_logits"] # [num_tokens, 1, full_vocab]
+ input_ids = cpu_ref["input_ids"]
+
+ print(f" CPU reference logits shape: {cpu_logits.shape}")
+ print(f" Prompt: '{cpu_ref.get('prompt', 'N/A')}'")
+
+ # Generate on Neuron to capture logits
+ # Request extra tokens because scores include CTE positions
+ # (we only need the first generated token's logits)
+ logit_gen_config = HFGenConfig(
+ do_sample=False,
+ max_new_tokens=16,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ )
+ attention_mask = torch.ones_like(input_ids)
+ gen_model = HuggingFaceGenerationAdapter(logit_model)
+ outputs = gen_model.generate(
+ input_ids,
+ generation_config=logit_gen_config,
+ attention_mask=attention_mask,
+ return_dict_in_generate=True,
+ output_scores=True,
+ )
+
+ neuron_scores = torch.stack(outputs.scores) # [total_steps, 1, tp_vocab]
+ tp_vocab = neuron_scores.shape[-1]
+ full_vocab = cpu_logits.shape[-1]
+ input_len = input_ids.shape[1]
+
+ print(
+ f" Neuron scores: {neuron_scores.shape[0]} steps, "
+ f"TP shard vocab: {tp_vocab}, full vocab: {full_vocab}"
+ )
+
+ # The first generated token's logits are at index input_len
+ # (indices 0..input_len-1 are CTE re-prediction of input tokens)
+ first_gen_idx = min(input_len, neuron_scores.shape[0] - 1)
+ neuron_first = neuron_scores[first_gen_idx, 0, :].float()
+
+ # CPU reference: position 0 = first generated token logits
+ cpu_first = cpu_logits[0, 0, :tp_vocab].float()
+
+ # --- Cosine similarity ---
+ cos_sim = torch.nn.functional.cosine_similarity(
+ cpu_first.unsqueeze(0), neuron_first.unsqueeze(0)
+ ).item()
+ print(
+ f" First-token cosine similarity (TP shard): {cos_sim:.4f} "
+ f"(threshold: {FIRST_TOKEN_COSINE_THRESHOLD})"
+ )
+
+ # --- Top-1 agreement (TP shard) ---
+ cpu_top1 = cpu_first.argmax().item()
+ neuron_top1 = neuron_first.argmax().item()
+ cpu_top1_str = tokenizer.decode([cpu_top1])
+ neuron_top1_str = tokenizer.decode([neuron_top1])
+ top1_match = cpu_top1 == neuron_top1
+ print(
+ f" TP-shard top-1: CPU={cpu_top1} ('{cpu_top1_str}'), "
+ f"Neuron={neuron_top1} ('{neuron_top1_str}'), match={top1_match}"
+ )
+
+ # --- Top-5 overlap ---
+ _, cpu_top5_idx = cpu_first.topk(5)
+ _, neuron_top5_idx = neuron_first.topk(5)
+ cpu_top5_set = set(cpu_top5_idx.tolist())
+ neuron_top5_set = set(neuron_top5_idx.tolist())
+ top5_overlap = len(cpu_top5_set & neuron_top5_set)
+ print(
+ f" Top-5 overlap: {top5_overlap}/5 "
+ f"(threshold: {FIRST_TOKEN_TOP5_OVERLAP_THRESHOLD})"
+ )
+ print(f" CPU top-5: {[tokenizer.decode([t]) for t in cpu_top5_idx.tolist()]}")
+ print(
+ f" Neuron top-5: {[tokenizer.decode([t]) for t in neuron_top5_idx.tolist()]}"
+ )
+
+ # --- Assertions ---
+ assert cos_sim >= FIRST_TOKEN_COSINE_THRESHOLD, (
+ f"First-token cosine similarity {cos_sim:.4f} < {FIRST_TOKEN_COSINE_THRESHOLD}. "
+ f"DeltaNet NKI kernels produce expected BF16 divergence but cosine should "
+ f"remain high for the first token (identical input prefix)."
+ )
+ assert top1_match, (
+ f"First-token top-1 mismatch in TP shard: "
+ f"CPU={cpu_top1} ('{cpu_top1_str}'), "
+ f"Neuron={neuron_top1} ('{neuron_top1_str}')"
+ )
+ assert top5_overlap >= FIRST_TOKEN_TOP5_OVERLAP_THRESHOLD, (
+ f"Top-5 overlap {top5_overlap}/5 < {FIRST_TOKEN_TOP5_OVERLAP_THRESHOLD}. "
+ f"CPU and Neuron top-5 token sets diverge too much."
+ )
+ print(f" PASS: First-token logit accuracy validated")
+
+
+# ── Performance Tests ───────────────────────────────────────────────────
+
+
+@requires_model_path
+def test_performance_ttft(compiled_model, tokenizer, generation_config):
+ """Time to first token should be within threshold."""
+ prompt = "Hello, I am a language model"
+
+ # Warmup
+ _generate(compiled_model, tokenizer, generation_config, prompt, max_new_tokens=1)
+
+ # Measure
+ times = []
+ for _ in range(3):
+ t0 = time.perf_counter()
+ _generate(
+ compiled_model, tokenizer, generation_config, prompt, max_new_tokens=1
+ )
+ times.append((time.perf_counter() - t0) * 1000)
+
+ avg_ms = sum(times) / len(times)
+ print(f" TTFT: {avg_ms:.1f} ms (threshold: {TTFT_THRESHOLD_MS} ms)")
+ assert avg_ms < TTFT_THRESHOLD_MS, (
+ f"TTFT {avg_ms:.1f}ms > threshold {TTFT_THRESHOLD_MS}ms"
+ )
+
+
+@requires_model_path
+def test_performance_throughput(compiled_model, tokenizer, generation_config):
+ """Throughput should meet minimum threshold."""
+ prompt = "Once upon a time"
+ num_new_tokens = 20
+
+ # Warmup
+ _generate(compiled_model, tokenizer, generation_config, prompt, max_new_tokens=5)
+
+ # Measure
+ t0 = time.perf_counter()
+ tokens, _ = _generate(
+ compiled_model,
+ tokenizer,
+ generation_config,
+ prompt,
+ max_new_tokens=num_new_tokens,
+ )
+ elapsed = time.perf_counter() - t0
+
+ input_len = len(tokenizer.encode(prompt))
+ actual_new = len(tokens) - input_len
+ throughput = actual_new / elapsed if elapsed > 0 else 0
+
+ print(
+ f" Throughput: {throughput:.1f} tok/s ({actual_new} tokens in {elapsed:.2f}s)"
+ )
+ print(f" Threshold: {THROUGHPUT_THRESHOLD} tok/s")
+ assert throughput > THROUGHPUT_THRESHOLD, (
+ f"Throughput {throughput:.1f} tok/s < threshold {THROUGHPUT_THRESHOLD}"
+ )
+
+
+# ── Multi-Prompt Quality Test ──────────────────────────────────────────
+
+
+@requires_model_path
+def test_multi_prompt_generation(compiled_model, tokenizer, generation_config):
+ """Multiple chat prompts should produce coherent outputs."""
+ user_messages = [
+ "What is the capital of France?",
+ "Write a Python fibonacci function.",
+ "What is the largest ocean on Earth?",
+ "List two ingredients for a chocolate cake.",
+ ]
+
+ for msg in user_messages:
+ _, response = _chat_generate(
+ compiled_model,
+ tokenizer,
+ generation_config,
+ msg,
+ max_new_tokens=50,
+ )
+ # Strip tags if present
+ clean = response
+ if "" in clean:
+ clean = clean.split("")[-1].strip()
+ words = clean.split()
+ assert len(words) >= 2, f"Message '{msg}' generated too few words: '{clean}'"
+ assert not _is_repetitive(clean), (
+ f"Message '{msg}' produced repetitive output: '{clean}'"
+ )
+ print(f" '{msg[:30]}...' -> {clean[:60]}...")
+
+
+# ── Standalone runner ───────────────────────────────────────────────────
+
+if __name__ == "__main__":
+ print("=" * 60)
+ print("Qwen3.5-2B Integration Tests")
+ print("=" * 60)
+
+ if not MODEL_PATH:
+ print("\nQWEN35_MODEL_PATH not set. Provide the model path to run tests:")
+ print(" QWEN35_MODEL_PATH=/path/to/Qwen3.5-2B \\")
+ print(" QWEN35_COMPILED_PATH=/mnt/models/qwen35_2b_traced \\")
+ print(" python -m pytest test/integration/test_model.py --capture=tee-sys")
+ sys.exit(0)
+
+ # Setup
+ from transformers import AutoTokenizer, GenerationConfig as GenConfig
+
+ tok = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right")
+ if tok.pad_token is None:
+ tok.pad_token = tok.eos_token
+ gen_cfg = GenConfig(
+ do_sample=True,
+ top_k=1,
+ pad_token_id=tok.pad_token_id,
+ eos_token_id=tok.eos_token_id,
+ )
+
+ # Build model
+ import json
+
+ from neuronx_distributed_inference.models.config import (
+ NeuronConfig,
+ OnDeviceSamplingConfig,
+ )
+ from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM
+
+ nc = NeuronConfig(
+ tp_degree=TP_DEGREE,
+ batch_size=1,
+ ctx_batch_size=1,
+ tkg_batch_size=1,
+ seq_len=SEQ_LEN,
+ torch_dtype=torch.bfloat16,
+ on_device_sampling_config=OnDeviceSamplingConfig(top_k=1),
+ enable_bucketing=False,
+ flash_decoding_enabled=False,
+ logical_nc_config=2,
+ save_sharded_checkpoint=True,
+ )
+
+ with open(os.path.join(MODEL_PATH, "config.json")) as f:
+ full_config = json.load(f)
+ text_config = full_config.get("text_config", full_config)
+ config_dict = dict(text_config)
+ config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044)
+ if "rope_parameters" in text_config:
+ config_dict["rope_theta"] = text_config["rope_parameters"].get(
+ "rope_theta", 10000000
+ )
+ config_dict.setdefault("tie_word_embeddings", True)
+ ic = Qwen35InferenceConfig(neuron_config=nc, **config_dict)
+
+ cp = COMPILED_PATH
+ if not os.path.exists(os.path.join(cp, "model.pt")):
+ print(f"Compiling to {cp}...")
+ m = NeuronQwen35ForCausalLM(MODEL_PATH, ic)
+ m.compile(cp)
+ del m
+ gc.collect()
+
+ print(f"Loading from {cp}...")
+ model = NeuronQwen35ForCausalLM(cp)
+ model.load(cp)
+
+ tests = [
+ ("model_loads", lambda: test_model_loads(model)),
+ ("model_generates", lambda: test_model_generates(model, tok, gen_cfg)),
+ ("output_coherence", lambda: test_output_coherence(model, tok, gen_cfg)),
+ ("top_token_valid", lambda: test_top_token_valid(model, tok, gen_cfg)),
+ ("capital_of_france", lambda: test_capital_of_france(model, tok, gen_cfg)),
+ ("logit_accuracy", lambda: test_logit_accuracy(tok)),
+ ("performance_ttft", lambda: test_performance_ttft(model, tok, gen_cfg)),
+ (
+ "performance_throughput",
+ lambda: test_performance_throughput(model, tok, gen_cfg),
+ ),
+ (
+ "multi_prompt_generation",
+ lambda: test_multi_prompt_generation(model, tok, gen_cfg),
+ ),
+ ]
+
+ passed = 0
+ for name, fn in tests:
+ print(f"\n--- {name} ---")
+ try:
+ fn()
+ print(f" PASS")
+ passed += 1
+ except Exception as e:
+ print(f" FAIL: {e}")
+
+ print(f"\n{'=' * 60}")
+ print(f"Results: {passed}/{len(tests)} passed")
+ print(f"{'=' * 60}")
diff --git a/contrib/models/Qwen3.5-2B/test/unit/__init__.py b/contrib/models/Qwen3.5-2B/test/unit/__init__.py
new file mode 100644
index 00000000..04f8b7b7
--- /dev/null
+++ b/contrib/models/Qwen3.5-2B/test/unit/__init__.py
@@ -0,0 +1,2 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
diff --git a/contrib/models/Qwen3.5-2B/test/unit/test_config.py b/contrib/models/Qwen3.5-2B/test/unit/test_config.py
new file mode 100644
index 00000000..3b0089d0
--- /dev/null
+++ b/contrib/models/Qwen3.5-2B/test/unit/test_config.py
@@ -0,0 +1,201 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Unit tests for Qwen3.5-2B inference configuration.
+
+CPU-only tests that validate config parsing, layer type setup,
+DeltaNet parameter defaults, RoPE configuration, and weight conversion logic.
+"""
+
+import os
+import sys
+import unittest
+from unittest.mock import MagicMock
+
+import torch
+
+# Ensure the contrib root (Qwen3.5-2B/) is on sys.path
+_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
+if _CONTRIB_ROOT not in sys.path:
+ sys.path.insert(0, _CONTRIB_ROOT)
+
+from src.modeling_qwen35 import (
+ Qwen35InferenceConfig,
+ convert_qwen35_hf_to_neuron_state_dict,
+)
+from neuronx_distributed_inference.models.config import NeuronConfig
+
+
+def _make_config(**overrides):
+ """Create a Qwen35InferenceConfig with reasonable defaults."""
+ neuron_config = NeuronConfig(
+ tp_degree=overrides.pop("tp_degree", 4),
+ batch_size=1,
+ seq_len=128,
+ torch_dtype=torch.bfloat16,
+ )
+ defaults = dict(
+ hidden_size=2048,
+ num_hidden_layers=24,
+ num_attention_heads=8,
+ num_key_value_heads=2,
+ head_dim=256,
+ intermediate_size=6144,
+ vocab_size=248320,
+ rms_norm_eps=1e-6,
+ max_position_embeddings=262144,
+ rope_theta=10000000,
+ hidden_act="silu",
+ tie_word_embeddings=True,
+ # DeltaNet-specific
+ linear_num_value_heads=16,
+ linear_num_key_heads=16,
+ linear_key_head_dim=128,
+ linear_value_head_dim=128,
+ linear_conv_kernel_dim=4,
+ )
+ defaults.update(overrides)
+ config = Qwen35InferenceConfig(neuron_config=neuron_config, **defaults)
+ return config
+
+
+class TestConfigParsing(unittest.TestCase):
+ """Test basic config attribute initialization."""
+
+ def test_hidden_size(self):
+ config = _make_config()
+ self.assertEqual(config.hidden_size, 2048)
+
+ def test_num_hidden_layers(self):
+ config = _make_config()
+ self.assertEqual(config.num_hidden_layers, 24)
+
+ def test_num_attention_heads(self):
+ config = _make_config()
+ self.assertEqual(config.num_attention_heads, 8)
+
+ def test_num_key_value_heads(self):
+ config = _make_config()
+ self.assertEqual(config.num_key_value_heads, 2)
+
+ def test_head_dim(self):
+ config = _make_config()
+ self.assertEqual(config.head_dim, 256)
+
+ def test_intermediate_size(self):
+ config = _make_config()
+ self.assertEqual(config.intermediate_size, 6144)
+
+ def test_vocab_size(self):
+ config = _make_config()
+ self.assertEqual(config.vocab_size, 248320)
+
+ def test_hidden_act(self):
+ config = _make_config()
+ self.assertEqual(config.hidden_act, "silu")
+
+
+class TestLayerTypes(unittest.TestCase):
+ """Test hybrid layer type assignment (3 DeltaNet + 1 GQA) x 6."""
+
+ def test_layer_types_length(self):
+ config = _make_config()
+ self.assertEqual(len(config.layer_types), 24)
+
+ def test_layer_types_pattern(self):
+ """Every 4th layer (3, 7, 11, ...) should be full_attention."""
+ config = _make_config()
+ for i in range(24):
+ expected = "full_attention" if i % 4 == 3 else "linear_attention"
+ self.assertEqual(config.layer_types[i], expected, f"Layer {i} mismatch")
+
+ def test_deltanet_layer_count(self):
+ config = _make_config()
+ dn_count = sum(1 for t in config.layer_types if t == "linear_attention")
+ self.assertEqual(dn_count, 18)
+
+ def test_gqa_layer_count(self):
+ config = _make_config()
+ gqa_count = sum(1 for t in config.layer_types if t == "full_attention")
+ self.assertEqual(gqa_count, 6)
+
+
+class TestDeltaNetConfig(unittest.TestCase):
+ """Test DeltaNet-specific configuration defaults."""
+
+ def test_linear_num_value_heads(self):
+ config = _make_config()
+ self.assertEqual(config.linear_num_value_heads, 16)
+
+ def test_linear_num_key_heads(self):
+ config = _make_config()
+ self.assertEqual(config.linear_num_key_heads, 16)
+
+ def test_linear_key_head_dim(self):
+ config = _make_config()
+ self.assertEqual(config.linear_key_head_dim, 128)
+
+ def test_linear_value_head_dim(self):
+ config = _make_config()
+ self.assertEqual(config.linear_value_head_dim, 128)
+
+ def test_linear_conv_kernel_dim(self):
+ config = _make_config()
+ self.assertEqual(config.linear_conv_kernel_dim, 4)
+
+
+class TestRoPEConfig(unittest.TestCase):
+ """Test partial RoPE configuration."""
+
+ def test_partial_rotary_factor(self):
+ config = _make_config()
+ self.assertAlmostEqual(config.partial_rotary_factor, 0.25)
+
+ def test_rope_dim(self):
+ """rope_dim = head_dim * partial_rotary_factor = 256 * 0.25 = 64."""
+ config = _make_config()
+ self.assertEqual(config.rope_dim, 64)
+
+ def test_attn_output_gate(self):
+ config = _make_config()
+ self.assertTrue(config.attn_output_gate)
+
+ def test_mrope_section(self):
+ config = _make_config()
+ self.assertEqual(config.mrope_section, [11, 11, 10])
+
+ def test_mrope_interleaved(self):
+ config = _make_config()
+ self.assertTrue(config.mrope_interleaved)
+
+
+class TestNeuronConfig(unittest.TestCase):
+ """Test Neuron-specific configuration settings."""
+
+ def test_neuron_config_cls(self):
+ """Qwen3.5-2B is dense -- uses NeuronConfig, NOT MoENeuronConfig."""
+ self.assertEqual(
+ Qwen35InferenceConfig.get_neuron_config_cls(),
+ NeuronConfig,
+ )
+
+ def test_required_attributes(self):
+ config = _make_config()
+ required = config.get_required_attributes()
+ self.assertIn("hidden_size", required)
+ self.assertIn("num_hidden_layers", required)
+ self.assertIn("linear_num_value_heads", required)
+ self.assertIn("linear_key_head_dim", required)
+ self.assertIn("layer_types", required)
+
+ def test_output_attentions_default(self):
+ config = _make_config()
+ self.assertFalse(config.output_attentions)
+
+ def test_output_hidden_states_default(self):
+ config = _make_config()
+ self.assertFalse(config.output_hidden_states)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/contrib/models/Qwen3.5-2B/test/unit/test_weight_conversion.py b/contrib/models/Qwen3.5-2B/test/unit/test_weight_conversion.py
new file mode 100644
index 00000000..edc7753d
--- /dev/null
+++ b/contrib/models/Qwen3.5-2B/test/unit/test_weight_conversion.py
@@ -0,0 +1,434 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Unit tests for Qwen3.5-2B HF-to-NxDI weight conversion.
+
+CPU-only tests that validate:
+- RMSNorm (+1 convention) weight conversion
+- GQA q_proj interleaved split (query + gate)
+- QK norm key renaming (q_norm -> q_layernorm, k_norm -> k_layernorm)
+- Fused QKV concatenation
+- DeltaNet layer weights pass through unchanged
+- VL wrapper prefix stripping
+- rank_util injection
+"""
+
+import os
+import sys
+import unittest
+
+import torch
+
+_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
+if _CONTRIB_ROOT not in sys.path:
+ sys.path.insert(0, _CONTRIB_ROOT)
+
+from src.modeling_qwen35 import (
+ Qwen35InferenceConfig,
+ NeuronQwen35ForCausalLM,
+ convert_qwen35_hf_to_neuron_state_dict,
+)
+from neuronx_distributed_inference.models.config import NeuronConfig
+
+
+def _make_mini_config(num_layers=4, tp_degree=2, fused_qkv=True):
+ """Create a small Qwen35InferenceConfig for testing."""
+ neuron_config = NeuronConfig(
+ tp_degree=tp_degree,
+ batch_size=1,
+ seq_len=128,
+ torch_dtype=torch.bfloat16,
+ fused_qkv=fused_qkv,
+ )
+ config = Qwen35InferenceConfig(
+ neuron_config=neuron_config,
+ hidden_size=256,
+ num_hidden_layers=num_layers,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ head_dim=64,
+ intermediate_size=512,
+ vocab_size=1000,
+ rms_norm_eps=1e-6,
+ max_position_embeddings=4096,
+ rope_theta=10000,
+ hidden_act="silu",
+ linear_num_value_heads=8,
+ linear_num_key_heads=4,
+ linear_key_head_dim=32,
+ linear_value_head_dim=32,
+ linear_conv_kernel_dim=4,
+ )
+ return config
+
+
+def _make_mini_state_dict(config):
+ """Create a minimal HF-style state dict for conversion testing."""
+ sd = {}
+ H = config.hidden_size # 256
+ I = config.intermediate_size # 512
+ V = config.vocab_size # 1000
+ num_heads = config.num_attention_heads # 4
+ num_kv = config.num_key_value_heads # 2
+ head_dim = config.head_dim # 64
+
+ sd["embed_tokens.weight"] = torch.randn(V, H, dtype=torch.bfloat16) * 0.02
+ sd["lm_head.weight"] = torch.randn(V, H, dtype=torch.bfloat16) * 0.02
+ sd["norm.weight"] = torch.zeros(H, dtype=torch.bfloat16) # +1 convention: zeros
+
+ for l in range(config.num_hidden_layers):
+ sd[f"layers.{l}.input_layernorm.weight"] = torch.zeros(H, dtype=torch.bfloat16)
+ sd[f"layers.{l}.post_attention_layernorm.weight"] = torch.zeros(
+ H, dtype=torch.bfloat16
+ )
+
+ # Dense MLP (all layers)
+ sd[f"layers.{l}.mlp.gate_proj.weight"] = (
+ torch.randn(I, H, dtype=torch.bfloat16) * 0.02
+ )
+ sd[f"layers.{l}.mlp.up_proj.weight"] = (
+ torch.randn(I, H, dtype=torch.bfloat16) * 0.02
+ )
+ sd[f"layers.{l}.mlp.down_proj.weight"] = (
+ torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+ )
+
+ if config.layer_types[l] == "full_attention":
+ # GQA layer: q_proj is interleaved [head0_q | head0_gate | head1_q | ...]
+ q_proj = (
+ torch.randn(num_heads * head_dim * 2, H, dtype=torch.bfloat16) * 0.02
+ )
+ sd[f"layers.{l}.self_attn.q_proj.weight"] = q_proj
+ sd[f"layers.{l}.self_attn.k_proj.weight"] = (
+ torch.randn(num_kv * head_dim, H, dtype=torch.bfloat16) * 0.02
+ )
+ sd[f"layers.{l}.self_attn.v_proj.weight"] = (
+ torch.randn(num_kv * head_dim, H, dtype=torch.bfloat16) * 0.02
+ )
+ sd[f"layers.{l}.self_attn.o_proj.weight"] = (
+ torch.randn(H, num_heads * head_dim, dtype=torch.bfloat16) * 0.02
+ )
+ sd[f"layers.{l}.self_attn.q_norm.weight"] = torch.zeros(
+ head_dim, dtype=torch.bfloat16
+ )
+ sd[f"layers.{l}.self_attn.k_norm.weight"] = torch.zeros(
+ head_dim, dtype=torch.bfloat16
+ )
+ else:
+ # DeltaNet layer: minimal required weights
+ key_dim = config.linear_num_key_heads * config.linear_key_head_dim # 128
+ value_dim = (
+ config.linear_num_value_heads * config.linear_value_head_dim
+ ) # 256
+ conv_dim = key_dim * 2 + value_dim # 512
+ sd[f"layers.{l}.linear_attn.in_proj_qkv.weight"] = (
+ torch.randn(conv_dim, H, dtype=torch.bfloat16) * 0.02
+ )
+ sd[f"layers.{l}.linear_attn.in_proj_z.weight"] = (
+ torch.randn(value_dim, H, dtype=torch.bfloat16) * 0.02
+ )
+ sd[f"layers.{l}.linear_attn.in_proj_a.weight"] = (
+ torch.randn(config.linear_num_value_heads, H, dtype=torch.bfloat16)
+ * 0.02
+ )
+ sd[f"layers.{l}.linear_attn.in_proj_b.weight"] = (
+ torch.randn(config.linear_num_value_heads, H, dtype=torch.bfloat16)
+ * 0.02
+ )
+ sd[f"layers.{l}.linear_attn.conv1d.weight"] = (
+ torch.randn(
+ conv_dim, 1, config.linear_conv_kernel_dim, dtype=torch.bfloat16
+ )
+ * 0.02
+ )
+ sd[f"layers.{l}.linear_attn.A_log"] = torch.randn(
+ config.linear_num_value_heads, dtype=torch.bfloat16
+ )
+ sd[f"layers.{l}.linear_attn.dt_bias"] = torch.randn(
+ config.linear_num_value_heads, dtype=torch.bfloat16
+ )
+ sd[f"layers.{l}.linear_attn.norm.weight"] = (
+ torch.randn(value_dim, dtype=torch.bfloat16) * 0.5
+ )
+ sd[f"layers.{l}.linear_attn.out_proj.weight"] = (
+ torch.randn(H, value_dim, dtype=torch.bfloat16) * 0.02
+ )
+
+ return sd
+
+
+class TestNormConversion(unittest.TestCase):
+ """Test (+1 convention) RMSNorm weight conversion."""
+
+ def test_norm_weight_adds_one(self):
+ """Weights initialized to zero should become 1.0 after conversion."""
+ config = _make_mini_config()
+ sd = _make_mini_state_dict(config)
+ result = convert_qwen35_hf_to_neuron_state_dict(sd, config)
+ # norm.weight was zeros -> should now be ones
+ torch.testing.assert_close(
+ result["norm.weight"],
+ torch.ones_like(result["norm.weight"]),
+ )
+
+ def test_input_layernorm_adds_one(self):
+ config = _make_mini_config()
+ sd = _make_mini_state_dict(config)
+ result = convert_qwen35_hf_to_neuron_state_dict(sd, config)
+ for l in range(config.num_hidden_layers):
+ w = result[f"layers.{l}.input_layernorm.weight"]
+ self.assertTrue(
+ torch.allclose(w, torch.ones_like(w)),
+ f"Layer {l} input_layernorm not converted",
+ )
+
+ def test_post_attn_layernorm_adds_one(self):
+ config = _make_mini_config()
+ sd = _make_mini_state_dict(config)
+ result = convert_qwen35_hf_to_neuron_state_dict(sd, config)
+ for l in range(config.num_hidden_layers):
+ w = result[f"layers.{l}.post_attention_layernorm.weight"]
+ self.assertTrue(
+ torch.allclose(w, torch.ones_like(w)),
+ f"Layer {l} post_attention_layernorm not converted",
+ )
+
+ def test_qk_norm_adds_one(self):
+ """Q/K norms on GQA layers should also get +1 applied."""
+ config = _make_mini_config()
+ sd = _make_mini_state_dict(config)
+ result = convert_qwen35_hf_to_neuron_state_dict(sd, config)
+ for l in range(config.num_hidden_layers):
+ if config.layer_types[l] == "full_attention":
+ q_w = result[f"layers.{l}.self_attn.q_layernorm.weight"]
+ k_w = result[f"layers.{l}.self_attn.k_layernorm.weight"]
+ self.assertTrue(
+ torch.allclose(q_w, torch.ones_like(q_w)),
+ f"Layer {l} q_layernorm not converted",
+ )
+ self.assertTrue(
+ torch.allclose(k_w, torch.ones_like(k_w)),
+ f"Layer {l} k_layernorm not converted",
+ )
+
+
+class TestQProjSplit(unittest.TestCase):
+ """Test q_proj interleaved split into query + gate."""
+
+ def test_q_proj_split_shapes(self):
+ """q_proj (num_heads * head_dim * 2, H) -> separate query and gate."""
+ config = _make_mini_config(fused_qkv=False)
+ sd = _make_mini_state_dict(config)
+ result = convert_qwen35_hf_to_neuron_state_dict(sd, config)
+
+ for l in range(config.num_hidden_layers):
+ if config.layer_types[l] == "full_attention":
+ # After split: q_proj should be (num_heads * head_dim, H) = (256, 256)
+ q_w = result[f"layers.{l}.self_attn.q_proj.weight"]
+ gate_w = result[f"layers.{l}.self_attn.output_gate_proj.weight"]
+ expected_shape = (
+ config.num_attention_heads * config.head_dim,
+ config.hidden_size,
+ )
+ self.assertEqual(
+ q_w.shape, expected_shape, f"Layer {l} q_proj shape wrong"
+ )
+ self.assertEqual(
+ gate_w.shape, expected_shape, f"Layer {l} gate shape wrong"
+ )
+
+ def test_q_proj_deinterleave_correct(self):
+ """Verify the interleaved split correctly separates query and gate."""
+ config = _make_mini_config(fused_qkv=False)
+ sd = _make_mini_state_dict(config)
+
+ # Create a known pattern: head0 query is 1s, head0 gate is 2s, etc.
+ l = 3 # First full_attention layer (layer 3)
+ num_heads = config.num_attention_heads
+ head_dim = config.head_dim
+ H = config.hidden_size
+
+ interleaved = torch.zeros(num_heads * head_dim * 2, H, dtype=torch.bfloat16)
+ for h in range(num_heads):
+ interleaved[h * head_dim * 2 : h * head_dim * 2 + head_dim, :] = float(
+ h + 1
+ ) # query
+ interleaved[h * head_dim * 2 + head_dim : (h + 1) * head_dim * 2, :] = (
+ float(h + 100)
+ ) # gate
+
+ sd[f"layers.{l}.self_attn.q_proj.weight"] = interleaved
+ result = convert_qwen35_hf_to_neuron_state_dict(sd, config)
+
+ q_w = result[f"layers.{l}.self_attn.q_proj.weight"]
+ gate_w = result[f"layers.{l}.self_attn.output_gate_proj.weight"]
+
+ for h in range(num_heads):
+ q_head = q_w[h * head_dim : (h + 1) * head_dim, :]
+ gate_head = gate_w[h * head_dim : (h + 1) * head_dim, :]
+ self.assertTrue(
+ torch.all(q_head == float(h + 1)), f"Head {h} query values wrong"
+ )
+ self.assertTrue(
+ torch.all(gate_head == float(h + 100)), f"Head {h} gate values wrong"
+ )
+
+
+class TestQKNormRename(unittest.TestCase):
+ """Test q_norm -> q_layernorm and k_norm -> k_layernorm renaming."""
+
+ def test_old_keys_removed(self):
+ config = _make_mini_config()
+ sd = _make_mini_state_dict(config)
+ result = convert_qwen35_hf_to_neuron_state_dict(sd, config)
+ for l in range(config.num_hidden_layers):
+ if config.layer_types[l] == "full_attention":
+ self.assertNotIn(f"layers.{l}.self_attn.q_norm.weight", result)
+ self.assertNotIn(f"layers.{l}.self_attn.k_norm.weight", result)
+
+ def test_new_keys_present(self):
+ config = _make_mini_config()
+ sd = _make_mini_state_dict(config)
+ result = convert_qwen35_hf_to_neuron_state_dict(sd, config)
+ for l in range(config.num_hidden_layers):
+ if config.layer_types[l] == "full_attention":
+ self.assertIn(f"layers.{l}.self_attn.q_layernorm.weight", result)
+ self.assertIn(f"layers.{l}.self_attn.k_layernorm.weight", result)
+
+
+class TestFusedQKV(unittest.TestCase):
+ """Test fused QKV concatenation for attention layers."""
+
+ def test_fused_qkv_shape(self):
+ config = _make_mini_config(fused_qkv=True)
+ sd = _make_mini_state_dict(config)
+ result = convert_qwen35_hf_to_neuron_state_dict(sd, config)
+
+ for l in range(config.num_hidden_layers):
+ if config.layer_types[l] == "full_attention":
+ fused_key = f"layers.{l}.self_attn.Wqkv.weight"
+ self.assertIn(fused_key, result, f"Layer {l} missing Wqkv")
+
+ q_dim = config.num_attention_heads * config.head_dim
+ k_dim = config.num_key_value_heads * config.head_dim
+ v_dim = config.num_key_value_heads * config.head_dim
+ expected_rows = q_dim + k_dim + v_dim
+ self.assertEqual(result[fused_key].shape[0], expected_rows)
+
+ def test_fused_qkv_removes_individual_keys(self):
+ config = _make_mini_config(fused_qkv=True)
+ sd = _make_mini_state_dict(config)
+ result = convert_qwen35_hf_to_neuron_state_dict(sd, config)
+
+ for l in range(config.num_hidden_layers):
+ if config.layer_types[l] == "full_attention":
+ self.assertNotIn(f"layers.{l}.self_attn.q_proj.weight", result)
+ self.assertNotIn(f"layers.{l}.self_attn.k_proj.weight", result)
+ self.assertNotIn(f"layers.{l}.self_attn.v_proj.weight", result)
+
+
+class TestDeltaNetPassthrough(unittest.TestCase):
+ """Test that DeltaNet layer weights pass through conversion unchanged."""
+
+ def test_deltanet_weights_unchanged(self):
+ config = _make_mini_config()
+ sd = _make_mini_state_dict(config)
+
+ # Record original DeltaNet weights
+ originals = {}
+ for l in range(config.num_hidden_layers):
+ if config.layer_types[l] == "linear_attention":
+ key = f"layers.{l}.linear_attn.in_proj_qkv.weight"
+ originals[key] = sd[key].clone()
+
+ result = convert_qwen35_hf_to_neuron_state_dict(sd, config)
+
+ for key, orig in originals.items():
+ self.assertIn(key, result, f"Missing: {key}")
+ torch.testing.assert_close(
+ result[key], orig, msg=f"DeltaNet weight changed: {key}"
+ )
+
+ def test_deltanet_norm_not_converted(self):
+ """DeltaNet layers use standard RMSNorm (NOT +1 convention).
+ The norm weight should NOT be changed."""
+ config = _make_mini_config()
+ sd = _make_mini_state_dict(config)
+
+ # Set DeltaNet norm to a known non-zero value
+ for l in range(config.num_hidden_layers):
+ if config.layer_types[l] == "linear_attention":
+ sd[f"layers.{l}.linear_attn.norm.weight"] = torch.full(
+ (config.linear_num_value_heads * config.linear_value_head_dim,),
+ 0.87,
+ dtype=torch.bfloat16,
+ )
+
+ result = convert_qwen35_hf_to_neuron_state_dict(sd, config)
+
+ for l in range(config.num_hidden_layers):
+ if config.layer_types[l] == "linear_attention":
+ w = result[f"layers.{l}.linear_attn.norm.weight"]
+ # Should still be ~0.87, NOT 1.87
+ self.assertTrue(
+ torch.allclose(w, torch.full_like(w, 0.87), atol=0.01),
+ f"Layer {l} DeltaNet norm was incorrectly modified",
+ )
+
+
+class TestRankUtil(unittest.TestCase):
+ """Test rank_util tensor injection."""
+
+ def test_rank_util_present(self):
+ config = _make_mini_config(tp_degree=4)
+ sd = _make_mini_state_dict(config)
+ result = convert_qwen35_hf_to_neuron_state_dict(sd, config)
+ self.assertIn("rank_util.rank", result)
+ expected = torch.arange(0, 4, dtype=torch.int32)
+ torch.testing.assert_close(result["rank_util.rank"], expected)
+
+ def test_gqa_layer_rank_util(self):
+ config = _make_mini_config(tp_degree=4)
+ sd = _make_mini_state_dict(config)
+ result = convert_qwen35_hf_to_neuron_state_dict(sd, config)
+ for l in range(config.num_hidden_layers):
+ if config.layer_types[l] == "full_attention":
+ key = f"layers.{l}.self_attn.rank_util.rank"
+ self.assertIn(key, result)
+ expected = torch.arange(0, 4, dtype=torch.int32)
+ torch.testing.assert_close(result[key], expected)
+
+
+class TestVLPrefixStripping(unittest.TestCase):
+ """Test VL wrapper prefix stripping in convert_hf_to_neuron_state_dict."""
+
+ def test_language_model_prefix_stripped(self):
+ config = _make_mini_config()
+ sd = _make_mini_state_dict(config)
+
+ # Wrap with VL prefix
+ vl_sd = {}
+ for k, v in sd.items():
+ vl_sd[f"language_model.{k}"] = v
+ vl_sd["visual.encoder.weight"] = torch.zeros(10) # should be skipped
+ vl_sd["mtp.something"] = torch.zeros(5) # should be skipped
+
+ result = NeuronQwen35ForCausalLM.convert_hf_to_neuron_state_dict(vl_sd, config)
+ self.assertNotIn("visual.encoder.weight", result)
+ self.assertNotIn("mtp.something", result)
+ self.assertIn("norm.weight", result)
+
+ def test_model_language_model_prefix_stripped(self):
+ config = _make_mini_config()
+ sd = _make_mini_state_dict(config)
+
+ vl_sd = {}
+ for k, v in sd.items():
+ vl_sd[f"model.language_model.{k}"] = v
+
+ result = NeuronQwen35ForCausalLM.convert_hf_to_neuron_state_dict(vl_sd, config)
+ self.assertIn("norm.weight", result)
+
+
+if __name__ == "__main__":
+ unittest.main()