forked from tenstorrent/tt-metal
-
Notifications
You must be signed in to change notification settings - Fork 0
Add Experimental Support for Gemma variants [1B, 27B] #27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
MohammedTaherMcW
wants to merge
16
commits into
main
Choose a base branch
from
mcw/gemma_3_27b/pr_1_experimental
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
7de6963
Add Base commit for Gemma3
MohammedTaherMcW c7d3ea4
Add Gemma Text and Vision model support
MohammedTaherMcW 3659202
experimantal gemma27b CCL changes
MohammedTaherMcW 7d7e79b
Add test_end2end script Multidevice support for Gemma
MohammedTaherMcW c0afe26
Fix submodule tests for Gemma
MohammedTaherMcW 54c7170
Fix Rebase issue
MohammedTaherMcW f175b08
Add Gemma-3-1b-it support
MohammedTaherMcW 530d180
Remove experimental Gemma-3-4b-it
MohammedTaherMcW bd90541
Fix end to end Gemma model
MohammedTaherMcW 69297ec
Fix Repetition issue
MohammedTaherMcW a1f090b
Fix Trace issue
MohammedTaherMcW b98bb19
Modify Attention mask logic
MohammedTaherMcW 8da698c
Fix Gemma vision generator
jennychristopher 0d27fda
Add sliding window mask support in SDPA_decode
MohammedTaherMcW 7ee5ddc
RMS Norm Fix
Bhuvanesh194Sankar ac90423
Addressed the review comments
Bhuvanesh194Sankar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,289 @@ | ||
| """Gemma-3 Test for Text Attention""" | ||
|
|
||
|
|
||
| # SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC | ||
|
|
||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import os | ||
|
|
||
| import pytest | ||
| import torch | ||
| from loguru import logger | ||
|
|
||
| import ttnn | ||
| from models.experimental.gemma3.tt.attention import Attention | ||
| from models.tt_transformers.tt.common import PagedAttentionConfig, precompute_freqs | ||
| from models.tt_transformers.tt.rope import RotarySetup | ||
| from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull | ||
|
|
||
| from models.tt_transformers.tt.model_config import ModelArgs | ||
| from models.tt_transformers.tt.ccl import TT_CCL | ||
|
|
||
|
|
||
| @torch.no_grad() | ||
| @skip_for_grayskull("Requires wormhole_b0 to run") | ||
| @pytest.mark.parametrize( | ||
| "mesh_device", | ||
| [ | ||
| {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( | ||
| os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) | ||
| ) | ||
| ], | ||
| indirect=True, | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| "paged_attention", | ||
| ( | ||
| True, | ||
| # False, | ||
| ), | ||
| ids=( | ||
| "paged_attention", | ||
| # "default_attention", | ||
| ), | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| "page_params", | ||
| [{"page_block_size": 32, "page_max_num_blocks": 1024}], | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| "batch_size", | ||
| (1,), | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| "max_seq_len", | ||
| (1,), # For decode-only unit test, there's no need to run with large sequence lengths | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| "device_params", | ||
| [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], | ||
| indirect=True, | ||
| ) | ||
| def test_attention_inference( | ||
| max_seq_len, | ||
| batch_size, | ||
| paged_attention, | ||
| page_params, | ||
| mesh_device, | ||
| reset_seeds, | ||
| device_params, | ||
| # ensure_gc, | ||
| ): | ||
| dtype = ttnn.bfloat16 | ||
| pcc = 0.99 | ||
|
|
||
| model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=128) | ||
| model_args.n_layers = 6 # For the unit test, just run a single layer | ||
|
|
||
| state_dict = model_args.load_state_dict() | ||
|
|
||
| first_layer_prefix = model_args.get_state_dict_prefix("Attention", 0) + "." | ||
| # Ref model needs partial state dict, but our models use full state dict keys as cached weight names | ||
| # partial_state_dict = { | ||
| # k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) | ||
| # } | ||
|
|
||
| reference_model = model_args.reference_attention() | ||
| # reference_model.load_state_dict(partial_state_dict) | ||
|
|
||
| seq_len = 1 | ||
|
|
||
| generation_start_pos = 0 | ||
| generation_length = 10 | ||
| all_tests_pass = True | ||
|
|
||
| # Setup RoPE transformation matrices | ||
| rope_setup = RotarySetup( | ||
| mesh_device, | ||
| batch_size, | ||
| model_args.head_dim, | ||
| model_args.max_seq_len, | ||
| model_args.rope_theta, | ||
| model_args.rope_scaling, | ||
| ) | ||
|
|
||
| transformation_mats = rope_setup.get_both_trans_mats() | ||
|
|
||
| page_table_tt = None | ||
| paged_attention_config = None | ||
|
|
||
| if paged_attention: | ||
| paged_attention_config = PagedAttentionConfig( | ||
| block_size=page_params["page_block_size"], | ||
| max_num_blocks=page_params["page_max_num_blocks"], | ||
| ) | ||
|
|
||
| # Implied shuffling of blocks | ||
| permutation = torch.randperm(paged_attention_config.max_num_blocks) | ||
| # Page table which maps virtual blocks to physical | ||
| reverse_permutation = torch.argsort(permutation) | ||
| page_table = reverse_permutation.reshape( | ||
| model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size | ||
| ) | ||
| page_table_tt = ttnn.from_torch( | ||
| page_table, | ||
| device=mesh_device, | ||
| dtype=ttnn.int32, | ||
| layout=ttnn.ROW_MAJOR_LAYOUT, | ||
| mesh_mapper=ttnn.ShardTensor2dMesh( | ||
| mesh_device, | ||
| dims=(None, -2) if (model_args.is_galaxy and batch_size > 1) else (None, None), | ||
| mesh_shape=model_args.cluster_shape, | ||
| ), | ||
| ) | ||
|
|
||
| tt_ccl = TT_CCL(mesh_device) | ||
| tt_model = Attention( | ||
| mesh_device, | ||
| tt_ccl, | ||
| state_dict, | ||
| weight_cache_path=model_args.weight_cache_path(dtype), | ||
| layer_num=0, | ||
| dtype=dtype, | ||
| transformation_mats=transformation_mats, | ||
| configuration=model_args, | ||
| paged_attention_config=paged_attention_config, | ||
| ) | ||
|
|
||
| cos, sin = precompute_freqs( | ||
| model_args.head_dim, | ||
| model_args.max_seq_len * 2, | ||
| model_args.rope_theta, | ||
| model_args.rope_scaling.factor if model_args.rope_scaling else None, | ||
| model_args.rope_scaling.original_max_position_embeddings if model_args.rope_scaling else None, | ||
| rope_type="linear", | ||
| ) | ||
| freqs_cis = torch.complex(cos, sin) | ||
|
|
||
| # Initial positions | ||
| current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)]) | ||
| current_pos_tensor = ttnn.from_torch( | ||
| current_pos, | ||
| device=mesh_device, | ||
| dtype=ttnn.int32, | ||
| mesh_mapper=ttnn.ShardTensor2dMesh( | ||
| mesh_device, | ||
| dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), | ||
| mesh_shape=model_args.cluster_shape, | ||
| ), | ||
| ) | ||
|
|
||
| for i in range(generation_length): | ||
| # 70B attention block typically sees tensors with mean 0 and std 0.03 - 0.05 in layer 1 | ||
| pt_attention_input = torch.randn(batch_size, seq_len, model_args.dim) # Qwen2.5 0.5B sees 0.1 to 2.1 | ||
|
|
||
| tt_attention_input = pt_attention_input.clone() | ||
|
|
||
| attention_input = model_args.prepare_residual_tensor_decode( | ||
| tt_attention_input, | ||
| model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"], | ||
| force_replicated=False if model_args.is_galaxy else True, | ||
| ) | ||
|
|
||
| # Get cos/sin matrices for the current position of each user | ||
| rot_mats = rope_setup.get_rot_mats(current_pos) | ||
|
|
||
| tt_out = tt_model( | ||
| attention_input, | ||
| current_pos_tensor, | ||
| rot_mats=rot_mats, | ||
| mode="decode", | ||
| page_table=page_table_tt, | ||
| ) | ||
| # multi-device attention module returns replicated output | ||
| tt_out = ttnn.to_torch( | ||
| tt_out, | ||
| mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape), | ||
| ) | ||
| tt_output_torch = tt_out[:, 0:1, : model_args.max_batch_size, : model_args.dim].view(-1, 1, model_args.dim) | ||
|
|
||
| # In this test all users have the same position (if using batch > 1) | ||
| freqs_cis_i = freqs_cis[current_pos[0], :].unsqueeze(0) | ||
|
|
||
| reference_output = reference_model(pt_attention_input, current_pos[0], freqs_cis_i, mask=None) | ||
|
|
||
| passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) | ||
|
|
||
| logger.info(comp_allclose(reference_output, tt_output_torch)) | ||
| logger.info(f"PCC: {pcc_message}") | ||
| if passing: | ||
| logger.info(f"[pos={current_pos[0]}] Attention Passed!") | ||
| else: | ||
| logger.warning(f"[pos={current_pos[0]}] Attention Failed!") | ||
| all_tests_pass = False | ||
|
|
||
| # Increment position | ||
| current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch_size)]) | ||
| current_pos_tensor = ttnn.from_torch( | ||
| current_pos, | ||
| device=mesh_device, | ||
| dtype=ttnn.int32, | ||
| mesh_mapper=ttnn.ShardTensor2dMesh( | ||
| mesh_device, | ||
| dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), | ||
| mesh_shape=model_args.cluster_shape, | ||
| ), | ||
| ) | ||
|
|
||
| check_kv_cache = True | ||
| if check_kv_cache: | ||
| # PyTorch output -------------------------------------------------------------------- | ||
| pytorch_layer_present = [ | ||
| reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] | ||
| reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] | ||
| ] | ||
| # TT hardware execution ------------------------------------------------------------- | ||
| if paged_attention: | ||
| tt_layer_present = [ | ||
| ( | ||
| ttnn.to_torch( | ||
| cache, | ||
| mesh_composer=ttnn.ConcatMesh2dToTensor( | ||
| mesh_device, | ||
| dims=(1, 3) if model_args.is_galaxy else (0, 1), | ||
| mesh_shape=model_args.cluster_shape, | ||
| ), | ||
| )[reverse_permutation][:, : model_args.n_kv_heads, :, : model_args.head_dim] | ||
| .reshape( | ||
| model_args.max_batch_size, | ||
| paged_attention_config.max_num_blocks // model_args.max_batch_size, | ||
| model_args.n_kv_heads, | ||
| paged_attention_config.block_size, | ||
| model_args.head_dim, | ||
| ) | ||
| .transpose(1, 2) | ||
| .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ | ||
| :batch_size, ... | ||
| ] | ||
| ) | ||
| for cache in tt_model.layer_past | ||
| ] | ||
| else: | ||
| tt_layer_present = [ | ||
| ttnn.to_torch( | ||
| cache, | ||
| mesh_composer=ttnn.ConcatMesh2dToTensor( | ||
| mesh_device, | ||
| dims=(1, 0) if model_args.is_galaxy else (0, 1), | ||
| mesh_shape=model_args.cluster_shape, | ||
| ), | ||
| )[:batch_size, :, :, :] | ||
| for cache in tt_model.layer_past | ||
| ] | ||
| for label, cache_pt, cache_tt in zip(["K", "V"], pytorch_layer_present, tt_layer_present): | ||
| cache_length_to_check = min(model_args.max_seq_len, generation_start_pos + i + 1) | ||
| cache_pt = cache_pt[:, :, generation_start_pos:cache_length_to_check, :] | ||
| cache_tt = cache_tt[:, :, generation_start_pos:cache_length_to_check, :] | ||
| does_pass, output_pcc = comp_pcc(cache_pt, cache_tt, pcc) | ||
| logger.info(f"{label} cache output: {output_pcc}") | ||
| if does_pass: | ||
| logger.info(f"{label} cache Passed!") | ||
| else: | ||
| logger.warning(f"{label} Cache Failed! PCC value is lower than {pcc}") | ||
| all_tests_pass = False | ||
|
|
||
| if all_tests_pass: | ||
| logger.info("Attention output Passed!") | ||
| else: | ||
| logger.warning("Attention output Failed!") | ||
| assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're loading the first layer, which is a sliding attention layer. The rotary setup and precompute_freqs are set up for a global attention layer, so that should not be matching the reference output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe another test with a global attention layer is also interesting?