Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions examples/models/qwen3_5_35B_A3B/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from dataclasses import dataclass, field
from typing import List

import numpy as np
import torch.distributed as dist
from neuronxcc.nki.language import bfloat16
from transformers import AutoConfig

DTYPE = bfloat16

# Layer types for Qwen3.5 hybrid architecture
FULL_ATTENTION = "full_attention"
LINEAR_ATTENTION = "linear_attention"


@dataclass
class Config:
hidden_size: int
num_heads: int # full attention Q heads
head_dim: int # full attention head dim
num_kv_heads: int # full attention KV heads
num_layers: int
num_experts_per_tok: int
num_experts: int
intermediate_size: int # moe expert intermediate (per device)
shared_expert_intermediate_size: int # shared expert intermediate (per device)
vocab_size: int
# Linear attention params
linear_num_key_heads: int
linear_num_value_heads: int
linear_key_head_dim: int
linear_value_head_dim: int
linear_conv_kernel_dim: int
# Layer type info
layer_types: List[str] = field(default_factory=list)
# RoPE
partial_rotary_factor: float = 0.25
rope_theta: float = 10000000.0
# Sequence
context_len: int = None
max_new_tokens: int = None
max_batch_size: int = 1
max_seq_len: int = 4096
# Norm
norm_eps: float = 1e-6
dtype: np.dtype = DTYPE
additional_compiler_args_nkipy: str = "--lnc 1"


def get_config(model_name, context_len, max_new_tokens):
hf_config = AutoConfig.from_pretrained(model_name)
# Qwen3.5 is multimodal; text config is nested
text_cfg = hf_config.text_config if hasattr(hf_config, "text_config") else hf_config

