From 9d038e9e8339fe59858972ea654ce44035e5d9d0 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Thu, 30 Apr 2026 13:34:42 +0530 Subject: [PATCH 1/5] Add Qwen3.5 4B and 9B contrib models Adds Qwen3.5 dense hybrid DeltaNet/GQA contrib model variants for 4B and 9B, including NKI DeltaNet kernels, weight conversion tests, and Trainium integration tests. This builds on Jim Burtoft's Qwen3.5/Qwen3.6 contrib work in PR #141 and PR #140; his dummy-KV plus side-channel DeltaNet state pattern is the baseline used here. --- contrib/models/Qwen3.5-4B/README.md | 71 + contrib/models/Qwen3.5-4B/src/__init__.py | 27 + .../models/Qwen3.5-4B/src/modeling_qwen35.py | 2533 +++++++++++++++++ .../Qwen3.5-4B/src/nki_kernels/__init__.py | 10 + .../src/nki_kernels/nki_deltanet.py | 337 +++ .../src/nki_kernels/nki_deltanet_chunked.py | 323 +++ .../src/nki_kernels/nki_deltanet_fused.py | 577 ++++ contrib/models/Qwen3.5-4B/test/__init__.py | 2 + .../Qwen3.5-4B/test/integration/__init__.py | 2 + .../Qwen3.5-4B/test/integration/test_model.py | 482 ++++ .../test/parity/deltanet_path_probe.py | 206 ++ .../models/Qwen3.5-4B/test/unit/__init__.py | 2 + .../Qwen3.5-4B/test/unit/test_config.py | 201 ++ .../test/unit/test_weight_conversion.py | 445 +++ contrib/models/Qwen3.5-9B/README.md | 156 + contrib/models/Qwen3.5-9B/src/__init__.py | 27 + .../models/Qwen3.5-9B/src/modeling_qwen35.py | 2528 ++++++++++++++++ .../Qwen3.5-9B/src/nki_kernels/__init__.py | 10 + .../src/nki_kernels/nki_deltanet.py | 337 +++ .../src/nki_kernels/nki_deltanet_chunked.py | 323 +++ .../src/nki_kernels/nki_deltanet_fused.py | 577 ++++ contrib/models/Qwen3.5-9B/test/__init__.py | 2 + .../Qwen3.5-9B/test/integration/__init__.py | 2 + .../Qwen3.5-9B/test/integration/test_model.py | 486 ++++ .../test/parity/deltanet_path_probe.py | 206 ++ .../models/Qwen3.5-9B/test/unit/__init__.py | 2 + .../Qwen3.5-9B/test/unit/test_config.py | 201 ++ .../test/unit/test_weight_conversion.py | 434 +++ 28 files changed, 10509 insertions(+) create mode 100644 contrib/models/Qwen3.5-4B/README.md create mode 100644 contrib/models/Qwen3.5-4B/src/__init__.py create mode 100644 contrib/models/Qwen3.5-4B/src/modeling_qwen35.py create mode 100644 contrib/models/Qwen3.5-4B/src/nki_kernels/__init__.py create mode 100644 contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet.py create mode 100644 contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_chunked.py create mode 100644 contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_fused.py create mode 100644 contrib/models/Qwen3.5-4B/test/__init__.py create mode 100644 contrib/models/Qwen3.5-4B/test/integration/__init__.py create mode 100644 contrib/models/Qwen3.5-4B/test/integration/test_model.py create mode 100644 contrib/models/Qwen3.5-4B/test/parity/deltanet_path_probe.py create mode 100644 contrib/models/Qwen3.5-4B/test/unit/__init__.py create mode 100644 contrib/models/Qwen3.5-4B/test/unit/test_config.py create mode 100644 contrib/models/Qwen3.5-4B/test/unit/test_weight_conversion.py create mode 100644 contrib/models/Qwen3.5-9B/README.md create mode 100644 contrib/models/Qwen3.5-9B/src/__init__.py create mode 100644 contrib/models/Qwen3.5-9B/src/modeling_qwen35.py create mode 100644 contrib/models/Qwen3.5-9B/src/nki_kernels/__init__.py create mode 100644 contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet.py create mode 100644 contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_chunked.py create mode 100644 contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_fused.py create mode 100644 contrib/models/Qwen3.5-9B/test/__init__.py create mode 100644 contrib/models/Qwen3.5-9B/test/integration/__init__.py create mode 100644 contrib/models/Qwen3.5-9B/test/integration/test_model.py create mode 100644 contrib/models/Qwen3.5-9B/test/parity/deltanet_path_probe.py create mode 100644 contrib/models/Qwen3.5-9B/test/unit/__init__.py create mode 100644 contrib/models/Qwen3.5-9B/test/unit/test_config.py create mode 100644 contrib/models/Qwen3.5-9B/test/unit/test_weight_conversion.py diff --git a/contrib/models/Qwen3.5-4B/README.md b/contrib/models/Qwen3.5-4B/README.md new file mode 100644 index 00000000..1b75bee6 --- /dev/null +++ b/contrib/models/Qwen3.5-4B/README.md @@ -0,0 +1,71 @@ +# Contrib Model: Qwen3.5-4B + +NeuronX Distributed Inference implementation of Qwen3.5-4B, a dense hybrid DeltaNet + GQA decoder from Alibaba Cloud. + +This variant reuses the proven PR 140/141 architecture: + +- Standard GQA layers use NxDI `KVCacheManager`. +- DeltaNet layers return dummy KV tensors to satisfy NxDI cache plumbing. +- Real DeltaNet state is carried through layer-local `recurrent_state_buffer` and `conv_state_buffer` side-channel aliases. + +## Model Information + +| Feature | Value | +| --- | --- | +| HuggingFace ID | `Qwen/Qwen3.5-4B` | +| Model type | `qwen3_5_text` under top-level `qwen3_5` | +| Layers | 32: 24 DeltaNet + 8 GQA | +| Layer pattern | `[3 DeltaNet + 1 GQA] x 8` | +| Hidden size | 2560 | +| MLP | Dense SwiGLU, intermediate size 9216 | +| GQA attention | 16 Q heads, 4 KV heads, head_dim 256 | +| DeltaNet | 32 value heads, 16 key heads, k_dim=v_dim=128 | +| Conv kernel | 4, state stores last 3 pre-conv QKV tokens | +| RoPE | Partial RoPE, 25% of head_dim = 64 dims | +| Vocabulary | 248,320 | +| Tied embeddings | Yes | + +Derived DeltaNet shapes: + +| Tensor | Shape | +| --- | --- | +| `in_proj_qkv.weight` | `[8192, 2560]` | +| `in_proj_z.weight` | `[4096, 2560]` | +| `in_proj_a.weight` | `[32, 2560]` | +| `in_proj_b.weight` | `[32, 2560]` | +| `conv1d.weight` | `[8192, 1, 4]` | +| `recurrent_state_buffer` | `[max_batch, 32, 128, 128]` | +| `conv_state_buffer` | `[max_batch, 8192, 3]` | + +## Status + +Prepared for Trn2 bring-up. Validate TP=4, batch=1, seq_len=128 first, then increase context or reduce TP only after the baseline generates correctly. + +## Testing + +CPU unit tests: + +```bash +cd contrib/models/Qwen3.5-4B +pytest test/unit -v +``` + +Trainium integration: + +```bash +cd contrib/models/Qwen3.5-4B +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +NEURON_PLATFORM_TARGET_OVERRIDE=trn2 \ +QWEN35_MODEL_PATH=/home/ubuntu/models/Qwen3.5-4B \ +QWEN35_COMPILED_PATH=/home/ubuntu/models/qwen35_4b_traced_trn2 \ +QWEN35_TP_DEGREE=4 \ +QWEN35_SEQ_LEN=128 \ +pytest test/integration/test_model.py --capture=tee-sys -v +``` + +## Known Limitations + +- DeltaNet weights are replicated across TP ranks in v1. +- DeltaNet layers still allocate dummy KV cache through NxDI's normal cache manager. +- MoE, VL, quantization, speculation, and custom hybrid cache cleanup are out of scope. diff --git a/contrib/models/Qwen3.5-4B/src/__init__.py b/contrib/models/Qwen3.5-4B/src/__init__.py new file mode 100644 index 00000000..f8e014e6 --- /dev/null +++ b/contrib/models/Qwen3.5-4B/src/__init__.py @@ -0,0 +1,27 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from src.modeling_qwen35 import ( + NeuronGatedDeltaNet, + NeuronQwen35Attention, + NeuronQwen35DecoderLayer, + NeuronQwen35ForCausalLM, + NeuronQwen35Model, + Qwen35DecoderModelInstance, + Qwen35InferenceConfig, + Qwen35MLP, + Qwen35ModelWrapper, +) + +__all__ = [ + # Text decoder + "NeuronGatedDeltaNet", + "NeuronQwen35Attention", + "NeuronQwen35DecoderLayer", + "NeuronQwen35ForCausalLM", + "NeuronQwen35Model", + "Qwen35DecoderModelInstance", + "Qwen35InferenceConfig", + "Qwen35MLP", + "Qwen35ModelWrapper", +] diff --git a/contrib/models/Qwen3.5-4B/src/modeling_qwen35.py b/contrib/models/Qwen3.5-4B/src/modeling_qwen35.py new file mode 100644 index 00000000..3e27e419 --- /dev/null +++ b/contrib/models/Qwen3.5-4B/src/modeling_qwen35.py @@ -0,0 +1,2533 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +NxDI contrib: Qwen3.5-4B (qwen3_5 -- dense model) + +Hybrid DeltaNet + Standard Attention + Dense MLP architecture. + +24 of 32 layers use Gated DeltaNet (linear recurrent attention) +8 of 32 layers use standard GQA with KV cache + output gate +All 32 layers use a dense SwiGLU MLP (intermediate_size=9216) + +Architecture details: +- DeltaNet layers: separate in_proj_{qkv, z, a, b}, causal conv1d on QKV, gated delta rule +- Attention layers: q_proj doubled (Q + gate), partial RoPE (25% of head_dim), sigmoid output gate +- Dense MLP: standard SwiGLU (gate_proj, up_proj, down_proj) -- no MoE, no router, no experts +- KV cache: NxDI KVCacheManager for attention layers; DeltaNet layers store recurrent+conv + state as nn.Parameter buffers and return dummy KV tuples +""" + +import gc +import math +import logging +import os +import sys +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm + +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.utils import cpu_mode + +try: + from nki import jit as nki_jit # NKI 0.3.0+ (SDK 2.29) +except ImportError: + from torch_neuronx.xla_impl.ops import nki_jit # NKI 0.2.x (SDK 2.28) +from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeRMSNorm + +from src.nki_kernels.nki_deltanet import deltanet_recurrent_fwd as _deltanet_nki_kernel +from src.nki_kernels.nki_deltanet import ( + deltanet_recurrent_fwd_state as _deltanet_nki_kernel_state, +) +from src.nki_kernels.nki_deltanet_chunked import ( + deltanet_chunk_step as _deltanet_nki_chunk_step, +) +from src.nki_kernels.nki_deltanet_fused import ( + deltanet_fused_chunked_fwd as _deltanet_fused_kernel, +) +from src.nki_kernels.nki_deltanet_fused import ( + _make_lower_mask, + _make_lower_mask_diag, + _make_identity, +) + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + NeuronConfig, +) +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, + DecoderModelInstance, + ModelWrapper, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.models.layer_boundary_marker import ( + ModuleMarkerEndWrapper, + ModuleMarkerStartWrapper, +) + +logger = logging.getLogger(__name__) + +try: + _flash_fwd_call = nki_jit()(attention_isa_kernel) +except TypeError: + from torch_neuronx.xla_impl.ops import nki_jit as _torch_xla_nki_jit + + _flash_fwd_call = _torch_xla_nki_jit()(attention_isa_kernel) + +# Option B: Direct nkilib flash attention for head_dim > 128 +USE_NKILIB_KERNEL = os.environ.get("USE_NKILIB_KERNEL", "0") == "1" + +_nkilib_flash_attn = None +if USE_NKILIB_KERNEL: + try: + import neuronxcc.nki as _nki + from neuronx_distributed_inference.modules.attention.attention_base import ( + peel_decorations as _peel_decorations, + get_platform_target as _get_platform_target, + ) + from neuronxcc.nki.compiler import ( + skip_middle_end_transformations as _skip_middle_end, + enable_stack_allocator as _enable_stack_allocator, + ) + + import importlib + + _fork_path = "/home/ubuntu/nki-library-fork/nkilib_src" + if os.path.isdir(_fork_path) and _fork_path not in sys.path: + sys.path.insert(0, _fork_path) + _to_remove = [k for k in sys.modules if k.startswith("nkilib")] + for k in _to_remove: + del sys.modules[k] + import nki.language as _stub_nl + import neuronxcc.nki.language as _real_nl + + for _attr in [ + "NKIObject", + "float8_e4m3fn", + "float8_e4m3fn_x4", + "float8_e5m2_x4", + "float4_e2m1fn_x4", + ]: + if not hasattr(_real_nl, _attr) and hasattr(_stub_nl, _attr): + setattr(_real_nl, _attr, getattr(_stub_nl, _attr)) + from nkilib.core.attention.attention_cte import ( + attention_cte as _attention_cte_raw, + _MAX_HEAD_DIM, + ) + + assert _MAX_HEAD_DIM == 256, ( + f"nkilib fork has _MAX_HEAD_DIM={_MAX_HEAD_DIM}, expected 256. " + f"System nkilib may have been loaded instead of fork." + ) + logger.info( + f"Loaded nkilib attention_cte from fork (_MAX_HEAD_DIM={_MAX_HEAD_DIM})" + ) + + _raw_fn = _peel_decorations(_attention_cte_raw) + _platform = _get_platform_target() + _nkilib_flash_attn = _nki.jit( + _raw_fn, + mode="torchxla", + platform_target=_platform, + show_compiler_tb=True, + debug_kernel=True, + ) + _nkilib_flash_attn = _skip_middle_end(_nkilib_flash_attn) + _nkilib_flash_attn = _enable_stack_allocator( + _nkilib_flash_attn, log_level=logging.INFO + ) + logger.info("Option B: nkilib flash attention loaded for head_dim > 128") + except Exception as e: + logger.warning(f"Option B: Failed to load nkilib flash attention: {e}") + import traceback as _tb + + _tb.print_exc() + _nkilib_flash_attn = None + +# Option A: Detect if patch_attn_kernel was imported +NKILIB_PATCH_ACTIVE = False +try: + from importlib import import_module as _import_module + + _attn_mod = _import_module("neuronxcc.nki._pre_prod_kernels.attn_fwd") + if hasattr(_attn_mod, "_original_attention_nki_kernel_adapter"): + NKILIB_PATCH_ACTIVE = True + logger.info("Option A detected: _pre_prod_kernels patched with nkilib kernel") +except Exception: + pass + + +# ============================================================ +# Newton-Raphson Refined RMSNorm +# ============================================================ +USE_NEWTON_RMSNORM = os.environ.get("USE_NEWTON_RMSNORM") == "1" +USE_PYTHON_RMSNORM = os.environ.get("USE_PYTHON_RMSNORM") == "1" + + +class NewtonRMSNorm(nn.Module): + """RMSNorm with Newton-Raphson refined rsqrt for improved numerical accuracy.""" + + def __init__(self, hidden_size=None, eps=1e-6): + super().__init__() + self.weight = None + if hidden_size is not None: + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.hidden_size = hidden_size + self.variance_epsilon = eps + + def forward(self, hidden_states): + original_dtype = hidden_states.dtype + x = hidden_states.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + y = torch.rsqrt(variance + self.variance_epsilon) + y = y * (3.0 - (variance + self.variance_epsilon) * y * y) * 0.5 + result = x * y + if self.weight is not None: + result = result * self.weight.float() + return result.to(original_dtype) + + +def get_rmsnorm_cls(): + if cpu_mode() or USE_PYTHON_RMSNORM: + return Qwen3MoeRMSNorm + return NewtonRMSNorm if USE_NEWTON_RMSNORM else CustomRMSNorm + + +def l2norm(x, dim=-1, eps=1e-6): + return F.normalize(x, p=2, dim=dim, eps=eps) + + +# ============================================================ +# Gated DeltaNet Module (Linear Recurrent Attention) +# ============================================================ + + +class NeuronGatedDeltaNet(nn.Module): + """ + Gated DeltaNet linear attention for Neuron. + + Replaces standard attention for 24 of 32 layers in Qwen3.5-4B. + Uses a chunk-based linear recurrence instead of KV cache. + + HF weight layout (4B dense): + - in_proj_qkv.weight: (key_dim*2 + value_dim, hidden_size) = (8192, 2560) + - in_proj_z.weight: (value_dim, hidden_size) = (4096, 2560) + - in_proj_a.weight: (num_v_heads, hidden_size) = (32, 2560) + - in_proj_b.weight: (num_v_heads, hidden_size) = (32, 2560) + - conv1d.weight: (conv_dim, 1, conv_kernel_size) = (8192, 1, 4) + - A_log: (num_v_heads,) = (32,) + - dt_bias: (num_v_heads,) = (32,) + - norm.weight: (head_v_dim,) = (128,) + - out_proj.weight: (hidden_size, value_dim) = (2560, 4096) + """ + + def __init__(self, config, layer_idx: int): + super().__init__() + tc = config + + self.hidden_size = tc.hidden_size # 5120 + self.num_v_heads = tc.linear_num_value_heads # 48 + self.num_k_heads = tc.linear_num_key_heads # 16 + self.head_k_dim = tc.linear_key_head_dim # 128 + self.head_v_dim = tc.linear_value_head_dim # 128 + self.key_dim = self.head_k_dim * self.num_k_heads # 2048 + self.value_dim = self.head_v_dim * self.num_v_heads # 6144 + self.conv_kernel_size = tc.linear_conv_kernel_dim # 4 + self.layer_idx = layer_idx + self.rms_norm_eps = tc.rms_norm_eps + + # KV cache dummy shape info + self.head_dim = tc.head_dim # 256 + tp_degree = tc.neuron_config.tp_degree + raw_kv_heads = tc.num_key_value_heads + if raw_kv_heads < tp_degree: + replicated_kv_heads = tp_degree + else: + replicated_kv_heads = raw_kv_heads + self.kv_heads_per_rank = replicated_kv_heads // tp_degree + + # Conv1d on concatenated QKV (NOT Z) + self.conv_dim = self.key_dim * 2 + self.value_dim # 10240 + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # Input projections (nn.Linear — NOT sharded by NxDI TP, replicated on all ranks) + self.in_proj_qkv = nn.Linear( + self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False + ) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + # Decay parameters + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + self.A_log = nn.Parameter(torch.zeros(self.num_v_heads)) + + # Output norm and projection + self.norm = Qwen3MoeRMSNorm(self.head_v_dim, eps=self.rms_norm_eps) + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + # State buffers for CTE -> TKG carry-over + alloc_batch_size = getattr(config.neuron_config, "max_batch_size", 1) + self._phase_batch_size = getattr(config.neuron_config, "batch_size", 1) + self.recurrent_state_buffer = nn.Parameter( + torch.zeros( + alloc_batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=config.neuron_config.torch_dtype, + ), + requires_grad=False, + ) + self.conv_state_buffer = nn.Parameter( + torch.zeros( + alloc_batch_size, + self.conv_dim, + self.conv_kernel_size - 1, + dtype=config.neuron_config.torch_dtype, + ), + requires_grad=False, + ) + + def _recurrent_step(self, query, key, value, g, beta, recurrent_state): + """Single-step recurrent update for token generation.""" + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + q_t = query[:, :, 0] + k_t = key[:, :, 0] + v_t = value[:, :, 0] + g_t = g[:, :, 0].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, 0].unsqueeze(-1) + + new_state = recurrent_state * g_t + kv_mem = (new_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + new_state = new_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + output = (new_state * q_t.unsqueeze(-1)).sum(dim=-2) + + return output.unsqueeze(2), new_state + + def _nki_recurrent_forward(self, query, key, value, g, beta): + """Full-sequence recurrent forward using NKI kernel for context encoding.""" + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + BH = B * H + query_flat = query.reshape(BH, S, k_dim).contiguous() + key_flat = key.reshape(BH, S, k_dim).contiguous() + value_flat = value.reshape(BH, S, v_dim).contiguous() + + g_flat = g.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous() + beta_flat = beta.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous() + + outputs = [] + states = [] + for bh in range(BH): + out_bh, state_bh = _deltanet_nki_kernel_state( + query_flat[bh], + key_flat[bh], + value_flat[bh], + g_flat[bh], + beta_flat[bh], + ) + outputs.append(out_bh) + states.append(state_bh) + + output = torch.stack(outputs, dim=0) + output = output.reshape(B, H, S, v_dim) + + final_state = torch.stack(states, dim=0) + final_state = final_state.reshape(B, H, k_dim, v_dim) + + return output, final_state + + def _nki_chunked_forward( + self, query, key, value, g, beta, output_final_state=False + ): + """Chunked NKI kernel forward for context encoding (prefill).""" + chunk_size = 128 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + num_chunks = total_seq_len // chunk_size + g_reshaped = g.reshape(B, H, num_chunks, chunk_size) + g_cs = g_reshaped.cumsum(dim=-1) + g_last_per_chunk = g_cs[:, :, :, -1:] + g_last_expanded = g_last_per_chunk.expand(-1, -1, -1, chunk_size) + + query_chunks = query.reshape(B, H, num_chunks, chunk_size, k_dim) + key_chunks = key.reshape(B, H, num_chunks, chunk_size, k_dim) + value_chunks = value.reshape(B, H, num_chunks, chunk_size, v_dim) + + beta_chunks = ( + beta.reshape(B, H, num_chunks, chunk_size) + .unsqueeze(-1) + .expand(-1, -1, -1, -1, v_dim) + ) + gc_chunks = g_cs.unsqueeze(-1).expand(-1, -1, -1, -1, v_dim) + gl_chunks = g_last_expanded.unsqueeze(-1).expand(-1, -1, -1, -1, v_dim) + + BH = B * H + query_chunks = query_chunks.reshape( + BH, num_chunks, chunk_size, k_dim + ).contiguous() + key_chunks = key_chunks.reshape(BH, num_chunks, chunk_size, k_dim).contiguous() + value_chunks = value_chunks.reshape( + BH, num_chunks, chunk_size, v_dim + ).contiguous() + beta_chunks = beta_chunks.reshape( + BH, num_chunks, chunk_size, v_dim + ).contiguous() + gc_chunks = gc_chunks.reshape(BH, num_chunks, chunk_size, v_dim).contiguous() + gl_chunks = gl_chunks.reshape(BH, num_chunks, chunk_size, v_dim).contiguous() + + device = query.device + lower_mask = torch.tril( + torch.ones(chunk_size, chunk_size, dtype=torch.float32, device=device), + diagonal=-1, + ) + identity_mat = torch.eye(chunk_size, dtype=torch.float32, device=device) + lower_mask_diag = torch.tril( + torch.ones(chunk_size, chunk_size, dtype=torch.float32, device=device), + diagonal=0, + ) + + all_outputs = [] + all_states = [] + for bh in range(BH): + state = torch.zeros(k_dim, v_dim, dtype=torch.float32, device=device) + + head_chunks = [] + for c_idx in range(num_chunks): + q_chunk = query_chunks[bh, c_idx].contiguous() + k_chunk = key_chunks[bh, c_idx].contiguous() + v_chunk = value_chunks[bh, c_idx].contiguous() + beta_chunk = beta_chunks[bh, c_idx].contiguous() + gc_chunk = gc_chunks[bh, c_idx].contiguous() + gl_chunk = gl_chunks[bh, c_idx].contiguous() + + out_chunk, state = _deltanet_nki_chunk_step( + q_chunk, + k_chunk, + v_chunk, + beta_chunk, + gc_chunk, + gl_chunk, + state, + lower_mask, + identity_mat, + lower_mask_diag, + ) + head_chunks.append(out_chunk) + + head_output = torch.cat(head_chunks, dim=0) + all_outputs.append(head_output) + all_states.append(state) + + output = torch.stack(all_outputs, dim=0) + output = output.reshape(B, H, total_seq_len, v_dim) + output = output[:, :, :S] + + if output_final_state: + final_state = torch.stack(all_states, dim=0) + last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim) + else: + last_recurrent_state = None + + return output, last_recurrent_state + + def _fused_chunked_forward( + self, query, key, value, g, beta, output_final_state=False + ): + """Fused single-kernel chunked forward for CTE — SSD-style. + + Processes all chunks in a single NKI kernel call per (B,H) pair. + State persists in SBUF across chunks (no HBM round-trips). + Cumsum of g computed in-kernel via tensor_tensor_scan. + + This is the optimized version of _nki_chunked_forward with: + 1. Single kernel call per (B,H) instead of B*H*num_chunks + 2. State in SBUF across all chunks (biggest perf win) + 3. In-kernel cumsum (avoids PyTorch cumsum overhead) + 4. tensor_scalar for broadcasts (no explicit loops) + """ + chunk_size = 128 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + # Pad sequence to multiple of chunk_size + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + BH = B * H + # Flatten to (BH, S, dim) for per-(b,h) kernel calls + query_flat = query.reshape(BH, total_seq_len, k_dim).contiguous() + key_flat = key.reshape(BH, total_seq_len, k_dim).contiguous() + value_flat = value.reshape(BH, total_seq_len, v_dim).contiguous() + + # g and beta: (BH, S) -> (BH, S, 1) for the kernel's (S, 1) input layout + g_flat = g.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + beta_flat = beta.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + + # Create constant mask tensors (shared across all B*H calls) + device = query.device + lower_mask = torch.tensor( + _make_lower_mask(), dtype=torch.float32, device=device + ) + identity_mat = torch.tensor( + _make_identity(), dtype=torch.float32, device=device + ) + lower_mask_diag = torch.tensor( + _make_lower_mask_diag(), dtype=torch.float32, device=device + ) + + all_outputs = [] + all_states = [] + for bh in range(BH): + out_bh, state_bh = _deltanet_fused_kernel( + query_flat[bh], # (S, 128) + key_flat[bh], # (S, 128) + value_flat[bh], # (S, 128) + g_flat[bh], # (S, 1) — RAW g, not cumsum + beta_flat[bh], # (S, 1) — sigmoid(b) + lower_mask, # (128, 128) + identity_mat, # (128, 128) + lower_mask_diag, # (128, 128) + ) + all_outputs.append(out_bh) + all_states.append(state_bh) + + output = torch.stack(all_outputs, dim=0) + output = output.reshape(B, H, total_seq_len, v_dim) + output = output[:, :, :S] + + if output_final_state: + final_state = torch.stack(all_states, dim=0) + last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim) + else: + last_recurrent_state = None + + return output, last_recurrent_state + + def _sequential_forward(self, query, key, value, g, beta, output_final_state=False): + """Sequential full-sequence gated delta rule for CTE. + + Uses the same per-step recurrence as _recurrent_step but loops over the + full sequence. Avoids the slice-assignment loop in _chunk_forward that + may compile incorrectly on Neuron/XLA. + """ + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + state = query.new_zeros(B, H, k_dim, v_dim) + all_outputs = [] + for t in range(S): + q_t = query[:, :, t] # (B, H, K) + k_t = key[:, :, t] # (B, H, K) + v_t = value[:, :, t] # (B, H, V) + beta_t = beta[:, :, t].unsqueeze(-1) # (B, H, 1) + g_t = g[:, :, t].exp().unsqueeze(-1).unsqueeze(-1) # (B, H, 1, 1) + + # Gated delta rule + state = state * g_t + kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) # (B, H, V) + delta = (v_t - kv_mem) * beta_t # (B, H, V) + state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) # (B, H, K, V) + + o_t = (state * q_t.unsqueeze(-1)).sum(dim=-2) # (B, H, V) + all_outputs.append(o_t.unsqueeze(2)) + + output = torch.cat(all_outputs, dim=2) # (B, H, S, V) + final_state = state if output_final_state else None + return output, final_state + + def _chunk_forward(self, query, key, value, g, beta, output_final_state=False): + """Chunk-based forward for context encoding (prefill).""" + chunk_size = 64 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + + num_chunks = total_seq_len // chunk_size + query = query.reshape(B, H, num_chunks, chunk_size, k_dim) + key = key.reshape(B, H, num_chunks, chunk_size, k_dim) + value = value.reshape(B, H, num_chunks, chunk_size, v_dim) + k_beta = k_beta.reshape(B, H, num_chunks, chunk_size, k_dim) + v_beta = v_beta.reshape(B, H, num_chunks, chunk_size, v_dim) + g = g.reshape(B, H, num_chunks, chunk_size) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=0, + ) + + g = g.cumsum(dim=-1) + decay_mask = (g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().tril() + + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + last_recurrent_state = torch.zeros( + B, H, k_dim, v_dim, dtype=query.dtype, device=query.device + ) + core_attn_out = torch.zeros_like(value) + mask2 = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=1, + ) + + for i in range(num_chunks): + q_i = query[:, :, i] + k_i = key[:, :, i] + v_i = value[:, :, i] + + attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_( + mask2, 0 + ) + + v_prime = k_cumdecay[:, :, i] @ last_recurrent_state + v_new = v_i - v_prime + + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn_i @ v_new + + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + ( + k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None] + ).transpose(-1, -2) + @ v_new + ) + + core_attn_out = core_attn_out.reshape(B, H, -1, v_dim) + core_attn_out = core_attn_out[:, :, :S] + + if not output_final_state: + last_recurrent_state = None + + return core_attn_out, last_recurrent_state + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + **kwargs, + ): + """Forward pass compatible with NxDI decoder layer interface.""" + batch_size, seq_len, _ = hidden_states.shape + + seq_ids = kwargs.get("seq_ids", None) + is_decode = past_key_value is not None + + # Padding mask for DeltaNet: [B, S, 1] with 1.0 for real tokens, 0.0 for padding. + # Passed from get_model_output where it's computed from input_ids != pad_token_id. + # Embeddings are already zeroed for padding tokens; this mask additionally + # zeros the decay gate so the recurrent state is preserved unchanged + # through padding positions (no spurious decay). + valid_mask_1d = kwargs.get("deltanet_padding_mask", None) # [B, S, 1] or None + + # Project inputs + deltanet_fp32 = os.environ.get("DELTANET_FP32") == "1" + if deltanet_fp32: + hs_f32 = hidden_states.float() + qkv = F.linear(hs_f32, self.in_proj_qkv.weight.float()).to( + hidden_states.dtype + ) + z = F.linear(hs_f32, self.in_proj_z.weight.float()).to(hidden_states.dtype) + b = F.linear(hs_f32, self.in_proj_b.weight.float()).to(hidden_states.dtype) + a = F.linear(hs_f32, self.in_proj_a.weight.float()).to(hidden_states.dtype) + else: + qkv = self.in_proj_qkv(hidden_states) + z = self.in_proj_z(hidden_states) + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + # Split QKV + query = qkv[..., : self.key_dim] + key = qkv[..., self.key_dim : self.key_dim * 2] + value = qkv[..., self.key_dim * 2 :] + + # Causal Conv1d on QKV + mixed = torch.cat([query, key, value], dim=-1) + mixed = mixed.transpose(1, 2) + + if is_decode: + if seq_ids is not None: + conv_state = torch.index_select(self.conv_state_buffer, 0, seq_ids) + else: + conv_state = self.conv_state_buffer[:batch_size] + conv_input = torch.cat([conv_state, mixed], dim=-1) + + w = self.conv1d.weight.squeeze(1) + conv_out = torch.zeros_like(mixed) + for k in range(4): + conv_out = ( + conv_out + + w[:, k].unsqueeze(0).unsqueeze(-1) * conv_input[:, :, k : k + 1] + ) + mixed_post_conv = F.silu(conv_out) + + new_conv_state = torch.cat([conv_state[:, :, 1:], mixed], dim=-1) + alloc_bs = self.conv_state_buffer.shape[0] + if seq_ids is not None: + # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement + # Add buffer dependency for input_output_alias + new_conv_state = ( + new_conv_state.to(self.conv_state_buffer.dtype) + + self.conv_state_buffer * 0 + ) + elif batch_size < alloc_bs: + pad_size = alloc_bs - batch_size + new_conv_state = torch.cat( + [ + new_conv_state, + self.conv_state_buffer[batch_size:] * 0, + ], + dim=0, + ) + else: + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + else: + mixed_post_conv = F.silu(self.conv1d(mixed)[:, :, :seq_len]) + + if valid_mask_1d is not None: + # valid_mask_1d is [B, S, 1]; count valid tokens per batch + num_valid = ( + valid_mask_1d.squeeze(-1).sum(dim=-1, keepdim=True).long() + ) # [B, 1] + idx_base = num_valid - 3 + idx_base = idx_base.clamp(min=0) + offsets = torch.arange(3, device=mixed.device).unsqueeze(0) + gather_idx = idx_base + offsets # [B, 3] + gather_idx = gather_idx.unsqueeze(1).expand(-1, self.conv_dim, -1) + new_conv_state = torch.gather(mixed, 2, gather_idx) + else: + new_conv_state = mixed[:, :, -3:].contiguous() + + alloc_bs = self.conv_state_buffer.shape[0] + if seq_ids is not None: + # BS=1 optimization: scatter to index 0 = direct replacement + new_conv_state = ( + new_conv_state.to(self.conv_state_buffer.dtype) + + self.conv_state_buffer * 0 + ) + elif batch_size < alloc_bs: + pad_size = alloc_bs - batch_size + new_conv_state = torch.cat( + [ + new_conv_state, + torch.zeros( + pad_size, + self.conv_dim, + self.conv_kernel_size - 1, + dtype=new_conv_state.dtype, + device=new_conv_state.device, + ), + ], + dim=0, + ) + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + else: + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + + mixed_post_conv = mixed_post_conv.transpose(1, 2) + + # Zero out conv1d output for padding positions. + # Conv1d with kernel_size=4 leaks real token info into the first + # few padding positions. Zeroing here ensures Q, K, V are exactly + # zero for all padding positions so the recurrence is unaffected. + if valid_mask_1d is not None: + mixed_post_conv = ( + mixed_post_conv * valid_mask_1d + ) # [B, S, conv_dim] * [B, S, 1] + + query = mixed_post_conv[..., : self.key_dim] + key = mixed_post_conv[..., self.key_dim : self.key_dim * 2] + value = mixed_post_conv[..., self.key_dim * 2 :] + + # Reshape to heads + query = query.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + key = key.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + value = value.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + + # Compute gating + beta = b.sigmoid() + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + + if valid_mask_1d is not None: + # Zero g for padding → alpha=exp(0)=1 → state preserved through padding + # Zero beta for padding → no state update from padding tokens + mask_2d = valid_mask_1d.squeeze(-1).float() # [B, S] + g = g * mask_2d.unsqueeze(-1) + beta = beta * mask_2d.unsqueeze(-1) + + # Expand K heads to match V heads (16 -> 48) using expand+reshape + if self.num_v_heads // self.num_k_heads > 1: + rep = self.num_v_heads // self.num_k_heads # 3 + query = ( + query.unsqueeze(3) + .expand(-1, -1, -1, rep, -1) + .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim) + ) + key = ( + key.unsqueeze(3) + .expand(-1, -1, -1, rep, -1) + .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim) + ) + + # Transpose to (B, H, S, dim) + query = query.transpose(1, 2).contiguous().float() + key = key.transpose(1, 2).contiguous().float() + value = value.transpose(1, 2).contiguous().float() + g = g.transpose(1, 2).contiguous().float() + beta = beta.transpose(1, 2).contiguous().float() + + if is_decode: + # TKG: single-step recurrent update + if seq_ids is not None: + recurrent_state = torch.index_select( + self.recurrent_state_buffer, 0, seq_ids + ).float() + else: + recurrent_state = self.recurrent_state_buffer[:batch_size].float() + + output, new_state = self._recurrent_step( + query, key, value, g, beta, recurrent_state + ) + new_state_bf16 = new_state.to(self.recurrent_state_buffer.dtype) + alloc_bs = self.recurrent_state_buffer.shape[0] + if seq_ids is not None: + # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement + # Add buffer dependency for input_output_alias + new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0 + elif batch_size < alloc_bs: + new_rec_state = torch.cat( + [ + new_state_bf16, + self.recurrent_state_buffer[batch_size:] * 0, + ], + dim=0, + ) + else: + new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0 + else: + # CTE: fused NKI kernel by default (PyTorch _chunk_forward can hit + # neuronx-cc codegen ICE NCC_INLA001 with these DeltaNet dimensions). + # Override with env vars for debugging/benchmarking. + use_nki_fused = os.environ.get("USE_NKI_FUSED", "1") != "0" + use_nki_chunked = os.environ.get("USE_NKI_CHUNKED") == "1" + use_nki = os.environ.get("USE_NKI") == "1" + use_sequential = os.environ.get("DELTANET_SEQUENTIAL") == "1" + use_pytorch_chunk = os.environ.get("USE_PYTORCH_CHUNK") == "1" + + if use_pytorch_chunk: + output, final_state = self._chunk_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki_chunked: + output, final_state = self._nki_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki: + output, final_state = self._nki_recurrent_forward( + query, key, value, g, beta + ) + elif use_sequential: + output, final_state = self._sequential_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki_fused: + output, final_state = self._fused_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + else: + output, final_state = self._fused_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + + if final_state is not None: + final_state_bf16 = final_state.to(self.recurrent_state_buffer.dtype) + alloc_bs = self.recurrent_state_buffer.shape[0] + if seq_ids is not None: + # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement + # Add buffer dependency for input_output_alias + new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0 + elif batch_size < alloc_bs: + new_rec_state = torch.cat( + [ + final_state_bf16, + torch.zeros( + alloc_bs - batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=final_state_bf16.dtype, + device=final_state_bf16.device, + ), + ], + dim=0, + ) + new_rec_state = new_rec_state + self.recurrent_state_buffer * 0 + else: + new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0 + else: + new_rec_state = self.recurrent_state_buffer * 1 + + # Output: norm, gate, project + output = output.to(hidden_states.dtype) + output = output.transpose(1, 2).contiguous() + output = output.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + output = self.norm(output) + z_gate = z.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + output = output * F.silu(z_gate) + output = output.reshape(batch_size, seq_len, self.value_dim) + output = self.out_proj(output) + + # Return dummy KV for KVCacheManager + dummy_k = torch.zeros( + batch_size, + self.kv_heads_per_rank, + seq_len, + self.head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + dummy_v = torch.zeros_like(dummy_k) + + return output, (dummy_k, dummy_v), new_rec_state, new_conv_state + + +# ============================================================ +# InferenceConfig (Dense -- no MoE) +# ============================================================ + + +class Qwen35InferenceConfig(InferenceConfig): + """Config for Qwen3.5-4B (dense) with hybrid DeltaNet + Attention.""" + + def __init__(self, *args, **kwargs): + # Set defaults BEFORE super().__init__() because it calls validate_config() + # which checks get_required_attributes(). These can be overridden by + # kwargs or load_config. + + # Layer types for hybrid dispatch: [3 DeltaNet + 1 GQA] repeated. + if "layer_types" not in kwargs and not any( + hasattr(a, "layer_types") for a in args if hasattr(a, "__dict__") + ): + num_layers = kwargs.get("num_hidden_layers", 32) + if num_layers % 4 != 0: + raise ValueError( + f"Qwen3.5 hybrid layer count must be divisible by 4, got {num_layers}" + ) + layer_types = [] + for _ in range(num_layers // 4): + layer_types.extend( + [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + ] + ) + kwargs.setdefault("layer_types", layer_types) + + # DeltaNet-specific config defaults + kwargs.setdefault("linear_num_value_heads", 32) + kwargs.setdefault("linear_num_key_heads", 16) + kwargs.setdefault("linear_key_head_dim", 128) + kwargs.setdefault("linear_value_head_dim", 128) + kwargs.setdefault("linear_conv_kernel_dim", 4) + + super().__init__(*args, **kwargs) + + # Attention output gate + self.attn_output_gate = getattr(self, "attn_output_gate", True) + + # Partial RoPE + self.partial_rotary_factor = getattr(self, "partial_rotary_factor", 0.25) + self.rope_dim = int(self.head_dim * self.partial_rotary_factor) # 64 + + # mRoPE (multimodal RoPE) for VL support + rope_params = getattr(self, "rope_parameters", {}) or {} + self.mrope_section = rope_params.get("mrope_section", [11, 11, 10]) + self.mrope_interleaved = rope_params.get("mrope_interleaved", True) + + # Standard HF config attributes expected by NxDI + if not hasattr(self, "output_attentions"): + self.output_attentions = False + if not hasattr(self, "output_hidden_states"): + self.output_hidden_states = False + + def get_required_attributes(self) -> List[str]: + return [ + "head_dim", + "hidden_act", + "hidden_size", + "intermediate_size", + "max_position_embeddings", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "rms_norm_eps", + "rope_theta", + "vocab_size", + # DeltaNet-specific + "linear_num_value_heads", + "linear_num_key_heads", + "linear_key_head_dim", + "linear_value_head_dim", + "linear_conv_kernel_dim", + "layer_types", + ] + + @classmethod + def get_neuron_config_cls(cls): + return NeuronConfig + + +# ============================================================ +# Attention (standard GQA for 16 of 64 layers) +# With output gate: q_proj is 2x sized, split into (query, gate) +# With partial RoPE: only first rope_dim dimensions get rotary +# ============================================================ + + +class Qwen35MRoPEEmbedding(nn.Module): + """Multimodal Rotary Position Embedding (mRoPE) for Qwen3.5. + + Handles 3D position information (temporal, height, width) for VL models. + Position IDs have shape (3, batch_size, seq_len) for T/H/W dimensions. + For text-only (2D position_ids), broadcasts to 3D with identical positions. + """ + + def __init__(self, config): + super().__init__() + self.head_dim = config.head_dim # 256 + self.rope_dim = config.rope_dim # 64 + self.mrope_section = config.mrope_section # [11, 11, 10] + self.mrope_interleaved = getattr(config, "mrope_interleaved", True) + self.rope_theta = config.rope_theta + + # Validate mrope_section sums to rope_dim // 2 = 32 + assert sum(self.mrope_section) == self.rope_dim // 2, ( + f"mrope_section {self.mrope_section} sums to {sum(self.mrope_section)}, " + f"expected {self.rope_dim // 2}" + ) + + def forward(self, x, position_ids_3d): + """Compute cos/sin from 3D position IDs. + + Args: + x: hidden_states (for device/dtype inference) + position_ids_3d: (3, batch_size, seq_len) -- T, H, W positions + + Returns: + cos: (batch_size, seq_len, rope_dim) + sin: (batch_size, seq_len, rope_dim) + """ + device = x.device + dtype = torch.float32 + + sections = self.mrope_section # [11, 11, 10] + cos_parts = [] + sin_parts = [] + + freq_offset = 0 + for axis_idx, section_size in enumerate(sections): + pos = position_ids_3d[axis_idx].float() # (batch, seq_len) + + dim_pairs = section_size # number of (cos, sin) pairs for this axis + freqs = 1.0 / ( + self.rope_theta + ** ( + torch.arange(0, dim_pairs * 2, 2, dtype=dtype, device=device) + / (self.rope_dim) + ) + ) # (dim_pairs,) + + # freqs: (dim_pairs,), pos: (B, S) -> angles: (B, S, dim_pairs) + angles = pos.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0) + + cos_parts.append(angles.cos()) + sin_parts.append(angles.sin()) + + # Concatenate: (B, S, 32) + cos = torch.cat(cos_parts, dim=-1) + sin = torch.cat(sin_parts, dim=-1) + + if self.mrope_interleaved: + # Interleave to (B, S, 64): [c0, c0, c1, c1, ...] for rotate_half + cos = cos.repeat_interleave(2, dim=-1) + sin = sin.repeat_interleave(2, dim=-1) + else: + cos = torch.cat([cos, cos], dim=-1) + sin = torch.cat([sin, sin], dim=-1) + + return cos, sin + + +class NeuronQwen35Attention(NeuronAttentionBase): + """Standard GQA attention for Qwen3.5 with output gate and partial RoPE. + + 16 Q heads, 4 KV heads (4:1 GQA), head_dim=256 for the 4B dense model. + q_proj is doubled (query + gate), split at load time. + Only first rope_dim=64 of head_dim=256 gets rotary encoding. + + Uses NeuronAttentionBase infrastructure for QKV projection, KV cache, + RoPE, and attention computation. Overrides forward() to insert the + sigmoid output gate between attention output and o_proj. + """ + + def __init__(self, config): + # Partial RoPE: create mRoPE embedding with rope_dim (64) + self.rope_dim = config.rope_dim # 64 = head_dim * partial_rotary_factor + + # Create QK norm modules (will be passed to base class) + rms_norm_eps = config.rms_norm_eps + q_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps) + k_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps) + + # Partial RoPE: use standard RotaryEmbedding. + # For VL with 3D mRoPE positions, cos/sin are pre-computed externally in + # get_model_output() using Qwen35MRoPEEmbedding and passed as cos_cache/sin_cache. + rotary_emb = RotaryEmbedding( + self.rope_dim, # Only 64 dims get rotary embedding + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + rotary_emb=rotary_emb, + rms_norm_eps=rms_norm_eps, + use_qk_norm=False, + q_layernorm=q_ln, + k_layernorm=k_ln, + ) + + # Separate mRoPE module for VL 3D position_ids + self.mrope_emb = Qwen35MRoPEEmbedding(config) + + # Output gate projection: hidden_size -> num_heads * head_dim + # Populated from the second half of q_proj during state dict conversion. + self.output_gate_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * config.head_dim, + bias=False, + gather_output=False, + ) + + def apply_rotary_embedding( + self, Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope + ): + """Partial RoPE: only apply rotary embedding to first rope_dim dimensions. + + Q shape: (B, H, S, head_dim) where head_dim=256 + cos/sin shape: (B, S, rope_dim) where rope_dim=64 (from RotaryEmbedding(dim=64)) + + Split Q/K along last dim into: + q_rope (first 64 dims) -- apply RoPE + q_pass (remaining 192 dims) -- pass through unchanged + """ + from neuronx_distributed_inference.modules.attention.utils import ( + apply_rotary_pos_emb, + ) + + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + + # Split into rope and pass-through portions + Q_orig_dtype = Q.dtype + q_rope = Q[..., : self.rope_dim] # (B, H, S, 64) + q_pass = Q[..., self.rope_dim :] # (B, H, S, 192) + k_rope = K[..., : self.rope_dim] + k_pass = K[..., self.rope_dim :] + + # Apply RoPE only to the rope portion + q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos_cache, sin_cache) + + # Concatenate back (ensure bf16 is maintained) + Q = torch.cat([q_rope, q_pass], dim=-1).to(Q_orig_dtype) + K = torch.cat([k_rope, k_pass], dim=-1).to(Q_orig_dtype) + + return Q, K, cos_cache, sin_cache + + def perform_prefill(self, Q, K, V, q_len, bsz, attention_mask=None): + """Prefill path with NKI flash attention for head_dim=256.""" + head_dim = Q.shape[-1] + + # Option B: nkilib flash attention for head_dim > 128 + if _nkilib_flash_attn is not None: + q_contig = Q.contiguous() + k_contig = K.contiguous() + v_contig = V.contiguous() + scale = 1.0 / math.sqrt(head_dim) + result = _nkilib_flash_attn( + q_contig, k_contig, v_contig, scale=scale, use_causal_mask=True + ) + return result, None + + # Option A: kernel patched globally + if NKILIB_PATCH_ACTIVE: + return _flash_fwd_call(Q, K, V, use_causal_mask=True), None + + # Fallback: softmax path (use 3D tensors to avoid compiler ICE with 4D patterns) + if head_dim > 128: + # GQA: expand K/V heads to match Q heads + num_q_heads = Q.shape[1] + num_kv_heads = K.shape[1] + if num_q_heads != num_kv_heads: + kv_rep = num_q_heads // num_kv_heads + K = ( + K.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(bsz, num_q_heads, q_len, head_dim) + ) + V = ( + V.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(bsz, num_q_heads, q_len, head_dim) + ) + # Reshape to 3D (B*H, S, d) to avoid neuronx-cc codegen ICE with 4D + # attention weight tensors (NCC_INLA001: Expected 2D tensor but got 4D AP) + Q_3d = Q.reshape(bsz * num_q_heads, q_len, head_dim) + K_3d = K.reshape(bsz * num_q_heads, q_len, head_dim) + V_3d = V.reshape(bsz * num_q_heads, q_len, head_dim) + attn_weights = torch.bmm(Q_3d, K_3d.transpose(-1, -2)) / math.sqrt(head_dim) + # Build causal mask for 3D: (1, S, S) broadcast over B*H + causal_mask = torch.triu( + torch.full( + (q_len, q_len), + -65504.0, + dtype=attn_weights.dtype, + device=attn_weights.device, + ), + diagonal=1, + ).unsqueeze(0) + attn_weights = attn_weights + causal_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + Q.dtype + ) + attn_output = torch.bmm(attn_weights, V_3d) + # Reshape back to 4D (B, H, S, d) + return attn_output.reshape(bsz, num_q_heads, q_len, head_dim), None + + return _flash_fwd_call(Q, K, V, use_causal_mask=True), None + + def forward( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + adapter_ids=None, + active_mask=None, + **kwargs, + ): + """Forward with output gate applied BEFORE o_proj. + + Override NeuronAttentionBase.forward() to insert the sigmoid gate + between the attention output and o_proj, matching the HF reference: + gate = sigmoid(gate_proj(pre_attn_hidden)) + attn_output = attn_output * gate + attn_output = o_proj(attn_output) + """ + bsz, q_len, _ = hidden_states.shape + + # Use standard 2D position_ids for prep_qkv_tensors. + rope_pos_ids = position_ids + + # Compute gate from input hidden states (before QKV projection) + gate = self.output_gate_proj(hidden_states) # (B, S, num_heads * head_dim) + + # Standard QKV prep (projections, QK norm, RoPE) + Q, K, V, cos_cache, sin_cache, _residual = self.prep_qkv_tensors( + rope_pos_ids, + hidden_states, + past_key_value, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rmsnorm=rmsnorm, + ) + + if past_key_value is None: + # Context encoding (prefill) + attn_output, _flash_strategy = self.perform_prefill( + Q, K, V, q_len, bsz, attention_mask + ) + else: + # Token generation (decode) + tkg_mask = attention_mask + if tkg_mask is not None and tkg_mask.ndim == 2: + tkg_mask = tkg_mask.unsqueeze(1).unsqueeze(2) # (B, S) -> (B, 1, 1, S) + attn_output = self.compute_for_token_gen( + Q, K, V, position_ids, past_key_value, tkg_mask, active_mask + ) + + # attn_output is (B, H, S, head_dim) -- transpose to (B, S, H*head_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + + # Apply sigmoid output gate BEFORE o_proj (matching HF reference) + attn_output = attn_output * torch.sigmoid(gate) + + # Apply o_proj + attn_output = self.get_o_proj()(attn_output, adapter_ids=adapter_ids) + + # Ensure K, V are in model dtype (bf16) for KV cache update + # (prevents mixed-precision dynamic-update-slice in neuronx-cc) + K = K.to(self.torch_dtype) + V = V.to(self.torch_dtype) + past_key_value = (K, V) + return attn_output, past_key_value, cos_cache, sin_cache + + +# ============================================================ +# Dense MLP (replaces MoE) +# ============================================================ + + +class Qwen35MLP(nn.Module): + """Dense SwiGLU MLP for Qwen3.5-4B. + + gate_proj: hidden_size -> intermediate_size (2560 -> 9216) + up_proj: hidden_size -> intermediate_size (2560 -> 9216) + down_proj: intermediate_size -> hidden_size (9216 -> 2560) + + output = down_proj(silu(gate_proj(x)) * up_proj(x)) + """ + + def __init__(self, config): + super().__init__() + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + ) + self.down_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + ) + + def forward(self, hidden_states): + gate = self.gate_proj(hidden_states) + up = self.up_proj(hidden_states) + hidden_states = F.silu(gate) * up + hidden_states = self.down_proj(hidden_states) + return hidden_states + + +# ============================================================ +# Decoder Layer (hybrid dispatch -- DeltaNet or GQA + Dense MLP) +# ============================================================ + + +class NeuronQwen35DecoderLayer(nn.Module): + """Hybrid decoder layer: dispatches to DeltaNet or standard attention. + Uses dense MLP for all layers (no MoE). + """ + + def __init__(self, config: Qwen35InferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_type = config.layer_types[layer_idx] + self.layer_idx = layer_idx + self.config = config + + # Attention (DeltaNet or standard GQA) + if self.layer_type == "linear_attention": + self.linear_attn = NeuronGatedDeltaNet(config, layer_idx) + else: + self.self_attn = NeuronQwen35Attention(config=config) + + # Dense MLP (all layers) + self.mlp = Qwen35MLP(config) + + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + padding_mask=None, + cos_cache=None, + sin_cache=None, + **kwargs, + ): + residual = hidden_states + + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == "linear_attention": + # DeltaNet path + attn_out, dummy_kv, new_rec_state, new_conv_state = self.linear_attn( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + **kwargs, + ) + hidden_states = residual + attn_out + present_key_value = dummy_kv + deltanet_states = (new_rec_state, new_conv_state) + else: + deltanet_states = None + # Standard attention path + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cos_cache=cos_cache, + sin_cache=sin_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Dense MLP FFN + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + outputs = ( + hidden_states, + present_key_value, + cos_cache, + sin_cache, + None, + deltanet_states, + ) + return outputs + + +# ============================================================ +# Model +# ============================================================ + + +class NeuronQwen35Model(NeuronBaseModel): + def setup_attr_for_model(self, config: Qwen35InferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: Qwen35InferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + self.layers = nn.ModuleList( + [ + NeuronQwen35DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=False if self.on_device_sampling else True, + bias=False, + ) + + # mRoPE embedding for VL + self.mrope_emb = Qwen35MRoPEEmbedding(config) + + @property + def _deltanet_state_params(self): + """Return DeltaNet state nn.Parameters in alias order.""" + params = [] + for layer in self.layers: + if hasattr(layer, "linear_attn"): + params.append(layer.linear_attn.recurrent_state_buffer) + params.append(layer.linear_attn.conv_state_buffer) + return params + + def encode_vision_to_input(self, inputs_embeds, vision_embeddings, vision_mask): + """Scatter vision embeddings into text input embeddings at image token positions.""" + _, max_positions, embedding_dim = inputs_embeds.shape + h_new = inputs_embeds.clone() + vision_flat = vision_embeddings.view(-1, embedding_dim) + positions_flat = vision_mask.view(-1) + h_new.view(-1, embedding_dim).index_put_( + (positions_flat,), vision_flat, accumulate=False + ) + return h_new + + def get_model_output( + self, + input_ids=None, + seq_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + active_mask=None, + inputs_embeds=None, + prev_hidden=None, + adapter_ids=None, + rotary_position_ids=None, + update_cache=False, + is_for_context_encoding=False, + vision_embeddings=None, + vision_mask=None, + local_attn_mask=None, + windowed_context_encoding_window_idx=-1, + padding_mask=None, + **kwargs, + ): + """Override to collect DeltaNet state tensors from decoder layers.""" + batch_size, seq_length = input_ids.shape[:2] + if self.config.neuron_config.layer_boundary_markers: + input_ids = ModuleMarkerStartWrapper()(input_ids) + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][1].shape[2] + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # CRITICAL: Zero out embeddings for padding tokens so DeltaNet recurrence + # is not polluted. DeltaNet has no attention mask -- it processes all + # sequence positions through a linear recurrence. Padding tokens have + # real embedding vectors which corrupt the recurrence state. + # The mask is [B, S, 1] float with 1.0 for real tokens, 0.0 for padding. + deltanet_padding_mask = ( + (input_ids != self.padding_idx).unsqueeze(-1).to(inputs_embeds.dtype) + ) + if is_for_context_encoding: + inputs_embeds = inputs_embeds * deltanet_padding_mask + + # Vision embedding injection + if (vision_embeddings is not None) and (vision_mask is not None): + if vision_embeddings.dtype != self.config.neuron_config.torch_dtype: + vision_embeddings = vision_embeddings.to( + self.config.neuron_config.torch_dtype + ) + if is_for_context_encoding: + inputs_embeds = self.encode_vision_to_input( + inputs_embeds, vision_embeddings, vision_mask + ) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + hidden_states = inputs_embeds + + # Get KV cache for TKG + cache_size = self.n_positions + if not is_for_context_encoding: + if self.kv_mgr is not None: + past_key_values = self.kv_mgr.get_cache( + seq_ids=seq_ids, + seq_len=cache_size, + is_for_context_encoding=is_for_context_encoding, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + + # Decoder layers + next_decoder_cache = () + deltanet_state_tensors = [] + cos_cache = None + sin_cache = None + + # Convert 2D attention_mask to 4D causal mask for CTE + if ( + attention_mask is not None + and attention_mask.ndim == 2 + and is_for_context_encoding + ): + causal = torch.ones( + (seq_length, seq_length), + dtype=torch.bool, + device=attention_mask.device, + ).tril() + padding_4d = attention_mask[:, None, None, :].to(torch.bool) + attention_mask = (causal[None, None, :, :] & padding_4d).to( + attention_mask.dtype + ) + + # Pre-compute mRoPE cos/sin + if rotary_position_ids is not None and rotary_position_ids.ndim == 3: + cos_cache, sin_cache = self.mrope_emb(inputs_embeds, rotary_position_ids) + + for idx, decoder_layer in enumerate(self.layers): + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + layer_outputs = decoder_layer( + hidden_states, + seq_ids=seq_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + active_mask=active_mask, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rotary_position_ids=rotary_position_ids, + kv_mgr=self.kv_mgr, + get_kv_per_layer=False, + update_kv_per_layer=False, + idx=idx, + is_for_context_encoding=is_for_context_encoding, + seq_len=cache_size, + residual=None, + local_mask=local_attn_mask, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + padding_mask=padding_mask, + deltanet_padding_mask=deltanet_padding_mask, + **kwargs, + ) + + hidden_states = layer_outputs[0] + kv = layer_outputs[1] + next_decoder_cache += (kv,) + cos_cache, sin_cache = layer_outputs[2:4] + + # Collect DeltaNet state tensors + deltanet_states = layer_outputs[5] if len(layer_outputs) > 5 else None + if deltanet_states is not None: + deltanet_state_tensors.append(deltanet_states[0]) + deltanet_state_tensors.append(deltanet_states[1]) + + # Update KV cache + if update_cache: + next_decoder_cache = self.kv_mgr.update_cache( + is_for_context_encoding=is_for_context_encoding, + seq_ids=seq_ids, + position_ids=position_ids, + new_key_values=next_decoder_cache, + seq_len=cache_size, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + self._deltanet_updated_states = deltanet_state_tensors + + return (hidden_states, next_decoder_cache) + + def forward( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden=None, + adapter_ids=None, + accepted_indices=None, + current_length=None, + medusa_mask=None, + scatter_index=None, + slot_mapping=None, + active_block_table=None, + num_queries=None, + computed_context_lens=None, + tile_q_indices=None, + tile_block_tables=None, + tile_masks=None, + inputs_embeds=None, + kv_cache=None, + active_mask=None, + rotary_position_id=None, + vision_embeddings=None, + vision_mask=None, + ): + """Override base forward to append DeltaNet state tensors to output.""" + prev_hidden = self.set_none_if_empty(prev_hidden) + adapter_ids = self.set_none_if_empty(adapter_ids) + accepted_indices = self.set_none_if_empty(accepted_indices) + current_length = self.set_none_if_empty(current_length) + medusa_mask = self.set_none_if_empty(medusa_mask) + scatter_index = self.set_none_if_empty(scatter_index) + slot_mapping = self.set_none_if_empty(slot_mapping) + active_block_table = self.set_none_if_empty(active_block_table) + num_queries = self.set_none_if_empty(num_queries) + computed_context_lens = self.set_none_if_empty(computed_context_lens) + tile_q_indices = self.set_none_if_empty(tile_q_indices) + tile_block_tables = self.set_none_if_empty(tile_block_tables) + tile_masks = self.set_none_if_empty(tile_masks) + inputs_embeds = self.set_none_if_empty(inputs_embeds) + kv_cache = self.set_none_if_empty(kv_cache) + active_mask = self.set_none_if_empty(active_mask) + rotary_position_id = self.set_none_if_empty(rotary_position_id) + vision_embeddings = self.set_none_if_empty(vision_embeddings) + vision_mask = self.set_none_if_empty(vision_mask) + + is_for_context_encoding = position_ids.shape[-1] != 1 and not ( + hasattr(self.neuron_config, "speculation_length") + and position_ids.shape[-1] == self.neuron_config.speculation_length + ) + + seq_ids = seq_ids.to(torch.int32) + attn_mask = attention_mask + + hidden_states, updated_kv_cache = self.get_model_output( + input_ids=input_ids, + seq_ids=seq_ids, + attention_mask=attn_mask, + position_ids=position_ids, + active_mask=active_mask, + inputs_embeds=inputs_embeds, + adapter_ids=adapter_ids, + rotary_position_ids=rotary_position_id, + update_cache=True, + is_for_context_encoding=is_for_context_encoding, + padding_mask=None, + active_block_table=active_block_table, + scatter_index=slot_mapping + if getattr(self, "is_block_kv_layout", False) + else scatter_index, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + ) + + batch_size = input_ids.shape[0] + if not getattr(self, "sliced_hidden", False): + if not is_for_context_encoding: + pass + else: + index = torch.max(position_ids, dim=1, keepdim=True).indices + index = index.unsqueeze(1).expand(batch_size, 1, self.hidden_size) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + + logits = self.lm_head(hidden_states) + logits = logits.float() + + if hasattr(self.lm_head, "pad_size"): + if self.lm_head.gather_output: + rank_id = torch.tensor(0, device=logits.device, dtype=torch.int32) + world_size = 1 + else: + from neuronx_distributed.parallel_layers import parallel_state + + rank_id = self.rank_util.get_rank() + world_size = torch.distributed.get_world_size( + group=self.lm_head.tensor_parallel_group + ) + from neuronx_distributed_inference.models.model_base import ( + mask_padded_logits, + ) + + logits = mask_padded_logits( + logits, rank_id, world_size, pad_size=self.lm_head.pad_size + ) + + if self.on_device_sampling: + res = self._sample_on_device( + logits, sampling_params, False, is_for_context_encoding + ) + else: + res = logits + + outputs = [res] + if self.neuron_config.output_logits: + outputs += [logits] + outputs += updated_kv_cache + + # Append DeltaNet state tensors (for input_output_aliases) + if hasattr(self, "_deltanet_updated_states"): + outputs += self._deltanet_updated_states + + return outputs + + +# ============================================================ +# State Dict Converter (Dense -- no MoE weight handling) +# ============================================================ + + +def convert_qwen35_hf_to_neuron_state_dict(neuron_state_dict, config): + """Convert HF Qwen3.5 (dense) weights to NxDI format. + + Weight mappings per layer type: + + DeltaNet layers (linear_attention): + HF: layers.X.linear_attn.{in_proj_qkv, in_proj_z, in_proj_a, in_proj_b, + conv1d, A_log, dt_bias, norm, out_proj} + NxDI: same names (no remapping needed) + + Full attention layers: + HF: layers.X.self_attn.q_proj.weight: (num_heads*head_dim*2, hidden) -- doubled for gate + NxDI: layers.X.self_attn.Wqkv.weight (fused Q+K+V, gate separated) + layers.X.self_attn.output_gate_proj.weight (gate part) + HF: layers.X.self_attn.{k_proj, v_proj, o_proj, q_norm, k_norm} + NxDI: layers.X.self_attn.{..., q_layernorm, k_layernorm} + + Dense MLP (all layers): + HF: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + NxDI: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight (same names) + """ + # Add rank_util + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + # CRITICAL: Convert (1+weight) RMSNorm weights to standard RMSNorm weights. + # Qwen3.5 uses RMSNorm with `output = norm(x) * (1 + weight)` where weight + # is initialized to zeros. Standard NxDI RMSNorm uses `output = norm(x) * weight` + # where weight is initialized to ones. To convert: new_weight = old_weight + 1.0 + norm_keys_to_convert = [] + for l in range(config.num_hidden_layers): + norm_keys_to_convert.append(f"layers.{l}.input_layernorm.weight") + norm_keys_to_convert.append(f"layers.{l}.post_attention_layernorm.weight") + if config.layer_types[l] == "full_attention": + norm_keys_to_convert.append(f"layers.{l}.self_attn.q_norm.weight") + norm_keys_to_convert.append(f"layers.{l}.self_attn.k_norm.weight") + norm_keys_to_convert.append("norm.weight") + + for nk in norm_keys_to_convert: + if nk in neuron_state_dict: + old_val = neuron_state_dict[nk] + neuron_state_dict[nk] = old_val.float() + 1.0 + if "layers.0." in nk or nk == "norm.weight": + logger.debug( + f"[NORM FIX] {nk}: mean {old_val.float().mean():.4f} -> {neuron_state_dict[nk].mean():.4f}" + ) + else: + if "layers.0." in nk or nk == "norm.weight": + logger.warning(f"[NORM FIX] key not found: {nk}") + + for l in range(config.num_hidden_layers): + layer_type = config.layer_types[l] + + # === Attention layers === + if layer_type == "full_attention": + neuron_state_dict[f"layers.{l}.self_attn.rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + # QK norms: q_norm -> q_layernorm, k_norm -> k_layernorm + q_norm_key = f"layers.{l}.self_attn.q_norm.weight" + k_norm_key = f"layers.{l}.self_attn.k_norm.weight" + if q_norm_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.q_layernorm.weight"] = ( + neuron_state_dict.pop(q_norm_key).detach().clone() + ) + if k_norm_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.k_layernorm.weight"] = ( + neuron_state_dict.pop(k_norm_key).detach().clone() + ) + + # q_proj is doubled: (num_heads * head_dim * 2, hidden_size) + # INTERLEAVED: [head0_query(head_dim) | head0_gate(head_dim) | head1_query(head_dim) | ...] + q_proj_key = f"layers.{l}.self_attn.q_proj.weight" + if q_proj_key in neuron_state_dict: + q_proj_w = neuron_state_dict.pop(q_proj_key) + num_heads = config.num_attention_heads + head_dim = config.head_dim + q_proj_w = q_proj_w.reshape(num_heads, head_dim * 2, config.hidden_size) + query_w = q_proj_w[:, :head_dim, :] + gate_w = q_proj_w[:, head_dim:, :] + query_w = query_w.reshape(num_heads * head_dim, config.hidden_size) + gate_w = gate_w.reshape(num_heads * head_dim, config.hidden_size) + + neuron_state_dict[q_proj_key] = query_w + neuron_state_dict[f"layers.{l}.self_attn.output_gate_proj.weight"] = ( + gate_w + ) + + # Fuse QKV + if config.neuron_config.fused_qkv: + q_key = f"layers.{l}.self_attn.q_proj.weight" + k_key = f"layers.{l}.self_attn.k_proj.weight" + v_key = f"layers.{l}.self_attn.v_proj.weight" + if q_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = torch.cat( + [ + neuron_state_dict[q_key], + neuron_state_dict[k_key], + neuron_state_dict[v_key], + ] + ) + del neuron_state_dict[q_key] + del neuron_state_dict[k_key] + del neuron_state_dict[v_key] + + # Dense MLP: no weight conversion needed -- HF and NxDI use same names + # HF: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + # NxDI: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + + gc.collect() + + return neuron_state_dict + + +# ============================================================ +# Custom ModelWrapper and DecoderModelInstance for DeltaNet state aliasing +# ============================================================ + + +class Qwen35DecoderModelInstance(DecoderModelInstance): + """Custom DecoderModelInstance that adds DeltaNet state buffers to input_output_aliases.""" + + def get(self, bucket_rank, **kwargs): + """Override to add DeltaNet state aliases after KV cache aliases.""" + module, input_output_aliases = super().get(bucket_rank, **kwargs) + + num_output_from_trace = 1 if not self.neuron_config.output_logits else 2 + + if module.kv_mgr is not None: + num_kv = len(module.kv_mgr.past_key_values) + else: + num_kv = 0 + + state_start_idx = num_output_from_trace + num_kv + + if hasattr(module, "_deltanet_state_params"): + for i, param in enumerate(module._deltanet_state_params): + input_output_aliases[param] = state_start_idx + i + + return module, input_output_aliases + + +class Qwen35ModelWrapper(ModelWrapper): + """Custom ModelWrapper for VL support with mRoPE and vision inputs.""" + + def get_model_instance(self): + return Qwen35DecoderModelInstance( + model_cls=self.model_cls, + config=self.config, + **self.model_init_kwargs, + ) + + def input_generator(self): + """Generate inputs including mrope_position_ids, vision_embeddings, and vision_mask.""" + base_inputs = super().input_generator() + extended_inputs = [] + + for bucket_inputs in base_inputs: + input_ids = bucket_inputs[0] + batch_size = input_ids.shape[0] + n_active_tokens = input_ids.shape[1] + + is_cte = n_active_tokens > 1 + + if is_cte: + mrope_position_ids = ( + torch.arange(0, n_active_tokens, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + + vision_embeddings = torch.zeros( + (batch_size, n_active_tokens, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, n_active_tokens, 1), + fill_value=n_active_tokens - 1, + dtype=torch.int32, + ) + else: + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + vision_embeddings = torch.zeros( + (0,), dtype=self.config.neuron_config.torch_dtype + ) + vision_mask = torch.zeros((0,), dtype=torch.int32) + + padded = list(bucket_inputs) + while len(padded) < 21: + padded.append(torch.zeros((0,), dtype=torch.int32)) + padded.append(mrope_position_ids) # position 21 + padded.append(vision_embeddings) # position 22 + padded.append(vision_mask) # position 23 + + extended_inputs.append(tuple(padded)) + + return extended_inputs + + def pad_inputs(self, *args, pad_type="first_fit"): + """Override to pad mrope_position_ids and vision inputs to bucket size.""" + orig_mrope = args[21] if len(args) >= 22 else None + orig_vis_emb = args[22] if len(args) >= 23 else None + orig_vis_mask = args[23] if len(args) >= 24 else None + + padded_args = super().pad_inputs(*args, pad_type=pad_type) + + if len(padded_args) >= 24 and orig_mrope is not None: + padded_seq_len = padded_args[0].shape[1] + batch_size = padded_args[0].shape[0] + is_cte = padded_seq_len > 1 + + if is_cte: + current_mrope = orig_mrope + current_vis_emb = orig_vis_emb + current_vis_mask = orig_vis_mask + + if ( + current_mrope.ndim == 3 + and current_mrope.shape[-1] != padded_seq_len + ): + orig_len = current_mrope.shape[-1] + pad_size = padded_seq_len - orig_len + last_pos = current_mrope[:, :, -1:] + pad_offsets = torch.arange( + 1, pad_size + 1, dtype=current_mrope.dtype + ) + pad_offsets = ( + pad_offsets.unsqueeze(0).unsqueeze(0).expand(3, batch_size, -1) + ) + mrope_pad = last_pos + pad_offsets + mrope_position_ids = torch.cat([current_mrope, mrope_pad], dim=-1) + elif current_mrope.ndim == 3: + mrope_position_ids = current_mrope + else: + mrope_position_ids = ( + torch.arange(0, padded_seq_len, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + + if ( + current_vis_emb is not None + and current_vis_emb.ndim == 3 + and current_vis_emb.shape[1] < padded_seq_len + ): + pad_emb = torch.zeros( + ( + batch_size, + padded_seq_len - current_vis_emb.shape[1], + current_vis_emb.shape[2], + ), + dtype=current_vis_emb.dtype, + ) + vision_embeddings = torch.cat([current_vis_emb, pad_emb], dim=1) + elif current_vis_emb is not None and current_vis_emb.ndim == 3: + vision_embeddings = current_vis_emb[:, :padded_seq_len] + else: + vision_embeddings = torch.zeros( + (batch_size, padded_seq_len, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + + if ( + current_vis_mask is not None + and current_vis_mask.ndim == 3 + and current_vis_mask.shape[1] < padded_seq_len + ): + pad_mask = torch.full( + (batch_size, padded_seq_len - current_vis_mask.shape[1], 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + vision_mask = torch.cat([current_vis_mask, pad_mask], dim=1) + elif current_vis_mask is not None and current_vis_mask.ndim == 3: + vision_mask = current_vis_mask[:, :padded_seq_len] + else: + vision_mask = torch.full( + (batch_size, padded_seq_len, 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + + padded_args = ( + *padded_args[:21], + mrope_position_ids, + vision_embeddings, + vision_mask, + ) + + padded_args = list(padded_args) + padded_args[23] = padded_args[23].clamp(max=padded_seq_len - 1) + padded_args = tuple(padded_args) + + return padded_args + + +# ============================================================ +# Top-Level Model +# ============================================================ + + +class NeuronQwen35ForCausalLM(NeuronBaseForCausalLM): + _model_cls = NeuronQwen35Model + + def get_model_wrapper_cls(self): + """Return custom ModelWrapper with DeltaNet state aliasing.""" + return Qwen35ModelWrapper + + @staticmethod + def load_hf_model(model_path, **kwargs): + """Load HF model weights. + + The model is a VL model (Qwen3_5ForConditionalGeneration) but we + only need the text backbone. + """ + from transformers import AutoModelForCausalLM + + kwargs.setdefault("trust_remote_code", True) + return AutoModelForCausalLM.from_pretrained(model_path, **kwargs) + + @classmethod + def get_config_cls(cls): + return Qwen35InferenceConfig + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + """Copy embed_tokens weight to lm_head for tied embeddings.""" + state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone() + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict, config): + """Strip VL wrapper prefix and convert to NxDI format.""" + new_sd = {} + for k, v in state_dict.items(): + if k.startswith("language_model."): + new_k = k.replace("language_model.", "", 1) + new_sd[new_k] = v + elif k.startswith("model.language_model."): + new_k = k.replace("model.language_model.", "", 1) + new_sd[new_k] = v + elif k.startswith("model.visual") or k.startswith("visual"): + continue # Skip vision encoder + elif k.startswith("model."): + new_sd[k.replace("model.", "", 1)] = v + elif k.startswith("mtp."): + continue # Skip MTP + elif k.startswith("lm_head."): + new_sd[k] = v + else: + new_sd[k] = v + + if ( + getattr(config, "tie_word_embeddings", False) + and "lm_head.weight" not in new_sd + and "embed_tokens.weight" in new_sd + ): + new_sd["lm_head.weight"] = new_sd["embed_tokens.weight"] + + return convert_qwen35_hf_to_neuron_state_dict(new_sd, config) + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + def _copy_past_key_values(self, outputs): + """Override to also copy DeltaNet state buffers on CPU.""" + super()._copy_past_key_values(outputs) + + num_output_from_trace = 1 + if ( + self.neuron_config.output_logits + and self.neuron_config.on_device_sampling_config + ): + num_output_from_trace = 2 + + if ( + hasattr(self, "token_generation_model") + and self.token_generation_model is not None + ): + tkg_model = self.token_generation_model.model + cte_model = self.context_encoding_model.model + else: + return + + if tkg_model.kv_mgr is not None: + num_kv = len(tkg_model.kv_mgr.past_key_values) + else: + num_kv = 0 + + state_start = num_output_from_trace + num_kv + + tkg_params = getattr(tkg_model, "_deltanet_state_params", []) + cte_params = getattr(cte_model, "_deltanet_state_params", []) + + if len(tkg_params) > 0 and state_start + len(tkg_params) <= len(outputs): + for i, (tkg_param, cte_param) in enumerate(zip(tkg_params, cte_params)): + new_state = outputs[state_start + i] + tkg_param.data = new_state + cte_param.data = new_state + + def get_required_kwargs(self): + """Return extra kwargs for HF generation loop.""" + return ["llava_args"] + + def _get_model_outputs( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + medusa_args, + llava_args, + slot_mapping=None, + block_table=None, + full_context_lens=None, + computed_context_lens=None, + tf_args=None, + ): + """Override to pass all 24 positional args explicitly.""" + is_prefill = self._is_prefill(position_ids) + + seq_len = input_ids.shape[1] + batch_size = input_ids.shape[0] + + if llava_args and len(llava_args) >= 2: + vision_embeddings = llava_args[0] + vision_mask = llava_args[1] + if len(llava_args) >= 3: + mrope_position_ids = llava_args[2] + else: + mrope_position_ids = None + elif is_prefill: + vision_embeddings = torch.zeros( + (batch_size, seq_len, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, seq_len, 1), + fill_value=seq_len - 1, + dtype=torch.int32, + ) + mrope_position_ids = None + else: + vision_embeddings = torch.zeros((0,), dtype=torch.float32) + vision_mask = torch.zeros((0,), dtype=torch.int32) + mrope_position_ids = None + + if is_prefill: + if mrope_position_ids is None: + mrope_position_ids = ( + torch.arange(0, seq_len, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + else: + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + + empties = [torch.empty(0) for _ in range(14)] + + if self._is_prefill(position_ids): + ctx_bs = self.context_encoding_model.neuron_config.batch_size + output_logits = [] + + for cb in range(0, batch_size, ctx_bs): + cb_end = min(cb + ctx_bs, batch_size) + actual_chunk = cb_end - cb + + chunk_input_ids = input_ids[cb:cb_end] + chunk_attn_mask = attention_mask[cb:cb_end] + chunk_pos_ids = position_ids[cb:cb_end] + chunk_seq_ids = seq_ids[cb:cb_end] + chunk_sampling = sampling_params[cb:cb_end] + chunk_prev_hidden = ( + prev_hidden[cb:cb_end] + if prev_hidden is not None + and hasattr(prev_hidden, "ndim") + and prev_hidden.ndim > 0 + and prev_hidden.shape[0] > 0 + else prev_hidden + ) + chunk_adapter_ids = ( + adapter_ids[cb:cb_end] + if adapter_ids is not None + and hasattr(adapter_ids, "ndim") + and adapter_ids.ndim > 0 + and adapter_ids.shape[0] > 0 + else adapter_ids + ) + + if mrope_position_ids.ndim == 3: + chunk_mrope = mrope_position_ids[:, cb:cb_end, :] + else: + chunk_mrope = mrope_position_ids + + if vision_embeddings.ndim == 3: + chunk_vis_emb = vision_embeddings[cb:cb_end] + chunk_vis_mask = vision_mask[cb:cb_end] + else: + chunk_vis_emb = vision_embeddings + chunk_vis_mask = vision_mask + + if actual_chunk < ctx_bs: + pad_n = ctx_bs - actual_chunk + chunk_input_ids = torch.cat( + [chunk_input_ids, chunk_input_ids[:1].expand(pad_n, -1)], dim=0 + ) + chunk_attn_mask = torch.cat( + [chunk_attn_mask, chunk_attn_mask[:1].expand(pad_n, -1)], dim=0 + ) + chunk_pos_ids = torch.cat( + [chunk_pos_ids, chunk_pos_ids[:1].expand(pad_n, -1)], dim=0 + ) + pad_seq = torch.arange( + batch_size, batch_size + pad_n, dtype=chunk_seq_ids.dtype + ) + chunk_seq_ids = torch.cat([chunk_seq_ids, pad_seq], dim=0) + chunk_sampling = torch.cat( + [chunk_sampling, chunk_sampling[:1].expand(pad_n, -1)], dim=0 + ) + if ( + chunk_prev_hidden is not None + and hasattr(chunk_prev_hidden, "ndim") + and chunk_prev_hidden.ndim > 0 + and chunk_prev_hidden.shape[0] > 0 + ): + chunk_prev_hidden = torch.cat( + [ + chunk_prev_hidden, + chunk_prev_hidden[:1].expand(pad_n, -1), + ], + dim=0, + ) + if ( + chunk_adapter_ids is not None + and hasattr(chunk_adapter_ids, "ndim") + and chunk_adapter_ids.ndim > 0 + and chunk_adapter_ids.shape[0] > 0 + ): + chunk_adapter_ids = torch.cat( + [ + chunk_adapter_ids, + chunk_adapter_ids[:1].expand(pad_n, -1), + ], + dim=0, + ) + if chunk_mrope.ndim == 3: + chunk_mrope = torch.cat( + [chunk_mrope, chunk_mrope[:, :1, :].expand(-1, pad_n, -1)], + dim=1, + ) + if chunk_vis_emb.ndim == 3: + chunk_vis_emb = torch.cat( + [ + chunk_vis_emb, + torch.zeros( + (pad_n,) + chunk_vis_emb.shape[1:], + dtype=chunk_vis_emb.dtype, + ), + ], + dim=0, + ) + chunk_vis_mask = torch.cat( + [ + chunk_vis_mask, + torch.full( + (pad_n,) + chunk_vis_mask.shape[1:], + fill_value=seq_len - 1, + dtype=chunk_vis_mask.dtype, + ), + ], + dim=0, + ) + + chunk_out = self.context_encoding_model( + chunk_input_ids, + chunk_attn_mask, + chunk_pos_ids, + chunk_seq_ids, + chunk_sampling, + chunk_prev_hidden, + chunk_adapter_ids, + *empties, + chunk_mrope, + chunk_vis_emb, + chunk_vis_mask, + ) + if actual_chunk < ctx_bs: + chunk_out = chunk_out[:actual_chunk] + output_logits.append(chunk_out) + + outputs = ( + torch.cat(output_logits, dim=0) + if len(output_logits) > 1 + else output_logits[0] + ) + self.kv_cache_populated = True + is_run_on_neuron = self.context_encoding_model.is_neuron() + else: + outputs = self.token_generation_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + *empties, + mrope_position_ids, + vision_embeddings, + vision_mask, + ) + is_run_on_neuron = self.token_generation_model.is_neuron() + + return outputs, is_run_on_neuron + + def get_compiler_args(self): + if self.compile_tag == CONTEXT_ENCODING_MODEL_TAG: + optimization_level = "-O1" + else: + optimization_level = "-O1" + + compiler_args = ( + "--enable-saturate-infinity " + "--enable-mixed-precision-accumulation " + f"--model-type transformer {optimization_level} " + "--auto-cast=none " + ) + return compiler_args diff --git a/contrib/models/Qwen3.5-4B/src/nki_kernels/__init__.py b/contrib/models/Qwen3.5-4B/src/nki_kernels/__init__.py new file mode 100644 index 00000000..22cdc372 --- /dev/null +++ b/contrib/models/Qwen3.5-4B/src/nki_kernels/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Custom NKI kernels for Qwen3.5-4B DeltaNet layers. + +Contains three kernel implementations: +- nki_deltanet: Per-token recurrent kernel (used for token generation) +- nki_deltanet_chunked: Per-chunk kernel (legacy, superseded by fused) +- nki_deltanet_fused: Fused single-kernel chunked forward (used for context encoding) +""" diff --git a/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet.py b/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet.py new file mode 100644 index 00000000..e6740aa1 --- /dev/null +++ b/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet.py @@ -0,0 +1,337 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""NKI kernels for DeltaNet gated delta rule recurrent forward. + +NKI v3 (SDK 2.29, NKI 0.3.0). Processes a SINGLE (batch, head) pair per kernel call. +The caller loops over (B, H) in PyTorch and calls this kernel for each pair. + +Input layout: All inputs are 2D contiguous tensors (S, 128). +Each call processes one (batch, head) element's full sequence. + +k_dim = v_dim = 128, which matches SBUF tile partition dimension exactly. +g and beta are scalars per token, expanded to (S, 128) by the caller. + +Two kernel variants: + deltanet_recurrent_fwd -- returns output only (original) + deltanet_recurrent_fwd_state -- returns (output, final_state) for CTE->TKG carry-over +""" + +import nki +import nki.isa as nisa +import nki.language as nl + +# Partition dimension max (NeuronCore SBUF tile width) +P_MAX = 128 + +# Shuffle mask: broadcast partition 0 to all partitions in a 32-wide group +_BROADCAST_MASK = [0] * 32 + + +@nki.jit +def deltanet_recurrent_fwd( + query: nl.ndarray, # (S, 128) float32 + key: nl.ndarray, # (S, 128) float32 + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128 + beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128 +) -> nl.ndarray: + """NKI kernel for DeltaNet recurrent forward -- single (batch, head). + + Iterates over sequence tokens with sequential_range. + State matrix (128 x 128) lives in SBUF. + + Args: + query: (S, 128) float32 + key: (S, 128) float32 + value: (S, 128) float32 + g_in: (S, 128) float32 + beta_in: (S, 128) float32 + + Returns: + output: (S, 128) float32 + """ + seq_len, dim = query.shape + + # Output tensor in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + + # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1 + seq_stride = dim + + # Initialize recurrent state in SBUF: (128, 128) + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # Sequential loop over tokens (state-dependent) + for t in nl.sequential_range(seq_len): + tok_offset = t * seq_stride + + # ---- Load inputs for token t ---- + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_t, + src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_t, + src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_t, + src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_t, + src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_t, + src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + # ---- Step 1: Decay state -- state = state * exp(g_t) ---- + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + nisa.tensor_copy(dst=state, src=state_decayed) + + # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ---- + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + # ---- Step 3: delta = (v_t - kv_mem) * beta_t ---- + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + # ---- Step 4: state += outer(k_t, delta) ---- + # Broadcast multiply: outer[i,j] = k_t[i] * delta[j] + # 1) Transpose delta (128,1) -> (1,128) in PSUM + # 2) Copy PSUM (1,128) -> SBUF (128,128) -- partition broadcast + # 3) Multiply by k_t (128,1) which broadcasts across free dim + # This avoids the nc_matmul P=1 outer product (wastes 127/128 TE lanes). + + # Transpose delta to get values along free dimension + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + # Copy PSUM (1, 128) -> SBUF (1, 128) first (NKI 0.3.0 requires matching P dims) + delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum) + + # Broadcast (1, 128) SBUF -> (128, 128) SBUF via nc_stream_shuffle + # Each partition row gets the same delta values + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=delta_row_sb[0:1, 0:P_MAX], + dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + # Element-wise multiply: outer[i,j] = delta_broadcast[i,j] * k_t[i,0] + # tensor_scalar broadcasts (P,1) k_t across all F columns + outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=outer_prod, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + engine=nisa.vector_engine, + ) + + # Accumulate into state + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add) + nisa.tensor_copy(dst=state, src=state_new) + + # ---- Step 5: o_t = state^T @ q_t ---- + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + # ---- Store output for token t ---- + nisa.dma_copy( + dst=output.ap(pattern=[[1, dim]], offset=tok_offset), + src=o_t, + ) + + return output + + +@nki.jit +def deltanet_recurrent_fwd_state( + query: nl.ndarray, # (S, 128) float32 + key: nl.ndarray, # (S, 128) float32 + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128 + beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128 +): + """NKI kernel for DeltaNet recurrent forward with final state output. + + Same recurrence as deltanet_recurrent_fwd, but ALSO writes the final + recurrent state (128, 128) to an output HBM buffer. This enables + CTE -> TKG state carry-over. + + Returns: + output: (S, 128) float32 -- per-token output + final_state: (128, 128) float32 -- recurrent state after last token + """ + seq_len, dim = query.shape + + # Output tensors in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + final_state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1 + seq_stride = dim + + # Initialize recurrent state in SBUF: (128, 128) + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # Sequential loop over tokens (state-dependent) + for t in nl.sequential_range(seq_len): + tok_offset = t * seq_stride + + # ---- Load inputs for token t ---- + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_t, + src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_t, + src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_t, + src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_t, + src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_t, + src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + # ---- Step 1: Decay state -- state = state * exp(g_t) ---- + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + nisa.tensor_copy(dst=state, src=state_decayed) + + # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ---- + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + # ---- Step 3: delta = (v_t - kv_mem) * beta_t ---- + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + # ---- Step 4: state += outer(k_t, delta) ---- + # Broadcast multiply: outer[i,j] = k_t[i] * delta[j] + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + # Copy PSUM (1, 128) -> SBUF (1, 128) first (NKI 0.3.0 requires matching P dims) + delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum) + + # Broadcast (1, 128) SBUF -> (128, 128) SBUF via nc_stream_shuffle + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=delta_row_sb[0:1, 0:P_MAX], + dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=outer_prod, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + engine=nisa.vector_engine, + ) + + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add) + nisa.tensor_copy(dst=state, src=state_new) + + # ---- Step 5: o_t = state^T @ q_t ---- + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + # ---- Store output for token t ---- + nisa.dma_copy( + dst=output.ap(pattern=[[1, dim]], offset=tok_offset), + src=o_t, + ) + + # ---- Write final state to HBM ---- + # state is (128, 128) in SBUF, copy to final_state in HBM + # Use dma_copy with full tile: P_MAX rows, dim cols + nisa.dma_copy(dst=final_state, src=state) + + return output, final_state diff --git a/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_chunked.py b/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_chunked.py new file mode 100644 index 00000000..88f0cc1b --- /dev/null +++ b/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_chunked.py @@ -0,0 +1,323 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""NKI per-chunk DeltaNet kernel for CTE (context encoding / prefill). + +Single-chunk kernel: processes one chunk (128 tokens) with Neumann-series +power-doubling for intra-chunk correction. The caller loops over chunks +in PyTorch, passing state between calls. + +Each kernel call: + - Takes one chunk of data: q, k, v, beta, g_cumsum, g_last (all 128x128) + - Takes recurrent state_in (128x128) + - Returns chunk output (128x128) and state_out (128x128) + +No sequence-indexed DMA inside the kernel -- all inputs/outputs are full tiles. +This avoids the DMA OOB issue seen with nl.sequential_range + slice indexing +in the NxDI model compilation context. + +NKI v3 (SDK 2.29, NKI 0.3.0). Uses nki.* namespace. +""" + +import nki +import nki.isa as nisa +import nki.language as nl + +P_MAX = 128 + + +@nki.jit +def deltanet_chunk_step( + query, # (128, 128) float32 -- one chunk, l2-normed+scaled + key, # (128, 128) float32 -- one chunk, l2-normed + value, # (128, 128) float32 -- one chunk + beta_broadcast, # (128, 128) float32 -- write gate broadcast to 128 + g_cumsum, # (128, 128) float32 -- cumsum of g within chunk, broadcast + g_last, # (128, 128) float32 -- g_cumsum[-1], constant in chunk, broadcast + state_in, # (128, 128) float32 -- recurrent state from previous chunk + lower_mask, # (128, 128) float32 -- strict lower triangular + identity, # (128, 128) float32 -- identity matrix + lower_mask_diag, # (128, 128) float32 -- lower tri with diagonal +): + """Process one chunk of DeltaNet. + + Returns: + output: (128, 128) float32 -- chunk output + state_out: (128, 128) float32 -- updated recurrent state + """ + C, dim = query.shape # C = 128, dim = 128 + + # Output tensors in HBM + output = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.shared_hbm) + state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # Load all inputs into SBUF + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=q_c, src=query) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=k_c, src=key) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=v_c, src=value) + + beta_c = nl.ndarray((P_MAX, dim), dtype=beta_broadcast.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=beta_c, src=beta_broadcast) + + gc_c = nl.ndarray((P_MAX, dim), dtype=g_cumsum.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=gc_c, src=g_cumsum) + + gl_c = nl.ndarray((P_MAX, dim), dtype=g_last.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=gl_c, src=g_last) + + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=state, src=state_in) + + # Load masks + eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=eye, src=identity) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + Lmask_d = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask_d, src=lower_mask_diag) + + # ============================================================ + # k_beta = K * beta, v_beta = V * beta + # ============================================================ + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=k_beta, data1=k_c, data2=beta_c, op=nl.multiply) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_beta, data1=v_c, data2=beta_c, op=nl.multiply) + + # ============================================================ + # exp(g_cumsum) and exp(-g_cumsum) + # ============================================================ + exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_gc, op=nl.exp, data=gc_c, bias=None, scale=1.0) + + neg_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_gc, + data=gc_c, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + exp_neg_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_neg_gc, op=nl.exp, data=neg_gc, bias=None, scale=1.0) + + # exp(g_last) for state decay + exp_gl = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_gl, op=nl.exp, data=gl_c, bias=None, scale=1.0) + + # ============================================================ + # Phase 1: Build A matrix (intra-chunk correction) + # QK = k_beta @ k^T -- contract over features + # ============================================================ + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kb_T_psum, stationary=k_beta, moving=eye) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=k_T_psum, stationary=k_c, moving=eye) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T) + QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK, src=QK_psum) + + # ============================================================ + # Decay mask: QK_decay[i,j] = QK[i,j] * exp(gc[i]) * exp(-gc[j]) + # ============================================================ + QK_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_row, data1=QK, data2=exp_gc, op=nl.multiply) + + QK_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_r_T_psum, stationary=QK_row, moving=eye) + QK_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_r_T, src=QK_r_T_psum) + + QK_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_r_T_col, data1=QK_r_T, data2=exp_neg_gc, op=nl.multiply) + + QK_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_d_psum, stationary=QK_r_T_col, moving=eye) + QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_decay, src=QK_d_psum) + + # A = -QK_decay * lower_mask + neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_QK_decay, + data=QK_decay, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + A = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=A, data1=neg_QK_decay, data2=Lmask, op=nl.multiply) + + # ============================================================ + # Neumann power-doubling: N = (I+A)(I+A^2)...(I+A^{64}) + # ============================================================ + P_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=P_acc, data1=eye, data2=A, op=nl.add) + + A_pow = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_pow, src=A) + + for _round in nl.sequential_range(6): + Ap_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Ap_T_psum, stationary=A_pow, moving=eye) + Ap_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=Ap_T, src=Ap_T_psum) + + Ap_sq_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Ap_sq_psum, stationary=Ap_T, moving=A_pow) + nisa.tensor_copy(dst=A_pow, src=Ap_sq_psum) + + IpA = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=IpA, data1=eye, data2=A_pow, op=nl.add) + + IpA_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=IpA_T_psum, stationary=IpA, moving=eye) + IpA_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=IpA_T, src=IpA_T_psum) + + Pacc_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Pacc_psum, stationary=IpA_T, moving=P_acc) + nisa.tensor_copy(dst=P_acc, src=Pacc_psum) + + # ============================================================ + # Apply N: value_corr = N @ v_beta, k_cumdecay = N @ (k_beta * exp_gc) + # ============================================================ + N_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=N_T_psum, stationary=P_acc, moving=eye) + N_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=N_T, src=N_T_psum) + + vc_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vc_psum, stationary=N_T, moving=v_beta) + value_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=value_corr, src=vc_psum) + + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=kb_exp_gc, data1=k_beta, data2=exp_gc, op=nl.multiply) + + kcd_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_psum, stationary=N_T, moving=kb_exp_gc) + k_cumdecay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_cumdecay, src=kcd_psum) + + # ============================================================ + # Phase 2: Inter-chunk state propagation + # attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + # ============================================================ + q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=q_T_psum, stationary=q_c, moving=eye) + q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T) + qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_raw, src=qk_psum) + + qk_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=qk_row, data1=qk_raw, data2=exp_gc, op=nl.multiply) + + qk_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_r_T_psum, stationary=qk_row, moving=eye) + qk_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_r_T, src=qk_r_T_psum) + + qk_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=qk_r_T_col, data1=qk_r_T, data2=exp_neg_gc, op=nl.multiply) + + qk_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_d_psum, stationary=qk_r_T_col, moving=eye) + qk_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_decay, src=qk_d_psum) + + attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=attn_intra, data1=qk_decay, data2=Lmask_d, op=nl.multiply) + + # ============================================================ + # v_prime = k_cumdecay @ state + # ============================================================ + kcd_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_T_psum, stationary=k_cumdecay, moving=eye) + kcd_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kcd_T, src=kcd_T_psum) + + vp_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vp_psum, stationary=kcd_T, moving=state) + v_prime = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=v_prime, src=vp_psum) + + v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_new, data1=value_corr, data2=v_prime, op=nl.subtract) + + # ============================================================ + # attn_inter = (q * exp(g_cumsum)) @ state + # ============================================================ + q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=q_exp, data1=q_c, data2=exp_gc, op=nl.multiply) + + qe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qe_T_psum, stationary=q_exp, moving=eye) + qe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qe_T, src=qe_T_psum) + + ai_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state) + attn_inter = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=attn_inter, src=ai_psum) + + # ============================================================ + # attn_intra @ v_new + # ============================================================ + ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_T_psum, stationary=attn_intra, moving=eye) + ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ai_T, src=ai_T_psum) + + intra_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=intra_psum, stationary=ai_T, moving=v_new) + intra_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=intra_out, src=intra_psum) + + # ============================================================ + # chunk_output = attn_inter + intra_out + # ============================================================ + chunk_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add) + + nisa.dma_copy(dst=output, src=chunk_out) + + # ============================================================ + # State update: state_new = exp(g_last) * (state + k_raw_decay^T @ v_new) + # ============================================================ + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=k_raw_decay, data1=k_c, data2=exp_neg_gc, op=nl.multiply) + + kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_psum, stationary=k_raw_decay, moving=v_new) + kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_outer, src=kv_psum) + + state_plus = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_plus, data1=state, data2=kv_outer, op=nl.add) + + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state_plus, data2=exp_gl, op=nl.multiply) + + nisa.dma_copy(dst=state_out, src=state_new) + + return output, state_out diff --git a/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_fused.py b/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_fused.py new file mode 100644 index 00000000..3447a138 --- /dev/null +++ b/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_fused.py @@ -0,0 +1,577 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Fused single-kernel DeltaNet chunked forward for CTE (context encoding). + +SSD-style architecture: processes ALL chunks for one (batch, head) pair in +a single NKI kernel call. State (128x128) persists in SBUF across chunks — +no HBM round-trips for inter-chunk state propagation. + +Key optimizations over nki_deltanet_chunked.py: + 1. Single kernel call per (B,H) instead of B*H*num_chunks calls + 2. State in SBUF across all chunks (no HBM state read/write per chunk) + 3. In-kernel cumsum via tensor_tensor_scan (no PyTorch cumsum) + 4. Masks and constants loaded once, reused across chunks + 5. Uses tensor_scalar for partition-broadcast (no explicit broadcast loops) + 6. nc_transpose (Vector Engine) for all 128x128 transposes instead of + nc_matmul(moving=eye) (Tensor Engine) — frees TE for actual math + +NKI 0.3.0 (SDK 2.29). k_dim = v_dim = 128 = P_MAX exactly. +Chunk size = 128 = P_MAX (one tile per chunk). + +Mathematical framework (same as nki_deltanet_chunked.py): + Per-chunk Neumann-series power-doubling for intra-chunk correction: + A = -QK_decay * lower_mask + N = (I+A)(I+A^2)(I+A^4)...(I+A^64) [6 rounds] + value_corr = N @ v_beta + k_cumdecay = N @ (k_beta * exp(gc)) + + Inter-chunk state propagation: + v_prime = k_cumdecay @ state + v_new = value_corr - v_prime + attn_inter = (q * exp(gc)) @ state + attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + output = attn_inter + attn_intra @ v_new + state = exp(g_last) * (state + k_raw_decay^T @ v_new) +""" + +import numpy as np + +import nki +import nki.isa as nisa +import nki.language as nl + +P_MAX = 128 # Partition dim = chunk_size = k_dim = v_dim +CHUNK_SIZE = 128 + +# Broadcast partition 0 to all partitions in a 32-wide group +_BROADCAST_MASK = [0] * 32 + + +def _make_lower_mask(): + """Strict lower triangular (128x128) as numpy constant.""" + return np.tril(np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=-1) + + +def _make_lower_mask_diag(): + """Lower triangular with diagonal (128x128) as numpy constant.""" + return np.tril(np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=0) + + +def _make_identity(): + """Identity matrix (128x128) as numpy constant.""" + return np.eye(CHUNK_SIZE, dtype=np.float32) + + +@nki.jit +def deltanet_fused_chunked_fwd( + query: nl.ndarray, # (S, 128) float32 — l2-normed and scaled + key: nl.ndarray, # (S, 128) float32 — l2-normed + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 1) float32 — per-token log-decay (NOT cumsum) + beta_in: nl.ndarray, # (S, 1) float32 — per-token write gate + lower_mask: nl.ndarray, # (128, 128) float32 — strict lower tri + identity: nl.ndarray, # (128, 128) float32 — identity + lower_mask_diag: nl.ndarray, # (128, 128) float32 — lower tri with diag +): + """Fused chunked DeltaNet forward — single kernel call per (batch, head). + + Processes all chunks sequentially within the kernel, keeping the recurrent + state (128x128) in SBUF across chunks. Returns per-token output and + final state. + + Input requirements: + - S must be divisible by 128 (pad before calling) + - query must be l2-normed and scaled by 1/sqrt(k_dim) + - key must be l2-normed + - g_in is RAW log-decay (cumsum computed in-kernel via tensor_tensor_scan) + - beta_in is sigmoid(b) (write gate) + + Returns: + output: (S, 128) float32 + final_state: (128, 128) float32 + """ + seq_len = query.shape[0] + dim = query.shape[1] # 128 + num_chunks = seq_len // CHUNK_SIZE + + # Output tensors in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + final_state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # ================================================================ + # Load constant masks into SBUF once (reused across all chunks) + # ================================================================ + eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=eye, src=identity) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + Lmask_d = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask_d, src=lower_mask_diag) + + # Ones vector for cumsum scan: (1, CHUNK_SIZE) + ones_1xC = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=ones_1xC, value=1.0) + + # Zero initial for cumsum scan + zero_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=zero_11, value=0.0) + + # ================================================================ + # Initialize recurrent state in SBUF — persists across ALL chunks + # ================================================================ + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # ================================================================ + # Sequential chunk processing + # ================================================================ + for i_chunk in nl.sequential_range(num_chunks): + chunk_start = i_chunk * CHUNK_SIZE + + # ---- Load chunk data from HBM ---- + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_c, + src=query[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_c, + src=key[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_c, + src=value[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + # g: (CHUNK_SIZE, 1) — raw log-decay per token + g_chunk_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_chunk_p[0:CHUNK_SIZE, 0:1], + src=g_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + # beta: (CHUNK_SIZE, 1) — write gate scalar per token + beta_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_p[0:CHUNK_SIZE, 0:1], + src=beta_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + # ---- In-kernel cumsum of g via tensor_tensor_scan ---- + # Need g as (1, CHUNK_SIZE) for scan along free dim. + # Transpose: (CHUNK_SIZE, 1) -> (1, CHUNK_SIZE) via nc_transpose + g_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=g_padded, value=0.0) + nisa.tensor_copy( + dst=g_padded[0:CHUNK_SIZE, 0:1], + src=g_chunk_p[0:CHUNK_SIZE, 0:1], + ) + + g_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=g_tp_psum, data=g_padded) + + g_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=g_row[0:1, 0:CHUNK_SIZE], + src=g_tp_psum[0:1, 0:CHUNK_SIZE], + ) + + # cumsum: gc_row[t] = 1.0 * gc_row[t-1] + g_row[t] + gc_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor_scan( + dst=gc_row[0:1, 0:CHUNK_SIZE], + data0=ones_1xC[0:1, 0:CHUNK_SIZE], + data1=g_row[0:1, 0:CHUNK_SIZE], + initial=zero_11[0:1, 0:1], + op0=nl.multiply, + op1=nl.add, + ) + + # Transpose gc back to (CHUNK_SIZE, 1) partition layout + gc_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=gc_padded, value=0.0) + nisa.tensor_copy( + dst=gc_padded[0:1, 0:CHUNK_SIZE], + src=gc_row[0:1, 0:CHUNK_SIZE], + ) + + gc_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=gc_tp_psum, data=gc_padded) + + # gc_p: (P_MAX, 1) — cumulative sum of g per token in this chunk + gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gc_p[0:CHUNK_SIZE, 0:1], + src=gc_tp_psum[0:CHUNK_SIZE, 0:1], + ) + + # g_last = gc[-1] (scalar) — needed for state decay + gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gl_11[0:1, 0:1], + src=gc_row[0:1, CHUNK_SIZE - 1 : CHUNK_SIZE], + ) + + # ---- Compute exp(gc), exp(-gc), exp(g_last) as (P_MAX, 1) scalars ---- + # These (P_MAX, 1) tensors are used with tensor_scalar to broadcast + # across the free dimension without explicit (P_MAX, dim) copies. + + exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_gc_p, + data=gc_p, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + exp_neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_neg_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=neg_gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + # exp(g_last): scalar, then broadcast to (P_MAX, 1) + exp_gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_11, + op=nl.exp, + data=gl_11, + bias=None, + scale=1.0, + ) + + exp_gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=exp_gl_11[0:1, 0:1], + dst=exp_gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + # ============================================================ + # k_beta = K * beta, v_beta = V * beta + # tensor_scalar broadcasts beta_p (P_MAX, 1) across free dim + # ============================================================ + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_beta, + data=k_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=v_beta, + data=v_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + # ============================================================ + # Phase 1: Build A matrix (intra-chunk correction) + # Transpose K and K_beta for matmul + # ============================================================ + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kb_T_psum, data=k_beta) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=k_T_psum, data=k_c) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + # QK = k_beta^T @ k (contract over features) + QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T) + QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK, src=QK_psum) + + # ============================================================ + # Decay mask: QK_decay[i,j] = QK[i,j] * exp(gc[i]) * exp(-gc[j]) + # + # Row scaling: QK_row[i,:] = QK[i,:] * exp(gc[i]) + # Then transpose, column scale, transpose back. + # Uses tensor_scalar with (P_MAX,1) operand for row scaling. + # ============================================================ + QK_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=QK_row, + data=QK, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + # Transpose to scale columns (now rows in transposed view) + QK_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=QK_r_T_psum, data=QK_row) + QK_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_r_T, src=QK_r_T_psum) + + QK_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=QK_r_T_col, + data=QK_r_T, + op0=nl.multiply, + operand0=exp_neg_gc_p, + engine=nisa.vector_engine, + ) + + # Transpose back + QK_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=QK_d_psum, data=QK_r_T_col) + QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_decay, src=QK_d_psum) + + # A = -QK_decay * lower_mask + neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_QK_decay, + data=QK_decay, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + A_mat = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=A_mat, data1=neg_QK_decay, data2=Lmask, op=nl.multiply) + + # ============================================================ + # Neumann power-doubling: N = (I+A)(I+A^2)...(I+A^{64}) + # 6 rounds → resolves rank up to 2^6 = 64 (sufficient for chunk=128) + # ============================================================ + P_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=P_acc, data1=eye, data2=A_mat, op=nl.add) + + A_pow = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_pow, src=A_mat) + + for _round in nl.sequential_range(6): + # A_pow = A_pow^2: transpose A_pow, then matmul + Ap_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=Ap_T_psum, data=A_pow) + Ap_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=Ap_T, src=Ap_T_psum) + + Ap_sq_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Ap_sq_psum, stationary=Ap_T, moving=A_pow) + nisa.tensor_copy(dst=A_pow, src=Ap_sq_psum) + + # P_acc = (I + A_pow) @ P_acc: transpose IpA, then matmul + IpA = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=IpA, data1=eye, data2=A_pow, op=nl.add) + + IpA_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=IpA_T_psum, data=IpA) + IpA_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=IpA_T, src=IpA_T_psum) + + Pacc_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Pacc_psum, stationary=IpA_T, moving=P_acc) + nisa.tensor_copy(dst=P_acc, src=Pacc_psum) + + # ============================================================ + # Apply N: value_corr = N @ v_beta + # k_cumdecay = N @ (k_beta * exp(gc)) + # ============================================================ + N_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=N_T_psum, data=P_acc) + N_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=N_T, src=N_T_psum) + + vc_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vc_psum, stationary=N_T, moving=v_beta) + value_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=value_corr, src=vc_psum) + + # k_beta * exp(gc): row-scaled + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=kb_exp_gc, + data=k_beta, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + kcd_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_psum, stationary=N_T, moving=kb_exp_gc) + k_cumdecay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_cumdecay, src=kcd_psum) + + # ============================================================ + # Phase 2: Inter-chunk state propagation + # attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + # ============================================================ + q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=q_T_psum, data=q_c) + q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T) + qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_raw, src=qk_psum) + + # Row-scale by exp(gc) + qk_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=qk_row, + data=qk_raw, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + # Transpose, column-scale by exp(-gc), transpose back + qk_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qk_r_T_psum, data=qk_row) + qk_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_r_T, src=qk_r_T_psum) + + qk_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=qk_r_T_col, + data=qk_r_T, + op0=nl.multiply, + operand0=exp_neg_gc_p, + engine=nisa.vector_engine, + ) + + qk_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qk_d_psum, data=qk_r_T_col) + qk_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_decay, src=qk_d_psum) + + attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=attn_intra, data1=qk_decay, data2=Lmask_d, op=nl.multiply + ) + + # ============================================================ + # v_prime = k_cumdecay @ state (state is in SBUF!) + # ============================================================ + kcd_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kcd_T_psum, data=k_cumdecay) + kcd_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kcd_T, src=kcd_T_psum) + + vp_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vp_psum, stationary=kcd_T, moving=state) + v_prime = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=v_prime, src=vp_psum) + + v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_new, data1=value_corr, data2=v_prime, op=nl.subtract) + + # ============================================================ + # attn_inter = (q * exp(gc)) @ state (state is in SBUF!) + # ============================================================ + q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_exp, + data=q_c, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + qe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qe_T_psum, data=q_exp) + qe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qe_T, src=qe_T_psum) + + ai_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state) + attn_inter = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=attn_inter, src=ai_psum) + + # ============================================================ + # attn_intra @ v_new + # ============================================================ + ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=ai_T_psum, data=attn_intra) + ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ai_T, src=ai_T_psum) + + intra_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=intra_psum, stationary=ai_T, moving=v_new) + intra_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=intra_out, src=intra_psum) + + # ============================================================ + # chunk_output = attn_inter + intra_out + # ============================================================ + chunk_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add) + + # Store output chunk to HBM + nisa.dma_copy( + dst=output[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + src=chunk_out, + ) + + # ============================================================ + # State update: state = exp(g_last) * (state + k_raw_decay^T @ v_new) + # state is updated IN-PLACE in SBUF — no HBM round-trip! + # ============================================================ + + # k_raw_decay = k * exp(-gc) + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_raw_decay, + data=k_c, + op0=nl.multiply, + operand0=exp_neg_gc_p, + engine=nisa.vector_engine, + ) + + # k_raw_decay^T @ v_new → (dim, dim) outer product sum + # nc_matmul: result[M,N] = sum_K stationary[K,M] * moving[K,N] + # stationary=k_raw_decay (P_MAX, dim), moving=v_new (P_MAX, dim) + # Result: sum over tokens -> (dim, dim) + kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_psum, stationary=k_raw_decay, moving=v_new) + kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_outer, src=kv_psum) + + # state = state + kv_outer + state_plus = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_plus, data1=state, data2=kv_outer, op=nl.add) + + # state = state_plus * exp(g_last) + # tensor_scalar broadcasts exp_gl_p (P_MAX, 1) across free dim + nisa.tensor_scalar( + dst=state, + data=state_plus, + op0=nl.multiply, + operand0=exp_gl_p, + engine=nisa.vector_engine, + ) + + # ---- Write final state to HBM ---- + nisa.dma_copy(dst=final_state_out, src=state) + + return output, final_state_out diff --git a/contrib/models/Qwen3.5-4B/test/__init__.py b/contrib/models/Qwen3.5-4B/test/__init__.py new file mode 100644 index 00000000..04f8b7b7 --- /dev/null +++ b/contrib/models/Qwen3.5-4B/test/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/contrib/models/Qwen3.5-4B/test/integration/__init__.py b/contrib/models/Qwen3.5-4B/test/integration/__init__.py new file mode 100644 index 00000000..04f8b7b7 --- /dev/null +++ b/contrib/models/Qwen3.5-4B/test/integration/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/contrib/models/Qwen3.5-4B/test/integration/test_model.py b/contrib/models/Qwen3.5-4B/test/integration/test_model.py new file mode 100644 index 00000000..3bdafc3e --- /dev/null +++ b/contrib/models/Qwen3.5-4B/test/integration/test_model.py @@ -0,0 +1,482 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Integration tests for Qwen3.5-4B on Neuron. + +Tests compilation, loading, inference accuracy, and performance using +the full 4B model with pre-downloaded HuggingFace weights on a trn2 instance. + +These tests use the same Qwen35* classes and QWEN35_* env vars because the +underlying `qwen3_5` dense hybrid architecture is shared across Qwen3.5/3.6. + +Note: A mini model option is not provided because DeltaNet layers require NKI +kernels that only execute on Neuron devices, and the hybrid DeltaNet + GQA +architecture should be validated at TP=4 before attempting TP=2. + +Environment variables: + QWEN35_MODEL_PATH Path to HF model weights (required) + QWEN35_COMPILED_PATH Path to compiled artifacts (default: /tmp/qwen35_4b_traced) + QWEN35_TP_DEGREE Tensor parallelism degree (default: 4) + QWEN35_SEQ_LEN Max sequence length (default: 128) + TTFT_THRESHOLD_MS Max TTFT in ms (default: 5000) + THROUGHPUT_THRESHOLD Min throughput in tok/s (default: 5.0) + +Prerequisites: + - trn2.3xlarge or larger with TP >= 4 NeuronCores available + - NXDI installed (neuronx_distributed_inference) + - HuggingFace weights downloaded to QWEN35_MODEL_PATH + - SDK 2.29+ (NKI 0.3.0 required for DeltaNet kernels) + +Usage: + # Full model (trn2.3xlarge, TP=4): + QWEN35_MODEL_PATH=/mnt/models/Qwen3.5-4B \\ + QWEN35_COMPILED_PATH=/mnt/models/qwen35_4b_traced \\ + pytest test/integration/test_model.py --capture=tee-sys +""" + +import gc +import os +import sys +import time + +import pytest +import torch + +# Ensure the contrib root (Qwen3.5-4B/) is on sys.path +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +# ── Configuration from environment ────────────────────────────────────── + +MODEL_PATH = os.environ.get("QWEN35_MODEL_PATH", "") +COMPILED_PATH = os.environ.get("QWEN35_COMPILED_PATH", "/tmp/qwen35_4b_traced") +TP_DEGREE = int(os.environ.get("QWEN35_TP_DEGREE", "4")) +SEQ_LEN = int(os.environ.get("QWEN35_SEQ_LEN", "128")) +TTFT_THRESHOLD_MS = float(os.environ.get("TTFT_THRESHOLD_MS", "5000")) +THROUGHPUT_THRESHOLD = float(os.environ.get("THROUGHPUT_THRESHOLD", "5.0")) + +requires_model_path = pytest.mark.skipif( + not MODEL_PATH, + reason=( + "QWEN35_MODEL_PATH not set. Integration tests require the full 4B model " + "weights. Set QWEN35_MODEL_PATH=/path/to/Qwen3.5-4B to run these tests." + ), +) + + +# ── Fixtures ──────────────────────────────────────────────────────────── + + +@pytest.fixture(scope="module") +def model_path(): + """Return path to model weights.""" + return MODEL_PATH + + +@pytest.fixture(scope="module") +def compiled_model(model_path): + """Compile and load the model on Neuron.""" + import json + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + + neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=False, + flash_decoding_enabled=False, + logical_nc_config=2, + save_sharded_checkpoint=True, + ) + + # Read config.json directly (model_type 'qwen3_5' may not be in + # AutoConfig registry for all transformers versions) + with open(os.path.join(model_path, "config.json")) as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + config_dict["tie_word_embeddings"] = full_config.get( + "tie_word_embeddings", + text_config.get("tie_word_embeddings", False), + ) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + + inf_config = Qwen35InferenceConfig( + neuron_config=neuron_config, + **config_dict, + ) + + # Compile if no existing artifacts + compiled_path = COMPILED_PATH + neff_path = os.path.join(compiled_path, "model.pt") + if not os.path.exists(neff_path): + print(f"Compiling to {compiled_path}...") + model = NeuronQwen35ForCausalLM(model_path, inf_config) + model.compile(compiled_path) + del model + gc.collect() + + # Load + print(f"Loading from {compiled_path}...") + model = NeuronQwen35ForCausalLM(compiled_path) + model.load(compiled_path) + return model + + +@pytest.fixture(scope="module") +def tokenizer(model_path): + """Load tokenizer.""" + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(model_path, padding_side="right") + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + + +@pytest.fixture(scope="module") +def generation_config(tokenizer): + """Create generation config.""" + from transformers import GenerationConfig + + return GenerationConfig( + do_sample=True, + top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + +def _generate(model, tokenizer, generation_config, prompt, max_new_tokens=20): + """Generate text using the NXDI model.""" + import transformers + + from neuronx_distributed_inference.utils.hf_adapter import ( + HuggingFaceGenerationAdapter, + ) + + inputs = tokenizer(prompt, padding=True, return_tensors="pt") + gen_model = HuggingFaceGenerationAdapter(model) + gen_model.generation_config.transformers_version = transformers.__version__ + generation_config.transformers_version = transformers.__version__ + outputs = gen_model.generate( + inputs.input_ids, + generation_config=generation_config, + attention_mask=inputs.attention_mask, + max_new_tokens=max_new_tokens, + ) + return outputs[0].tolist(), tokenizer.decode(outputs[0], skip_special_tokens=True) + + +def _is_repetitive(text, max_repeat=5): + """Check for excessive word repetition.""" + words = text.split() + if len(words) < max_repeat: + return False + for i in range(len(words) - max_repeat + 1): + if len(set(words[i : i + max_repeat])) == 1: + return True + return False + + +# ── Smoke Tests ───────────────────────────────────────────────────────── + + +@requires_model_path +def test_model_loads(compiled_model): + """Model compiles and loads successfully.""" + assert compiled_model is not None + assert hasattr(compiled_model, "neuron_config") + print(" Model loaded successfully") + + +@requires_model_path +def test_model_generates(compiled_model, tokenizer, generation_config): + """Model generates at least 5 tokens.""" + tokens, text = _generate( + compiled_model, + tokenizer, + generation_config, + "Hello, I am a language model", + max_new_tokens=20, + ) + input_len = len(tokenizer.encode("Hello, I am a language model")) + new_tokens = len(tokens) - input_len + assert new_tokens >= 5, f"Expected >= 5 new tokens, got {new_tokens}" + print(f" Generated {new_tokens} tokens: {text[:100]}...") + + +# ── Accuracy Tests ────────────────────────────────────────────────────── + + +@requires_model_path +def test_output_coherence(compiled_model, tokenizer, generation_config): + """Output should contain multiple words and not be excessively repetitive.""" + _, text = _generate( + compiled_model, + tokenizer, + generation_config, + "The capital of France is", + max_new_tokens=30, + ) + generated = text[len("The capital of France is") :].strip() + words = generated.split() + assert len(words) >= 3, f"Expected >= 3 words, got {len(words)}: '{generated}'" + assert not _is_repetitive(generated), ( + f"Output is excessively repetitive: '{generated}'" + ) + print(f" Output coherent ({len(words)} words): {generated[:80]}...") + + +@requires_model_path +def test_top_token_valid(compiled_model, tokenizer, generation_config): + """First generated token should be a valid decodable token.""" + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + "Hello!", + max_new_tokens=1, + ) + input_len = len(tokenizer.encode("Hello!")) + first_new = tokens[input_len] + assert 0 <= first_new < tokenizer.vocab_size, ( + f"Token {first_new} out of vocab range" + ) + decoded = tokenizer.decode([first_new]) + assert len(decoded) > 0, f"Token {first_new} decoded to empty string" + print(f" First token: {first_new} -> '{decoded}'") + + +@requires_model_path +def test_capital_of_france(compiled_model, tokenizer, generation_config): + """'The capital of France is' should produce 'Paris' in the response.""" + tokens, text = _generate( + compiled_model, + tokenizer, + generation_config, + "The capital of France is", + max_new_tokens=30, + ) + generated = text[len("The capital of France is") :].strip() + assert "paris" in generated.lower(), ( + f"Expected 'Paris' in output, got: '{generated}'" + ) + print(f" Capital of France: {generated}") + + +# ── Performance Tests ─────────────────────────────────────────────────── + + +@requires_model_path +def test_performance_ttft(compiled_model, tokenizer, generation_config): + """Time to first token should be within threshold.""" + prompt = "Hello, I am a language model" + + # Warmup + _generate(compiled_model, tokenizer, generation_config, prompt, max_new_tokens=1) + + # Measure + times = [] + for _ in range(3): + t0 = time.perf_counter() + _generate( + compiled_model, tokenizer, generation_config, prompt, max_new_tokens=1 + ) + times.append((time.perf_counter() - t0) * 1000) + + avg_ms = sum(times) / len(times) + print(f" TTFT: {avg_ms:.1f} ms (threshold: {TTFT_THRESHOLD_MS} ms)") + assert avg_ms < TTFT_THRESHOLD_MS, ( + f"TTFT {avg_ms:.1f}ms > threshold {TTFT_THRESHOLD_MS}ms" + ) + + +@requires_model_path +def test_performance_throughput(compiled_model, tokenizer, generation_config): + """Throughput should meet minimum threshold.""" + prompt = "Once upon a time" + num_new_tokens = 20 + + # Warmup + _generate(compiled_model, tokenizer, generation_config, prompt, max_new_tokens=5) + + # Measure + t0 = time.perf_counter() + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=num_new_tokens, + ) + elapsed = time.perf_counter() - t0 + + input_len = len(tokenizer.encode(prompt)) + actual_new = len(tokens) - input_len + throughput = actual_new / elapsed if elapsed > 0 else 0 + + print( + f" Throughput: {throughput:.1f} tok/s ({actual_new} tokens in {elapsed:.2f}s)" + ) + print(f" Threshold: {THROUGHPUT_THRESHOLD} tok/s") + assert throughput > THROUGHPUT_THRESHOLD, ( + f"Throughput {throughput:.1f} tok/s < threshold {THROUGHPUT_THRESHOLD}" + ) + + +# ── Multi-Prompt Quality Test ────────────────────────────────────────── + + +@requires_model_path +def test_multi_prompt_generation(compiled_model, tokenizer, generation_config): + """Multiple prompts should produce coherent outputs.""" + prompts = [ + "The capital of France is", + "def fibonacci(n):", + "The largest ocean on Earth is", + "To make a chocolate cake, you need", + ] + + for prompt in prompts: + _, text = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=30, + ) + generated = text[len(prompt) :].strip() + words = generated.split() + assert len(words) >= 2, ( + f"Prompt '{prompt}' generated too few words: '{generated}'" + ) + assert not _is_repetitive(generated), ( + f"Prompt '{prompt}' produced repetitive output: '{generated}'" + ) + print(f" '{prompt[:30]}...' -> {generated[:60]}...") + + +# ── Standalone runner ─────────────────────────────────────────────────── + +if __name__ == "__main__": + print("=" * 60) + print("Qwen3.5-4B Integration Tests") + print("=" * 60) + + if not MODEL_PATH: + print("\nQWEN35_MODEL_PATH not set. Provide the model path to run tests:") + print(" QWEN35_MODEL_PATH=/path/to/Qwen3.5-4B \\") + print(" QWEN35_COMPILED_PATH=/mnt/models/qwen35_9b_traced \\") + print(" python -m pytest test/integration/test_model.py --capture=tee-sys") + sys.exit(0) + + # Setup + from transformers import AutoTokenizer, GenerationConfig as GenConfig + + tok = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right") + if tok.pad_token is None: + tok.pad_token = tok.eos_token + gen_cfg = GenConfig( + do_sample=True, + top_k=1, + pad_token_id=tok.pad_token_id, + eos_token_id=tok.eos_token_id, + ) + + # Build model + import json + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + + nc = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=False, + flash_decoding_enabled=False, + logical_nc_config=2, + save_sharded_checkpoint=True, + ) + + with open(os.path.join(MODEL_PATH, "config.json")) as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + config_dict["tie_word_embeddings"] = full_config.get( + "tie_word_embeddings", + text_config.get("tie_word_embeddings", False), + ) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + ic = Qwen35InferenceConfig(neuron_config=nc, **config_dict) + + cp = COMPILED_PATH + if not os.path.exists(os.path.join(cp, "model.pt")): + print(f"Compiling to {cp}...") + m = NeuronQwen35ForCausalLM(MODEL_PATH, ic) + m.compile(cp) + del m + gc.collect() + + print(f"Loading from {cp}...") + model = NeuronQwen35ForCausalLM(cp) + model.load(cp) + + tests = [ + ("model_loads", lambda: test_model_loads(model)), + ("model_generates", lambda: test_model_generates(model, tok, gen_cfg)), + ("output_coherence", lambda: test_output_coherence(model, tok, gen_cfg)), + ("top_token_valid", lambda: test_top_token_valid(model, tok, gen_cfg)), + ("capital_of_france", lambda: test_capital_of_france(model, tok, gen_cfg)), + ("performance_ttft", lambda: test_performance_ttft(model, tok, gen_cfg)), + ( + "performance_throughput", + lambda: test_performance_throughput(model, tok, gen_cfg), + ), + ( + "multi_prompt_generation", + lambda: test_multi_prompt_generation(model, tok, gen_cfg), + ), + ] + + passed = 0 + for name, fn in tests: + print(f"\n--- {name} ---") + try: + fn() + print(f" PASS") + passed += 1 + except Exception as e: + print(f" FAIL: {e}") + + print(f"\n{'=' * 60}") + print(f"Results: {passed}/{len(tests)} passed") + print(f"{'=' * 60}") diff --git a/contrib/models/Qwen3.5-4B/test/parity/deltanet_path_probe.py b/contrib/models/Qwen3.5-4B/test/parity/deltanet_path_probe.py new file mode 100644 index 00000000..e3cb9f11 --- /dev/null +++ b/contrib/models/Qwen3.5-4B/test/parity/deltanet_path_probe.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""DeltaNet path parity probe for Qwen3.5-4B. + +Run on a Neuron instance after weights are available. This intentionally is not +part of normal pytest collection because it can compile NKI kernels and requires +the full checkpoint. + +Example: + cd contrib/models/Qwen3.5-4B + QWEN35_MODEL_PATH=/mnt/models/Qwen3.5-4B \\ + python test/parity/deltanet_path_probe.py --layer-idx 0 --seq-len 128 +""" + +import argparse +import json +import os +import sys +from contextlib import contextmanager + +import torch +import torch.nn.functional as F + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from neuronx_distributed_inference.models.config import NeuronConfig +from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronGatedDeltaNet + + +@contextmanager +def patched_env(**updates): + old = {k: os.environ.get(k) for k in updates} + for k, v in updates.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = str(v) + try: + yield + finally: + for k, v in old.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v + + +def cosine(a, b): + return F.cosine_similarity(a.float().flatten(), b.float().flatten(), dim=0).item() + + +def max_abs(a, b): + return (a.float() - b.float()).abs().max().item() + + +def load_config(model_path, tp_degree): + with open(os.path.join(model_path, "config.json")) as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + config_dict["tie_word_embeddings"] = full_config.get( + "tie_word_embeddings", + text_config.get("tie_word_embeddings", False), + ) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + + neuron_config = NeuronConfig( + tp_degree=tp_degree, + batch_size=1, + max_batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + enable_bucketing=False, + flash_decoding_enabled=False, + logical_nc_config=2, + ) + return Qwen35InferenceConfig(neuron_config=neuron_config, **config_dict) + + +def strip_prefix_state_dict(state_dict): + stripped = {} + for k, v in state_dict.items(): + if k.startswith("language_model."): + stripped[k.replace("language_model.", "", 1)] = v + elif k.startswith("model.language_model."): + stripped[k.replace("model.language_model.", "", 1)] = v + elif k.startswith("model."): + stripped[k.replace("model.", "", 1)] = v + else: + stripped[k] = v + return stripped + + +def load_deltanet_layer_weights(module, model_path, layer_idx): + from transformers import AutoModelForCausalLM + + hf_model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + low_cpu_mem_usage=True, + ) + state_dict = strip_prefix_state_dict(hf_model.state_dict()) + prefix = f"layers.{layer_idx}.linear_attn." + layer_sd = {} + for name in module.state_dict().keys(): + key = prefix + name + if key in state_dict: + layer_sd[name] = state_dict[key] + + missing, unexpected = module.load_state_dict(layer_sd, strict=False) + missing = [m for m in missing if not m.endswith("_buffer")] + if missing or unexpected: + raise RuntimeError(f"weight load mismatch: missing={missing}, unexpected={unexpected}") + del hf_model + + +def run_path(module, hidden_states, mode): + env = { + "USE_NKI_FUSED": "0", + "USE_NKI_CHUNKED": None, + "USE_NKI": None, + "DELTANET_SEQUENTIAL": None, + "USE_PYTORCH_CHUNK": None, + } + if mode == "sequential": + env["DELTANET_SEQUENTIAL"] = "1" + elif mode == "fused": + env["USE_NKI_FUSED"] = "1" + elif mode == "chunk": + env["USE_PYTORCH_CHUNK"] = "1" + elif mode == "nki_recurrent": + env["USE_NKI"] = "1" + else: + raise ValueError(f"unknown mode: {mode}") + + with patched_env(**env): + with torch.no_grad(): + out, _dummy_kv, rec_state, conv_state = module(hidden_states) + return out.detach().cpu(), rec_state.detach().cpu(), conv_state.detach().cpu() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", default=os.environ.get("QWEN35_MODEL_PATH")) + parser.add_argument("--layer-idx", type=int, default=0) + parser.add_argument("--seq-len", type=int, default=128) + parser.add_argument("--tp-degree", type=int, default=4) + parser.add_argument( + "--compare", + nargs="+", + default=["fused"], + choices=["fused", "chunk", "nki_recurrent"], + ) + parser.add_argument("--device", default="cpu", choices=["cpu", "xla"]) + args = parser.parse_args() + + if not args.model_path: + raise SystemExit("Set QWEN35_MODEL_PATH or pass --model-path") + + device = torch.device("cpu") + if args.device == "xla": + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + + config = load_config(args.model_path, args.tp_degree) + if config.layer_types[args.layer_idx] != "linear_attention": + raise SystemExit(f"layer {args.layer_idx} is {config.layer_types[args.layer_idx]}, not DeltaNet") + + module = NeuronGatedDeltaNet(config, args.layer_idx).to(device) + load_deltanet_layer_weights(module, args.model_path, args.layer_idx) + module = module.to(device=device, dtype=torch.bfloat16).eval() + + torch.manual_seed(0) + hidden_states = torch.randn( + 1, + args.seq_len, + config.hidden_size, + dtype=torch.bfloat16, + device=device, + ) + + ref = run_path(module, hidden_states, "sequential") + print(f"reference=sequential layer={args.layer_idx} seq_len={args.seq_len}") + for mode in args.compare: + cur = run_path(module, hidden_states, mode) + print(f"\nmode={mode}") + for label, ref_t, cur_t in zip(("output", "recurrent_state", "conv_state"), ref, cur): + print( + f"{label}: cosine={cosine(ref_t, cur_t):.6f} " + f"max_abs={max_abs(ref_t, cur_t):.6f} " + f"shape={tuple(cur_t.shape)}" + ) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen3.5-4B/test/unit/__init__.py b/contrib/models/Qwen3.5-4B/test/unit/__init__.py new file mode 100644 index 00000000..04f8b7b7 --- /dev/null +++ b/contrib/models/Qwen3.5-4B/test/unit/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/contrib/models/Qwen3.5-4B/test/unit/test_config.py b/contrib/models/Qwen3.5-4B/test/unit/test_config.py new file mode 100644 index 00000000..5e672c13 --- /dev/null +++ b/contrib/models/Qwen3.5-4B/test/unit/test_config.py @@ -0,0 +1,201 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Qwen3.5-4B inference configuration. + +CPU-only tests that validate config parsing, layer type setup, +DeltaNet parameter defaults, RoPE configuration, and weight conversion logic. +""" + +import os +import sys +import unittest +from unittest.mock import MagicMock + +import torch + +# Ensure the contrib root (Qwen3.5-4B/) is on sys.path +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + Qwen35InferenceConfig, + convert_qwen35_hf_to_neuron_state_dict, +) +from neuronx_distributed_inference.models.config import NeuronConfig + + +def _make_config(**overrides): + """Create a Qwen35InferenceConfig with reasonable defaults.""" + neuron_config = NeuronConfig( + tp_degree=overrides.pop("tp_degree", 4), + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + ) + defaults = dict( + hidden_size=2560, + num_hidden_layers=32, + num_attention_heads=16, + num_key_value_heads=4, + head_dim=256, + intermediate_size=9216, + vocab_size=248320, + rms_norm_eps=1e-6, + max_position_embeddings=262144, + rope_theta=10000000, + hidden_act="silu", + tie_word_embeddings=True, + # DeltaNet-specific + linear_num_value_heads=32, + linear_num_key_heads=16, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + ) + defaults.update(overrides) + config = Qwen35InferenceConfig(neuron_config=neuron_config, **defaults) + return config + + +class TestConfigParsing(unittest.TestCase): + """Test basic config attribute initialization.""" + + def test_hidden_size(self): + config = _make_config() + self.assertEqual(config.hidden_size, 2560) + + def test_num_hidden_layers(self): + config = _make_config() + self.assertEqual(config.num_hidden_layers, 32) + + def test_num_attention_heads(self): + config = _make_config() + self.assertEqual(config.num_attention_heads, 16) + + def test_num_key_value_heads(self): + config = _make_config() + self.assertEqual(config.num_key_value_heads, 4) + + def test_head_dim(self): + config = _make_config() + self.assertEqual(config.head_dim, 256) + + def test_intermediate_size(self): + config = _make_config() + self.assertEqual(config.intermediate_size, 9216) + + def test_vocab_size(self): + config = _make_config() + self.assertEqual(config.vocab_size, 248320) + + def test_hidden_act(self): + config = _make_config() + self.assertEqual(config.hidden_act, "silu") + + +class TestLayerTypes(unittest.TestCase): + """Test hybrid layer type assignment (3 DeltaNet + 1 GQA) x 8.""" + + def test_layer_types_length(self): + config = _make_config() + self.assertEqual(len(config.layer_types), 32) + + def test_layer_types_pattern(self): + """Every 4th layer (3, 7, 11, ...) should be full_attention.""" + config = _make_config() + for i in range(32): + expected = "full_attention" if i % 4 == 3 else "linear_attention" + self.assertEqual(config.layer_types[i], expected, f"Layer {i} mismatch") + + def test_deltanet_layer_count(self): + config = _make_config() + dn_count = sum(1 for t in config.layer_types if t == "linear_attention") + self.assertEqual(dn_count, 24) + + def test_gqa_layer_count(self): + config = _make_config() + gqa_count = sum(1 for t in config.layer_types if t == "full_attention") + self.assertEqual(gqa_count, 8) + + +class TestDeltaNetConfig(unittest.TestCase): + """Test DeltaNet-specific configuration defaults.""" + + def test_linear_num_value_heads(self): + config = _make_config() + self.assertEqual(config.linear_num_value_heads, 32) + + def test_linear_num_key_heads(self): + config = _make_config() + self.assertEqual(config.linear_num_key_heads, 16) + + def test_linear_key_head_dim(self): + config = _make_config() + self.assertEqual(config.linear_key_head_dim, 128) + + def test_linear_value_head_dim(self): + config = _make_config() + self.assertEqual(config.linear_value_head_dim, 128) + + def test_linear_conv_kernel_dim(self): + config = _make_config() + self.assertEqual(config.linear_conv_kernel_dim, 4) + + +class TestRoPEConfig(unittest.TestCase): + """Test partial RoPE configuration.""" + + def test_partial_rotary_factor(self): + config = _make_config() + self.assertAlmostEqual(config.partial_rotary_factor, 0.25) + + def test_rope_dim(self): + """rope_dim = head_dim * partial_rotary_factor = 256 * 0.25 = 64.""" + config = _make_config() + self.assertEqual(config.rope_dim, 64) + + def test_attn_output_gate(self): + config = _make_config() + self.assertTrue(config.attn_output_gate) + + def test_mrope_section(self): + config = _make_config() + self.assertEqual(config.mrope_section, [11, 11, 10]) + + def test_mrope_interleaved(self): + config = _make_config() + self.assertTrue(config.mrope_interleaved) + + +class TestNeuronConfig(unittest.TestCase): + """Test Neuron-specific configuration settings.""" + + def test_neuron_config_cls(self): + """Qwen3.5-4B is dense -- uses NeuronConfig, NOT MoENeuronConfig.""" + self.assertEqual( + Qwen35InferenceConfig.get_neuron_config_cls(), + NeuronConfig, + ) + + def test_required_attributes(self): + config = _make_config() + required = config.get_required_attributes() + self.assertIn("hidden_size", required) + self.assertIn("num_hidden_layers", required) + self.assertIn("linear_num_value_heads", required) + self.assertIn("linear_key_head_dim", required) + self.assertIn("layer_types", required) + + def test_output_attentions_default(self): + config = _make_config() + self.assertFalse(config.output_attentions) + + def test_output_hidden_states_default(self): + config = _make_config() + self.assertFalse(config.output_hidden_states) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.5-4B/test/unit/test_weight_conversion.py b/contrib/models/Qwen3.5-4B/test/unit/test_weight_conversion.py new file mode 100644 index 00000000..fc21cf8f --- /dev/null +++ b/contrib/models/Qwen3.5-4B/test/unit/test_weight_conversion.py @@ -0,0 +1,445 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Qwen3.5-4B HF-to-NxDI weight conversion. + +CPU-only tests that validate: +- RMSNorm (+1 convention) weight conversion +- GQA q_proj interleaved split (query + gate) +- QK norm key renaming (q_norm -> q_layernorm, k_norm -> k_layernorm) +- Fused QKV concatenation +- DeltaNet layer weights pass through unchanged +- VL wrapper prefix stripping +- rank_util injection +""" + +import os +import sys +import unittest + +import torch + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + Qwen35InferenceConfig, + NeuronQwen35ForCausalLM, + convert_qwen35_hf_to_neuron_state_dict, +) +from neuronx_distributed_inference.models.config import NeuronConfig + + +def _make_mini_config(num_layers=4, tp_degree=2, fused_qkv=True): + """Create a small Qwen35InferenceConfig for testing.""" + neuron_config = NeuronConfig( + tp_degree=tp_degree, + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + fused_qkv=fused_qkv, + ) + config = Qwen35InferenceConfig( + neuron_config=neuron_config, + hidden_size=256, + num_hidden_layers=num_layers, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=64, + intermediate_size=512, + vocab_size=1000, + rms_norm_eps=1e-6, + max_position_embeddings=4096, + rope_theta=10000, + hidden_act="silu", + linear_num_value_heads=8, + linear_num_key_heads=4, + linear_key_head_dim=32, + linear_value_head_dim=32, + linear_conv_kernel_dim=4, + ) + return config + + +def _make_mini_state_dict(config): + """Create a minimal HF-style state dict for conversion testing.""" + sd = {} + H = config.hidden_size # 256 + I = config.intermediate_size # 512 + V = config.vocab_size # 1000 + num_heads = config.num_attention_heads # 4 + num_kv = config.num_key_value_heads # 2 + head_dim = config.head_dim # 64 + + sd["embed_tokens.weight"] = torch.randn(V, H, dtype=torch.bfloat16) * 0.02 + sd["lm_head.weight"] = torch.randn(V, H, dtype=torch.bfloat16) * 0.02 + sd["norm.weight"] = torch.zeros(H, dtype=torch.bfloat16) # +1 convention: zeros + + for l in range(config.num_hidden_layers): + sd[f"layers.{l}.input_layernorm.weight"] = torch.zeros(H, dtype=torch.bfloat16) + sd[f"layers.{l}.post_attention_layernorm.weight"] = torch.zeros( + H, dtype=torch.bfloat16 + ) + + # Dense MLP (all layers) + sd[f"layers.{l}.mlp.gate_proj.weight"] = ( + torch.randn(I, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.mlp.up_proj.weight"] = ( + torch.randn(I, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.mlp.down_proj.weight"] = ( + torch.randn(H, I, dtype=torch.bfloat16) * 0.02 + ) + + if config.layer_types[l] == "full_attention": + # GQA layer: q_proj is interleaved [head0_q | head0_gate | head1_q | ...] + q_proj = ( + torch.randn(num_heads * head_dim * 2, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.q_proj.weight"] = q_proj + sd[f"layers.{l}.self_attn.k_proj.weight"] = ( + torch.randn(num_kv * head_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.v_proj.weight"] = ( + torch.randn(num_kv * head_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.o_proj.weight"] = ( + torch.randn(H, num_heads * head_dim, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.q_norm.weight"] = torch.zeros( + head_dim, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.self_attn.k_norm.weight"] = torch.zeros( + head_dim, dtype=torch.bfloat16 + ) + else: + # DeltaNet layer: minimal required weights + key_dim = config.linear_num_key_heads * config.linear_key_head_dim # 128 + value_dim = ( + config.linear_num_value_heads * config.linear_value_head_dim + ) # 256 + conv_dim = key_dim * 2 + value_dim # 512 + sd[f"layers.{l}.linear_attn.in_proj_qkv.weight"] = ( + torch.randn(conv_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_z.weight"] = ( + torch.randn(value_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_a.weight"] = ( + torch.randn(config.linear_num_value_heads, H, dtype=torch.bfloat16) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_b.weight"] = ( + torch.randn(config.linear_num_value_heads, H, dtype=torch.bfloat16) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.conv1d.weight"] = ( + torch.randn( + conv_dim, 1, config.linear_conv_kernel_dim, dtype=torch.bfloat16 + ) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.A_log"] = torch.randn( + config.linear_num_value_heads, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.linear_attn.dt_bias"] = torch.randn( + config.linear_num_value_heads, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.linear_attn.norm.weight"] = ( + torch.randn(value_dim, dtype=torch.bfloat16) * 0.5 + ) + sd[f"layers.{l}.linear_attn.out_proj.weight"] = ( + torch.randn(H, value_dim, dtype=torch.bfloat16) * 0.02 + ) + + return sd + + +class TestNormConversion(unittest.TestCase): + """Test (+1 convention) RMSNorm weight conversion.""" + + def test_norm_weight_adds_one(self): + """Weights initialized to zero should become 1.0 after conversion.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + # norm.weight was zeros -> should now be ones + torch.testing.assert_close( + result["norm.weight"], + torch.ones_like(result["norm.weight"]), + ) + + def test_input_layernorm_adds_one(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + w = result[f"layers.{l}.input_layernorm.weight"] + self.assertTrue( + torch.allclose(w, torch.ones_like(w)), + f"Layer {l} input_layernorm not converted", + ) + + def test_post_attn_layernorm_adds_one(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + w = result[f"layers.{l}.post_attention_layernorm.weight"] + self.assertTrue( + torch.allclose(w, torch.ones_like(w)), + f"Layer {l} post_attention_layernorm not converted", + ) + + def test_qk_norm_adds_one(self): + """Q/K norms on GQA layers should also get +1 applied.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + q_w = result[f"layers.{l}.self_attn.q_layernorm.weight"] + k_w = result[f"layers.{l}.self_attn.k_layernorm.weight"] + self.assertTrue( + torch.allclose(q_w, torch.ones_like(q_w)), + f"Layer {l} q_layernorm not converted", + ) + self.assertTrue( + torch.allclose(k_w, torch.ones_like(k_w)), + f"Layer {l} k_layernorm not converted", + ) + + +class TestQProjSplit(unittest.TestCase): + """Test q_proj interleaved split into query + gate.""" + + def test_q_proj_split_shapes(self): + """q_proj (num_heads * head_dim * 2, H) -> separate query and gate.""" + config = _make_mini_config(fused_qkv=False) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + # After split: q_proj should be (num_heads * head_dim, H) = (256, 256) + q_w = result[f"layers.{l}.self_attn.q_proj.weight"] + gate_w = result[f"layers.{l}.self_attn.output_gate_proj.weight"] + expected_shape = ( + config.num_attention_heads * config.head_dim, + config.hidden_size, + ) + self.assertEqual( + q_w.shape, expected_shape, f"Layer {l} q_proj shape wrong" + ) + self.assertEqual( + gate_w.shape, expected_shape, f"Layer {l} gate shape wrong" + ) + + def test_q_proj_deinterleave_correct(self): + """Verify the interleaved split correctly separates query and gate.""" + config = _make_mini_config(fused_qkv=False) + sd = _make_mini_state_dict(config) + + # Create a known pattern: head0 query is 1s, head0 gate is 2s, etc. + l = 3 # First full_attention layer (layer 3) + num_heads = config.num_attention_heads + head_dim = config.head_dim + H = config.hidden_size + + interleaved = torch.zeros(num_heads * head_dim * 2, H, dtype=torch.bfloat16) + for h in range(num_heads): + interleaved[h * head_dim * 2 : h * head_dim * 2 + head_dim, :] = float( + h + 1 + ) # query + interleaved[h * head_dim * 2 + head_dim : (h + 1) * head_dim * 2, :] = ( + float(h + 100) + ) # gate + + sd[f"layers.{l}.self_attn.q_proj.weight"] = interleaved + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + q_w = result[f"layers.{l}.self_attn.q_proj.weight"] + gate_w = result[f"layers.{l}.self_attn.output_gate_proj.weight"] + + for h in range(num_heads): + q_head = q_w[h * head_dim : (h + 1) * head_dim, :] + gate_head = gate_w[h * head_dim : (h + 1) * head_dim, :] + self.assertTrue( + torch.all(q_head == float(h + 1)), f"Head {h} query values wrong" + ) + self.assertTrue( + torch.all(gate_head == float(h + 100)), f"Head {h} gate values wrong" + ) + + +class TestQKNormRename(unittest.TestCase): + """Test q_norm -> q_layernorm and k_norm -> k_layernorm renaming.""" + + def test_old_keys_removed(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertNotIn(f"layers.{l}.self_attn.q_norm.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.k_norm.weight", result) + + def test_new_keys_present(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertIn(f"layers.{l}.self_attn.q_layernorm.weight", result) + self.assertIn(f"layers.{l}.self_attn.k_layernorm.weight", result) + + +class TestFusedQKV(unittest.TestCase): + """Test fused QKV concatenation for attention layers.""" + + def test_fused_qkv_shape(self): + config = _make_mini_config(fused_qkv=True) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + fused_key = f"layers.{l}.self_attn.Wqkv.weight" + self.assertIn(fused_key, result, f"Layer {l} missing Wqkv") + + q_dim = config.num_attention_heads * config.head_dim + k_dim = config.num_key_value_heads * config.head_dim + v_dim = config.num_key_value_heads * config.head_dim + expected_rows = q_dim + k_dim + v_dim + self.assertEqual(result[fused_key].shape[0], expected_rows) + + def test_fused_qkv_removes_individual_keys(self): + config = _make_mini_config(fused_qkv=True) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertNotIn(f"layers.{l}.self_attn.q_proj.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.k_proj.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.v_proj.weight", result) + + +class TestDeltaNetPassthrough(unittest.TestCase): + """Test that DeltaNet layer weights pass through conversion unchanged.""" + + def test_deltanet_weights_unchanged(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Record original DeltaNet weights + originals = {} + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + key = f"layers.{l}.linear_attn.in_proj_qkv.weight" + originals[key] = sd[key].clone() + + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for key, orig in originals.items(): + self.assertIn(key, result, f"Missing: {key}") + torch.testing.assert_close( + result[key], orig, msg=f"DeltaNet weight changed: {key}" + ) + + def test_deltanet_norm_not_converted(self): + """DeltaNet layers use standard RMSNorm (NOT +1 convention). + The norm weight should NOT be changed.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Set DeltaNet norm to a known non-zero value + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + sd[f"layers.{l}.linear_attn.norm.weight"] = torch.full( + (config.linear_num_value_heads * config.linear_value_head_dim,), + 0.87, + dtype=torch.bfloat16, + ) + + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + w = result[f"layers.{l}.linear_attn.norm.weight"] + # Should still be ~0.87, NOT 1.87 + self.assertTrue( + torch.allclose(w, torch.full_like(w, 0.87), atol=0.01), + f"Layer {l} DeltaNet norm was incorrectly modified", + ) + + +class TestRankUtil(unittest.TestCase): + """Test rank_util tensor injection.""" + + def test_rank_util_present(self): + config = _make_mini_config(tp_degree=4) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + self.assertIn("rank_util.rank", result) + expected = torch.arange(0, 4, dtype=torch.int32) + torch.testing.assert_close(result["rank_util.rank"], expected) + + def test_gqa_layer_rank_util(self): + config = _make_mini_config(tp_degree=4) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + key = f"layers.{l}.self_attn.rank_util.rank" + self.assertIn(key, result) + expected = torch.arange(0, 4, dtype=torch.int32) + torch.testing.assert_close(result[key], expected) + + +class TestVLPrefixStripping(unittest.TestCase): + """Test VL wrapper prefix stripping in convert_hf_to_neuron_state_dict.""" + + def test_language_model_prefix_stripped(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Wrap with VL prefix + vl_sd = {} + for k, v in sd.items(): + vl_sd[f"language_model.{k}"] = v + vl_sd["visual.encoder.weight"] = torch.zeros(10) # should be skipped + vl_sd["mtp.something"] = torch.zeros(5) # should be skipped + + result = NeuronQwen35ForCausalLM.convert_hf_to_neuron_state_dict(vl_sd, config) + self.assertNotIn("visual.encoder.weight", result) + self.assertNotIn("mtp.something", result) + self.assertIn("norm.weight", result) + + def test_model_language_model_prefix_stripped(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + vl_sd = {} + for k, v in sd.items(): + vl_sd[f"model.language_model.{k}"] = v + + result = NeuronQwen35ForCausalLM.convert_hf_to_neuron_state_dict(vl_sd, config) + self.assertIn("norm.weight", result) + + def test_tied_embeddings_synthesize_lm_head(self): + config = _make_mini_config() + config.tie_word_embeddings = True + sd = _make_mini_state_dict(config) + embed = sd["embed_tokens.weight"] + del sd["lm_head.weight"] + + result = NeuronQwen35ForCausalLM.convert_hf_to_neuron_state_dict(sd, config) + self.assertIn("lm_head.weight", result) + torch.testing.assert_close(result["lm_head.weight"], embed) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.5-9B/README.md b/contrib/models/Qwen3.5-9B/README.md new file mode 100644 index 00000000..0b8dc6b3 --- /dev/null +++ b/contrib/models/Qwen3.5-9B/README.md @@ -0,0 +1,156 @@ +# Contrib Model: Qwen3.5-9B + +NeuronX Distributed Inference implementation of Qwen3.5-9B, a dense hybrid DeltaNet + GQA decoder from Alibaba Cloud. + +This variant is forked from the proven Qwen3.5-2B contrib implementation in PR 141. It keeps the same working cache architecture: + +- Standard GQA layers use NxDI `KVCacheManager`. +- DeltaNet layers return dummy KV tensors to satisfy NxDI cache plumbing. +- Real DeltaNet state is carried through layer-local `recurrent_state_buffer` and `conv_state_buffer` side-channel aliases. + +## Model Information + +| Feature | Value | +| --- | --- | +| HuggingFace ID | `Qwen/Qwen3.5-9B` | +| Model type | `qwen3_5_text` under top-level `qwen3_5` | +| Layers | 32: 24 DeltaNet + 8 GQA | +| Layer pattern | `[3 DeltaNet + 1 GQA] x 8` | +| Hidden size | 4096 | +| MLP | Dense SwiGLU, intermediate size 12288 | +| GQA attention | 16 Q heads, 4 KV heads, head_dim 256 | +| DeltaNet | 32 value heads, 16 key heads, k_dim=v_dim=128 | +| Conv kernel | 4, state stores last 3 pre-conv QKV tokens | +| RoPE | Partial RoPE, 25% of head_dim = 64 dims | +| Vocabulary | 248,320 | +| Tied embeddings | No | + +Derived DeltaNet shapes: + +| Tensor | Shape | +| --- | --- | +| `in_proj_qkv.weight` | `[8192, 4096]` | +| `in_proj_z.weight` | `[4096, 4096]` | +| `in_proj_a.weight` | `[32, 4096]` | +| `in_proj_b.weight` | `[32, 4096]` | +| `conv1d.weight` | `[8192, 1, 4]` | +| `recurrent_state_buffer` | `[max_batch, 32, 128, 128]` | +| `conv_state_buffer` | `[max_batch, 8192, 3]` | + +## Status + +This 9B contrib is prepared for bring-up. The implementation should be validated on Trn2 before TP=2 or Trn1 experiments. + +Validated baseline: + +- Qwen3.5-2B PR 141: trn2.3xlarge, TP=4, LNC=2, SDK 2.29, NKI 0.3. + +Unvalidated for this folder until run: + +- Qwen3.5-9B compile and generation +- TP=2 +- Trn1 +- long-context HBM limits + +## Usage + +```python +import json +import os +import torch +from transformers import AutoTokenizer, GenerationConfig +from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter + +from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + +model_path = "/mnt/models/Qwen3.5-9B" +compiled_path = "/mnt/models/qwen35_9b_traced" + +neuron_config = NeuronConfig( + tp_degree=4, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + logical_nc_config=2, + enable_bucketing=False, + flash_decoding_enabled=False, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + save_sharded_checkpoint=True, +) + +with open(os.path.join(model_path, "config.json")) as f: + hf_config = json.load(f) +text_config = hf_config.get("text_config", hf_config) +config_dict = dict(text_config) +config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) +config_dict["tie_word_embeddings"] = hf_config.get( + "tie_word_embeddings", + text_config.get("tie_word_embeddings", False), +) +if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + +config = Qwen35InferenceConfig(neuron_config=neuron_config, **config_dict) + +model = NeuronQwen35ForCausalLM(model_path, config) +model.compile(compiled_path) + +model = NeuronQwen35ForCausalLM(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right") +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +gen_config = GenerationConfig( + do_sample=True, + top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, +) + +inputs = tokenizer("The capital of France is", return_tensors="pt") +gen_model = HuggingFaceGenerationAdapter(model) +outputs = gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=32, +) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) +``` + +## Testing + +CPU unit tests: + +```bash +cd contrib/models/Qwen3.5-9B +pytest test/unit -v +``` + +Trainium integration: + +```bash +cd contrib/models/Qwen3.5-9B +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +QWEN35_MODEL_PATH=/mnt/models/Qwen3.5-9B \ +QWEN35_COMPILED_PATH=/mnt/models/qwen35_9b_traced \ +QWEN35_TP_DEGREE=4 \ +QWEN35_SEQ_LEN=128 \ +pytest test/integration/test_model.py --capture=tee-sys -v +``` + +## Known Limitations + +1. SDK 2.29+ and NKI 0.3 are expected. +2. DeltaNet weights are replicated across TP ranks in v1. +3. Dummy KV wastes HBM for DeltaNet layers. +4. First-token and multi-token logit parity are expected to show the same BF16 recurrent divergence reported by PR 141 until the DeltaNet precision work is done. +5. Hybrid cache, DeltaNet TP sharding, quantization, speculative decoding, and MoE are out of scope for first bring-up. diff --git a/contrib/models/Qwen3.5-9B/src/__init__.py b/contrib/models/Qwen3.5-9B/src/__init__.py new file mode 100644 index 00000000..f8e014e6 --- /dev/null +++ b/contrib/models/Qwen3.5-9B/src/__init__.py @@ -0,0 +1,27 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from src.modeling_qwen35 import ( + NeuronGatedDeltaNet, + NeuronQwen35Attention, + NeuronQwen35DecoderLayer, + NeuronQwen35ForCausalLM, + NeuronQwen35Model, + Qwen35DecoderModelInstance, + Qwen35InferenceConfig, + Qwen35MLP, + Qwen35ModelWrapper, +) + +__all__ = [ + # Text decoder + "NeuronGatedDeltaNet", + "NeuronQwen35Attention", + "NeuronQwen35DecoderLayer", + "NeuronQwen35ForCausalLM", + "NeuronQwen35Model", + "Qwen35DecoderModelInstance", + "Qwen35InferenceConfig", + "Qwen35MLP", + "Qwen35ModelWrapper", +] diff --git a/contrib/models/Qwen3.5-9B/src/modeling_qwen35.py b/contrib/models/Qwen3.5-9B/src/modeling_qwen35.py new file mode 100644 index 00000000..959c3170 --- /dev/null +++ b/contrib/models/Qwen3.5-9B/src/modeling_qwen35.py @@ -0,0 +1,2528 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +NxDI contrib: Qwen3.5-9B (qwen3_5 -- dense model) + +Hybrid DeltaNet + Standard Attention + Dense MLP architecture. + +24 of 32 layers use Gated DeltaNet (linear recurrent attention) +8 of 32 layers use standard GQA with KV cache + output gate +All 32 layers use a dense SwiGLU MLP (intermediate_size=12288) + +Architecture details: +- DeltaNet layers: separate in_proj_{qkv, z, a, b}, causal conv1d on QKV, gated delta rule +- Attention layers: q_proj doubled (Q + gate), partial RoPE (25% of head_dim), sigmoid output gate +- Dense MLP: standard SwiGLU (gate_proj, up_proj, down_proj) -- no MoE, no router, no experts +- KV cache: NxDI KVCacheManager for attention layers; DeltaNet layers store recurrent+conv + state as nn.Parameter buffers and return dummy KV tuples +""" + +import gc +import math +import logging +import os +import sys +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm + +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.utils import cpu_mode + +try: + from nki import jit as nki_jit # NKI 0.3.0+ (SDK 2.29) +except ImportError: + from torch_neuronx.xla_impl.ops import nki_jit # NKI 0.2.x (SDK 2.28) +from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeRMSNorm + +from src.nki_kernels.nki_deltanet import deltanet_recurrent_fwd as _deltanet_nki_kernel +from src.nki_kernels.nki_deltanet import ( + deltanet_recurrent_fwd_state as _deltanet_nki_kernel_state, +) +from src.nki_kernels.nki_deltanet_chunked import ( + deltanet_chunk_step as _deltanet_nki_chunk_step, +) +from src.nki_kernels.nki_deltanet_fused import ( + deltanet_fused_chunked_fwd as _deltanet_fused_kernel, +) +from src.nki_kernels.nki_deltanet_fused import ( + _make_lower_mask, + _make_lower_mask_diag, + _make_identity, +) + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + NeuronConfig, +) +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, + DecoderModelInstance, + ModelWrapper, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.models.layer_boundary_marker import ( + ModuleMarkerEndWrapper, + ModuleMarkerStartWrapper, +) + +logger = logging.getLogger(__name__) + +try: + _flash_fwd_call = nki_jit()(attention_isa_kernel) +except TypeError: + from torch_neuronx.xla_impl.ops import nki_jit as _torch_xla_nki_jit + + _flash_fwd_call = _torch_xla_nki_jit()(attention_isa_kernel) + +# Option B: Direct nkilib flash attention for head_dim > 128 +USE_NKILIB_KERNEL = os.environ.get("USE_NKILIB_KERNEL", "0") == "1" + +_nkilib_flash_attn = None +if USE_NKILIB_KERNEL: + try: + import neuronxcc.nki as _nki + from neuronx_distributed_inference.modules.attention.attention_base import ( + peel_decorations as _peel_decorations, + get_platform_target as _get_platform_target, + ) + from neuronxcc.nki.compiler import ( + skip_middle_end_transformations as _skip_middle_end, + enable_stack_allocator as _enable_stack_allocator, + ) + + import importlib + + _fork_path = "/home/ubuntu/nki-library-fork/nkilib_src" + if os.path.isdir(_fork_path) and _fork_path not in sys.path: + sys.path.insert(0, _fork_path) + _to_remove = [k for k in sys.modules if k.startswith("nkilib")] + for k in _to_remove: + del sys.modules[k] + import nki.language as _stub_nl + import neuronxcc.nki.language as _real_nl + + for _attr in [ + "NKIObject", + "float8_e4m3fn", + "float8_e4m3fn_x4", + "float8_e5m2_x4", + "float4_e2m1fn_x4", + ]: + if not hasattr(_real_nl, _attr) and hasattr(_stub_nl, _attr): + setattr(_real_nl, _attr, getattr(_stub_nl, _attr)) + from nkilib.core.attention.attention_cte import ( + attention_cte as _attention_cte_raw, + _MAX_HEAD_DIM, + ) + + assert _MAX_HEAD_DIM == 256, ( + f"nkilib fork has _MAX_HEAD_DIM={_MAX_HEAD_DIM}, expected 256. " + f"System nkilib may have been loaded instead of fork." + ) + logger.info( + f"Loaded nkilib attention_cte from fork (_MAX_HEAD_DIM={_MAX_HEAD_DIM})" + ) + + _raw_fn = _peel_decorations(_attention_cte_raw) + _platform = _get_platform_target() + _nkilib_flash_attn = _nki.jit( + _raw_fn, + mode="torchxla", + platform_target=_platform, + show_compiler_tb=True, + debug_kernel=True, + ) + _nkilib_flash_attn = _skip_middle_end(_nkilib_flash_attn) + _nkilib_flash_attn = _enable_stack_allocator( + _nkilib_flash_attn, log_level=logging.INFO + ) + logger.info("Option B: nkilib flash attention loaded for head_dim > 128") + except Exception as e: + logger.warning(f"Option B: Failed to load nkilib flash attention: {e}") + import traceback as _tb + + _tb.print_exc() + _nkilib_flash_attn = None + +# Option A: Detect if patch_attn_kernel was imported +NKILIB_PATCH_ACTIVE = False +try: + from importlib import import_module as _import_module + + _attn_mod = _import_module("neuronxcc.nki._pre_prod_kernels.attn_fwd") + if hasattr(_attn_mod, "_original_attention_nki_kernel_adapter"): + NKILIB_PATCH_ACTIVE = True + logger.info("Option A detected: _pre_prod_kernels patched with nkilib kernel") +except Exception: + pass + + +# ============================================================ +# Newton-Raphson Refined RMSNorm +# ============================================================ +USE_NEWTON_RMSNORM = os.environ.get("USE_NEWTON_RMSNORM") == "1" +USE_PYTHON_RMSNORM = os.environ.get("USE_PYTHON_RMSNORM") == "1" + + +class NewtonRMSNorm(nn.Module): + """RMSNorm with Newton-Raphson refined rsqrt for improved numerical accuracy.""" + + def __init__(self, hidden_size=None, eps=1e-6): + super().__init__() + self.weight = None + if hidden_size is not None: + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.hidden_size = hidden_size + self.variance_epsilon = eps + + def forward(self, hidden_states): + original_dtype = hidden_states.dtype + x = hidden_states.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + y = torch.rsqrt(variance + self.variance_epsilon) + y = y * (3.0 - (variance + self.variance_epsilon) * y * y) * 0.5 + result = x * y + if self.weight is not None: + result = result * self.weight.float() + return result.to(original_dtype) + + +def get_rmsnorm_cls(): + if cpu_mode() or USE_PYTHON_RMSNORM: + return Qwen3MoeRMSNorm + return NewtonRMSNorm if USE_NEWTON_RMSNORM else CustomRMSNorm + + +def l2norm(x, dim=-1, eps=1e-6): + return F.normalize(x, p=2, dim=dim, eps=eps) + + +# ============================================================ +# Gated DeltaNet Module (Linear Recurrent Attention) +# ============================================================ + + +class NeuronGatedDeltaNet(nn.Module): + """ + Gated DeltaNet linear attention for Neuron. + + Replaces standard attention for 24 of 32 layers in Qwen3.5-9B. + Uses a chunk-based linear recurrence instead of KV cache. + + HF weight layout (9B dense): + - in_proj_qkv.weight: (key_dim*2 + value_dim, hidden_size) = (8192, 4096) + - in_proj_z.weight: (value_dim, hidden_size) = (4096, 4096) + - in_proj_a.weight: (num_v_heads, hidden_size) = (32, 4096) + - in_proj_b.weight: (num_v_heads, hidden_size) = (32, 4096) + - conv1d.weight: (conv_dim, 1, conv_kernel_size) = (8192, 1, 4) + - A_log: (num_v_heads,) = (32,) + - dt_bias: (num_v_heads,) = (32,) + - norm.weight: (head_v_dim,) = (128,) + - out_proj.weight: (hidden_size, value_dim) = (4096, 4096) + """ + + def __init__(self, config, layer_idx: int): + super().__init__() + tc = config + + self.hidden_size = tc.hidden_size # 5120 + self.num_v_heads = tc.linear_num_value_heads # 48 + self.num_k_heads = tc.linear_num_key_heads # 16 + self.head_k_dim = tc.linear_key_head_dim # 128 + self.head_v_dim = tc.linear_value_head_dim # 128 + self.key_dim = self.head_k_dim * self.num_k_heads # 2048 + self.value_dim = self.head_v_dim * self.num_v_heads # 6144 + self.conv_kernel_size = tc.linear_conv_kernel_dim # 4 + self.layer_idx = layer_idx + self.rms_norm_eps = tc.rms_norm_eps + + # KV cache dummy shape info + self.head_dim = tc.head_dim # 256 + tp_degree = tc.neuron_config.tp_degree + raw_kv_heads = tc.num_key_value_heads + if raw_kv_heads < tp_degree: + replicated_kv_heads = tp_degree + else: + replicated_kv_heads = raw_kv_heads + self.kv_heads_per_rank = replicated_kv_heads // tp_degree + + # Conv1d on concatenated QKV (NOT Z) + self.conv_dim = self.key_dim * 2 + self.value_dim # 10240 + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # Input projections (nn.Linear — NOT sharded by NxDI TP, replicated on all ranks) + self.in_proj_qkv = nn.Linear( + self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False + ) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + # Decay parameters + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + self.A_log = nn.Parameter(torch.zeros(self.num_v_heads)) + + # Output norm and projection + self.norm = Qwen3MoeRMSNorm(self.head_v_dim, eps=self.rms_norm_eps) + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + # State buffers for CTE -> TKG carry-over + alloc_batch_size = getattr(config.neuron_config, "max_batch_size", 1) + self._phase_batch_size = getattr(config.neuron_config, "batch_size", 1) + self.recurrent_state_buffer = nn.Parameter( + torch.zeros( + alloc_batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=config.neuron_config.torch_dtype, + ), + requires_grad=False, + ) + self.conv_state_buffer = nn.Parameter( + torch.zeros( + alloc_batch_size, + self.conv_dim, + self.conv_kernel_size - 1, + dtype=config.neuron_config.torch_dtype, + ), + requires_grad=False, + ) + + def _recurrent_step(self, query, key, value, g, beta, recurrent_state): + """Single-step recurrent update for token generation.""" + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + q_t = query[:, :, 0] + k_t = key[:, :, 0] + v_t = value[:, :, 0] + g_t = g[:, :, 0].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, 0].unsqueeze(-1) + + new_state = recurrent_state * g_t + kv_mem = (new_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + new_state = new_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + output = (new_state * q_t.unsqueeze(-1)).sum(dim=-2) + + return output.unsqueeze(2), new_state + + def _nki_recurrent_forward(self, query, key, value, g, beta): + """Full-sequence recurrent forward using NKI kernel for context encoding.""" + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + BH = B * H + query_flat = query.reshape(BH, S, k_dim).contiguous() + key_flat = key.reshape(BH, S, k_dim).contiguous() + value_flat = value.reshape(BH, S, v_dim).contiguous() + + g_flat = g.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous() + beta_flat = beta.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous() + + outputs = [] + states = [] + for bh in range(BH): + out_bh, state_bh = _deltanet_nki_kernel_state( + query_flat[bh], + key_flat[bh], + value_flat[bh], + g_flat[bh], + beta_flat[bh], + ) + outputs.append(out_bh) + states.append(state_bh) + + output = torch.stack(outputs, dim=0) + output = output.reshape(B, H, S, v_dim) + + final_state = torch.stack(states, dim=0) + final_state = final_state.reshape(B, H, k_dim, v_dim) + + return output, final_state + + def _nki_chunked_forward( + self, query, key, value, g, beta, output_final_state=False + ): + """Chunked NKI kernel forward for context encoding (prefill).""" + chunk_size = 128 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + num_chunks = total_seq_len // chunk_size + g_reshaped = g.reshape(B, H, num_chunks, chunk_size) + g_cs = g_reshaped.cumsum(dim=-1) + g_last_per_chunk = g_cs[:, :, :, -1:] + g_last_expanded = g_last_per_chunk.expand(-1, -1, -1, chunk_size) + + query_chunks = query.reshape(B, H, num_chunks, chunk_size, k_dim) + key_chunks = key.reshape(B, H, num_chunks, chunk_size, k_dim) + value_chunks = value.reshape(B, H, num_chunks, chunk_size, v_dim) + + beta_chunks = ( + beta.reshape(B, H, num_chunks, chunk_size) + .unsqueeze(-1) + .expand(-1, -1, -1, -1, v_dim) + ) + gc_chunks = g_cs.unsqueeze(-1).expand(-1, -1, -1, -1, v_dim) + gl_chunks = g_last_expanded.unsqueeze(-1).expand(-1, -1, -1, -1, v_dim) + + BH = B * H + query_chunks = query_chunks.reshape( + BH, num_chunks, chunk_size, k_dim + ).contiguous() + key_chunks = key_chunks.reshape(BH, num_chunks, chunk_size, k_dim).contiguous() + value_chunks = value_chunks.reshape( + BH, num_chunks, chunk_size, v_dim + ).contiguous() + beta_chunks = beta_chunks.reshape( + BH, num_chunks, chunk_size, v_dim + ).contiguous() + gc_chunks = gc_chunks.reshape(BH, num_chunks, chunk_size, v_dim).contiguous() + gl_chunks = gl_chunks.reshape(BH, num_chunks, chunk_size, v_dim).contiguous() + + device = query.device + lower_mask = torch.tril( + torch.ones(chunk_size, chunk_size, dtype=torch.float32, device=device), + diagonal=-1, + ) + identity_mat = torch.eye(chunk_size, dtype=torch.float32, device=device) + lower_mask_diag = torch.tril( + torch.ones(chunk_size, chunk_size, dtype=torch.float32, device=device), + diagonal=0, + ) + + all_outputs = [] + all_states = [] + for bh in range(BH): + state = torch.zeros(k_dim, v_dim, dtype=torch.float32, device=device) + + head_chunks = [] + for c_idx in range(num_chunks): + q_chunk = query_chunks[bh, c_idx].contiguous() + k_chunk = key_chunks[bh, c_idx].contiguous() + v_chunk = value_chunks[bh, c_idx].contiguous() + beta_chunk = beta_chunks[bh, c_idx].contiguous() + gc_chunk = gc_chunks[bh, c_idx].contiguous() + gl_chunk = gl_chunks[bh, c_idx].contiguous() + + out_chunk, state = _deltanet_nki_chunk_step( + q_chunk, + k_chunk, + v_chunk, + beta_chunk, + gc_chunk, + gl_chunk, + state, + lower_mask, + identity_mat, + lower_mask_diag, + ) + head_chunks.append(out_chunk) + + head_output = torch.cat(head_chunks, dim=0) + all_outputs.append(head_output) + all_states.append(state) + + output = torch.stack(all_outputs, dim=0) + output = output.reshape(B, H, total_seq_len, v_dim) + output = output[:, :, :S] + + if output_final_state: + final_state = torch.stack(all_states, dim=0) + last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim) + else: + last_recurrent_state = None + + return output, last_recurrent_state + + def _fused_chunked_forward( + self, query, key, value, g, beta, output_final_state=False + ): + """Fused single-kernel chunked forward for CTE — SSD-style. + + Processes all chunks in a single NKI kernel call per (B,H) pair. + State persists in SBUF across chunks (no HBM round-trips). + Cumsum of g computed in-kernel via tensor_tensor_scan. + + This is the optimized version of _nki_chunked_forward with: + 1. Single kernel call per (B,H) instead of B*H*num_chunks + 2. State in SBUF across all chunks (biggest perf win) + 3. In-kernel cumsum (avoids PyTorch cumsum overhead) + 4. tensor_scalar for broadcasts (no explicit loops) + """ + chunk_size = 128 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + # Pad sequence to multiple of chunk_size + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + BH = B * H + # Flatten to (BH, S, dim) for per-(b,h) kernel calls + query_flat = query.reshape(BH, total_seq_len, k_dim).contiguous() + key_flat = key.reshape(BH, total_seq_len, k_dim).contiguous() + value_flat = value.reshape(BH, total_seq_len, v_dim).contiguous() + + # g and beta: (BH, S) -> (BH, S, 1) for the kernel's (S, 1) input layout + g_flat = g.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + beta_flat = beta.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + + # Create constant mask tensors (shared across all B*H calls) + device = query.device + lower_mask = torch.tensor( + _make_lower_mask(), dtype=torch.float32, device=device + ) + identity_mat = torch.tensor( + _make_identity(), dtype=torch.float32, device=device + ) + lower_mask_diag = torch.tensor( + _make_lower_mask_diag(), dtype=torch.float32, device=device + ) + + all_outputs = [] + all_states = [] + for bh in range(BH): + out_bh, state_bh = _deltanet_fused_kernel( + query_flat[bh], # (S, 128) + key_flat[bh], # (S, 128) + value_flat[bh], # (S, 128) + g_flat[bh], # (S, 1) — RAW g, not cumsum + beta_flat[bh], # (S, 1) — sigmoid(b) + lower_mask, # (128, 128) + identity_mat, # (128, 128) + lower_mask_diag, # (128, 128) + ) + all_outputs.append(out_bh) + all_states.append(state_bh) + + output = torch.stack(all_outputs, dim=0) + output = output.reshape(B, H, total_seq_len, v_dim) + output = output[:, :, :S] + + if output_final_state: + final_state = torch.stack(all_states, dim=0) + last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim) + else: + last_recurrent_state = None + + return output, last_recurrent_state + + def _sequential_forward(self, query, key, value, g, beta, output_final_state=False): + """Sequential full-sequence gated delta rule for CTE. + + Uses the same per-step recurrence as _recurrent_step but loops over the + full sequence. Avoids the slice-assignment loop in _chunk_forward that + may compile incorrectly on Neuron/XLA. + """ + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + state = query.new_zeros(B, H, k_dim, v_dim) + all_outputs = [] + for t in range(S): + q_t = query[:, :, t] # (B, H, K) + k_t = key[:, :, t] # (B, H, K) + v_t = value[:, :, t] # (B, H, V) + beta_t = beta[:, :, t].unsqueeze(-1) # (B, H, 1) + g_t = g[:, :, t].exp().unsqueeze(-1).unsqueeze(-1) # (B, H, 1, 1) + + # Gated delta rule + state = state * g_t + kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) # (B, H, V) + delta = (v_t - kv_mem) * beta_t # (B, H, V) + state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) # (B, H, K, V) + + o_t = (state * q_t.unsqueeze(-1)).sum(dim=-2) # (B, H, V) + all_outputs.append(o_t.unsqueeze(2)) + + output = torch.cat(all_outputs, dim=2) # (B, H, S, V) + final_state = state if output_final_state else None + return output, final_state + + def _chunk_forward(self, query, key, value, g, beta, output_final_state=False): + """Chunk-based forward for context encoding (prefill).""" + chunk_size = 64 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + + num_chunks = total_seq_len // chunk_size + query = query.reshape(B, H, num_chunks, chunk_size, k_dim) + key = key.reshape(B, H, num_chunks, chunk_size, k_dim) + value = value.reshape(B, H, num_chunks, chunk_size, v_dim) + k_beta = k_beta.reshape(B, H, num_chunks, chunk_size, k_dim) + v_beta = v_beta.reshape(B, H, num_chunks, chunk_size, v_dim) + g = g.reshape(B, H, num_chunks, chunk_size) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=0, + ) + + g = g.cumsum(dim=-1) + decay_mask = (g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().tril() + + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + last_recurrent_state = torch.zeros( + B, H, k_dim, v_dim, dtype=query.dtype, device=query.device + ) + core_attn_out = torch.zeros_like(value) + mask2 = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=1, + ) + + for i in range(num_chunks): + q_i = query[:, :, i] + k_i = key[:, :, i] + v_i = value[:, :, i] + + attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_( + mask2, 0 + ) + + v_prime = k_cumdecay[:, :, i] @ last_recurrent_state + v_new = v_i - v_prime + + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn_i @ v_new + + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + ( + k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None] + ).transpose(-1, -2) + @ v_new + ) + + core_attn_out = core_attn_out.reshape(B, H, -1, v_dim) + core_attn_out = core_attn_out[:, :, :S] + + if not output_final_state: + last_recurrent_state = None + + return core_attn_out, last_recurrent_state + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + **kwargs, + ): + """Forward pass compatible with NxDI decoder layer interface.""" + batch_size, seq_len, _ = hidden_states.shape + + seq_ids = kwargs.get("seq_ids", None) + is_decode = past_key_value is not None + + # Padding mask for DeltaNet: [B, S, 1] with 1.0 for real tokens, 0.0 for padding. + # Passed from get_model_output where it's computed from input_ids != pad_token_id. + # Embeddings are already zeroed for padding tokens; this mask additionally + # zeros the decay gate so the recurrent state is preserved unchanged + # through padding positions (no spurious decay). + valid_mask_1d = kwargs.get("deltanet_padding_mask", None) # [B, S, 1] or None + + # Project inputs + deltanet_fp32 = os.environ.get("DELTANET_FP32") == "1" + if deltanet_fp32: + hs_f32 = hidden_states.float() + qkv = F.linear(hs_f32, self.in_proj_qkv.weight.float()).to( + hidden_states.dtype + ) + z = F.linear(hs_f32, self.in_proj_z.weight.float()).to(hidden_states.dtype) + b = F.linear(hs_f32, self.in_proj_b.weight.float()).to(hidden_states.dtype) + a = F.linear(hs_f32, self.in_proj_a.weight.float()).to(hidden_states.dtype) + else: + qkv = self.in_proj_qkv(hidden_states) + z = self.in_proj_z(hidden_states) + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + # Split QKV + query = qkv[..., : self.key_dim] + key = qkv[..., self.key_dim : self.key_dim * 2] + value = qkv[..., self.key_dim * 2 :] + + # Causal Conv1d on QKV + mixed = torch.cat([query, key, value], dim=-1) + mixed = mixed.transpose(1, 2) + + if is_decode: + if seq_ids is not None: + conv_state = torch.index_select(self.conv_state_buffer, 0, seq_ids) + else: + conv_state = self.conv_state_buffer[:batch_size] + conv_input = torch.cat([conv_state, mixed], dim=-1) + + w = self.conv1d.weight.squeeze(1) + conv_out = torch.zeros_like(mixed) + for k in range(4): + conv_out = ( + conv_out + + w[:, k].unsqueeze(0).unsqueeze(-1) * conv_input[:, :, k : k + 1] + ) + mixed_post_conv = F.silu(conv_out) + + new_conv_state = torch.cat([conv_state[:, :, 1:], mixed], dim=-1) + alloc_bs = self.conv_state_buffer.shape[0] + if seq_ids is not None: + # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement + # Add buffer dependency for input_output_alias + new_conv_state = ( + new_conv_state.to(self.conv_state_buffer.dtype) + + self.conv_state_buffer * 0 + ) + elif batch_size < alloc_bs: + pad_size = alloc_bs - batch_size + new_conv_state = torch.cat( + [ + new_conv_state, + self.conv_state_buffer[batch_size:] * 0, + ], + dim=0, + ) + else: + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + else: + mixed_post_conv = F.silu(self.conv1d(mixed)[:, :, :seq_len]) + + if valid_mask_1d is not None: + # valid_mask_1d is [B, S, 1]; count valid tokens per batch + num_valid = ( + valid_mask_1d.squeeze(-1).sum(dim=-1, keepdim=True).long() + ) # [B, 1] + idx_base = num_valid - 3 + idx_base = idx_base.clamp(min=0) + offsets = torch.arange(3, device=mixed.device).unsqueeze(0) + gather_idx = idx_base + offsets # [B, 3] + gather_idx = gather_idx.unsqueeze(1).expand(-1, self.conv_dim, -1) + new_conv_state = torch.gather(mixed, 2, gather_idx) + else: + new_conv_state = mixed[:, :, -3:].contiguous() + + alloc_bs = self.conv_state_buffer.shape[0] + if seq_ids is not None: + # BS=1 optimization: scatter to index 0 = direct replacement + new_conv_state = ( + new_conv_state.to(self.conv_state_buffer.dtype) + + self.conv_state_buffer * 0 + ) + elif batch_size < alloc_bs: + pad_size = alloc_bs - batch_size + new_conv_state = torch.cat( + [ + new_conv_state, + torch.zeros( + pad_size, + self.conv_dim, + self.conv_kernel_size - 1, + dtype=new_conv_state.dtype, + device=new_conv_state.device, + ), + ], + dim=0, + ) + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + else: + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + + mixed_post_conv = mixed_post_conv.transpose(1, 2) + + # Zero out conv1d output for padding positions. + # Conv1d with kernel_size=4 leaks real token info into the first + # few padding positions. Zeroing here ensures Q, K, V are exactly + # zero for all padding positions so the recurrence is unaffected. + if valid_mask_1d is not None: + mixed_post_conv = ( + mixed_post_conv * valid_mask_1d + ) # [B, S, conv_dim] * [B, S, 1] + + query = mixed_post_conv[..., : self.key_dim] + key = mixed_post_conv[..., self.key_dim : self.key_dim * 2] + value = mixed_post_conv[..., self.key_dim * 2 :] + + # Reshape to heads + query = query.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + key = key.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + value = value.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + + # Compute gating + beta = b.sigmoid() + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + + if valid_mask_1d is not None: + # Zero g for padding → alpha=exp(0)=1 → state preserved through padding + # Zero beta for padding → no state update from padding tokens + mask_2d = valid_mask_1d.squeeze(-1).float() # [B, S] + g = g * mask_2d.unsqueeze(-1) + beta = beta * mask_2d.unsqueeze(-1) + + # Expand K heads to match V heads (16 -> 48) using expand+reshape + if self.num_v_heads // self.num_k_heads > 1: + rep = self.num_v_heads // self.num_k_heads # 3 + query = ( + query.unsqueeze(3) + .expand(-1, -1, -1, rep, -1) + .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim) + ) + key = ( + key.unsqueeze(3) + .expand(-1, -1, -1, rep, -1) + .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim) + ) + + # Transpose to (B, H, S, dim) + query = query.transpose(1, 2).contiguous().float() + key = key.transpose(1, 2).contiguous().float() + value = value.transpose(1, 2).contiguous().float() + g = g.transpose(1, 2).contiguous().float() + beta = beta.transpose(1, 2).contiguous().float() + + if is_decode: + # TKG: single-step recurrent update + if seq_ids is not None: + recurrent_state = torch.index_select( + self.recurrent_state_buffer, 0, seq_ids + ).float() + else: + recurrent_state = self.recurrent_state_buffer[:batch_size].float() + + output, new_state = self._recurrent_step( + query, key, value, g, beta, recurrent_state + ) + new_state_bf16 = new_state.to(self.recurrent_state_buffer.dtype) + alloc_bs = self.recurrent_state_buffer.shape[0] + if seq_ids is not None: + # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement + # Add buffer dependency for input_output_alias + new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0 + elif batch_size < alloc_bs: + new_rec_state = torch.cat( + [ + new_state_bf16, + self.recurrent_state_buffer[batch_size:] * 0, + ], + dim=0, + ) + else: + new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0 + else: + # CTE: fused NKI kernel by default (PyTorch _chunk_forward can hit + # neuronx-cc codegen ICE NCC_INLA001 with these DeltaNet dimensions). + # Override with env vars for debugging/benchmarking. + use_nki_fused = os.environ.get("USE_NKI_FUSED", "1") != "0" + use_nki_chunked = os.environ.get("USE_NKI_CHUNKED") == "1" + use_nki = os.environ.get("USE_NKI") == "1" + use_sequential = os.environ.get("DELTANET_SEQUENTIAL") == "1" + use_pytorch_chunk = os.environ.get("USE_PYTORCH_CHUNK") == "1" + + if use_pytorch_chunk: + output, final_state = self._chunk_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki_chunked: + output, final_state = self._nki_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki: + output, final_state = self._nki_recurrent_forward( + query, key, value, g, beta + ) + elif use_sequential: + output, final_state = self._sequential_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki_fused: + output, final_state = self._fused_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + else: + output, final_state = self._fused_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + + if final_state is not None: + final_state_bf16 = final_state.to(self.recurrent_state_buffer.dtype) + alloc_bs = self.recurrent_state_buffer.shape[0] + if seq_ids is not None: + # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement + # Add buffer dependency for input_output_alias + new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0 + elif batch_size < alloc_bs: + new_rec_state = torch.cat( + [ + final_state_bf16, + torch.zeros( + alloc_bs - batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=final_state_bf16.dtype, + device=final_state_bf16.device, + ), + ], + dim=0, + ) + new_rec_state = new_rec_state + self.recurrent_state_buffer * 0 + else: + new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0 + else: + new_rec_state = self.recurrent_state_buffer * 1 + + # Output: norm, gate, project + output = output.to(hidden_states.dtype) + output = output.transpose(1, 2).contiguous() + output = output.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + output = self.norm(output) + z_gate = z.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + output = output * F.silu(z_gate) + output = output.reshape(batch_size, seq_len, self.value_dim) + output = self.out_proj(output) + + # Return dummy KV for KVCacheManager + dummy_k = torch.zeros( + batch_size, + self.kv_heads_per_rank, + seq_len, + self.head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + dummy_v = torch.zeros_like(dummy_k) + + return output, (dummy_k, dummy_v), new_rec_state, new_conv_state + + +# ============================================================ +# InferenceConfig (Dense -- no MoE) +# ============================================================ + + +class Qwen35InferenceConfig(InferenceConfig): + """Config for Qwen3.5-9B (dense) with hybrid DeltaNet + Attention.""" + + def __init__(self, *args, **kwargs): + # Set defaults BEFORE super().__init__() because it calls validate_config() + # which checks get_required_attributes(). These can be overridden by + # kwargs or load_config. + + # Layer types for hybrid dispatch: [3 DeltaNet + 1 GQA] repeated. + if "layer_types" not in kwargs and not any( + hasattr(a, "layer_types") for a in args if hasattr(a, "__dict__") + ): + num_layers = kwargs.get("num_hidden_layers", 32) + if num_layers % 4 != 0: + raise ValueError( + f"Qwen3.5 hybrid layer count must be divisible by 4, got {num_layers}" + ) + layer_types = [] + for _ in range(num_layers // 4): + layer_types.extend( + [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + ] + ) + kwargs.setdefault("layer_types", layer_types) + + # DeltaNet-specific config defaults + kwargs.setdefault("linear_num_value_heads", 32) + kwargs.setdefault("linear_num_key_heads", 16) + kwargs.setdefault("linear_key_head_dim", 128) + kwargs.setdefault("linear_value_head_dim", 128) + kwargs.setdefault("linear_conv_kernel_dim", 4) + + super().__init__(*args, **kwargs) + + # Attention output gate + self.attn_output_gate = getattr(self, "attn_output_gate", True) + + # Partial RoPE + self.partial_rotary_factor = getattr(self, "partial_rotary_factor", 0.25) + self.rope_dim = int(self.head_dim * self.partial_rotary_factor) # 64 + + # mRoPE (multimodal RoPE) for VL support + rope_params = getattr(self, "rope_parameters", {}) or {} + self.mrope_section = rope_params.get("mrope_section", [11, 11, 10]) + self.mrope_interleaved = rope_params.get("mrope_interleaved", True) + + # Standard HF config attributes expected by NxDI + if not hasattr(self, "output_attentions"): + self.output_attentions = False + if not hasattr(self, "output_hidden_states"): + self.output_hidden_states = False + + def get_required_attributes(self) -> List[str]: + return [ + "head_dim", + "hidden_act", + "hidden_size", + "intermediate_size", + "max_position_embeddings", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "rms_norm_eps", + "rope_theta", + "vocab_size", + # DeltaNet-specific + "linear_num_value_heads", + "linear_num_key_heads", + "linear_key_head_dim", + "linear_value_head_dim", + "linear_conv_kernel_dim", + "layer_types", + ] + + @classmethod + def get_neuron_config_cls(cls): + return NeuronConfig + + +# ============================================================ +# Attention (standard GQA for 16 of 64 layers) +# With output gate: q_proj is 2x sized, split into (query, gate) +# With partial RoPE: only first rope_dim dimensions get rotary +# ============================================================ + + +class Qwen35MRoPEEmbedding(nn.Module): + """Multimodal Rotary Position Embedding (mRoPE) for Qwen3.5. + + Handles 3D position information (temporal, height, width) for VL models. + Position IDs have shape (3, batch_size, seq_len) for T/H/W dimensions. + For text-only (2D position_ids), broadcasts to 3D with identical positions. + """ + + def __init__(self, config): + super().__init__() + self.head_dim = config.head_dim # 256 + self.rope_dim = config.rope_dim # 64 + self.mrope_section = config.mrope_section # [11, 11, 10] + self.mrope_interleaved = getattr(config, "mrope_interleaved", True) + self.rope_theta = config.rope_theta + + # Validate mrope_section sums to rope_dim // 2 = 32 + assert sum(self.mrope_section) == self.rope_dim // 2, ( + f"mrope_section {self.mrope_section} sums to {sum(self.mrope_section)}, " + f"expected {self.rope_dim // 2}" + ) + + def forward(self, x, position_ids_3d): + """Compute cos/sin from 3D position IDs. + + Args: + x: hidden_states (for device/dtype inference) + position_ids_3d: (3, batch_size, seq_len) -- T, H, W positions + + Returns: + cos: (batch_size, seq_len, rope_dim) + sin: (batch_size, seq_len, rope_dim) + """ + device = x.device + dtype = torch.float32 + + sections = self.mrope_section # [11, 11, 10] + cos_parts = [] + sin_parts = [] + + freq_offset = 0 + for axis_idx, section_size in enumerate(sections): + pos = position_ids_3d[axis_idx].float() # (batch, seq_len) + + dim_pairs = section_size # number of (cos, sin) pairs for this axis + freqs = 1.0 / ( + self.rope_theta + ** ( + torch.arange(0, dim_pairs * 2, 2, dtype=dtype, device=device) + / (self.rope_dim) + ) + ) # (dim_pairs,) + + # freqs: (dim_pairs,), pos: (B, S) -> angles: (B, S, dim_pairs) + angles = pos.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0) + + cos_parts.append(angles.cos()) + sin_parts.append(angles.sin()) + + # Concatenate: (B, S, 32) + cos = torch.cat(cos_parts, dim=-1) + sin = torch.cat(sin_parts, dim=-1) + + if self.mrope_interleaved: + # Interleave to (B, S, 64): [c0, c0, c1, c1, ...] for rotate_half + cos = cos.repeat_interleave(2, dim=-1) + sin = sin.repeat_interleave(2, dim=-1) + else: + cos = torch.cat([cos, cos], dim=-1) + sin = torch.cat([sin, sin], dim=-1) + + return cos, sin + + +class NeuronQwen35Attention(NeuronAttentionBase): + """Standard GQA attention for Qwen3.5 with output gate and partial RoPE. + + 16 Q heads, 4 KV heads (4:1 GQA), head_dim=256 for the 9B dense model. + q_proj is doubled (query + gate), split at load time. + Only first rope_dim=64 of head_dim=256 gets rotary encoding. + + Uses NeuronAttentionBase infrastructure for QKV projection, KV cache, + RoPE, and attention computation. Overrides forward() to insert the + sigmoid output gate between attention output and o_proj. + """ + + def __init__(self, config): + # Partial RoPE: create mRoPE embedding with rope_dim (64) + self.rope_dim = config.rope_dim # 64 = head_dim * partial_rotary_factor + + # Create QK norm modules (will be passed to base class) + rms_norm_eps = config.rms_norm_eps + q_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps) + k_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps) + + # Partial RoPE: use standard RotaryEmbedding. + # For VL with 3D mRoPE positions, cos/sin are pre-computed externally in + # get_model_output() using Qwen35MRoPEEmbedding and passed as cos_cache/sin_cache. + rotary_emb = RotaryEmbedding( + self.rope_dim, # Only 64 dims get rotary embedding + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + rotary_emb=rotary_emb, + rms_norm_eps=rms_norm_eps, + use_qk_norm=False, + q_layernorm=q_ln, + k_layernorm=k_ln, + ) + + # Separate mRoPE module for VL 3D position_ids + self.mrope_emb = Qwen35MRoPEEmbedding(config) + + # Output gate projection: hidden_size -> num_heads * head_dim + # Populated from the second half of q_proj during state dict conversion. + self.output_gate_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * config.head_dim, + bias=False, + gather_output=False, + ) + + def apply_rotary_embedding( + self, Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope + ): + """Partial RoPE: only apply rotary embedding to first rope_dim dimensions. + + Q shape: (B, H, S, head_dim) where head_dim=256 + cos/sin shape: (B, S, rope_dim) where rope_dim=64 (from RotaryEmbedding(dim=64)) + + Split Q/K along last dim into: + q_rope (first 64 dims) -- apply RoPE + q_pass (remaining 192 dims) -- pass through unchanged + """ + from neuronx_distributed_inference.modules.attention.utils import ( + apply_rotary_pos_emb, + ) + + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + + # Split into rope and pass-through portions + Q_orig_dtype = Q.dtype + q_rope = Q[..., : self.rope_dim] # (B, H, S, 64) + q_pass = Q[..., self.rope_dim :] # (B, H, S, 192) + k_rope = K[..., : self.rope_dim] + k_pass = K[..., self.rope_dim :] + + # Apply RoPE only to the rope portion + q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos_cache, sin_cache) + + # Concatenate back (ensure bf16 is maintained) + Q = torch.cat([q_rope, q_pass], dim=-1).to(Q_orig_dtype) + K = torch.cat([k_rope, k_pass], dim=-1).to(Q_orig_dtype) + + return Q, K, cos_cache, sin_cache + + def perform_prefill(self, Q, K, V, q_len, bsz, attention_mask=None): + """Prefill path with NKI flash attention for head_dim=256.""" + head_dim = Q.shape[-1] + + # Option B: nkilib flash attention for head_dim > 128 + if _nkilib_flash_attn is not None: + q_contig = Q.contiguous() + k_contig = K.contiguous() + v_contig = V.contiguous() + scale = 1.0 / math.sqrt(head_dim) + result = _nkilib_flash_attn( + q_contig, k_contig, v_contig, scale=scale, use_causal_mask=True + ) + return result, None + + # Option A: kernel patched globally + if NKILIB_PATCH_ACTIVE: + return _flash_fwd_call(Q, K, V, use_causal_mask=True), None + + # Fallback: softmax path (use 3D tensors to avoid compiler ICE with 4D patterns) + if head_dim > 128: + # GQA: expand K/V heads to match Q heads + num_q_heads = Q.shape[1] + num_kv_heads = K.shape[1] + if num_q_heads != num_kv_heads: + kv_rep = num_q_heads // num_kv_heads + K = ( + K.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(bsz, num_q_heads, q_len, head_dim) + ) + V = ( + V.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(bsz, num_q_heads, q_len, head_dim) + ) + # Reshape to 3D (B*H, S, d) to avoid neuronx-cc codegen ICE with 4D + # attention weight tensors (NCC_INLA001: Expected 2D tensor but got 4D AP) + Q_3d = Q.reshape(bsz * num_q_heads, q_len, head_dim) + K_3d = K.reshape(bsz * num_q_heads, q_len, head_dim) + V_3d = V.reshape(bsz * num_q_heads, q_len, head_dim) + attn_weights = torch.bmm(Q_3d, K_3d.transpose(-1, -2)) / math.sqrt(head_dim) + # Build causal mask for 3D: (1, S, S) broadcast over B*H + causal_mask = torch.triu( + torch.full( + (q_len, q_len), + -65504.0, + dtype=attn_weights.dtype, + device=attn_weights.device, + ), + diagonal=1, + ).unsqueeze(0) + attn_weights = attn_weights + causal_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + Q.dtype + ) + attn_output = torch.bmm(attn_weights, V_3d) + # Reshape back to 4D (B, H, S, d) + return attn_output.reshape(bsz, num_q_heads, q_len, head_dim), None + + return _flash_fwd_call(Q, K, V, use_causal_mask=True), None + + def forward( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + adapter_ids=None, + active_mask=None, + **kwargs, + ): + """Forward with output gate applied BEFORE o_proj. + + Override NeuronAttentionBase.forward() to insert the sigmoid gate + between the attention output and o_proj, matching the HF reference: + gate = sigmoid(gate_proj(pre_attn_hidden)) + attn_output = attn_output * gate + attn_output = o_proj(attn_output) + """ + bsz, q_len, _ = hidden_states.shape + + # Use standard 2D position_ids for prep_qkv_tensors. + rope_pos_ids = position_ids + + # Compute gate from input hidden states (before QKV projection) + gate = self.output_gate_proj(hidden_states) # (B, S, num_heads * head_dim) + + # Standard QKV prep (projections, QK norm, RoPE) + Q, K, V, cos_cache, sin_cache, _residual = self.prep_qkv_tensors( + rope_pos_ids, + hidden_states, + past_key_value, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rmsnorm=rmsnorm, + ) + + if past_key_value is None: + # Context encoding (prefill) + attn_output, _flash_strategy = self.perform_prefill( + Q, K, V, q_len, bsz, attention_mask + ) + else: + # Token generation (decode) + tkg_mask = attention_mask + if tkg_mask is not None and tkg_mask.ndim == 2: + tkg_mask = tkg_mask.unsqueeze(1).unsqueeze(2) # (B, S) -> (B, 1, 1, S) + attn_output = self.compute_for_token_gen( + Q, K, V, position_ids, past_key_value, tkg_mask, active_mask + ) + + # attn_output is (B, H, S, head_dim) -- transpose to (B, S, H*head_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + + # Apply sigmoid output gate BEFORE o_proj (matching HF reference) + attn_output = attn_output * torch.sigmoid(gate) + + # Apply o_proj + attn_output = self.get_o_proj()(attn_output, adapter_ids=adapter_ids) + + # Ensure K, V are in model dtype (bf16) for KV cache update + # (prevents mixed-precision dynamic-update-slice in neuronx-cc) + K = K.to(self.torch_dtype) + V = V.to(self.torch_dtype) + past_key_value = (K, V) + return attn_output, past_key_value, cos_cache, sin_cache + + +# ============================================================ +# Dense MLP (replaces MoE) +# ============================================================ + + +class Qwen35MLP(nn.Module): + """Dense SwiGLU MLP for Qwen3.5-9B. + + gate_proj: hidden_size -> intermediate_size (4096 -> 12288) + up_proj: hidden_size -> intermediate_size (4096 -> 12288) + down_proj: intermediate_size -> hidden_size (12288 -> 4096) + + output = down_proj(silu(gate_proj(x)) * up_proj(x)) + """ + + def __init__(self, config): + super().__init__() + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + ) + self.down_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + ) + + def forward(self, hidden_states): + gate = self.gate_proj(hidden_states) + up = self.up_proj(hidden_states) + hidden_states = F.silu(gate) * up + hidden_states = self.down_proj(hidden_states) + return hidden_states + + +# ============================================================ +# Decoder Layer (hybrid dispatch -- DeltaNet or GQA + Dense MLP) +# ============================================================ + + +class NeuronQwen35DecoderLayer(nn.Module): + """Hybrid decoder layer: dispatches to DeltaNet or standard attention. + Uses dense MLP for all layers (no MoE). + """ + + def __init__(self, config: Qwen35InferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_type = config.layer_types[layer_idx] + self.layer_idx = layer_idx + self.config = config + + # Attention (DeltaNet or standard GQA) + if self.layer_type == "linear_attention": + self.linear_attn = NeuronGatedDeltaNet(config, layer_idx) + else: + self.self_attn = NeuronQwen35Attention(config=config) + + # Dense MLP (all layers) + self.mlp = Qwen35MLP(config) + + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + padding_mask=None, + cos_cache=None, + sin_cache=None, + **kwargs, + ): + residual = hidden_states + + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == "linear_attention": + # DeltaNet path + attn_out, dummy_kv, new_rec_state, new_conv_state = self.linear_attn( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + **kwargs, + ) + hidden_states = residual + attn_out + present_key_value = dummy_kv + deltanet_states = (new_rec_state, new_conv_state) + else: + deltanet_states = None + # Standard attention path + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cos_cache=cos_cache, + sin_cache=sin_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Dense MLP FFN + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + outputs = ( + hidden_states, + present_key_value, + cos_cache, + sin_cache, + None, + deltanet_states, + ) + return outputs + + +# ============================================================ +# Model +# ============================================================ + + +class NeuronQwen35Model(NeuronBaseModel): + def setup_attr_for_model(self, config: Qwen35InferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: Qwen35InferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + self.layers = nn.ModuleList( + [ + NeuronQwen35DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=False if self.on_device_sampling else True, + bias=False, + ) + + # mRoPE embedding for VL + self.mrope_emb = Qwen35MRoPEEmbedding(config) + + @property + def _deltanet_state_params(self): + """Return DeltaNet state nn.Parameters in alias order.""" + params = [] + for layer in self.layers: + if hasattr(layer, "linear_attn"): + params.append(layer.linear_attn.recurrent_state_buffer) + params.append(layer.linear_attn.conv_state_buffer) + return params + + def encode_vision_to_input(self, inputs_embeds, vision_embeddings, vision_mask): + """Scatter vision embeddings into text input embeddings at image token positions.""" + _, max_positions, embedding_dim = inputs_embeds.shape + h_new = inputs_embeds.clone() + vision_flat = vision_embeddings.view(-1, embedding_dim) + positions_flat = vision_mask.view(-1) + h_new.view(-1, embedding_dim).index_put_( + (positions_flat,), vision_flat, accumulate=False + ) + return h_new + + def get_model_output( + self, + input_ids=None, + seq_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + active_mask=None, + inputs_embeds=None, + prev_hidden=None, + adapter_ids=None, + rotary_position_ids=None, + update_cache=False, + is_for_context_encoding=False, + vision_embeddings=None, + vision_mask=None, + local_attn_mask=None, + windowed_context_encoding_window_idx=-1, + padding_mask=None, + **kwargs, + ): + """Override to collect DeltaNet state tensors from decoder layers.""" + batch_size, seq_length = input_ids.shape[:2] + if self.config.neuron_config.layer_boundary_markers: + input_ids = ModuleMarkerStartWrapper()(input_ids) + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][1].shape[2] + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # CRITICAL: Zero out embeddings for padding tokens so DeltaNet recurrence + # is not polluted. DeltaNet has no attention mask -- it processes all + # sequence positions through a linear recurrence. Padding tokens have + # real embedding vectors which corrupt the recurrence state. + # The mask is [B, S, 1] float with 1.0 for real tokens, 0.0 for padding. + deltanet_padding_mask = ( + (input_ids != self.padding_idx).unsqueeze(-1).to(inputs_embeds.dtype) + ) + if is_for_context_encoding: + inputs_embeds = inputs_embeds * deltanet_padding_mask + + # Vision embedding injection + if (vision_embeddings is not None) and (vision_mask is not None): + if vision_embeddings.dtype != self.config.neuron_config.torch_dtype: + vision_embeddings = vision_embeddings.to( + self.config.neuron_config.torch_dtype + ) + if is_for_context_encoding: + inputs_embeds = self.encode_vision_to_input( + inputs_embeds, vision_embeddings, vision_mask + ) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + hidden_states = inputs_embeds + + # Get KV cache for TKG + cache_size = self.n_positions + if not is_for_context_encoding: + if self.kv_mgr is not None: + past_key_values = self.kv_mgr.get_cache( + seq_ids=seq_ids, + seq_len=cache_size, + is_for_context_encoding=is_for_context_encoding, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + + # Decoder layers + next_decoder_cache = () + deltanet_state_tensors = [] + cos_cache = None + sin_cache = None + + # Convert 2D attention_mask to 4D causal mask for CTE + if ( + attention_mask is not None + and attention_mask.ndim == 2 + and is_for_context_encoding + ): + causal = torch.ones( + (seq_length, seq_length), + dtype=torch.bool, + device=attention_mask.device, + ).tril() + padding_4d = attention_mask[:, None, None, :].to(torch.bool) + attention_mask = (causal[None, None, :, :] & padding_4d).to( + attention_mask.dtype + ) + + # Pre-compute mRoPE cos/sin + if rotary_position_ids is not None and rotary_position_ids.ndim == 3: + cos_cache, sin_cache = self.mrope_emb(inputs_embeds, rotary_position_ids) + + for idx, decoder_layer in enumerate(self.layers): + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + layer_outputs = decoder_layer( + hidden_states, + seq_ids=seq_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + active_mask=active_mask, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rotary_position_ids=rotary_position_ids, + kv_mgr=self.kv_mgr, + get_kv_per_layer=False, + update_kv_per_layer=False, + idx=idx, + is_for_context_encoding=is_for_context_encoding, + seq_len=cache_size, + residual=None, + local_mask=local_attn_mask, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + padding_mask=padding_mask, + deltanet_padding_mask=deltanet_padding_mask, + **kwargs, + ) + + hidden_states = layer_outputs[0] + kv = layer_outputs[1] + next_decoder_cache += (kv,) + cos_cache, sin_cache = layer_outputs[2:4] + + # Collect DeltaNet state tensors + deltanet_states = layer_outputs[5] if len(layer_outputs) > 5 else None + if deltanet_states is not None: + deltanet_state_tensors.append(deltanet_states[0]) + deltanet_state_tensors.append(deltanet_states[1]) + + # Update KV cache + if update_cache: + next_decoder_cache = self.kv_mgr.update_cache( + is_for_context_encoding=is_for_context_encoding, + seq_ids=seq_ids, + position_ids=position_ids, + new_key_values=next_decoder_cache, + seq_len=cache_size, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + self._deltanet_updated_states = deltanet_state_tensors + + return (hidden_states, next_decoder_cache) + + def forward( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden=None, + adapter_ids=None, + accepted_indices=None, + current_length=None, + medusa_mask=None, + scatter_index=None, + slot_mapping=None, + active_block_table=None, + num_queries=None, + computed_context_lens=None, + tile_q_indices=None, + tile_block_tables=None, + tile_masks=None, + inputs_embeds=None, + kv_cache=None, + active_mask=None, + rotary_position_id=None, + vision_embeddings=None, + vision_mask=None, + ): + """Override base forward to append DeltaNet state tensors to output.""" + prev_hidden = self.set_none_if_empty(prev_hidden) + adapter_ids = self.set_none_if_empty(adapter_ids) + accepted_indices = self.set_none_if_empty(accepted_indices) + current_length = self.set_none_if_empty(current_length) + medusa_mask = self.set_none_if_empty(medusa_mask) + scatter_index = self.set_none_if_empty(scatter_index) + slot_mapping = self.set_none_if_empty(slot_mapping) + active_block_table = self.set_none_if_empty(active_block_table) + num_queries = self.set_none_if_empty(num_queries) + computed_context_lens = self.set_none_if_empty(computed_context_lens) + tile_q_indices = self.set_none_if_empty(tile_q_indices) + tile_block_tables = self.set_none_if_empty(tile_block_tables) + tile_masks = self.set_none_if_empty(tile_masks) + inputs_embeds = self.set_none_if_empty(inputs_embeds) + kv_cache = self.set_none_if_empty(kv_cache) + active_mask = self.set_none_if_empty(active_mask) + rotary_position_id = self.set_none_if_empty(rotary_position_id) + vision_embeddings = self.set_none_if_empty(vision_embeddings) + vision_mask = self.set_none_if_empty(vision_mask) + + is_for_context_encoding = position_ids.shape[-1] != 1 and not ( + hasattr(self.neuron_config, "speculation_length") + and position_ids.shape[-1] == self.neuron_config.speculation_length + ) + + seq_ids = seq_ids.to(torch.int32) + attn_mask = attention_mask + + hidden_states, updated_kv_cache = self.get_model_output( + input_ids=input_ids, + seq_ids=seq_ids, + attention_mask=attn_mask, + position_ids=position_ids, + active_mask=active_mask, + inputs_embeds=inputs_embeds, + adapter_ids=adapter_ids, + rotary_position_ids=rotary_position_id, + update_cache=True, + is_for_context_encoding=is_for_context_encoding, + padding_mask=None, + active_block_table=active_block_table, + scatter_index=slot_mapping + if getattr(self, "is_block_kv_layout", False) + else scatter_index, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + ) + + batch_size = input_ids.shape[0] + if not getattr(self, "sliced_hidden", False): + if not is_for_context_encoding: + pass + else: + index = torch.max(position_ids, dim=1, keepdim=True).indices + index = index.unsqueeze(1).expand(batch_size, 1, self.hidden_size) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + + logits = self.lm_head(hidden_states) + logits = logits.float() + + if hasattr(self.lm_head, "pad_size"): + if self.lm_head.gather_output: + rank_id = torch.tensor(0, device=logits.device, dtype=torch.int32) + world_size = 1 + else: + from neuronx_distributed.parallel_layers import parallel_state + + rank_id = self.rank_util.get_rank() + world_size = torch.distributed.get_world_size( + group=self.lm_head.tensor_parallel_group + ) + from neuronx_distributed_inference.models.model_base import ( + mask_padded_logits, + ) + + logits = mask_padded_logits( + logits, rank_id, world_size, pad_size=self.lm_head.pad_size + ) + + if self.on_device_sampling: + res = self._sample_on_device( + logits, sampling_params, False, is_for_context_encoding + ) + else: + res = logits + + outputs = [res] + if self.neuron_config.output_logits: + outputs += [logits] + outputs += updated_kv_cache + + # Append DeltaNet state tensors (for input_output_aliases) + if hasattr(self, "_deltanet_updated_states"): + outputs += self._deltanet_updated_states + + return outputs + + +# ============================================================ +# State Dict Converter (Dense -- no MoE weight handling) +# ============================================================ + + +def convert_qwen35_hf_to_neuron_state_dict(neuron_state_dict, config): + """Convert HF Qwen3.5 (dense) weights to NxDI format. + + Weight mappings per layer type: + + DeltaNet layers (linear_attention): + HF: layers.X.linear_attn.{in_proj_qkv, in_proj_z, in_proj_a, in_proj_b, + conv1d, A_log, dt_bias, norm, out_proj} + NxDI: same names (no remapping needed) + + Full attention layers: + HF: layers.X.self_attn.q_proj.weight: (num_heads*head_dim*2, hidden) -- doubled for gate + NxDI: layers.X.self_attn.Wqkv.weight (fused Q+K+V, gate separated) + layers.X.self_attn.output_gate_proj.weight (gate part) + HF: layers.X.self_attn.{k_proj, v_proj, o_proj, q_norm, k_norm} + NxDI: layers.X.self_attn.{..., q_layernorm, k_layernorm} + + Dense MLP (all layers): + HF: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + NxDI: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight (same names) + """ + # Add rank_util + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + # CRITICAL: Convert (1+weight) RMSNorm weights to standard RMSNorm weights. + # Qwen3.5 uses RMSNorm with `output = norm(x) * (1 + weight)` where weight + # is initialized to zeros. Standard NxDI RMSNorm uses `output = norm(x) * weight` + # where weight is initialized to ones. To convert: new_weight = old_weight + 1.0 + norm_keys_to_convert = [] + for l in range(config.num_hidden_layers): + norm_keys_to_convert.append(f"layers.{l}.input_layernorm.weight") + norm_keys_to_convert.append(f"layers.{l}.post_attention_layernorm.weight") + if config.layer_types[l] == "full_attention": + norm_keys_to_convert.append(f"layers.{l}.self_attn.q_norm.weight") + norm_keys_to_convert.append(f"layers.{l}.self_attn.k_norm.weight") + norm_keys_to_convert.append("norm.weight") + + for nk in norm_keys_to_convert: + if nk in neuron_state_dict: + old_val = neuron_state_dict[nk] + neuron_state_dict[nk] = old_val.float() + 1.0 + if "layers.0." in nk or nk == "norm.weight": + logger.debug( + f"[NORM FIX] {nk}: mean {old_val.float().mean():.4f} -> {neuron_state_dict[nk].mean():.4f}" + ) + else: + if "layers.0." in nk or nk == "norm.weight": + logger.warning(f"[NORM FIX] key not found: {nk}") + + for l in range(config.num_hidden_layers): + layer_type = config.layer_types[l] + + # === Attention layers === + if layer_type == "full_attention": + neuron_state_dict[f"layers.{l}.self_attn.rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + # QK norms: q_norm -> q_layernorm, k_norm -> k_layernorm + q_norm_key = f"layers.{l}.self_attn.q_norm.weight" + k_norm_key = f"layers.{l}.self_attn.k_norm.weight" + if q_norm_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.q_layernorm.weight"] = ( + neuron_state_dict.pop(q_norm_key).detach().clone() + ) + if k_norm_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.k_layernorm.weight"] = ( + neuron_state_dict.pop(k_norm_key).detach().clone() + ) + + # q_proj is doubled: (num_heads * head_dim * 2, hidden_size) + # INTERLEAVED: [head0_query(head_dim) | head0_gate(head_dim) | head1_query(head_dim) | ...] + q_proj_key = f"layers.{l}.self_attn.q_proj.weight" + if q_proj_key in neuron_state_dict: + q_proj_w = neuron_state_dict.pop(q_proj_key) + num_heads = config.num_attention_heads + head_dim = config.head_dim + q_proj_w = q_proj_w.reshape(num_heads, head_dim * 2, config.hidden_size) + query_w = q_proj_w[:, :head_dim, :] + gate_w = q_proj_w[:, head_dim:, :] + query_w = query_w.reshape(num_heads * head_dim, config.hidden_size) + gate_w = gate_w.reshape(num_heads * head_dim, config.hidden_size) + + neuron_state_dict[q_proj_key] = query_w + neuron_state_dict[f"layers.{l}.self_attn.output_gate_proj.weight"] = ( + gate_w + ) + + # Fuse QKV + if config.neuron_config.fused_qkv: + q_key = f"layers.{l}.self_attn.q_proj.weight" + k_key = f"layers.{l}.self_attn.k_proj.weight" + v_key = f"layers.{l}.self_attn.v_proj.weight" + if q_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = torch.cat( + [ + neuron_state_dict[q_key], + neuron_state_dict[k_key], + neuron_state_dict[v_key], + ] + ) + del neuron_state_dict[q_key] + del neuron_state_dict[k_key] + del neuron_state_dict[v_key] + + # Dense MLP: no weight conversion needed -- HF and NxDI use same names + # HF: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + # NxDI: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + + gc.collect() + + return neuron_state_dict + + +# ============================================================ +# Custom ModelWrapper and DecoderModelInstance for DeltaNet state aliasing +# ============================================================ + + +class Qwen35DecoderModelInstance(DecoderModelInstance): + """Custom DecoderModelInstance that adds DeltaNet state buffers to input_output_aliases.""" + + def get(self, bucket_rank, **kwargs): + """Override to add DeltaNet state aliases after KV cache aliases.""" + module, input_output_aliases = super().get(bucket_rank, **kwargs) + + num_output_from_trace = 1 if not self.neuron_config.output_logits else 2 + + if module.kv_mgr is not None: + num_kv = len(module.kv_mgr.past_key_values) + else: + num_kv = 0 + + state_start_idx = num_output_from_trace + num_kv + + if hasattr(module, "_deltanet_state_params"): + for i, param in enumerate(module._deltanet_state_params): + input_output_aliases[param] = state_start_idx + i + + return module, input_output_aliases + + +class Qwen35ModelWrapper(ModelWrapper): + """Custom ModelWrapper for VL support with mRoPE and vision inputs.""" + + def get_model_instance(self): + return Qwen35DecoderModelInstance( + model_cls=self.model_cls, + config=self.config, + **self.model_init_kwargs, + ) + + def input_generator(self): + """Generate inputs including mrope_position_ids, vision_embeddings, and vision_mask.""" + base_inputs = super().input_generator() + extended_inputs = [] + + for bucket_inputs in base_inputs: + input_ids = bucket_inputs[0] + batch_size = input_ids.shape[0] + n_active_tokens = input_ids.shape[1] + + is_cte = n_active_tokens > 1 + + if is_cte: + mrope_position_ids = ( + torch.arange(0, n_active_tokens, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + + vision_embeddings = torch.zeros( + (batch_size, n_active_tokens, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, n_active_tokens, 1), + fill_value=n_active_tokens - 1, + dtype=torch.int32, + ) + else: + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + vision_embeddings = torch.zeros( + (0,), dtype=self.config.neuron_config.torch_dtype + ) + vision_mask = torch.zeros((0,), dtype=torch.int32) + + padded = list(bucket_inputs) + while len(padded) < 21: + padded.append(torch.zeros((0,), dtype=torch.int32)) + padded.append(mrope_position_ids) # position 21 + padded.append(vision_embeddings) # position 22 + padded.append(vision_mask) # position 23 + + extended_inputs.append(tuple(padded)) + + return extended_inputs + + def pad_inputs(self, *args, pad_type="first_fit"): + """Override to pad mrope_position_ids and vision inputs to bucket size.""" + orig_mrope = args[21] if len(args) >= 22 else None + orig_vis_emb = args[22] if len(args) >= 23 else None + orig_vis_mask = args[23] if len(args) >= 24 else None + + padded_args = super().pad_inputs(*args, pad_type=pad_type) + + if len(padded_args) >= 24 and orig_mrope is not None: + padded_seq_len = padded_args[0].shape[1] + batch_size = padded_args[0].shape[0] + is_cte = padded_seq_len > 1 + + if is_cte: + current_mrope = orig_mrope + current_vis_emb = orig_vis_emb + current_vis_mask = orig_vis_mask + + if ( + current_mrope.ndim == 3 + and current_mrope.shape[-1] != padded_seq_len + ): + orig_len = current_mrope.shape[-1] + pad_size = padded_seq_len - orig_len + last_pos = current_mrope[:, :, -1:] + pad_offsets = torch.arange( + 1, pad_size + 1, dtype=current_mrope.dtype + ) + pad_offsets = ( + pad_offsets.unsqueeze(0).unsqueeze(0).expand(3, batch_size, -1) + ) + mrope_pad = last_pos + pad_offsets + mrope_position_ids = torch.cat([current_mrope, mrope_pad], dim=-1) + elif current_mrope.ndim == 3: + mrope_position_ids = current_mrope + else: + mrope_position_ids = ( + torch.arange(0, padded_seq_len, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + + if ( + current_vis_emb is not None + and current_vis_emb.ndim == 3 + and current_vis_emb.shape[1] < padded_seq_len + ): + pad_emb = torch.zeros( + ( + batch_size, + padded_seq_len - current_vis_emb.shape[1], + current_vis_emb.shape[2], + ), + dtype=current_vis_emb.dtype, + ) + vision_embeddings = torch.cat([current_vis_emb, pad_emb], dim=1) + elif current_vis_emb is not None and current_vis_emb.ndim == 3: + vision_embeddings = current_vis_emb[:, :padded_seq_len] + else: + vision_embeddings = torch.zeros( + (batch_size, padded_seq_len, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + + if ( + current_vis_mask is not None + and current_vis_mask.ndim == 3 + and current_vis_mask.shape[1] < padded_seq_len + ): + pad_mask = torch.full( + (batch_size, padded_seq_len - current_vis_mask.shape[1], 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + vision_mask = torch.cat([current_vis_mask, pad_mask], dim=1) + elif current_vis_mask is not None and current_vis_mask.ndim == 3: + vision_mask = current_vis_mask[:, :padded_seq_len] + else: + vision_mask = torch.full( + (batch_size, padded_seq_len, 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + + padded_args = ( + *padded_args[:21], + mrope_position_ids, + vision_embeddings, + vision_mask, + ) + + padded_args = list(padded_args) + padded_args[23] = padded_args[23].clamp(max=padded_seq_len - 1) + padded_args = tuple(padded_args) + + return padded_args + + +# ============================================================ +# Top-Level Model +# ============================================================ + + +class NeuronQwen35ForCausalLM(NeuronBaseForCausalLM): + _model_cls = NeuronQwen35Model + + def get_model_wrapper_cls(self): + """Return custom ModelWrapper with DeltaNet state aliasing.""" + return Qwen35ModelWrapper + + @staticmethod + def load_hf_model(model_path, **kwargs): + """Load HF model weights. + + The model is a VL model (Qwen3_5ForConditionalGeneration) but we + only need the text backbone. + """ + from transformers import AutoModelForCausalLM + + kwargs.setdefault("trust_remote_code", True) + return AutoModelForCausalLM.from_pretrained(model_path, **kwargs) + + @classmethod + def get_config_cls(cls): + return Qwen35InferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict, config): + """Strip VL wrapper prefix and convert to NxDI format.""" + new_sd = {} + for k, v in state_dict.items(): + if k.startswith("language_model."): + new_k = k.replace("language_model.", "", 1) + new_sd[new_k] = v + elif k.startswith("model.language_model."): + new_k = k.replace("model.language_model.", "", 1) + new_sd[new_k] = v + elif k.startswith("model.visual") or k.startswith("visual"): + continue # Skip vision encoder + elif k.startswith("model."): + new_sd[k.replace("model.", "", 1)] = v + elif k.startswith("mtp."): + continue # Skip MTP + elif k.startswith("lm_head."): + new_sd[k] = v + else: + new_sd[k] = v + + if ( + getattr(config, "tie_word_embeddings", False) + and "lm_head.weight" not in new_sd + and "embed_tokens.weight" in new_sd + ): + new_sd["lm_head.weight"] = new_sd["embed_tokens.weight"] + + return convert_qwen35_hf_to_neuron_state_dict(new_sd, config) + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + def _copy_past_key_values(self, outputs): + """Override to also copy DeltaNet state buffers on CPU.""" + super()._copy_past_key_values(outputs) + + num_output_from_trace = 1 + if ( + self.neuron_config.output_logits + and self.neuron_config.on_device_sampling_config + ): + num_output_from_trace = 2 + + if ( + hasattr(self, "token_generation_model") + and self.token_generation_model is not None + ): + tkg_model = self.token_generation_model.model + cte_model = self.context_encoding_model.model + else: + return + + if tkg_model.kv_mgr is not None: + num_kv = len(tkg_model.kv_mgr.past_key_values) + else: + num_kv = 0 + + state_start = num_output_from_trace + num_kv + + tkg_params = getattr(tkg_model, "_deltanet_state_params", []) + cte_params = getattr(cte_model, "_deltanet_state_params", []) + + if len(tkg_params) > 0 and state_start + len(tkg_params) <= len(outputs): + for i, (tkg_param, cte_param) in enumerate(zip(tkg_params, cte_params)): + new_state = outputs[state_start + i] + tkg_param.data = new_state + cte_param.data = new_state + + def get_required_kwargs(self): + """Return extra kwargs for HF generation loop.""" + return ["llava_args"] + + def _get_model_outputs( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + medusa_args, + llava_args, + slot_mapping=None, + block_table=None, + full_context_lens=None, + computed_context_lens=None, + tf_args=None, + ): + """Override to pass all 24 positional args explicitly.""" + is_prefill = self._is_prefill(position_ids) + + seq_len = input_ids.shape[1] + batch_size = input_ids.shape[0] + + if llava_args and len(llava_args) >= 2: + vision_embeddings = llava_args[0] + vision_mask = llava_args[1] + if len(llava_args) >= 3: + mrope_position_ids = llava_args[2] + else: + mrope_position_ids = None + elif is_prefill: + vision_embeddings = torch.zeros( + (batch_size, seq_len, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, seq_len, 1), + fill_value=seq_len - 1, + dtype=torch.int32, + ) + mrope_position_ids = None + else: + vision_embeddings = torch.zeros((0,), dtype=torch.float32) + vision_mask = torch.zeros((0,), dtype=torch.int32) + mrope_position_ids = None + + if is_prefill: + if mrope_position_ids is None: + mrope_position_ids = ( + torch.arange(0, seq_len, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + else: + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + + empties = [torch.empty(0) for _ in range(14)] + + if self._is_prefill(position_ids): + ctx_bs = self.context_encoding_model.neuron_config.batch_size + output_logits = [] + + for cb in range(0, batch_size, ctx_bs): + cb_end = min(cb + ctx_bs, batch_size) + actual_chunk = cb_end - cb + + chunk_input_ids = input_ids[cb:cb_end] + chunk_attn_mask = attention_mask[cb:cb_end] + chunk_pos_ids = position_ids[cb:cb_end] + chunk_seq_ids = seq_ids[cb:cb_end] + chunk_sampling = sampling_params[cb:cb_end] + chunk_prev_hidden = ( + prev_hidden[cb:cb_end] + if prev_hidden is not None + and hasattr(prev_hidden, "ndim") + and prev_hidden.ndim > 0 + and prev_hidden.shape[0] > 0 + else prev_hidden + ) + chunk_adapter_ids = ( + adapter_ids[cb:cb_end] + if adapter_ids is not None + and hasattr(adapter_ids, "ndim") + and adapter_ids.ndim > 0 + and adapter_ids.shape[0] > 0 + else adapter_ids + ) + + if mrope_position_ids.ndim == 3: + chunk_mrope = mrope_position_ids[:, cb:cb_end, :] + else: + chunk_mrope = mrope_position_ids + + if vision_embeddings.ndim == 3: + chunk_vis_emb = vision_embeddings[cb:cb_end] + chunk_vis_mask = vision_mask[cb:cb_end] + else: + chunk_vis_emb = vision_embeddings + chunk_vis_mask = vision_mask + + if actual_chunk < ctx_bs: + pad_n = ctx_bs - actual_chunk + chunk_input_ids = torch.cat( + [chunk_input_ids, chunk_input_ids[:1].expand(pad_n, -1)], dim=0 + ) + chunk_attn_mask = torch.cat( + [chunk_attn_mask, chunk_attn_mask[:1].expand(pad_n, -1)], dim=0 + ) + chunk_pos_ids = torch.cat( + [chunk_pos_ids, chunk_pos_ids[:1].expand(pad_n, -1)], dim=0 + ) + pad_seq = torch.arange( + batch_size, batch_size + pad_n, dtype=chunk_seq_ids.dtype + ) + chunk_seq_ids = torch.cat([chunk_seq_ids, pad_seq], dim=0) + chunk_sampling = torch.cat( + [chunk_sampling, chunk_sampling[:1].expand(pad_n, -1)], dim=0 + ) + if ( + chunk_prev_hidden is not None + and hasattr(chunk_prev_hidden, "ndim") + and chunk_prev_hidden.ndim > 0 + and chunk_prev_hidden.shape[0] > 0 + ): + chunk_prev_hidden = torch.cat( + [ + chunk_prev_hidden, + chunk_prev_hidden[:1].expand(pad_n, -1), + ], + dim=0, + ) + if ( + chunk_adapter_ids is not None + and hasattr(chunk_adapter_ids, "ndim") + and chunk_adapter_ids.ndim > 0 + and chunk_adapter_ids.shape[0] > 0 + ): + chunk_adapter_ids = torch.cat( + [ + chunk_adapter_ids, + chunk_adapter_ids[:1].expand(pad_n, -1), + ], + dim=0, + ) + if chunk_mrope.ndim == 3: + chunk_mrope = torch.cat( + [chunk_mrope, chunk_mrope[:, :1, :].expand(-1, pad_n, -1)], + dim=1, + ) + if chunk_vis_emb.ndim == 3: + chunk_vis_emb = torch.cat( + [ + chunk_vis_emb, + torch.zeros( + (pad_n,) + chunk_vis_emb.shape[1:], + dtype=chunk_vis_emb.dtype, + ), + ], + dim=0, + ) + chunk_vis_mask = torch.cat( + [ + chunk_vis_mask, + torch.full( + (pad_n,) + chunk_vis_mask.shape[1:], + fill_value=seq_len - 1, + dtype=chunk_vis_mask.dtype, + ), + ], + dim=0, + ) + + chunk_out = self.context_encoding_model( + chunk_input_ids, + chunk_attn_mask, + chunk_pos_ids, + chunk_seq_ids, + chunk_sampling, + chunk_prev_hidden, + chunk_adapter_ids, + *empties, + chunk_mrope, + chunk_vis_emb, + chunk_vis_mask, + ) + if actual_chunk < ctx_bs: + chunk_out = chunk_out[:actual_chunk] + output_logits.append(chunk_out) + + outputs = ( + torch.cat(output_logits, dim=0) + if len(output_logits) > 1 + else output_logits[0] + ) + self.kv_cache_populated = True + is_run_on_neuron = self.context_encoding_model.is_neuron() + else: + outputs = self.token_generation_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + *empties, + mrope_position_ids, + vision_embeddings, + vision_mask, + ) + is_run_on_neuron = self.token_generation_model.is_neuron() + + return outputs, is_run_on_neuron + + def get_compiler_args(self): + if self.compile_tag == CONTEXT_ENCODING_MODEL_TAG: + optimization_level = "-O1" + else: + optimization_level = "-O1" + + compiler_args = ( + "--enable-saturate-infinity " + "--enable-mixed-precision-accumulation " + f"--model-type transformer {optimization_level} " + "--auto-cast=none " + ) + return compiler_args diff --git a/contrib/models/Qwen3.5-9B/src/nki_kernels/__init__.py b/contrib/models/Qwen3.5-9B/src/nki_kernels/__init__.py new file mode 100644 index 00000000..6472e49c --- /dev/null +++ b/contrib/models/Qwen3.5-9B/src/nki_kernels/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Custom NKI kernels for Qwen3.5-9B DeltaNet layers. + +Contains three kernel implementations: +- nki_deltanet: Per-token recurrent kernel (used for token generation) +- nki_deltanet_chunked: Per-chunk kernel (legacy, superseded by fused) +- nki_deltanet_fused: Fused single-kernel chunked forward (used for context encoding) +""" diff --git a/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet.py b/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet.py new file mode 100644 index 00000000..e6740aa1 --- /dev/null +++ b/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet.py @@ -0,0 +1,337 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""NKI kernels for DeltaNet gated delta rule recurrent forward. + +NKI v3 (SDK 2.29, NKI 0.3.0). Processes a SINGLE (batch, head) pair per kernel call. +The caller loops over (B, H) in PyTorch and calls this kernel for each pair. + +Input layout: All inputs are 2D contiguous tensors (S, 128). +Each call processes one (batch, head) element's full sequence. + +k_dim = v_dim = 128, which matches SBUF tile partition dimension exactly. +g and beta are scalars per token, expanded to (S, 128) by the caller. + +Two kernel variants: + deltanet_recurrent_fwd -- returns output only (original) + deltanet_recurrent_fwd_state -- returns (output, final_state) for CTE->TKG carry-over +""" + +import nki +import nki.isa as nisa +import nki.language as nl + +# Partition dimension max (NeuronCore SBUF tile width) +P_MAX = 128 + +# Shuffle mask: broadcast partition 0 to all partitions in a 32-wide group +_BROADCAST_MASK = [0] * 32 + + +@nki.jit +def deltanet_recurrent_fwd( + query: nl.ndarray, # (S, 128) float32 + key: nl.ndarray, # (S, 128) float32 + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128 + beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128 +) -> nl.ndarray: + """NKI kernel for DeltaNet recurrent forward -- single (batch, head). + + Iterates over sequence tokens with sequential_range. + State matrix (128 x 128) lives in SBUF. + + Args: + query: (S, 128) float32 + key: (S, 128) float32 + value: (S, 128) float32 + g_in: (S, 128) float32 + beta_in: (S, 128) float32 + + Returns: + output: (S, 128) float32 + """ + seq_len, dim = query.shape + + # Output tensor in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + + # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1 + seq_stride = dim + + # Initialize recurrent state in SBUF: (128, 128) + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # Sequential loop over tokens (state-dependent) + for t in nl.sequential_range(seq_len): + tok_offset = t * seq_stride + + # ---- Load inputs for token t ---- + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_t, + src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_t, + src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_t, + src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_t, + src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_t, + src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + # ---- Step 1: Decay state -- state = state * exp(g_t) ---- + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + nisa.tensor_copy(dst=state, src=state_decayed) + + # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ---- + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + # ---- Step 3: delta = (v_t - kv_mem) * beta_t ---- + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + # ---- Step 4: state += outer(k_t, delta) ---- + # Broadcast multiply: outer[i,j] = k_t[i] * delta[j] + # 1) Transpose delta (128,1) -> (1,128) in PSUM + # 2) Copy PSUM (1,128) -> SBUF (128,128) -- partition broadcast + # 3) Multiply by k_t (128,1) which broadcasts across free dim + # This avoids the nc_matmul P=1 outer product (wastes 127/128 TE lanes). + + # Transpose delta to get values along free dimension + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + # Copy PSUM (1, 128) -> SBUF (1, 128) first (NKI 0.3.0 requires matching P dims) + delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum) + + # Broadcast (1, 128) SBUF -> (128, 128) SBUF via nc_stream_shuffle + # Each partition row gets the same delta values + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=delta_row_sb[0:1, 0:P_MAX], + dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + # Element-wise multiply: outer[i,j] = delta_broadcast[i,j] * k_t[i,0] + # tensor_scalar broadcasts (P,1) k_t across all F columns + outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=outer_prod, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + engine=nisa.vector_engine, + ) + + # Accumulate into state + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add) + nisa.tensor_copy(dst=state, src=state_new) + + # ---- Step 5: o_t = state^T @ q_t ---- + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + # ---- Store output for token t ---- + nisa.dma_copy( + dst=output.ap(pattern=[[1, dim]], offset=tok_offset), + src=o_t, + ) + + return output + + +@nki.jit +def deltanet_recurrent_fwd_state( + query: nl.ndarray, # (S, 128) float32 + key: nl.ndarray, # (S, 128) float32 + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128 + beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128 +): + """NKI kernel for DeltaNet recurrent forward with final state output. + + Same recurrence as deltanet_recurrent_fwd, but ALSO writes the final + recurrent state (128, 128) to an output HBM buffer. This enables + CTE -> TKG state carry-over. + + Returns: + output: (S, 128) float32 -- per-token output + final_state: (128, 128) float32 -- recurrent state after last token + """ + seq_len, dim = query.shape + + # Output tensors in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + final_state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1 + seq_stride = dim + + # Initialize recurrent state in SBUF: (128, 128) + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # Sequential loop over tokens (state-dependent) + for t in nl.sequential_range(seq_len): + tok_offset = t * seq_stride + + # ---- Load inputs for token t ---- + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_t, + src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_t, + src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_t, + src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_t, + src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_t, + src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + # ---- Step 1: Decay state -- state = state * exp(g_t) ---- + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + nisa.tensor_copy(dst=state, src=state_decayed) + + # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ---- + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + # ---- Step 3: delta = (v_t - kv_mem) * beta_t ---- + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + # ---- Step 4: state += outer(k_t, delta) ---- + # Broadcast multiply: outer[i,j] = k_t[i] * delta[j] + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + # Copy PSUM (1, 128) -> SBUF (1, 128) first (NKI 0.3.0 requires matching P dims) + delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum) + + # Broadcast (1, 128) SBUF -> (128, 128) SBUF via nc_stream_shuffle + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=delta_row_sb[0:1, 0:P_MAX], + dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=outer_prod, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + engine=nisa.vector_engine, + ) + + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add) + nisa.tensor_copy(dst=state, src=state_new) + + # ---- Step 5: o_t = state^T @ q_t ---- + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + # ---- Store output for token t ---- + nisa.dma_copy( + dst=output.ap(pattern=[[1, dim]], offset=tok_offset), + src=o_t, + ) + + # ---- Write final state to HBM ---- + # state is (128, 128) in SBUF, copy to final_state in HBM + # Use dma_copy with full tile: P_MAX rows, dim cols + nisa.dma_copy(dst=final_state, src=state) + + return output, final_state diff --git a/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_chunked.py b/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_chunked.py new file mode 100644 index 00000000..88f0cc1b --- /dev/null +++ b/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_chunked.py @@ -0,0 +1,323 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""NKI per-chunk DeltaNet kernel for CTE (context encoding / prefill). + +Single-chunk kernel: processes one chunk (128 tokens) with Neumann-series +power-doubling for intra-chunk correction. The caller loops over chunks +in PyTorch, passing state between calls. + +Each kernel call: + - Takes one chunk of data: q, k, v, beta, g_cumsum, g_last (all 128x128) + - Takes recurrent state_in (128x128) + - Returns chunk output (128x128) and state_out (128x128) + +No sequence-indexed DMA inside the kernel -- all inputs/outputs are full tiles. +This avoids the DMA OOB issue seen with nl.sequential_range + slice indexing +in the NxDI model compilation context. + +NKI v3 (SDK 2.29, NKI 0.3.0). Uses nki.* namespace. +""" + +import nki +import nki.isa as nisa +import nki.language as nl + +P_MAX = 128 + + +@nki.jit +def deltanet_chunk_step( + query, # (128, 128) float32 -- one chunk, l2-normed+scaled + key, # (128, 128) float32 -- one chunk, l2-normed + value, # (128, 128) float32 -- one chunk + beta_broadcast, # (128, 128) float32 -- write gate broadcast to 128 + g_cumsum, # (128, 128) float32 -- cumsum of g within chunk, broadcast + g_last, # (128, 128) float32 -- g_cumsum[-1], constant in chunk, broadcast + state_in, # (128, 128) float32 -- recurrent state from previous chunk + lower_mask, # (128, 128) float32 -- strict lower triangular + identity, # (128, 128) float32 -- identity matrix + lower_mask_diag, # (128, 128) float32 -- lower tri with diagonal +): + """Process one chunk of DeltaNet. + + Returns: + output: (128, 128) float32 -- chunk output + state_out: (128, 128) float32 -- updated recurrent state + """ + C, dim = query.shape # C = 128, dim = 128 + + # Output tensors in HBM + output = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.shared_hbm) + state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # Load all inputs into SBUF + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=q_c, src=query) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=k_c, src=key) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=v_c, src=value) + + beta_c = nl.ndarray((P_MAX, dim), dtype=beta_broadcast.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=beta_c, src=beta_broadcast) + + gc_c = nl.ndarray((P_MAX, dim), dtype=g_cumsum.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=gc_c, src=g_cumsum) + + gl_c = nl.ndarray((P_MAX, dim), dtype=g_last.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=gl_c, src=g_last) + + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=state, src=state_in) + + # Load masks + eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=eye, src=identity) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + Lmask_d = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask_d, src=lower_mask_diag) + + # ============================================================ + # k_beta = K * beta, v_beta = V * beta + # ============================================================ + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=k_beta, data1=k_c, data2=beta_c, op=nl.multiply) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_beta, data1=v_c, data2=beta_c, op=nl.multiply) + + # ============================================================ + # exp(g_cumsum) and exp(-g_cumsum) + # ============================================================ + exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_gc, op=nl.exp, data=gc_c, bias=None, scale=1.0) + + neg_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_gc, + data=gc_c, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + exp_neg_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_neg_gc, op=nl.exp, data=neg_gc, bias=None, scale=1.0) + + # exp(g_last) for state decay + exp_gl = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_gl, op=nl.exp, data=gl_c, bias=None, scale=1.0) + + # ============================================================ + # Phase 1: Build A matrix (intra-chunk correction) + # QK = k_beta @ k^T -- contract over features + # ============================================================ + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kb_T_psum, stationary=k_beta, moving=eye) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=k_T_psum, stationary=k_c, moving=eye) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T) + QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK, src=QK_psum) + + # ============================================================ + # Decay mask: QK_decay[i,j] = QK[i,j] * exp(gc[i]) * exp(-gc[j]) + # ============================================================ + QK_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_row, data1=QK, data2=exp_gc, op=nl.multiply) + + QK_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_r_T_psum, stationary=QK_row, moving=eye) + QK_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_r_T, src=QK_r_T_psum) + + QK_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_r_T_col, data1=QK_r_T, data2=exp_neg_gc, op=nl.multiply) + + QK_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_d_psum, stationary=QK_r_T_col, moving=eye) + QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_decay, src=QK_d_psum) + + # A = -QK_decay * lower_mask + neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_QK_decay, + data=QK_decay, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + A = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=A, data1=neg_QK_decay, data2=Lmask, op=nl.multiply) + + # ============================================================ + # Neumann power-doubling: N = (I+A)(I+A^2)...(I+A^{64}) + # ============================================================ + P_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=P_acc, data1=eye, data2=A, op=nl.add) + + A_pow = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_pow, src=A) + + for _round in nl.sequential_range(6): + Ap_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Ap_T_psum, stationary=A_pow, moving=eye) + Ap_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=Ap_T, src=Ap_T_psum) + + Ap_sq_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Ap_sq_psum, stationary=Ap_T, moving=A_pow) + nisa.tensor_copy(dst=A_pow, src=Ap_sq_psum) + + IpA = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=IpA, data1=eye, data2=A_pow, op=nl.add) + + IpA_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=IpA_T_psum, stationary=IpA, moving=eye) + IpA_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=IpA_T, src=IpA_T_psum) + + Pacc_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Pacc_psum, stationary=IpA_T, moving=P_acc) + nisa.tensor_copy(dst=P_acc, src=Pacc_psum) + + # ============================================================ + # Apply N: value_corr = N @ v_beta, k_cumdecay = N @ (k_beta * exp_gc) + # ============================================================ + N_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=N_T_psum, stationary=P_acc, moving=eye) + N_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=N_T, src=N_T_psum) + + vc_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vc_psum, stationary=N_T, moving=v_beta) + value_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=value_corr, src=vc_psum) + + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=kb_exp_gc, data1=k_beta, data2=exp_gc, op=nl.multiply) + + kcd_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_psum, stationary=N_T, moving=kb_exp_gc) + k_cumdecay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_cumdecay, src=kcd_psum) + + # ============================================================ + # Phase 2: Inter-chunk state propagation + # attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + # ============================================================ + q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=q_T_psum, stationary=q_c, moving=eye) + q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T) + qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_raw, src=qk_psum) + + qk_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=qk_row, data1=qk_raw, data2=exp_gc, op=nl.multiply) + + qk_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_r_T_psum, stationary=qk_row, moving=eye) + qk_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_r_T, src=qk_r_T_psum) + + qk_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=qk_r_T_col, data1=qk_r_T, data2=exp_neg_gc, op=nl.multiply) + + qk_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_d_psum, stationary=qk_r_T_col, moving=eye) + qk_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_decay, src=qk_d_psum) + + attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=attn_intra, data1=qk_decay, data2=Lmask_d, op=nl.multiply) + + # ============================================================ + # v_prime = k_cumdecay @ state + # ============================================================ + kcd_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_T_psum, stationary=k_cumdecay, moving=eye) + kcd_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kcd_T, src=kcd_T_psum) + + vp_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vp_psum, stationary=kcd_T, moving=state) + v_prime = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=v_prime, src=vp_psum) + + v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_new, data1=value_corr, data2=v_prime, op=nl.subtract) + + # ============================================================ + # attn_inter = (q * exp(g_cumsum)) @ state + # ============================================================ + q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=q_exp, data1=q_c, data2=exp_gc, op=nl.multiply) + + qe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qe_T_psum, stationary=q_exp, moving=eye) + qe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qe_T, src=qe_T_psum) + + ai_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state) + attn_inter = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=attn_inter, src=ai_psum) + + # ============================================================ + # attn_intra @ v_new + # ============================================================ + ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_T_psum, stationary=attn_intra, moving=eye) + ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ai_T, src=ai_T_psum) + + intra_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=intra_psum, stationary=ai_T, moving=v_new) + intra_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=intra_out, src=intra_psum) + + # ============================================================ + # chunk_output = attn_inter + intra_out + # ============================================================ + chunk_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add) + + nisa.dma_copy(dst=output, src=chunk_out) + + # ============================================================ + # State update: state_new = exp(g_last) * (state + k_raw_decay^T @ v_new) + # ============================================================ + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=k_raw_decay, data1=k_c, data2=exp_neg_gc, op=nl.multiply) + + kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_psum, stationary=k_raw_decay, moving=v_new) + kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_outer, src=kv_psum) + + state_plus = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_plus, data1=state, data2=kv_outer, op=nl.add) + + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state_plus, data2=exp_gl, op=nl.multiply) + + nisa.dma_copy(dst=state_out, src=state_new) + + return output, state_out diff --git a/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_fused.py b/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_fused.py new file mode 100644 index 00000000..3447a138 --- /dev/null +++ b/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_fused.py @@ -0,0 +1,577 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Fused single-kernel DeltaNet chunked forward for CTE (context encoding). + +SSD-style architecture: processes ALL chunks for one (batch, head) pair in +a single NKI kernel call. State (128x128) persists in SBUF across chunks — +no HBM round-trips for inter-chunk state propagation. + +Key optimizations over nki_deltanet_chunked.py: + 1. Single kernel call per (B,H) instead of B*H*num_chunks calls + 2. State in SBUF across all chunks (no HBM state read/write per chunk) + 3. In-kernel cumsum via tensor_tensor_scan (no PyTorch cumsum) + 4. Masks and constants loaded once, reused across chunks + 5. Uses tensor_scalar for partition-broadcast (no explicit broadcast loops) + 6. nc_transpose (Vector Engine) for all 128x128 transposes instead of + nc_matmul(moving=eye) (Tensor Engine) — frees TE for actual math + +NKI 0.3.0 (SDK 2.29). k_dim = v_dim = 128 = P_MAX exactly. +Chunk size = 128 = P_MAX (one tile per chunk). + +Mathematical framework (same as nki_deltanet_chunked.py): + Per-chunk Neumann-series power-doubling for intra-chunk correction: + A = -QK_decay * lower_mask + N = (I+A)(I+A^2)(I+A^4)...(I+A^64) [6 rounds] + value_corr = N @ v_beta + k_cumdecay = N @ (k_beta * exp(gc)) + + Inter-chunk state propagation: + v_prime = k_cumdecay @ state + v_new = value_corr - v_prime + attn_inter = (q * exp(gc)) @ state + attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + output = attn_inter + attn_intra @ v_new + state = exp(g_last) * (state + k_raw_decay^T @ v_new) +""" + +import numpy as np + +import nki +import nki.isa as nisa +import nki.language as nl + +P_MAX = 128 # Partition dim = chunk_size = k_dim = v_dim +CHUNK_SIZE = 128 + +# Broadcast partition 0 to all partitions in a 32-wide group +_BROADCAST_MASK = [0] * 32 + + +def _make_lower_mask(): + """Strict lower triangular (128x128) as numpy constant.""" + return np.tril(np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=-1) + + +def _make_lower_mask_diag(): + """Lower triangular with diagonal (128x128) as numpy constant.""" + return np.tril(np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=0) + + +def _make_identity(): + """Identity matrix (128x128) as numpy constant.""" + return np.eye(CHUNK_SIZE, dtype=np.float32) + + +@nki.jit +def deltanet_fused_chunked_fwd( + query: nl.ndarray, # (S, 128) float32 — l2-normed and scaled + key: nl.ndarray, # (S, 128) float32 — l2-normed + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 1) float32 — per-token log-decay (NOT cumsum) + beta_in: nl.ndarray, # (S, 1) float32 — per-token write gate + lower_mask: nl.ndarray, # (128, 128) float32 — strict lower tri + identity: nl.ndarray, # (128, 128) float32 — identity + lower_mask_diag: nl.ndarray, # (128, 128) float32 — lower tri with diag +): + """Fused chunked DeltaNet forward — single kernel call per (batch, head). + + Processes all chunks sequentially within the kernel, keeping the recurrent + state (128x128) in SBUF across chunks. Returns per-token output and + final state. + + Input requirements: + - S must be divisible by 128 (pad before calling) + - query must be l2-normed and scaled by 1/sqrt(k_dim) + - key must be l2-normed + - g_in is RAW log-decay (cumsum computed in-kernel via tensor_tensor_scan) + - beta_in is sigmoid(b) (write gate) + + Returns: + output: (S, 128) float32 + final_state: (128, 128) float32 + """ + seq_len = query.shape[0] + dim = query.shape[1] # 128 + num_chunks = seq_len // CHUNK_SIZE + + # Output tensors in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + final_state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # ================================================================ + # Load constant masks into SBUF once (reused across all chunks) + # ================================================================ + eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=eye, src=identity) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + Lmask_d = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask_d, src=lower_mask_diag) + + # Ones vector for cumsum scan: (1, CHUNK_SIZE) + ones_1xC = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=ones_1xC, value=1.0) + + # Zero initial for cumsum scan + zero_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=zero_11, value=0.0) + + # ================================================================ + # Initialize recurrent state in SBUF — persists across ALL chunks + # ================================================================ + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # ================================================================ + # Sequential chunk processing + # ================================================================ + for i_chunk in nl.sequential_range(num_chunks): + chunk_start = i_chunk * CHUNK_SIZE + + # ---- Load chunk data from HBM ---- + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_c, + src=query[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_c, + src=key[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_c, + src=value[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + # g: (CHUNK_SIZE, 1) — raw log-decay per token + g_chunk_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_chunk_p[0:CHUNK_SIZE, 0:1], + src=g_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + # beta: (CHUNK_SIZE, 1) — write gate scalar per token + beta_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_p[0:CHUNK_SIZE, 0:1], + src=beta_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + # ---- In-kernel cumsum of g via tensor_tensor_scan ---- + # Need g as (1, CHUNK_SIZE) for scan along free dim. + # Transpose: (CHUNK_SIZE, 1) -> (1, CHUNK_SIZE) via nc_transpose + g_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=g_padded, value=0.0) + nisa.tensor_copy( + dst=g_padded[0:CHUNK_SIZE, 0:1], + src=g_chunk_p[0:CHUNK_SIZE, 0:1], + ) + + g_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=g_tp_psum, data=g_padded) + + g_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=g_row[0:1, 0:CHUNK_SIZE], + src=g_tp_psum[0:1, 0:CHUNK_SIZE], + ) + + # cumsum: gc_row[t] = 1.0 * gc_row[t-1] + g_row[t] + gc_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor_scan( + dst=gc_row[0:1, 0:CHUNK_SIZE], + data0=ones_1xC[0:1, 0:CHUNK_SIZE], + data1=g_row[0:1, 0:CHUNK_SIZE], + initial=zero_11[0:1, 0:1], + op0=nl.multiply, + op1=nl.add, + ) + + # Transpose gc back to (CHUNK_SIZE, 1) partition layout + gc_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=gc_padded, value=0.0) + nisa.tensor_copy( + dst=gc_padded[0:1, 0:CHUNK_SIZE], + src=gc_row[0:1, 0:CHUNK_SIZE], + ) + + gc_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=gc_tp_psum, data=gc_padded) + + # gc_p: (P_MAX, 1) — cumulative sum of g per token in this chunk + gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gc_p[0:CHUNK_SIZE, 0:1], + src=gc_tp_psum[0:CHUNK_SIZE, 0:1], + ) + + # g_last = gc[-1] (scalar) — needed for state decay + gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gl_11[0:1, 0:1], + src=gc_row[0:1, CHUNK_SIZE - 1 : CHUNK_SIZE], + ) + + # ---- Compute exp(gc), exp(-gc), exp(g_last) as (P_MAX, 1) scalars ---- + # These (P_MAX, 1) tensors are used with tensor_scalar to broadcast + # across the free dimension without explicit (P_MAX, dim) copies. + + exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_gc_p, + data=gc_p, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + exp_neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_neg_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=neg_gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + # exp(g_last): scalar, then broadcast to (P_MAX, 1) + exp_gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_11, + op=nl.exp, + data=gl_11, + bias=None, + scale=1.0, + ) + + exp_gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=exp_gl_11[0:1, 0:1], + dst=exp_gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + # ============================================================ + # k_beta = K * beta, v_beta = V * beta + # tensor_scalar broadcasts beta_p (P_MAX, 1) across free dim + # ============================================================ + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_beta, + data=k_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=v_beta, + data=v_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + # ============================================================ + # Phase 1: Build A matrix (intra-chunk correction) + # Transpose K and K_beta for matmul + # ============================================================ + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kb_T_psum, data=k_beta) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=k_T_psum, data=k_c) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + # QK = k_beta^T @ k (contract over features) + QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T) + QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK, src=QK_psum) + + # ============================================================ + # Decay mask: QK_decay[i,j] = QK[i,j] * exp(gc[i]) * exp(-gc[j]) + # + # Row scaling: QK_row[i,:] = QK[i,:] * exp(gc[i]) + # Then transpose, column scale, transpose back. + # Uses tensor_scalar with (P_MAX,1) operand for row scaling. + # ============================================================ + QK_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=QK_row, + data=QK, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + # Transpose to scale columns (now rows in transposed view) + QK_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=QK_r_T_psum, data=QK_row) + QK_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_r_T, src=QK_r_T_psum) + + QK_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=QK_r_T_col, + data=QK_r_T, + op0=nl.multiply, + operand0=exp_neg_gc_p, + engine=nisa.vector_engine, + ) + + # Transpose back + QK_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=QK_d_psum, data=QK_r_T_col) + QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_decay, src=QK_d_psum) + + # A = -QK_decay * lower_mask + neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_QK_decay, + data=QK_decay, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + A_mat = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=A_mat, data1=neg_QK_decay, data2=Lmask, op=nl.multiply) + + # ============================================================ + # Neumann power-doubling: N = (I+A)(I+A^2)...(I+A^{64}) + # 6 rounds → resolves rank up to 2^6 = 64 (sufficient for chunk=128) + # ============================================================ + P_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=P_acc, data1=eye, data2=A_mat, op=nl.add) + + A_pow = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_pow, src=A_mat) + + for _round in nl.sequential_range(6): + # A_pow = A_pow^2: transpose A_pow, then matmul + Ap_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=Ap_T_psum, data=A_pow) + Ap_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=Ap_T, src=Ap_T_psum) + + Ap_sq_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Ap_sq_psum, stationary=Ap_T, moving=A_pow) + nisa.tensor_copy(dst=A_pow, src=Ap_sq_psum) + + # P_acc = (I + A_pow) @ P_acc: transpose IpA, then matmul + IpA = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=IpA, data1=eye, data2=A_pow, op=nl.add) + + IpA_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=IpA_T_psum, data=IpA) + IpA_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=IpA_T, src=IpA_T_psum) + + Pacc_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Pacc_psum, stationary=IpA_T, moving=P_acc) + nisa.tensor_copy(dst=P_acc, src=Pacc_psum) + + # ============================================================ + # Apply N: value_corr = N @ v_beta + # k_cumdecay = N @ (k_beta * exp(gc)) + # ============================================================ + N_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=N_T_psum, data=P_acc) + N_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=N_T, src=N_T_psum) + + vc_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vc_psum, stationary=N_T, moving=v_beta) + value_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=value_corr, src=vc_psum) + + # k_beta * exp(gc): row-scaled + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=kb_exp_gc, + data=k_beta, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + kcd_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_psum, stationary=N_T, moving=kb_exp_gc) + k_cumdecay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_cumdecay, src=kcd_psum) + + # ============================================================ + # Phase 2: Inter-chunk state propagation + # attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + # ============================================================ + q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=q_T_psum, data=q_c) + q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T) + qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_raw, src=qk_psum) + + # Row-scale by exp(gc) + qk_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=qk_row, + data=qk_raw, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + # Transpose, column-scale by exp(-gc), transpose back + qk_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qk_r_T_psum, data=qk_row) + qk_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_r_T, src=qk_r_T_psum) + + qk_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=qk_r_T_col, + data=qk_r_T, + op0=nl.multiply, + operand0=exp_neg_gc_p, + engine=nisa.vector_engine, + ) + + qk_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qk_d_psum, data=qk_r_T_col) + qk_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_decay, src=qk_d_psum) + + attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=attn_intra, data1=qk_decay, data2=Lmask_d, op=nl.multiply + ) + + # ============================================================ + # v_prime = k_cumdecay @ state (state is in SBUF!) + # ============================================================ + kcd_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kcd_T_psum, data=k_cumdecay) + kcd_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kcd_T, src=kcd_T_psum) + + vp_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vp_psum, stationary=kcd_T, moving=state) + v_prime = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=v_prime, src=vp_psum) + + v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_new, data1=value_corr, data2=v_prime, op=nl.subtract) + + # ============================================================ + # attn_inter = (q * exp(gc)) @ state (state is in SBUF!) + # ============================================================ + q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_exp, + data=q_c, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + qe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qe_T_psum, data=q_exp) + qe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qe_T, src=qe_T_psum) + + ai_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state) + attn_inter = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=attn_inter, src=ai_psum) + + # ============================================================ + # attn_intra @ v_new + # ============================================================ + ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=ai_T_psum, data=attn_intra) + ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ai_T, src=ai_T_psum) + + intra_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=intra_psum, stationary=ai_T, moving=v_new) + intra_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=intra_out, src=intra_psum) + + # ============================================================ + # chunk_output = attn_inter + intra_out + # ============================================================ + chunk_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add) + + # Store output chunk to HBM + nisa.dma_copy( + dst=output[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + src=chunk_out, + ) + + # ============================================================ + # State update: state = exp(g_last) * (state + k_raw_decay^T @ v_new) + # state is updated IN-PLACE in SBUF — no HBM round-trip! + # ============================================================ + + # k_raw_decay = k * exp(-gc) + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_raw_decay, + data=k_c, + op0=nl.multiply, + operand0=exp_neg_gc_p, + engine=nisa.vector_engine, + ) + + # k_raw_decay^T @ v_new → (dim, dim) outer product sum + # nc_matmul: result[M,N] = sum_K stationary[K,M] * moving[K,N] + # stationary=k_raw_decay (P_MAX, dim), moving=v_new (P_MAX, dim) + # Result: sum over tokens -> (dim, dim) + kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_psum, stationary=k_raw_decay, moving=v_new) + kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_outer, src=kv_psum) + + # state = state + kv_outer + state_plus = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_plus, data1=state, data2=kv_outer, op=nl.add) + + # state = state_plus * exp(g_last) + # tensor_scalar broadcasts exp_gl_p (P_MAX, 1) across free dim + nisa.tensor_scalar( + dst=state, + data=state_plus, + op0=nl.multiply, + operand0=exp_gl_p, + engine=nisa.vector_engine, + ) + + # ---- Write final state to HBM ---- + nisa.dma_copy(dst=final_state_out, src=state) + + return output, final_state_out diff --git a/contrib/models/Qwen3.5-9B/test/__init__.py b/contrib/models/Qwen3.5-9B/test/__init__.py new file mode 100644 index 00000000..04f8b7b7 --- /dev/null +++ b/contrib/models/Qwen3.5-9B/test/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/contrib/models/Qwen3.5-9B/test/integration/__init__.py b/contrib/models/Qwen3.5-9B/test/integration/__init__.py new file mode 100644 index 00000000..04f8b7b7 --- /dev/null +++ b/contrib/models/Qwen3.5-9B/test/integration/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/contrib/models/Qwen3.5-9B/test/integration/test_model.py b/contrib/models/Qwen3.5-9B/test/integration/test_model.py new file mode 100644 index 00000000..ff8a25c9 --- /dev/null +++ b/contrib/models/Qwen3.5-9B/test/integration/test_model.py @@ -0,0 +1,486 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Integration tests for Qwen3.5-9B on Neuron. + +Tests compilation, loading, inference accuracy, and performance using +the full 9B model with pre-downloaded HuggingFace weights on a trn2 instance. + +These tests use the same Qwen35* classes and QWEN35_* env vars because the +underlying `qwen3_5` dense hybrid architecture is shared across Qwen3.5/3.6. + +Note: A mini model option is not provided because DeltaNet layers require NKI +kernels that only execute on Neuron devices, and the hybrid DeltaNet + GQA +architecture should be validated at TP=4 before attempting TP=2. + +Environment variables: + QWEN35_MODEL_PATH Path to HF model weights (required) + QWEN35_COMPILED_PATH Path to compiled artifacts (default: /tmp/qwen35_9b_traced) + QWEN35_TP_DEGREE Tensor parallelism degree (default: 4) + QWEN35_SEQ_LEN Max sequence length (default: 128) + TTFT_THRESHOLD_MS Max TTFT in ms (default: 5000) + THROUGHPUT_THRESHOLD Min throughput in tok/s (default: 5.0) + +Prerequisites: + - trn2.3xlarge or larger with TP >= 4 NeuronCores available + - NXDI installed (neuronx_distributed_inference) + - HuggingFace weights downloaded to QWEN35_MODEL_PATH + - SDK 2.29+ (NKI 0.3.0 required for DeltaNet kernels) + +Usage: + # Full model (trn2.3xlarge, TP=4): + QWEN35_MODEL_PATH=/mnt/models/Qwen3.5-9B \\ + QWEN35_COMPILED_PATH=/mnt/models/qwen35_9b_traced \\ + pytest test/integration/test_model.py --capture=tee-sys +""" + +import gc +import os +import sys +import time + +import pytest +import torch + +# Ensure the contrib root (Qwen3.5-9B/) is on sys.path +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +# ── Configuration from environment ────────────────────────────────────── + +MODEL_PATH = os.environ.get("QWEN35_MODEL_PATH", "") +COMPILED_PATH = os.environ.get("QWEN35_COMPILED_PATH", "/tmp/qwen35_9b_traced") +TP_DEGREE = int(os.environ.get("QWEN35_TP_DEGREE", "4")) +SEQ_LEN = int(os.environ.get("QWEN35_SEQ_LEN", "128")) +TTFT_THRESHOLD_MS = float(os.environ.get("TTFT_THRESHOLD_MS", "5000")) +THROUGHPUT_THRESHOLD = float(os.environ.get("THROUGHPUT_THRESHOLD", "5.0")) + +requires_model_path = pytest.mark.skipif( + not MODEL_PATH, + reason=( + "QWEN35_MODEL_PATH not set. Integration tests require the full 9B model " + "weights. Set QWEN35_MODEL_PATH=/path/to/Qwen3.5-9B to run these tests." + ), +) + + +# ── Fixtures ──────────────────────────────────────────────────────────── + + +@pytest.fixture(scope="module") +def model_path(): + """Return path to model weights.""" + return MODEL_PATH + + +@pytest.fixture(scope="module") +def compiled_model(model_path): + """Compile and load the model on Neuron.""" + import json + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + + neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=False, + flash_decoding_enabled=False, + logical_nc_config=2, + save_sharded_checkpoint=True, + ) + + # Read config.json directly (model_type 'qwen3_5' may not be in + # AutoConfig registry for all transformers versions) + with open(os.path.join(model_path, "config.json")) as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + config_dict["tie_word_embeddings"] = full_config.get( + "tie_word_embeddings", + text_config.get("tie_word_embeddings", False), + ) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + + inf_config = Qwen35InferenceConfig( + neuron_config=neuron_config, + **config_dict, + ) + + # Compile if no existing artifacts + compiled_path = COMPILED_PATH + neff_path = os.path.join(compiled_path, "model.pt") + if not os.path.exists(neff_path): + print(f"Compiling to {compiled_path}...") + model = NeuronQwen35ForCausalLM(model_path, inf_config) + model.compile(compiled_path) + del model + gc.collect() + + # Load + print(f"Loading from {compiled_path}...") + model = NeuronQwen35ForCausalLM(compiled_path) + model.load(compiled_path) + return model + + +@pytest.fixture(scope="module") +def tokenizer(model_path): + """Load tokenizer.""" + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(model_path, padding_side="right") + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + + +@pytest.fixture(scope="module") +def generation_config(tokenizer): + """Create generation config.""" + from transformers import GenerationConfig + + return GenerationConfig( + do_sample=True, + top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + +def _generate(model, tokenizer, generation_config, prompt, max_new_tokens=20): + """Generate text using the NXDI model.""" + import transformers + + from neuronx_distributed_inference.utils.hf_adapter import ( + HuggingFaceGenerationAdapter, + ) + + inputs = tokenizer(prompt, padding=True, return_tensors="pt") + gen_model = HuggingFaceGenerationAdapter(model) + gen_model.generation_config.transformers_version = transformers.__version__ + generation_config.transformers_version = transformers.__version__ + outputs = gen_model.generate( + inputs.input_ids, + generation_config=generation_config, + attention_mask=inputs.attention_mask, + max_new_tokens=max_new_tokens, + ) + return outputs[0].tolist(), tokenizer.decode(outputs[0], skip_special_tokens=True) + + +def _is_repetitive(text, max_repeat=5): + """Check for excessive word repetition.""" + words = text.split() + if len(words) < max_repeat: + return False + for i in range(len(words) - max_repeat + 1): + if len(set(words[i : i + max_repeat])) == 1: + return True + return False + + +# ── Smoke Tests ───────────────────────────────────────────────────────── + + +@requires_model_path +def test_model_loads(compiled_model): + """Model compiles and loads successfully.""" + assert compiled_model is not None + assert hasattr(compiled_model, "neuron_config") + print(" Model loaded successfully") + + +@requires_model_path +def test_model_generates(compiled_model, tokenizer, generation_config): + """Model generates at least 5 tokens.""" + tokens, text = _generate( + compiled_model, + tokenizer, + generation_config, + "Hello, I am a language model", + max_new_tokens=20, + ) + input_len = len(tokenizer.encode("Hello, I am a language model")) + new_tokens = len(tokens) - input_len + assert new_tokens >= 5, f"Expected >= 5 new tokens, got {new_tokens}" + print(f" Generated {new_tokens} tokens: {text[:100]}...") + + +# ── Accuracy Tests ────────────────────────────────────────────────────── + + +@requires_model_path +def test_output_coherence(compiled_model, tokenizer, generation_config): + """Output should contain multiple words and not be excessively repetitive.""" + _, text = _generate( + compiled_model, + tokenizer, + generation_config, + "The capital of France is", + max_new_tokens=30, + ) + generated = text[len("The capital of France is") :].strip() + words = generated.split() + assert len(words) >= 3, f"Expected >= 3 words, got {len(words)}: '{generated}'" + assert not _is_repetitive(generated), ( + f"Output is excessively repetitive: '{generated}'" + ) + print(f" Output coherent ({len(words)} words): {generated[:80]}...") + + +@requires_model_path +def test_top_token_valid(compiled_model, tokenizer, generation_config): + """First generated token should be a valid decodable token.""" + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + "Hello!", + max_new_tokens=1, + ) + input_len = len(tokenizer.encode("Hello!")) + first_new = tokens[input_len] + assert 0 <= first_new < tokenizer.vocab_size, ( + f"Token {first_new} out of vocab range" + ) + decoded = tokenizer.decode([first_new]) + assert len(decoded) > 0, f"Token {first_new} decoded to empty string" + print(f" First token: {first_new} -> '{decoded}'") + + +@requires_model_path +def test_simple_factual_generation(compiled_model, tokenizer, generation_config): + """A simple factual prompt should produce the expected entity.""" + prompt = "The largest ocean on Earth is" + _, text = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=30, + ) + generated = text[len(prompt) :].strip() + assert "pacific" in generated.lower(), ( + f"Expected 'Pacific' in output, got: '{generated}'" + ) + print(f" Simple factual answer: {generated}") + + +# ── Performance Tests ─────────────────────────────────────────────────── + + +@requires_model_path +def test_performance_ttft(compiled_model, tokenizer, generation_config): + """Time to first token should be within threshold.""" + prompt = "Hello, I am a language model" + + # Warmup + _generate(compiled_model, tokenizer, generation_config, prompt, max_new_tokens=1) + + # Measure + times = [] + for _ in range(3): + t0 = time.perf_counter() + _generate( + compiled_model, tokenizer, generation_config, prompt, max_new_tokens=1 + ) + times.append((time.perf_counter() - t0) * 1000) + + avg_ms = sum(times) / len(times) + print(f" TTFT: {avg_ms:.1f} ms (threshold: {TTFT_THRESHOLD_MS} ms)") + assert avg_ms < TTFT_THRESHOLD_MS, ( + f"TTFT {avg_ms:.1f}ms > threshold {TTFT_THRESHOLD_MS}ms" + ) + + +@requires_model_path +def test_performance_throughput(compiled_model, tokenizer, generation_config): + """Throughput should meet minimum threshold.""" + prompt = "Once upon a time" + num_new_tokens = 20 + + # Warmup + _generate(compiled_model, tokenizer, generation_config, prompt, max_new_tokens=5) + + # Measure + t0 = time.perf_counter() + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=num_new_tokens, + ) + elapsed = time.perf_counter() - t0 + + input_len = len(tokenizer.encode(prompt)) + actual_new = len(tokens) - input_len + throughput = actual_new / elapsed if elapsed > 0 else 0 + + print( + f" Throughput: {throughput:.1f} tok/s ({actual_new} tokens in {elapsed:.2f}s)" + ) + print(f" Threshold: {THROUGHPUT_THRESHOLD} tok/s") + assert throughput > THROUGHPUT_THRESHOLD, ( + f"Throughput {throughput:.1f} tok/s < threshold {THROUGHPUT_THRESHOLD}" + ) + + +# ── Multi-Prompt Quality Test ────────────────────────────────────────── + + +@requires_model_path +def test_multi_prompt_generation(compiled_model, tokenizer, generation_config): + """Multiple prompts should produce coherent outputs.""" + prompts = [ + "The capital of France is", + "def fibonacci(n):", + "The largest ocean on Earth is", + "To make a chocolate cake, you need", + ] + + for prompt in prompts: + _, text = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=30, + ) + generated = text[len(prompt) :].strip() + words = generated.split() + assert len(words) >= 2, ( + f"Prompt '{prompt}' generated too few words: '{generated}'" + ) + assert not _is_repetitive(generated), ( + f"Prompt '{prompt}' produced repetitive output: '{generated}'" + ) + print(f" '{prompt[:30]}...' -> {generated[:60]}...") + + +# ── Standalone runner ─────────────────────────────────────────────────── + +if __name__ == "__main__": + print("=" * 60) + print("Qwen3.5-9B Integration Tests") + print("=" * 60) + + if not MODEL_PATH: + print("\nQWEN35_MODEL_PATH not set. Provide the model path to run tests:") + print(" QWEN35_MODEL_PATH=/path/to/Qwen3.5-9B \\") + print(" QWEN35_COMPILED_PATH=/mnt/models/qwen35_9b_traced \\") + print(" python -m pytest test/integration/test_model.py --capture=tee-sys") + sys.exit(0) + + # Setup + from transformers import AutoTokenizer, GenerationConfig as GenConfig + + tok = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right") + if tok.pad_token is None: + tok.pad_token = tok.eos_token + gen_cfg = GenConfig( + do_sample=True, + top_k=1, + pad_token_id=tok.pad_token_id, + eos_token_id=tok.eos_token_id, + ) + + # Build model + import json + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + + nc = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=False, + flash_decoding_enabled=False, + logical_nc_config=2, + save_sharded_checkpoint=True, + ) + + with open(os.path.join(MODEL_PATH, "config.json")) as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + config_dict["tie_word_embeddings"] = full_config.get( + "tie_word_embeddings", + text_config.get("tie_word_embeddings", False), + ) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + ic = Qwen35InferenceConfig(neuron_config=nc, **config_dict) + + cp = COMPILED_PATH + if not os.path.exists(os.path.join(cp, "model.pt")): + print(f"Compiling to {cp}...") + m = NeuronQwen35ForCausalLM(MODEL_PATH, ic) + m.compile(cp) + del m + gc.collect() + + print(f"Loading from {cp}...") + model = NeuronQwen35ForCausalLM(cp) + model.load(cp) + + tests = [ + ("model_loads", lambda: test_model_loads(model)), + ("model_generates", lambda: test_model_generates(model, tok, gen_cfg)), + ("output_coherence", lambda: test_output_coherence(model, tok, gen_cfg)), + ("top_token_valid", lambda: test_top_token_valid(model, tok, gen_cfg)), + ( + "simple_factual_generation", + lambda: test_simple_factual_generation(model, tok, gen_cfg), + ), + ("performance_ttft", lambda: test_performance_ttft(model, tok, gen_cfg)), + ( + "performance_throughput", + lambda: test_performance_throughput(model, tok, gen_cfg), + ), + ( + "multi_prompt_generation", + lambda: test_multi_prompt_generation(model, tok, gen_cfg), + ), + ] + + passed = 0 + for name, fn in tests: + print(f"\n--- {name} ---") + try: + fn() + print(f" PASS") + passed += 1 + except Exception as e: + print(f" FAIL: {e}") + + print(f"\n{'=' * 60}") + print(f"Results: {passed}/{len(tests)} passed") + print(f"{'=' * 60}") diff --git a/contrib/models/Qwen3.5-9B/test/parity/deltanet_path_probe.py b/contrib/models/Qwen3.5-9B/test/parity/deltanet_path_probe.py new file mode 100644 index 00000000..2ee510d1 --- /dev/null +++ b/contrib/models/Qwen3.5-9B/test/parity/deltanet_path_probe.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""DeltaNet path parity probe for Qwen3.5-9B. + +Run on a Neuron instance after weights are available. This intentionally is not +part of normal pytest collection because it can compile NKI kernels and requires +the full checkpoint. + +Example: + cd contrib/models/Qwen3.5-9B + QWEN35_MODEL_PATH=/mnt/models/Qwen3.5-9B \\ + python test/parity/deltanet_path_probe.py --layer-idx 0 --seq-len 128 +""" + +import argparse +import json +import os +import sys +from contextlib import contextmanager + +import torch +import torch.nn.functional as F + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from neuronx_distributed_inference.models.config import NeuronConfig +from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronGatedDeltaNet + + +@contextmanager +def patched_env(**updates): + old = {k: os.environ.get(k) for k in updates} + for k, v in updates.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = str(v) + try: + yield + finally: + for k, v in old.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v + + +def cosine(a, b): + return F.cosine_similarity(a.float().flatten(), b.float().flatten(), dim=0).item() + + +def max_abs(a, b): + return (a.float() - b.float()).abs().max().item() + + +def load_config(model_path, tp_degree): + with open(os.path.join(model_path, "config.json")) as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + config_dict["tie_word_embeddings"] = full_config.get( + "tie_word_embeddings", + text_config.get("tie_word_embeddings", False), + ) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + + neuron_config = NeuronConfig( + tp_degree=tp_degree, + batch_size=1, + max_batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + enable_bucketing=False, + flash_decoding_enabled=False, + logical_nc_config=2, + ) + return Qwen35InferenceConfig(neuron_config=neuron_config, **config_dict) + + +def strip_prefix_state_dict(state_dict): + stripped = {} + for k, v in state_dict.items(): + if k.startswith("language_model."): + stripped[k.replace("language_model.", "", 1)] = v + elif k.startswith("model.language_model."): + stripped[k.replace("model.language_model.", "", 1)] = v + elif k.startswith("model."): + stripped[k.replace("model.", "", 1)] = v + else: + stripped[k] = v + return stripped + + +def load_deltanet_layer_weights(module, model_path, layer_idx): + from transformers import AutoModelForCausalLM + + hf_model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + low_cpu_mem_usage=True, + ) + state_dict = strip_prefix_state_dict(hf_model.state_dict()) + prefix = f"layers.{layer_idx}.linear_attn." + layer_sd = {} + for name in module.state_dict().keys(): + key = prefix + name + if key in state_dict: + layer_sd[name] = state_dict[key] + + missing, unexpected = module.load_state_dict(layer_sd, strict=False) + missing = [m for m in missing if not m.endswith("_buffer")] + if missing or unexpected: + raise RuntimeError(f"weight load mismatch: missing={missing}, unexpected={unexpected}") + del hf_model + + +def run_path(module, hidden_states, mode): + env = { + "USE_NKI_FUSED": "0", + "USE_NKI_CHUNKED": None, + "USE_NKI": None, + "DELTANET_SEQUENTIAL": None, + "USE_PYTORCH_CHUNK": None, + } + if mode == "sequential": + env["DELTANET_SEQUENTIAL"] = "1" + elif mode == "fused": + env["USE_NKI_FUSED"] = "1" + elif mode == "chunk": + env["USE_PYTORCH_CHUNK"] = "1" + elif mode == "nki_recurrent": + env["USE_NKI"] = "1" + else: + raise ValueError(f"unknown mode: {mode}") + + with patched_env(**env): + with torch.no_grad(): + out, _dummy_kv, rec_state, conv_state = module(hidden_states) + return out.detach().cpu(), rec_state.detach().cpu(), conv_state.detach().cpu() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", default=os.environ.get("QWEN35_MODEL_PATH")) + parser.add_argument("--layer-idx", type=int, default=0) + parser.add_argument("--seq-len", type=int, default=128) + parser.add_argument("--tp-degree", type=int, default=4) + parser.add_argument( + "--compare", + nargs="+", + default=["fused"], + choices=["fused", "chunk", "nki_recurrent"], + ) + parser.add_argument("--device", default="cpu", choices=["cpu", "xla"]) + args = parser.parse_args() + + if not args.model_path: + raise SystemExit("Set QWEN35_MODEL_PATH or pass --model-path") + + device = torch.device("cpu") + if args.device == "xla": + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + + config = load_config(args.model_path, args.tp_degree) + if config.layer_types[args.layer_idx] != "linear_attention": + raise SystemExit(f"layer {args.layer_idx} is {config.layer_types[args.layer_idx]}, not DeltaNet") + + module = NeuronGatedDeltaNet(config, args.layer_idx).to(device) + load_deltanet_layer_weights(module, args.model_path, args.layer_idx) + module = module.to(device=device, dtype=torch.bfloat16).eval() + + torch.manual_seed(0) + hidden_states = torch.randn( + 1, + args.seq_len, + config.hidden_size, + dtype=torch.bfloat16, + device=device, + ) + + ref = run_path(module, hidden_states, "sequential") + print(f"reference=sequential layer={args.layer_idx} seq_len={args.seq_len}") + for mode in args.compare: + cur = run_path(module, hidden_states, mode) + print(f"\nmode={mode}") + for label, ref_t, cur_t in zip(("output", "recurrent_state", "conv_state"), ref, cur): + print( + f"{label}: cosine={cosine(ref_t, cur_t):.6f} " + f"max_abs={max_abs(ref_t, cur_t):.6f} " + f"shape={tuple(cur_t.shape)}" + ) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen3.5-9B/test/unit/__init__.py b/contrib/models/Qwen3.5-9B/test/unit/__init__.py new file mode 100644 index 00000000..04f8b7b7 --- /dev/null +++ b/contrib/models/Qwen3.5-9B/test/unit/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/contrib/models/Qwen3.5-9B/test/unit/test_config.py b/contrib/models/Qwen3.5-9B/test/unit/test_config.py new file mode 100644 index 00000000..9837e307 --- /dev/null +++ b/contrib/models/Qwen3.5-9B/test/unit/test_config.py @@ -0,0 +1,201 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Qwen3.5-9B inference configuration. + +CPU-only tests that validate config parsing, layer type setup, +DeltaNet parameter defaults, RoPE configuration, and weight conversion logic. +""" + +import os +import sys +import unittest +from unittest.mock import MagicMock + +import torch + +# Ensure the contrib root (Qwen3.5-9B/) is on sys.path +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + Qwen35InferenceConfig, + convert_qwen35_hf_to_neuron_state_dict, +) +from neuronx_distributed_inference.models.config import NeuronConfig + + +def _make_config(**overrides): + """Create a Qwen35InferenceConfig with reasonable defaults.""" + neuron_config = NeuronConfig( + tp_degree=overrides.pop("tp_degree", 4), + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + ) + defaults = dict( + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=16, + num_key_value_heads=4, + head_dim=256, + intermediate_size=12288, + vocab_size=248320, + rms_norm_eps=1e-6, + max_position_embeddings=262144, + rope_theta=10000000, + hidden_act="silu", + tie_word_embeddings=False, + # DeltaNet-specific + linear_num_value_heads=32, + linear_num_key_heads=16, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + ) + defaults.update(overrides) + config = Qwen35InferenceConfig(neuron_config=neuron_config, **defaults) + return config + + +class TestConfigParsing(unittest.TestCase): + """Test basic config attribute initialization.""" + + def test_hidden_size(self): + config = _make_config() + self.assertEqual(config.hidden_size, 4096) + + def test_num_hidden_layers(self): + config = _make_config() + self.assertEqual(config.num_hidden_layers, 32) + + def test_num_attention_heads(self): + config = _make_config() + self.assertEqual(config.num_attention_heads, 16) + + def test_num_key_value_heads(self): + config = _make_config() + self.assertEqual(config.num_key_value_heads, 4) + + def test_head_dim(self): + config = _make_config() + self.assertEqual(config.head_dim, 256) + + def test_intermediate_size(self): + config = _make_config() + self.assertEqual(config.intermediate_size, 12288) + + def test_vocab_size(self): + config = _make_config() + self.assertEqual(config.vocab_size, 248320) + + def test_hidden_act(self): + config = _make_config() + self.assertEqual(config.hidden_act, "silu") + + +class TestLayerTypes(unittest.TestCase): + """Test hybrid layer type assignment (3 DeltaNet + 1 GQA) x 8.""" + + def test_layer_types_length(self): + config = _make_config() + self.assertEqual(len(config.layer_types), 32) + + def test_layer_types_pattern(self): + """Every 4th layer (3, 7, 11, ...) should be full_attention.""" + config = _make_config() + for i in range(32): + expected = "full_attention" if i % 4 == 3 else "linear_attention" + self.assertEqual(config.layer_types[i], expected, f"Layer {i} mismatch") + + def test_deltanet_layer_count(self): + config = _make_config() + dn_count = sum(1 for t in config.layer_types if t == "linear_attention") + self.assertEqual(dn_count, 24) + + def test_gqa_layer_count(self): + config = _make_config() + gqa_count = sum(1 for t in config.layer_types if t == "full_attention") + self.assertEqual(gqa_count, 8) + + +class TestDeltaNetConfig(unittest.TestCase): + """Test DeltaNet-specific configuration defaults.""" + + def test_linear_num_value_heads(self): + config = _make_config() + self.assertEqual(config.linear_num_value_heads, 32) + + def test_linear_num_key_heads(self): + config = _make_config() + self.assertEqual(config.linear_num_key_heads, 16) + + def test_linear_key_head_dim(self): + config = _make_config() + self.assertEqual(config.linear_key_head_dim, 128) + + def test_linear_value_head_dim(self): + config = _make_config() + self.assertEqual(config.linear_value_head_dim, 128) + + def test_linear_conv_kernel_dim(self): + config = _make_config() + self.assertEqual(config.linear_conv_kernel_dim, 4) + + +class TestRoPEConfig(unittest.TestCase): + """Test partial RoPE configuration.""" + + def test_partial_rotary_factor(self): + config = _make_config() + self.assertAlmostEqual(config.partial_rotary_factor, 0.25) + + def test_rope_dim(self): + """rope_dim = head_dim * partial_rotary_factor = 256 * 0.25 = 64.""" + config = _make_config() + self.assertEqual(config.rope_dim, 64) + + def test_attn_output_gate(self): + config = _make_config() + self.assertTrue(config.attn_output_gate) + + def test_mrope_section(self): + config = _make_config() + self.assertEqual(config.mrope_section, [11, 11, 10]) + + def test_mrope_interleaved(self): + config = _make_config() + self.assertTrue(config.mrope_interleaved) + + +class TestNeuronConfig(unittest.TestCase): + """Test Neuron-specific configuration settings.""" + + def test_neuron_config_cls(self): + """Qwen3.5-9B is dense -- uses NeuronConfig, NOT MoENeuronConfig.""" + self.assertEqual( + Qwen35InferenceConfig.get_neuron_config_cls(), + NeuronConfig, + ) + + def test_required_attributes(self): + config = _make_config() + required = config.get_required_attributes() + self.assertIn("hidden_size", required) + self.assertIn("num_hidden_layers", required) + self.assertIn("linear_num_value_heads", required) + self.assertIn("linear_key_head_dim", required) + self.assertIn("layer_types", required) + + def test_output_attentions_default(self): + config = _make_config() + self.assertFalse(config.output_attentions) + + def test_output_hidden_states_default(self): + config = _make_config() + self.assertFalse(config.output_hidden_states) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.5-9B/test/unit/test_weight_conversion.py b/contrib/models/Qwen3.5-9B/test/unit/test_weight_conversion.py new file mode 100644 index 00000000..aa527068 --- /dev/null +++ b/contrib/models/Qwen3.5-9B/test/unit/test_weight_conversion.py @@ -0,0 +1,434 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Qwen3.5-9B HF-to-NxDI weight conversion. + +CPU-only tests that validate: +- RMSNorm (+1 convention) weight conversion +- GQA q_proj interleaved split (query + gate) +- QK norm key renaming (q_norm -> q_layernorm, k_norm -> k_layernorm) +- Fused QKV concatenation +- DeltaNet layer weights pass through unchanged +- VL wrapper prefix stripping +- rank_util injection +""" + +import os +import sys +import unittest + +import torch + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + Qwen35InferenceConfig, + NeuronQwen35ForCausalLM, + convert_qwen35_hf_to_neuron_state_dict, +) +from neuronx_distributed_inference.models.config import NeuronConfig + + +def _make_mini_config(num_layers=4, tp_degree=2, fused_qkv=True): + """Create a small Qwen35InferenceConfig for testing.""" + neuron_config = NeuronConfig( + tp_degree=tp_degree, + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + fused_qkv=fused_qkv, + ) + config = Qwen35InferenceConfig( + neuron_config=neuron_config, + hidden_size=256, + num_hidden_layers=num_layers, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=64, + intermediate_size=512, + vocab_size=1000, + rms_norm_eps=1e-6, + max_position_embeddings=4096, + rope_theta=10000, + hidden_act="silu", + linear_num_value_heads=8, + linear_num_key_heads=4, + linear_key_head_dim=32, + linear_value_head_dim=32, + linear_conv_kernel_dim=4, + ) + return config + + +def _make_mini_state_dict(config): + """Create a minimal HF-style state dict for conversion testing.""" + sd = {} + H = config.hidden_size # 256 + I = config.intermediate_size # 512 + V = config.vocab_size # 1000 + num_heads = config.num_attention_heads # 4 + num_kv = config.num_key_value_heads # 2 + head_dim = config.head_dim # 64 + + sd["embed_tokens.weight"] = torch.randn(V, H, dtype=torch.bfloat16) * 0.02 + sd["lm_head.weight"] = torch.randn(V, H, dtype=torch.bfloat16) * 0.02 + sd["norm.weight"] = torch.zeros(H, dtype=torch.bfloat16) # +1 convention: zeros + + for l in range(config.num_hidden_layers): + sd[f"layers.{l}.input_layernorm.weight"] = torch.zeros(H, dtype=torch.bfloat16) + sd[f"layers.{l}.post_attention_layernorm.weight"] = torch.zeros( + H, dtype=torch.bfloat16 + ) + + # Dense MLP (all layers) + sd[f"layers.{l}.mlp.gate_proj.weight"] = ( + torch.randn(I, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.mlp.up_proj.weight"] = ( + torch.randn(I, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.mlp.down_proj.weight"] = ( + torch.randn(H, I, dtype=torch.bfloat16) * 0.02 + ) + + if config.layer_types[l] == "full_attention": + # GQA layer: q_proj is interleaved [head0_q | head0_gate | head1_q | ...] + q_proj = ( + torch.randn(num_heads * head_dim * 2, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.q_proj.weight"] = q_proj + sd[f"layers.{l}.self_attn.k_proj.weight"] = ( + torch.randn(num_kv * head_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.v_proj.weight"] = ( + torch.randn(num_kv * head_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.o_proj.weight"] = ( + torch.randn(H, num_heads * head_dim, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.q_norm.weight"] = torch.zeros( + head_dim, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.self_attn.k_norm.weight"] = torch.zeros( + head_dim, dtype=torch.bfloat16 + ) + else: + # DeltaNet layer: minimal required weights + key_dim = config.linear_num_key_heads * config.linear_key_head_dim # 128 + value_dim = ( + config.linear_num_value_heads * config.linear_value_head_dim + ) # 256 + conv_dim = key_dim * 2 + value_dim # 512 + sd[f"layers.{l}.linear_attn.in_proj_qkv.weight"] = ( + torch.randn(conv_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_z.weight"] = ( + torch.randn(value_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_a.weight"] = ( + torch.randn(config.linear_num_value_heads, H, dtype=torch.bfloat16) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_b.weight"] = ( + torch.randn(config.linear_num_value_heads, H, dtype=torch.bfloat16) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.conv1d.weight"] = ( + torch.randn( + conv_dim, 1, config.linear_conv_kernel_dim, dtype=torch.bfloat16 + ) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.A_log"] = torch.randn( + config.linear_num_value_heads, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.linear_attn.dt_bias"] = torch.randn( + config.linear_num_value_heads, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.linear_attn.norm.weight"] = ( + torch.randn(value_dim, dtype=torch.bfloat16) * 0.5 + ) + sd[f"layers.{l}.linear_attn.out_proj.weight"] = ( + torch.randn(H, value_dim, dtype=torch.bfloat16) * 0.02 + ) + + return sd + + +class TestNormConversion(unittest.TestCase): + """Test (+1 convention) RMSNorm weight conversion.""" + + def test_norm_weight_adds_one(self): + """Weights initialized to zero should become 1.0 after conversion.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + # norm.weight was zeros -> should now be ones + torch.testing.assert_close( + result["norm.weight"], + torch.ones_like(result["norm.weight"]), + ) + + def test_input_layernorm_adds_one(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + w = result[f"layers.{l}.input_layernorm.weight"] + self.assertTrue( + torch.allclose(w, torch.ones_like(w)), + f"Layer {l} input_layernorm not converted", + ) + + def test_post_attn_layernorm_adds_one(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + w = result[f"layers.{l}.post_attention_layernorm.weight"] + self.assertTrue( + torch.allclose(w, torch.ones_like(w)), + f"Layer {l} post_attention_layernorm not converted", + ) + + def test_qk_norm_adds_one(self): + """Q/K norms on GQA layers should also get +1 applied.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + q_w = result[f"layers.{l}.self_attn.q_layernorm.weight"] + k_w = result[f"layers.{l}.self_attn.k_layernorm.weight"] + self.assertTrue( + torch.allclose(q_w, torch.ones_like(q_w)), + f"Layer {l} q_layernorm not converted", + ) + self.assertTrue( + torch.allclose(k_w, torch.ones_like(k_w)), + f"Layer {l} k_layernorm not converted", + ) + + +class TestQProjSplit(unittest.TestCase): + """Test q_proj interleaved split into query + gate.""" + + def test_q_proj_split_shapes(self): + """q_proj (num_heads * head_dim * 2, H) -> separate query and gate.""" + config = _make_mini_config(fused_qkv=False) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + # After split: q_proj should be (num_heads * head_dim, H) = (256, 256) + q_w = result[f"layers.{l}.self_attn.q_proj.weight"] + gate_w = result[f"layers.{l}.self_attn.output_gate_proj.weight"] + expected_shape = ( + config.num_attention_heads * config.head_dim, + config.hidden_size, + ) + self.assertEqual( + q_w.shape, expected_shape, f"Layer {l} q_proj shape wrong" + ) + self.assertEqual( + gate_w.shape, expected_shape, f"Layer {l} gate shape wrong" + ) + + def test_q_proj_deinterleave_correct(self): + """Verify the interleaved split correctly separates query and gate.""" + config = _make_mini_config(fused_qkv=False) + sd = _make_mini_state_dict(config) + + # Create a known pattern: head0 query is 1s, head0 gate is 2s, etc. + l = 3 # First full_attention layer (layer 3) + num_heads = config.num_attention_heads + head_dim = config.head_dim + H = config.hidden_size + + interleaved = torch.zeros(num_heads * head_dim * 2, H, dtype=torch.bfloat16) + for h in range(num_heads): + interleaved[h * head_dim * 2 : h * head_dim * 2 + head_dim, :] = float( + h + 1 + ) # query + interleaved[h * head_dim * 2 + head_dim : (h + 1) * head_dim * 2, :] = ( + float(h + 100) + ) # gate + + sd[f"layers.{l}.self_attn.q_proj.weight"] = interleaved + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + q_w = result[f"layers.{l}.self_attn.q_proj.weight"] + gate_w = result[f"layers.{l}.self_attn.output_gate_proj.weight"] + + for h in range(num_heads): + q_head = q_w[h * head_dim : (h + 1) * head_dim, :] + gate_head = gate_w[h * head_dim : (h + 1) * head_dim, :] + self.assertTrue( + torch.all(q_head == float(h + 1)), f"Head {h} query values wrong" + ) + self.assertTrue( + torch.all(gate_head == float(h + 100)), f"Head {h} gate values wrong" + ) + + +class TestQKNormRename(unittest.TestCase): + """Test q_norm -> q_layernorm and k_norm -> k_layernorm renaming.""" + + def test_old_keys_removed(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertNotIn(f"layers.{l}.self_attn.q_norm.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.k_norm.weight", result) + + def test_new_keys_present(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertIn(f"layers.{l}.self_attn.q_layernorm.weight", result) + self.assertIn(f"layers.{l}.self_attn.k_layernorm.weight", result) + + +class TestFusedQKV(unittest.TestCase): + """Test fused QKV concatenation for attention layers.""" + + def test_fused_qkv_shape(self): + config = _make_mini_config(fused_qkv=True) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + fused_key = f"layers.{l}.self_attn.Wqkv.weight" + self.assertIn(fused_key, result, f"Layer {l} missing Wqkv") + + q_dim = config.num_attention_heads * config.head_dim + k_dim = config.num_key_value_heads * config.head_dim + v_dim = config.num_key_value_heads * config.head_dim + expected_rows = q_dim + k_dim + v_dim + self.assertEqual(result[fused_key].shape[0], expected_rows) + + def test_fused_qkv_removes_individual_keys(self): + config = _make_mini_config(fused_qkv=True) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertNotIn(f"layers.{l}.self_attn.q_proj.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.k_proj.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.v_proj.weight", result) + + +class TestDeltaNetPassthrough(unittest.TestCase): + """Test that DeltaNet layer weights pass through conversion unchanged.""" + + def test_deltanet_weights_unchanged(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Record original DeltaNet weights + originals = {} + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + key = f"layers.{l}.linear_attn.in_proj_qkv.weight" + originals[key] = sd[key].clone() + + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for key, orig in originals.items(): + self.assertIn(key, result, f"Missing: {key}") + torch.testing.assert_close( + result[key], orig, msg=f"DeltaNet weight changed: {key}" + ) + + def test_deltanet_norm_not_converted(self): + """DeltaNet layers use standard RMSNorm (NOT +1 convention). + The norm weight should NOT be changed.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Set DeltaNet norm to a known non-zero value + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + sd[f"layers.{l}.linear_attn.norm.weight"] = torch.full( + (config.linear_num_value_heads * config.linear_value_head_dim,), + 0.87, + dtype=torch.bfloat16, + ) + + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + w = result[f"layers.{l}.linear_attn.norm.weight"] + # Should still be ~0.87, NOT 1.87 + self.assertTrue( + torch.allclose(w, torch.full_like(w, 0.87), atol=0.01), + f"Layer {l} DeltaNet norm was incorrectly modified", + ) + + +class TestRankUtil(unittest.TestCase): + """Test rank_util tensor injection.""" + + def test_rank_util_present(self): + config = _make_mini_config(tp_degree=4) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + self.assertIn("rank_util.rank", result) + expected = torch.arange(0, 4, dtype=torch.int32) + torch.testing.assert_close(result["rank_util.rank"], expected) + + def test_gqa_layer_rank_util(self): + config = _make_mini_config(tp_degree=4) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + key = f"layers.{l}.self_attn.rank_util.rank" + self.assertIn(key, result) + expected = torch.arange(0, 4, dtype=torch.int32) + torch.testing.assert_close(result[key], expected) + + +class TestVLPrefixStripping(unittest.TestCase): + """Test VL wrapper prefix stripping in convert_hf_to_neuron_state_dict.""" + + def test_language_model_prefix_stripped(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Wrap with VL prefix + vl_sd = {} + for k, v in sd.items(): + vl_sd[f"language_model.{k}"] = v + vl_sd["visual.encoder.weight"] = torch.zeros(10) # should be skipped + vl_sd["mtp.something"] = torch.zeros(5) # should be skipped + + result = NeuronQwen35ForCausalLM.convert_hf_to_neuron_state_dict(vl_sd, config) + self.assertNotIn("visual.encoder.weight", result) + self.assertNotIn("mtp.something", result) + self.assertIn("norm.weight", result) + + def test_model_language_model_prefix_stripped(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + vl_sd = {} + for k, v in sd.items(): + vl_sd[f"model.language_model.{k}"] = v + + result = NeuronQwen35ForCausalLM.convert_hf_to_neuron_state_dict(vl_sd, config) + self.assertIn("norm.weight", result) + + +if __name__ == "__main__": + unittest.main() From 27f6183c16bacf507a588fa96774e1d0ec17065a Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Mon, 4 May 2026 17:35:53 +0530 Subject: [PATCH 2/5] Stabilize Qwen3.5 DeltaNet generation --- .../models/Qwen3.5-4B/src/modeling_qwen35.py | 102 ++++++++++++------ .../Qwen3.5-4B/test/integration/test_model.py | 23 +++- .../test/unit/test_deltanet_decay.py | 68 ++++++++++++ .../models/Qwen3.5-9B/src/modeling_qwen35.py | 102 ++++++++++++------ .../Qwen3.5-9B/test/integration/test_model.py | 23 +++- .../test/unit/test_deltanet_decay.py | 68 ++++++++++++ 6 files changed, 318 insertions(+), 68 deletions(-) create mode 100644 contrib/models/Qwen3.5-4B/test/unit/test_deltanet_decay.py create mode 100644 contrib/models/Qwen3.5-9B/test/unit/test_deltanet_decay.py diff --git a/contrib/models/Qwen3.5-4B/src/modeling_qwen35.py b/contrib/models/Qwen3.5-4B/src/modeling_qwen35.py index 3e27e419..bf3d4066 100644 --- a/contrib/models/Qwen3.5-4B/src/modeling_qwen35.py +++ b/contrib/models/Qwen3.5-4B/src/modeling_qwen35.py @@ -221,6 +221,33 @@ def l2norm(x, dim=-1, eps=1e-6): return F.normalize(x, p=2, dim=dim, eps=eps) +FUSED_DELTANET_DECAY_MIN = -60.0 +FUSED_DELTANET_DECAY_MAX = 0.0 + + +def _bound_fused_deltanet_log_decay( + g, batch_size, num_heads, total_seq_len, chunk_size +): + """Bound cumulative DeltaNet decay before the fused NKI kernel. + + The fused kernel internally computes both exp(cumsum(g)) and exp(-cumsum(g)). + Large negative cumulative decays make the second term overflow even though + the true pairwise decay exp(gc_i - gc_j) is bounded by one. Return + equivalent per-token deltas whose per-chunk cumulative sum is clamped. + """ + num_chunks = total_seq_len // chunk_size + g_chunks = g.reshape(batch_size, num_heads, num_chunks, chunk_size) + g_cumsum = g_chunks.cumsum(dim=-1).clamp( + min=FUSED_DELTANET_DECAY_MIN, + max=FUSED_DELTANET_DECAY_MAX, + ) + g_first = g_cumsum[..., :1] + g_rest = g_cumsum[..., 1:] - g_cumsum[..., :-1] + return torch.cat([g_first, g_rest], dim=-1).reshape( + batch_size, num_heads, total_seq_len + ) + + # ============================================================ # Gated DeltaNet Module (Linear Recurrent Attention) # ============================================================ @@ -521,6 +548,7 @@ def _fused_chunked_forward( beta = F.pad(beta, (0, pad_size)) g = F.pad(g, (0, pad_size)) total_seq_len = S + pad_size + g = _bound_fused_deltanet_log_decay(g, B, H, total_seq_len, chunk_size) BH = B * H # Flatten to (BH, S, dim) for per-(b,h) kernel calls @@ -1122,40 +1150,36 @@ def forward(self, x, position_ids_3d): device = x.device dtype = torch.float32 - sections = self.mrope_section # [11, 11, 10] - cos_parts = [] - sin_parts = [] - - freq_offset = 0 - for axis_idx, section_size in enumerate(sections): - pos = position_ids_3d[axis_idx].float() # (batch, seq_len) - - dim_pairs = section_size # number of (cos, sin) pairs for this axis - freqs = 1.0 / ( - self.rope_theta - ** ( - torch.arange(0, dim_pairs * 2, 2, dtype=dtype, device=device) - / (self.rope_dim) - ) - ) # (dim_pairs,) - - # freqs: (dim_pairs,), pos: (B, S) -> angles: (B, S, dim_pairs) - angles = pos.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0) - - cos_parts.append(angles.cos()) - sin_parts.append(angles.sin()) + if position_ids_3d.ndim == 2: + position_ids_3d = position_ids_3d[None, ...].expand( + 3, position_ids_3d.shape[0], -1 + ) - # Concatenate: (B, S, 32) - cos = torch.cat(cos_parts, dim=-1) - sin = torch.cat(sin_parts, dim=-1) + inv_freq = 1.0 / ( + self.rope_theta + ** ( + torch.arange(0, self.rope_dim, 2, dtype=dtype, device=device) + / self.rope_dim + ) + ) + inv_freq = inv_freq[None, None, :, None].expand( + 3, position_ids_3d.shape[1], -1, 1 + ) + positions = position_ids_3d[:, :, None, :].float() + freqs = (inv_freq.float() @ positions).transpose(2, 3) + # Match HF Qwen3.5 mRoPE layout exactly: start from the temporal + # frequencies, then splice H/W frequencies into interleaved positions. + freqs_t = freqs[0] if self.mrope_interleaved: - # Interleave to (B, S, 64): [c0, c0, c1, c1, ...] for rotate_half - cos = cos.repeat_interleave(2, dim=-1) - sin = sin.repeat_interleave(2, dim=-1) - else: - cos = torch.cat([cos, cos], dim=-1) - sin = torch.cat([sin, sin], dim=-1) + for dim, offset in enumerate((1, 2), start=1): + length = self.mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + + emb = torch.cat((freqs_t, freqs_t), dim=-1) + cos = emb.cos().to(dtype=x.dtype) + sin = emb.sin().to(dtype=x.dtype) return cos, sin @@ -1631,16 +1655,28 @@ def get_model_output( if is_for_context_encoding: inputs_embeds = inputs_embeds * deltanet_padding_mask - # Vision embedding injection + # Vision embedding injection. Text-only calls still pass dummy vision + # tensors to keep the traced input signature stable; those tensors have + # one dummy entry per text token and must not overwrite text embeddings. if (vision_embeddings is not None) and (vision_mask is not None): if vision_embeddings.dtype != self.config.neuron_config.torch_dtype: vision_embeddings = vision_embeddings.to( self.config.neuron_config.torch_dtype ) - if is_for_context_encoding: + has_real_vision_inputs = ( + vision_embeddings.ndim == 3 + and vision_mask.ndim == 3 + and vision_embeddings.shape[1] != seq_length + ) + if is_for_context_encoding and has_real_vision_inputs: inputs_embeds = self.encode_vision_to_input( inputs_embeds, vision_embeddings, vision_mask ) + elif is_for_context_encoding and vision_embeddings.numel() > 0: + inputs_embeds = inputs_embeds + vision_embeddings.sum() * 0 + inputs_embeds = ( + inputs_embeds + vision_mask.sum().to(inputs_embeds.dtype) * 0 + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device diff --git a/contrib/models/Qwen3.5-4B/test/integration/test_model.py b/contrib/models/Qwen3.5-4B/test/integration/test_model.py index 3bdafc3e..6de43de4 100644 --- a/contrib/models/Qwen3.5-4B/test/integration/test_model.py +++ b/contrib/models/Qwen3.5-4B/test/integration/test_model.py @@ -256,7 +256,7 @@ def test_top_token_valid(compiled_model, tokenizer, generation_config): ) input_len = len(tokenizer.encode("Hello!")) first_new = tokens[input_len] - assert 0 <= first_new < tokenizer.vocab_size, ( + assert 0 <= first_new < len(tokenizer), ( f"Token {first_new} out of vocab range" ) decoded = tokenizer.decode([first_new]) @@ -264,6 +264,27 @@ def test_top_token_valid(compiled_model, tokenizer, generation_config): print(f" First token: {first_new} -> '{decoded}'") +@requires_model_path +def test_olympics_prompt_no_invalid_tokens( + compiled_model, tokenizer, generation_config +): + """Regression test for NaN logits producing the int32-min token id.""" + prompt = "Give me a summary of the 2020 Olympics in 100 tokens." + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=32, + ) + input_len = len(tokenizer.encode(prompt)) + generated = tokens[input_len:] + invalid = [token for token in generated if token < 0 or token >= len(tokenizer)] + + assert len(generated) >= 5, f"Expected >= 5 generated tokens, got {generated}" + assert not invalid, f"Generated invalid token ids: {invalid}" + + @requires_model_path def test_capital_of_france(compiled_model, tokenizer, generation_config): """'The capital of France is' should produce 'Paris' in the response.""" diff --git a/contrib/models/Qwen3.5-4B/test/unit/test_deltanet_decay.py b/contrib/models/Qwen3.5-4B/test/unit/test_deltanet_decay.py new file mode 100644 index 00000000..416a431a --- /dev/null +++ b/contrib/models/Qwen3.5-4B/test/unit/test_deltanet_decay.py @@ -0,0 +1,68 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for fused DeltaNet log-decay bounding.""" + +import os +import sys +import unittest + +import torch + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + FUSED_DELTANET_DECAY_MAX, + FUSED_DELTANET_DECAY_MIN, + _bound_fused_deltanet_log_decay, +) + + +def _chunked_cumsum(g, batch_size, num_heads, total_seq_len, chunk_size): + num_chunks = total_seq_len // chunk_size + return g.reshape(batch_size, num_heads, num_chunks, chunk_size).cumsum(dim=-1) + + +class TestFusedDeltaNetDecayBounding(unittest.TestCase): + def test_preserves_non_extreme_decay(self): + batch_size, num_heads, total_seq_len, chunk_size = 2, 3, 16, 8 + g = torch.full( + (batch_size, num_heads, total_seq_len), + -0.125, + dtype=torch.float32, + ) + + bounded = _bound_fused_deltanet_log_decay( + g, batch_size, num_heads, total_seq_len, chunk_size + ) + + torch.testing.assert_close(bounded, g) + + def test_bounds_per_chunk_cumulative_decay(self): + batch_size, num_heads, total_seq_len, chunk_size = 2, 3, 16, 8 + g = torch.full( + (batch_size, num_heads, total_seq_len), + -10.0, + dtype=torch.float32, + ) + + bounded = _bound_fused_deltanet_log_decay( + g, batch_size, num_heads, total_seq_len, chunk_size + ) + bounded_cumsum = _chunked_cumsum( + bounded, batch_size, num_heads, total_seq_len, chunk_size + ) + expected_cumsum = _chunked_cumsum( + g, batch_size, num_heads, total_seq_len, chunk_size + ).clamp(min=FUSED_DELTANET_DECAY_MIN, max=FUSED_DELTANET_DECAY_MAX) + + torch.testing.assert_close(bounded_cumsum, expected_cumsum) + self.assertGreaterEqual(float(bounded_cumsum.min()), FUSED_DELTANET_DECAY_MIN) + self.assertLessEqual(float(bounded_cumsum.max()), FUSED_DELTANET_DECAY_MAX) + self.assertTrue(torch.isfinite(bounded).all()) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.5-9B/src/modeling_qwen35.py b/contrib/models/Qwen3.5-9B/src/modeling_qwen35.py index 959c3170..657df76f 100644 --- a/contrib/models/Qwen3.5-9B/src/modeling_qwen35.py +++ b/contrib/models/Qwen3.5-9B/src/modeling_qwen35.py @@ -221,6 +221,33 @@ def l2norm(x, dim=-1, eps=1e-6): return F.normalize(x, p=2, dim=dim, eps=eps) +FUSED_DELTANET_DECAY_MIN = -60.0 +FUSED_DELTANET_DECAY_MAX = 0.0 + + +def _bound_fused_deltanet_log_decay( + g, batch_size, num_heads, total_seq_len, chunk_size +): + """Bound cumulative DeltaNet decay before the fused NKI kernel. + + The fused kernel internally computes both exp(cumsum(g)) and exp(-cumsum(g)). + Large negative cumulative decays make the second term overflow even though + the true pairwise decay exp(gc_i - gc_j) is bounded by one. Return + equivalent per-token deltas whose per-chunk cumulative sum is clamped. + """ + num_chunks = total_seq_len // chunk_size + g_chunks = g.reshape(batch_size, num_heads, num_chunks, chunk_size) + g_cumsum = g_chunks.cumsum(dim=-1).clamp( + min=FUSED_DELTANET_DECAY_MIN, + max=FUSED_DELTANET_DECAY_MAX, + ) + g_first = g_cumsum[..., :1] + g_rest = g_cumsum[..., 1:] - g_cumsum[..., :-1] + return torch.cat([g_first, g_rest], dim=-1).reshape( + batch_size, num_heads, total_seq_len + ) + + # ============================================================ # Gated DeltaNet Module (Linear Recurrent Attention) # ============================================================ @@ -521,6 +548,7 @@ def _fused_chunked_forward( beta = F.pad(beta, (0, pad_size)) g = F.pad(g, (0, pad_size)) total_seq_len = S + pad_size + g = _bound_fused_deltanet_log_decay(g, B, H, total_seq_len, chunk_size) BH = B * H # Flatten to (BH, S, dim) for per-(b,h) kernel calls @@ -1122,40 +1150,36 @@ def forward(self, x, position_ids_3d): device = x.device dtype = torch.float32 - sections = self.mrope_section # [11, 11, 10] - cos_parts = [] - sin_parts = [] - - freq_offset = 0 - for axis_idx, section_size in enumerate(sections): - pos = position_ids_3d[axis_idx].float() # (batch, seq_len) - - dim_pairs = section_size # number of (cos, sin) pairs for this axis - freqs = 1.0 / ( - self.rope_theta - ** ( - torch.arange(0, dim_pairs * 2, 2, dtype=dtype, device=device) - / (self.rope_dim) - ) - ) # (dim_pairs,) - - # freqs: (dim_pairs,), pos: (B, S) -> angles: (B, S, dim_pairs) - angles = pos.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0) - - cos_parts.append(angles.cos()) - sin_parts.append(angles.sin()) + if position_ids_3d.ndim == 2: + position_ids_3d = position_ids_3d[None, ...].expand( + 3, position_ids_3d.shape[0], -1 + ) - # Concatenate: (B, S, 32) - cos = torch.cat(cos_parts, dim=-1) - sin = torch.cat(sin_parts, dim=-1) + inv_freq = 1.0 / ( + self.rope_theta + ** ( + torch.arange(0, self.rope_dim, 2, dtype=dtype, device=device) + / self.rope_dim + ) + ) + inv_freq = inv_freq[None, None, :, None].expand( + 3, position_ids_3d.shape[1], -1, 1 + ) + positions = position_ids_3d[:, :, None, :].float() + freqs = (inv_freq.float() @ positions).transpose(2, 3) + # Match HF Qwen3.5 mRoPE layout exactly: start from the temporal + # frequencies, then splice H/W frequencies into interleaved positions. + freqs_t = freqs[0] if self.mrope_interleaved: - # Interleave to (B, S, 64): [c0, c0, c1, c1, ...] for rotate_half - cos = cos.repeat_interleave(2, dim=-1) - sin = sin.repeat_interleave(2, dim=-1) - else: - cos = torch.cat([cos, cos], dim=-1) - sin = torch.cat([sin, sin], dim=-1) + for dim, offset in enumerate((1, 2), start=1): + length = self.mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + + emb = torch.cat((freqs_t, freqs_t), dim=-1) + cos = emb.cos().to(dtype=x.dtype) + sin = emb.sin().to(dtype=x.dtype) return cos, sin @@ -1631,16 +1655,28 @@ def get_model_output( if is_for_context_encoding: inputs_embeds = inputs_embeds * deltanet_padding_mask - # Vision embedding injection + # Vision embedding injection. Text-only calls still pass dummy vision + # tensors to keep the traced input signature stable; those tensors have + # one dummy entry per text token and must not overwrite text embeddings. if (vision_embeddings is not None) and (vision_mask is not None): if vision_embeddings.dtype != self.config.neuron_config.torch_dtype: vision_embeddings = vision_embeddings.to( self.config.neuron_config.torch_dtype ) - if is_for_context_encoding: + has_real_vision_inputs = ( + vision_embeddings.ndim == 3 + and vision_mask.ndim == 3 + and vision_embeddings.shape[1] != seq_length + ) + if is_for_context_encoding and has_real_vision_inputs: inputs_embeds = self.encode_vision_to_input( inputs_embeds, vision_embeddings, vision_mask ) + elif is_for_context_encoding and vision_embeddings.numel() > 0: + inputs_embeds = inputs_embeds + vision_embeddings.sum() * 0 + inputs_embeds = ( + inputs_embeds + vision_mask.sum().to(inputs_embeds.dtype) * 0 + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device diff --git a/contrib/models/Qwen3.5-9B/test/integration/test_model.py b/contrib/models/Qwen3.5-9B/test/integration/test_model.py index ff8a25c9..b1cf876c 100644 --- a/contrib/models/Qwen3.5-9B/test/integration/test_model.py +++ b/contrib/models/Qwen3.5-9B/test/integration/test_model.py @@ -256,7 +256,7 @@ def test_top_token_valid(compiled_model, tokenizer, generation_config): ) input_len = len(tokenizer.encode("Hello!")) first_new = tokens[input_len] - assert 0 <= first_new < tokenizer.vocab_size, ( + assert 0 <= first_new < len(tokenizer), ( f"Token {first_new} out of vocab range" ) decoded = tokenizer.decode([first_new]) @@ -264,6 +264,27 @@ def test_top_token_valid(compiled_model, tokenizer, generation_config): print(f" First token: {first_new} -> '{decoded}'") +@requires_model_path +def test_olympics_prompt_no_invalid_tokens( + compiled_model, tokenizer, generation_config +): + """Regression test for NaN logits producing the int32-min token id.""" + prompt = "Give me a summary of the 2020 Olympics in 100 tokens." + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=32, + ) + input_len = len(tokenizer.encode(prompt)) + generated = tokens[input_len:] + invalid = [token for token in generated if token < 0 or token >= len(tokenizer)] + + assert len(generated) >= 5, f"Expected >= 5 generated tokens, got {generated}" + assert not invalid, f"Generated invalid token ids: {invalid}" + + @requires_model_path def test_simple_factual_generation(compiled_model, tokenizer, generation_config): """A simple factual prompt should produce the expected entity.""" diff --git a/contrib/models/Qwen3.5-9B/test/unit/test_deltanet_decay.py b/contrib/models/Qwen3.5-9B/test/unit/test_deltanet_decay.py new file mode 100644 index 00000000..416a431a --- /dev/null +++ b/contrib/models/Qwen3.5-9B/test/unit/test_deltanet_decay.py @@ -0,0 +1,68 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for fused DeltaNet log-decay bounding.""" + +import os +import sys +import unittest + +import torch + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + FUSED_DELTANET_DECAY_MAX, + FUSED_DELTANET_DECAY_MIN, + _bound_fused_deltanet_log_decay, +) + + +def _chunked_cumsum(g, batch_size, num_heads, total_seq_len, chunk_size): + num_chunks = total_seq_len // chunk_size + return g.reshape(batch_size, num_heads, num_chunks, chunk_size).cumsum(dim=-1) + + +class TestFusedDeltaNetDecayBounding(unittest.TestCase): + def test_preserves_non_extreme_decay(self): + batch_size, num_heads, total_seq_len, chunk_size = 2, 3, 16, 8 + g = torch.full( + (batch_size, num_heads, total_seq_len), + -0.125, + dtype=torch.float32, + ) + + bounded = _bound_fused_deltanet_log_decay( + g, batch_size, num_heads, total_seq_len, chunk_size + ) + + torch.testing.assert_close(bounded, g) + + def test_bounds_per_chunk_cumulative_decay(self): + batch_size, num_heads, total_seq_len, chunk_size = 2, 3, 16, 8 + g = torch.full( + (batch_size, num_heads, total_seq_len), + -10.0, + dtype=torch.float32, + ) + + bounded = _bound_fused_deltanet_log_decay( + g, batch_size, num_heads, total_seq_len, chunk_size + ) + bounded_cumsum = _chunked_cumsum( + bounded, batch_size, num_heads, total_seq_len, chunk_size + ) + expected_cumsum = _chunked_cumsum( + g, batch_size, num_heads, total_seq_len, chunk_size + ).clamp(min=FUSED_DELTANET_DECAY_MIN, max=FUSED_DELTANET_DECAY_MAX) + + torch.testing.assert_close(bounded_cumsum, expected_cumsum) + self.assertGreaterEqual(float(bounded_cumsum.min()), FUSED_DELTANET_DECAY_MIN) + self.assertLessEqual(float(bounded_cumsum.max()), FUSED_DELTANET_DECAY_MAX) + self.assertTrue(torch.isfinite(bounded).all()) + + +if __name__ == "__main__": + unittest.main() From 7bb3c8006c78b9e1791a5df35d642adc15d0103f Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Mon, 4 May 2026 17:46:55 +0530 Subject: [PATCH 3/5] Update Qwen3.5 contrib documentation --- contrib/models/Qwen3.5-4B/README.md | 99 ++++++++++++++++++++++++++++- contrib/models/Qwen3.5-9B/README.md | 34 ++++++---- 2 files changed, 119 insertions(+), 14 deletions(-) diff --git a/contrib/models/Qwen3.5-4B/README.md b/contrib/models/Qwen3.5-4B/README.md index 1b75bee6..9910cd12 100644 --- a/contrib/models/Qwen3.5-4B/README.md +++ b/contrib/models/Qwen3.5-4B/README.md @@ -37,9 +37,97 @@ Derived DeltaNet shapes: | `recurrent_state_buffer` | `[max_batch, 32, 128, 128]` | | `conv_state_buffer` | `[max_batch, 8192, 3]` | +## Compatibility + +| Instance | Neuron SDK / environment | TP | dtype | seq_len | Status | +| --- | --- | --- | --- | --- | --- | +| `trn2.48xlarge` | PyTorch 2.9 NxDI inference env | 4 | BF16 | 160 | Unit and integration tests pass | + ## Status -Prepared for Trn2 bring-up. Validate TP=4, batch=1, seq_len=128 first, then increase context or reduce TP only after the baseline generates correctly. +Validated on Trn2 with TP=4, batch=1, and seq_len=160. TP=2, Trn1, long-context HBM limits, and quantized inference are not validated for this contrib model. + +## Compatible Checkpoints + +- [Qwen/Qwen3.5-4B](https://huggingface.co/Qwen/Qwen3.5-4B) + +## Usage + +```python +import json +import os +import torch +from transformers import AutoTokenizer, GenerationConfig +from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, +) +from neuronx_distributed_inference.utils.hf_adapter import ( + HuggingFaceGenerationAdapter, +) + +from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + +model_path = "/mnt/models/Qwen3.5-4B" +compiled_path = "/mnt/models/qwen35_4b_traced" + +neuron_config = NeuronConfig( + tp_degree=4, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=160, + torch_dtype=torch.bfloat16, + logical_nc_config=2, + enable_bucketing=False, + flash_decoding_enabled=False, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + save_sharded_checkpoint=True, +) + +with open(os.path.join(model_path, "config.json")) as f: + hf_config = json.load(f) +text_config = hf_config.get("text_config", hf_config) +config_dict = dict(text_config) +config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) +config_dict["tie_word_embeddings"] = hf_config.get( + "tie_word_embeddings", + text_config.get("tie_word_embeddings", True), +) +if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + +config = Qwen35InferenceConfig(neuron_config=neuron_config, **config_dict) + +model = NeuronQwen35ForCausalLM(model_path, config) +model.compile(compiled_path) + +model = NeuronQwen35ForCausalLM(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right") +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +gen_config = GenerationConfig( + do_sample=True, + top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, +) + +inputs = tokenizer("The capital of France is", return_tensors="pt") +gen_model = HuggingFaceGenerationAdapter(model) +outputs = gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=32, +) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) +``` ## Testing @@ -60,10 +148,17 @@ NEURON_PLATFORM_TARGET_OVERRIDE=trn2 \ QWEN35_MODEL_PATH=/home/ubuntu/models/Qwen3.5-4B \ QWEN35_COMPILED_PATH=/home/ubuntu/models/qwen35_4b_traced_trn2 \ QWEN35_TP_DEGREE=4 \ -QWEN35_SEQ_LEN=128 \ +QWEN35_SEQ_LEN=160 \ pytest test/integration/test_model.py --capture=tee-sys -v ``` +Validated results on `trn2.48xlarge`: + +- Unit tests: `45 passed` +- Integration tests: `9 passed` +- TTFT: `83.2 ms` +- Throughput: `68.1 tok/s` + ## Known Limitations - DeltaNet weights are replicated across TP ranks in v1. diff --git a/contrib/models/Qwen3.5-9B/README.md b/contrib/models/Qwen3.5-9B/README.md index 0b8dc6b3..638d3649 100644 --- a/contrib/models/Qwen3.5-9B/README.md +++ b/contrib/models/Qwen3.5-9B/README.md @@ -37,20 +37,23 @@ Derived DeltaNet shapes: | `recurrent_state_buffer` | `[max_batch, 32, 128, 128]` | | `conv_state_buffer` | `[max_batch, 8192, 3]` | +## Compatibility + +| Instance | Neuron SDK / environment | TP | dtype | seq_len | Status | +| --- | --- | --- | --- | --- | --- | +| `trn2.48xlarge` | PyTorch 2.9 NxDI inference env | 4 | BF16 | 160 | Unit and integration tests pass | + ## Status -This 9B contrib is prepared for bring-up. The implementation should be validated on Trn2 before TP=2 or Trn1 experiments. +Validated on Trn2 with TP=4, batch=1, and seq_len=160. TP=2, Trn1, long-context HBM limits, and quantized inference are not validated for this contrib model. Validated baseline: - Qwen3.5-2B PR 141: trn2.3xlarge, TP=4, LNC=2, SDK 2.29, NKI 0.3. -Unvalidated for this folder until run: +## Compatible Checkpoints -- Qwen3.5-9B compile and generation -- TP=2 -- Trn1 -- long-context HBM limits +- [Qwen/Qwen3.5-9B](https://huggingface.co/Qwen/Qwen3.5-9B) ## Usage @@ -72,7 +75,7 @@ neuron_config = NeuronConfig( batch_size=1, ctx_batch_size=1, tkg_batch_size=1, - seq_len=128, + seq_len=160, torch_dtype=torch.bfloat16, logical_nc_config=2, enable_bucketing=False, @@ -140,17 +143,24 @@ Trainium integration: cd contrib/models/Qwen3.5-9B source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate -QWEN35_MODEL_PATH=/mnt/models/Qwen3.5-9B \ -QWEN35_COMPILED_PATH=/mnt/models/qwen35_9b_traced \ +NEURON_PLATFORM_TARGET_OVERRIDE=trn2 \ +QWEN35_MODEL_PATH=/home/ubuntu/models/Qwen3.5-9B \ +QWEN35_COMPILED_PATH=/home/ubuntu/models/qwen35_9b_traced_trn2 \ QWEN35_TP_DEGREE=4 \ -QWEN35_SEQ_LEN=128 \ +QWEN35_SEQ_LEN=160 \ pytest test/integration/test_model.py --capture=tee-sys -v ``` +Validated results on `trn2.48xlarge`: + +- Unit tests: `44 passed` +- Integration tests: `9 passed` +- TTFT: `88.1 ms` +- Throughput: `49.6 tok/s` + ## Known Limitations 1. SDK 2.29+ and NKI 0.3 are expected. 2. DeltaNet weights are replicated across TP ranks in v1. 3. Dummy KV wastes HBM for DeltaNet layers. -4. First-token and multi-token logit parity are expected to show the same BF16 recurrent divergence reported by PR 141 until the DeltaNet precision work is done. -5. Hybrid cache, DeltaNet TP sharding, quantization, speculative decoding, and MoE are out of scope for first bring-up. +4. Hybrid cache, DeltaNet TP sharding, quantization, speculative decoding, and MoE are out of scope for first bring-up. From 8e1288ad7995160a8879c37a50ee4bbbbcc3c016 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Tue, 5 May 2026 17:52:00 +0530 Subject: [PATCH 4/5] Scope Qwen3.5 hybrid cache to 4B and 9B --- .../models/Qwen3.5-4B/src/modeling_qwen35.py | 307 +++++++++++++++- .../src/nki_kernels/nki_deltanet_fused.py | 63 +++- .../Qwen3.5-4B/test/integration/test_model.py | 107 ++++++ .../test/unit/test_hybrid_cache_manager.py | 335 +++++++++++++++++ .../models/Qwen3.5-9B/src/modeling_qwen35.py | 307 +++++++++++++++- .../src/nki_kernels/nki_deltanet_fused.py | 63 +++- .../Qwen3.5-9B/test/integration/test_model.py | 107 ++++++ .../test/unit/test_hybrid_cache_manager.py | 341 ++++++++++++++++++ 8 files changed, 1582 insertions(+), 48 deletions(-) create mode 100644 contrib/models/Qwen3.5-4B/test/unit/test_hybrid_cache_manager.py create mode 100644 contrib/models/Qwen3.5-9B/test/unit/test_hybrid_cache_manager.py diff --git a/contrib/models/Qwen3.5-4B/src/modeling_qwen35.py b/contrib/models/Qwen3.5-4B/src/modeling_qwen35.py index bf3d4066..791420cb 100644 --- a/contrib/models/Qwen3.5-4B/src/modeling_qwen35.py +++ b/contrib/models/Qwen3.5-4B/src/modeling_qwen35.py @@ -84,6 +84,7 @@ NeuronAttentionBase, ) from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.modules.kvcache.kv_cache_manager import KVCacheManager from neuronx_distributed_inference.models.layer_boundary_marker import ( ModuleMarkerEndWrapper, ModuleMarkerStartWrapper, @@ -221,7 +222,7 @@ def l2norm(x, dim=-1, eps=1e-6): return F.normalize(x, p=2, dim=dim, eps=eps) -FUSED_DELTANET_DECAY_MIN = -60.0 +FUSED_DELTANET_DECAY_MIN = -20.0 FUSED_DELTANET_DECAY_MAX = 0.0 @@ -286,6 +287,7 @@ def __init__(self, config, layer_idx: int): self.conv_kernel_size = tc.linear_conv_kernel_dim # 4 self.layer_idx = layer_idx self.rms_norm_eps = tc.rms_norm_eps + self.use_hybrid_cache_manager = getattr(tc, "use_hybrid_cache_manager", False) # KV cache dummy shape info self.head_dim = tc.head_dim # 256 @@ -747,6 +749,11 @@ def forward( # zeros the decay gate so the recurrent state is preserved unchanged # through padding positions (no spurious decay). valid_mask_1d = kwargs.get("deltanet_padding_mask", None) # [B, S, 1] or None + hybrid_cache_active = self.use_hybrid_cache_manager + recurrent_state_cache = None + conv_state_cache = None + if hybrid_cache_active and past_key_value is not None: + recurrent_state_cache, conv_state_cache = past_key_value # Project inputs deltanet_fp32 = os.environ.get("DELTANET_FP32") == "1" @@ -774,7 +781,9 @@ def forward( mixed = mixed.transpose(1, 2) if is_decode: - if seq_ids is not None: + if conv_state_cache is not None: + conv_state = conv_state_cache[:batch_size] + elif seq_ids is not None: conv_state = torch.index_select(self.conv_state_buffer, 0, seq_ids) else: conv_state = self.conv_state_buffer[:batch_size] @@ -791,7 +800,9 @@ def forward( new_conv_state = torch.cat([conv_state[:, :, 1:], mixed], dim=-1) alloc_bs = self.conv_state_buffer.shape[0] - if seq_ids is not None: + if hybrid_cache_active: + new_conv_state = new_conv_state.to(self.conv_state_buffer.dtype) + elif seq_ids is not None: # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement # Add buffer dependency for input_output_alias new_conv_state = ( @@ -827,7 +838,9 @@ def forward( new_conv_state = mixed[:, :, -3:].contiguous() alloc_bs = self.conv_state_buffer.shape[0] - if seq_ids is not None: + if hybrid_cache_active: + new_conv_state = new_conv_state.to(self.conv_state_buffer.dtype) + elif seq_ids is not None: # BS=1 optimization: scatter to index 0 = direct replacement new_conv_state = ( new_conv_state.to(self.conv_state_buffer.dtype) @@ -906,7 +919,9 @@ def forward( if is_decode: # TKG: single-step recurrent update - if seq_ids is not None: + if recurrent_state_cache is not None: + recurrent_state = recurrent_state_cache[:batch_size].float() + elif seq_ids is not None: recurrent_state = torch.index_select( self.recurrent_state_buffer, 0, seq_ids ).float() @@ -918,7 +933,9 @@ def forward( ) new_state_bf16 = new_state.to(self.recurrent_state_buffer.dtype) alloc_bs = self.recurrent_state_buffer.shape[0] - if seq_ids is not None: + if hybrid_cache_active: + new_rec_state = new_state_bf16 + elif seq_ids is not None: # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement # Add buffer dependency for input_output_alias new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0 @@ -970,7 +987,9 @@ def forward( if final_state is not None: final_state_bf16 = final_state.to(self.recurrent_state_buffer.dtype) alloc_bs = self.recurrent_state_buffer.shape[0] - if seq_ids is not None: + if hybrid_cache_active: + new_rec_state = final_state_bf16 + elif seq_ids is not None: # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement # Add buffer dependency for input_output_alias new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0 @@ -1005,6 +1024,9 @@ def forward( output = output.reshape(batch_size, seq_len, self.value_dim) output = self.out_proj(output) + if hybrid_cache_active: + return output, (new_rec_state, new_conv_state), new_rec_state, new_conv_state + # Return dummy KV for KVCacheManager dummy_k = torch.zeros( batch_size, @@ -1059,6 +1081,7 @@ def __init__(self, *args, **kwargs): kwargs.setdefault("linear_key_head_dim", 128) kwargs.setdefault("linear_value_head_dim", 128) kwargs.setdefault("linear_conv_kernel_dim", 4) + kwargs.setdefault("use_hybrid_cache_manager", False) super().__init__(*args, **kwargs) @@ -1512,7 +1535,11 @@ def forward( ) hidden_states = residual + attn_out present_key_value = dummy_kv - deltanet_states = (new_rec_state, new_conv_state) + deltanet_states = ( + None + if getattr(self.config, "use_hybrid_cache_manager", False) + else (new_rec_state, new_conv_state) + ) else: deltanet_states = None # Standard attention path @@ -1545,6 +1572,240 @@ def forward( return outputs +# ============================================================ +# Hybrid Cache Manager (opt-in) +# ============================================================ + + +class HybridDeltaNetCacheManager(KVCacheManager): + """Layer-type-aware cache manager for Qwen3.5/Qwen3.6 hybrid dense models.""" + + def __init__(self, config: Qwen35InferenceConfig, num_kv_head, **kwargs): + self.layer_types = list(config.layer_types) + self._validate_hybrid_config(config) + super().__init__(config, num_kv_head=num_kv_head, **kwargs) + + dtype = ( + config.neuron_config.attention_dtype + if config.neuron_config.attention_dtype is not None + else config.neuron_config.torch_dtype + ) + cache_dtype = getattr(self, "cache_dtype", dtype) + max_batch_size = ( + config.neuron_config.kv_cache_batch_size + + config.neuron_config.kv_cache_padding_size + ) + recurrent_shape = [ + max_batch_size, + config.linear_num_value_heads, + config.linear_key_head_dim, + config.linear_value_head_dim, + ] + conv_dim = ( + 2 * config.linear_num_key_heads * config.linear_key_head_dim + + config.linear_num_value_heads * config.linear_value_head_dim + ) + conv_shape = [ + max_batch_size, + conv_dim, + config.linear_conv_kernel_dim - 1, + ] + + params = [] + for layer_idx, layer_type in enumerate(self.layer_types): + if layer_type == "linear_attention": + params.append( + nn.Parameter(torch.zeros(recurrent_shape, dtype=dtype), requires_grad=False) + ) + params.append( + nn.Parameter(torch.zeros(conv_shape, dtype=dtype), requires_grad=False) + ) + else: + k_shape = self.k_shapes[layer_idx] if hasattr(self, "k_shapes") else self.k_shape + v_shape = self.v_shapes[layer_idx] if hasattr(self, "v_shapes") else self.v_shape + params.append( + nn.Parameter(torch.zeros(k_shape, dtype=cache_dtype), requires_grad=False) + ) + params.append( + nn.Parameter(torch.zeros(v_shape, dtype=cache_dtype), requires_grad=False) + ) + + self.past_key_values = nn.ParameterList(params) + + @staticmethod + def _validate_hybrid_config(config: Qwen35InferenceConfig): + nc = config.neuron_config + unsupported = [] + if nc.is_block_kv_layout: + unsupported.append("block KV layout") + if getattr(nc, "kv_quant_config", None) is not None or getattr(nc, "kv_cache_quant", False): + unsupported.append("KV cache quantization") + if nc.enable_fused_speculation or nc.speculation_length > 0 or nc.is_medusa: + unsupported.append("speculative decoding") + if getattr(nc, "enable_eagle_speculation", False) or getattr(nc, "is_eagle_draft", False): + unsupported.append("EAGLE speculation") + if nc.flash_decoding_enabled: + unsupported.append("flash decoding") + if nc.attention_dp_degree > 1: + unsupported.append("attention data parallelism") + if nc.kv_cache_tiling: + unsupported.append("KV cache tiling") + if nc.padding_side != "right": + unsupported.append("left padding") + if nc.is_continuous_batching: + unsupported.append("continuous batching") + if unsupported: + raise ValueError( + "HybridDeltaNetCacheManager v1 does not support: " + + ", ".join(unsupported) + ) + + def _is_deltanet_layer(self, idx: int) -> bool: + return self.layer_types[idx] == "linear_attention" + + def get_seq_length(self, past_key_values=None): + for idx, layer_type in enumerate(self.layer_types): + if layer_type != "linear_attention": + if past_key_values is None: + _, v_cache = self._fetch_cache(idx) + elif len(past_key_values) == len(self.past_key_values): + v_cache = past_key_values[2 * idx + 1] + else: + v_cache = past_key_values[idx][1] + return v_cache.shape[2] + return 0 + + def get_deltanet_state_by_layer_id(self, idx, kvcache_buffer=None, seq_ids=None): + recurrent_state, conv_state = self._fetch_cache(idx, kvcache_buffer) + if seq_ids is not None: + cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) + recurrent_state = torch.index_select(recurrent_state, dim=0, index=cache_idx) + conv_state = torch.index_select(conv_state, dim=0, index=cache_idx) + elif self.kv_cache_padding_size > 0: + recurrent_state = recurrent_state[: -self.kv_cache_padding_size] + conv_state = conv_state[: -self.kv_cache_padding_size] + return recurrent_state, conv_state + + def get_cache( + self, + seq_len: int, + skip_slice=False, + kvcache_buffer=None, + seq_ids=None, + windowed_context_encoding_window_idx=-1, + **kwargs, + ): + past_key_values = [] + for idx in range(len(self.past_key_values) // 2): + if self._is_deltanet_layer(idx): + past_key_values.append( + list(self.get_deltanet_state_by_layer_id(idx, kvcache_buffer, seq_ids)) + ) + else: + past_key_values.append( + list( + self.get_kv_by_layer_id( + idx=idx, + skip_slice=skip_slice, + seq_len=seq_len, + kvcache_buffer=kvcache_buffer, + seq_ids=seq_ids, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + ) + ) + return past_key_values + + def update_cache( + self, + is_for_context_encoding: bool, + seq_ids: torch.Tensor, + position_ids: torch.Tensor, + new_key_values: List[torch.Tensor], + seq_len: int, + scatter_index=None, + kv_active_mask=None, + kvcache_buffer=None, + windowed_context_encoding_window_idx: int = -1, + **kwargs, + ): + updated_cache = [] + for idx, kv_per_layer in enumerate(new_key_values): + if self._is_deltanet_layer(idx): + recurrent_state, conv_state = self.update_deltanet_state_by_layer_id( + idx=idx, + seq_ids=seq_ids, + state_per_layer=kv_per_layer, + kvcache_buffer=kvcache_buffer, + ) + else: + recurrent_state, conv_state = self.update_kv_by_layer_id( + idx=idx, + is_for_context_encoding=is_for_context_encoding, + seq_ids=seq_ids, + position_ids=position_ids, + kv_per_layer=kv_per_layer, + seq_len=seq_len, + scatter_index=scatter_index, + kv_active_mask=kv_active_mask, + kvcache_buffer=kvcache_buffer, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + updated_cache.append(recurrent_state) + updated_cache.append(conv_state) + return updated_cache + + def update_deltanet_state_by_layer_id( + self, + idx: int, + seq_ids: torch.Tensor, + state_per_layer: Tuple[torch.Tensor, torch.Tensor], + kvcache_buffer=None, + ): + latest_recurrent, latest_conv = state_per_layer + recurrent_cache, conv_cache = self._fetch_cache(idx, kvcache_buffer) + latest_recurrent = latest_recurrent.to(recurrent_cache.dtype) + latest_conv = latest_conv.to(conv_cache.dtype) + + if seq_ids is not None: + cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) + recurrent_index = cache_idx.view(-1, 1, 1, 1).expand_as(latest_recurrent) + conv_index = cache_idx.view(-1, 1, 1).expand_as(latest_conv) + recurrent_cache = torch.scatter( + input=recurrent_cache, + dim=0, + index=recurrent_index, + src=latest_recurrent, + ) + conv_cache = torch.scatter( + input=conv_cache, + dim=0, + index=conv_index, + src=latest_conv, + ) + return recurrent_cache, conv_cache + + if latest_recurrent.shape[0] == recurrent_cache.shape[0]: + return ( + latest_recurrent + recurrent_cache * 0, + latest_conv + conv_cache * 0, + ) + + pad_size = recurrent_cache.shape[0] - latest_recurrent.shape[0] + if pad_size > 0: + latest_recurrent = torch.cat( + [latest_recurrent, recurrent_cache[latest_recurrent.shape[0] :] * 0], + dim=0, + ) + latest_conv = torch.cat( + [latest_conv, conv_cache[latest_conv.shape[0] :] * 0], + dim=0, + ) + return latest_recurrent + recurrent_cache * 0, latest_conv + conv_cache * 0 + + # ============================================================ # Model # ============================================================ @@ -1590,6 +1851,19 @@ def init_model(self, config: Qwen35InferenceConfig): # mRoPE embedding for VL self.mrope_emb = Qwen35MRoPEEmbedding(config) + def init_inference_optimization(self, config: Qwen35InferenceConfig): + super().init_inference_optimization(config) + if getattr(config, "use_hybrid_cache_manager", False): + self.kv_mgr = HybridDeltaNetCacheManager( + config, + num_kv_head=self.num_key_value_heads, + global_rank=self.rank_util, + attention_chunk_size=self.attention_chunk_size, + sliding_window=self.sliding_window, + windowed_context_encoding_size=self.windowed_context_encoding_size, + layer_to_cache_size_mapping=self.layer_to_cache_size_mapping, + ) + @property def _deltanet_state_params(self): """Return DeltaNet state nn.Parameters in alias order.""" @@ -1639,7 +1913,10 @@ def get_model_output( past_key_values_length = 0 if past_key_values is not None: - past_key_values_length = past_key_values[0][1].shape[2] + if hasattr(self.kv_mgr, "get_seq_length"): + past_key_values_length = self.kv_mgr.get_seq_length(past_key_values) + else: + past_key_values_length = past_key_values[0][1].shape[2] if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -1909,7 +2186,10 @@ def forward( outputs += updated_kv_cache # Append DeltaNet state tensors (for input_output_aliases) - if hasattr(self, "_deltanet_updated_states"): + if ( + not getattr(self.config, "use_hybrid_cache_manager", False) + and hasattr(self, "_deltanet_updated_states") + ): outputs += self._deltanet_updated_states return outputs @@ -2061,7 +2341,10 @@ def get(self, bucket_rank, **kwargs): state_start_idx = num_output_from_trace + num_kv - if hasattr(module, "_deltanet_state_params"): + if ( + not getattr(module.config, "use_hybrid_cache_manager", False) + and hasattr(module, "_deltanet_state_params") + ): for i, param in enumerate(module._deltanet_state_params): input_output_aliases[param] = state_start_idx + i @@ -2301,6 +2584,8 @@ def enable_token_generation(self): def _copy_past_key_values(self, outputs): """Override to also copy DeltaNet state buffers on CPU.""" super()._copy_past_key_values(outputs) + if getattr(self.config, "use_hybrid_cache_manager", False): + return num_output_from_trace = 1 if ( diff --git a/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_fused.py b/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_fused.py index 3447a138..4d02423d 100644 --- a/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_fused.py +++ b/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_fused.py @@ -312,14 +312,21 @@ def deltanet_fused_chunked_fwd( # ============================================================ # Decay mask: QK_decay[i,j] = QK[i,j] * exp(gc[i]) * exp(-gc[j]) # + # Apply the strict causal mask before the split exp(gc) / exp(-gc) + # scaling. Upper-triangular entries are mathematically unused, but + # scaling them first can create very large finite values that poison + # later matmuls before the mask is applied. + # ============================================================ + QK_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_masked, data1=QK, data2=Lmask, op=nl.multiply) + # Row scaling: QK_row[i,:] = QK[i,:] * exp(gc[i]) # Then transpose, column scale, transpose back. # Uses tensor_scalar with (P_MAX,1) operand for row scaling. - # ============================================================ QK_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_scalar( dst=QK_row, - data=QK, + data=QK_masked, op0=nl.multiply, operand0=exp_gc_p, engine=nisa.vector_engine, @@ -435,11 +442,16 @@ def deltanet_fused_chunked_fwd( qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_copy(dst=qk_raw, src=qk_psum) + # Mask before split scaling for the same reason as the A matrix above: + # upper-triangular decay factors are unused and can be numerically huge. + qk_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=qk_masked, data1=qk_raw, data2=Lmask_d, op=nl.multiply) + # Row-scale by exp(gc) qk_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_scalar( dst=qk_row, - data=qk_raw, + data=qk_masked, op0=nl.multiply, operand0=exp_gc_p, engine=nisa.vector_engine, @@ -538,13 +550,40 @@ def deltanet_fused_chunked_fwd( # state is updated IN-PLACE in SBUF — no HBM round-trip! # ============================================================ - # k_raw_decay = k * exp(-gc) + # k_raw_decay contributes as exp(g_last) * (k * exp(-gc))^T @ v_new. + # Compute the equivalent form with one bounded exponential, + # k * exp(g_last - gc), so the factor is always <= 1 for valid + # causal positions. + gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gl_11[0:1, 0:1], + dst=gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gl_minus_gc_p, + data1=gl_p, + data2=gc_p, + op=nl.subtract, + ) + exp_gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_minus_gc_p, + op=nl.exp, + data=gl_minus_gc_p, + bias=None, + scale=1.0, + ) + + # k_raw_decay = k * exp(g_last - gc) k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_scalar( dst=k_raw_decay, data=k_c, op0=nl.multiply, - operand0=exp_neg_gc_p, + operand0=exp_gl_minus_gc_p, engine=nisa.vector_engine, ) @@ -557,19 +596,17 @@ def deltanet_fused_chunked_fwd( kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_copy(dst=kv_outer, src=kv_psum) - # state = state + kv_outer - state_plus = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_tensor(dst=state_plus, data1=state, data2=kv_outer, op=nl.add) - - # state = state_plus * exp(g_last) - # tensor_scalar broadcasts exp_gl_p (P_MAX, 1) across free dim + # state = state * exp(g_last) + kv_outer + # tensor_scalar broadcasts exp_gl_p (P_MAX, 1) across free dim. + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_scalar( - dst=state, - data=state_plus, + dst=state_decayed, + data=state, op0=nl.multiply, operand0=exp_gl_p, engine=nisa.vector_engine, ) + nisa.tensor_tensor(dst=state, data1=state_decayed, data2=kv_outer, op=nl.add) # ---- Write final state to HBM ---- nisa.dma_copy(dst=final_state_out, src=state) diff --git a/contrib/models/Qwen3.5-4B/test/integration/test_model.py b/contrib/models/Qwen3.5-4B/test/integration/test_model.py index 6de43de4..2c1353c0 100644 --- a/contrib/models/Qwen3.5-4B/test/integration/test_model.py +++ b/contrib/models/Qwen3.5-4B/test/integration/test_model.py @@ -36,7 +36,10 @@ """ import gc +import json import os +import shutil +import subprocess import sys import time @@ -56,6 +59,8 @@ SEQ_LEN = int(os.environ.get("QWEN35_SEQ_LEN", "128")) TTFT_THRESHOLD_MS = float(os.environ.get("TTFT_THRESHOLD_MS", "5000")) THROUGHPUT_THRESHOLD = float(os.environ.get("THROUGHPUT_THRESHOLD", "5.0")) +USE_HYBRID_CACHE = os.environ.get("QWEN35_USE_HYBRID_CACHE", "0") == "1" +RECORD_HBM = os.environ.get("QWEN35_RECORD_HBM", "0") == "1" requires_model_path = pytest.mark.skipif( not MODEL_PATH, @@ -64,6 +69,13 @@ "weights. Set QWEN35_MODEL_PATH=/path/to/Qwen3.5-4B to run these tests." ), ) +requires_hbm_recording = pytest.mark.skipif( + not RECORD_HBM, + reason=( + "QWEN35_RECORD_HBM=1 not set. This optional test records Neuron HBM " + "usage for dummy-KV vs hybrid-cache comparisons." + ), +) # ── Fixtures ──────────────────────────────────────────────────────────── @@ -119,6 +131,7 @@ def compiled_model(model_path): inf_config = Qwen35InferenceConfig( neuron_config=neuron_config, + use_hybrid_cache_manager=USE_HYBRID_CACHE, **config_dict, ) @@ -195,6 +208,73 @@ def _is_repetitive(text, max_repeat=5): return False +def _parse_peak_neuron_memory(stdout): + peak_device = 0 + peak_tensors = 0 + samples = 0 + for line in stdout.splitlines(): + line = line.strip() + if not line: + continue + try: + report = json.loads(line) + except json.JSONDecodeError: + continue + for runtime in report.get("neuron_runtime_data", []): + memory_used = runtime.get("report", {}).get("memory_used", {}) + used = memory_used.get("neuron_runtime_used_bytes", {}) + peak_device = max(peak_device, int(used.get("neuron_device", 0) or 0)) + nc_usage = ( + used.get("usage_breakdown", {}).get("neuroncore_memory_usage", {}) + ) + tensor_bytes = sum( + int(core.get("tensors", 0) or 0) for core in nc_usage.values() + ) + peak_tensors = max(peak_tensors, tensor_bytes) + samples += 1 + return peak_device, peak_tensors, samples + + +def _capture_neuron_hbm(tmp_path, fn): + if shutil.which("neuron-monitor") is None: + pytest.skip("neuron-monitor is not available") + + monitor_config = { + "period": "0.5s", + "neuron_runtimes": [ + { + "tag_filter": ".*", + "metrics": [{"type": "memory_used", "period": "0.5s"}], + } + ], + } + config_path = tmp_path / "neuron-monitor.json" + config_path.write_text(json.dumps(monitor_config)) + + proc = subprocess.Popen( + ["neuron-monitor", "--config-file", str(config_path)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + try: + time.sleep(1.0) + result = fn() + time.sleep(1.0) + finally: + proc.terminate() + try: + stdout, stderr = proc.communicate(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + stdout, stderr = proc.communicate(timeout=5) + + peak_device, peak_tensors, samples = _parse_peak_neuron_memory(stdout) + assert samples > 0, f"neuron-monitor produced no runtime samples: {stderr}" + assert peak_device > 0, "Expected non-zero Neuron device HBM usage" + return result, peak_device, peak_tensors, samples + + # ── Smoke Tests ───────────────────────────────────────────────────────── @@ -362,6 +442,33 @@ def test_performance_throughput(compiled_model, tokenizer, generation_config): ) +@requires_model_path +@requires_hbm_recording +def test_hybrid_cache_hbm_snapshot(compiled_model, tokenizer, generation_config, tmp_path): + """Record peak Neuron HBM for dummy-KV vs hybrid-cache comparison runs.""" + prompt = "Give me a summary of the 2020 Olympics in 100 tokens." + max_new_tokens = int(os.environ.get("QWEN35_HBM_NEW_TOKENS", "32")) + + (_, text), peak_device, peak_tensors, samples = _capture_neuron_hbm( + tmp_path, + lambda: _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=max_new_tokens, + ), + ) + + mode = "hybrid" if USE_HYBRID_CACHE else "dummy_kv" + print( + " HBM " + f"mode={mode} peak_device_bytes={peak_device} " + f"peak_tensor_bytes={peak_tensors} samples={samples}" + ) + assert len(text) > len(prompt) + + # ── Multi-Prompt Quality Test ────────────────────────────────────────── diff --git a/contrib/models/Qwen3.5-4B/test/unit/test_hybrid_cache_manager.py b/contrib/models/Qwen3.5-4B/test/unit/test_hybrid_cache_manager.py new file mode 100644 index 00000000..c5941945 --- /dev/null +++ b/contrib/models/Qwen3.5-4B/test/unit/test_hybrid_cache_manager.py @@ -0,0 +1,335 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys +import unittest +from math import prod +from unittest.mock import patch + +import torch + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from neuronx_distributed_inference.models.config import NeuronConfig +from src.modeling_qwen35 import HybridDeltaNetCacheManager, Qwen35InferenceConfig + + +def _make_config(**overrides): + neuron_overrides = overrides.pop("neuron_overrides", {}) + neuron_kwargs = dict( + tp_degree=overrides.pop("tp_degree", 4), + batch_size=1, + max_batch_size=2, + kv_cache_batch_size=2, + seq_len=16, + torch_dtype=torch.bfloat16, + ) + neuron_kwargs.update(neuron_overrides) + neuron_config = NeuronConfig(**neuron_kwargs) + defaults = dict( + hidden_size=2560, + num_hidden_layers=32, + num_attention_heads=16, + num_key_value_heads=4, + head_dim=256, + intermediate_size=9216, + vocab_size=248320, + rms_norm_eps=1e-6, + max_position_embeddings=262144, + rope_theta=10000000, + hidden_act="silu", + tie_word_embeddings=True, + linear_num_value_heads=32, + linear_num_key_heads=16, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + use_hybrid_cache_manager=True, + ) + defaults.update(overrides) + return Qwen35InferenceConfig(neuron_config=neuron_config, **defaults) + + +def _numel(shape): + return prod(int(dim) for dim in shape) + + +def _managed_cache_numel(mgr): + return sum(param.numel() for param in mgr.past_key_values) + + +def _deltanet_state_numel(config, max_batch_size): + recurrent = ( + max_batch_size + * config.linear_num_value_heads + * config.linear_key_head_dim + * config.linear_value_head_dim + ) + conv_dim = ( + 2 * config.linear_num_key_heads * config.linear_key_head_dim + + config.linear_num_value_heads * config.linear_value_head_dim + ) + conv = max_batch_size * conv_dim * (config.linear_conv_kernel_dim - 1) + return recurrent + conv + + +class TestHybridDeltaNetCacheManager(unittest.TestCase): + def test_allocates_per_layer_cache_shapes(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + self.assertEqual(len(mgr.past_key_values), config.num_hidden_layers * 2) + self.assertEqual(list(mgr.past_key_values[0].shape), [2, 32, 128, 128]) + self.assertEqual(list(mgr.past_key_values[1].shape), [2, 8192, 3]) + self.assertEqual(mgr.layer_types[3], "full_attention") + self.assertEqual(mgr.past_key_values[6].dim(), 4) + self.assertEqual(mgr.past_key_values[7].shape[2], 16) + + def test_get_cache_slices_only_full_attention_layers(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + cache = mgr.get_cache(seq_len=4, seq_ids=torch.tensor([1])) + recurrent_state, conv_state = cache[0] + full_k, full_v = cache[3] + + self.assertEqual(list(recurrent_state.shape), [1, 32, 128, 128]) + self.assertEqual(list(conv_state.shape), [1, 8192, 3]) + self.assertEqual(full_k.shape[0], 2) + self.assertEqual(full_v.shape[0], 2) + self.assertEqual(full_k.shape[2], 4) + self.assertEqual(full_v.shape[2], 4) + + def test_get_seq_length_uses_first_full_attention_layer(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + nested_cache = mgr.get_cache(seq_len=5, seq_ids=torch.tensor([0])) + flat_cache = [tensor for layer_cache in nested_cache for tensor in layer_cache] + + self.assertEqual(nested_cache[0][1].shape[2], 3) + self.assertEqual(mgr.get_seq_length(nested_cache), 5) + self.assertEqual(mgr.get_seq_length(flat_cache), 5) + + def test_get_cache_selects_deltanet_state_rows_by_seq_ids(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + with torch.no_grad(): + mgr.past_key_values[0][0, ...].fill_(7) + mgr.past_key_values[0][1, ...].fill_(13) + mgr.past_key_values[1][0, ...].fill_(17) + mgr.past_key_values[1][1, ...].fill_(19) + + recurrent_state, conv_state = mgr.get_cache( + seq_len=4, + seq_ids=torch.tensor([1, 0]), + )[0] + + self.assertTrue(torch.all(recurrent_state[0] == 13)) + self.assertTrue(torch.all(recurrent_state[1] == 7)) + self.assertTrue(torch.all(conv_state[0] == 19)) + self.assertTrue(torch.all(conv_state[1] == 17)) + + def test_deltanet_update_scatters_by_seq_id(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones((1, 32, 128, 128), dtype=torch.bfloat16) + conv = torch.ones((1, 8192, 3), dtype=torch.bfloat16) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=torch.tensor([1]), + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 0)) + self.assertTrue(torch.all(updated_conv[0] == 0)) + self.assertTrue(torch.all(updated_recurrent[1] == 1)) + self.assertTrue(torch.all(updated_conv[1] == 1)) + + def test_deltanet_full_batch_update_replaces_state_cache(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones((2, 32, 128, 128), dtype=torch.bfloat16) + conv = torch.ones((2, 8192, 3), dtype=torch.bfloat16) + recurrent[0].fill_(3) + recurrent[1].fill_(5) + conv[0].fill_(11) + conv[1].fill_(13) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=None, + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 3)) + self.assertTrue(torch.all(updated_recurrent[1] == 5)) + self.assertTrue(torch.all(updated_conv[0] == 11)) + self.assertTrue(torch.all(updated_conv[1] == 13)) + + def test_deltanet_full_batch_update_scatters_non_identity_seq_ids(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones((2, 32, 128, 128), dtype=torch.bfloat16) + conv = torch.ones((2, 8192, 3), dtype=torch.bfloat16) + recurrent[0].fill_(3) + recurrent[1].fill_(5) + conv[0].fill_(11) + conv[1].fill_(13) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=torch.tensor([1, 0]), + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 5)) + self.assertTrue(torch.all(updated_recurrent[1] == 3)) + self.assertTrue(torch.all(updated_conv[0] == 13)) + self.assertTrue(torch.all(updated_conv[1] == 11)) + + def test_deltanet_update_maps_out_of_range_seq_id_to_padding_row(self): + config = _make_config(neuron_overrides={"kv_cache_padding_size": 1}) + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones((1, 32, 128, 128), dtype=torch.bfloat16) + conv = torch.ones((1, 8192, 3), dtype=torch.bfloat16) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=torch.tensor([99]), + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 0)) + self.assertTrue(torch.all(updated_recurrent[1] == 0)) + self.assertTrue(torch.all(updated_recurrent[2] == 1)) + self.assertTrue(torch.all(updated_conv[2] == 1)) + + def test_deltanet_state_shapes_do_not_scale_with_sequence_length(self): + short_config = _make_config(neuron_overrides={"seq_len": 128}) + long_config = _make_config(neuron_overrides={"seq_len": 2048}) + short_mgr = HybridDeltaNetCacheManager( + short_config, num_kv_head=short_config.num_key_value_heads + ) + long_mgr = HybridDeltaNetCacheManager( + long_config, num_kv_head=long_config.num_key_value_heads + ) + + self.assertEqual(short_mgr.past_key_values[0].shape, long_mgr.past_key_values[0].shape) + self.assertEqual(short_mgr.past_key_values[1].shape, long_mgr.past_key_values[1].shape) + self.assertLess(short_mgr.past_key_values[7].shape[2], long_mgr.past_key_values[7].shape[2]) + + def test_get_cache_trims_padding_row_without_seq_ids(self): + config = _make_config(neuron_overrides={"kv_cache_padding_size": 1}) + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + recurrent_state, conv_state = mgr.get_cache(seq_len=4)[0] + + self.assertEqual(list(recurrent_state.shape), [2, 32, 128, 128]) + self.assertEqual(list(conv_state.shape), [2, 8192, 3]) + + def test_update_cache_dispatches_deltanet_and_full_attention_layers(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + new_key_values = [] + for idx in range(4): + first = mgr.past_key_values[2 * idx] + second = mgr.past_key_values[2 * idx + 1] + new_key_values.append( + ( + torch.full_like(first, fill_value=idx + 1), + torch.full_like(second, fill_value=idx + 11), + ) + ) + + position_ids = torch.arange(16, dtype=torch.long).unsqueeze(0).expand(2, -1) + full_k_update = torch.full_like(mgr.past_key_values[6], fill_value=4) + full_v_update = torch.full_like(mgr.past_key_values[7], fill_value=14) + with patch.object( + mgr, "update_kv_by_layer_id", return_value=(full_k_update, full_v_update) + ) as update_kv: + updated = mgr.update_cache( + is_for_context_encoding=True, + seq_ids=torch.tensor([0, 1], dtype=torch.int32), + position_ids=position_ids, + new_key_values=new_key_values, + seq_len=16, + ) + + self.assertEqual(update_kv.call_count, 1) + self.assertEqual(update_kv.call_args.kwargs["idx"], 3) + self.assertTrue(torch.all(updated[0] == 1)) + self.assertTrue(torch.all(updated[1] == 11)) + self.assertTrue(torch.all(updated[6] == 4)) + self.assertTrue(torch.all(updated[7] == 14)) + + def test_managed_cache_removes_dummy_kv_for_deltanet_layers(self): + config = _make_config(neuron_overrides={"seq_len": 1024}) + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + max_batch_size = ( + config.neuron_config.kv_cache_batch_size + + config.neuron_config.kv_cache_padding_size + ) + full_kv_per_layer = _numel(mgr.k_shape) + _numel(mgr.v_shape) + deltanet_layers = config.layer_types.count("linear_attention") + legacy_total_numel = ( + full_kv_per_layer * config.num_hidden_layers + + _deltanet_state_numel(config, max_batch_size) * deltanet_layers + ) + expected_savings = full_kv_per_layer * deltanet_layers + + self.assertEqual( + legacy_total_numel - _managed_cache_numel(mgr), + expected_savings, + ) + self.assertLess(_managed_cache_numel(mgr), legacy_total_numel) + + def test_rejects_unsupported_hybrid_modes(self): + unsupported_cases = [ + ({"padding_side": "left"}, "left padding"), + ({"flash_decoding_enabled": True}, "flash decoding"), + ] + + for neuron_overrides, expected_error in unsupported_cases: + with self.subTest(expected_error=expected_error): + config = _make_config(neuron_overrides=neuron_overrides) + with self.assertRaisesRegex(ValueError, expected_error): + HybridDeltaNetCacheManager( + config, num_kv_head=config.num_key_value_heads + ) + + config = _make_config() + config.neuron_config.kv_cache_quant = True + with self.assertRaisesRegex(ValueError, "KV cache quantization"): + HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + config = _make_config( + neuron_overrides={ + "attention_dp_degree": 2, + "batch_size": 2, + "ctx_batch_size": 2, + "tkg_batch_size": 2, + "max_batch_size": 2, + "kv_cache_batch_size": 2, + "is_continuous_batching": True, + } + ) + with self.assertRaisesRegex(ValueError, "attention data parallelism"): + HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + config = _make_config() + config.neuron_config.kv_cache_tiling = True + with self.assertRaisesRegex(ValueError, "KV cache tiling"): + HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + def test_legacy_config_default_is_disabled(self): + config = _make_config(use_hybrid_cache_manager=False) + self.assertFalse(config.use_hybrid_cache_manager) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.5-9B/src/modeling_qwen35.py b/contrib/models/Qwen3.5-9B/src/modeling_qwen35.py index 657df76f..b1622d86 100644 --- a/contrib/models/Qwen3.5-9B/src/modeling_qwen35.py +++ b/contrib/models/Qwen3.5-9B/src/modeling_qwen35.py @@ -84,6 +84,7 @@ NeuronAttentionBase, ) from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.modules.kvcache.kv_cache_manager import KVCacheManager from neuronx_distributed_inference.models.layer_boundary_marker import ( ModuleMarkerEndWrapper, ModuleMarkerStartWrapper, @@ -221,7 +222,7 @@ def l2norm(x, dim=-1, eps=1e-6): return F.normalize(x, p=2, dim=dim, eps=eps) -FUSED_DELTANET_DECAY_MIN = -60.0 +FUSED_DELTANET_DECAY_MIN = -20.0 FUSED_DELTANET_DECAY_MAX = 0.0 @@ -286,6 +287,7 @@ def __init__(self, config, layer_idx: int): self.conv_kernel_size = tc.linear_conv_kernel_dim # 4 self.layer_idx = layer_idx self.rms_norm_eps = tc.rms_norm_eps + self.use_hybrid_cache_manager = getattr(tc, "use_hybrid_cache_manager", False) # KV cache dummy shape info self.head_dim = tc.head_dim # 256 @@ -747,6 +749,11 @@ def forward( # zeros the decay gate so the recurrent state is preserved unchanged # through padding positions (no spurious decay). valid_mask_1d = kwargs.get("deltanet_padding_mask", None) # [B, S, 1] or None + hybrid_cache_active = self.use_hybrid_cache_manager + recurrent_state_cache = None + conv_state_cache = None + if hybrid_cache_active and past_key_value is not None: + recurrent_state_cache, conv_state_cache = past_key_value # Project inputs deltanet_fp32 = os.environ.get("DELTANET_FP32") == "1" @@ -774,7 +781,9 @@ def forward( mixed = mixed.transpose(1, 2) if is_decode: - if seq_ids is not None: + if conv_state_cache is not None: + conv_state = conv_state_cache[:batch_size] + elif seq_ids is not None: conv_state = torch.index_select(self.conv_state_buffer, 0, seq_ids) else: conv_state = self.conv_state_buffer[:batch_size] @@ -791,7 +800,9 @@ def forward( new_conv_state = torch.cat([conv_state[:, :, 1:], mixed], dim=-1) alloc_bs = self.conv_state_buffer.shape[0] - if seq_ids is not None: + if hybrid_cache_active: + new_conv_state = new_conv_state.to(self.conv_state_buffer.dtype) + elif seq_ids is not None: # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement # Add buffer dependency for input_output_alias new_conv_state = ( @@ -827,7 +838,9 @@ def forward( new_conv_state = mixed[:, :, -3:].contiguous() alloc_bs = self.conv_state_buffer.shape[0] - if seq_ids is not None: + if hybrid_cache_active: + new_conv_state = new_conv_state.to(self.conv_state_buffer.dtype) + elif seq_ids is not None: # BS=1 optimization: scatter to index 0 = direct replacement new_conv_state = ( new_conv_state.to(self.conv_state_buffer.dtype) @@ -906,7 +919,9 @@ def forward( if is_decode: # TKG: single-step recurrent update - if seq_ids is not None: + if recurrent_state_cache is not None: + recurrent_state = recurrent_state_cache[:batch_size].float() + elif seq_ids is not None: recurrent_state = torch.index_select( self.recurrent_state_buffer, 0, seq_ids ).float() @@ -918,7 +933,9 @@ def forward( ) new_state_bf16 = new_state.to(self.recurrent_state_buffer.dtype) alloc_bs = self.recurrent_state_buffer.shape[0] - if seq_ids is not None: + if hybrid_cache_active: + new_rec_state = new_state_bf16 + elif seq_ids is not None: # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement # Add buffer dependency for input_output_alias new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0 @@ -970,7 +987,9 @@ def forward( if final_state is not None: final_state_bf16 = final_state.to(self.recurrent_state_buffer.dtype) alloc_bs = self.recurrent_state_buffer.shape[0] - if seq_ids is not None: + if hybrid_cache_active: + new_rec_state = final_state_bf16 + elif seq_ids is not None: # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement # Add buffer dependency for input_output_alias new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0 @@ -1005,6 +1024,9 @@ def forward( output = output.reshape(batch_size, seq_len, self.value_dim) output = self.out_proj(output) + if hybrid_cache_active: + return output, (new_rec_state, new_conv_state), new_rec_state, new_conv_state + # Return dummy KV for KVCacheManager dummy_k = torch.zeros( batch_size, @@ -1059,6 +1081,7 @@ def __init__(self, *args, **kwargs): kwargs.setdefault("linear_key_head_dim", 128) kwargs.setdefault("linear_value_head_dim", 128) kwargs.setdefault("linear_conv_kernel_dim", 4) + kwargs.setdefault("use_hybrid_cache_manager", False) super().__init__(*args, **kwargs) @@ -1512,7 +1535,11 @@ def forward( ) hidden_states = residual + attn_out present_key_value = dummy_kv - deltanet_states = (new_rec_state, new_conv_state) + deltanet_states = ( + None + if getattr(self.config, "use_hybrid_cache_manager", False) + else (new_rec_state, new_conv_state) + ) else: deltanet_states = None # Standard attention path @@ -1545,6 +1572,240 @@ def forward( return outputs +# ============================================================ +# Hybrid Cache Manager (opt-in) +# ============================================================ + + +class HybridDeltaNetCacheManager(KVCacheManager): + """Layer-type-aware cache manager for Qwen3.5/Qwen3.6 hybrid dense models.""" + + def __init__(self, config: Qwen35InferenceConfig, num_kv_head, **kwargs): + self.layer_types = list(config.layer_types) + self._validate_hybrid_config(config) + super().__init__(config, num_kv_head=num_kv_head, **kwargs) + + dtype = ( + config.neuron_config.attention_dtype + if config.neuron_config.attention_dtype is not None + else config.neuron_config.torch_dtype + ) + cache_dtype = getattr(self, "cache_dtype", dtype) + max_batch_size = ( + config.neuron_config.kv_cache_batch_size + + config.neuron_config.kv_cache_padding_size + ) + recurrent_shape = [ + max_batch_size, + config.linear_num_value_heads, + config.linear_key_head_dim, + config.linear_value_head_dim, + ] + conv_dim = ( + 2 * config.linear_num_key_heads * config.linear_key_head_dim + + config.linear_num_value_heads * config.linear_value_head_dim + ) + conv_shape = [ + max_batch_size, + conv_dim, + config.linear_conv_kernel_dim - 1, + ] + + params = [] + for layer_idx, layer_type in enumerate(self.layer_types): + if layer_type == "linear_attention": + params.append( + nn.Parameter(torch.zeros(recurrent_shape, dtype=dtype), requires_grad=False) + ) + params.append( + nn.Parameter(torch.zeros(conv_shape, dtype=dtype), requires_grad=False) + ) + else: + k_shape = self.k_shapes[layer_idx] if hasattr(self, "k_shapes") else self.k_shape + v_shape = self.v_shapes[layer_idx] if hasattr(self, "v_shapes") else self.v_shape + params.append( + nn.Parameter(torch.zeros(k_shape, dtype=cache_dtype), requires_grad=False) + ) + params.append( + nn.Parameter(torch.zeros(v_shape, dtype=cache_dtype), requires_grad=False) + ) + + self.past_key_values = nn.ParameterList(params) + + @staticmethod + def _validate_hybrid_config(config: Qwen35InferenceConfig): + nc = config.neuron_config + unsupported = [] + if nc.is_block_kv_layout: + unsupported.append("block KV layout") + if getattr(nc, "kv_quant_config", None) is not None or getattr(nc, "kv_cache_quant", False): + unsupported.append("KV cache quantization") + if nc.enable_fused_speculation or nc.speculation_length > 0 or nc.is_medusa: + unsupported.append("speculative decoding") + if getattr(nc, "enable_eagle_speculation", False) or getattr(nc, "is_eagle_draft", False): + unsupported.append("EAGLE speculation") + if nc.flash_decoding_enabled: + unsupported.append("flash decoding") + if nc.attention_dp_degree > 1: + unsupported.append("attention data parallelism") + if nc.kv_cache_tiling: + unsupported.append("KV cache tiling") + if nc.padding_side != "right": + unsupported.append("left padding") + if nc.is_continuous_batching: + unsupported.append("continuous batching") + if unsupported: + raise ValueError( + "HybridDeltaNetCacheManager v1 does not support: " + + ", ".join(unsupported) + ) + + def _is_deltanet_layer(self, idx: int) -> bool: + return self.layer_types[idx] == "linear_attention" + + def get_seq_length(self, past_key_values=None): + for idx, layer_type in enumerate(self.layer_types): + if layer_type != "linear_attention": + if past_key_values is None: + _, v_cache = self._fetch_cache(idx) + elif len(past_key_values) == len(self.past_key_values): + v_cache = past_key_values[2 * idx + 1] + else: + v_cache = past_key_values[idx][1] + return v_cache.shape[2] + return 0 + + def get_deltanet_state_by_layer_id(self, idx, kvcache_buffer=None, seq_ids=None): + recurrent_state, conv_state = self._fetch_cache(idx, kvcache_buffer) + if seq_ids is not None: + cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) + recurrent_state = torch.index_select(recurrent_state, dim=0, index=cache_idx) + conv_state = torch.index_select(conv_state, dim=0, index=cache_idx) + elif self.kv_cache_padding_size > 0: + recurrent_state = recurrent_state[: -self.kv_cache_padding_size] + conv_state = conv_state[: -self.kv_cache_padding_size] + return recurrent_state, conv_state + + def get_cache( + self, + seq_len: int, + skip_slice=False, + kvcache_buffer=None, + seq_ids=None, + windowed_context_encoding_window_idx=-1, + **kwargs, + ): + past_key_values = [] + for idx in range(len(self.past_key_values) // 2): + if self._is_deltanet_layer(idx): + past_key_values.append( + list(self.get_deltanet_state_by_layer_id(idx, kvcache_buffer, seq_ids)) + ) + else: + past_key_values.append( + list( + self.get_kv_by_layer_id( + idx=idx, + skip_slice=skip_slice, + seq_len=seq_len, + kvcache_buffer=kvcache_buffer, + seq_ids=seq_ids, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + ) + ) + return past_key_values + + def update_cache( + self, + is_for_context_encoding: bool, + seq_ids: torch.Tensor, + position_ids: torch.Tensor, + new_key_values: List[torch.Tensor], + seq_len: int, + scatter_index=None, + kv_active_mask=None, + kvcache_buffer=None, + windowed_context_encoding_window_idx: int = -1, + **kwargs, + ): + updated_cache = [] + for idx, kv_per_layer in enumerate(new_key_values): + if self._is_deltanet_layer(idx): + recurrent_state, conv_state = self.update_deltanet_state_by_layer_id( + idx=idx, + seq_ids=seq_ids, + state_per_layer=kv_per_layer, + kvcache_buffer=kvcache_buffer, + ) + else: + recurrent_state, conv_state = self.update_kv_by_layer_id( + idx=idx, + is_for_context_encoding=is_for_context_encoding, + seq_ids=seq_ids, + position_ids=position_ids, + kv_per_layer=kv_per_layer, + seq_len=seq_len, + scatter_index=scatter_index, + kv_active_mask=kv_active_mask, + kvcache_buffer=kvcache_buffer, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + updated_cache.append(recurrent_state) + updated_cache.append(conv_state) + return updated_cache + + def update_deltanet_state_by_layer_id( + self, + idx: int, + seq_ids: torch.Tensor, + state_per_layer: Tuple[torch.Tensor, torch.Tensor], + kvcache_buffer=None, + ): + latest_recurrent, latest_conv = state_per_layer + recurrent_cache, conv_cache = self._fetch_cache(idx, kvcache_buffer) + latest_recurrent = latest_recurrent.to(recurrent_cache.dtype) + latest_conv = latest_conv.to(conv_cache.dtype) + + if seq_ids is not None: + cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) + recurrent_index = cache_idx.view(-1, 1, 1, 1).expand_as(latest_recurrent) + conv_index = cache_idx.view(-1, 1, 1).expand_as(latest_conv) + recurrent_cache = torch.scatter( + input=recurrent_cache, + dim=0, + index=recurrent_index, + src=latest_recurrent, + ) + conv_cache = torch.scatter( + input=conv_cache, + dim=0, + index=conv_index, + src=latest_conv, + ) + return recurrent_cache, conv_cache + + if latest_recurrent.shape[0] == recurrent_cache.shape[0]: + return ( + latest_recurrent + recurrent_cache * 0, + latest_conv + conv_cache * 0, + ) + + pad_size = recurrent_cache.shape[0] - latest_recurrent.shape[0] + if pad_size > 0: + latest_recurrent = torch.cat( + [latest_recurrent, recurrent_cache[latest_recurrent.shape[0] :] * 0], + dim=0, + ) + latest_conv = torch.cat( + [latest_conv, conv_cache[latest_conv.shape[0] :] * 0], + dim=0, + ) + return latest_recurrent + recurrent_cache * 0, latest_conv + conv_cache * 0 + + # ============================================================ # Model # ============================================================ @@ -1590,6 +1851,19 @@ def init_model(self, config: Qwen35InferenceConfig): # mRoPE embedding for VL self.mrope_emb = Qwen35MRoPEEmbedding(config) + def init_inference_optimization(self, config: Qwen35InferenceConfig): + super().init_inference_optimization(config) + if getattr(config, "use_hybrid_cache_manager", False): + self.kv_mgr = HybridDeltaNetCacheManager( + config, + num_kv_head=self.num_key_value_heads, + global_rank=self.rank_util, + attention_chunk_size=self.attention_chunk_size, + sliding_window=self.sliding_window, + windowed_context_encoding_size=self.windowed_context_encoding_size, + layer_to_cache_size_mapping=self.layer_to_cache_size_mapping, + ) + @property def _deltanet_state_params(self): """Return DeltaNet state nn.Parameters in alias order.""" @@ -1639,7 +1913,10 @@ def get_model_output( past_key_values_length = 0 if past_key_values is not None: - past_key_values_length = past_key_values[0][1].shape[2] + if hasattr(self.kv_mgr, "get_seq_length"): + past_key_values_length = self.kv_mgr.get_seq_length(past_key_values) + else: + past_key_values_length = past_key_values[0][1].shape[2] if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -1909,7 +2186,10 @@ def forward( outputs += updated_kv_cache # Append DeltaNet state tensors (for input_output_aliases) - if hasattr(self, "_deltanet_updated_states"): + if ( + not getattr(self.config, "use_hybrid_cache_manager", False) + and hasattr(self, "_deltanet_updated_states") + ): outputs += self._deltanet_updated_states return outputs @@ -2061,7 +2341,10 @@ def get(self, bucket_rank, **kwargs): state_start_idx = num_output_from_trace + num_kv - if hasattr(module, "_deltanet_state_params"): + if ( + not getattr(module.config, "use_hybrid_cache_manager", False) + and hasattr(module, "_deltanet_state_params") + ): for i, param in enumerate(module._deltanet_state_params): input_output_aliases[param] = state_start_idx + i @@ -2296,6 +2579,8 @@ def enable_token_generation(self): def _copy_past_key_values(self, outputs): """Override to also copy DeltaNet state buffers on CPU.""" super()._copy_past_key_values(outputs) + if getattr(self.config, "use_hybrid_cache_manager", False): + return num_output_from_trace = 1 if ( diff --git a/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_fused.py b/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_fused.py index 3447a138..4d02423d 100644 --- a/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_fused.py +++ b/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_fused.py @@ -312,14 +312,21 @@ def deltanet_fused_chunked_fwd( # ============================================================ # Decay mask: QK_decay[i,j] = QK[i,j] * exp(gc[i]) * exp(-gc[j]) # + # Apply the strict causal mask before the split exp(gc) / exp(-gc) + # scaling. Upper-triangular entries are mathematically unused, but + # scaling them first can create very large finite values that poison + # later matmuls before the mask is applied. + # ============================================================ + QK_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_masked, data1=QK, data2=Lmask, op=nl.multiply) + # Row scaling: QK_row[i,:] = QK[i,:] * exp(gc[i]) # Then transpose, column scale, transpose back. # Uses tensor_scalar with (P_MAX,1) operand for row scaling. - # ============================================================ QK_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_scalar( dst=QK_row, - data=QK, + data=QK_masked, op0=nl.multiply, operand0=exp_gc_p, engine=nisa.vector_engine, @@ -435,11 +442,16 @@ def deltanet_fused_chunked_fwd( qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_copy(dst=qk_raw, src=qk_psum) + # Mask before split scaling for the same reason as the A matrix above: + # upper-triangular decay factors are unused and can be numerically huge. + qk_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=qk_masked, data1=qk_raw, data2=Lmask_d, op=nl.multiply) + # Row-scale by exp(gc) qk_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_scalar( dst=qk_row, - data=qk_raw, + data=qk_masked, op0=nl.multiply, operand0=exp_gc_p, engine=nisa.vector_engine, @@ -538,13 +550,40 @@ def deltanet_fused_chunked_fwd( # state is updated IN-PLACE in SBUF — no HBM round-trip! # ============================================================ - # k_raw_decay = k * exp(-gc) + # k_raw_decay contributes as exp(g_last) * (k * exp(-gc))^T @ v_new. + # Compute the equivalent form with one bounded exponential, + # k * exp(g_last - gc), so the factor is always <= 1 for valid + # causal positions. + gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gl_11[0:1, 0:1], + dst=gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gl_minus_gc_p, + data1=gl_p, + data2=gc_p, + op=nl.subtract, + ) + exp_gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_minus_gc_p, + op=nl.exp, + data=gl_minus_gc_p, + bias=None, + scale=1.0, + ) + + # k_raw_decay = k * exp(g_last - gc) k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_scalar( dst=k_raw_decay, data=k_c, op0=nl.multiply, - operand0=exp_neg_gc_p, + operand0=exp_gl_minus_gc_p, engine=nisa.vector_engine, ) @@ -557,19 +596,17 @@ def deltanet_fused_chunked_fwd( kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_copy(dst=kv_outer, src=kv_psum) - # state = state + kv_outer - state_plus = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_tensor(dst=state_plus, data1=state, data2=kv_outer, op=nl.add) - - # state = state_plus * exp(g_last) - # tensor_scalar broadcasts exp_gl_p (P_MAX, 1) across free dim + # state = state * exp(g_last) + kv_outer + # tensor_scalar broadcasts exp_gl_p (P_MAX, 1) across free dim. + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_scalar( - dst=state, - data=state_plus, + dst=state_decayed, + data=state, op0=nl.multiply, operand0=exp_gl_p, engine=nisa.vector_engine, ) + nisa.tensor_tensor(dst=state, data1=state_decayed, data2=kv_outer, op=nl.add) # ---- Write final state to HBM ---- nisa.dma_copy(dst=final_state_out, src=state) diff --git a/contrib/models/Qwen3.5-9B/test/integration/test_model.py b/contrib/models/Qwen3.5-9B/test/integration/test_model.py index b1cf876c..725bca2e 100644 --- a/contrib/models/Qwen3.5-9B/test/integration/test_model.py +++ b/contrib/models/Qwen3.5-9B/test/integration/test_model.py @@ -36,7 +36,10 @@ """ import gc +import json import os +import shutil +import subprocess import sys import time @@ -56,6 +59,8 @@ SEQ_LEN = int(os.environ.get("QWEN35_SEQ_LEN", "128")) TTFT_THRESHOLD_MS = float(os.environ.get("TTFT_THRESHOLD_MS", "5000")) THROUGHPUT_THRESHOLD = float(os.environ.get("THROUGHPUT_THRESHOLD", "5.0")) +USE_HYBRID_CACHE = os.environ.get("QWEN35_USE_HYBRID_CACHE", "0") == "1" +RECORD_HBM = os.environ.get("QWEN35_RECORD_HBM", "0") == "1" requires_model_path = pytest.mark.skipif( not MODEL_PATH, @@ -64,6 +69,13 @@ "weights. Set QWEN35_MODEL_PATH=/path/to/Qwen3.5-9B to run these tests." ), ) +requires_hbm_recording = pytest.mark.skipif( + not RECORD_HBM, + reason=( + "QWEN35_RECORD_HBM=1 not set. This optional test records Neuron HBM " + "usage for dummy-KV vs hybrid-cache comparisons." + ), +) # ── Fixtures ──────────────────────────────────────────────────────────── @@ -119,6 +131,7 @@ def compiled_model(model_path): inf_config = Qwen35InferenceConfig( neuron_config=neuron_config, + use_hybrid_cache_manager=USE_HYBRID_CACHE, **config_dict, ) @@ -195,6 +208,73 @@ def _is_repetitive(text, max_repeat=5): return False +def _parse_peak_neuron_memory(stdout): + peak_device = 0 + peak_tensors = 0 + samples = 0 + for line in stdout.splitlines(): + line = line.strip() + if not line: + continue + try: + report = json.loads(line) + except json.JSONDecodeError: + continue + for runtime in report.get("neuron_runtime_data", []): + memory_used = runtime.get("report", {}).get("memory_used", {}) + used = memory_used.get("neuron_runtime_used_bytes", {}) + peak_device = max(peak_device, int(used.get("neuron_device", 0) or 0)) + nc_usage = ( + used.get("usage_breakdown", {}).get("neuroncore_memory_usage", {}) + ) + tensor_bytes = sum( + int(core.get("tensors", 0) or 0) for core in nc_usage.values() + ) + peak_tensors = max(peak_tensors, tensor_bytes) + samples += 1 + return peak_device, peak_tensors, samples + + +def _capture_neuron_hbm(tmp_path, fn): + if shutil.which("neuron-monitor") is None: + pytest.skip("neuron-monitor is not available") + + monitor_config = { + "period": "0.5s", + "neuron_runtimes": [ + { + "tag_filter": ".*", + "metrics": [{"type": "memory_used", "period": "0.5s"}], + } + ], + } + config_path = tmp_path / "neuron-monitor.json" + config_path.write_text(json.dumps(monitor_config)) + + proc = subprocess.Popen( + ["neuron-monitor", "--config-file", str(config_path)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + try: + time.sleep(1.0) + result = fn() + time.sleep(1.0) + finally: + proc.terminate() + try: + stdout, stderr = proc.communicate(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + stdout, stderr = proc.communicate(timeout=5) + + peak_device, peak_tensors, samples = _parse_peak_neuron_memory(stdout) + assert samples > 0, f"neuron-monitor produced no runtime samples: {stderr}" + assert peak_device > 0, "Expected non-zero Neuron device HBM usage" + return result, peak_device, peak_tensors, samples + + # ── Smoke Tests ───────────────────────────────────────────────────────── @@ -363,6 +443,33 @@ def test_performance_throughput(compiled_model, tokenizer, generation_config): ) +@requires_model_path +@requires_hbm_recording +def test_hybrid_cache_hbm_snapshot(compiled_model, tokenizer, generation_config, tmp_path): + """Record peak Neuron HBM for dummy-KV vs hybrid-cache comparison runs.""" + prompt = "Give me a summary of the 2020 Olympics in 100 tokens." + max_new_tokens = int(os.environ.get("QWEN35_HBM_NEW_TOKENS", "32")) + + (_, text), peak_device, peak_tensors, samples = _capture_neuron_hbm( + tmp_path, + lambda: _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=max_new_tokens, + ), + ) + + mode = "hybrid" if USE_HYBRID_CACHE else "dummy_kv" + print( + " HBM " + f"mode={mode} peak_device_bytes={peak_device} " + f"peak_tensor_bytes={peak_tensors} samples={samples}" + ) + assert len(text) > len(prompt) + + # ── Multi-Prompt Quality Test ────────────────────────────────────────── diff --git a/contrib/models/Qwen3.5-9B/test/unit/test_hybrid_cache_manager.py b/contrib/models/Qwen3.5-9B/test/unit/test_hybrid_cache_manager.py new file mode 100644 index 00000000..503dc24b --- /dev/null +++ b/contrib/models/Qwen3.5-9B/test/unit/test_hybrid_cache_manager.py @@ -0,0 +1,341 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys +import unittest +from math import prod +from unittest.mock import patch + +import torch + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from neuronx_distributed_inference.models.config import NeuronConfig +from src.modeling_qwen35 import HybridDeltaNetCacheManager, Qwen35InferenceConfig + + +def _make_config(**overrides): + neuron_overrides = overrides.pop("neuron_overrides", {}) + neuron_kwargs = dict( + tp_degree=overrides.pop("tp_degree", 4), + batch_size=1, + max_batch_size=2, + kv_cache_batch_size=2, + seq_len=16, + torch_dtype=torch.bfloat16, + ) + neuron_kwargs.update(neuron_overrides) + neuron_config = NeuronConfig(**neuron_kwargs) + defaults = dict( + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=16, + num_key_value_heads=4, + head_dim=256, + intermediate_size=12288, + vocab_size=248320, + rms_norm_eps=1e-6, + max_position_embeddings=262144, + rope_theta=10000000, + hidden_act="silu", + tie_word_embeddings=False, + linear_num_value_heads=32, + linear_num_key_heads=16, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + use_hybrid_cache_manager=True, + ) + defaults.update(overrides) + return Qwen35InferenceConfig(neuron_config=neuron_config, **defaults) + + +def _numel(shape): + return prod(int(dim) for dim in shape) + + +def _managed_cache_numel(mgr): + return sum(param.numel() for param in mgr.past_key_values) + + +def _deltanet_state_numel(config, max_batch_size): + recurrent = ( + max_batch_size + * config.linear_num_value_heads + * config.linear_key_head_dim + * config.linear_value_head_dim + ) + conv_dim = ( + 2 * config.linear_num_key_heads * config.linear_key_head_dim + + config.linear_num_value_heads * config.linear_value_head_dim + ) + conv = max_batch_size * conv_dim * (config.linear_conv_kernel_dim - 1) + return recurrent + conv + + +class TestHybridDeltaNetCacheManager(unittest.TestCase): + def test_allocates_per_layer_cache_shapes(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + self.assertEqual(len(mgr.past_key_values), config.num_hidden_layers * 2) + self.assertEqual( + list(mgr.past_key_values[0].shape), + [2, 32, 128, 128], + ) + self.assertEqual( + list(mgr.past_key_values[1].shape), + [2, 8192, 3], + ) + self.assertEqual(mgr.layer_types[3], "full_attention") + self.assertEqual(mgr.past_key_values[6].dim(), 4) + self.assertEqual(mgr.past_key_values[7].shape[2], 16) + + def test_get_cache_slices_only_full_attention_layers(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + cache = mgr.get_cache(seq_len=4, seq_ids=torch.tensor([1])) + recurrent_state, conv_state = cache[0] + full_k, full_v = cache[3] + + self.assertEqual(list(recurrent_state.shape), [1, 32, 128, 128]) + self.assertEqual(list(conv_state.shape), [1, 8192, 3]) + self.assertEqual(full_k.shape[0], 2) + self.assertEqual(full_v.shape[0], 2) + self.assertEqual(full_k.shape[2], 4) + self.assertEqual(full_v.shape[2], 4) + + def test_get_seq_length_uses_first_full_attention_layer(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + nested_cache = mgr.get_cache(seq_len=5, seq_ids=torch.tensor([0])) + flat_cache = [tensor for layer_cache in nested_cache for tensor in layer_cache] + + self.assertEqual(nested_cache[0][1].shape[2], 3) + self.assertEqual(mgr.get_seq_length(nested_cache), 5) + self.assertEqual(mgr.get_seq_length(flat_cache), 5) + + def test_get_cache_selects_deltanet_state_rows_by_seq_ids(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + with torch.no_grad(): + mgr.past_key_values[0][0, ...].fill_(7) + mgr.past_key_values[0][1, ...].fill_(13) + mgr.past_key_values[1][0, ...].fill_(17) + mgr.past_key_values[1][1, ...].fill_(19) + + recurrent_state, conv_state = mgr.get_cache( + seq_len=4, + seq_ids=torch.tensor([1, 0]), + )[0] + + self.assertTrue(torch.all(recurrent_state[0] == 13)) + self.assertTrue(torch.all(recurrent_state[1] == 7)) + self.assertTrue(torch.all(conv_state[0] == 19)) + self.assertTrue(torch.all(conv_state[1] == 17)) + + def test_deltanet_update_scatters_by_seq_id(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones((1, 32, 128, 128), dtype=torch.bfloat16) + conv = torch.ones((1, 8192, 3), dtype=torch.bfloat16) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=torch.tensor([1]), + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 0)) + self.assertTrue(torch.all(updated_conv[0] == 0)) + self.assertTrue(torch.all(updated_recurrent[1] == 1)) + self.assertTrue(torch.all(updated_conv[1] == 1)) + + def test_deltanet_full_batch_update_replaces_state_cache(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones((2, 32, 128, 128), dtype=torch.bfloat16) + conv = torch.ones((2, 8192, 3), dtype=torch.bfloat16) + recurrent[0].fill_(3) + recurrent[1].fill_(5) + conv[0].fill_(11) + conv[1].fill_(13) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=None, + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 3)) + self.assertTrue(torch.all(updated_recurrent[1] == 5)) + self.assertTrue(torch.all(updated_conv[0] == 11)) + self.assertTrue(torch.all(updated_conv[1] == 13)) + + def test_deltanet_full_batch_update_scatters_non_identity_seq_ids(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones((2, 32, 128, 128), dtype=torch.bfloat16) + conv = torch.ones((2, 8192, 3), dtype=torch.bfloat16) + recurrent[0].fill_(3) + recurrent[1].fill_(5) + conv[0].fill_(11) + conv[1].fill_(13) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=torch.tensor([1, 0]), + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 5)) + self.assertTrue(torch.all(updated_recurrent[1] == 3)) + self.assertTrue(torch.all(updated_conv[0] == 13)) + self.assertTrue(torch.all(updated_conv[1] == 11)) + + def test_deltanet_update_maps_out_of_range_seq_id_to_padding_row(self): + config = _make_config(neuron_overrides={"kv_cache_padding_size": 1}) + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones((1, 32, 128, 128), dtype=torch.bfloat16) + conv = torch.ones((1, 8192, 3), dtype=torch.bfloat16) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=torch.tensor([99]), + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 0)) + self.assertTrue(torch.all(updated_recurrent[1] == 0)) + self.assertTrue(torch.all(updated_recurrent[2] == 1)) + self.assertTrue(torch.all(updated_conv[2] == 1)) + + def test_deltanet_state_shapes_do_not_scale_with_sequence_length(self): + short_config = _make_config(neuron_overrides={"seq_len": 128}) + long_config = _make_config(neuron_overrides={"seq_len": 2048}) + short_mgr = HybridDeltaNetCacheManager( + short_config, num_kv_head=short_config.num_key_value_heads + ) + long_mgr = HybridDeltaNetCacheManager( + long_config, num_kv_head=long_config.num_key_value_heads + ) + + self.assertEqual(short_mgr.past_key_values[0].shape, long_mgr.past_key_values[0].shape) + self.assertEqual(short_mgr.past_key_values[1].shape, long_mgr.past_key_values[1].shape) + self.assertLess(short_mgr.past_key_values[7].shape[2], long_mgr.past_key_values[7].shape[2]) + + def test_get_cache_trims_padding_row_without_seq_ids(self): + config = _make_config(neuron_overrides={"kv_cache_padding_size": 1}) + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + recurrent_state, conv_state = mgr.get_cache(seq_len=4)[0] + + self.assertEqual(list(recurrent_state.shape), [2, 32, 128, 128]) + self.assertEqual(list(conv_state.shape), [2, 8192, 3]) + + def test_update_cache_dispatches_deltanet_and_full_attention_layers(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + new_key_values = [] + for idx in range(4): + first = mgr.past_key_values[2 * idx] + second = mgr.past_key_values[2 * idx + 1] + new_key_values.append( + ( + torch.full_like(first, fill_value=idx + 1), + torch.full_like(second, fill_value=idx + 11), + ) + ) + + position_ids = torch.arange(16, dtype=torch.long).unsqueeze(0).expand(2, -1) + full_k_update = torch.full_like(mgr.past_key_values[6], fill_value=4) + full_v_update = torch.full_like(mgr.past_key_values[7], fill_value=14) + with patch.object( + mgr, "update_kv_by_layer_id", return_value=(full_k_update, full_v_update) + ) as update_kv: + updated = mgr.update_cache( + is_for_context_encoding=True, + seq_ids=torch.tensor([0, 1], dtype=torch.int32), + position_ids=position_ids, + new_key_values=new_key_values, + seq_len=16, + ) + + self.assertEqual(update_kv.call_count, 1) + self.assertEqual(update_kv.call_args.kwargs["idx"], 3) + self.assertTrue(torch.all(updated[0] == 1)) + self.assertTrue(torch.all(updated[1] == 11)) + self.assertTrue(torch.all(updated[6] == 4)) + self.assertTrue(torch.all(updated[7] == 14)) + + def test_managed_cache_removes_dummy_kv_for_deltanet_layers(self): + config = _make_config(neuron_overrides={"seq_len": 1024}) + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + max_batch_size = ( + config.neuron_config.kv_cache_batch_size + + config.neuron_config.kv_cache_padding_size + ) + full_kv_per_layer = _numel(mgr.k_shape) + _numel(mgr.v_shape) + deltanet_layers = config.layer_types.count("linear_attention") + legacy_total_numel = ( + full_kv_per_layer * config.num_hidden_layers + + _deltanet_state_numel(config, max_batch_size) * deltanet_layers + ) + expected_savings = full_kv_per_layer * deltanet_layers + + self.assertEqual( + legacy_total_numel - _managed_cache_numel(mgr), + expected_savings, + ) + self.assertLess(_managed_cache_numel(mgr), legacy_total_numel) + + def test_rejects_unsupported_hybrid_modes(self): + unsupported_cases = [ + ({"padding_side": "left"}, "left padding"), + ({"flash_decoding_enabled": True}, "flash decoding"), + ] + + for neuron_overrides, expected_error in unsupported_cases: + with self.subTest(expected_error=expected_error): + config = _make_config(neuron_overrides=neuron_overrides) + with self.assertRaisesRegex(ValueError, expected_error): + HybridDeltaNetCacheManager( + config, num_kv_head=config.num_key_value_heads + ) + + config = _make_config() + config.neuron_config.kv_cache_quant = True + with self.assertRaisesRegex(ValueError, "KV cache quantization"): + HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + config = _make_config( + neuron_overrides={ + "attention_dp_degree": 2, + "batch_size": 2, + "ctx_batch_size": 2, + "tkg_batch_size": 2, + "max_batch_size": 2, + "kv_cache_batch_size": 2, + "is_continuous_batching": True, + } + ) + with self.assertRaisesRegex(ValueError, "attention data parallelism"): + HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + config = _make_config() + config.neuron_config.kv_cache_tiling = True + with self.assertRaisesRegex(ValueError, "KV cache tiling"): + HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + def test_legacy_config_default_is_disabled(self): + config = _make_config(use_hybrid_cache_manager=False) + self.assertFalse(config.use_hybrid_cache_manager) + + +if __name__ == "__main__": + unittest.main() From c1795b0000d84c2aa2978b35adf4e6bcc69bc670 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Tue, 5 May 2026 21:17:32 +0530 Subject: [PATCH 5/5] Keep Qwen3.5 contrib on stable dummy KV path --- contrib/models/Qwen3.5-4B/README.md | 3 +- .../models/Qwen3.5-4B/src/modeling_qwen35.py | 330 +------------- .../src/nki_kernels/nki_deltanet_fused.py | 419 ++++++++++++------ .../Qwen3.5-4B/test/integration/test_model.py | 142 ++---- .../test/unit/test_deltanet_decay.py | 71 +-- .../test/unit/test_hybrid_cache_manager.py | 335 -------------- contrib/models/Qwen3.5-9B/README.md | 3 +- .../models/Qwen3.5-9B/src/modeling_qwen35.py | 330 +------------- .../src/nki_kernels/nki_deltanet_fused.py | 419 ++++++++++++------ .../Qwen3.5-9B/test/integration/test_model.py | 142 ++---- .../test/unit/test_deltanet_decay.py | 71 +-- .../test/unit/test_hybrid_cache_manager.py | 341 -------------- 12 files changed, 694 insertions(+), 1912 deletions(-) delete mode 100644 contrib/models/Qwen3.5-4B/test/unit/test_hybrid_cache_manager.py delete mode 100644 contrib/models/Qwen3.5-9B/test/unit/test_hybrid_cache_manager.py diff --git a/contrib/models/Qwen3.5-4B/README.md b/contrib/models/Qwen3.5-4B/README.md index 9910cd12..61f7f16a 100644 --- a/contrib/models/Qwen3.5-4B/README.md +++ b/contrib/models/Qwen3.5-4B/README.md @@ -163,4 +163,5 @@ Validated results on `trn2.48xlarge`: - DeltaNet weights are replicated across TP ranks in v1. - DeltaNet layers still allocate dummy KV cache through NxDI's normal cache manager. -- MoE, VL, quantization, speculation, and custom hybrid cache cleanup are out of scope. +- MoE, VL, quantization, and speculation are out of scope. +- A custom hybrid cache manager to remove dummy KV HBM overhead is planned as a follow-up PR. diff --git a/contrib/models/Qwen3.5-4B/src/modeling_qwen35.py b/contrib/models/Qwen3.5-4B/src/modeling_qwen35.py index 791420cb..574fdbe0 100644 --- a/contrib/models/Qwen3.5-4B/src/modeling_qwen35.py +++ b/contrib/models/Qwen3.5-4B/src/modeling_qwen35.py @@ -84,7 +84,6 @@ NeuronAttentionBase, ) from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding -from neuronx_distributed_inference.modules.kvcache.kv_cache_manager import KVCacheManager from neuronx_distributed_inference.models.layer_boundary_marker import ( ModuleMarkerEndWrapper, ModuleMarkerStartWrapper, @@ -222,33 +221,6 @@ def l2norm(x, dim=-1, eps=1e-6): return F.normalize(x, p=2, dim=dim, eps=eps) -FUSED_DELTANET_DECAY_MIN = -20.0 -FUSED_DELTANET_DECAY_MAX = 0.0 - - -def _bound_fused_deltanet_log_decay( - g, batch_size, num_heads, total_seq_len, chunk_size -): - """Bound cumulative DeltaNet decay before the fused NKI kernel. - - The fused kernel internally computes both exp(cumsum(g)) and exp(-cumsum(g)). - Large negative cumulative decays make the second term overflow even though - the true pairwise decay exp(gc_i - gc_j) is bounded by one. Return - equivalent per-token deltas whose per-chunk cumulative sum is clamped. - """ - num_chunks = total_seq_len // chunk_size - g_chunks = g.reshape(batch_size, num_heads, num_chunks, chunk_size) - g_cumsum = g_chunks.cumsum(dim=-1).clamp( - min=FUSED_DELTANET_DECAY_MIN, - max=FUSED_DELTANET_DECAY_MAX, - ) - g_first = g_cumsum[..., :1] - g_rest = g_cumsum[..., 1:] - g_cumsum[..., :-1] - return torch.cat([g_first, g_rest], dim=-1).reshape( - batch_size, num_heads, total_seq_len - ) - - # ============================================================ # Gated DeltaNet Module (Linear Recurrent Attention) # ============================================================ @@ -287,7 +259,6 @@ def __init__(self, config, layer_idx: int): self.conv_kernel_size = tc.linear_conv_kernel_dim # 4 self.layer_idx = layer_idx self.rms_norm_eps = tc.rms_norm_eps - self.use_hybrid_cache_manager = getattr(tc, "use_hybrid_cache_manager", False) # KV cache dummy shape info self.head_dim = tc.head_dim # 256 @@ -550,7 +521,8 @@ def _fused_chunked_forward( beta = F.pad(beta, (0, pad_size)) g = F.pad(g, (0, pad_size)) total_seq_len = S + pad_size - g = _bound_fused_deltanet_log_decay(g, B, H, total_seq_len, chunk_size) + # Pass raw per-token log-decay. The fused NKI kernel forms decay as + # exp(cumsum(g)_i - cumsum(g)_j), so no pre-kernel clamp is needed. BH = B * H # Flatten to (BH, S, dim) for per-(b,h) kernel calls @@ -749,11 +721,6 @@ def forward( # zeros the decay gate so the recurrent state is preserved unchanged # through padding positions (no spurious decay). valid_mask_1d = kwargs.get("deltanet_padding_mask", None) # [B, S, 1] or None - hybrid_cache_active = self.use_hybrid_cache_manager - recurrent_state_cache = None - conv_state_cache = None - if hybrid_cache_active and past_key_value is not None: - recurrent_state_cache, conv_state_cache = past_key_value # Project inputs deltanet_fp32 = os.environ.get("DELTANET_FP32") == "1" @@ -781,9 +748,7 @@ def forward( mixed = mixed.transpose(1, 2) if is_decode: - if conv_state_cache is not None: - conv_state = conv_state_cache[:batch_size] - elif seq_ids is not None: + if seq_ids is not None: conv_state = torch.index_select(self.conv_state_buffer, 0, seq_ids) else: conv_state = self.conv_state_buffer[:batch_size] @@ -800,9 +765,7 @@ def forward( new_conv_state = torch.cat([conv_state[:, :, 1:], mixed], dim=-1) alloc_bs = self.conv_state_buffer.shape[0] - if hybrid_cache_active: - new_conv_state = new_conv_state.to(self.conv_state_buffer.dtype) - elif seq_ids is not None: + if seq_ids is not None: # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement # Add buffer dependency for input_output_alias new_conv_state = ( @@ -838,9 +801,7 @@ def forward( new_conv_state = mixed[:, :, -3:].contiguous() alloc_bs = self.conv_state_buffer.shape[0] - if hybrid_cache_active: - new_conv_state = new_conv_state.to(self.conv_state_buffer.dtype) - elif seq_ids is not None: + if seq_ids is not None: # BS=1 optimization: scatter to index 0 = direct replacement new_conv_state = ( new_conv_state.to(self.conv_state_buffer.dtype) @@ -919,9 +880,7 @@ def forward( if is_decode: # TKG: single-step recurrent update - if recurrent_state_cache is not None: - recurrent_state = recurrent_state_cache[:batch_size].float() - elif seq_ids is not None: + if seq_ids is not None: recurrent_state = torch.index_select( self.recurrent_state_buffer, 0, seq_ids ).float() @@ -933,9 +892,7 @@ def forward( ) new_state_bf16 = new_state.to(self.recurrent_state_buffer.dtype) alloc_bs = self.recurrent_state_buffer.shape[0] - if hybrid_cache_active: - new_rec_state = new_state_bf16 - elif seq_ids is not None: + if seq_ids is not None: # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement # Add buffer dependency for input_output_alias new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0 @@ -987,9 +944,7 @@ def forward( if final_state is not None: final_state_bf16 = final_state.to(self.recurrent_state_buffer.dtype) alloc_bs = self.recurrent_state_buffer.shape[0] - if hybrid_cache_active: - new_rec_state = final_state_bf16 - elif seq_ids is not None: + if seq_ids is not None: # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement # Add buffer dependency for input_output_alias new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0 @@ -1024,9 +979,6 @@ def forward( output = output.reshape(batch_size, seq_len, self.value_dim) output = self.out_proj(output) - if hybrid_cache_active: - return output, (new_rec_state, new_conv_state), new_rec_state, new_conv_state - # Return dummy KV for KVCacheManager dummy_k = torch.zeros( batch_size, @@ -1081,7 +1033,6 @@ def __init__(self, *args, **kwargs): kwargs.setdefault("linear_key_head_dim", 128) kwargs.setdefault("linear_value_head_dim", 128) kwargs.setdefault("linear_conv_kernel_dim", 4) - kwargs.setdefault("use_hybrid_cache_manager", False) super().__init__(*args, **kwargs) @@ -1535,11 +1486,7 @@ def forward( ) hidden_states = residual + attn_out present_key_value = dummy_kv - deltanet_states = ( - None - if getattr(self.config, "use_hybrid_cache_manager", False) - else (new_rec_state, new_conv_state) - ) + deltanet_states = (new_rec_state, new_conv_state) else: deltanet_states = None # Standard attention path @@ -1572,240 +1519,6 @@ def forward( return outputs -# ============================================================ -# Hybrid Cache Manager (opt-in) -# ============================================================ - - -class HybridDeltaNetCacheManager(KVCacheManager): - """Layer-type-aware cache manager for Qwen3.5/Qwen3.6 hybrid dense models.""" - - def __init__(self, config: Qwen35InferenceConfig, num_kv_head, **kwargs): - self.layer_types = list(config.layer_types) - self._validate_hybrid_config(config) - super().__init__(config, num_kv_head=num_kv_head, **kwargs) - - dtype = ( - config.neuron_config.attention_dtype - if config.neuron_config.attention_dtype is not None - else config.neuron_config.torch_dtype - ) - cache_dtype = getattr(self, "cache_dtype", dtype) - max_batch_size = ( - config.neuron_config.kv_cache_batch_size - + config.neuron_config.kv_cache_padding_size - ) - recurrent_shape = [ - max_batch_size, - config.linear_num_value_heads, - config.linear_key_head_dim, - config.linear_value_head_dim, - ] - conv_dim = ( - 2 * config.linear_num_key_heads * config.linear_key_head_dim - + config.linear_num_value_heads * config.linear_value_head_dim - ) - conv_shape = [ - max_batch_size, - conv_dim, - config.linear_conv_kernel_dim - 1, - ] - - params = [] - for layer_idx, layer_type in enumerate(self.layer_types): - if layer_type == "linear_attention": - params.append( - nn.Parameter(torch.zeros(recurrent_shape, dtype=dtype), requires_grad=False) - ) - params.append( - nn.Parameter(torch.zeros(conv_shape, dtype=dtype), requires_grad=False) - ) - else: - k_shape = self.k_shapes[layer_idx] if hasattr(self, "k_shapes") else self.k_shape - v_shape = self.v_shapes[layer_idx] if hasattr(self, "v_shapes") else self.v_shape - params.append( - nn.Parameter(torch.zeros(k_shape, dtype=cache_dtype), requires_grad=False) - ) - params.append( - nn.Parameter(torch.zeros(v_shape, dtype=cache_dtype), requires_grad=False) - ) - - self.past_key_values = nn.ParameterList(params) - - @staticmethod - def _validate_hybrid_config(config: Qwen35InferenceConfig): - nc = config.neuron_config - unsupported = [] - if nc.is_block_kv_layout: - unsupported.append("block KV layout") - if getattr(nc, "kv_quant_config", None) is not None or getattr(nc, "kv_cache_quant", False): - unsupported.append("KV cache quantization") - if nc.enable_fused_speculation or nc.speculation_length > 0 or nc.is_medusa: - unsupported.append("speculative decoding") - if getattr(nc, "enable_eagle_speculation", False) or getattr(nc, "is_eagle_draft", False): - unsupported.append("EAGLE speculation") - if nc.flash_decoding_enabled: - unsupported.append("flash decoding") - if nc.attention_dp_degree > 1: - unsupported.append("attention data parallelism") - if nc.kv_cache_tiling: - unsupported.append("KV cache tiling") - if nc.padding_side != "right": - unsupported.append("left padding") - if nc.is_continuous_batching: - unsupported.append("continuous batching") - if unsupported: - raise ValueError( - "HybridDeltaNetCacheManager v1 does not support: " - + ", ".join(unsupported) - ) - - def _is_deltanet_layer(self, idx: int) -> bool: - return self.layer_types[idx] == "linear_attention" - - def get_seq_length(self, past_key_values=None): - for idx, layer_type in enumerate(self.layer_types): - if layer_type != "linear_attention": - if past_key_values is None: - _, v_cache = self._fetch_cache(idx) - elif len(past_key_values) == len(self.past_key_values): - v_cache = past_key_values[2 * idx + 1] - else: - v_cache = past_key_values[idx][1] - return v_cache.shape[2] - return 0 - - def get_deltanet_state_by_layer_id(self, idx, kvcache_buffer=None, seq_ids=None): - recurrent_state, conv_state = self._fetch_cache(idx, kvcache_buffer) - if seq_ids is not None: - cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) - recurrent_state = torch.index_select(recurrent_state, dim=0, index=cache_idx) - conv_state = torch.index_select(conv_state, dim=0, index=cache_idx) - elif self.kv_cache_padding_size > 0: - recurrent_state = recurrent_state[: -self.kv_cache_padding_size] - conv_state = conv_state[: -self.kv_cache_padding_size] - return recurrent_state, conv_state - - def get_cache( - self, - seq_len: int, - skip_slice=False, - kvcache_buffer=None, - seq_ids=None, - windowed_context_encoding_window_idx=-1, - **kwargs, - ): - past_key_values = [] - for idx in range(len(self.past_key_values) // 2): - if self._is_deltanet_layer(idx): - past_key_values.append( - list(self.get_deltanet_state_by_layer_id(idx, kvcache_buffer, seq_ids)) - ) - else: - past_key_values.append( - list( - self.get_kv_by_layer_id( - idx=idx, - skip_slice=skip_slice, - seq_len=seq_len, - kvcache_buffer=kvcache_buffer, - seq_ids=seq_ids, - windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, - **kwargs, - ) - ) - ) - return past_key_values - - def update_cache( - self, - is_for_context_encoding: bool, - seq_ids: torch.Tensor, - position_ids: torch.Tensor, - new_key_values: List[torch.Tensor], - seq_len: int, - scatter_index=None, - kv_active_mask=None, - kvcache_buffer=None, - windowed_context_encoding_window_idx: int = -1, - **kwargs, - ): - updated_cache = [] - for idx, kv_per_layer in enumerate(new_key_values): - if self._is_deltanet_layer(idx): - recurrent_state, conv_state = self.update_deltanet_state_by_layer_id( - idx=idx, - seq_ids=seq_ids, - state_per_layer=kv_per_layer, - kvcache_buffer=kvcache_buffer, - ) - else: - recurrent_state, conv_state = self.update_kv_by_layer_id( - idx=idx, - is_for_context_encoding=is_for_context_encoding, - seq_ids=seq_ids, - position_ids=position_ids, - kv_per_layer=kv_per_layer, - seq_len=seq_len, - scatter_index=scatter_index, - kv_active_mask=kv_active_mask, - kvcache_buffer=kvcache_buffer, - windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, - **kwargs, - ) - updated_cache.append(recurrent_state) - updated_cache.append(conv_state) - return updated_cache - - def update_deltanet_state_by_layer_id( - self, - idx: int, - seq_ids: torch.Tensor, - state_per_layer: Tuple[torch.Tensor, torch.Tensor], - kvcache_buffer=None, - ): - latest_recurrent, latest_conv = state_per_layer - recurrent_cache, conv_cache = self._fetch_cache(idx, kvcache_buffer) - latest_recurrent = latest_recurrent.to(recurrent_cache.dtype) - latest_conv = latest_conv.to(conv_cache.dtype) - - if seq_ids is not None: - cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) - recurrent_index = cache_idx.view(-1, 1, 1, 1).expand_as(latest_recurrent) - conv_index = cache_idx.view(-1, 1, 1).expand_as(latest_conv) - recurrent_cache = torch.scatter( - input=recurrent_cache, - dim=0, - index=recurrent_index, - src=latest_recurrent, - ) - conv_cache = torch.scatter( - input=conv_cache, - dim=0, - index=conv_index, - src=latest_conv, - ) - return recurrent_cache, conv_cache - - if latest_recurrent.shape[0] == recurrent_cache.shape[0]: - return ( - latest_recurrent + recurrent_cache * 0, - latest_conv + conv_cache * 0, - ) - - pad_size = recurrent_cache.shape[0] - latest_recurrent.shape[0] - if pad_size > 0: - latest_recurrent = torch.cat( - [latest_recurrent, recurrent_cache[latest_recurrent.shape[0] :] * 0], - dim=0, - ) - latest_conv = torch.cat( - [latest_conv, conv_cache[latest_conv.shape[0] :] * 0], - dim=0, - ) - return latest_recurrent + recurrent_cache * 0, latest_conv + conv_cache * 0 - - # ============================================================ # Model # ============================================================ @@ -1851,19 +1564,6 @@ def init_model(self, config: Qwen35InferenceConfig): # mRoPE embedding for VL self.mrope_emb = Qwen35MRoPEEmbedding(config) - def init_inference_optimization(self, config: Qwen35InferenceConfig): - super().init_inference_optimization(config) - if getattr(config, "use_hybrid_cache_manager", False): - self.kv_mgr = HybridDeltaNetCacheManager( - config, - num_kv_head=self.num_key_value_heads, - global_rank=self.rank_util, - attention_chunk_size=self.attention_chunk_size, - sliding_window=self.sliding_window, - windowed_context_encoding_size=self.windowed_context_encoding_size, - layer_to_cache_size_mapping=self.layer_to_cache_size_mapping, - ) - @property def _deltanet_state_params(self): """Return DeltaNet state nn.Parameters in alias order.""" @@ -2186,10 +1886,7 @@ def forward( outputs += updated_kv_cache # Append DeltaNet state tensors (for input_output_aliases) - if ( - not getattr(self.config, "use_hybrid_cache_manager", False) - and hasattr(self, "_deltanet_updated_states") - ): + if hasattr(self, "_deltanet_updated_states"): outputs += self._deltanet_updated_states return outputs @@ -2341,10 +2038,7 @@ def get(self, bucket_rank, **kwargs): state_start_idx = num_output_from_trace + num_kv - if ( - not getattr(module.config, "use_hybrid_cache_manager", False) - and hasattr(module, "_deltanet_state_params") - ): + if hasattr(module, "_deltanet_state_params"): for i, param in enumerate(module._deltanet_state_params): input_output_aliases[param] = state_start_idx + i @@ -2584,8 +2278,6 @@ def enable_token_generation(self): def _copy_past_key_values(self, outputs): """Override to also copy DeltaNet state buffers on CPU.""" super()._copy_past_key_values(outputs) - if getattr(self.config, "use_hybrid_cache_manager", False): - return num_output_from_trace = 1 if ( diff --git a/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_fused.py b/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_fused.py index 4d02423d..b13e2e95 100644 --- a/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_fused.py +++ b/contrib/models/Qwen3.5-4B/src/nki_kernels/nki_deltanet_fused.py @@ -15,14 +15,17 @@ 5. Uses tensor_scalar for partition-broadcast (no explicit broadcast loops) 6. nc_transpose (Vector Engine) for all 128x128 transposes instead of nc_matmul(moving=eye) (Tensor Engine) — frees TE for actual math + 7. Forms decay as exp(cumsum(g)_i - cumsum(g)_j), never as split + exp(cumsum(g)_i) * exp(-cumsum(g)_j) NKI 0.3.0 (SDK 2.29). k_dim = v_dim = 128 = P_MAX exactly. Chunk size = 128 = P_MAX (one tile per chunk). -Mathematical framework (same as nki_deltanet_chunked.py): - Per-chunk Neumann-series power-doubling for intra-chunk correction: +Mathematical framework: + Per-chunk blocked triangular solve for intra-chunk correction: A = -QK_decay * lower_mask - N = (I+A)(I+A^2)(I+A^4)...(I+A^64) [6 rounds] + N = inv(I - A), computed by 64x64 forward substitution plus one + 64->128 lower-block merge value_corr = N @ v_beta k_cumdecay = N @ (k_beta * exp(gc)) @@ -32,7 +35,7 @@ attn_inter = (q * exp(gc)) @ state attn_intra = (q @ k^T) * decay_mask * lower_mask_diag output = attn_inter + attn_intra @ v_new - state = exp(g_last) * (state + k_raw_decay^T @ v_new) + state = state * exp(g_last) + (k * exp(g_last - gc))^T @ v_new """ import numpy as np @@ -219,9 +222,15 @@ def deltanet_fused_chunked_fwd( src=gc_row[0:1, CHUNK_SIZE - 1 : CHUNK_SIZE], ) - # ---- Compute exp(gc), exp(-gc), exp(g_last) as (P_MAX, 1) scalars ---- - # These (P_MAX, 1) tensors are used with tensor_scalar to broadcast - # across the free dimension without explicit (P_MAX, dim) copies. + # ---- Build stable decay factors from cumulative log-decay ---- + # + # Pairwise decays are computed as exp(gc[i] - gc[j]) under the causal + # mask. The older split form exp(gc[i]) * exp(-gc[j]) is algebraically + # equivalent, but can overflow/underflow before the multiply. + # + # The one-vector exp(gc) and exp(g_last) factors are still required by + # the chunk recurrence; these are non-positive GDN decays and therefore + # bounded above by one. exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) nisa.activation( @@ -232,24 +241,15 @@ def deltanet_fused_chunked_fwd( scale=1.0, ) - neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_scalar( - dst=neg_gc_p, - data=gc_p, - op0=nl.multiply, - operand0=-1.0, - engine=nisa.vector_engine, - ) - exp_neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) - nisa.activation( - dst=exp_neg_gc_p[0:P_MAX, 0:1], - op=nl.exp, - data=neg_gc_p[0:P_MAX, 0:1], - bias=None, - scale=1.0, - ) + # g_last: scalar, then broadcast raw and exp(g_last) to (P_MAX, 1) + gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gl_11[0:1, 0:1], + dst=gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) - # exp(g_last): scalar, then broadcast to (P_MAX, 1) exp_gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) nisa.activation( dst=exp_gl_11, @@ -267,6 +267,84 @@ def deltanet_fused_chunked_fwd( shuffle_mask=_BROADCAST_MASK, ) + # Broadcast gc row-wise so row i, column j can form gc[i] - gc[j]. + gc_row_broadcast = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gc_row[0:1, 0:P_MAX], + dst=gc_row_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + # Strict-lower decay for the KKT solve: exp(gc[i] - gc[j]) where i > j. + # Mask before exp by zeroing non-causal differences, then mask again + # after exp so exp(0) from non-causal positions does not contribute. + gc_col_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_strict, + data=Lmask, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_strict, data1=gc_row_broadcast, data2=Lmask, op=nl.multiply + ) + g_diff_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_strict, + data1=gc_col_strict, + data2=gc_row_strict, + op=nl.subtract, + ) + decay_strict_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_strict_raw, + op=nl.exp, + data=g_diff_strict, + bias=None, + scale=1.0, + ) + decay_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_strict, data1=decay_strict_raw, data2=Lmask, op=nl.multiply + ) + + # Lower-with-diagonal decay for intra-chunk attention: exp(gc[i] - gc[j]) + # where i >= j. + gc_col_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_diag, + data=Lmask_d, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_diag, data1=gc_row_broadcast, data2=Lmask_d, op=nl.multiply + ) + g_diff_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_diag, + data1=gc_col_diag, + data2=gc_row_diag, + op=nl.subtract, + ) + decay_diag_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_diag_raw, + op=nl.exp, + data=g_diff_diag, + bias=None, + scale=1.0, + ) + decay_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_diag, data1=decay_diag_raw, data2=Lmask_d, op=nl.multiply + ) + # ============================================================ # k_beta = K * beta, v_beta = V * beta # tensor_scalar broadcasts beta_p (P_MAX, 1) across free dim @@ -309,49 +387,11 @@ def deltanet_fused_chunked_fwd( QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_copy(dst=QK, src=QK_psum) - # ============================================================ - # Decay mask: QK_decay[i,j] = QK[i,j] * exp(gc[i]) * exp(-gc[j]) - # - # Apply the strict causal mask before the split exp(gc) / exp(-gc) - # scaling. Upper-triangular entries are mathematically unused, but - # scaling them first can create very large finite values that poison - # later matmuls before the mask is applied. - # ============================================================ - QK_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_tensor(dst=QK_masked, data1=QK, data2=Lmask, op=nl.multiply) - - # Row scaling: QK_row[i,:] = QK[i,:] * exp(gc[i]) - # Then transpose, column scale, transpose back. - # Uses tensor_scalar with (P_MAX,1) operand for row scaling. - QK_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_scalar( - dst=QK_row, - data=QK_masked, - op0=nl.multiply, - operand0=exp_gc_p, - engine=nisa.vector_engine, - ) - - # Transpose to scale columns (now rows in transposed view) - QK_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) - nisa.nc_transpose(dst=QK_r_T_psum, data=QK_row) - QK_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=QK_r_T, src=QK_r_T_psum) - - QK_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_scalar( - dst=QK_r_T_col, - data=QK_r_T, - op0=nl.multiply, - operand0=exp_neg_gc_p, - engine=nisa.vector_engine, - ) - - # Transpose back - QK_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) - nisa.nc_transpose(dst=QK_d_psum, data=QK_r_T_col) + # QK_decay[i,j] = QK[i,j] * exp(gc[i] - gc[j]) for i > j. + # This is the same causal decay as the split-exp form, but numerically + # bounded by construction. QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=QK_decay, src=QK_d_psum) + nisa.tensor_tensor(dst=QK_decay, data1=QK, data2=decay_strict, op=nl.multiply) # A = -QK_decay * lower_mask neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) @@ -366,38 +406,184 @@ def deltanet_fused_chunked_fwd( nisa.tensor_tensor(dst=A_mat, data1=neg_QK_decay, data2=Lmask, op=nl.multiply) # ============================================================ - # Neumann power-doubling: N = (I+A)(I+A^2)...(I+A^{64}) - # 6 rounds → resolves rank up to 2^6 = 64 (sufficient for chunk=128) + # Stable triangular solve: N = inv(I - A_mat) + # + # A_mat is strictly lower triangular. Solve two 64x64 diagonal + # blocks row-by-row: + # N[i, :] = e_i + sum_{j= j. attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_tensor( - dst=attn_intra, data1=qk_decay, data2=Lmask_d, op=nl.multiply - ) + nisa.tensor_tensor(dst=attn_intra, data1=qk_raw, data2=decay_diag, op=nl.multiply) # ============================================================ # v_prime = k_cumdecay @ state (state is in SBUF!) @@ -550,17 +700,8 @@ def deltanet_fused_chunked_fwd( # state is updated IN-PLACE in SBUF — no HBM round-trip! # ============================================================ - # k_raw_decay contributes as exp(g_last) * (k * exp(-gc))^T @ v_new. - # Compute the equivalent form with one bounded exponential, - # k * exp(g_last - gc), so the factor is always <= 1 for valid - # causal positions. - gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) - for i_shuf in nl.static_range(P_MAX // 32): - nisa.nc_stream_shuffle( - src=gl_11[0:1, 0:1], - dst=gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], - shuffle_mask=_BROADCAST_MASK, - ) + # k_raw_decay contributes as k * exp(g_last - gc), with one bounded + # exponential instead of exp(g_last) * exp(-gc). gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_tensor( dst=gl_minus_gc_p, diff --git a/contrib/models/Qwen3.5-4B/test/integration/test_model.py b/contrib/models/Qwen3.5-4B/test/integration/test_model.py index 2c1353c0..e168757f 100644 --- a/contrib/models/Qwen3.5-4B/test/integration/test_model.py +++ b/contrib/models/Qwen3.5-4B/test/integration/test_model.py @@ -36,10 +36,7 @@ """ import gc -import json import os -import shutil -import subprocess import sys import time @@ -59,8 +56,6 @@ SEQ_LEN = int(os.environ.get("QWEN35_SEQ_LEN", "128")) TTFT_THRESHOLD_MS = float(os.environ.get("TTFT_THRESHOLD_MS", "5000")) THROUGHPUT_THRESHOLD = float(os.environ.get("THROUGHPUT_THRESHOLD", "5.0")) -USE_HYBRID_CACHE = os.environ.get("QWEN35_USE_HYBRID_CACHE", "0") == "1" -RECORD_HBM = os.environ.get("QWEN35_RECORD_HBM", "0") == "1" requires_model_path = pytest.mark.skipif( not MODEL_PATH, @@ -69,13 +64,6 @@ "weights. Set QWEN35_MODEL_PATH=/path/to/Qwen3.5-4B to run these tests." ), ) -requires_hbm_recording = pytest.mark.skipif( - not RECORD_HBM, - reason=( - "QWEN35_RECORD_HBM=1 not set. This optional test records Neuron HBM " - "usage for dummy-KV vs hybrid-cache comparisons." - ), -) # ── Fixtures ──────────────────────────────────────────────────────────── @@ -131,7 +119,6 @@ def compiled_model(model_path): inf_config = Qwen35InferenceConfig( neuron_config=neuron_config, - use_hybrid_cache_manager=USE_HYBRID_CACHE, **config_dict, ) @@ -197,6 +184,16 @@ def _generate(model, tokenizer, generation_config, prompt, max_new_tokens=20): return outputs[0].tolist(), tokenizer.decode(outputs[0], skip_special_tokens=True) +def _make_repeated_stress_prompt(tokenizer, target_tokens=133): + """Build a repeated prompt near the target token length.""" + seed = ( + "Repeat this stability phrase for DeltaNet recurrent decoding. " + "Repeat this stability phrase for DeltaNet recurrent decoding. " + ) + ids = tokenizer.encode(seed * 32, add_special_tokens=False)[:target_tokens] + return tokenizer.decode(ids, skip_special_tokens=True) + + def _is_repetitive(text, max_repeat=5): """Check for excessive word repetition.""" words = text.split() @@ -208,73 +205,6 @@ def _is_repetitive(text, max_repeat=5): return False -def _parse_peak_neuron_memory(stdout): - peak_device = 0 - peak_tensors = 0 - samples = 0 - for line in stdout.splitlines(): - line = line.strip() - if not line: - continue - try: - report = json.loads(line) - except json.JSONDecodeError: - continue - for runtime in report.get("neuron_runtime_data", []): - memory_used = runtime.get("report", {}).get("memory_used", {}) - used = memory_used.get("neuron_runtime_used_bytes", {}) - peak_device = max(peak_device, int(used.get("neuron_device", 0) or 0)) - nc_usage = ( - used.get("usage_breakdown", {}).get("neuroncore_memory_usage", {}) - ) - tensor_bytes = sum( - int(core.get("tensors", 0) or 0) for core in nc_usage.values() - ) - peak_tensors = max(peak_tensors, tensor_bytes) - samples += 1 - return peak_device, peak_tensors, samples - - -def _capture_neuron_hbm(tmp_path, fn): - if shutil.which("neuron-monitor") is None: - pytest.skip("neuron-monitor is not available") - - monitor_config = { - "period": "0.5s", - "neuron_runtimes": [ - { - "tag_filter": ".*", - "metrics": [{"type": "memory_used", "period": "0.5s"}], - } - ], - } - config_path = tmp_path / "neuron-monitor.json" - config_path.write_text(json.dumps(monitor_config)) - - proc = subprocess.Popen( - ["neuron-monitor", "--config-file", str(config_path)], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - try: - time.sleep(1.0) - result = fn() - time.sleep(1.0) - finally: - proc.terminate() - try: - stdout, stderr = proc.communicate(timeout=5) - except subprocess.TimeoutExpired: - proc.kill() - stdout, stderr = proc.communicate(timeout=5) - - peak_device, peak_tensors, samples = _parse_peak_neuron_memory(stdout) - assert samples > 0, f"neuron-monitor produced no runtime samples: {stderr}" - assert peak_device > 0, "Expected non-zero Neuron device HBM usage" - return result, peak_device, peak_tensors, samples - - # ── Smoke Tests ───────────────────────────────────────────────────────── @@ -365,6 +295,31 @@ def test_olympics_prompt_no_invalid_tokens( assert not invalid, f"Generated invalid token ids: {invalid}" +@requires_model_path +def test_repeated_stress_prompt_no_invalid_tokens( + compiled_model, tokenizer, generation_config +): + """Regression test for repeated 129/133-token prompts that exposed NaNs.""" + prompt = _make_repeated_stress_prompt(tokenizer, target_tokens=133) + prompt_len = len(tokenizer.encode(prompt)) + assert 129 <= prompt_len <= 133, f"Unexpected stress prompt length: {prompt_len}" + + tokens, text = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=16, + ) + generated = tokens[prompt_len:] + invalid = [token for token in generated if token < 0 or token >= len(tokenizer)] + + assert len(generated) >= 5, f"Expected >= 5 generated tokens, got {generated}" + assert not invalid, f"Generated invalid token ids: {invalid}" + assert len(text) > len(prompt) + print(f" Repeated stress prompt length: {prompt_len} tokens") + + @requires_model_path def test_capital_of_france(compiled_model, tokenizer, generation_config): """'The capital of France is' should produce 'Paris' in the response.""" @@ -442,33 +397,6 @@ def test_performance_throughput(compiled_model, tokenizer, generation_config): ) -@requires_model_path -@requires_hbm_recording -def test_hybrid_cache_hbm_snapshot(compiled_model, tokenizer, generation_config, tmp_path): - """Record peak Neuron HBM for dummy-KV vs hybrid-cache comparison runs.""" - prompt = "Give me a summary of the 2020 Olympics in 100 tokens." - max_new_tokens = int(os.environ.get("QWEN35_HBM_NEW_TOKENS", "32")) - - (_, text), peak_device, peak_tensors, samples = _capture_neuron_hbm( - tmp_path, - lambda: _generate( - compiled_model, - tokenizer, - generation_config, - prompt, - max_new_tokens=max_new_tokens, - ), - ) - - mode = "hybrid" if USE_HYBRID_CACHE else "dummy_kv" - print( - " HBM " - f"mode={mode} peak_device_bytes={peak_device} " - f"peak_tensor_bytes={peak_tensors} samples={samples}" - ) - assert len(text) > len(prompt) - - # ── Multi-Prompt Quality Test ────────────────────────────────────────── diff --git a/contrib/models/Qwen3.5-4B/test/unit/test_deltanet_decay.py b/contrib/models/Qwen3.5-4B/test/unit/test_deltanet_decay.py index 416a431a..f3d7d8bc 100644 --- a/contrib/models/Qwen3.5-4B/test/unit/test_deltanet_decay.py +++ b/contrib/models/Qwen3.5-4B/test/unit/test_deltanet_decay.py @@ -1,67 +1,34 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -"""Unit tests for fused DeltaNet log-decay bounding.""" +"""Unit tests for fused DeltaNet log-decay stability structure.""" -import os -import sys +import pathlib import unittest -import torch -_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) -if _CONTRIB_ROOT not in sys.path: - sys.path.insert(0, _CONTRIB_ROOT) +_CONTRIB_ROOT = pathlib.Path(__file__).resolve().parents[2] +_SRC_ROOT = _CONTRIB_ROOT / "src" -from src.modeling_qwen35 import ( - FUSED_DELTANET_DECAY_MAX, - FUSED_DELTANET_DECAY_MIN, - _bound_fused_deltanet_log_decay, -) +class TestFusedDeltaNetDecayStability(unittest.TestCase): + def test_fused_kernel_uses_exp_of_differences(self): + kernel_source = ( + _SRC_ROOT / "nki_kernels" / "nki_deltanet_fused.py" + ).read_text() -def _chunked_cumsum(g, batch_size, num_heads, total_seq_len, chunk_size): - num_chunks = total_seq_len // chunk_size - return g.reshape(batch_size, num_heads, num_chunks, chunk_size).cumsum(dim=-1) + self.assertIn("decay_strict", kernel_source) + self.assertIn("decay_diag", kernel_source) + self.assertIn("gl_minus_gc_p", kernel_source) + self.assertNotIn("exp_neg_gc_p", kernel_source) + self.assertNotIn("operand0=exp_neg_gc_p", kernel_source) + def test_modeling_does_not_clamp_fused_decay_inputs(self): + modeling_source = (_SRC_ROOT / "modeling_qwen35.py").read_text() -class TestFusedDeltaNetDecayBounding(unittest.TestCase): - def test_preserves_non_extreme_decay(self): - batch_size, num_heads, total_seq_len, chunk_size = 2, 3, 16, 8 - g = torch.full( - (batch_size, num_heads, total_seq_len), - -0.125, - dtype=torch.float32, - ) - - bounded = _bound_fused_deltanet_log_decay( - g, batch_size, num_heads, total_seq_len, chunk_size - ) - - torch.testing.assert_close(bounded, g) - - def test_bounds_per_chunk_cumulative_decay(self): - batch_size, num_heads, total_seq_len, chunk_size = 2, 3, 16, 8 - g = torch.full( - (batch_size, num_heads, total_seq_len), - -10.0, - dtype=torch.float32, - ) - - bounded = _bound_fused_deltanet_log_decay( - g, batch_size, num_heads, total_seq_len, chunk_size - ) - bounded_cumsum = _chunked_cumsum( - bounded, batch_size, num_heads, total_seq_len, chunk_size - ) - expected_cumsum = _chunked_cumsum( - g, batch_size, num_heads, total_seq_len, chunk_size - ).clamp(min=FUSED_DELTANET_DECAY_MIN, max=FUSED_DELTANET_DECAY_MAX) - - torch.testing.assert_close(bounded_cumsum, expected_cumsum) - self.assertGreaterEqual(float(bounded_cumsum.min()), FUSED_DELTANET_DECAY_MIN) - self.assertLessEqual(float(bounded_cumsum.max()), FUSED_DELTANET_DECAY_MAX) - self.assertTrue(torch.isfinite(bounded).all()) + self.assertNotIn("_bound_fused_deltanet_log_decay", modeling_source) + self.assertNotIn("FUSED_DELTANET_DECAY_MIN", modeling_source) + self.assertIn("exp(cumsum(g)_i - cumsum(g)_j)", modeling_source) if __name__ == "__main__": diff --git a/contrib/models/Qwen3.5-4B/test/unit/test_hybrid_cache_manager.py b/contrib/models/Qwen3.5-4B/test/unit/test_hybrid_cache_manager.py deleted file mode 100644 index c5941945..00000000 --- a/contrib/models/Qwen3.5-4B/test/unit/test_hybrid_cache_manager.py +++ /dev/null @@ -1,335 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -import os -import sys -import unittest -from math import prod -from unittest.mock import patch - -import torch - -_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) -if _CONTRIB_ROOT not in sys.path: - sys.path.insert(0, _CONTRIB_ROOT) - -from neuronx_distributed_inference.models.config import NeuronConfig -from src.modeling_qwen35 import HybridDeltaNetCacheManager, Qwen35InferenceConfig - - -def _make_config(**overrides): - neuron_overrides = overrides.pop("neuron_overrides", {}) - neuron_kwargs = dict( - tp_degree=overrides.pop("tp_degree", 4), - batch_size=1, - max_batch_size=2, - kv_cache_batch_size=2, - seq_len=16, - torch_dtype=torch.bfloat16, - ) - neuron_kwargs.update(neuron_overrides) - neuron_config = NeuronConfig(**neuron_kwargs) - defaults = dict( - hidden_size=2560, - num_hidden_layers=32, - num_attention_heads=16, - num_key_value_heads=4, - head_dim=256, - intermediate_size=9216, - vocab_size=248320, - rms_norm_eps=1e-6, - max_position_embeddings=262144, - rope_theta=10000000, - hidden_act="silu", - tie_word_embeddings=True, - linear_num_value_heads=32, - linear_num_key_heads=16, - linear_key_head_dim=128, - linear_value_head_dim=128, - linear_conv_kernel_dim=4, - use_hybrid_cache_manager=True, - ) - defaults.update(overrides) - return Qwen35InferenceConfig(neuron_config=neuron_config, **defaults) - - -def _numel(shape): - return prod(int(dim) for dim in shape) - - -def _managed_cache_numel(mgr): - return sum(param.numel() for param in mgr.past_key_values) - - -def _deltanet_state_numel(config, max_batch_size): - recurrent = ( - max_batch_size - * config.linear_num_value_heads - * config.linear_key_head_dim - * config.linear_value_head_dim - ) - conv_dim = ( - 2 * config.linear_num_key_heads * config.linear_key_head_dim - + config.linear_num_value_heads * config.linear_value_head_dim - ) - conv = max_batch_size * conv_dim * (config.linear_conv_kernel_dim - 1) - return recurrent + conv - - -class TestHybridDeltaNetCacheManager(unittest.TestCase): - def test_allocates_per_layer_cache_shapes(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - - self.assertEqual(len(mgr.past_key_values), config.num_hidden_layers * 2) - self.assertEqual(list(mgr.past_key_values[0].shape), [2, 32, 128, 128]) - self.assertEqual(list(mgr.past_key_values[1].shape), [2, 8192, 3]) - self.assertEqual(mgr.layer_types[3], "full_attention") - self.assertEqual(mgr.past_key_values[6].dim(), 4) - self.assertEqual(mgr.past_key_values[7].shape[2], 16) - - def test_get_cache_slices_only_full_attention_layers(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - - cache = mgr.get_cache(seq_len=4, seq_ids=torch.tensor([1])) - recurrent_state, conv_state = cache[0] - full_k, full_v = cache[3] - - self.assertEqual(list(recurrent_state.shape), [1, 32, 128, 128]) - self.assertEqual(list(conv_state.shape), [1, 8192, 3]) - self.assertEqual(full_k.shape[0], 2) - self.assertEqual(full_v.shape[0], 2) - self.assertEqual(full_k.shape[2], 4) - self.assertEqual(full_v.shape[2], 4) - - def test_get_seq_length_uses_first_full_attention_layer(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - - nested_cache = mgr.get_cache(seq_len=5, seq_ids=torch.tensor([0])) - flat_cache = [tensor for layer_cache in nested_cache for tensor in layer_cache] - - self.assertEqual(nested_cache[0][1].shape[2], 3) - self.assertEqual(mgr.get_seq_length(nested_cache), 5) - self.assertEqual(mgr.get_seq_length(flat_cache), 5) - - def test_get_cache_selects_deltanet_state_rows_by_seq_ids(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - with torch.no_grad(): - mgr.past_key_values[0][0, ...].fill_(7) - mgr.past_key_values[0][1, ...].fill_(13) - mgr.past_key_values[1][0, ...].fill_(17) - mgr.past_key_values[1][1, ...].fill_(19) - - recurrent_state, conv_state = mgr.get_cache( - seq_len=4, - seq_ids=torch.tensor([1, 0]), - )[0] - - self.assertTrue(torch.all(recurrent_state[0] == 13)) - self.assertTrue(torch.all(recurrent_state[1] == 7)) - self.assertTrue(torch.all(conv_state[0] == 19)) - self.assertTrue(torch.all(conv_state[1] == 17)) - - def test_deltanet_update_scatters_by_seq_id(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - recurrent = torch.ones((1, 32, 128, 128), dtype=torch.bfloat16) - conv = torch.ones((1, 8192, 3), dtype=torch.bfloat16) - - updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( - idx=0, - seq_ids=torch.tensor([1]), - state_per_layer=(recurrent, conv), - ) - - self.assertTrue(torch.all(updated_recurrent[0] == 0)) - self.assertTrue(torch.all(updated_conv[0] == 0)) - self.assertTrue(torch.all(updated_recurrent[1] == 1)) - self.assertTrue(torch.all(updated_conv[1] == 1)) - - def test_deltanet_full_batch_update_replaces_state_cache(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - recurrent = torch.ones((2, 32, 128, 128), dtype=torch.bfloat16) - conv = torch.ones((2, 8192, 3), dtype=torch.bfloat16) - recurrent[0].fill_(3) - recurrent[1].fill_(5) - conv[0].fill_(11) - conv[1].fill_(13) - - updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( - idx=0, - seq_ids=None, - state_per_layer=(recurrent, conv), - ) - - self.assertTrue(torch.all(updated_recurrent[0] == 3)) - self.assertTrue(torch.all(updated_recurrent[1] == 5)) - self.assertTrue(torch.all(updated_conv[0] == 11)) - self.assertTrue(torch.all(updated_conv[1] == 13)) - - def test_deltanet_full_batch_update_scatters_non_identity_seq_ids(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - recurrent = torch.ones((2, 32, 128, 128), dtype=torch.bfloat16) - conv = torch.ones((2, 8192, 3), dtype=torch.bfloat16) - recurrent[0].fill_(3) - recurrent[1].fill_(5) - conv[0].fill_(11) - conv[1].fill_(13) - - updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( - idx=0, - seq_ids=torch.tensor([1, 0]), - state_per_layer=(recurrent, conv), - ) - - self.assertTrue(torch.all(updated_recurrent[0] == 5)) - self.assertTrue(torch.all(updated_recurrent[1] == 3)) - self.assertTrue(torch.all(updated_conv[0] == 13)) - self.assertTrue(torch.all(updated_conv[1] == 11)) - - def test_deltanet_update_maps_out_of_range_seq_id_to_padding_row(self): - config = _make_config(neuron_overrides={"kv_cache_padding_size": 1}) - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - recurrent = torch.ones((1, 32, 128, 128), dtype=torch.bfloat16) - conv = torch.ones((1, 8192, 3), dtype=torch.bfloat16) - - updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( - idx=0, - seq_ids=torch.tensor([99]), - state_per_layer=(recurrent, conv), - ) - - self.assertTrue(torch.all(updated_recurrent[0] == 0)) - self.assertTrue(torch.all(updated_recurrent[1] == 0)) - self.assertTrue(torch.all(updated_recurrent[2] == 1)) - self.assertTrue(torch.all(updated_conv[2] == 1)) - - def test_deltanet_state_shapes_do_not_scale_with_sequence_length(self): - short_config = _make_config(neuron_overrides={"seq_len": 128}) - long_config = _make_config(neuron_overrides={"seq_len": 2048}) - short_mgr = HybridDeltaNetCacheManager( - short_config, num_kv_head=short_config.num_key_value_heads - ) - long_mgr = HybridDeltaNetCacheManager( - long_config, num_kv_head=long_config.num_key_value_heads - ) - - self.assertEqual(short_mgr.past_key_values[0].shape, long_mgr.past_key_values[0].shape) - self.assertEqual(short_mgr.past_key_values[1].shape, long_mgr.past_key_values[1].shape) - self.assertLess(short_mgr.past_key_values[7].shape[2], long_mgr.past_key_values[7].shape[2]) - - def test_get_cache_trims_padding_row_without_seq_ids(self): - config = _make_config(neuron_overrides={"kv_cache_padding_size": 1}) - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - - recurrent_state, conv_state = mgr.get_cache(seq_len=4)[0] - - self.assertEqual(list(recurrent_state.shape), [2, 32, 128, 128]) - self.assertEqual(list(conv_state.shape), [2, 8192, 3]) - - def test_update_cache_dispatches_deltanet_and_full_attention_layers(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - new_key_values = [] - for idx in range(4): - first = mgr.past_key_values[2 * idx] - second = mgr.past_key_values[2 * idx + 1] - new_key_values.append( - ( - torch.full_like(first, fill_value=idx + 1), - torch.full_like(second, fill_value=idx + 11), - ) - ) - - position_ids = torch.arange(16, dtype=torch.long).unsqueeze(0).expand(2, -1) - full_k_update = torch.full_like(mgr.past_key_values[6], fill_value=4) - full_v_update = torch.full_like(mgr.past_key_values[7], fill_value=14) - with patch.object( - mgr, "update_kv_by_layer_id", return_value=(full_k_update, full_v_update) - ) as update_kv: - updated = mgr.update_cache( - is_for_context_encoding=True, - seq_ids=torch.tensor([0, 1], dtype=torch.int32), - position_ids=position_ids, - new_key_values=new_key_values, - seq_len=16, - ) - - self.assertEqual(update_kv.call_count, 1) - self.assertEqual(update_kv.call_args.kwargs["idx"], 3) - self.assertTrue(torch.all(updated[0] == 1)) - self.assertTrue(torch.all(updated[1] == 11)) - self.assertTrue(torch.all(updated[6] == 4)) - self.assertTrue(torch.all(updated[7] == 14)) - - def test_managed_cache_removes_dummy_kv_for_deltanet_layers(self): - config = _make_config(neuron_overrides={"seq_len": 1024}) - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - max_batch_size = ( - config.neuron_config.kv_cache_batch_size - + config.neuron_config.kv_cache_padding_size - ) - full_kv_per_layer = _numel(mgr.k_shape) + _numel(mgr.v_shape) - deltanet_layers = config.layer_types.count("linear_attention") - legacy_total_numel = ( - full_kv_per_layer * config.num_hidden_layers - + _deltanet_state_numel(config, max_batch_size) * deltanet_layers - ) - expected_savings = full_kv_per_layer * deltanet_layers - - self.assertEqual( - legacy_total_numel - _managed_cache_numel(mgr), - expected_savings, - ) - self.assertLess(_managed_cache_numel(mgr), legacy_total_numel) - - def test_rejects_unsupported_hybrid_modes(self): - unsupported_cases = [ - ({"padding_side": "left"}, "left padding"), - ({"flash_decoding_enabled": True}, "flash decoding"), - ] - - for neuron_overrides, expected_error in unsupported_cases: - with self.subTest(expected_error=expected_error): - config = _make_config(neuron_overrides=neuron_overrides) - with self.assertRaisesRegex(ValueError, expected_error): - HybridDeltaNetCacheManager( - config, num_kv_head=config.num_key_value_heads - ) - - config = _make_config() - config.neuron_config.kv_cache_quant = True - with self.assertRaisesRegex(ValueError, "KV cache quantization"): - HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - - config = _make_config( - neuron_overrides={ - "attention_dp_degree": 2, - "batch_size": 2, - "ctx_batch_size": 2, - "tkg_batch_size": 2, - "max_batch_size": 2, - "kv_cache_batch_size": 2, - "is_continuous_batching": True, - } - ) - with self.assertRaisesRegex(ValueError, "attention data parallelism"): - HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - - config = _make_config() - config.neuron_config.kv_cache_tiling = True - with self.assertRaisesRegex(ValueError, "KV cache tiling"): - HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - - def test_legacy_config_default_is_disabled(self): - config = _make_config(use_hybrid_cache_manager=False) - self.assertFalse(config.use_hybrid_cache_manager) - - -if __name__ == "__main__": - unittest.main() diff --git a/contrib/models/Qwen3.5-9B/README.md b/contrib/models/Qwen3.5-9B/README.md index 638d3649..6999fb97 100644 --- a/contrib/models/Qwen3.5-9B/README.md +++ b/contrib/models/Qwen3.5-9B/README.md @@ -163,4 +163,5 @@ Validated results on `trn2.48xlarge`: 1. SDK 2.29+ and NKI 0.3 are expected. 2. DeltaNet weights are replicated across TP ranks in v1. 3. Dummy KV wastes HBM for DeltaNet layers. -4. Hybrid cache, DeltaNet TP sharding, quantization, speculative decoding, and MoE are out of scope for first bring-up. +4. DeltaNet TP sharding, quantization, speculative decoding, and MoE are out of scope for first bring-up. +5. A custom hybrid cache manager to remove dummy KV HBM overhead is planned as a follow-up PR. diff --git a/contrib/models/Qwen3.5-9B/src/modeling_qwen35.py b/contrib/models/Qwen3.5-9B/src/modeling_qwen35.py index b1622d86..56fea6eb 100644 --- a/contrib/models/Qwen3.5-9B/src/modeling_qwen35.py +++ b/contrib/models/Qwen3.5-9B/src/modeling_qwen35.py @@ -84,7 +84,6 @@ NeuronAttentionBase, ) from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding -from neuronx_distributed_inference.modules.kvcache.kv_cache_manager import KVCacheManager from neuronx_distributed_inference.models.layer_boundary_marker import ( ModuleMarkerEndWrapper, ModuleMarkerStartWrapper, @@ -222,33 +221,6 @@ def l2norm(x, dim=-1, eps=1e-6): return F.normalize(x, p=2, dim=dim, eps=eps) -FUSED_DELTANET_DECAY_MIN = -20.0 -FUSED_DELTANET_DECAY_MAX = 0.0 - - -def _bound_fused_deltanet_log_decay( - g, batch_size, num_heads, total_seq_len, chunk_size -): - """Bound cumulative DeltaNet decay before the fused NKI kernel. - - The fused kernel internally computes both exp(cumsum(g)) and exp(-cumsum(g)). - Large negative cumulative decays make the second term overflow even though - the true pairwise decay exp(gc_i - gc_j) is bounded by one. Return - equivalent per-token deltas whose per-chunk cumulative sum is clamped. - """ - num_chunks = total_seq_len // chunk_size - g_chunks = g.reshape(batch_size, num_heads, num_chunks, chunk_size) - g_cumsum = g_chunks.cumsum(dim=-1).clamp( - min=FUSED_DELTANET_DECAY_MIN, - max=FUSED_DELTANET_DECAY_MAX, - ) - g_first = g_cumsum[..., :1] - g_rest = g_cumsum[..., 1:] - g_cumsum[..., :-1] - return torch.cat([g_first, g_rest], dim=-1).reshape( - batch_size, num_heads, total_seq_len - ) - - # ============================================================ # Gated DeltaNet Module (Linear Recurrent Attention) # ============================================================ @@ -287,7 +259,6 @@ def __init__(self, config, layer_idx: int): self.conv_kernel_size = tc.linear_conv_kernel_dim # 4 self.layer_idx = layer_idx self.rms_norm_eps = tc.rms_norm_eps - self.use_hybrid_cache_manager = getattr(tc, "use_hybrid_cache_manager", False) # KV cache dummy shape info self.head_dim = tc.head_dim # 256 @@ -550,7 +521,8 @@ def _fused_chunked_forward( beta = F.pad(beta, (0, pad_size)) g = F.pad(g, (0, pad_size)) total_seq_len = S + pad_size - g = _bound_fused_deltanet_log_decay(g, B, H, total_seq_len, chunk_size) + # Pass raw per-token log-decay. The fused NKI kernel forms decay as + # exp(cumsum(g)_i - cumsum(g)_j), so no pre-kernel clamp is needed. BH = B * H # Flatten to (BH, S, dim) for per-(b,h) kernel calls @@ -749,11 +721,6 @@ def forward( # zeros the decay gate so the recurrent state is preserved unchanged # through padding positions (no spurious decay). valid_mask_1d = kwargs.get("deltanet_padding_mask", None) # [B, S, 1] or None - hybrid_cache_active = self.use_hybrid_cache_manager - recurrent_state_cache = None - conv_state_cache = None - if hybrid_cache_active and past_key_value is not None: - recurrent_state_cache, conv_state_cache = past_key_value # Project inputs deltanet_fp32 = os.environ.get("DELTANET_FP32") == "1" @@ -781,9 +748,7 @@ def forward( mixed = mixed.transpose(1, 2) if is_decode: - if conv_state_cache is not None: - conv_state = conv_state_cache[:batch_size] - elif seq_ids is not None: + if seq_ids is not None: conv_state = torch.index_select(self.conv_state_buffer, 0, seq_ids) else: conv_state = self.conv_state_buffer[:batch_size] @@ -800,9 +765,7 @@ def forward( new_conv_state = torch.cat([conv_state[:, :, 1:], mixed], dim=-1) alloc_bs = self.conv_state_buffer.shape[0] - if hybrid_cache_active: - new_conv_state = new_conv_state.to(self.conv_state_buffer.dtype) - elif seq_ids is not None: + if seq_ids is not None: # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement # Add buffer dependency for input_output_alias new_conv_state = ( @@ -838,9 +801,7 @@ def forward( new_conv_state = mixed[:, :, -3:].contiguous() alloc_bs = self.conv_state_buffer.shape[0] - if hybrid_cache_active: - new_conv_state = new_conv_state.to(self.conv_state_buffer.dtype) - elif seq_ids is not None: + if seq_ids is not None: # BS=1 optimization: scatter to index 0 = direct replacement new_conv_state = ( new_conv_state.to(self.conv_state_buffer.dtype) @@ -919,9 +880,7 @@ def forward( if is_decode: # TKG: single-step recurrent update - if recurrent_state_cache is not None: - recurrent_state = recurrent_state_cache[:batch_size].float() - elif seq_ids is not None: + if seq_ids is not None: recurrent_state = torch.index_select( self.recurrent_state_buffer, 0, seq_ids ).float() @@ -933,9 +892,7 @@ def forward( ) new_state_bf16 = new_state.to(self.recurrent_state_buffer.dtype) alloc_bs = self.recurrent_state_buffer.shape[0] - if hybrid_cache_active: - new_rec_state = new_state_bf16 - elif seq_ids is not None: + if seq_ids is not None: # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement # Add buffer dependency for input_output_alias new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0 @@ -987,9 +944,7 @@ def forward( if final_state is not None: final_state_bf16 = final_state.to(self.recurrent_state_buffer.dtype) alloc_bs = self.recurrent_state_buffer.shape[0] - if hybrid_cache_active: - new_rec_state = final_state_bf16 - elif seq_ids is not None: + if seq_ids is not None: # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement # Add buffer dependency for input_output_alias new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0 @@ -1024,9 +979,6 @@ def forward( output = output.reshape(batch_size, seq_len, self.value_dim) output = self.out_proj(output) - if hybrid_cache_active: - return output, (new_rec_state, new_conv_state), new_rec_state, new_conv_state - # Return dummy KV for KVCacheManager dummy_k = torch.zeros( batch_size, @@ -1081,7 +1033,6 @@ def __init__(self, *args, **kwargs): kwargs.setdefault("linear_key_head_dim", 128) kwargs.setdefault("linear_value_head_dim", 128) kwargs.setdefault("linear_conv_kernel_dim", 4) - kwargs.setdefault("use_hybrid_cache_manager", False) super().__init__(*args, **kwargs) @@ -1535,11 +1486,7 @@ def forward( ) hidden_states = residual + attn_out present_key_value = dummy_kv - deltanet_states = ( - None - if getattr(self.config, "use_hybrid_cache_manager", False) - else (new_rec_state, new_conv_state) - ) + deltanet_states = (new_rec_state, new_conv_state) else: deltanet_states = None # Standard attention path @@ -1572,240 +1519,6 @@ def forward( return outputs -# ============================================================ -# Hybrid Cache Manager (opt-in) -# ============================================================ - - -class HybridDeltaNetCacheManager(KVCacheManager): - """Layer-type-aware cache manager for Qwen3.5/Qwen3.6 hybrid dense models.""" - - def __init__(self, config: Qwen35InferenceConfig, num_kv_head, **kwargs): - self.layer_types = list(config.layer_types) - self._validate_hybrid_config(config) - super().__init__(config, num_kv_head=num_kv_head, **kwargs) - - dtype = ( - config.neuron_config.attention_dtype - if config.neuron_config.attention_dtype is not None - else config.neuron_config.torch_dtype - ) - cache_dtype = getattr(self, "cache_dtype", dtype) - max_batch_size = ( - config.neuron_config.kv_cache_batch_size - + config.neuron_config.kv_cache_padding_size - ) - recurrent_shape = [ - max_batch_size, - config.linear_num_value_heads, - config.linear_key_head_dim, - config.linear_value_head_dim, - ] - conv_dim = ( - 2 * config.linear_num_key_heads * config.linear_key_head_dim - + config.linear_num_value_heads * config.linear_value_head_dim - ) - conv_shape = [ - max_batch_size, - conv_dim, - config.linear_conv_kernel_dim - 1, - ] - - params = [] - for layer_idx, layer_type in enumerate(self.layer_types): - if layer_type == "linear_attention": - params.append( - nn.Parameter(torch.zeros(recurrent_shape, dtype=dtype), requires_grad=False) - ) - params.append( - nn.Parameter(torch.zeros(conv_shape, dtype=dtype), requires_grad=False) - ) - else: - k_shape = self.k_shapes[layer_idx] if hasattr(self, "k_shapes") else self.k_shape - v_shape = self.v_shapes[layer_idx] if hasattr(self, "v_shapes") else self.v_shape - params.append( - nn.Parameter(torch.zeros(k_shape, dtype=cache_dtype), requires_grad=False) - ) - params.append( - nn.Parameter(torch.zeros(v_shape, dtype=cache_dtype), requires_grad=False) - ) - - self.past_key_values = nn.ParameterList(params) - - @staticmethod - def _validate_hybrid_config(config: Qwen35InferenceConfig): - nc = config.neuron_config - unsupported = [] - if nc.is_block_kv_layout: - unsupported.append("block KV layout") - if getattr(nc, "kv_quant_config", None) is not None or getattr(nc, "kv_cache_quant", False): - unsupported.append("KV cache quantization") - if nc.enable_fused_speculation or nc.speculation_length > 0 or nc.is_medusa: - unsupported.append("speculative decoding") - if getattr(nc, "enable_eagle_speculation", False) or getattr(nc, "is_eagle_draft", False): - unsupported.append("EAGLE speculation") - if nc.flash_decoding_enabled: - unsupported.append("flash decoding") - if nc.attention_dp_degree > 1: - unsupported.append("attention data parallelism") - if nc.kv_cache_tiling: - unsupported.append("KV cache tiling") - if nc.padding_side != "right": - unsupported.append("left padding") - if nc.is_continuous_batching: - unsupported.append("continuous batching") - if unsupported: - raise ValueError( - "HybridDeltaNetCacheManager v1 does not support: " - + ", ".join(unsupported) - ) - - def _is_deltanet_layer(self, idx: int) -> bool: - return self.layer_types[idx] == "linear_attention" - - def get_seq_length(self, past_key_values=None): - for idx, layer_type in enumerate(self.layer_types): - if layer_type != "linear_attention": - if past_key_values is None: - _, v_cache = self._fetch_cache(idx) - elif len(past_key_values) == len(self.past_key_values): - v_cache = past_key_values[2 * idx + 1] - else: - v_cache = past_key_values[idx][1] - return v_cache.shape[2] - return 0 - - def get_deltanet_state_by_layer_id(self, idx, kvcache_buffer=None, seq_ids=None): - recurrent_state, conv_state = self._fetch_cache(idx, kvcache_buffer) - if seq_ids is not None: - cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) - recurrent_state = torch.index_select(recurrent_state, dim=0, index=cache_idx) - conv_state = torch.index_select(conv_state, dim=0, index=cache_idx) - elif self.kv_cache_padding_size > 0: - recurrent_state = recurrent_state[: -self.kv_cache_padding_size] - conv_state = conv_state[: -self.kv_cache_padding_size] - return recurrent_state, conv_state - - def get_cache( - self, - seq_len: int, - skip_slice=False, - kvcache_buffer=None, - seq_ids=None, - windowed_context_encoding_window_idx=-1, - **kwargs, - ): - past_key_values = [] - for idx in range(len(self.past_key_values) // 2): - if self._is_deltanet_layer(idx): - past_key_values.append( - list(self.get_deltanet_state_by_layer_id(idx, kvcache_buffer, seq_ids)) - ) - else: - past_key_values.append( - list( - self.get_kv_by_layer_id( - idx=idx, - skip_slice=skip_slice, - seq_len=seq_len, - kvcache_buffer=kvcache_buffer, - seq_ids=seq_ids, - windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, - **kwargs, - ) - ) - ) - return past_key_values - - def update_cache( - self, - is_for_context_encoding: bool, - seq_ids: torch.Tensor, - position_ids: torch.Tensor, - new_key_values: List[torch.Tensor], - seq_len: int, - scatter_index=None, - kv_active_mask=None, - kvcache_buffer=None, - windowed_context_encoding_window_idx: int = -1, - **kwargs, - ): - updated_cache = [] - for idx, kv_per_layer in enumerate(new_key_values): - if self._is_deltanet_layer(idx): - recurrent_state, conv_state = self.update_deltanet_state_by_layer_id( - idx=idx, - seq_ids=seq_ids, - state_per_layer=kv_per_layer, - kvcache_buffer=kvcache_buffer, - ) - else: - recurrent_state, conv_state = self.update_kv_by_layer_id( - idx=idx, - is_for_context_encoding=is_for_context_encoding, - seq_ids=seq_ids, - position_ids=position_ids, - kv_per_layer=kv_per_layer, - seq_len=seq_len, - scatter_index=scatter_index, - kv_active_mask=kv_active_mask, - kvcache_buffer=kvcache_buffer, - windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, - **kwargs, - ) - updated_cache.append(recurrent_state) - updated_cache.append(conv_state) - return updated_cache - - def update_deltanet_state_by_layer_id( - self, - idx: int, - seq_ids: torch.Tensor, - state_per_layer: Tuple[torch.Tensor, torch.Tensor], - kvcache_buffer=None, - ): - latest_recurrent, latest_conv = state_per_layer - recurrent_cache, conv_cache = self._fetch_cache(idx, kvcache_buffer) - latest_recurrent = latest_recurrent.to(recurrent_cache.dtype) - latest_conv = latest_conv.to(conv_cache.dtype) - - if seq_ids is not None: - cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) - recurrent_index = cache_idx.view(-1, 1, 1, 1).expand_as(latest_recurrent) - conv_index = cache_idx.view(-1, 1, 1).expand_as(latest_conv) - recurrent_cache = torch.scatter( - input=recurrent_cache, - dim=0, - index=recurrent_index, - src=latest_recurrent, - ) - conv_cache = torch.scatter( - input=conv_cache, - dim=0, - index=conv_index, - src=latest_conv, - ) - return recurrent_cache, conv_cache - - if latest_recurrent.shape[0] == recurrent_cache.shape[0]: - return ( - latest_recurrent + recurrent_cache * 0, - latest_conv + conv_cache * 0, - ) - - pad_size = recurrent_cache.shape[0] - latest_recurrent.shape[0] - if pad_size > 0: - latest_recurrent = torch.cat( - [latest_recurrent, recurrent_cache[latest_recurrent.shape[0] :] * 0], - dim=0, - ) - latest_conv = torch.cat( - [latest_conv, conv_cache[latest_conv.shape[0] :] * 0], - dim=0, - ) - return latest_recurrent + recurrent_cache * 0, latest_conv + conv_cache * 0 - - # ============================================================ # Model # ============================================================ @@ -1851,19 +1564,6 @@ def init_model(self, config: Qwen35InferenceConfig): # mRoPE embedding for VL self.mrope_emb = Qwen35MRoPEEmbedding(config) - def init_inference_optimization(self, config: Qwen35InferenceConfig): - super().init_inference_optimization(config) - if getattr(config, "use_hybrid_cache_manager", False): - self.kv_mgr = HybridDeltaNetCacheManager( - config, - num_kv_head=self.num_key_value_heads, - global_rank=self.rank_util, - attention_chunk_size=self.attention_chunk_size, - sliding_window=self.sliding_window, - windowed_context_encoding_size=self.windowed_context_encoding_size, - layer_to_cache_size_mapping=self.layer_to_cache_size_mapping, - ) - @property def _deltanet_state_params(self): """Return DeltaNet state nn.Parameters in alias order.""" @@ -2186,10 +1886,7 @@ def forward( outputs += updated_kv_cache # Append DeltaNet state tensors (for input_output_aliases) - if ( - not getattr(self.config, "use_hybrid_cache_manager", False) - and hasattr(self, "_deltanet_updated_states") - ): + if hasattr(self, "_deltanet_updated_states"): outputs += self._deltanet_updated_states return outputs @@ -2341,10 +2038,7 @@ def get(self, bucket_rank, **kwargs): state_start_idx = num_output_from_trace + num_kv - if ( - not getattr(module.config, "use_hybrid_cache_manager", False) - and hasattr(module, "_deltanet_state_params") - ): + if hasattr(module, "_deltanet_state_params"): for i, param in enumerate(module._deltanet_state_params): input_output_aliases[param] = state_start_idx + i @@ -2579,8 +2273,6 @@ def enable_token_generation(self): def _copy_past_key_values(self, outputs): """Override to also copy DeltaNet state buffers on CPU.""" super()._copy_past_key_values(outputs) - if getattr(self.config, "use_hybrid_cache_manager", False): - return num_output_from_trace = 1 if ( diff --git a/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_fused.py b/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_fused.py index 4d02423d..b13e2e95 100644 --- a/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_fused.py +++ b/contrib/models/Qwen3.5-9B/src/nki_kernels/nki_deltanet_fused.py @@ -15,14 +15,17 @@ 5. Uses tensor_scalar for partition-broadcast (no explicit broadcast loops) 6. nc_transpose (Vector Engine) for all 128x128 transposes instead of nc_matmul(moving=eye) (Tensor Engine) — frees TE for actual math + 7. Forms decay as exp(cumsum(g)_i - cumsum(g)_j), never as split + exp(cumsum(g)_i) * exp(-cumsum(g)_j) NKI 0.3.0 (SDK 2.29). k_dim = v_dim = 128 = P_MAX exactly. Chunk size = 128 = P_MAX (one tile per chunk). -Mathematical framework (same as nki_deltanet_chunked.py): - Per-chunk Neumann-series power-doubling for intra-chunk correction: +Mathematical framework: + Per-chunk blocked triangular solve for intra-chunk correction: A = -QK_decay * lower_mask - N = (I+A)(I+A^2)(I+A^4)...(I+A^64) [6 rounds] + N = inv(I - A), computed by 64x64 forward substitution plus one + 64->128 lower-block merge value_corr = N @ v_beta k_cumdecay = N @ (k_beta * exp(gc)) @@ -32,7 +35,7 @@ attn_inter = (q * exp(gc)) @ state attn_intra = (q @ k^T) * decay_mask * lower_mask_diag output = attn_inter + attn_intra @ v_new - state = exp(g_last) * (state + k_raw_decay^T @ v_new) + state = state * exp(g_last) + (k * exp(g_last - gc))^T @ v_new """ import numpy as np @@ -219,9 +222,15 @@ def deltanet_fused_chunked_fwd( src=gc_row[0:1, CHUNK_SIZE - 1 : CHUNK_SIZE], ) - # ---- Compute exp(gc), exp(-gc), exp(g_last) as (P_MAX, 1) scalars ---- - # These (P_MAX, 1) tensors are used with tensor_scalar to broadcast - # across the free dimension without explicit (P_MAX, dim) copies. + # ---- Build stable decay factors from cumulative log-decay ---- + # + # Pairwise decays are computed as exp(gc[i] - gc[j]) under the causal + # mask. The older split form exp(gc[i]) * exp(-gc[j]) is algebraically + # equivalent, but can overflow/underflow before the multiply. + # + # The one-vector exp(gc) and exp(g_last) factors are still required by + # the chunk recurrence; these are non-positive GDN decays and therefore + # bounded above by one. exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) nisa.activation( @@ -232,24 +241,15 @@ def deltanet_fused_chunked_fwd( scale=1.0, ) - neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_scalar( - dst=neg_gc_p, - data=gc_p, - op0=nl.multiply, - operand0=-1.0, - engine=nisa.vector_engine, - ) - exp_neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) - nisa.activation( - dst=exp_neg_gc_p[0:P_MAX, 0:1], - op=nl.exp, - data=neg_gc_p[0:P_MAX, 0:1], - bias=None, - scale=1.0, - ) + # g_last: scalar, then broadcast raw and exp(g_last) to (P_MAX, 1) + gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gl_11[0:1, 0:1], + dst=gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) - # exp(g_last): scalar, then broadcast to (P_MAX, 1) exp_gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) nisa.activation( dst=exp_gl_11, @@ -267,6 +267,84 @@ def deltanet_fused_chunked_fwd( shuffle_mask=_BROADCAST_MASK, ) + # Broadcast gc row-wise so row i, column j can form gc[i] - gc[j]. + gc_row_broadcast = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gc_row[0:1, 0:P_MAX], + dst=gc_row_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + # Strict-lower decay for the KKT solve: exp(gc[i] - gc[j]) where i > j. + # Mask before exp by zeroing non-causal differences, then mask again + # after exp so exp(0) from non-causal positions does not contribute. + gc_col_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_strict, + data=Lmask, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_strict, data1=gc_row_broadcast, data2=Lmask, op=nl.multiply + ) + g_diff_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_strict, + data1=gc_col_strict, + data2=gc_row_strict, + op=nl.subtract, + ) + decay_strict_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_strict_raw, + op=nl.exp, + data=g_diff_strict, + bias=None, + scale=1.0, + ) + decay_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_strict, data1=decay_strict_raw, data2=Lmask, op=nl.multiply + ) + + # Lower-with-diagonal decay for intra-chunk attention: exp(gc[i] - gc[j]) + # where i >= j. + gc_col_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_diag, + data=Lmask_d, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_diag, data1=gc_row_broadcast, data2=Lmask_d, op=nl.multiply + ) + g_diff_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_diag, + data1=gc_col_diag, + data2=gc_row_diag, + op=nl.subtract, + ) + decay_diag_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_diag_raw, + op=nl.exp, + data=g_diff_diag, + bias=None, + scale=1.0, + ) + decay_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_diag, data1=decay_diag_raw, data2=Lmask_d, op=nl.multiply + ) + # ============================================================ # k_beta = K * beta, v_beta = V * beta # tensor_scalar broadcasts beta_p (P_MAX, 1) across free dim @@ -309,49 +387,11 @@ def deltanet_fused_chunked_fwd( QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_copy(dst=QK, src=QK_psum) - # ============================================================ - # Decay mask: QK_decay[i,j] = QK[i,j] * exp(gc[i]) * exp(-gc[j]) - # - # Apply the strict causal mask before the split exp(gc) / exp(-gc) - # scaling. Upper-triangular entries are mathematically unused, but - # scaling them first can create very large finite values that poison - # later matmuls before the mask is applied. - # ============================================================ - QK_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_tensor(dst=QK_masked, data1=QK, data2=Lmask, op=nl.multiply) - - # Row scaling: QK_row[i,:] = QK[i,:] * exp(gc[i]) - # Then transpose, column scale, transpose back. - # Uses tensor_scalar with (P_MAX,1) operand for row scaling. - QK_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_scalar( - dst=QK_row, - data=QK_masked, - op0=nl.multiply, - operand0=exp_gc_p, - engine=nisa.vector_engine, - ) - - # Transpose to scale columns (now rows in transposed view) - QK_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) - nisa.nc_transpose(dst=QK_r_T_psum, data=QK_row) - QK_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=QK_r_T, src=QK_r_T_psum) - - QK_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_scalar( - dst=QK_r_T_col, - data=QK_r_T, - op0=nl.multiply, - operand0=exp_neg_gc_p, - engine=nisa.vector_engine, - ) - - # Transpose back - QK_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) - nisa.nc_transpose(dst=QK_d_psum, data=QK_r_T_col) + # QK_decay[i,j] = QK[i,j] * exp(gc[i] - gc[j]) for i > j. + # This is the same causal decay as the split-exp form, but numerically + # bounded by construction. QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=QK_decay, src=QK_d_psum) + nisa.tensor_tensor(dst=QK_decay, data1=QK, data2=decay_strict, op=nl.multiply) # A = -QK_decay * lower_mask neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) @@ -366,38 +406,184 @@ def deltanet_fused_chunked_fwd( nisa.tensor_tensor(dst=A_mat, data1=neg_QK_decay, data2=Lmask, op=nl.multiply) # ============================================================ - # Neumann power-doubling: N = (I+A)(I+A^2)...(I+A^{64}) - # 6 rounds → resolves rank up to 2^6 = 64 (sufficient for chunk=128) + # Stable triangular solve: N = inv(I - A_mat) + # + # A_mat is strictly lower triangular. Solve two 64x64 diagonal + # blocks row-by-row: + # N[i, :] = e_i + sum_{j= j. attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_tensor( - dst=attn_intra, data1=qk_decay, data2=Lmask_d, op=nl.multiply - ) + nisa.tensor_tensor(dst=attn_intra, data1=qk_raw, data2=decay_diag, op=nl.multiply) # ============================================================ # v_prime = k_cumdecay @ state (state is in SBUF!) @@ -550,17 +700,8 @@ def deltanet_fused_chunked_fwd( # state is updated IN-PLACE in SBUF — no HBM round-trip! # ============================================================ - # k_raw_decay contributes as exp(g_last) * (k * exp(-gc))^T @ v_new. - # Compute the equivalent form with one bounded exponential, - # k * exp(g_last - gc), so the factor is always <= 1 for valid - # causal positions. - gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) - for i_shuf in nl.static_range(P_MAX // 32): - nisa.nc_stream_shuffle( - src=gl_11[0:1, 0:1], - dst=gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], - shuffle_mask=_BROADCAST_MASK, - ) + # k_raw_decay contributes as k * exp(g_last - gc), with one bounded + # exponential instead of exp(g_last) * exp(-gc). gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_tensor( dst=gl_minus_gc_p, diff --git a/contrib/models/Qwen3.5-9B/test/integration/test_model.py b/contrib/models/Qwen3.5-9B/test/integration/test_model.py index 725bca2e..ef77ea18 100644 --- a/contrib/models/Qwen3.5-9B/test/integration/test_model.py +++ b/contrib/models/Qwen3.5-9B/test/integration/test_model.py @@ -36,10 +36,7 @@ """ import gc -import json import os -import shutil -import subprocess import sys import time @@ -59,8 +56,6 @@ SEQ_LEN = int(os.environ.get("QWEN35_SEQ_LEN", "128")) TTFT_THRESHOLD_MS = float(os.environ.get("TTFT_THRESHOLD_MS", "5000")) THROUGHPUT_THRESHOLD = float(os.environ.get("THROUGHPUT_THRESHOLD", "5.0")) -USE_HYBRID_CACHE = os.environ.get("QWEN35_USE_HYBRID_CACHE", "0") == "1" -RECORD_HBM = os.environ.get("QWEN35_RECORD_HBM", "0") == "1" requires_model_path = pytest.mark.skipif( not MODEL_PATH, @@ -69,13 +64,6 @@ "weights. Set QWEN35_MODEL_PATH=/path/to/Qwen3.5-9B to run these tests." ), ) -requires_hbm_recording = pytest.mark.skipif( - not RECORD_HBM, - reason=( - "QWEN35_RECORD_HBM=1 not set. This optional test records Neuron HBM " - "usage for dummy-KV vs hybrid-cache comparisons." - ), -) # ── Fixtures ──────────────────────────────────────────────────────────── @@ -131,7 +119,6 @@ def compiled_model(model_path): inf_config = Qwen35InferenceConfig( neuron_config=neuron_config, - use_hybrid_cache_manager=USE_HYBRID_CACHE, **config_dict, ) @@ -197,6 +184,16 @@ def _generate(model, tokenizer, generation_config, prompt, max_new_tokens=20): return outputs[0].tolist(), tokenizer.decode(outputs[0], skip_special_tokens=True) +def _make_repeated_stress_prompt(tokenizer, target_tokens=133): + """Build a repeated prompt near the target token length.""" + seed = ( + "Repeat this stability phrase for DeltaNet recurrent decoding. " + "Repeat this stability phrase for DeltaNet recurrent decoding. " + ) + ids = tokenizer.encode(seed * 32, add_special_tokens=False)[:target_tokens] + return tokenizer.decode(ids, skip_special_tokens=True) + + def _is_repetitive(text, max_repeat=5): """Check for excessive word repetition.""" words = text.split() @@ -208,73 +205,6 @@ def _is_repetitive(text, max_repeat=5): return False -def _parse_peak_neuron_memory(stdout): - peak_device = 0 - peak_tensors = 0 - samples = 0 - for line in stdout.splitlines(): - line = line.strip() - if not line: - continue - try: - report = json.loads(line) - except json.JSONDecodeError: - continue - for runtime in report.get("neuron_runtime_data", []): - memory_used = runtime.get("report", {}).get("memory_used", {}) - used = memory_used.get("neuron_runtime_used_bytes", {}) - peak_device = max(peak_device, int(used.get("neuron_device", 0) or 0)) - nc_usage = ( - used.get("usage_breakdown", {}).get("neuroncore_memory_usage", {}) - ) - tensor_bytes = sum( - int(core.get("tensors", 0) or 0) for core in nc_usage.values() - ) - peak_tensors = max(peak_tensors, tensor_bytes) - samples += 1 - return peak_device, peak_tensors, samples - - -def _capture_neuron_hbm(tmp_path, fn): - if shutil.which("neuron-monitor") is None: - pytest.skip("neuron-monitor is not available") - - monitor_config = { - "period": "0.5s", - "neuron_runtimes": [ - { - "tag_filter": ".*", - "metrics": [{"type": "memory_used", "period": "0.5s"}], - } - ], - } - config_path = tmp_path / "neuron-monitor.json" - config_path.write_text(json.dumps(monitor_config)) - - proc = subprocess.Popen( - ["neuron-monitor", "--config-file", str(config_path)], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - try: - time.sleep(1.0) - result = fn() - time.sleep(1.0) - finally: - proc.terminate() - try: - stdout, stderr = proc.communicate(timeout=5) - except subprocess.TimeoutExpired: - proc.kill() - stdout, stderr = proc.communicate(timeout=5) - - peak_device, peak_tensors, samples = _parse_peak_neuron_memory(stdout) - assert samples > 0, f"neuron-monitor produced no runtime samples: {stderr}" - assert peak_device > 0, "Expected non-zero Neuron device HBM usage" - return result, peak_device, peak_tensors, samples - - # ── Smoke Tests ───────────────────────────────────────────────────────── @@ -365,6 +295,31 @@ def test_olympics_prompt_no_invalid_tokens( assert not invalid, f"Generated invalid token ids: {invalid}" +@requires_model_path +def test_repeated_stress_prompt_no_invalid_tokens( + compiled_model, tokenizer, generation_config +): + """Regression test for repeated 129/133-token prompts that exposed NaNs.""" + prompt = _make_repeated_stress_prompt(tokenizer, target_tokens=133) + prompt_len = len(tokenizer.encode(prompt)) + assert 129 <= prompt_len <= 133, f"Unexpected stress prompt length: {prompt_len}" + + tokens, text = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=16, + ) + generated = tokens[prompt_len:] + invalid = [token for token in generated if token < 0 or token >= len(tokenizer)] + + assert len(generated) >= 5, f"Expected >= 5 generated tokens, got {generated}" + assert not invalid, f"Generated invalid token ids: {invalid}" + assert len(text) > len(prompt) + print(f" Repeated stress prompt length: {prompt_len} tokens") + + @requires_model_path def test_simple_factual_generation(compiled_model, tokenizer, generation_config): """A simple factual prompt should produce the expected entity.""" @@ -443,33 +398,6 @@ def test_performance_throughput(compiled_model, tokenizer, generation_config): ) -@requires_model_path -@requires_hbm_recording -def test_hybrid_cache_hbm_snapshot(compiled_model, tokenizer, generation_config, tmp_path): - """Record peak Neuron HBM for dummy-KV vs hybrid-cache comparison runs.""" - prompt = "Give me a summary of the 2020 Olympics in 100 tokens." - max_new_tokens = int(os.environ.get("QWEN35_HBM_NEW_TOKENS", "32")) - - (_, text), peak_device, peak_tensors, samples = _capture_neuron_hbm( - tmp_path, - lambda: _generate( - compiled_model, - tokenizer, - generation_config, - prompt, - max_new_tokens=max_new_tokens, - ), - ) - - mode = "hybrid" if USE_HYBRID_CACHE else "dummy_kv" - print( - " HBM " - f"mode={mode} peak_device_bytes={peak_device} " - f"peak_tensor_bytes={peak_tensors} samples={samples}" - ) - assert len(text) > len(prompt) - - # ── Multi-Prompt Quality Test ────────────────────────────────────────── diff --git a/contrib/models/Qwen3.5-9B/test/unit/test_deltanet_decay.py b/contrib/models/Qwen3.5-9B/test/unit/test_deltanet_decay.py index 416a431a..f3d7d8bc 100644 --- a/contrib/models/Qwen3.5-9B/test/unit/test_deltanet_decay.py +++ b/contrib/models/Qwen3.5-9B/test/unit/test_deltanet_decay.py @@ -1,67 +1,34 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -"""Unit tests for fused DeltaNet log-decay bounding.""" +"""Unit tests for fused DeltaNet log-decay stability structure.""" -import os -import sys +import pathlib import unittest -import torch -_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) -if _CONTRIB_ROOT not in sys.path: - sys.path.insert(0, _CONTRIB_ROOT) +_CONTRIB_ROOT = pathlib.Path(__file__).resolve().parents[2] +_SRC_ROOT = _CONTRIB_ROOT / "src" -from src.modeling_qwen35 import ( - FUSED_DELTANET_DECAY_MAX, - FUSED_DELTANET_DECAY_MIN, - _bound_fused_deltanet_log_decay, -) +class TestFusedDeltaNetDecayStability(unittest.TestCase): + def test_fused_kernel_uses_exp_of_differences(self): + kernel_source = ( + _SRC_ROOT / "nki_kernels" / "nki_deltanet_fused.py" + ).read_text() -def _chunked_cumsum(g, batch_size, num_heads, total_seq_len, chunk_size): - num_chunks = total_seq_len // chunk_size - return g.reshape(batch_size, num_heads, num_chunks, chunk_size).cumsum(dim=-1) + self.assertIn("decay_strict", kernel_source) + self.assertIn("decay_diag", kernel_source) + self.assertIn("gl_minus_gc_p", kernel_source) + self.assertNotIn("exp_neg_gc_p", kernel_source) + self.assertNotIn("operand0=exp_neg_gc_p", kernel_source) + def test_modeling_does_not_clamp_fused_decay_inputs(self): + modeling_source = (_SRC_ROOT / "modeling_qwen35.py").read_text() -class TestFusedDeltaNetDecayBounding(unittest.TestCase): - def test_preserves_non_extreme_decay(self): - batch_size, num_heads, total_seq_len, chunk_size = 2, 3, 16, 8 - g = torch.full( - (batch_size, num_heads, total_seq_len), - -0.125, - dtype=torch.float32, - ) - - bounded = _bound_fused_deltanet_log_decay( - g, batch_size, num_heads, total_seq_len, chunk_size - ) - - torch.testing.assert_close(bounded, g) - - def test_bounds_per_chunk_cumulative_decay(self): - batch_size, num_heads, total_seq_len, chunk_size = 2, 3, 16, 8 - g = torch.full( - (batch_size, num_heads, total_seq_len), - -10.0, - dtype=torch.float32, - ) - - bounded = _bound_fused_deltanet_log_decay( - g, batch_size, num_heads, total_seq_len, chunk_size - ) - bounded_cumsum = _chunked_cumsum( - bounded, batch_size, num_heads, total_seq_len, chunk_size - ) - expected_cumsum = _chunked_cumsum( - g, batch_size, num_heads, total_seq_len, chunk_size - ).clamp(min=FUSED_DELTANET_DECAY_MIN, max=FUSED_DELTANET_DECAY_MAX) - - torch.testing.assert_close(bounded_cumsum, expected_cumsum) - self.assertGreaterEqual(float(bounded_cumsum.min()), FUSED_DELTANET_DECAY_MIN) - self.assertLessEqual(float(bounded_cumsum.max()), FUSED_DELTANET_DECAY_MAX) - self.assertTrue(torch.isfinite(bounded).all()) + self.assertNotIn("_bound_fused_deltanet_log_decay", modeling_source) + self.assertNotIn("FUSED_DELTANET_DECAY_MIN", modeling_source) + self.assertIn("exp(cumsum(g)_i - cumsum(g)_j)", modeling_source) if __name__ == "__main__": diff --git a/contrib/models/Qwen3.5-9B/test/unit/test_hybrid_cache_manager.py b/contrib/models/Qwen3.5-9B/test/unit/test_hybrid_cache_manager.py deleted file mode 100644 index 503dc24b..00000000 --- a/contrib/models/Qwen3.5-9B/test/unit/test_hybrid_cache_manager.py +++ /dev/null @@ -1,341 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -import os -import sys -import unittest -from math import prod -from unittest.mock import patch - -import torch - -_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) -if _CONTRIB_ROOT not in sys.path: - sys.path.insert(0, _CONTRIB_ROOT) - -from neuronx_distributed_inference.models.config import NeuronConfig -from src.modeling_qwen35 import HybridDeltaNetCacheManager, Qwen35InferenceConfig - - -def _make_config(**overrides): - neuron_overrides = overrides.pop("neuron_overrides", {}) - neuron_kwargs = dict( - tp_degree=overrides.pop("tp_degree", 4), - batch_size=1, - max_batch_size=2, - kv_cache_batch_size=2, - seq_len=16, - torch_dtype=torch.bfloat16, - ) - neuron_kwargs.update(neuron_overrides) - neuron_config = NeuronConfig(**neuron_kwargs) - defaults = dict( - hidden_size=4096, - num_hidden_layers=32, - num_attention_heads=16, - num_key_value_heads=4, - head_dim=256, - intermediate_size=12288, - vocab_size=248320, - rms_norm_eps=1e-6, - max_position_embeddings=262144, - rope_theta=10000000, - hidden_act="silu", - tie_word_embeddings=False, - linear_num_value_heads=32, - linear_num_key_heads=16, - linear_key_head_dim=128, - linear_value_head_dim=128, - linear_conv_kernel_dim=4, - use_hybrid_cache_manager=True, - ) - defaults.update(overrides) - return Qwen35InferenceConfig(neuron_config=neuron_config, **defaults) - - -def _numel(shape): - return prod(int(dim) for dim in shape) - - -def _managed_cache_numel(mgr): - return sum(param.numel() for param in mgr.past_key_values) - - -def _deltanet_state_numel(config, max_batch_size): - recurrent = ( - max_batch_size - * config.linear_num_value_heads - * config.linear_key_head_dim - * config.linear_value_head_dim - ) - conv_dim = ( - 2 * config.linear_num_key_heads * config.linear_key_head_dim - + config.linear_num_value_heads * config.linear_value_head_dim - ) - conv = max_batch_size * conv_dim * (config.linear_conv_kernel_dim - 1) - return recurrent + conv - - -class TestHybridDeltaNetCacheManager(unittest.TestCase): - def test_allocates_per_layer_cache_shapes(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - - self.assertEqual(len(mgr.past_key_values), config.num_hidden_layers * 2) - self.assertEqual( - list(mgr.past_key_values[0].shape), - [2, 32, 128, 128], - ) - self.assertEqual( - list(mgr.past_key_values[1].shape), - [2, 8192, 3], - ) - self.assertEqual(mgr.layer_types[3], "full_attention") - self.assertEqual(mgr.past_key_values[6].dim(), 4) - self.assertEqual(mgr.past_key_values[7].shape[2], 16) - - def test_get_cache_slices_only_full_attention_layers(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - - cache = mgr.get_cache(seq_len=4, seq_ids=torch.tensor([1])) - recurrent_state, conv_state = cache[0] - full_k, full_v = cache[3] - - self.assertEqual(list(recurrent_state.shape), [1, 32, 128, 128]) - self.assertEqual(list(conv_state.shape), [1, 8192, 3]) - self.assertEqual(full_k.shape[0], 2) - self.assertEqual(full_v.shape[0], 2) - self.assertEqual(full_k.shape[2], 4) - self.assertEqual(full_v.shape[2], 4) - - def test_get_seq_length_uses_first_full_attention_layer(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - - nested_cache = mgr.get_cache(seq_len=5, seq_ids=torch.tensor([0])) - flat_cache = [tensor for layer_cache in nested_cache for tensor in layer_cache] - - self.assertEqual(nested_cache[0][1].shape[2], 3) - self.assertEqual(mgr.get_seq_length(nested_cache), 5) - self.assertEqual(mgr.get_seq_length(flat_cache), 5) - - def test_get_cache_selects_deltanet_state_rows_by_seq_ids(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - with torch.no_grad(): - mgr.past_key_values[0][0, ...].fill_(7) - mgr.past_key_values[0][1, ...].fill_(13) - mgr.past_key_values[1][0, ...].fill_(17) - mgr.past_key_values[1][1, ...].fill_(19) - - recurrent_state, conv_state = mgr.get_cache( - seq_len=4, - seq_ids=torch.tensor([1, 0]), - )[0] - - self.assertTrue(torch.all(recurrent_state[0] == 13)) - self.assertTrue(torch.all(recurrent_state[1] == 7)) - self.assertTrue(torch.all(conv_state[0] == 19)) - self.assertTrue(torch.all(conv_state[1] == 17)) - - def test_deltanet_update_scatters_by_seq_id(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - recurrent = torch.ones((1, 32, 128, 128), dtype=torch.bfloat16) - conv = torch.ones((1, 8192, 3), dtype=torch.bfloat16) - - updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( - idx=0, - seq_ids=torch.tensor([1]), - state_per_layer=(recurrent, conv), - ) - - self.assertTrue(torch.all(updated_recurrent[0] == 0)) - self.assertTrue(torch.all(updated_conv[0] == 0)) - self.assertTrue(torch.all(updated_recurrent[1] == 1)) - self.assertTrue(torch.all(updated_conv[1] == 1)) - - def test_deltanet_full_batch_update_replaces_state_cache(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - recurrent = torch.ones((2, 32, 128, 128), dtype=torch.bfloat16) - conv = torch.ones((2, 8192, 3), dtype=torch.bfloat16) - recurrent[0].fill_(3) - recurrent[1].fill_(5) - conv[0].fill_(11) - conv[1].fill_(13) - - updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( - idx=0, - seq_ids=None, - state_per_layer=(recurrent, conv), - ) - - self.assertTrue(torch.all(updated_recurrent[0] == 3)) - self.assertTrue(torch.all(updated_recurrent[1] == 5)) - self.assertTrue(torch.all(updated_conv[0] == 11)) - self.assertTrue(torch.all(updated_conv[1] == 13)) - - def test_deltanet_full_batch_update_scatters_non_identity_seq_ids(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - recurrent = torch.ones((2, 32, 128, 128), dtype=torch.bfloat16) - conv = torch.ones((2, 8192, 3), dtype=torch.bfloat16) - recurrent[0].fill_(3) - recurrent[1].fill_(5) - conv[0].fill_(11) - conv[1].fill_(13) - - updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( - idx=0, - seq_ids=torch.tensor([1, 0]), - state_per_layer=(recurrent, conv), - ) - - self.assertTrue(torch.all(updated_recurrent[0] == 5)) - self.assertTrue(torch.all(updated_recurrent[1] == 3)) - self.assertTrue(torch.all(updated_conv[0] == 13)) - self.assertTrue(torch.all(updated_conv[1] == 11)) - - def test_deltanet_update_maps_out_of_range_seq_id_to_padding_row(self): - config = _make_config(neuron_overrides={"kv_cache_padding_size": 1}) - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - recurrent = torch.ones((1, 32, 128, 128), dtype=torch.bfloat16) - conv = torch.ones((1, 8192, 3), dtype=torch.bfloat16) - - updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( - idx=0, - seq_ids=torch.tensor([99]), - state_per_layer=(recurrent, conv), - ) - - self.assertTrue(torch.all(updated_recurrent[0] == 0)) - self.assertTrue(torch.all(updated_recurrent[1] == 0)) - self.assertTrue(torch.all(updated_recurrent[2] == 1)) - self.assertTrue(torch.all(updated_conv[2] == 1)) - - def test_deltanet_state_shapes_do_not_scale_with_sequence_length(self): - short_config = _make_config(neuron_overrides={"seq_len": 128}) - long_config = _make_config(neuron_overrides={"seq_len": 2048}) - short_mgr = HybridDeltaNetCacheManager( - short_config, num_kv_head=short_config.num_key_value_heads - ) - long_mgr = HybridDeltaNetCacheManager( - long_config, num_kv_head=long_config.num_key_value_heads - ) - - self.assertEqual(short_mgr.past_key_values[0].shape, long_mgr.past_key_values[0].shape) - self.assertEqual(short_mgr.past_key_values[1].shape, long_mgr.past_key_values[1].shape) - self.assertLess(short_mgr.past_key_values[7].shape[2], long_mgr.past_key_values[7].shape[2]) - - def test_get_cache_trims_padding_row_without_seq_ids(self): - config = _make_config(neuron_overrides={"kv_cache_padding_size": 1}) - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - - recurrent_state, conv_state = mgr.get_cache(seq_len=4)[0] - - self.assertEqual(list(recurrent_state.shape), [2, 32, 128, 128]) - self.assertEqual(list(conv_state.shape), [2, 8192, 3]) - - def test_update_cache_dispatches_deltanet_and_full_attention_layers(self): - config = _make_config() - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - new_key_values = [] - for idx in range(4): - first = mgr.past_key_values[2 * idx] - second = mgr.past_key_values[2 * idx + 1] - new_key_values.append( - ( - torch.full_like(first, fill_value=idx + 1), - torch.full_like(second, fill_value=idx + 11), - ) - ) - - position_ids = torch.arange(16, dtype=torch.long).unsqueeze(0).expand(2, -1) - full_k_update = torch.full_like(mgr.past_key_values[6], fill_value=4) - full_v_update = torch.full_like(mgr.past_key_values[7], fill_value=14) - with patch.object( - mgr, "update_kv_by_layer_id", return_value=(full_k_update, full_v_update) - ) as update_kv: - updated = mgr.update_cache( - is_for_context_encoding=True, - seq_ids=torch.tensor([0, 1], dtype=torch.int32), - position_ids=position_ids, - new_key_values=new_key_values, - seq_len=16, - ) - - self.assertEqual(update_kv.call_count, 1) - self.assertEqual(update_kv.call_args.kwargs["idx"], 3) - self.assertTrue(torch.all(updated[0] == 1)) - self.assertTrue(torch.all(updated[1] == 11)) - self.assertTrue(torch.all(updated[6] == 4)) - self.assertTrue(torch.all(updated[7] == 14)) - - def test_managed_cache_removes_dummy_kv_for_deltanet_layers(self): - config = _make_config(neuron_overrides={"seq_len": 1024}) - mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - max_batch_size = ( - config.neuron_config.kv_cache_batch_size - + config.neuron_config.kv_cache_padding_size - ) - full_kv_per_layer = _numel(mgr.k_shape) + _numel(mgr.v_shape) - deltanet_layers = config.layer_types.count("linear_attention") - legacy_total_numel = ( - full_kv_per_layer * config.num_hidden_layers - + _deltanet_state_numel(config, max_batch_size) * deltanet_layers - ) - expected_savings = full_kv_per_layer * deltanet_layers - - self.assertEqual( - legacy_total_numel - _managed_cache_numel(mgr), - expected_savings, - ) - self.assertLess(_managed_cache_numel(mgr), legacy_total_numel) - - def test_rejects_unsupported_hybrid_modes(self): - unsupported_cases = [ - ({"padding_side": "left"}, "left padding"), - ({"flash_decoding_enabled": True}, "flash decoding"), - ] - - for neuron_overrides, expected_error in unsupported_cases: - with self.subTest(expected_error=expected_error): - config = _make_config(neuron_overrides=neuron_overrides) - with self.assertRaisesRegex(ValueError, expected_error): - HybridDeltaNetCacheManager( - config, num_kv_head=config.num_key_value_heads - ) - - config = _make_config() - config.neuron_config.kv_cache_quant = True - with self.assertRaisesRegex(ValueError, "KV cache quantization"): - HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - - config = _make_config( - neuron_overrides={ - "attention_dp_degree": 2, - "batch_size": 2, - "ctx_batch_size": 2, - "tkg_batch_size": 2, - "max_batch_size": 2, - "kv_cache_batch_size": 2, - "is_continuous_batching": True, - } - ) - with self.assertRaisesRegex(ValueError, "attention data parallelism"): - HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - - config = _make_config() - config.neuron_config.kv_cache_tiling = True - with self.assertRaisesRegex(ValueError, "KV cache tiling"): - HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) - - def test_legacy_config_default_is_disabled(self): - config = _make_config(use_hybrid_cache_manager=False) - self.assertFalse(config.use_hybrid_cache_manager) - - -if __name__ == "__main__": - unittest.main()