Skip to content

Commit fdd2f1b

Browse files
Add Support for Gemma-3-4b-it
1 parent 5bcfe13 commit fdd2f1b

38 files changed

+6654
-163
lines changed
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
"""Gemma-3-4b-it Test for Text Attention"""
2+
3+
4+
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
5+
6+
# SPDX-License-Identifier: Apache-2.0
7+
import os
8+
9+
import pytest
10+
import torch
11+
from loguru import logger
12+
13+
import ttnn
14+
from models.experimental.gemma3_4b.tt.attention import Attention
15+
from models.tt_transformers.tt.common import PagedAttentionConfig, precompute_freqs
16+
from models.tt_transformers.tt.rope import RotarySetup
17+
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
18+
19+
from models.tt_transformers.tt.model_config import ModelArgs
20+
21+
22+
@torch.no_grad()
23+
@skip_for_grayskull("Requires wormhole_b0 to run")
24+
@pytest.mark.parametrize(
25+
"mesh_device",
26+
[
27+
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
28+
os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids())
29+
)
30+
],
31+
indirect=True,
32+
)
33+
@pytest.mark.parametrize(
34+
"paged_attention",
35+
(
36+
True,
37+
# False,
38+
),
39+
ids=(
40+
"paged_attention",
41+
# "default_attention",
42+
),
43+
)
44+
@pytest.mark.parametrize(
45+
"page_params",
46+
[{"page_block_size": 32, "page_max_num_blocks": 1024}],
47+
)
48+
@pytest.mark.parametrize(
49+
"batch_size",
50+
(1,),
51+
)
52+
@pytest.mark.parametrize(
53+
"max_seq_len",
54+
(1,), # For decode-only unit test, there's no need to run with large sequence lengths
55+
)
56+
def test_attention_inference(
57+
max_seq_len,
58+
batch_size,
59+
paged_attention,
60+
page_params,
61+
mesh_device,
62+
reset_seeds,
63+
# ensure_gc,
64+
):
65+
dtype = ttnn.bfloat16
66+
pcc = 0.99
67+
68+
model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=128)
69+
model_args.n_layers = 1 # For the unit test, just run a single layer
70+
71+
state_dict = model_args.load_state_dict()
72+
73+
first_layer_prefix = model_args.get_state_dict_prefix("Attention", 0) + "."
74+
# Ref model needs partial state dict, but our models use full state dict keys as cached weight names
75+
# partial_state_dict = {
76+
# k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix))
77+
# }
78+
79+
reference_model = model_args.reference_attention()
80+
# reference_model.load_state_dict(partial_state_dict)
81+
82+
seq_len = 1
83+
84+
generation_start_pos = 0
85+
generation_length = 10
86+
all_tests_pass = True
87+
88+
# Setup RoPE transformation matrices
89+
rope_setup = RotarySetup(
90+
mesh_device,
91+
batch_size,
92+
model_args.head_dim,
93+
model_args.max_seq_len,
94+
model_args.rope_theta,
95+
model_args.rope_scaling,
96+
)
97+
98+
transformation_mats = rope_setup.get_both_trans_mats()
99+
100+
page_table_tt = None
101+
paged_attention_config = None
102+
103+
if paged_attention:
104+
paged_attention_config = PagedAttentionConfig(
105+
block_size=page_params["page_block_size"],
106+
max_num_blocks=page_params["page_max_num_blocks"],
107+
)
108+
109+
# Implied shuffling of blocks
110+
permutation = torch.randperm(paged_attention_config.max_num_blocks)
111+
# Page table which maps virtual blocks to physical
112+
reverse_permutation = torch.argsort(permutation)
113+
page_table = reverse_permutation.reshape(
114+
model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size
115+
)
116+
page_table_tt = ttnn.from_torch(
117+
page_table,
118+
device=mesh_device,
119+
dtype=ttnn.int32,
120+
layout=ttnn.ROW_MAJOR_LAYOUT,
121+
mesh_mapper=ttnn.ShardTensor2dMesh(
122+
mesh_device,
123+
dims=(None, -2) if (model_args.is_galaxy and batch_size > 1) else (None, None),
124+
mesh_shape=model_args.cluster_shape,
125+
),
126+
)
127+
128+
tt_model = Attention(
129+
mesh_device,
130+
state_dict,
131+
weight_cache_path=model_args.weight_cache_path(dtype),
132+
layer_num=0,
133+
dtype=dtype,
134+
transformation_mats=transformation_mats,
135+
configuration=model_args,
136+
paged_attention_config=paged_attention_config,
137+
)
138+
139+
cos, sin = precompute_freqs(
140+
model_args.head_dim,
141+
model_args.max_seq_len * 2,
142+
model_args.rope_theta,
143+
model_args.rope_scaling.factor if model_args.rope_scaling else None,
144+
model_args.rope_scaling.original_max_position_embeddings if model_args.rope_scaling else None,
145+
)
146+
freqs_cis = torch.complex(cos, sin)
147+
148+
# Initial positions
149+
current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)])
150+
current_pos_tensor = ttnn.from_torch(
151+
current_pos,
152+
device=mesh_device,
153+
dtype=ttnn.int32,
154+
mesh_mapper=ttnn.ShardTensor2dMesh(
155+
mesh_device,
156+
dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None),
157+
mesh_shape=model_args.cluster_shape,
158+
),
159+
)
160+
161+
for i in range(generation_length):
162+
# 70B attention block typically sees tensors with mean 0 and std 0.03 - 0.05 in layer 1
163+
pt_attention_input = torch.randn(batch_size, seq_len, model_args.dim) # Qwen2.5 0.5B sees 0.1 to 2.1
164+
165+
tt_attention_input = pt_attention_input.clone()
166+
167+
attention_input = model_args.prepare_residual_tensor_decode(
168+
tt_attention_input,
169+
model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"],
170+
force_replicated=False if model_args.is_galaxy else True,
171+
)
172+
173+
# Get cos/sin matrices for the current position of each user
174+
rot_mats = rope_setup.get_rot_mats(current_pos)
175+
176+
tt_out = tt_model(
177+
attention_input,
178+
current_pos_tensor,
179+
rot_mats=rot_mats,
180+
mode="decode",
181+
page_table=page_table_tt,
182+
)
183+
# multi-device attention module returns replicated output
184+
tt_out = ttnn.to_torch(
185+
tt_out,
186+
mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape),
187+
)
188+
tt_output_torch = tt_out[:, 0:1, : model_args.max_batch_size, : model_args.dim].view(-1, 1, model_args.dim)
189+
190+
# In this test all users have the same position (if using batch > 1)
191+
freqs_cis_i = freqs_cis[current_pos[0], :].unsqueeze(0)
192+
193+
reference_output = reference_model(pt_attention_input, current_pos[0], freqs_cis_i, mask=None)
194+
195+
passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc)
196+
197+
logger.info(comp_allclose(reference_output, tt_output_torch))
198+
logger.info(f"PCC: {pcc_message}")
199+
if passing:
200+
logger.info(f"[pos={current_pos[0]}] Attention Passed!")
201+
else:
202+
logger.warning(f"[pos={current_pos[0]}] Attention Failed!")
203+
all_tests_pass = False
204+
205+
# Increment position
206+
current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch_size)])
207+
current_pos_tensor = ttnn.from_torch(
208+
current_pos,
209+
device=mesh_device,
210+
dtype=ttnn.int32,
211+
mesh_mapper=ttnn.ShardTensor2dMesh(
212+
mesh_device,
213+
dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None),
214+
mesh_shape=model_args.cluster_shape,
215+
),
216+
)
217+
218+
check_kv_cache = True
219+
if check_kv_cache:
220+
# PyTorch output --------------------------------------------------------------------
221+
pytorch_layer_present = [
222+
reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim]
223+
reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim]
224+
]
225+
# TT hardware execution -------------------------------------------------------------
226+
if paged_attention:
227+
tt_layer_present = [
228+
(
229+
ttnn.to_torch(
230+
cache,
231+
mesh_composer=ttnn.ConcatMesh2dToTensor(
232+
mesh_device,
233+
dims=(1, 3) if model_args.is_galaxy else (0, 1),
234+
mesh_shape=model_args.cluster_shape,
235+
),
236+
)[reverse_permutation][:, : model_args.n_kv_heads, :, : model_args.head_dim]
237+
.reshape(
238+
model_args.max_batch_size,
239+
paged_attention_config.max_num_blocks // model_args.max_batch_size,
240+
model_args.n_kv_heads,
241+
paged_attention_config.block_size,
242+
model_args.head_dim,
243+
)
244+
.transpose(1, 2)
245+
.reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[
246+
:batch_size, ...
247+
]
248+
)
249+
for cache in tt_model.layer_past
250+
]
251+
else:
252+
tt_layer_present = [
253+
ttnn.to_torch(
254+
cache,
255+
mesh_composer=ttnn.ConcatMesh2dToTensor(
256+
mesh_device,
257+
dims=(1, 0) if model_args.is_galaxy else (0, 1),
258+
mesh_shape=model_args.cluster_shape,
259+
),
260+
)[:batch_size, :, :, :]
261+
for cache in tt_model.layer_past
262+
]
263+
for label, cache_pt, cache_tt in zip(["K", "V"], pytorch_layer_present, tt_layer_present):
264+
cache_length_to_check = min(model_args.max_seq_len, generation_start_pos + i + 1)
265+
cache_pt = cache_pt[:, :, generation_start_pos:cache_length_to_check, :]
266+
cache_tt = cache_tt[:, :, generation_start_pos:cache_length_to_check, :]
267+
does_pass, output_pcc = comp_pcc(cache_pt, cache_tt, pcc)
268+
logger.info(f"{label} cache output: {output_pcc}")
269+
if does_pass:
270+
logger.info(f"{label} cache Passed!")
271+
else:
272+
logger.warning(f"{label} Cache Failed! PCC value is lower than {pcc}")
273+
all_tests_pass = False
274+
275+
if all_tests_pass:
276+
logger.info("Attention output Passed!")
277+
else:
278+
logger.warning("Attention output Failed!")
279+
assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!"

0 commit comments

Comments
 (0)