diff --git a/examples/models/qwen3_5_35B_A3B/config.py b/examples/models/qwen3_5_35B_A3B/config.py new file mode 100644 index 0000000..9dd3ddf --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/config.py @@ -0,0 +1,79 @@ +from dataclasses import dataclass, field +from typing import List + +import numpy as np +import torch.distributed as dist +from neuronxcc.nki.language import bfloat16 +from transformers import AutoConfig + +DTYPE = bfloat16 + +# Layer types for Qwen3.5 hybrid architecture +FULL_ATTENTION = "full_attention" +LINEAR_ATTENTION = "linear_attention" + + +@dataclass +class Config: + hidden_size: int + num_heads: int # full attention Q heads + head_dim: int # full attention head dim + num_kv_heads: int # full attention KV heads + num_layers: int + num_experts_per_tok: int + num_experts: int + intermediate_size: int # moe expert intermediate (per device) + shared_expert_intermediate_size: int # shared expert intermediate (per device) + vocab_size: int + # Linear attention params + linear_num_key_heads: int + linear_num_value_heads: int + linear_key_head_dim: int + linear_value_head_dim: int + linear_conv_kernel_dim: int + # Layer type info + layer_types: List[str] = field(default_factory=list) + # RoPE + partial_rotary_factor: float = 0.25 + rope_theta: float = 10000000.0 + # Sequence + context_len: int = None + max_new_tokens: int = None + max_batch_size: int = 1 + max_seq_len: int = 4096 + # Norm + norm_eps: float = 1e-6 + dtype: np.dtype = DTYPE + additional_compiler_args_nkipy: str = "--lnc 1" + + +def get_config(model_name, context_len, max_new_tokens): + hf_config = AutoConfig.from_pretrained(model_name) + # Qwen3.5 is multimodal; text config is nested + text_cfg = hf_config.text_config if hasattr(hf_config, "text_config") else hf_config + + ws = dist.get_world_size() + config = Config( + hidden_size=text_cfg.hidden_size, + num_heads=text_cfg.num_attention_heads, + head_dim=text_cfg.head_dim, + num_kv_heads=text_cfg.num_key_value_heads, + num_layers=text_cfg.num_hidden_layers, + num_experts_per_tok=text_cfg.num_experts_per_tok, + num_experts=text_cfg.num_experts, + intermediate_size=text_cfg.moe_intermediate_size // ws, + shared_expert_intermediate_size=text_cfg.shared_expert_intermediate_size // ws, + vocab_size=text_cfg.vocab_size, + linear_num_key_heads=text_cfg.linear_num_key_heads, + linear_num_value_heads=text_cfg.linear_num_value_heads, + linear_key_head_dim=text_cfg.linear_key_head_dim, + linear_value_head_dim=text_cfg.linear_value_head_dim, + linear_conv_kernel_dim=text_cfg.linear_conv_kernel_dim, + layer_types=list(text_cfg.layer_types), + partial_rotary_factor=text_cfg.partial_rotary_factor, + rope_theta=text_cfg.rope_parameters.get("rope_theta", 10000000.0), + norm_eps=text_cfg.rms_norm_eps, + context_len=context_len, + max_new_tokens=max_new_tokens, + ) + return config diff --git a/examples/models/qwen3_5_35B_A3B/evaluate.py b/examples/models/qwen3_5_35B_A3B/evaluate.py new file mode 100644 index 0000000..980693c --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/evaluate.py @@ -0,0 +1,163 @@ +"""Benchmarking and accuracy validation for NKIPy Qwen3.5-35B-A3B.""" + +import json +import os +import time + +import torch + + +# --------------------------------------------------------------------------- +# Benchmarking +# --------------------------------------------------------------------------- + + +def _percentile(data, pct): + if not data: + return 0.0 + sorted_data = sorted(data) + k = (len(sorted_data) - 1) * (pct / 100.0) + f = int(k) + c = f + 1 + if c >= len(sorted_data): + return sorted_data[f] + return sorted_data[f] + (k - f) * (sorted_data[c] - sorted_data[f]) + + +def _run_once(model, input_ids): + token_times = [] + start = time.perf_counter() + for i, _token_id in enumerate(model.generate(input_ids)): + now = time.perf_counter() + if i == 0: + ttft_ms = (now - start) * 1000.0 + prev = now + else: + token_times.append((now - prev) * 1000.0) + prev = now + + end = time.perf_counter() + return { + "ttft_ms": ttft_ms, + "decode_latencies_ms": token_times, + "num_tokens": len(token_times) + 1, + "total_time_ms": (end - start) * 1000.0, + } + + +def benchmark_generation(model, input_ids, num_warmup=2, num_runs=5): + total_runs = num_warmup + num_runs + run_reports = [] + + for run_idx in range(total_runs): + is_warmup = run_idx < num_warmup + label = ( + f"warmup {run_idx + 1}/{num_warmup}" + if is_warmup + else f"run {run_idx - num_warmup + 1}/{num_runs}" + ) + print(f"[benchmark] {label}...") + + report = _run_once(model, input_ids) + + if not is_warmup: + run_reports.append(report) + throughput = ( + report["num_tokens"] / (report["total_time_ms"] / 1000.0) + if report["total_time_ms"] > 0 + else 0.0 + ) + print( + f" TTFT={report['ttft_ms']:.1f}ms " + f"tokens={report['num_tokens']} " + f"throughput={throughput:.1f} tok/s" + ) + + n = len(run_reports) + if n == 0: + return {} + + avg_ttft = sum(r["ttft_ms"] for r in run_reports) / n + avg_total = sum(r["total_time_ms"] for r in run_reports) / n + num_tokens = run_reports[0]["num_tokens"] + + all_decode = [] + for r in run_reports: + all_decode.extend(r["decode_latencies_ms"]) + + throughput = num_tokens / (avg_total / 1000.0) if avg_total > 0 else 0.0 + + result = { + "ttft_ms": round(avg_ttft, 3), + "decode_latency_p50_ms": round(_percentile(all_decode, 50), 3), + "decode_latency_p90_ms": round(_percentile(all_decode, 90), 3), + "decode_latency_p99_ms": round(_percentile(all_decode, 99), 3), + "num_tokens": num_tokens, + "total_time_ms": round(avg_total, 3), + "throughput_tokens_per_sec": round(throughput, 2), + } + + print("\n=== Benchmark Results ===") + print(f" TTFT: {result['ttft_ms']:.1f} ms") + print(f" Decode latency (p50): {result['decode_latency_p50_ms']:.1f} ms") + print(f" Decode latency (p90): {result['decode_latency_p90_ms']:.1f} ms") + print(f" Decode latency (p99): {result['decode_latency_p99_ms']:.1f} ms") + print(f" Tokens generated: {result['num_tokens']}") + print( + f" Throughput: {result['throughput_tokens_per_sec']:.1f} tokens/sec" + ) + print("=========================\n") + + return result + + +def save_benchmark_report(result, path="benchmark_report.json"): + with open(path, "w") as f: + json.dump(result, f, indent=2) + print(f"[benchmark] Report saved to {path}") + + +# --------------------------------------------------------------------------- +# CLI driver +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + import sys + + parser = argparse.ArgumentParser( + description="Evaluate NKIPy Qwen3.5-35B-A3B: benchmark" + ) + + parser.add_argument("-n", "--max-new-tokens", type=int, default=16) + parser.add_argument("prompt", nargs="?", default="The capital of France is") + parser.add_argument("--checkpoint", default="./qwen3_5_shards") + parser.add_argument("--model", default="Qwen/Qwen3.5-35B-A3B") + + mode = parser.add_mutually_exclusive_group(required=True) + mode.add_argument("--benchmark", action="store_true") + + parser.add_argument("--benchmark-warmup", type=int, default=2) + parser.add_argument("--benchmark-runs", type=int, default=5) + parser.add_argument( + "--benchmark-output", type=str, default="benchmark_report.json" + ) + + args = parser.parse_args() + + import torch.distributed as dist + + from qwen3_5 import load_model + + model, input_ids, _ = load_model(args) + + if args.benchmark: + dist.barrier() + result = benchmark_generation( + model, + input_ids, + num_warmup=args.benchmark_warmup, + num_runs=args.benchmark_runs, + ) + if dist.get_rank() == 0: + save_benchmark_report(result, args.benchmark_output) diff --git a/examples/models/qwen3_5_35B_A3B/kernels/__init__.py b/examples/models/qwen3_5_35B_A3B/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/models/qwen3_5_35B_A3B/kernels/attention.py b/examples/models/qwen3_5_35B_A3B/kernels/attention.py new file mode 100644 index 0000000..675d9eb --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/kernels/attention.py @@ -0,0 +1,164 @@ +from typing import Optional + +import neuronxcc.nki.language as nl +import nkipy.core.typing as nt +import nkipy.distributed.collectives as cc +import numpy as np +import torch.distributed as dist +from nkipy.core import tensor_apis + +from .rmsnorm import rmsnorm_kernel +from .rope import apply_rotary_emb_kernel, compute_cos_sin_cache +from .softmax import softmax_kernel + + +def repeat_kv_kernel(x, n_rep: int): + if n_rep == 1: + return x + return np.repeat(x, n_rep, axis=2) + + +def attention_kernel( + x, + qkv_weight, + q_norm_weight, + k_norm_weight, + norm_eps, + n_heads, + head_dim, + n_kv_heads, + partial_rotary_factor, + rope_theta, + cache_k, + cache_v, + start_pos: Optional[nt.tensor], + o_weight, +): + """Full attention kernel for Qwen3.5 with output gating and partial RoPE. + + Key differences from Qwen3: + - Q projection includes gate (2x wider): query + gate + - Partial RoPE: only first rotary_dim dims get rotary embeddings + - Output gating: output = sigmoid(gate) * attention_output + """ + is_prefill = start_pos is None + batch_size, seq_len, _ = x.shape + + n_local_heads = n_heads // dist.get_world_size() + assert n_local_heads > 0 + n_local_kv_heads = max(1, n_kv_heads // dist.get_world_size()) + n_rep = n_local_heads // n_local_kv_heads + + rotary_dim = int(head_dim * partial_rotary_factor) + + # QKV projection (Q includes gate, so 2x head_dim) + split_axis = x.ndim - 1 + split0 = n_local_heads * head_dim * 2 # query + gate + split1 = split0 + n_local_kv_heads * head_dim + splits = [split0, split1] + xqg, xk, xv = np.split(np.matmul(x, qkv_weight), splits, axis=split_axis) + + # Split query and gate from Q projection + xqg = xqg.reshape(batch_size, seq_len, n_local_heads, head_dim * 2) + xq = xqg[:, :, :, :head_dim] + gate = xqg[:, :, :, head_dim:] + gate = gate.reshape(batch_size, seq_len, -1) # (B, S, n_local_heads * head_dim) + + xk = xk.reshape(batch_size, seq_len, n_local_kv_heads, head_dim) + xv = xv.reshape(batch_size, seq_len, n_local_kv_heads, head_dim) + + # QK RMSNorm + xq = rmsnorm_kernel(xq, q_norm_weight, norm_eps) + xk = rmsnorm_kernel(xk, k_norm_weight, norm_eps) + + # Partial RoPE: only apply to first rotary_dim dims + max_seq_len = cache_k.shape[1] + freqs_cos, freqs_sin = compute_cos_sin_cache( + rotary_dim, max_seq_len, base=rope_theta, dtype=nl.bfloat16 + ) + + if is_prefill: + freqs_cos = freqs_cos[0:seq_len] + freqs_sin = freqs_sin[0:seq_len] + else: + freqs_cos = tensor_apis.constant(freqs_cos) + freqs_sin = tensor_apis.constant(freqs_sin) + freqs_cos = freqs_cos[start_pos] + freqs_sin = freqs_sin[start_pos] + + # Split rotary and pass-through dimensions + xq_rot = xq[:, :, :, :rotary_dim] + xq_pass = xq[:, :, :, rotary_dim:] + xk_rot = xk[:, :, :, :rotary_dim] + xk_pass = xk[:, :, :, rotary_dim:] + + xq_rot, xk_rot = apply_rotary_emb_kernel(xq_rot, xk_rot, freqs_cos, freqs_sin) + + xq = np.concatenate([xq_rot, xq_pass], axis=-1) + xk = np.concatenate([xk_rot, xk_pass], axis=-1) + + # KV cache update + if is_prefill: + cache_k[:, :seq_len] = xk + cache_v[:, :seq_len] = xv + else: + assert seq_len == 1 + cache_k[:, start_pos] = xk + cache_v[:, start_pos] = xv + + # GQA: repeat KV heads + keys = repeat_kv_kernel(cache_k, n_rep) + values = repeat_kv_kernel(cache_v, n_rep) + + # Transpose for attention: BSHD -> BHSD + xq = xq.transpose(0, 2, 1, 3) + keys = keys.transpose(0, 2, 1, 3) + values = values.transpose(0, 2, 1, 3) + + # Attention scores + k_seq_len = keys.shape[2] + scores = (xq @ keys.transpose(0, 1, 3, 2)) / np.float32(np.sqrt(head_dim)) + scores = scores.astype(nl.bfloat16) + + # Causal mask + causal_mask = np.triu(np.ones((k_seq_len, k_seq_len)) * -100000, k=1).astype( + scores.dtype + ) + causal_mask = tensor_apis.constant(causal_mask) + if is_prefill: + scores = scores + np.expand_dims( + causal_mask[:seq_len, :k_seq_len], axis=[0, 1] + ) + else: + scores = scores + np.expand_dims( + causal_mask[start_pos, :k_seq_len], axis=[0, 1] + ) + + attention_weights = softmax_kernel(scores) + + # Apply attention + output = attention_weights @ values + + # Transpose back: BHSD -> BSHD and flatten + output = output.transpose(0, 2, 1, 3) + output = output.reshape(batch_size, seq_len, -1) + + # Output gating: sigmoid(gate) * output + gate_sigmoid = 1.0 / (1.0 + np.exp(-gate.astype(np.float32))) + gate_sigmoid = gate_sigmoid.astype(output.dtype) + output = output * gate_sigmoid + + # Output projection + output_to_be_reduced = np.matmul(output, o_weight) + + # All-reduce for tensor parallelism (skip for TP=1) + if dist.get_world_size() > 1: + output = cc.all_reduce( + output_to_be_reduced, + replica_groups=[list(range(dist.get_world_size()))], + reduce_op=np.add, + ) + else: + output = output_to_be_reduced + + return output diff --git a/examples/models/qwen3_5_35B_A3B/kernels/feedforward.py b/examples/models/qwen3_5_35B_A3B/kernels/feedforward.py new file mode 100644 index 0000000..d80261f --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/kernels/feedforward.py @@ -0,0 +1,22 @@ +import numpy as np + + +def silu_kernel_(x): + """SiLU (Swish) activation function: x * sigmoid(x).""" + return x * (1 / (1 + np.exp(-x))) + + +def feedforward_kernel(x, gate_up_weight, down_weight): + """Feed-forward network kernel with SiLU activation and gating.""" + mm_gup = np.matmul(x, gate_up_weight) + xg, x_V = np.split(mm_gup, 2, axis=-1) + swish = silu_kernel_(xg) + x0 = swish * x_V + return x0 @ down_weight + + +def shared_expert_kernel(x, gate_proj_weight, up_proj_weight, down_proj_weight): + """Shared expert FFN: SiLU(x @ gate) * (x @ up) @ down.""" + gate = silu_kernel_(np.matmul(x, gate_proj_weight)) + up = np.matmul(x, up_proj_weight) + return np.matmul(gate * up, down_proj_weight) diff --git a/examples/models/qwen3_5_35B_A3B/kernels/linear_attention.py b/examples/models/qwen3_5_35B_A3B/kernels/linear_attention.py new file mode 100644 index 0000000..6797634 --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/kernels/linear_attention.py @@ -0,0 +1,256 @@ +from typing import Optional + +import nkipy.core.typing as nt +import nkipy.distributed.collectives as cc +import numpy as np +import torch.distributed as dist +from nkipy.core import tensor_apis + + +def sigmoid(x): + return 1.0 / (1.0 + np.exp(-x.astype(np.float32))) + + +def silu(x): + x_f32 = x.astype(np.float32) + return x_f32 * sigmoid(x_f32) + + +def softplus(x): + return np.log(1.0 + np.exp(x)) + + +def l2norm(x, eps=1e-6): + inv_norm = 1.0 / np.sqrt(np.sum(x * x, axis=-1, keepdims=True) + eps) + return x * inv_norm + + +def rmsnorm_gated(x, weight, gate, eps): + """RMSNormGated: weight * rmsnorm(x) * silu(gate) + + Used in the GatedDeltaNet output normalization. + weight is initialized to ones (not the 1+w convention). + """ + x_f32 = x.astype(np.float32) + weight_f32 = weight.astype(np.float32) + variance = np.mean(x_f32 * x_f32, axis=-1, keepdims=True) + x_normed = x_f32 / np.sqrt(variance + eps) + x_normed = weight_f32 * x_normed + x_normed = x_normed * silu(gate) + return x_normed.astype(x.dtype) + + +def causal_conv1d_prefill(x, conv_weight, kernel_size): + """Causal 1D depthwise convolution for prefill. + + Args: + x: (B, C, S) input (traced tensor) + conv_weight: (C, kernel_size) depthwise conv weights (traced tensor) + kernel_size: int (compile-time constant) + + Returns: + output: (B, C, S) convolved output + conv_state: (B, C, kernel_size) state for decode + """ + B, C, S = x.shape + # Create zero padding as a compile-time constant, then promote to traced + zeros_pad = tensor_apis.constant( + np.zeros((B, C, kernel_size - 1), dtype=np.float32).astype(x.dtype) + ) + padded = np.concatenate([zeros_pad, x], axis=2) + + # Depthwise convolution via shifted sums + # PyTorch conv1d convention: weight[0] is the most recent timestep + # With left-padding of K-1 zeros, output[t] = sum_k weight[k] * padded[t + K - 1 - k] + # Which means shift = k (not kernel_size - 1 - k) + output = padded[:, :, 0 : S] * np.expand_dims( + conv_weight[:, 0], axis=(0, 2) + ) + for k in range(1, kernel_size): + output = output + padded[:, :, k : S + k] * np.expand_dims( + conv_weight[:, k], axis=(0, 2) + ) + + # Save conv state (last kernel_size values of padded input) + conv_state = padded[:, :, -(kernel_size):] + + # SiLU activation + output = output * sigmoid(output) + return output, conv_state + + +def causal_conv1d_decode(x, conv_state, conv_weight, kernel_size): + """Causal 1D depthwise convolution for single-token decode. + + Args: + x: (B, C, 1) input token + conv_state: (B, C, kernel_size) previous state + conv_weight: (C, kernel_size) depthwise conv weights + kernel_size: int + + Returns: + output: (B, C, 1) convolved output + new_conv_state: (B, C, kernel_size) updated state + """ + # Shift state left, append new token + new_conv_state = np.concatenate([conv_state[:, :, 1:], x], axis=2) + # Dot product along kernel dimension + output = np.sum( + new_conv_state * np.expand_dims(conv_weight, axis=0), axis=2, keepdims=True + ) + # SiLU activation + output = output * sigmoid(output) + return output, new_conv_state + + +def gated_delta_net_kernel( + x, + qkv_weight, + z_weight, + b_weight, + a_weight, + conv_weight, + dt_bias, + A_log, + norm_weight, + out_weight, + norm_eps, + num_k_heads, + num_v_heads, + head_k_dim, + head_v_dim, + conv_kernel_size, + conv_state, + recurrent_state, + start_pos: Optional[nt.tensor], +): + """Gated Delta Net (linear attention) kernel for Qwen3.5. + + Implements the recurrent formulation of the gated delta rule. + """ + is_prefill = start_pos is None + batch_size, seq_len, _ = x.shape + + ws = dist.get_world_size() + n_local_k_heads = num_k_heads // ws + n_local_v_heads = num_v_heads // ws + key_dim_local = n_local_k_heads * head_k_dim + value_dim_local = n_local_v_heads * head_v_dim + v_per_k = n_local_v_heads // n_local_k_heads + + # Projections + mixed_qkv = np.matmul(x, qkv_weight) # (B, S, key_dim_local*2 + value_dim_local) + z = np.matmul(x, z_weight) # (B, S, value_dim_local) + b = np.matmul(x, b_weight) # (B, S, n_local_v_heads) + a = np.matmul(x, a_weight) # (B, S, n_local_v_heads) + + # Causal conv1d + mixed_qkv_t = mixed_qkv.transpose(0, 2, 1) # (B, C, S) + if is_prefill: + mixed_qkv_t, new_conv_state = causal_conv1d_prefill( + mixed_qkv_t, conv_weight, conv_kernel_size + ) + else: + mixed_qkv_t, new_conv_state = causal_conv1d_decode( + mixed_qkv_t, conv_state, conv_weight, conv_kernel_size + ) + + # Update conv state + conv_state[:] = new_conv_state + + mixed_qkv = mixed_qkv_t.transpose(0, 2, 1) # (B, S, C) + + # Split into Q, K, V + query, key, value = np.split( + mixed_qkv, + [key_dim_local, key_dim_local * 2], + axis=-1, + ) + + # Reshape to head dims + query = query.reshape(batch_size, seq_len, n_local_k_heads, head_k_dim) + key = key.reshape(batch_size, seq_len, n_local_k_heads, head_k_dim) + value = value.reshape(batch_size, seq_len, n_local_v_heads, head_v_dim) + + # L2 normalize Q and K + query = l2norm(query.astype(np.float32)).astype(query.dtype) + key = l2norm(key.astype(np.float32)).astype(key.dtype) + + # Compute gating: g = -exp(A_log) * softplus(a + dt_bias) + beta = sigmoid(b) # (B, S, n_local_v_heads) + g = -np.exp(A_log.astype(np.float32)) * softplus( + a.astype(np.float32) + dt_bias.astype(np.float32) + ) + + # Repeat K heads to match V heads if needed + if v_per_k > 1: + query = np.repeat(query, v_per_k, axis=2) + key = np.repeat(key, v_per_k, axis=2) + + # Scale query + scale = 1.0 / np.sqrt(np.float32(head_k_dim)) + + # Transpose to (B, heads, S, dim) for recurrence + query = query.transpose(0, 2, 1, 3).astype(np.float32) * scale + key = key.transpose(0, 2, 1, 3).astype(np.float32) + value = value.transpose(0, 2, 1, 3).astype(np.float32) + beta = beta.transpose(0, 2, 1).astype(np.float32) # (B, n_v_heads, S) + g = g.transpose(0, 2, 1).astype(np.float32) # (B, n_v_heads, S) + + # Recurrent gated delta rule + # recurrent_state: (B, n_local_v_heads, head_k_dim, head_v_dim) + state = recurrent_state.astype(np.float32) + + # Collect outputs in a list to avoid assigning traced tensors to numpy slices + output_steps = [] + + for i in range(seq_len): + q_t = query[:, :, i, :] # (B, n_v_heads, head_k_dim) + k_t = key[:, :, i, :] + v_t = value[:, :, i, :] # (B, n_v_heads, head_v_dim) + g_t = np.expand_dims(np.expand_dims(np.exp(g[:, :, i]), -1), -1) + beta_t = np.expand_dims(beta[:, :, i], -1) # (B, n_v_heads, 1) + + # Decay state + state = state * g_t + # Retrieve from state + kv_mem = np.sum(state * np.expand_dims(k_t, -1), axis=-2) + # Delta update + delta = (v_t - kv_mem) * beta_t + state = state + np.expand_dims(k_t, -1) * np.expand_dims(delta, -2) + # Query state -> (B, n_v_heads, head_v_dim) + step_out = np.sum(state * np.expand_dims(q_t, -1), axis=-2) + # Add sequence dimension: (B, n_v_heads, 1, head_v_dim) + output_steps.append(np.expand_dims(step_out, 2)) + + # Update recurrent state + recurrent_state[:] = state.astype(recurrent_state.dtype) + + # Concatenate along sequence dim: (B, n_v_heads, seq_len, head_v_dim) + core_output = np.concatenate(output_steps, axis=2) + + # Transpose back: (B, heads, S, dim) -> (B, S, heads, dim) -> (B, S, heads*dim) + core_output = core_output.transpose(0, 2, 1, 3) + + # RMSNormGated: reshape to 2D for norm, then back + core_flat = core_output.reshape(-1, head_v_dim) + z_flat = z.reshape(batch_size, seq_len, n_local_v_heads, head_v_dim) + z_flat = z_flat.reshape(-1, head_v_dim) + core_flat = rmsnorm_gated(core_flat, norm_weight, z_flat, norm_eps) + core_output = core_flat.reshape(batch_size, seq_len, -1) + + # Output projection + output_to_be_reduced = np.matmul(core_output, out_weight) + + # All-reduce (skip for TP=1, evaluated at trace time) + if dist.get_world_size() > 1: + output = cc.all_reduce( + output_to_be_reduced, + replica_groups=[list(range(dist.get_world_size()))], + reduce_op=np.add, + ) + else: + output = output_to_be_reduced + + # Cast back to input dtype (bfloat16) to match hidden_states + return output.astype(x.dtype) diff --git a/examples/models/qwen3_5_35B_A3B/kernels/rmsnorm.py b/examples/models/qwen3_5_35B_A3B/kernels/rmsnorm.py new file mode 100644 index 0000000..581ff89 --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/kernels/rmsnorm.py @@ -0,0 +1,21 @@ +import numpy as np + + +def rmsnorm_kernel( + x, + weight, + eps: float, + compute_dtype=np.float32, +): + """RMSNorm for Qwen3.5: output = (1 + weight) * (x / rms(x))""" + original_dtype = x.dtype + x = x.astype(compute_dtype) + weight = weight.astype(compute_dtype) + z = np.square(x) + z = np.mean(z, axis=-1, keepdims=True) + z = (z + eps).astype(x.dtype) + z = x / np.sqrt(z) + # Qwen3.5 uses (1 + weight) scaling (weight initialized to 0) + res = z * (1.0 + weight) + res = res.astype(original_dtype) + return res diff --git a/examples/models/qwen3_5_35B_A3B/kernels/rope.py b/examples/models/qwen3_5_35B_A3B/kernels/rope.py new file mode 100644 index 0000000..a2576a6 --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/kernels/rope.py @@ -0,0 +1,45 @@ +import numpy as np + + +def compute_cos_sin_cache( + rotary_dim: int, max_seq_len: int, base: float = 10000000.0, dtype=np.float32 +): + """Compute cosine and sine cache for partial RoPE. + + For Qwen3.5, rotary_dim = head_dim * partial_rotary_factor (e.g. 256 * 0.25 = 64). + The cache has rotary_dim/2 frequencies. + """ + freqs = 1.0 / ( + base ** (np.arange(0, rotary_dim, 2)[: (rotary_dim // 2)] / rotary_dim) + ) + t = np.arange(max_seq_len, dtype=np.float32) + freqs = np.outer(t, freqs) + return ( + np.cos(freqs, dtype=np.float32).astype(dtype), + np.sin(freqs, dtype=np.float32).astype(dtype), + ) + + +def apply_rotary_emb_kernel(xq, xk, freqs_cos, freqs_sin): + """Apply rotary position embedding to query and key tensors. + + Uses the rotate_half convention: (x * cos) + (rotate_half(x) * sin) + where rotate_half(x) = [-x2, x1] for x = [x1, x2]. + """ + freqs_cos = np.expand_dims(freqs_cos, axis=(0, 2)) + freqs_sin = np.expand_dims(freqs_sin, axis=(0, 2)) + + half_h = xq.shape[-1] // 2 + xq0 = xq[:, :, :, :half_h] + xq1 = xq[:, :, :, half_h:] + xk0 = xk[:, :, :, :half_h] + xk1 = xk[:, :, :, half_h:] + + xq_out_0 = xq0 * freqs_cos - xq1 * freqs_sin + xq_out_1 = xq0 * freqs_sin + xq1 * freqs_cos + xk_out_0 = xk0 * freqs_cos - xk1 * freqs_sin + xk_out_1 = xk0 * freqs_sin + xk1 * freqs_cos + + xq_out = np.concatenate([xq_out_0, xq_out_1], axis=-1) + xk_out = np.concatenate([xk_out_0, xk_out_1], axis=-1) + return xq_out, xk_out diff --git a/examples/models/qwen3_5_35B_A3B/kernels/sampling.py b/examples/models/qwen3_5_35B_A3B/kernels/sampling.py new file mode 100644 index 0000000..fa81c89 --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/kernels/sampling.py @@ -0,0 +1,35 @@ +import numpy as np +import torch.distributed as dist +import nkipy.distributed.collectives as cc +from nkipy.core import tensor_apis + +from config import Config + +from .rmsnorm import rmsnorm_kernel + + +def greedy_sampling(h, norm_weight, lm_head_weight, configs: Config): + """On-device greedy sampling: RMSNorm -> lm_head matmul -> all_gather -> topk. + + Returns a (B, 1) uint32 tensor of global token IDs. + + Works around the neuronx-cc bug where `all_gather` in the same kernel as + a prior `topk` corrupts the topk output (see bug_topk_allgather_dynamic_index.md). + The fix is to all_gather the full logits tensor first, then topk once on + the gathered result -- topk is never upstream of all_gather in the graph, + and no dynamic indexing is needed because topk over full-vocab logits + directly returns the global winner ID. + """ + h = rmsnorm_kernel(h, norm_weight, configs.norm_eps) + logits = h[:, -1, :] @ lm_head_weight # (B, vocab_per_device) + logits = logits.astype(np.float32) + + if dist.get_world_size() > 1: + logits = cc.all_gather( + logits, + all_gather_dim=1, + replica_groups=[list(range(dist.get_world_size()))], + ) # (B, vocab_total) + + _, next_id = tensor_apis.topk(logits, k=1, axis=1) # (B, 1) uint32 global + return next_id diff --git a/examples/models/qwen3_5_35B_A3B/kernels/softmax.py b/examples/models/qwen3_5_35B_A3B/kernels/softmax.py new file mode 100644 index 0000000..2f9a948 --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/kernels/softmax.py @@ -0,0 +1,6 @@ +import numpy as np + + +def softmax_kernel(x): + exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) + return exp_x / np.sum(exp_x, axis=-1, keepdims=True) diff --git a/examples/models/qwen3_5_35B_A3B/kernels/transformer_layer.py b/examples/models/qwen3_5_35B_A3B/kernels/transformer_layer.py new file mode 100644 index 0000000..d0e1aae --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/kernels/transformer_layer.py @@ -0,0 +1,220 @@ +import nkipy.distributed.collectives as cc +import numpy as np +import torch.distributed as dist +from nkipy.core import tensor_apis + +from config import Config + +from .attention import attention_kernel +from .feedforward import feedforward_kernel, shared_expert_kernel +from .linear_attention import gated_delta_net_kernel +from .rmsnorm import rmsnorm_kernel +from .softmax import softmax_kernel + + +def _moe_block( + norm_z, + z, + router_weight, + gate_up_weight, + down_weight, + shared_gate_proj_weight, + shared_up_proj_weight, + shared_down_proj_weight, + shared_expert_gate_weight, + configs: Config, +): + """MoE block with routed experts + shared expert (common to both layer types).""" + B, L, D = z.shape + n_experts = router_weight.shape[-1] + top_k = configs.num_experts_per_tok + + # Router scores + router_logits = np.matmul(norm_z, router_weight) + + # Routed expert output + routed_output = np.empty_like(z) + for b in range(B): + for t in range(L): + token_input = norm_z[b, t, :] + token_logits = router_logits[b, t] + token_logits = softmax_kernel(token_logits) + top_k_logits, top_k_indices = tensor_apis.topk(token_logits, k=top_k) + top_k_logits /= np.sum(top_k_logits, axis=-1, keepdims=True) + + token_output = tensor_apis.zeros((D), dtype=routed_output.dtype) + for e in range(top_k): + expert_idx = top_k_indices[e] + weight = top_k_logits[e] + expert_output = feedforward_kernel( + token_input, gate_up_weight[expert_idx], down_weight[expert_idx] + ) + token_output += weight * expert_output + routed_output[b, t] = token_output + + # Shared expert output + norm_z_flat = norm_z.reshape(-1, D) + shared_output = shared_expert_kernel( + norm_z_flat, + shared_gate_proj_weight, + shared_up_proj_weight, + shared_down_proj_weight, + ) + + # Shared expert gating + shared_gate = 1.0 / ( + 1.0 + np.exp(-(np.matmul(norm_z_flat, shared_expert_gate_weight)).astype(np.float32)) + ) + shared_gate = shared_gate.astype(shared_output.dtype) + shared_output = shared_gate * shared_output + shared_output = shared_output.reshape(B, L, D) + + # Combine routed + shared + output = routed_output + shared_output + + # All-reduce for tensor parallelism (skip for TP=1) + if dist.get_world_size() > 1: + output = cc.all_reduce( + output, replica_groups=[list(range(dist.get_world_size()))], reduce_op=np.add + ) + + final = z + output + return final.astype(z.dtype) + + +def transformer_layer_full_attn( + x, + start_pos, + # Attention weights + qkv_weight, + o_weight, + input_weight, + q_norm_weight, + k_norm_weight, + cache_k, + cache_v, + # MoE weights + post_attention_weight, + router_weight, + gate_up_weight, + down_weight, + shared_gate_proj_weight, + shared_up_proj_weight, + shared_down_proj_weight, + shared_expert_gate_weight, + configs: Config, +): + """Transformer layer with full attention (every 4th layer in Qwen3.5).""" + # Pre-attention RMSNorm + norm_x = rmsnorm_kernel(x, input_weight, configs.norm_eps) + + # Full attention with output gating and partial RoPE + h1 = attention_kernel( + norm_x, + qkv_weight, + q_norm_weight, + k_norm_weight, + configs.norm_eps, + configs.num_heads, + configs.head_dim, + configs.num_kv_heads, + configs.partial_rotary_factor, + configs.rope_theta, + cache_k, + cache_v, + start_pos=start_pos, + o_weight=o_weight, + ) + + z = x + h1 + + # Pre-MoE RMSNorm + norm_z = rmsnorm_kernel(z, post_attention_weight, configs.norm_eps) + + # MoE with shared expert + return _moe_block( + norm_z, + z, + router_weight, + gate_up_weight, + down_weight, + shared_gate_proj_weight, + shared_up_proj_weight, + shared_down_proj_weight, + shared_expert_gate_weight, + configs, + ) + + +def transformer_layer_linear_attn( + x, + start_pos, + # Linear attention weights + qkv_weight, + z_weight, + b_weight, + a_weight, + conv_weight, + dt_bias, + A_log, + linear_norm_weight, + out_weight, + input_weight, + conv_state, + recurrent_state, + # MoE weights + post_attention_weight, + router_weight, + gate_up_weight, + down_weight, + shared_gate_proj_weight, + shared_up_proj_weight, + shared_down_proj_weight, + shared_expert_gate_weight, + configs: Config, +): + """Transformer layer with linear attention (GatedDeltaNet) for Qwen3.5.""" + # Pre-attention RMSNorm + norm_x = rmsnorm_kernel(x, input_weight, configs.norm_eps) + + # Gated Delta Net (linear attention) + h1 = gated_delta_net_kernel( + norm_x, + qkv_weight, + z_weight, + b_weight, + a_weight, + conv_weight, + dt_bias, + A_log, + linear_norm_weight, + out_weight, + configs.norm_eps, + configs.linear_num_key_heads, + configs.linear_num_value_heads, + configs.linear_key_head_dim, + configs.linear_value_head_dim, + configs.linear_conv_kernel_dim, + conv_state, + recurrent_state, + start_pos=start_pos, + ) + + z = x + h1 + + # Pre-MoE RMSNorm + norm_z = rmsnorm_kernel(z, post_attention_weight, configs.norm_eps) + + # MoE with shared expert + return _moe_block( + norm_z, + z, + router_weight, + gate_up_weight, + down_weight, + shared_gate_proj_weight, + shared_up_proj_weight, + shared_down_proj_weight, + shared_expert_gate_weight, + configs, + ) diff --git a/examples/models/qwen3_5_35B_A3B/qwen3_5.py b/examples/models/qwen3_5_35B_A3B/qwen3_5.py new file mode 100644 index 0000000..167211a --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/qwen3_5.py @@ -0,0 +1,519 @@ +import argparse +import os +import sys +import time + +import numpy as np +import torch +import torch.distributed as dist +from config import FULL_ATTENTION, LINEAR_ATTENTION, Config, get_config +from kernels.sampling import greedy_sampling +from kernels.transformer_layer import ( + transformer_layer_full_attn, + transformer_layer_linear_attn, +) +from nkipy.runtime import DeviceKernel, DeviceTensor +from safetensors.torch import load_file +from transformers import AutoTokenizer +from utils import print_log + +BUILD_DIR = "./build" + + +class Qwen35Model: + def __init__(self, model_weights, config: Config): + self.config = config + self.tok_embedding = model_weights.get("tok_embedding") + + # Kernels (compiled lazily) + self.kernel_cte_full_attn = None + self.kernel_cte_linear_attn = None + self.kernel_tkg_full_attn = None + self.kernel_tkg_linear_attn = None + self.kernel_cte_greedy_sampling = None + self.kernel_tkg_greedy_sampling = None + + self.norm_weight = None + self.lm_head_weight = None + + self._prepare_tensors(model_weights) + self._prepare_kernels() + + def _prepare_tensors(self, weights): + t = time.time() + print_log("Preparing Tensors") + + ws = dist.get_world_size() + n_local_kv_heads = max(1, self.config.num_kv_heads // ws) + n_local_v_heads = self.config.linear_num_value_heads // ws + n_local_k_heads = self.config.linear_num_key_heads // ws + key_dim_local = n_local_k_heads * self.config.linear_key_head_dim + value_dim_local = n_local_v_heads * self.config.linear_value_head_dim + conv_dim_local = key_dim_local * 2 + value_dim_local + + self.layer_tensors = [] + for layer_id in range(self.config.num_layers): + layer_type = self.config.layer_types[layer_id] + layer_dict = {} + + # Common MoE weights + for key in [ + "input_weight", + "post_attention_weight", + "router_weight", + "gate_up_weight", + "down_weight", + "shared_gate_proj_weight", + "shared_up_proj_weight", + "shared_down_proj_weight", + "shared_expert_gate_weight", + ]: + w = weights.get(f"layers.{layer_id}.{key}") + layer_dict[key] = DeviceTensor.from_torch( + w, f"{key}_L{layer_id}" + ) + + if layer_type == FULL_ATTENTION: + # Full attention weights + KV cache + for key in ["qkv_weight", "o_weight", "q_norm_weight", "k_norm_weight"]: + w = weights.get(f"layers.{layer_id}.{key}") + layer_dict[key] = DeviceTensor.from_torch( + w, f"{key}_L{layer_id}" + ) + + cache_k = np.zeros( + (self.config.max_batch_size, self.config.max_seq_len, + n_local_kv_heads, self.config.head_dim), + dtype=self.config.dtype, + ) + cache_v = np.zeros_like(cache_k) + layer_dict["cache_k"] = DeviceTensor.from_numpy( + cache_k, f"cache_k_L{layer_id}" + ) + layer_dict["cache_v"] = DeviceTensor.from_numpy( + cache_v, f"cache_v_L{layer_id}" + ) + + else: # LINEAR_ATTENTION + for key in [ + "linear_qkv_weight", "linear_z_weight", + "linear_b_weight", "linear_a_weight", + "linear_conv_weight", "linear_dt_bias", + "linear_A_log", "linear_norm_weight", + "linear_out_weight", + ]: + w = weights.get(f"layers.{layer_id}.{key}") + layer_dict[key] = DeviceTensor.from_torch( + w, f"{key}_L{layer_id}" + ) + + conv_state = np.zeros( + (self.config.max_batch_size, conv_dim_local, + self.config.linear_conv_kernel_dim), + dtype=self.config.dtype, + ) + recurrent_state = np.zeros( + (self.config.max_batch_size, n_local_v_heads, + self.config.linear_key_head_dim, + self.config.linear_value_head_dim), + dtype=self.config.dtype, + ) + layer_dict["conv_state"] = DeviceTensor.from_numpy( + conv_state, f"conv_state_L{layer_id}" + ) + layer_dict["recurrent_state"] = DeviceTensor.from_numpy( + recurrent_state, f"recurrent_state_L{layer_id}" + ) + + layer_dict["layer_type"] = layer_type + self.layer_tensors.append(layer_dict) + + self.norm_weight = DeviceTensor.from_torch( + weights.get("norm_weight"), "norm_weight" + ) + self.lm_head_weight = DeviceTensor.from_torch( + weights.get("lm_head_weight"), "lm_head_weight" + ) + + print_log(f"--> Finished Preparing Tensors in {time.time() - t:.2f}s") + + def _find_first_layer_of_type(self, layer_type): + for i, lt in enumerate(self.layer_tensors): + if lt["layer_type"] == layer_type: + return i + return None + + def _prepare_kernels(self): + t = time.time() + print_log("Preparing kernels") + + x_context = DeviceTensor.from_numpy( + np.empty( + (self.config.max_batch_size, self.config.context_len, + self.config.hidden_size), + dtype=self.config.dtype, + ), + "x_context", + ) + x_token = DeviceTensor.from_numpy( + np.empty( + (self.config.max_batch_size, 1, self.config.hidden_size), + dtype=self.config.dtype, + ), + "x_token", + ) + start_pos = DeviceTensor.from_numpy( + np.empty(shape=(1), dtype=np.int32), "start_pos" + ) + + # --- Compile full attention kernels --- + fa_idx = self._find_first_layer_of_type(FULL_ATTENTION) + if fa_idx is not None: + fa = self.layer_tensors[fa_idx] + fa_common = dict( + qkv_weight=fa["qkv_weight"], + o_weight=fa["o_weight"], + input_weight=fa["input_weight"], + q_norm_weight=fa["q_norm_weight"], + k_norm_weight=fa["k_norm_weight"], + post_attention_weight=fa["post_attention_weight"], + router_weight=fa["router_weight"], + gate_up_weight=fa["gate_up_weight"], + down_weight=fa["down_weight"], + shared_gate_proj_weight=fa["shared_gate_proj_weight"], + shared_up_proj_weight=fa["shared_up_proj_weight"], + shared_down_proj_weight=fa["shared_down_proj_weight"], + shared_expert_gate_weight=fa["shared_expert_gate_weight"], + configs=self.config, + build_dir=BUILD_DIR, + additional_compiler_args=self.config.additional_compiler_args_nkipy, + ) + + self.kernel_cte_full_attn = DeviceKernel.compile_and_load( + transformer_layer_full_attn, + name="cte_full_attn", + x=x_context, + start_pos=None, + cache_k=fa["cache_k"], + cache_v=fa["cache_v"], + **fa_common, + ) + self.kernel_tkg_full_attn = DeviceKernel.compile_and_load( + transformer_layer_full_attn, + name="tkg_full_attn", + x=x_token, + start_pos=start_pos, + cache_k=fa["cache_k"], + cache_v=fa["cache_v"], + **fa_common, + ) + + # --- Compile linear attention kernels --- + la_idx = self._find_first_layer_of_type(LINEAR_ATTENTION) + if la_idx is not None: + la = self.layer_tensors[la_idx] + la_common = dict( + qkv_weight=la["linear_qkv_weight"], + z_weight=la["linear_z_weight"], + b_weight=la["linear_b_weight"], + a_weight=la["linear_a_weight"], + conv_weight=la["linear_conv_weight"], + dt_bias=la["linear_dt_bias"], + A_log=la["linear_A_log"], + linear_norm_weight=la["linear_norm_weight"], + out_weight=la["linear_out_weight"], + input_weight=la["input_weight"], + post_attention_weight=la["post_attention_weight"], + router_weight=la["router_weight"], + gate_up_weight=la["gate_up_weight"], + down_weight=la["down_weight"], + shared_gate_proj_weight=la["shared_gate_proj_weight"], + shared_up_proj_weight=la["shared_up_proj_weight"], + shared_down_proj_weight=la["shared_down_proj_weight"], + shared_expert_gate_weight=la["shared_expert_gate_weight"], + configs=self.config, + build_dir=BUILD_DIR, + additional_compiler_args=self.config.additional_compiler_args_nkipy, + ) + + self.kernel_cte_linear_attn = DeviceKernel.compile_and_load( + transformer_layer_linear_attn, + name="cte_linear_attn", + x=x_context, + start_pos=None, + conv_state=la["conv_state"], + recurrent_state=la["recurrent_state"], + **la_common, + ) + self.kernel_tkg_linear_attn = DeviceKernel.compile_and_load( + transformer_layer_linear_attn, + name="tkg_linear_attn", + x=x_token, + start_pos=start_pos, + conv_state=la["conv_state"], + recurrent_state=la["recurrent_state"], + **la_common, + ) + + # --- Compile on-device greedy sampling kernels --- + # RMSNorm + lm_head matmul + all_gather + global topk all on device. + # Output is (B,) uint32 global token id (topk over gathered logits collapses + # the batch axis to 1-D on the device side; we reshape to (B, 1) on host). + ws = dist.get_world_size() + vocab_per_device = self.lm_head_weight.numpy().shape[1] + self.vocab_per_device = vocab_per_device + + self.d_next_id_ctx = DeviceTensor.from_numpy( + np.empty((self.config.max_batch_size,), dtype=np.uint32), + "next_id_ctx", + ) + self.d_next_id_tok = DeviceTensor.from_numpy( + np.empty((self.config.max_batch_size,), dtype=np.uint32), + "next_id_tok", + ) + + self.kernel_cte_greedy_sampling = DeviceKernel.compile_and_load( + greedy_sampling, + name="cte_greedy_sampling", + h=x_context, + norm_weight=self.norm_weight, + lm_head_weight=self.lm_head_weight, + configs=self.config, + build_dir=BUILD_DIR, + additional_compiler_args=self.config.additional_compiler_args_nkipy, + ) + self.kernel_tkg_greedy_sampling = DeviceKernel.compile_and_load( + greedy_sampling, + name="tkg_greedy_sampling", + h=x_token, + norm_weight=self.norm_weight, + lm_head_weight=self.lm_head_weight, + configs=self.config, + build_dir=BUILD_DIR, + additional_compiler_args=self.config.additional_compiler_args_nkipy, + ) + + print_log( + f"--> Finished Kernel Compilation and Loading in {time.time() - t:.2f}s" + ) + + def _run_layer(self, kernel_fa, kernel_la, layer_idx, hidden_states, t_start_pos): + lt = self.layer_tensors[layer_idx] + layer_type = lt["layer_type"] + + # Common MoE inputs + moe_inputs = { + "input_weight": lt["input_weight"], + "post_attention_weight": lt["post_attention_weight"], + "router_weight": lt["router_weight"], + "gate_up_weight": lt["gate_up_weight"], + "down_weight": lt["down_weight"], + "shared_gate_proj_weight": lt["shared_gate_proj_weight"], + "shared_up_proj_weight": lt["shared_up_proj_weight"], + "shared_down_proj_weight": lt["shared_down_proj_weight"], + "shared_expert_gate_weight": lt["shared_expert_gate_weight"], + } + + if layer_type == FULL_ATTENTION: + inputs = { + "x": hidden_states, + "qkv_weight": lt["qkv_weight"], + "o_weight": lt["o_weight"], + "q_norm_weight": lt["q_norm_weight"], + "k_norm_weight": lt["k_norm_weight"], + "cache_k.must_alias_input": lt["cache_k"], + "cache_v.must_alias_input": lt["cache_v"], + **moe_inputs, + } + if t_start_pos is not None: + inputs["start_pos"] = t_start_pos + outputs = { + "output0": hidden_states, + "cache_k": lt["cache_k"], + "cache_v": lt["cache_v"], + } + kernel_fa(inputs=inputs, outputs=outputs) + else: + inputs = { + "x": hidden_states, + "qkv_weight": lt["linear_qkv_weight"], + "z_weight": lt["linear_z_weight"], + "b_weight": lt["linear_b_weight"], + "a_weight": lt["linear_a_weight"], + "conv_weight": lt["linear_conv_weight"], + "dt_bias": lt["linear_dt_bias"], + "A_log": lt["linear_A_log"], + "linear_norm_weight": lt["linear_norm_weight"], + "out_weight": lt["linear_out_weight"], + "conv_state.must_alias_input": lt["conv_state"], + "recurrent_state.must_alias_input": lt["recurrent_state"], + **moe_inputs, + } + if t_start_pos is not None: + inputs["start_pos"] = t_start_pos + outputs = { + "output0": hidden_states, + "conv_state": lt["conv_state"], + "recurrent_state": lt["recurrent_state"], + } + kernel_la(inputs=inputs, outputs=outputs) + + def _sample_token(self, kernel, hidden_states, d_next_id): + """Run on-device greedy sampling; return (B, 1) torch.int token IDs. + + The kernel does RMSNorm + lm_head matmul + all_gather + global topk entirely + on device. We read back (B,) uint32 and reshape to (B, 1) torch int32 to + match the previous argmax interface. + """ + kernel( + inputs={ + "h": hidden_states, + "norm_weight": self.norm_weight, + "lm_head_weight": self.lm_head_weight, + }, + outputs={"output0": d_next_id}, + ) + next_id_np = d_next_id.numpy().astype(np.int32).reshape(-1, 1) + return torch.from_numpy(next_id_np) + + def generate(self, input_ids): + context_len = self.config.context_len + + # Reset GDN states (conv_state, recurrent_state) before each generation + # Unlike KV cache (position-addressed), GDN state is accumulated, so must be zeroed. + for lt in self.layer_tensors: + if lt["layer_type"] == LINEAR_ATTENTION: + lt["conv_state"].write_from_numpy( + np.zeros(lt["conv_state"].numpy().shape, dtype=lt["conv_state"].numpy().dtype) + ) + lt["recurrent_state"].write_from_numpy( + np.zeros(lt["recurrent_state"].numpy().shape, dtype=lt["recurrent_state"].numpy().dtype) + ) + + hidden_states = DeviceTensor.from_torch( + self.tok_embedding[input_ids], "hidden_states" + ) + + # --- Prefill (context phase) --- + for i in range(self.config.num_layers): + self._run_layer( + self.kernel_cte_full_attn, + self.kernel_cte_linear_attn, + i, + hidden_states, + None, + ) + + next_id_torch = self._sample_token( + self.kernel_cte_greedy_sampling, hidden_states, self.d_next_id_ctx + ) + yield next_id_torch + + # --- Decode (token-by-token) --- + for pos in range(context_len, context_len + self.config.max_new_tokens): + t_start_pos = DeviceTensor.from_numpy( + np.array([pos], dtype=np.int32) + ) + hidden_states = DeviceTensor.from_torch( + self.tok_embedding[next_id_torch], "h0/res1" + ) + t_res1 = hidden_states + + for i in range(self.config.num_layers): + self._run_layer( + self.kernel_tkg_full_attn, + self.kernel_tkg_linear_attn, + i, + hidden_states, + t_start_pos, + ) + + next_id_torch = self._sample_token( + self.kernel_tkg_greedy_sampling, t_res1, self.d_next_id_tok + ) + yield next_id_torch + + +def load_model(args): + os.environ["TOKENIZERS_PARALLELISM"] = "true" + os.environ["OMP_NUM_THREADS"] = "1" + os.environ["NEURON_RT_ROOT_COMM_ID"] = "localhost:61239" + + dist.init_process_group() + torch.set_num_threads(128 // dist.get_world_size()) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + os.environ["NEURON_RT_VISIBLE_CORES"] = str(dist.get_rank()) + + tokenizer = AutoTokenizer.from_pretrained(args.model) + model_inputs = tokenizer(args.prompt, return_tensors="np") + input_ids = model_inputs["input_ids"] + config = get_config(args.model, input_ids.shape[1], args.max_new_tokens) + + print_log("Loading Model Weights") + + shard_path = os.path.join( + args.checkpoint, f"shard_{dist.get_rank()}.safetensors" + ) + weights = load_file(shard_path, device="cpu") + + model = Qwen35Model(weights, config) + + # Warming + start = time.time() + print_log("Warming model") + t = 0 + for id in model.generate(input_ids): + if t == 1: + break + t += 1 + print_log(f"--> Finished warming the model in {time.time() - start:.2f}s") + + return model, input_ids, tokenizer + + +# EOS tokens for Qwen3.5 +EOS_TOKENS = {248044} + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-n", "--max-new-tokens", type=int, default=16) + parser.add_argument("prompt", nargs="?", default="The capital of France is") + parser.add_argument( + "--checkpoint", default="./qwen3_5_shards" + ) + parser.add_argument("--model", default="Qwen/Qwen3.5-35B-A3B") + args = parser.parse_args() + + model, input_ids, tokenizer = load_model(args) + + dist.barrier() + start = time.time() + t = 0 + if dist.get_rank() == 0: + print(f"\n{args.prompt}", end="") + for id in model.generate(input_ids): + if t == 0: + first_token_time = time.time() + t += 1 + output_id = id[0].tolist() + if output_id[-1] in EOS_TOKENS: + print_log("Found EOS token, stop early") + break + if dist.get_rank() == 0: + print(tokenizer.decode(output_id), end="") + sys.stdout.flush() + + end_time = time.time() + + ttft = first_token_time - start + decoding_time = end_time - first_token_time + tokens_per_second = t / decoding_time + if dist.get_rank() == 0: + print(f"\nTime to first token: {ttft:.2f}s") + print(f"Decoding tokens per second: {tokens_per_second:.2f}") + + +if __name__ == "__main__": + main() diff --git a/examples/models/qwen3_5_35B_A3B/tensor_preparation.py b/examples/models/qwen3_5_35B_A3B/tensor_preparation.py new file mode 100644 index 0000000..8c1c776 --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/tensor_preparation.py @@ -0,0 +1,422 @@ +#!/usr/bin/env python3 +"""Pre-shard Qwen3.5-35B-A3B weights for tensor-parallel inference on Trainium. + +Handles the hybrid architecture: +- Full attention layers: Q (with gate), K, V, O projections + QK norms +- Linear attention layers: GatedDeltaNet projections, conv, state params +- MoE layers: routed experts + shared expert +""" + +import fnmatch +import os + +import numpy as np +import torch +from safetensors.torch import save_file + +# TP plan +base_model_tp_plan = { + # Full attention + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + # Linear attention (GatedDeltaNet) + "layers.*.linear_attn.in_proj_qkv": "colwise", + "layers.*.linear_attn.in_proj_z": "colwise", + "layers.*.linear_attn.in_proj_b": "colwise", + "layers.*.linear_attn.in_proj_a": "colwise", + "layers.*.linear_attn.out_proj": "rowwise", + # MoE routed experts (3D tensors) + "layers.*.mlp.experts.gate_up_proj": "colwise", + "layers.*.mlp.experts.down_proj": "rowwise", + # Shared expert + "layers.*.mlp.shared_expert.gate_proj": "colwise", + "layers.*.mlp.shared_expert.up_proj": "colwise", + "layers.*.mlp.shared_expert.down_proj": "rowwise", + # LM head + "lm_head": "colwise", +} + +_STATE_ITEMS = None +_WORLD_SIZE = None +_OUTPUT_DIR = None +_CONFIG = None +_LAYER_TYPES = None + + +def get_split_style(name: str): + for pat, style in base_model_tp_plan.items(): + if fnmatch.fnmatch(name, f"*{pat}*"): + return style + return None + + +def get_split_dim(name: str, tensor: torch.Tensor): + style = get_split_style(name) + if style is None: + return None + ndim = tensor.ndim + if ndim == 2: + return 0 if style == "colwise" else 1 + elif ndim == 3: + return 1 if style == "colwise" else 2 + else: + return None + + +def shard_1d_by_heads(tensor, rank, world_size): + """Shard a 1D tensor (e.g., dt_bias, A_log) evenly across ranks.""" + n = tensor.shape[0] + chunk_size = n // world_size + return tensor[rank * chunk_size : (rank + 1) * chunk_size] + + +def shard_conv_weight(tensor, rank, world_size): + """Shard conv1d weight: (channels, 1, kernel_size) -> shard channels.""" + n_channels = tensor.shape[0] + chunk_size = n_channels // world_size + return tensor[rank * chunk_size : (rank + 1) * chunk_size] + + +def build_and_save_shard(rank, head_dim): + shard = {} + for name, tensor in _STATE_ITEMS: + t = tensor + dim = get_split_dim(name, tensor) + + # Special: 1D params sharded by head count + if any( + k in name + for k in ["linear_attn.dt_bias", "linear_attn.A_log"] + ): + shard[name] = shard_1d_by_heads(t, rank, _WORLD_SIZE) + continue + + # Special: in_proj_qkv needs interleaved Q/K/V sharding + # Weight shape (conv_dim, hidden) where conv_dim = key_dim*2 + value_dim + # Naive colwise shard gives rank 0 ALL of Q, rank 1 ALL of K, etc. + # Correct: each rank gets its portion of Q, K, AND V. + if "linear_attn.in_proj_qkv" in name: + key_dim = _CONFIG.linear_num_key_heads * _CONFIG.linear_key_head_dim + value_dim = _CONFIG.linear_num_value_heads * _CONFIG.linear_value_head_dim + q_part = t[:key_dim] + k_part = t[key_dim : key_dim * 2] + v_part = t[key_dim * 2 :] + q_local = q_part.chunk(_WORLD_SIZE, dim=0)[rank] + k_local = k_part.chunk(_WORLD_SIZE, dim=0)[rank] + v_local = v_part.chunk(_WORLD_SIZE, dim=0)[rank] + shard[name] = torch.cat([q_local, k_local, v_local], dim=0) + continue + + # Special: conv1d weight needs same interleaved sharding as in_proj_qkv + # Weight shape (conv_dim, 1, kernel_size), channels match QKV ordering + if "linear_attn.conv1d.weight" in name: + key_dim = _CONFIG.linear_num_key_heads * _CONFIG.linear_key_head_dim + value_dim = _CONFIG.linear_num_value_heads * _CONFIG.linear_value_head_dim + q_ch = t[:key_dim] + k_ch = t[key_dim : key_dim * 2] + v_ch = t[key_dim * 2 :] + q_local = q_ch.chunk(_WORLD_SIZE, dim=0)[rank] + k_local = k_ch.chunk(_WORLD_SIZE, dim=0)[rank] + v_local = v_ch.chunk(_WORLD_SIZE, dim=0)[rank] + shard[name] = torch.cat([q_local, k_local, v_local], dim=0) + continue + + # Don't shard + if dim is None or t.numel() == 1 or t.size(dim) % _WORLD_SIZE != 0: + shard[name] = t + continue + + # KV projection: shard till head dim, then replicate + if ("k_proj" in name or "v_proj" in name) and ( + t.shape[dim] // _WORLD_SIZE < head_dim + ): + n_kv_heads = tensor.shape[0] // head_dim + tensor_r = tensor.reshape(-1, head_dim, tensor.shape[1]) + head_index = np.floor(n_kv_heads * rank / _WORLD_SIZE).astype(int) + part = tensor_r[head_index] + shard[name] = part + continue + + # 3D expert gate_up_proj: shard gate and up separately + if "experts.gate_up_proj" in name and t.ndim == 3: + intermediate = t.shape[1] // 2 + gate = t[:, :intermediate, :] + up = t[:, intermediate:, :] + gate_part = gate.chunk(_WORLD_SIZE, dim=1)[rank] + up_part = up.chunk(_WORLD_SIZE, dim=1)[rank] + part = torch.cat([gate_part, up_part], dim=1) + shard[name] = part + continue + + # Normal shard + part = t.chunk(_WORLD_SIZE, dim=dim)[rank] + shard[name] = part + + processed_shard = post_process_shard(shard, rank) + + for name, part in processed_shard.items(): + processed_shard[name] = part.contiguous() + + path = os.path.join(_OUTPUT_DIR, f"shard_{rank}.safetensors") + save_file(processed_shard, path) + del shard, processed_shard + + +def post_process_shard(shard, rank): + """Transform shard weights into the target format for inference.""" + processed = {} + n_layers = _CONFIG.num_hidden_layers + + # Token embeddings and final norm/head + if "model.norm.weight" in shard: + processed["norm_weight"] = shard["model.norm.weight"] + if "lm_head.weight" in shard: + processed["lm_head_weight"] = shard["lm_head.weight"].T + if "model.embed_tokens.weight" in shard: + processed["tok_embedding"] = shard["model.embed_tokens.weight"] + + for layer_id in range(n_layers): + prefix = f"model.layers.{layer_id}" + layer_type = _LAYER_TYPES[layer_id] + + # --- Common: layer norms --- + if f"{prefix}.input_layernorm.weight" in shard: + processed[f"layers.{layer_id}.input_weight"] = shard[ + f"{prefix}.input_layernorm.weight" + ] + processed[f"layers.{layer_id}.post_attention_weight"] = shard[ + f"{prefix}.post_attention_layernorm.weight" + ] + + # --- Full attention layer --- + if layer_type == "full_attention": + q_weight = shard.get(f"{prefix}.self_attn.q_proj.weight") + k_weight = shard.get(f"{prefix}.self_attn.k_proj.weight") + v_weight = shard.get(f"{prefix}.self_attn.v_proj.weight") + o_weight = shard.get(f"{prefix}.self_attn.o_proj.weight") + + if q_weight is not None and k_weight is not None and v_weight is not None: + # q_weight includes gate (2x head_dim per head) + qkv_weight = torch.cat( + [q_weight.T, k_weight.T, v_weight.T], axis=-1 + ) + processed[f"layers.{layer_id}.qkv_weight"] = qkv_weight + + if o_weight is not None: + processed[f"layers.{layer_id}.o_weight"] = o_weight.T + + if f"{prefix}.self_attn.q_norm.weight" in shard: + processed[f"layers.{layer_id}.q_norm_weight"] = shard[ + f"{prefix}.self_attn.q_norm.weight" + ] + processed[f"layers.{layer_id}.k_norm_weight"] = shard[ + f"{prefix}.self_attn.k_norm.weight" + ] + + # --- Linear attention layer (GatedDeltaNet) --- + elif layer_type == "linear_attention": + # QKV projection + w = shard.get(f"{prefix}.linear_attn.in_proj_qkv.weight") + if w is not None: + processed[f"layers.{layer_id}.linear_qkv_weight"] = w.T + + # Z projection (gate for RMSNormGated) + w = shard.get(f"{prefix}.linear_attn.in_proj_z.weight") + if w is not None: + processed[f"layers.{layer_id}.linear_z_weight"] = w.T + + # Beta projection + w = shard.get(f"{prefix}.linear_attn.in_proj_b.weight") + if w is not None: + processed[f"layers.{layer_id}.linear_b_weight"] = w.T + + # Alpha projection + w = shard.get(f"{prefix}.linear_attn.in_proj_a.weight") + if w is not None: + processed[f"layers.{layer_id}.linear_a_weight"] = w.T + + # Conv1d weight: (channels, 1, kernel_size) -> (channels, kernel_size) + w = shard.get(f"{prefix}.linear_attn.conv1d.weight") + if w is not None: + processed[f"layers.{layer_id}.linear_conv_weight"] = w.squeeze(1) + + # dt_bias + w = shard.get(f"{prefix}.linear_attn.dt_bias") + if w is not None: + processed[f"layers.{layer_id}.linear_dt_bias"] = w + + # A_log + w = shard.get(f"{prefix}.linear_attn.A_log") + if w is not None: + processed[f"layers.{layer_id}.linear_A_log"] = w + + # RMSNormGated weight + w = shard.get(f"{prefix}.linear_attn.norm.weight") + if w is not None: + processed[f"layers.{layer_id}.linear_norm_weight"] = w + + # Output projection + w = shard.get(f"{prefix}.linear_attn.out_proj.weight") + if w is not None: + processed[f"layers.{layer_id}.linear_out_weight"] = w.T + + # --- MoE --- + # Router weight + if f"{prefix}.mlp.gate.weight" in shard: + processed[f"layers.{layer_id}.router_weight"] = shard[ + f"{prefix}.mlp.gate.weight" + ].T + + # Routed experts (transformers 5.0+ 3D format) + gate_up_key = f"{prefix}.mlp.experts.gate_up_proj" + down_key = f"{prefix}.mlp.experts.down_proj" + if gate_up_key in shard: + processed[f"layers.{layer_id}.gate_up_weight"] = shard[ + gate_up_key + ].transpose(1, 2) + processed[f"layers.{layer_id}.down_weight"] = shard[ + down_key + ].transpose(1, 2) + else: + # Pre-5.0 format: separate expert tensors + num_experts = 0 + while f"{prefix}.mlp.experts.{num_experts}.gate_proj.weight" in shard: + num_experts += 1 + + if num_experts > 0: + gate_up_weights = [] + down_weights = [] + for expert_id in range(num_experts): + gate_w = shard.get( + f"{prefix}.mlp.experts.{expert_id}.gate_proj.weight" + ) + up_w = shard.get( + f"{prefix}.mlp.experts.{expert_id}.up_proj.weight" + ) + down_w = shard.get( + f"{prefix}.mlp.experts.{expert_id}.down_proj.weight" + ) + if gate_w is not None and up_w is not None: + gate_up_weights.append( + torch.cat([gate_w.T, up_w.T], axis=-1) + ) + if down_w is not None: + down_weights.append(down_w.T) + + if gate_up_weights: + processed[f"layers.{layer_id}.gate_up_weight"] = torch.stack( + gate_up_weights + ) + if down_weights: + processed[f"layers.{layer_id}.down_weight"] = torch.stack( + down_weights + ) + + # Shared expert + for proj_name in ["gate_proj", "up_proj", "down_proj"]: + key = f"{prefix}.mlp.shared_expert.{proj_name}.weight" + if key in shard: + processed[f"layers.{layer_id}.shared_{proj_name}_weight"] = shard[ + key + ].T + + # Shared expert gate + key = f"{prefix}.mlp.shared_expert_gate.weight" + if key in shard: + processed[f"layers.{layer_id}.shared_expert_gate_weight"] = shard[key].T + + return processed + + +def preshard_model( + model_name: str, + output_dir: str, + world_size: int, + head_dim: int, + dtype: torch.dtype = torch.bfloat16, +): + global _STATE_ITEMS, _WORLD_SIZE, _OUTPUT_DIR, _CONFIG, _LAYER_TYPES + + os.makedirs(output_dir, exist_ok=True) + _WORLD_SIZE = world_size + _OUTPUT_DIR = output_dir + + print(f"[1/3] Loading full model `{model_name}` onto CPU...") + + from transformers import AutoConfig, AutoModelForCausalLM + + # Load config to get layer types + hf_config = AutoConfig.from_pretrained(model_name) + text_cfg = ( + hf_config.text_config if hasattr(hf_config, "text_config") else hf_config + ) + _LAYER_TYPES = list(text_cfg.layer_types) + _CONFIG = text_cfg + + # Try loading as CausalLM first (text-only) + try: + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="cpu", + dtype=dtype, + low_cpu_mem_usage=False, + ) + except Exception: + # Fallback: load full multimodal model and extract language model + from transformers import AutoModel + + full_model = AutoModel.from_pretrained( + model_name, + device_map="cpu", + dtype=dtype, + low_cpu_mem_usage=False, + ) + model = full_model.language_model + + _STATE_ITEMS = list(model.state_dict().items()) + + print(f"[2/3] Splitting, post-processing, and saving {_WORLD_SIZE} shards...") + + for rank in range(_WORLD_SIZE): + build_and_save_shard(rank, head_dim) + + print(f"[3/3] Done! {_WORLD_SIZE} post-processed shards saved in {_OUTPUT_DIR}.") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Pre-shard Qwen3.5-35B-A3B for tensor-parallel inference." + ) + parser.add_argument( + "--model-name", + required=True, + help="HF repo or local path, e.g. Qwen/Qwen3.5-35B-A3B", + ) + parser.add_argument("--output-dir", default="qwen3_5_shards") + parser.add_argument( + "--world-size", type=int, required=True, help="Number of tensor-parallel ranks" + ) + parser.add_argument( + "--dtype", + choices=["f32", "f16", "bf16"], + default="bf16", + help="Data type to load/save", + ) + parser.add_argument("--head-dim", type=int, default=256, help="The head dim size") + + args = parser.parse_args() + dtype = {"f32": torch.float32, "f16": torch.float16, "bf16": torch.bfloat16}[ + args.dtype + ] + + preshard_model( + args.model_name, + args.output_dir, + args.world_size, + args.head_dim, + dtype=dtype, + ) diff --git a/examples/models/qwen3_5_35B_A3B/test.sh b/examples/models/qwen3_5_35B_A3B/test.sh new file mode 100644 index 0000000..da1df39 --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/test.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Test script for Qwen3.5-35B-A3B on Trainium +# Usage: bash test.sh + +set -e + +echo "==========================================" +echo "Qwen3.5-35B-A3B Test Script" +echo "==========================================" + +# Step 1: Clean compilation cache +echo "" +echo "[1/3] Cleaning compilation cache..." +rm -rf build/ 2>/dev/null || true +echo "Done" + +# Step 2: Check and prepare weights +echo "" +echo "[2/3] Checking weights..." + +WEIGHTS_PATH="./qwen3_5_shards" +TP_DEGREE=4 + +if [ ! -d "$WEIGHTS_PATH" ]; then + echo "Weights not found. Downloading and converting..." + python tensor_preparation.py --model-name Qwen/Qwen3.5-35B-A3B --world-size "$TP_DEGREE" --head-dim 256 --output-dir="$WEIGHTS_PATH" + echo "Done" +else + echo "Weights found at $WEIGHTS_PATH" +fi + +# Step 3: Run inference +echo "" +echo "[3/3] Running Qwen3.5 inference..." +echo "==========================================" + +export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=16 +export MALLOC_ARENA_MAX=2 +export MALLOC_TRIM_THRESHOLD_=-1 +export MALLOC_MMAP_THRESHOLD_=131072 +torchrun --nproc-per-node "$TP_DEGREE" qwen3_5.py -n 256 --checkpoint "$WEIGHTS_PATH" --model Qwen/Qwen3.5-35B-A3B "what is capital city of Austria?" + +echo "" +echo "==========================================" +echo "Test passed!" +echo "==========================================" diff --git a/examples/models/qwen3_5_35B_A3B/test_0_8b.py b/examples/models/qwen3_5_35B_A3B/test_0_8b.py new file mode 100644 index 0000000..5ce240b --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/test_0_8b.py @@ -0,0 +1,416 @@ +"""End-to-end test: Qwen3.5-0.8B on Trainium with TP=1 (no sharding). + +Tests GDN correctness without any TP complexity. +The 0.8B model is dense (no MoE), same GDN architecture as 35B-A3B. + +Usage: + # Step 1: Prepare weights (only needed once) + uv run python test_0_8b.py --prepare + + # Step 2: Run inference + validate + uv run torchrun --nproc_per_node 1 test_0_8b.py --run "The capital of France is" +""" + +import argparse +import os +import sys +import time + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +import ml_dtypes +import numpy as np +import torch + +MODEL_NAME = "Qwen/Qwen3.5-0.8B" +WEIGHTS_DIR = "./qwen3_5_0_8b_shards" +BUILD_DIR = "./build_0_8b" + +bf16 = ml_dtypes.bfloat16 + + +# ========================================================================= +# Weight preparation (no TP, just rename/transpose) +# ========================================================================= +def prepare_weights(): + from safetensors.torch import save_file + from transformers import AutoConfig, AutoModelForCausalLM + + config = AutoConfig.from_pretrained(MODEL_NAME) + text_cfg = config.text_config if hasattr(config, "text_config") else config + + print(f"Loading {MODEL_NAME}...") + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, dtype=torch.bfloat16, device_map="cpu", low_cpu_mem_usage=False + ) + + sd = model.state_dict() + processed = {} + n_layers = text_cfg.num_hidden_layers + layer_types = list(text_cfg.layer_types) + + # Global + processed["tok_embedding"] = sd["model.embed_tokens.weight"] + processed["norm_weight"] = sd["model.norm.weight"] + processed["lm_head_weight"] = sd["lm_head.weight"].T + + for i in range(n_layers): + p = f"model.layers.{i}" + lt = layer_types[i] + + # Norms + processed[f"layers.{i}.input_weight"] = sd[f"{p}.input_layernorm.weight"] + processed[f"layers.{i}.post_attention_weight"] = sd[f"{p}.post_attention_layernorm.weight"] + + if lt == "linear_attention": + # GDN weights - just transpose projections + processed[f"layers.{i}.linear_qkv_weight"] = sd[f"{p}.linear_attn.in_proj_qkv.weight"].T + processed[f"layers.{i}.linear_z_weight"] = sd[f"{p}.linear_attn.in_proj_z.weight"].T + processed[f"layers.{i}.linear_b_weight"] = sd[f"{p}.linear_attn.in_proj_b.weight"].T + processed[f"layers.{i}.linear_a_weight"] = sd[f"{p}.linear_attn.in_proj_a.weight"].T + processed[f"layers.{i}.linear_conv_weight"] = sd[f"{p}.linear_attn.conv1d.weight"].squeeze(1) + processed[f"layers.{i}.linear_dt_bias"] = sd[f"{p}.linear_attn.dt_bias"] + processed[f"layers.{i}.linear_A_log"] = sd[f"{p}.linear_attn.A_log"] + processed[f"layers.{i}.linear_norm_weight"] = sd[f"{p}.linear_attn.norm.weight"] + processed[f"layers.{i}.linear_out_weight"] = sd[f"{p}.linear_attn.out_proj.weight"].T + else: + # Full attention + q_w = sd[f"{p}.self_attn.q_proj.weight"] + k_w = sd[f"{p}.self_attn.k_proj.weight"] + v_w = sd[f"{p}.self_attn.v_proj.weight"] + processed[f"layers.{i}.qkv_weight"] = torch.cat([q_w.T, k_w.T, v_w.T], dim=-1) + processed[f"layers.{i}.o_weight"] = sd[f"{p}.self_attn.o_proj.weight"].T + processed[f"layers.{i}.q_norm_weight"] = sd[f"{p}.self_attn.q_norm.weight"] + processed[f"layers.{i}.k_norm_weight"] = sd[f"{p}.self_attn.k_norm.weight"] + + # Dense FFN (not MoE) + processed[f"layers.{i}.gate_proj_weight"] = sd[f"{p}.mlp.gate_proj.weight"].T + processed[f"layers.{i}.up_proj_weight"] = sd[f"{p}.mlp.up_proj.weight"].T + processed[f"layers.{i}.down_proj_weight"] = sd[f"{p}.mlp.down_proj.weight"].T + + for k, v in processed.items(): + processed[k] = v.contiguous() + + os.makedirs(WEIGHTS_DIR, exist_ok=True) + save_file(processed, os.path.join(WEIGHTS_DIR, "shard_0.safetensors")) + print(f"Saved {len(processed)} tensors to {WEIGHTS_DIR}/shard_0.safetensors") + del model + + +# ========================================================================= +# Model (TP=1, dense FFN) +# ========================================================================= +def run_inference(prompt, max_new_tokens=64): + import torch.distributed as dist + from nkipy.runtime import DeviceKernel, DeviceTensor + from safetensors.torch import load_file + from transformers import AutoConfig, AutoTokenizer + + from kernels.rmsnorm import rmsnorm_kernel + from kernels.attention import attention_kernel + from kernels.linear_attention import gated_delta_net_kernel + from kernels.softmax import softmax_kernel + from kernels.feedforward import silu_kernel_ + import nkipy.distributed.collectives as cc + from nkipy.core import tensor_apis + import nkipy.core.typing as nt + from typing import Optional + + os.environ["NEURON_RT_ROOT_COMM_ID"] = "localhost:61239" + os.environ["OMP_NUM_THREADS"] = "1" + dist.init_process_group() + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + os.environ["NEURON_RT_VISIBLE_CORES"] = "0" + + config = AutoConfig.from_pretrained(MODEL_NAME) + tc = config.text_config + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + input_ids_np = tokenizer(prompt, return_tensors="np")["input_ids"] + context_len = input_ids_np.shape[1] + + weights = load_file(os.path.join(WEIGHTS_DIR, "shard_0.safetensors"), device="cpu") + tok_embedding = weights["tok_embedding"] + + # --- Define kernels --- + def dense_ffn(x, gate_w, up_w, down_w): + gate = silu_kernel_(np.matmul(x, gate_w)) + up = np.matmul(x, up_w) + return np.matmul(gate * up, down_w) + + def linear_attn_layer( + x, input_weight, qkv_w, z_w, b_w, a_w, conv_w, dt_bias, A_log, + linear_norm_w, out_w, post_attention_weight, gate_w, up_w, down_w, + conv_state, recurrent_state, start_pos: Optional[nt.tensor], + num_k_heads, num_v_heads, head_k_dim, head_v_dim, conv_kernel_size, norm_eps, + ): + norm_x = rmsnorm_kernel(x, input_weight, norm_eps) + h1 = gated_delta_net_kernel( + norm_x, qkv_w, z_w, b_w, a_w, conv_w, dt_bias, A_log, linear_norm_w, + out_w, norm_eps, num_k_heads, num_v_heads, head_k_dim, head_v_dim, + conv_kernel_size, conv_state, recurrent_state, start_pos=start_pos, + ) + z = x + h1 + norm_z = rmsnorm_kernel(z, post_attention_weight, norm_eps) + ffn_out = dense_ffn(norm_z, gate_w, up_w, down_w) + # No all-reduce needed for TP=1, but the GDN kernel still has it (identity for ws=1) + return (z + ffn_out).astype(x.dtype) + + def full_attn_layer( + x, input_weight, qkv_weight, o_weight, q_norm_weight, k_norm_weight, + post_attention_weight, gate_w, up_w, down_w, + cache_k, cache_v, + start_pos: Optional[nt.tensor], + num_heads, head_dim, num_kv_heads, norm_eps, + ): + norm_x = rmsnorm_kernel(x, input_weight, norm_eps) + h1 = attention_kernel( + norm_x, qkv_weight, q_norm_weight, k_norm_weight, norm_eps, + num_heads, head_dim, num_kv_heads, 0.25, 10000000.0, + cache_k, cache_v, + start_pos=start_pos, o_weight=o_weight, + ) + z = x + h1 + norm_z = rmsnorm_kernel(z, post_attention_weight, norm_eps) + ffn_out = dense_ffn(norm_z, gate_w, up_w, down_w) + return (z + ffn_out).astype(x.dtype) + + def compute_logits(h, norm_weight, lm_head_weight, norm_eps): + """Compute logits (argmax done on CPU to avoid NKIPy argmax bug on large dims).""" + h = rmsnorm_kernel(h, norm_weight, norm_eps) + logits = h[:, -1, :] @ lm_head_weight + return logits.astype(np.float32) + + # --- Prepare device tensors --- + print("Preparing tensors...") + layer_types = list(tc.layer_types) + n_layers = tc.num_hidden_layers + norm_eps = tc.rms_norm_eps + + n_kv = tc.num_key_value_heads + n_v = tc.linear_num_value_heads + n_k = tc.linear_num_key_heads + hk = tc.linear_key_head_dim + hv = tc.linear_value_head_dim + head_dim = tc.head_dim + max_seq = 4096 + + layers = [] + for i in range(n_layers): + lt = layer_types[i] + d = {"type": lt} + for key in ["input_weight", "post_attention_weight", "gate_proj_weight", + "up_proj_weight", "down_proj_weight"]: + d[key] = DeviceTensor.from_torch(weights[f"layers.{i}.{key}"], f"{key}_L{i}") + + if lt == "linear_attention": + for key in ["linear_qkv_weight", "linear_z_weight", "linear_b_weight", + "linear_a_weight", "linear_conv_weight", "linear_dt_bias", + "linear_A_log", "linear_norm_weight", "linear_out_weight"]: + d[key] = DeviceTensor.from_torch(weights[f"layers.{i}.{key}"], f"{key}_L{i}") + conv_dim = n_k * hk * 2 + n_v * hv + d["conv_state"] = DeviceTensor.from_numpy( + np.zeros((1, conv_dim, tc.linear_conv_kernel_dim), dtype=bf16), f"cs_L{i}") + d["recurrent_state"] = DeviceTensor.from_numpy( + np.zeros((1, n_v, hk, hv), dtype=bf16), f"rs_L{i}") + else: + for key in ["qkv_weight", "o_weight", "q_norm_weight", "k_norm_weight"]: + d[key] = DeviceTensor.from_torch(weights[f"layers.{i}.{key}"], f"{key}_L{i}") + d["cache_k"] = DeviceTensor.from_numpy( + np.zeros((1, max_seq, n_kv, head_dim), dtype=bf16), f"ck_L{i}") + d["cache_v"] = DeviceTensor.from_numpy( + np.zeros((1, max_seq, n_kv, head_dim), dtype=bf16), f"cv_L{i}") + layers.append(d) + + d_norm = DeviceTensor.from_torch(weights["norm_weight"], "norm_w") + d_lm = DeviceTensor.from_torch(weights["lm_head_weight"], "lm_head") + + # --- Compile kernels --- + print("Compiling kernels...") + + x_ctx = DeviceTensor.from_numpy(np.empty((1, context_len, tc.hidden_size), dtype=bf16), "x_ctx") + x_tok = DeviceTensor.from_numpy(np.empty((1, 1, tc.hidden_size), dtype=bf16), "x_tok") + d_sp = DeviceTensor.from_numpy(np.empty((1,), dtype=np.int32), "sp") + + # Find first of each type + la_idx = next(i for i, l in enumerate(layers) if l["type"] == "linear_attention") + fa_idx = next(i for i, l in enumerate(layers) if l["type"] == "full_attention") + la, fa = layers[la_idx], layers[fa_idx] + + common_la = dict( + num_k_heads=n_k, num_v_heads=n_v, head_k_dim=hk, head_v_dim=hv, + conv_kernel_size=tc.linear_conv_kernel_dim, norm_eps=norm_eps, + ) + common_fa = dict( + num_heads=tc.num_attention_heads, head_dim=head_dim, + num_kv_heads=n_kv, norm_eps=norm_eps, + ) + + def make_la_args(x, sp, la): + return dict( + x=x, input_weight=la["input_weight"], + qkv_w=la["linear_qkv_weight"], z_w=la["linear_z_weight"], + b_w=la["linear_b_weight"], a_w=la["linear_a_weight"], + conv_w=la["linear_conv_weight"], dt_bias=la["linear_dt_bias"], + A_log=la["linear_A_log"], linear_norm_w=la["linear_norm_weight"], + out_w=la["linear_out_weight"], post_attention_weight=la["post_attention_weight"], + gate_w=la["gate_proj_weight"], up_w=la["up_proj_weight"], + down_w=la["down_proj_weight"], + conv_state=la["conv_state"], recurrent_state=la["recurrent_state"], + start_pos=sp, **common_la, + ) + + def make_fa_args(x, sp, fa): + return dict( + x=x, input_weight=fa["input_weight"], + qkv_weight=fa["qkv_weight"], o_weight=fa["o_weight"], + q_norm_weight=fa["q_norm_weight"], k_norm_weight=fa["k_norm_weight"], + post_attention_weight=fa["post_attention_weight"], + gate_w=fa["gate_proj_weight"], up_w=fa["up_proj_weight"], + down_w=fa["down_proj_weight"], + cache_k=fa["cache_k"], cache_v=fa["cache_v"], + start_pos=sp, **common_fa, + ) + + k_cte_la = DeviceKernel.compile_and_load(linear_attn_layer, name="cte_la", + build_dir=BUILD_DIR, **make_la_args(x_ctx, None, la)) + k_tkg_la = DeviceKernel.compile_and_load(linear_attn_layer, name="tkg_la", + build_dir=BUILD_DIR, **make_la_args(x_tok, d_sp, la)) + k_cte_fa = DeviceKernel.compile_and_load(full_attn_layer, name="cte_fa", + build_dir=BUILD_DIR, **make_fa_args(x_ctx, None, fa)) + k_tkg_fa = DeviceKernel.compile_and_load(full_attn_layer, name="tkg_fa", + build_dir=BUILD_DIR, **make_fa_args(x_tok, d_sp, fa)) + vocab_size = tc.vocab_size + d_logits_ctx = DeviceTensor.from_numpy(np.empty((1, vocab_size), dtype=np.float32), "logits_ctx") + d_logits_tok = DeviceTensor.from_numpy(np.empty((1, vocab_size), dtype=np.float32), "logits_tok") + k_cte_sample = DeviceKernel.compile_and_load(compute_logits, name="cte_samp", + h=x_ctx, norm_weight=d_norm, lm_head_weight=d_lm, norm_eps=norm_eps, + build_dir=BUILD_DIR) + k_tkg_sample = DeviceKernel.compile_and_load(compute_logits, name="tkg_samp", + h=x_tok, norm_weight=d_norm, lm_head_weight=d_lm, norm_eps=norm_eps, + build_dir=BUILD_DIR) + + print(f"Compilation done. Generating {max_new_tokens} tokens...\n") + + # --- Generate --- + def run_layer(kernel_la, kernel_fa, idx, h, sp_tensor): + l = layers[idx] + if l["type"] == "linear_attention": + inp = { + "x": h, "input_weight": l["input_weight"], + "qkv_w": l["linear_qkv_weight"], "z_w": l["linear_z_weight"], + "b_w": l["linear_b_weight"], "a_w": l["linear_a_weight"], + "conv_w": l["linear_conv_weight"], "dt_bias": l["linear_dt_bias"], + "A_log": l["linear_A_log"], "linear_norm_w": l["linear_norm_weight"], + "out_w": l["linear_out_weight"], "post_attention_weight": l["post_attention_weight"], + "gate_w": l["gate_proj_weight"], "up_w": l["up_proj_weight"], + "down_w": l["down_proj_weight"], + "conv_state.must_alias_input": l["conv_state"], + "recurrent_state.must_alias_input": l["recurrent_state"], + } + if sp_tensor is not None: + inp["start_pos"] = sp_tensor + out = {"output0": h, "conv_state": l["conv_state"], "recurrent_state": l["recurrent_state"]} + kernel_la(inputs=inp, outputs=out) + else: + inp = { + "x": h, "input_weight": l["input_weight"], + "qkv_weight": l["qkv_weight"], "o_weight": l["o_weight"], + "q_norm_weight": l["q_norm_weight"], "k_norm_weight": l["k_norm_weight"], + "post_attention_weight": l["post_attention_weight"], + "gate_w": l["gate_proj_weight"], "up_w": l["up_proj_weight"], + "down_w": l["down_proj_weight"], + "cache_k.must_alias_input": l["cache_k"], "cache_v.must_alias_input": l["cache_v"], + } + if sp_tensor is not None: + inp["start_pos"] = sp_tensor + out = {"output0": h, "cache_k": l["cache_k"], "cache_v": l["cache_v"]} + kernel_fa(inputs=inp, outputs=out) + + # Reset GDN states + for l in layers: + if l["type"] == "linear_attention": + l["conv_state"].write_from_numpy(np.zeros(l["conv_state"].numpy().shape, dtype=bf16)) + l["recurrent_state"].write_from_numpy(np.zeros(l["recurrent_state"].numpy().shape, dtype=bf16)) + + h = DeviceTensor.from_torch(tok_embedding[input_ids_np], "h") + next_id = DeviceTensor.from_numpy(np.array([[0]], dtype=np.uint32), "nid") + + # Prefill + for i in range(n_layers): + run_layer(k_cte_la, k_cte_fa, i, h, None) + k_cte_sample(inputs={"h": h, "norm_weight": d_norm, "lm_head_weight": d_lm}, + outputs={"output0": d_logits_ctx}) + + # CPU argmax (NKIPy argmax has bug on large vocab dims) + logits = d_logits_ctx.numpy().flatten().astype(np.float32) + first_tid = int(np.argmax(logits)) + top5 = np.argsort(logits)[-5:][::-1] + print(f"First token: {first_tid} = '{tokenizer.decode([first_tid])}'") + print(f"Top 5: {[(t, tokenizer.decode([t]), f'{logits[t]:.2f}') for t in top5]}") + + generated = [first_tid] + + # Decode + start = time.time() + for pos in range(context_len, context_len + max_new_tokens - 1): + sp = DeviceTensor.from_numpy(np.array([pos], dtype=np.int32)) + tid = generated[-1] + if tid >= vocab_size: + print(f"Warning: token {tid} >= vocab_size {vocab_size}, clamping") + tid = vocab_size - 1 + nid_torch = torch.tensor([[tid]], dtype=torch.int) + h = DeviceTensor.from_torch(tok_embedding[nid_torch.numpy()], "h_dec") + for i in range(n_layers): + run_layer(k_tkg_la, k_tkg_fa, i, h, sp) + k_tkg_sample(inputs={"h": h, "norm_weight": d_norm, "lm_head_weight": d_lm}, + outputs={"output0": d_logits_tok}) + tid = int(np.argmax(d_logits_tok.numpy().flatten().astype(np.float32))) + generated.append(tid) + if tid in {248044}: # EOS + break + + elapsed = time.time() - start + text = tokenizer.decode(generated, skip_special_tokens=True) + print(f"{prompt}{text}") + print(f"\nTokens: {len(generated)}, Time: {elapsed:.2f}s, Speed: {len(generated)/elapsed:.1f} tok/s") + + # --- Validate against HF --- + print("\n--- Validating against HF reference ---") + import transformers.integrations.moe as _m + _m.is_grouped_mm_available = lambda: False + from transformers import AutoModelForCausalLM + + hf_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16, device_map="cpu") + hf_model.eval() + hf_inputs = tokenizer(prompt, return_tensors="pt") + with torch.no_grad(): + hf_out = hf_model.generate(hf_inputs["input_ids"], max_new_tokens=max_new_tokens, do_sample=False) + hf_text = tokenizer.decode(hf_out[0][hf_inputs["input_ids"].shape[1]:], skip_special_tokens=True) + print(f"HF: {prompt}{hf_text}") + print(f"NKIPy: {prompt}{text}") + + # Token-level comparison + hf_ids = hf_out[0][hf_inputs["input_ids"].shape[1]:].tolist() + n_match = sum(1 for a, b in zip(generated, hf_ids) if a == b) + n_total = min(len(generated), len(hf_ids)) + print(f"\nToken match: {n_match}/{n_total} ({100*n_match/max(n_total,1):.0f}%)") + if n_match == n_total: + print("PERFECT MATCH!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--prepare", action="store_true", help="Prepare weights") + parser.add_argument("--run", action="store_true", help="Run inference") + parser.add_argument("prompt", nargs="?", default="who are you? ") + parser.add_argument("-n", "--max-new-tokens", type=int, default=32) + args = parser.parse_args() + + if args.prepare: + prepare_weights() + elif args.run: + run_inference(args.prompt, args.max_new_tokens) + else: + print("Use --prepare or --run") diff --git a/examples/models/qwen3_5_35B_A3B/test_0_8b.sh b/examples/models/qwen3_5_35B_A3B/test_0_8b.sh new file mode 100755 index 0000000..72d60b8 --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/test_0_8b.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Test script for Qwen3.5-0.8B on Trainium (TP=1, no sharding) +# Usage: bash test_0_8b.sh + +set -e + +echo "==========================================" +echo "Qwen3.5-0.8B Test Script (TP=1)" +echo "==========================================" + +# Step 1: Clean compilation cache +echo "" +echo "[1/3] Cleaning compilation cache..." +rm -rf build_0_8b/ 2>/dev/null || true +echo "Done" + +# Step 2: Check and prepare weights +echo "" +echo "[2/3] Checking weights..." + +WEIGHTS_PATH="./qwen3_5_0_8b_shards" + +if [ ! -d "$WEIGHTS_PATH" ]; then + echo "Weights not found. Downloading and converting..." + python test_0_8b.py --prepare + echo "Done" +else + echo "Weights found at $WEIGHTS_PATH" +fi + +# Step 3: Run inference + validate against HF +echo "" +echo "[3/3] Running Qwen3.5-0.8B inference (TP=1)..." +echo "==========================================" + +export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=16 +export MALLOC_ARENA_MAX=2 +export MALLOC_TRIM_THRESHOLD_=-1 +export MALLOC_MMAP_THRESHOLD_=131072 + +torchrun --nproc-per-node 1 test_0_8b.py -- --run -n 32 "who are you" + +echo "" +echo "==========================================" +echo "Test passed!" +echo "==========================================" diff --git a/examples/models/qwen3_5_35B_A3B/utils.py b/examples/models/qwen3_5_35B_A3B/utils.py new file mode 100644 index 0000000..ad9cacf --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/utils.py @@ -0,0 +1,15 @@ +import sys + +import ml_dtypes +import numpy as np +import torch.distributed as dist + +bfloat16 = np.dtype(ml_dtypes.bfloat16) + + +def print_log(msg, rank_list=[0], verbose=0): + if not dist.is_initialized(): + print(msg) + elif dist.get_rank() in rank_list: + print(f"[RANK {dist.get_rank()}] {msg}") + sys.stdout.flush() diff --git a/examples/models/qwen3_5_35B_A3B/validate_layers.py b/examples/models/qwen3_5_35B_A3B/validate_layers.py new file mode 100644 index 0000000..d337d91 --- /dev/null +++ b/examples/models/qwen3_5_35B_A3B/validate_layers.py @@ -0,0 +1,284 @@ +"""Layer-by-layer validation: compare HF (CPU) vs NKIPy (Trainium) outputs. + +Usage: + uv run torchrun --nproc_per_node 8 validate_layers.py \ + --checkpoint ./qwen3_5_shards --model Qwen/Qwen3.5-35B-A3B + +Finds the first layer where NKIPy diverges from the HF reference. +""" + +import argparse +import os +import sys +import time + +import numpy as np +import torch +import torch.distributed as dist +from config import FULL_ATTENTION, LINEAR_ATTENTION, Config, get_config +from nkipy.runtime import DeviceKernel, DeviceTensor +from safetensors.torch import load_file +from transformers import AutoModelForCausalLM, AutoTokenizer +from utils import print_log + + +def compare_tensors(name, hf_tensor, nkipy_tensor, atol=0.05, rtol=0.05): + """Compare two tensors, return (passed, stats_str).""" + hf = hf_tensor.float().cpu() + nk = torch.from_numpy(nkipy_tensor).float().cpu() if isinstance(nkipy_tensor, np.ndarray) else nkipy_tensor.float().cpu() + + # Align shapes + if hf.shape != nk.shape: + return False, f"SHAPE MISMATCH: HF {hf.shape} vs NKIPy {nk.shape}" + + abs_diff = (hf - nk).abs() + max_abs = abs_diff.max().item() + mean_abs = abs_diff.mean().item() + + # Relative error (avoid div by zero) + denom = hf.abs().clamp(min=1e-8) + rel_diff = (abs_diff / denom) + max_rel = rel_diff.max().item() + mean_rel = rel_diff.mean().item() + + # Cosine similarity + cos_sim = torch.nn.functional.cosine_similarity( + hf.flatten().unsqueeze(0), nk.flatten().unsqueeze(0) + ).item() + + passed = cos_sim > 0.99 and max_abs < 5.0 + status = "PASS" if passed else "FAIL" + + stats = ( + f"[{status}] {name}: " + f"cos_sim={cos_sim:.6f} " + f"max_abs={max_abs:.4f} mean_abs={mean_abs:.4f} " + f"max_rel={max_rel:.4f} mean_rel={mean_rel:.4f}" + ) + return passed, stats + + +def run_hf_reference(model_name, prompt, dtype=torch.bfloat16): + """Run HF model on CPU, capture outputs after each layer.""" + print_log("Loading HF reference model on CPU...") + tokenizer = AutoTokenizer.from_pretrained(model_name) + + # Disable grouped_mm which has alignment issues on CPU + import transformers.integrations.moe as _moe_mod + _moe_mod.is_grouped_mm_available = lambda: False + + model = AutoModelForCausalLM.from_pretrained( + model_name, dtype=dtype, device_map="cpu", low_cpu_mem_usage=True + ) + model.eval() + + inputs = tokenizer(prompt, return_tensors="pt") + input_ids = inputs["input_ids"] + + # Register hooks to capture layer outputs + layer_outputs = {} + + def make_hook(layer_idx): + def hook(module, inp, out): + # DecoderLayer returns hidden_states (possibly as tensor directly) + if isinstance(out, tuple): + layer_outputs[layer_idx] = out[0].detach().clone() + else: + layer_outputs[layer_idx] = out.detach().clone() + return hook + + # Hook embedding + def embed_hook(module, inp, out): + layer_outputs["embedding"] = out.detach().clone() + + model.model.embed_tokens.register_forward_hook(embed_hook) + + for i, layer in enumerate(model.model.layers): + layer.register_forward_hook(make_hook(i)) + + # Hook final norm + def norm_hook(module, inp, out): + layer_outputs["final_norm"] = out.detach().clone() + model.model.norm.register_forward_hook(norm_hook) + + print_log("Running HF forward pass...") + with torch.no_grad(): + outputs = model(input_ids) + + # Get logits for the last token + logits = outputs.logits[:, -1, :] # (1, vocab_size) + layer_outputs["logits"] = logits.detach().clone() + + # Get top-5 predictions + top5_vals, top5_ids = logits.topk(5, dim=-1) + print_log(f"HF top-5 token IDs: {top5_ids[0].tolist()}") + print_log(f"HF top-5 logit values: {[f'{v:.2f}' for v in top5_vals[0].tolist()]}") + decoded = [tokenizer.decode([tid]) for tid in top5_ids[0].tolist()] + print_log(f"HF top-5 tokens: {decoded}") + + # Clean up HF model to free memory + del model + import gc + gc.collect() + + return layer_outputs, input_ids + + +def run_nkipy_prefill_with_capture(args, input_ids_np): + """Run NKIPy prefill, capture hidden_states after each layer.""" + from qwen3_5 import Qwen35Model + + config = get_config(args.model, input_ids_np.shape[1], 1) + + shard_path = os.path.join(args.checkpoint, f"shard_{dist.get_rank()}.safetensors") + weights = load_file(shard_path, device="cpu") + + model = Qwen35Model(weights, config) + + # Capture embedding output + nkipy_outputs = {} + hidden_states = DeviceTensor.from_torch( + model.tok_embedding[input_ids_np], "hidden_states" + ) + nkipy_outputs["embedding"] = hidden_states.torch().clone() + + # Run prefill layer by layer, capturing outputs + for i in range(config.num_layers): + model._run_layer( + model.kernel_cte_full_attn, + model.kernel_cte_linear_attn, + i, + hidden_states, + None, # prefill: start_pos=None + ) + # Read back hidden_states from device + nkipy_outputs[i] = hidden_states.torch().clone() + + # Run sampling to get logits (we can compare top token) + next_id = DeviceTensor.from_numpy(np.array([[0]], dtype=np.uint32), "next_id") + model.kernel_cte_greedy_sampling( + inputs={ + "h": hidden_states, + "norm_weight": model.norm_weight, + "lm_head_weight": model.lm_head_weight, + }, + outputs={"output0": next_id}, + ) + next_id_val = next_id.torch().item() + nkipy_outputs["next_token_id"] = next_id_val + + return nkipy_outputs + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint", default="./qwen3_5_shards") + parser.add_argument("--model", default="Qwen/Qwen3.5-35B-A3B") + parser.add_argument("prompt", nargs="?", default="The capital of France is") + args = parser.parse_args() + + # Distributed setup + os.environ["TOKENIZERS_PARALLELISM"] = "true" + os.environ["OMP_NUM_THREADS"] = "1" + os.environ["NEURON_RT_ROOT_COMM_ID"] = "localhost:61239" + dist.init_process_group() + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + os.environ["NEURON_RT_VISIBLE_CORES"] = str(dist.get_rank()) + + rank = dist.get_rank() + is_rank0 = rank == 0 + + # --- Step 1: Run HF reference (only on rank 0) --- + hf_outputs = None + input_ids_pt = None + if is_rank0: + hf_outputs, input_ids_pt = run_hf_reference(args.model, args.prompt) + print_log(f"HF captured {len(hf_outputs)} outputs") + + dist.barrier() + + # --- Step 2: Run NKIPy on Trainium --- + # All ranks need to tokenize + tokenizer = AutoTokenizer.from_pretrained(args.model) + model_inputs = tokenizer(args.prompt, return_tensors="np") + input_ids_np = model_inputs["input_ids"] + + print_log("Running NKIPy prefill with layer capture...") + nkipy_outputs = run_nkipy_prefill_with_capture(args, input_ids_np) + + dist.barrier() + + # --- Step 3: Compare (rank 0 only) --- + if is_rank0: + config = get_config(args.model, input_ids_np.shape[1], 1) + + print("\n" + "=" * 70) + print("LAYER-BY-LAYER COMPARISON: HF (CPU) vs NKIPy (Trainium)") + print("=" * 70) + + # Compare embedding + passed, stats = compare_tensors( + "embedding", hf_outputs["embedding"], nkipy_outputs["embedding"] + ) + print(stats) + if not passed: + print(">>> DIVERGENCE at embedding! Stopping.") + return + + # Compare each layer + first_fail_layer = None + for layer_idx in range(config.num_layers): + layer_type = config.layer_types[layer_idx] + label = f"layer {layer_idx:2d} ({layer_type[:6]})" + + if layer_idx not in hf_outputs: + print(f"[SKIP] {label}: no HF output captured") + continue + + passed, stats = compare_tensors( + label, hf_outputs[layer_idx], nkipy_outputs[layer_idx] + ) + print(stats) + + if not passed and first_fail_layer is None: + first_fail_layer = layer_idx + + # Print detailed info about the failing layer + hf_t = hf_outputs[layer_idx].float() + nk_t = nkipy_outputs[layer_idx].float() + print(f" HF range: [{hf_t.min():.4f}, {hf_t.max():.4f}], mean={hf_t.mean():.4f}, std={hf_t.std():.4f}") + print(f" NKI range: [{nk_t.min():.4f}, {nk_t.max():.4f}], mean={nk_t.mean():.4f}, std={nk_t.std():.4f}") + + # Show where the biggest differences are + diff = (hf_t - nk_t).abs() + flat_idx = diff.flatten().topk(5).indices + print(f" Top-5 abs diff locations (flat idx): {flat_idx.tolist()}") + + # Also check if previous layer was OK + if layer_idx > 0 and (layer_idx - 1) in hf_outputs: + prev_passed, _ = compare_tensors( + "prev", hf_outputs[layer_idx - 1], nkipy_outputs[layer_idx - 1] + ) + if prev_passed: + print(f" >>> Layer {layer_idx - 1} was OK. Bug is IN layer {layer_idx} ({layer_type}).") + + # Compare next token + if "next_token_id" in nkipy_outputs: + hf_next = hf_outputs["logits"].argmax(dim=-1).item() + nk_next = nkipy_outputs["next_token_id"] + match = "MATCH" if hf_next == nk_next else "MISMATCH" + hf_tok = tokenizer.decode([hf_next]) + nk_tok = tokenizer.decode([nk_next]) + print(f"\nNext token: HF={hf_next} ('{hf_tok}') vs NKIPy={nk_next} ('{nk_tok}') [{match}]") + + print("\n" + "=" * 70) + if first_fail_layer is not None: + print(f"FIRST DIVERGENCE: Layer {first_fail_layer} ({config.layer_types[first_fail_layer]})") + else: + print("ALL LAYERS MATCH!") + print("=" * 70) + + +if __name__ == "__main__": + main()