|
| 1 | +""" Test for Gemma-3-1b-it Attention """ |
| 2 | +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. |
| 3 | + |
| 4 | +# SPDX-License-Identifier: Apache-2.0 |
| 5 | +import os |
| 6 | + |
| 7 | +import pytest |
| 8 | +import torch |
| 9 | +from loguru import logger |
| 10 | + |
| 11 | +import ttnn |
| 12 | +from models.experimental.gemma3_1b.tt.attention import Attention |
| 13 | +from models.tt_transformers.tt.common import PagedAttentionConfig, precompute_freqs |
| 14 | +from models.tt_transformers.tt.rope import RotarySetup |
| 15 | +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull |
| 16 | + |
| 17 | +from models.tt_transformers.tt.model_config import ModelArgs |
| 18 | + |
| 19 | + |
| 20 | +@torch.no_grad() |
| 21 | +@skip_for_grayskull("Requires wormhole_b0 to run") |
| 22 | +@pytest.mark.parametrize( |
| 23 | + "mesh_device", |
| 24 | + [ |
| 25 | + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( |
| 26 | + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) |
| 27 | + ) |
| 28 | + ], |
| 29 | + indirect=True, |
| 30 | +) |
| 31 | +@pytest.mark.parametrize( |
| 32 | + "paged_attention", |
| 33 | + ( |
| 34 | + True, |
| 35 | + False, |
| 36 | + ), |
| 37 | + ids=( |
| 38 | + "paged_attention", |
| 39 | + "default_attention", |
| 40 | + ), |
| 41 | +) |
| 42 | +@pytest.mark.parametrize( |
| 43 | + "page_params", |
| 44 | + [{"page_block_size": 32, "page_max_num_blocks": 1024}], |
| 45 | +) |
| 46 | +@pytest.mark.parametrize( |
| 47 | + "batch_size", |
| 48 | + (1,), |
| 49 | +) |
| 50 | +@pytest.mark.parametrize( |
| 51 | + "max_seq_len", |
| 52 | + (256,), # For decode-only unit test, there's no need to run with large sequence lengths |
| 53 | +) |
| 54 | +def test_attention_inference( |
| 55 | + max_seq_len, |
| 56 | + batch_size, |
| 57 | + paged_attention, |
| 58 | + page_params, |
| 59 | + mesh_device, |
| 60 | + reset_seeds, |
| 61 | +): |
| 62 | + dtype = ttnn.bfloat16 |
| 63 | + pcc = 0.99 |
| 64 | + |
| 65 | + model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) |
| 66 | + model_args.n_layers = 1 # For the unit test, just run a single layer |
| 67 | + |
| 68 | + state_dict = model_args.load_state_dict() |
| 69 | + |
| 70 | + first_layer_prefix = model_args.get_state_dict_prefix("Attention", 0) + "." |
| 71 | + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names |
| 72 | + partial_state_dict = { |
| 73 | + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) |
| 74 | + } |
| 75 | + |
| 76 | + reference_model = model_args.reference_attention() |
| 77 | + reference_model.load_state_dict(partial_state_dict) |
| 78 | + |
| 79 | + seq_len = 1 |
| 80 | + |
| 81 | + generation_start_pos = 0 |
| 82 | + generation_length = 10 |
| 83 | + all_tests_pass = True |
| 84 | + |
| 85 | + # Setup RoPE transformation matrices |
| 86 | + rope_setup = RotarySetup( |
| 87 | + mesh_device, |
| 88 | + batch_size, |
| 89 | + model_args.head_dim, |
| 90 | + model_args.max_seq_len, |
| 91 | + model_args.rope_theta, |
| 92 | + model_args.rope_scaling_factor, |
| 93 | + model_args.orig_context_len, |
| 94 | + ) |
| 95 | + |
| 96 | + transformation_mats = rope_setup.get_both_trans_mats() |
| 97 | + |
| 98 | + page_table_tt = None |
| 99 | + paged_attention_config = None |
| 100 | + |
| 101 | + if paged_attention: |
| 102 | + paged_attention_config = PagedAttentionConfig( |
| 103 | + block_size=page_params["page_block_size"], |
| 104 | + max_num_blocks=page_params["page_max_num_blocks"], |
| 105 | + ) |
| 106 | + |
| 107 | + # Implied shuffling of blocks |
| 108 | + permutation = torch.randperm(paged_attention_config.max_num_blocks) |
| 109 | + # Page table which maps virtual blocks to physical |
| 110 | + reverse_permutation = torch.argsort(permutation) |
| 111 | + page_table = reverse_permutation.reshape( |
| 112 | + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size |
| 113 | + ) |
| 114 | + page_table_tt = ttnn.from_torch( |
| 115 | + page_table, |
| 116 | + device=mesh_device, |
| 117 | + dtype=ttnn.int32, |
| 118 | + layout=ttnn.ROW_MAJOR_LAYOUT, |
| 119 | + mesh_mapper=ttnn.ShardTensor2dMesh( |
| 120 | + mesh_device, |
| 121 | + dims=(None, -2) if (model_args.is_galaxy and batch_size > 1) else (None, None), |
| 122 | + mesh_shape=model_args.cluster_shape, |
| 123 | + ), |
| 124 | + ) |
| 125 | + |
| 126 | + tt_model = Attention( |
| 127 | + mesh_device, |
| 128 | + state_dict, |
| 129 | + weight_cache_path=model_args.weight_cache_path(dtype), |
| 130 | + layer_num=0, |
| 131 | + dtype=dtype, |
| 132 | + transformation_mats=transformation_mats, |
| 133 | + configuration=model_args, |
| 134 | + paged_attention_config=paged_attention_config, |
| 135 | + ) |
| 136 | + |
| 137 | + cos, sin = precompute_freqs( |
| 138 | + model_args.head_dim, |
| 139 | + model_args.max_seq_len * 2, |
| 140 | + model_args.rope_theta, |
| 141 | + model_args.rope_scaling_factor, |
| 142 | + model_args.orig_context_len, |
| 143 | + ) |
| 144 | + freqs_cis = torch.complex(cos, sin) |
| 145 | + |
| 146 | + # Initial positions |
| 147 | + current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)]) |
| 148 | + current_pos_tensor = ttnn.from_torch( |
| 149 | + current_pos, |
| 150 | + device=mesh_device, |
| 151 | + dtype=ttnn.int32, |
| 152 | + mesh_mapper=ttnn.ShardTensor2dMesh( |
| 153 | + mesh_device, |
| 154 | + dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), |
| 155 | + mesh_shape=model_args.cluster_shape, |
| 156 | + ), |
| 157 | + ) |
| 158 | + |
| 159 | + for i in range(generation_length): |
| 160 | + # 70B attention block typically sees tensors with mean 0 and std 0.03 - 0.05 in layer 1 |
| 161 | + pt_attention_input = torch.randn(batch_size, seq_len, model_args.dim) # Qwen2.5 0.5B sees 0.1 to 2.1 |
| 162 | + |
| 163 | + tt_attention_input = pt_attention_input.clone() |
| 164 | + |
| 165 | + attention_input = model_args.prepare_residual_tensor_decode( |
| 166 | + tt_attention_input, |
| 167 | + model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"], |
| 168 | + force_replicated=False if model_args.is_galaxy else True, |
| 169 | + ) |
| 170 | + |
| 171 | + # Get cos/sin matrices for the current position of each user |
| 172 | + rot_mats = rope_setup.get_rot_mats(current_pos) |
| 173 | + |
| 174 | + tt_out = tt_model( |
| 175 | + attention_input, |
| 176 | + current_pos_tensor, |
| 177 | + rot_mats=rot_mats, |
| 178 | + mode="decode", |
| 179 | + page_table=page_table_tt, |
| 180 | + ) |
| 181 | + # multi-device attention module returns replicated output |
| 182 | + tt_out = ttnn.to_torch( |
| 183 | + tt_out, |
| 184 | + mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape), |
| 185 | + ) |
| 186 | + tt_output_torch = tt_out[:, 0:1, : model_args.max_batch_size, : model_args.dim].view(-1, 1, model_args.dim) |
| 187 | + |
| 188 | + # In this test all users have the same position (if using batch > 1) |
| 189 | + freqs_cis_i = freqs_cis[current_pos[0], :].unsqueeze(0) |
| 190 | + |
| 191 | + reference_output = reference_model(pt_attention_input, current_pos[0], freqs_cis_i, mask=None) |
| 192 | + |
| 193 | + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) |
| 194 | + |
| 195 | + logger.info(comp_allclose(reference_output, tt_output_torch)) |
| 196 | + logger.info(f"PCC: {pcc_message}") |
| 197 | + if passing: |
| 198 | + logger.info(f"[pos={current_pos[0]}] Attention Passed!") |
| 199 | + else: |
| 200 | + logger.warning(f"[pos={current_pos[0]}] Attention Failed!") |
| 201 | + all_tests_pass = False |
| 202 | + |
| 203 | + # Increment position |
| 204 | + current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch_size)]) |
| 205 | + current_pos_tensor = ttnn.from_torch( |
| 206 | + current_pos, |
| 207 | + device=mesh_device, |
| 208 | + dtype=ttnn.int32, |
| 209 | + mesh_mapper=ttnn.ShardTensor2dMesh( |
| 210 | + mesh_device, |
| 211 | + dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), |
| 212 | + mesh_shape=model_args.cluster_shape, |
| 213 | + ), |
| 214 | + ) |
| 215 | + |
| 216 | + check_kv_cache = True |
| 217 | + if check_kv_cache: |
| 218 | + # PyTorch output -------------------------------------------------------------------- |
| 219 | + pytorch_layer_present = [ |
| 220 | + reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] |
| 221 | + reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] |
| 222 | + ] |
| 223 | + # TT hardware execution ------------------------------------------------------------- |
| 224 | + if paged_attention: |
| 225 | + tt_layer_present = [ |
| 226 | + ( |
| 227 | + ttnn.to_torch( |
| 228 | + cache, |
| 229 | + mesh_composer=ttnn.ConcatMesh2dToTensor( |
| 230 | + mesh_device, |
| 231 | + dims=(1, 3) if model_args.is_galaxy else (0, 1), |
| 232 | + mesh_shape=model_args.cluster_shape, |
| 233 | + ), |
| 234 | + )[reverse_permutation][:, : model_args.n_kv_heads, :, : model_args.head_dim] |
| 235 | + .reshape( |
| 236 | + model_args.max_batch_size, |
| 237 | + paged_attention_config.max_num_blocks // model_args.max_batch_size, |
| 238 | + model_args.n_kv_heads, |
| 239 | + paged_attention_config.block_size, |
| 240 | + model_args.head_dim, |
| 241 | + ) |
| 242 | + .transpose(1, 2) |
| 243 | + .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ |
| 244 | + :batch_size, ... |
| 245 | + ] |
| 246 | + ) |
| 247 | + for cache in tt_model.layer_past |
| 248 | + ] |
| 249 | + else: |
| 250 | + tt_layer_present = [ |
| 251 | + ttnn.to_torch( |
| 252 | + cache, |
| 253 | + mesh_composer=ttnn.ConcatMesh2dToTensor( |
| 254 | + mesh_device, |
| 255 | + dims=(1, 0) if model_args.is_galaxy else (0, 1), |
| 256 | + mesh_shape=model_args.cluster_shape, |
| 257 | + ), |
| 258 | + )[:batch_size, :, :, :] |
| 259 | + for cache in tt_model.layer_past |
| 260 | + ] |
| 261 | + for label, cache_pt, cache_tt in zip(["K", "V"], pytorch_layer_present, tt_layer_present): |
| 262 | + cache_length_to_check = min(model_args.max_seq_len, generation_start_pos + i + 1) |
| 263 | + cache_pt = cache_pt[:, :, generation_start_pos:cache_length_to_check, :] |
| 264 | + cache_tt = cache_tt[:, :, generation_start_pos:cache_length_to_check, :] |
| 265 | + does_pass, output_pcc = comp_pcc(cache_pt, cache_tt, pcc) |
| 266 | + logger.info(f"{label} cache output: {output_pcc}") |
| 267 | + if does_pass: |
| 268 | + logger.info(f"{label} cache Passed!") |
| 269 | + else: |
| 270 | + logger.warning(f"{label} Cache Failed! PCC value is lower than {pcc}") |
| 271 | + all_tests_pass = False |
| 272 | + |
| 273 | + if all_tests_pass: |
| 274 | + logger.info("Attention output Passed!") |
| 275 | + else: |
| 276 | + logger.warning("Attention output Failed!") |
| 277 | + assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" |
0 commit comments