Skip to content

Commit ade214f

Browse files
Add experimental Gemma-3-1b-it Model bringup
1 parent ee72d26 commit ade214f

File tree

14 files changed

+2026
-42
lines changed

14 files changed

+2026
-42
lines changed
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
""" Test for Gemma-3-1b-it Attention """
2+
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
3+
4+
# SPDX-License-Identifier: Apache-2.0
5+
import os
6+
7+
import pytest
8+
import torch
9+
from loguru import logger
10+
11+
import ttnn
12+
from models.experimental.gemma3_1b.tt.attention import Attention
13+
from models.tt_transformers.tt.common import PagedAttentionConfig, precompute_freqs
14+
from models.tt_transformers.tt.rope import RotarySetup
15+
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
16+
17+
from models.tt_transformers.tt.model_config import ModelArgs
18+
19+
20+
@torch.no_grad()
21+
@skip_for_grayskull("Requires wormhole_b0 to run")
22+
@pytest.mark.parametrize(
23+
"mesh_device",
24+
[
25+
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
26+
os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids())
27+
)
28+
],
29+
indirect=True,
30+
)
31+
@pytest.mark.parametrize(
32+
"paged_attention",
33+
(
34+
True,
35+
False,
36+
),
37+
ids=(
38+
"paged_attention",
39+
"default_attention",
40+
),
41+
)
42+
@pytest.mark.parametrize(
43+
"page_params",
44+
[{"page_block_size": 32, "page_max_num_blocks": 1024}],
45+
)
46+
@pytest.mark.parametrize(
47+
"batch_size",
48+
(1,),
49+
)
50+
@pytest.mark.parametrize(
51+
"max_seq_len",
52+
(256,), # For decode-only unit test, there's no need to run with large sequence lengths
53+
)
54+
def test_attention_inference(
55+
max_seq_len,
56+
batch_size,
57+
paged_attention,
58+
page_params,
59+
mesh_device,
60+
reset_seeds,
61+
):
62+
dtype = ttnn.bfloat16
63+
pcc = 0.99
64+
65+
model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len)
66+
model_args.n_layers = 1 # For the unit test, just run a single layer
67+
68+
state_dict = model_args.load_state_dict()
69+
70+
first_layer_prefix = model_args.get_state_dict_prefix("Attention", 0) + "."
71+
# Ref model needs partial state dict, but our models use full state dict keys as cached weight names
72+
partial_state_dict = {
73+
k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix))
74+
}
75+
76+
reference_model = model_args.reference_attention()
77+
reference_model.load_state_dict(partial_state_dict)
78+
79+
seq_len = 1
80+
81+
generation_start_pos = 0
82+
generation_length = 10
83+
all_tests_pass = True
84+
85+
# Setup RoPE transformation matrices
86+
rope_setup = RotarySetup(
87+
mesh_device,
88+
batch_size,
89+
model_args.head_dim,
90+
model_args.max_seq_len,
91+
model_args.rope_theta,
92+
model_args.rope_scaling_factor,
93+
model_args.orig_context_len,
94+
)
95+
96+
transformation_mats = rope_setup.get_both_trans_mats()
97+
98+
page_table_tt = None
99+
paged_attention_config = None
100+
101+
if paged_attention:
102+
paged_attention_config = PagedAttentionConfig(
103+
block_size=page_params["page_block_size"],
104+
max_num_blocks=page_params["page_max_num_blocks"],
105+
)
106+
107+
# Implied shuffling of blocks
108+
permutation = torch.randperm(paged_attention_config.max_num_blocks)
109+
# Page table which maps virtual blocks to physical
110+
reverse_permutation = torch.argsort(permutation)
111+
page_table = reverse_permutation.reshape(
112+
model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size
113+
)
114+
page_table_tt = ttnn.from_torch(
115+
page_table,
116+
device=mesh_device,
117+
dtype=ttnn.int32,
118+
layout=ttnn.ROW_MAJOR_LAYOUT,
119+
mesh_mapper=ttnn.ShardTensor2dMesh(
120+
mesh_device,
121+
dims=(None, -2) if (model_args.is_galaxy and batch_size > 1) else (None, None),
122+
mesh_shape=model_args.cluster_shape,
123+
),
124+
)
125+
126+
tt_model = Attention(
127+
mesh_device,
128+
state_dict,
129+
weight_cache_path=model_args.weight_cache_path(dtype),
130+
layer_num=0,
131+
dtype=dtype,
132+
transformation_mats=transformation_mats,
133+
configuration=model_args,
134+
paged_attention_config=paged_attention_config,
135+
)
136+
137+
cos, sin = precompute_freqs(
138+
model_args.head_dim,
139+
model_args.max_seq_len * 2,
140+
model_args.rope_theta,
141+
model_args.rope_scaling_factor,
142+
model_args.orig_context_len,
143+
)
144+
freqs_cis = torch.complex(cos, sin)
145+
146+
# Initial positions
147+
current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)])
148+
current_pos_tensor = ttnn.from_torch(
149+
current_pos,
150+
device=mesh_device,
151+
dtype=ttnn.int32,
152+
mesh_mapper=ttnn.ShardTensor2dMesh(
153+
mesh_device,
154+
dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None),
155+
mesh_shape=model_args.cluster_shape,
156+
),
157+
)
158+
159+
for i in range(generation_length):
160+
# 70B attention block typically sees tensors with mean 0 and std 0.03 - 0.05 in layer 1
161+
pt_attention_input = torch.randn(batch_size, seq_len, model_args.dim) # Qwen2.5 0.5B sees 0.1 to 2.1
162+
163+
tt_attention_input = pt_attention_input.clone()
164+
165+
attention_input = model_args.prepare_residual_tensor_decode(
166+
tt_attention_input,
167+
model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"],
168+
force_replicated=False if model_args.is_galaxy else True,
169+
)
170+
171+
# Get cos/sin matrices for the current position of each user
172+
rot_mats = rope_setup.get_rot_mats(current_pos)
173+
174+
tt_out = tt_model(
175+
attention_input,
176+
current_pos_tensor,
177+
rot_mats=rot_mats,
178+
mode="decode",
179+
page_table=page_table_tt,
180+
)
181+
# multi-device attention module returns replicated output
182+
tt_out = ttnn.to_torch(
183+
tt_out,
184+
mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape),
185+
)
186+
tt_output_torch = tt_out[:, 0:1, : model_args.max_batch_size, : model_args.dim].view(-1, 1, model_args.dim)
187+
188+
# In this test all users have the same position (if using batch > 1)
189+
freqs_cis_i = freqs_cis[current_pos[0], :].unsqueeze(0)
190+
191+
reference_output = reference_model(pt_attention_input, current_pos[0], freqs_cis_i, mask=None)
192+
193+
passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc)
194+
195+
logger.info(comp_allclose(reference_output, tt_output_torch))
196+
logger.info(f"PCC: {pcc_message}")
197+
if passing:
198+
logger.info(f"[pos={current_pos[0]}] Attention Passed!")
199+
else:
200+
logger.warning(f"[pos={current_pos[0]}] Attention Failed!")
201+
all_tests_pass = False
202+
203+
# Increment position
204+
current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch_size)])
205+
current_pos_tensor = ttnn.from_torch(
206+
current_pos,
207+
device=mesh_device,
208+
dtype=ttnn.int32,
209+
mesh_mapper=ttnn.ShardTensor2dMesh(
210+
mesh_device,
211+
dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None),
212+
mesh_shape=model_args.cluster_shape,
213+
),
214+
)
215+
216+
check_kv_cache = True
217+
if check_kv_cache:
218+
# PyTorch output --------------------------------------------------------------------
219+
pytorch_layer_present = [
220+
reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim]
221+
reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim]
222+
]
223+
# TT hardware execution -------------------------------------------------------------
224+
if paged_attention:
225+
tt_layer_present = [
226+
(
227+
ttnn.to_torch(
228+
cache,
229+
mesh_composer=ttnn.ConcatMesh2dToTensor(
230+
mesh_device,
231+
dims=(1, 3) if model_args.is_galaxy else (0, 1),
232+
mesh_shape=model_args.cluster_shape,
233+
),
234+
)[reverse_permutation][:, : model_args.n_kv_heads, :, : model_args.head_dim]
235+
.reshape(
236+
model_args.max_batch_size,
237+
paged_attention_config.max_num_blocks // model_args.max_batch_size,
238+
model_args.n_kv_heads,
239+
paged_attention_config.block_size,
240+
model_args.head_dim,
241+
)
242+
.transpose(1, 2)
243+
.reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[
244+
:batch_size, ...
245+
]
246+
)
247+
for cache in tt_model.layer_past
248+
]
249+
else:
250+
tt_layer_present = [
251+
ttnn.to_torch(
252+
cache,
253+
mesh_composer=ttnn.ConcatMesh2dToTensor(
254+
mesh_device,
255+
dims=(1, 0) if model_args.is_galaxy else (0, 1),
256+
mesh_shape=model_args.cluster_shape,
257+
),
258+
)[:batch_size, :, :, :]
259+
for cache in tt_model.layer_past
260+
]
261+
for label, cache_pt, cache_tt in zip(["K", "V"], pytorch_layer_present, tt_layer_present):
262+
cache_length_to_check = min(model_args.max_seq_len, generation_start_pos + i + 1)
263+
cache_pt = cache_pt[:, :, generation_start_pos:cache_length_to_check, :]
264+
cache_tt = cache_tt[:, :, generation_start_pos:cache_length_to_check, :]
265+
does_pass, output_pcc = comp_pcc(cache_pt, cache_tt, pcc)
266+
logger.info(f"{label} cache output: {output_pcc}")
267+
if does_pass:
268+
logger.info(f"{label} cache Passed!")
269+
else:
270+
logger.warning(f"{label} Cache Failed! PCC value is lower than {pcc}")
271+
all_tests_pass = False
272+
273+
if all_tests_pass:
274+
logger.info("Attention output Passed!")
275+
else:
276+
logger.warning("Attention output Failed!")
277+
assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!"

0 commit comments

Comments
 (0)