ws = dist.get_world_size()
config = Config(
hidden_size=text_cfg.hidden_size,
num_heads=text_cfg.num_attention_heads,
head_dim=text_cfg.head_dim,
num_kv_heads=text_cfg.num_key_value_heads,
num_layers=text_cfg.num_hidden_layers,
num_experts_per_tok=text_cfg.num_experts_per_tok,
num_experts=text_cfg.num_experts,
intermediate_size=text_cfg.moe_intermediate_size // ws,
shared_expert_intermediate_size=text_cfg.shared_expert_intermediate_size // ws,
vocab_size=text_cfg.vocab_size,
linear_num_key_heads=text_cfg.linear_num_key_heads,
linear_num_value_heads=text_cfg.linear_num_value_heads,
linear_key_head_dim=text_cfg.linear_key_head_dim,
linear_value_head_dim=text_cfg.linear_value_head_dim,
linear_conv_kernel_dim=text_cfg.linear_conv_kernel_dim,
layer_types=list(text_cfg.layer_types),
partial_rotary_factor=text_cfg.partial_rotary_factor,
rope_theta=text_cfg.rope_parameters.get("rope_theta", 10000000.0),
norm_eps=text_cfg.rms_norm_eps,
context_len=context_len,
max_new_tokens=max_new_tokens,
)
return config
163 changes: 163 additions & 0 deletions examples/models/qwen3_5_35B_A3B/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""Benchmarking and accuracy validation for NKIPy Qwen3.5-35B-A3B."""

import json
import os
import time

import torch


# ---------------------------------------------------------------------------
# Benchmarking
# ---------------------------------------------------------------------------


def _percentile(data, pct):
if not data:
return 0.0
sorted_data = sorted(data)
k = (len(sorted_data) - 1) * (pct / 100.0)
f = int(k)
c = f + 1
if c >= len(sorted_data):
return sorted_data[f]
return sorted_data[f] + (k - f) * (sorted_data[c] - sorted_data[f])


def _run_once(model, input_ids):
token_times = []
start = time.perf_counter()
for i, _token_id in enumerate(model.generate(input_ids)):
now = time.perf_counter()
if i == 0:
ttft_ms = (now - start) * 1000.0
prev = now
else:
token_times.append((now - prev) * 1000.0)
prev = now

end = time.perf_counter()
return {
"ttft_ms": ttft_ms,
"decode_latencies_ms": token_times,
"num_tokens": len(token_times) + 1,
"total_time_ms": (end - start) * 1000.0,
}


def benchmark_generation(model, input_ids, num_warmup=2, num_runs=5):
total_runs = num_warmup + num_runs
run_reports = []

for run_idx in range(total_runs):
is_warmup = run_idx < num_warmup
label = (
f"warmup {run_idx + 1}/{num_warmup}"
if is_warmup
else f"run {run_idx - num_warmup + 1}/{num_runs}"
)
print(f"[benchmark] {label}...")

report = _run_once(model, input_ids)

if not is_warmup:
run_reports.append(report)
throughput = (
report["num_tokens"] / (report["total_time_ms"] / 1000.0)
if report["total_time_ms"] > 0
else 0.0
)
print(
f" TTFT={report['ttft_ms']:.1f}ms "
f"tokens={report['num_tokens']} "
f"throughput={throughput:.1f} tok/s"
)

n = len(run_reports)
if n == 0:
return {}

avg_ttft = sum(r["ttft_ms"] for r in run_reports) / n
avg_total = sum(r["total_time_ms"] for r in run_reports) / n
num_tokens = run_reports[0]["num_tokens"]

all_decode = []
for r in run_reports:
all_decode.extend(r["decode_latencies_ms"])

throughput = num_tokens / (avg_total / 1000.0) if avg_total > 0 else 0.0

result = {
"ttft_ms": round(avg_ttft, 3),
"decode_latency_p50_ms": round(_percentile(all_decode, 50), 3),
"decode_latency_p90_ms": round(_percentile(all_decode, 90), 3),
"decode_latency_p99_ms": round(_percentile(all_decode, 99), 3),
"num_tokens": num_tokens,
"total_time_ms": round(avg_total, 3),
"throughput_tokens_per_sec": round(throughput, 2),
}

print("\n=== Benchmark Results ===")
print(f" TTFT: {result['ttft_ms']:.1f} ms")
print(f" Decode latency (p50): {result['decode_latency_p50_ms']:.1f} ms")
print(f" Decode latency (p90): {result['decode_latency_p90_ms']:.1f} ms")
print(f" Decode latency (p99): {result['decode_latency_p99_ms']:.1f} ms")
print(f" Tokens generated: {result['num_tokens']}")
print(
f" Throughput: {result['throughput_tokens_per_sec']:.1f} tokens/sec"
)
print("=========================\n")

return result


def save_benchmark_report(result, path="benchmark_report.json"):
with open(path, "w") as f:
json.dump(result, f, indent=2)
print(f"[benchmark] Report saved to {path}")


# ---------------------------------------------------------------------------
# CLI driver
# ---------------------------------------------------------------------------

if __name__ == "__main__":
import argparse
import sys

parser = argparse.ArgumentParser(
description="Evaluate NKIPy Qwen3.5-35B-A3B: benchmark"
)

parser.add_argument("-n", "--max-new-tokens", type=int, default=16)
parser.add_argument("prompt", nargs="?", default="The capital of France is")
parser.add_argument("--checkpoint", default="./qwen3_5_shards")
parser.add_argument("--model", default="Qwen/Qwen3.5-35B-A3B")

mode = parser.add_mutually_exclusive_group(required=True)
mode.add_argument("--benchmark", action="store_true")

parser.add_argument("--benchmark-warmup", type=int, default=2)
parser.add_argument("--benchmark-runs", type=int, default=5)
parser.add_argument(
"--benchmark-output", type=str, default="benchmark_report.json"
)

args = parser.parse_args()

import torch.distributed as dist

from qwen3_5 import load_model

model, input_ids, _ = load_model(args)

if args.benchmark:
dist.barrier()
result = benchmark_generation(
model,
input_ids,
num_warmup=args.benchmark_warmup,
num_runs=args.benchmark_runs,
)
if dist.get_rank() == 0:
save_benchmark_report(result, args.benchmark_output)
Empty file.
Loading