diff --git a/models/experimental/gemma3/tests/test_attention.py b/models/experimental/gemma3/tests/test_attention.py new file mode 100644 index 000000000000..5e1c1a905cde --- /dev/null +++ b/models/experimental/gemma3/tests/test_attention.py @@ -0,0 +1,289 @@ +"""Gemma-3 Test for Text Attention""" + + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.experimental.gemma3.tt.attention import Attention +from models.tt_transformers.tt.common import PagedAttentionConfig, precompute_freqs +from models.tt_transformers.tt.rope import RotarySetup +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.ccl import TT_CCL + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False, + ), + ids=( + "paged_attention", + # "default_attention", + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (1,), # For decode-only unit test, there's no need to run with large sequence lengths +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_attention_inference( + max_seq_len, + batch_size, + paged_attention, + page_params, + mesh_device, + reset_seeds, + device_params, + # ensure_gc, +): + dtype = ttnn.bfloat16 + pcc = 0.99 + + model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=128) + model_args.n_layers = 6 # For the unit test, just run a single layer + + state_dict = model_args.load_state_dict() + + first_layer_prefix = model_args.get_state_dict_prefix("Attention", 0) + "." + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + # partial_state_dict = { + # k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + # } + + reference_model = model_args.reference_attention() + # reference_model.load_state_dict(partial_state_dict) + + seq_len = 1 + + generation_start_pos = 0 + generation_length = 10 + all_tests_pass = True + + # Setup RoPE transformation matrices + rope_setup = RotarySetup( + mesh_device, + batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_theta, + model_args.rope_scaling, + ) + + transformation_mats = rope_setup.get_both_trans_mats() + + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, -2) if (model_args.is_galaxy and batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + tt_ccl = TT_CCL(mesh_device) + tt_model = Attention( + mesh_device, + tt_ccl, + state_dict, + weight_cache_path=model_args.weight_cache_path(dtype), + layer_num=0, + dtype=dtype, + transformation_mats=transformation_mats, + configuration=model_args, + paged_attention_config=paged_attention_config, + ) + + cos, sin = precompute_freqs( + model_args.head_dim, + model_args.max_seq_len * 2, + model_args.rope_theta, + model_args.rope_scaling.factor if model_args.rope_scaling else None, + model_args.rope_scaling.original_max_position_embeddings if model_args.rope_scaling else None, + rope_type="linear", + ) + freqs_cis = torch.complex(cos, sin) + + # Initial positions + current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + for i in range(generation_length): + # 70B attention block typically sees tensors with mean 0 and std 0.03 - 0.05 in layer 1 + pt_attention_input = torch.randn(batch_size, seq_len, model_args.dim) # Qwen2.5 0.5B sees 0.1 to 2.1 + + tt_attention_input = pt_attention_input.clone() + + attention_input = model_args.prepare_residual_tensor_decode( + tt_attention_input, + model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"], + force_replicated=False if model_args.is_galaxy else True, + ) + + # Get cos/sin matrices for the current position of each user + rot_mats = rope_setup.get_rot_mats(current_pos) + + tt_out = tt_model( + attention_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table_tt, + ) + # multi-device attention module returns replicated output + tt_out = ttnn.to_torch( + tt_out, + mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape), + ) + tt_output_torch = tt_out[:, 0:1, : model_args.max_batch_size, : model_args.dim].view(-1, 1, model_args.dim) + + # In this test all users have the same position (if using batch > 1) + freqs_cis_i = freqs_cis[current_pos[0], :].unsqueeze(0) + + reference_output = reference_model(pt_attention_input, current_pos[0], freqs_cis_i, mask=None) + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + if passing: + logger.info(f"[pos={current_pos[0]}] Attention Passed!") + else: + logger.warning(f"[pos={current_pos[0]}] Attention Failed!") + all_tests_pass = False + + # Increment position + current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + check_kv_cache = True + if check_kv_cache: + # PyTorch output -------------------------------------------------------------------- + pytorch_layer_present = [ + reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] + reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] + ] + # TT hardware execution ------------------------------------------------------------- + if paged_attention: + tt_layer_present = [ + ( + ttnn.to_torch( + cache, + mesh_composer=ttnn.ConcatMesh2dToTensor( + mesh_device, + dims=(1, 3) if model_args.is_galaxy else (0, 1), + mesh_shape=model_args.cluster_shape, + ), + )[reverse_permutation][:, : model_args.n_kv_heads, :, : model_args.head_dim] + .reshape( + model_args.max_batch_size, + paged_attention_config.max_num_blocks // model_args.max_batch_size, + model_args.n_kv_heads, + paged_attention_config.block_size, + model_args.head_dim, + ) + .transpose(1, 2) + .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ + :batch_size, ... + ] + ) + for cache in tt_model.layer_past + ] + else: + tt_layer_present = [ + ttnn.to_torch( + cache, + mesh_composer=ttnn.ConcatMesh2dToTensor( + mesh_device, + dims=(1, 0) if model_args.is_galaxy else (0, 1), + mesh_shape=model_args.cluster_shape, + ), + )[:batch_size, :, :, :] + for cache in tt_model.layer_past + ] + for label, cache_pt, cache_tt in zip(["K", "V"], pytorch_layer_present, tt_layer_present): + cache_length_to_check = min(model_args.max_seq_len, generation_start_pos + i + 1) + cache_pt = cache_pt[:, :, generation_start_pos:cache_length_to_check, :] + cache_tt = cache_tt[:, :, generation_start_pos:cache_length_to_check, :] + does_pass, output_pcc = comp_pcc(cache_pt, cache_tt, pcc) + logger.info(f"{label} cache output: {output_pcc}") + if does_pass: + logger.info(f"{label} cache Passed!") + else: + logger.warning(f"{label} Cache Failed! PCC value is lower than {pcc}") + all_tests_pass = False + + if all_tests_pass: + logger.info("Attention output Passed!") + else: + logger.warning("Attention output Failed!") + assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3/tests/test_decoder.py b/models/experimental/gemma3/tests/test_decoder.py new file mode 100644 index 000000000000..0a40ff780bb7 --- /dev/null +++ b/models/experimental/gemma3/tests/test_decoder.py @@ -0,0 +1,229 @@ +"""Gemma3 Test for Text Decoder""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +from loguru import logger +import os +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.experimental.gemma3.tt.decoder import TransformerBlock +from models.utility_functions import ( + comp_pcc, + comp_allclose, +) +from models.utility_functions import skip_for_grayskull +from models.tt_transformers.tt.common import PagedAttentionConfig +from models.tt_transformers.tt.rope import RotarySetup +from models.tt_transformers.tt.ccl import TT_CCL +from models.tt_transformers.tt.common import PagedAttentionConfig, precompute_freqs + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False + ), + ids=( + "paged_attention", + # "default_attention" + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (256,), # For decode-only unit test, there's no need to run with large sequence lengths +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_decoder_inference( + max_seq_len, + batch_size, + paged_attention, + page_params, + mesh_device, + device_params, + reset_seeds, +): + dtype = ttnn.bfloat16 + + pcc_required = 0.85 + model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) + model_args.n_layers = 1 + + state_dict = model_args.load_state_dict() + + reference_model = model_args.reference_decoder() + + generation_start_pos = 0 + generation_length = 3 + all_tests_pass = True + + rope_setup = RotarySetup( + mesh_device, + batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_theta, + model_args.rope_scaling, + ) + transformation_mats = rope_setup.get_both_trans_mats() + + # Prepare page table for paged attention + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, -2) if (model_args.is_galaxy and batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + # Initialize TT model + tt_ccl = TT_CCL(mesh_device) + tt_model = TransformerBlock( + args=model_args, + mesh_device=mesh_device, + tt_ccl=tt_ccl, + dtype=dtype, + state_dict=state_dict, + layer_num=0, + weight_cache_path=model_args.weight_cache_path(dtype), + transformation_mats=transformation_mats, + paged_attention_config=paged_attention_config, + ) + + seqlen = 1 + + cos, sin = precompute_freqs( + model_args.head_dim, + model_args.max_seq_len * 2, + model_args.rope_theta, + model_args.rope_scaling.factor if model_args.rope_scaling else None, + model_args.rope_scaling.original_max_position_embeddings if model_args.rope_scaling else None, + rope_type="linear", + ) + freqs_cis = torch.complex(cos, sin) + + # Initial positions + current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + for i in range(generation_length): + pt_decode_input = (torch.rand(batch_size, seqlen, model_args.dim) * 2) - 1 + logger.info(f"[Decoder] Generating token {i}") + + tt_decode_input = pt_decode_input.clone() + + decode_input = model_args.prepare_residual_tensor_decode( + tt_decode_input, + # ttnn.DRAM_MEMORY_CONFIG, + model_args.model_config["DECODE_RESIDUAL_MEMCFG"], + ) + + # Get cos/sin matrices for the current position of each user + rot_mat_global = rope_setup.get_rot_mats(current_pos) + rot_mat_local = rope_setup.get_rot_mats(current_pos) + + # Run TT model + tt_out = tt_model( + decode_input, + current_pos_tensor, + rot_mats_global=rot_mat_global, + rot_mats_local=rot_mat_local, + mode="decode", + page_table=page_table_tt, + ) + tt_out = ttnn.to_torch( + tt_out, + mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape), + ) + + tt_output_torch = tt_out[:, 0:1, : model_args.max_batch_size, : model_args.dim].view(-1, 1, model_args.dim) + + # Reference model + ref_output = reference_model(pt_decode_input, current_pos[0], None, mask=None) + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + ref_output = ref_output[non_zero_indices] + + passing, pcc_message = comp_pcc(ref_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(ref_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("Decoder Block Passed!") + else: + logger.warning("Decoder Block Failed!") + all_tests_pass = False + + # Increment position + current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + if all_tests_pass: + logger.info(f"All {generation_length} decode iterations Passed!") + else: + logger.warning("One or more iterations of decode Failed!") + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3/tests/test_embedding.py b/models/experimental/gemma3/tests/test_embedding.py new file mode 100644 index 000000000000..751fbbbd824d --- /dev/null +++ b/models/experimental/gemma3/tests/test_embedding.py @@ -0,0 +1,100 @@ +"""Gemma3 test for Text Embedding""" + + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.embedding import Embedding, ScaledEmbedding +from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize("use_scaled_embedding", (False, True)) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (128,), # For decode-only unit test, there's no need to run with large sequence lengths +) +def test_embedding(max_seq_len, batch_size, mesh_device, reset_seeds, ensure_gc, use_scaled_embedding): + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len, cache_hf=True) + model_args.n_layers = 1 + + state_dict = model_args.load_state_dict() + tokenizer = model_args.tokenizer + + if use_scaled_embedding: + model_args.embed_scale = model_args.dim**0.5 + logger.info(f"Using scaled embedding with scale {model_args.embed_scale}") + + reference_emb = model_args.reference_embedding() + layer_name = "tok_embeddings.weight" + reference_emb.load_state_dict({"emb.weight": state_dict[layer_name]}) + + emb_kwargs = { + "mesh_device": mesh_device, + "args": model_args, + "weight_cache_path": model_args.weight_cache_path(dtype), + "state_dict": state_dict, + "dtype": dtype, + } + if use_scaled_embedding: + emb_kwargs["embed_scale"] = model_args.embed_scale + emb_cls = ScaledEmbedding + else: + emb_cls = Embedding + + tt_emb = emb_cls(**emb_kwargs) + + prompts = ["Joy"] * 32 + pt_input = torch.tensor([model_args.encode_prompt(prompt, instruct=False) for prompt in prompts]) + reference_output = reference_emb(pt_input) + logger.info(f"reference_output: {reference_output.shape}") + + tt_input = ttnn.from_torch( + pt_input.squeeze(1), + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + ) + tt_output = tt_emb(tt_input) + tt_output = ttnn.multiply(tt_output, model_args.embed_scale) + tt_output_torch = ttnn.to_torch( + tt_output, + mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(0, -1), mesh_shape=model_args.cluster_shape), + )[:32].view(reference_output.shape) + logger.info(f"tt_output_torch: {tt_output_torch.shape}") + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + if passing: + logger.info("embedding Passed!") + else: + logger.warning("embedding Failed!") + + assert passing, f"embedding output does not meet PCC requirement {0.99}." diff --git a/models/experimental/gemma3/tests/test_lm_head.py b/models/experimental/gemma3/tests/test_lm_head.py new file mode 100644 index 000000000000..fdecdec31ebe --- /dev/null +++ b/models/experimental/gemma3/tests/test_lm_head.py @@ -0,0 +1,106 @@ +"""Gemma3 Test for lm_head""" + + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.experimental.gemma3.tt.lm_head import LMHead +from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from models.tt_transformers.tt.ccl import TT_CCL + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "seq_len", + (32,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_lm_head_inference(seq_len, batch_size, mesh_device, reset_seeds): + dtype = ttnn.bfloat8_b + + model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=seq_len) + model_args.n_layers = 1 + state_dict = model_args.load_state_dict() + + state_dict_prefix = model_args.get_state_dict_prefix("", None) + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + partial_state_dict = { + "weight": state_dict[f"{state_dict_prefix}output.weight"], + } + + model_args.WEIGHTS_DTYPE = dtype + reference_model = model_args.reference_lm_head() + reference_model.load_state_dict(partial_state_dict) + + tt_ccl = TT_CCL(mesh_device) + tt_model = LMHead( + args=model_args, + mesh_device=mesh_device, + tt_ccl=tt_ccl, + dtype=dtype, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + max_columns_per_device=model_args.max_columns_per_device_lm_head, + ) + + torch_input = torch.randn(1, 1, seq_len, model_args.dim) + reference_output = reference_model(torch_input) + tt_input = ttnn.from_torch( + torch_input, + device=mesh_device, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, dims=(None, 3) if model_args.is_galaxy else (None, None), mesh_shape=model_args.cluster_shape + ), + dtype=ttnn.bfloat16, + memory_config=model_args.model_config["LM_HEAD_INPUT_MEMCFG"], + layout=ttnn.TILE_LAYOUT, + ) + + logger.info("Run LM_Head") + tt_output = tt_model(tt_input) + tt_output_torch = ttnn.to_torch( + tt_output, + mesh_composer=ttnn.ConcatMesh2dToTensor( + mesh_device, model_args.cluster_shape, dims=(3, 1) if model_args.is_galaxy else (1, 3) + ), + ) + tt_output_torch = tt_output_torch[:, 0:1, :, : model_args.vocab_size] + + pcc_required = 0.99 + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + if passing: + logger.info("LM_Head Passed!") + else: + logger.warning("LM_Head Failed!") + + assert passing, f"LM_Head output does not meet PCC requirement {pcc_required}: {pcc_message}." diff --git a/models/experimental/gemma3/tests/test_mlp.py b/models/experimental/gemma3/tests/test_mlp.py new file mode 100644 index 000000000000..02cec5c19101 --- /dev/null +++ b/models/experimental/gemma3/tests/test_mlp.py @@ -0,0 +1,115 @@ +"""Gemma3 Test for Text MLP""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +from loguru import logger + +import torch +import pytest +import os +import ttnn + +from models.tt_transformers.tests.test_utils import get_ref_model_dype +from models.experimental.gemma3.tt.mlp import MLP +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from models.tt_transformers.tt.ccl import TT_CCL +from models.tt_transformers.tt.model_config import ModelArgs + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_mlp_inference(seq_len, batch_size, reset_seeds, mesh_device, device_params): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=128) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + # # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + # first_layer_prefix = "layers.0.feed_forward" + first_layer_prefix = tt_model_args.get_state_dict_prefix("MLP", 0) + + partial_state_dict = { + k[len(first_layer_prefix) + 1 :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model = tt_model_args.reference_mlp() # Gemma3 MLP + reference_model.load_state_dict(partial_state_dict) + + tt_ccl = TT_CCL(mesh_device) + tt_model = MLP( + mesh_device=mesh_device, + tt_ccl=tt_ccl, + args=tt_model_args, + state_dict=state_dict, + weight_cache_path=tt_model_args.weight_cache_path(dtype), + layer_num=0, + dtype=dtype, + model_config=tt_model_args.get_model_config(), + ) + + torch_input = torch.randn( + 1, 1, seq_len, tt_model_args.dim, dtype=get_ref_model_dype(reference_model, tt_model_args.model_name) + ) + reference_output = reference_model(torch_input) + + tt_input = ttnn.from_torch( + torch_input, + device=mesh_device, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, 3) if tt_model_args.is_galaxy else (None, None), + mesh_shape=tt_model_args.cluster_shape, + ), # When both dims are None, the mapper used is `ReplicateTensorToMesh` + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + ) + + logger.info("Run MLP") + tt_output = tt_model(tt_input, mode) + + tt_output_torch = ttnn.to_torch( + tt_output, + mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=tt_model_args.cluster_shape), + ) + + # tt_output_torch = tt_output_torch[:, :1, :, :] + + pcc_required = 0.99 + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch[0])) + logger.info(f"PCC: {pcc_message}") + if passing: + logger.info("MLP Passed!") + else: + logger.warning("MLP Failed!") + + assert passing, f"MLP output does not meet PCC requirement {pcc_required}: {pcc_message}." diff --git a/models/experimental/gemma3/tests/test_rmsnorm.py b/models/experimental/gemma3/tests/test_rmsnorm.py new file mode 100644 index 000000000000..2e734273505b --- /dev/null +++ b/models/experimental/gemma3/tests/test_rmsnorm.py @@ -0,0 +1,169 @@ +"""Gemma3 Test for Text RMSNorm""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +from loguru import logger + +import torch +import pytest +import os + +import ttnn +from models.experimental.gemma3.tt.rmsnorm import RMSNorm +from models.tt_transformers.tt.distributed_norm import DistributedNorm + +from models.tt_transformers.tt.ccl import TT_CCL +from models.utility_functions import comp_allclose, skip_for_grayskull +from models.tt_transformers.tt.model_config import ModelArgs + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "tt_layer_name, torch_layer_name, dim", + ( + ("norm", "norm", 1152), + ("layers.0.attention_norm", "layers.0.input_layernorm", 1152), + ("layers.0.ffn_norm", "layers.0.post_attention_layernorm", 1152), + ("layers.0.pre_feedforward_layernorm", "layers.0.pre_feedforward_layernorm", 1152), + ("layers.0.post_feedforward_layernorm", "layers.0.post_feedforward_layernorm", 1152), + ("layers.0.attention.q_norm", "layers.0.self_attn.q_norm", 256), + ("layers.0.attention.k_norm", "layers.0.self_attn.k_norm", 256), + ), +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_rmsnorm_inference( + seq_len, batch_size, reset_seeds, mesh_device, tt_layer_name, torch_layer_name, device_params, dim +): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + mesh_device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + reference_model = tt_model_args.reference_transformer(wrap=False) # Gemma3 Entire Model + reference_model = reference_model.model.get_submodule(torch_layer_name) + + state_dict_prefix = "" + first_layer_prefix = state_dict_prefix + tt_layer_name + "." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model.load_state_dict(partial_state_dict) + tt_ccl = TT_CCL(mesh_device) + if "q_norm" in tt_layer_name or "k_norm" in tt_layer_name: + tt_model = RMSNorm( + device=mesh_device, + dim=dim, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key=tt_layer_name, + weight_dtype=dtype, + is_distributed=False, + sharded_program_config=None, + sharded_output_config=None, + tt_ccl=tt_ccl, + ) + else: + tt_inner_norm = RMSNorm( + device=mesh_device, + dim=tt_model_args.dim, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key=tt_layer_name, + weight_dtype=dtype, + is_distributed=tt_model_args.is_distributed_norm, + sharded_program_config=tt_model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=tt_model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], + tt_ccl=tt_ccl, + ) + + # Wrap it in DistributedNorm + tt_model = DistributedNorm(tt_inner_norm, tt_model_args, tt_ccl, TG=tt_model_args.is_galaxy) + if "q_norm" in tt_layer_name or "k_norm" in tt_layer_name: + input = torch.rand(1, 1, dim) + else: + input = torch.rand(1, 1, 32, dim) + + reference_output = reference_model(input) + + # DistributedNorm inputs are fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode) + if "q_norm" in tt_layer_name or "k_norm" in tt_layer_name: + tt_input = ttnn.from_torch( + input, + device=mesh_device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + memory_config=( + tt_model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] + if mode == "decode" + else ttnn.DRAM_MEMORY_CONFIG + ), + ) + else: + tt_input = ttnn.from_torch( + input, + device=mesh_device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=(None, -1), mesh_shape=tt_model_args.cluster_shape), + memory_config=( + tt_model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] + if mode == "decode" + else ttnn.DRAM_MEMORY_CONFIG + ), + ) + + tt_output = tt_model(tt_input, mode=mode) + + # DistributedNorm outputs are replicated across devices + tt_output_torch = ttnn.to_torch( + tt_output, + mesh_composer=ttnn.ConcatMesh2dToTensor( + mesh_device, dims=(0, 2) if tt_model_args.is_galaxy else (2, 0), mesh_shape=tt_model_args.cluster_shape + ), + )[:1, :, :] + + # tt_output_torch = tt_output_torch.view(1, 1, tt_model_args.dim) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + pcc_message = "RMSNORM" + logger.info(f"PCC: {torch_layer_name} , {pcc_message}") + + passing = 0.99 + if passing: + logger.info("rms_norm Passed!") + else: + logger.warning("rms_norm Failed!") + + assert passing, f"rms_norm output does not meet PCC requirement {0.99}." diff --git a/models/experimental/gemma3/tests/vision_tests/test_end2end.py b/models/experimental/gemma3/tests/vision_tests/test_end2end.py new file mode 100644 index 000000000000..802b253f38ac --- /dev/null +++ b/models/experimental/gemma3/tests/vision_tests/test_end2end.py @@ -0,0 +1,756 @@ +""" End-to-end test for Gemma3 vision-text pipeline.""" +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +import torch +import pytest +from loguru import logger +import os +import ttnn +from models.tt_transformers.tt.common import ( + encode_prompt_hf, + sample_host, + PagedAttentionConfig, + preprocess_inputs_prefill, +) +from models.tt_transformers.tt.model_config import DecodersPrecision + +from models.experimental.gemma3.tt.text_model import Gemma3Transformer +from models.experimental.gemma3.tt.gemma_vision_crossattention import TtGemmaTransformerVision +from models.experimental.gemma3.tt.gemma3_generator import Gemma3Generator +from models.utility_functions import ( + comp_pcc, + comp_allclose, +) +from models.utility_functions import skip_for_grayskull, skip_for_blackhole +from models.tt_transformers.tt.model_config import HfModelWrapper + +from models.tt_transformers.tt.model_config import ModelArgs +from transformers import AutoProcessor, AutoTokenizer + +import re + + +def parse_chat_output(text): + """Parse chat output format from generated text.""" + pattern = r"<\|(?Puser|assistant)\|>\s*(?P.*?)(?=<\|(?:user|assistant|end)\|>|$)" + matches = re.finditer(pattern, text, re.DOTALL) + return [(match.group("role"), match.group("message").strip()) for match in matches] + + +def display_chat(logger, conversation): + """Display chat conversation in formatted output.""" + for role, message in conversation: + if role == "user": + logger.info(f"👤 User: {message}") + elif role == "assistant": + logger.info(f"🤖 Assistant: {message}") + + +def setup_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations): + """Setup model arguments and configuration.""" + instruct = True if weights == "instruct" else False + + model_args = ModelArgs( + mesh_device=mesh_device, + instruct=instruct, + optimizations=optimizations, + max_seq_len=max_seq_len, + max_batch_size=batch_size, + ) + + return model_args, instruct + + +def setup_prompts_and_tokenizer(model_args, instruct): + """Setup prompts and tokenizer for the test.""" + prompts = ["Write a essay about Lion"] * model_args.max_batch_size + tokenizer = model_args.tokenizer + + if instruct: + encoded_prompts = encode_prompt_hf(tokenizer=tokenizer, prompt_text=prompts[0]) + else: + encoded_prompts = [model_args.encode_prompt(prompt, instruct=False) for prompt in prompts] + + return prompts, tokenizer, encoded_prompts + + +def setup_reference_model(model_args, run_ref_pt): + """Setup reference PyTorch model and embedding.""" + if run_ref_pt: + reference_transformer_model = model_args.reference_transformer(wrap=False) + reference_model = HfModelWrapper(reference_transformer_model, model_args.head_dim) + logger.info("Finished loading reference model.") + embd = model_args.reference_embedding(reference_transformer_model) + else: + reference_model = None + embd = model_args.reference_embedding() + + return reference_model, embd + + +def setup_paged_attention(paged_attention, page_params, model_args, mesh_device): + """Setup paged attention configuration and page table.""" + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, -2) if model_args.max_batch_size > 1 else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + return paged_attention_config, page_table_tt + + +# ============================================================================= +# NEW E2E PIPELINE COMPONENTS - Following SOLID Principles +# ============================================================================= + + +def setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations): + """Setup model arguments for vision-enabled model (Single Responsibility).""" + instruct = True if weights == "instruct" else False + + model_args = ModelArgs( + mesh_device=mesh_device, + instruct=instruct, + optimizations=optimizations, + max_seq_len=max_seq_len, + max_batch_size=batch_size, + ) + return model_args, instruct + + +def setup_vision_prompts_and_tokenizer(model_args, instruct): + """Setup multimodal prompts and tokenizer for vision-enabled model.""" + # Create multimodal messages similar to test_end2end.py + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is your favorite condiment? There are so many condiments to choose from, each bringing its unique flavor and texture to enhance different dishes. Do you prefer the classic taste of ketchup, the creamy richness of mayonnaise, the spicy kick of mustard, or perhaps something more exotic like sriracha or hoisin sauce? Maybe you enjoy the tangy zest of salsa or the smooth and savory taste of aioli. Share what your favorite condiment is and why you love it. Does it remind you of a specific dish or meal?", + }, + ], + } + ] + + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", + }, + {"type": "text", "text": "Describe this image in detail."}, + ], + } + ] + + tokenizer = model_args.tokenizer + return messages, tokenizer + + +def setup_vision_reference_model(model_args, run_ref_pt): + """Setup reference vision-enabled model (Open/Closed Principle).""" + if run_ref_pt: + reference_transformer_model = model_args.reference_vision_transformer(wrap=False) + reference_model = HfModelWrapper(reference_transformer_model, model_args.head_dim) + logger.info("Finished loading reference vision model.") + embd = model_args.reference_embedding(reference_transformer_model) + else: + reference_model = None + embd = model_args.reference_embedding() + + return reference_model, embd + + +def process_real_vision_inputs(messages, model_args): + """Process real image inputs using AutoProcessor (Interface Segregation).""" + model_id = model_args.CKPT_DIR + + try: + # Try loading processor (works for models that has preprocessor_config.json) + processor = AutoProcessor.from_pretrained(model_id) + except OSError: + # Fallback to tokenizer + processor = AutoTokenizer.from_pretrained(model_id) + + # Process the multimodal messages similar to test_end2end.py + encoded = processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to(torch.bfloat16) + + input_ids = encoded["input_ids"] + pixel_values = encoded["pixel_values"] if "pixel_values" in encoded else None + attention_mask = encoded["attention_mask"] + + # logger.info(f"Processed vision inputs - input_ids: {input_ids.shape}, pixel_values: {pixel_values.shape}") + + return { + "input_ids": input_ids, + "pixel_values": pixel_values, + "attention_mask": attention_mask, + "processor": processor, + "input_prompts": messages, + } + + +# Legacy function removed - vision model now part of multimodal model + + +def load_separate_models_like_test_end2end(model_args, mesh_device, dtype, paged_attention, page_params): + """Load separate vision and text models following test_end2end.py pattern.""" + state_dict = model_args.load_state_dict() + vision_prefix = "vision_tower.vision_model." + + # Setup paged attention config (exactly like test_end2end.py) + paged_attention_config = None + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + + # Load vision model (exactly like test_end2end.py) + if model_args.is_multimodal: + vision_model = TtGemmaTransformerVision( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=vision_prefix, + dtype=dtype, + configuration=model_args, + weight_cache_path=model_args.weight_cache_path(dtype), + ) + else: + vision_model = None + # Load text model (exactly like test_end2end.py) + text_model = Gemma3Transformer( + args=model_args, + mesh_device=mesh_device, + dtype=dtype, + state_dict=state_dict, + weight_cache_path=model_args.weight_cache_path(dtype), + paged_attention_config=paged_attention_config, + ) + + logger.info("Separate vision and text models loaded like test_end2end.py") + return vision_model, text_model + + +def run_generation_exactly_like_test_end2end( + vision_model, text_model, processed_inputs, model_args, page_table=None, paged_attention_config=None, max_gen_len=20 +): + """Run generation following the EXACT pattern from test_end2end.py.""" + input_ids = processed_inputs["input_ids"] + pixel_values = processed_inputs["pixel_values"] + input_prompts = processed_inputs["input_prompts"] + processor = processed_inputs["processor"] + + logger.info("Running generation exactly like test_end2end.py...") + + # Process vision (exactly like test_end2end.py) + logger.info("Running Vision Model...") + + # Create Generator (exactly like test_end2end.py) + generator = Gemma3Generator([text_model], [model_args], text_model.mesh_device, tokenizer=model_args.tokenizer) + + # Setup KV cache (exactly like test_end2end.py) + tt_kv_cache = [[l.attention.layer_past for l in text_model.layers]] if paged_attention_config else None + + # Get embeddings and combine with vision (exactly like test_end2end.py) + # host_embedding = model_args.reference_embedding() + + # # Text generation setup (exactly like test_end2end.py) + input_tokens_prefill = input_ids + batch_size = input_tokens_prefill.shape[0] + # seq_len = input_tokens_prefill.shape[1] + model_args.tokenizer = processor + + ( + input_tokens_prefill_pt, + encoded_prompts, + decoding_pos, + prefill_lens, + ) = preprocess_inputs_prefill( + input_prompts, + model_args.tokenizer, + [model_args], + instruct=True, + max_generated_tokens=max_gen_len, + max_prefill_len=8192, + ) + + input_tokens_prefill_pt = torch.stack(input_tokens_prefill_pt).view(batch_size, -1) + + logger.info("Running prefill...") + logits = generator.prefill_forward_text( + input_tokens_prefill_pt, + page_table=page_table, + kv_cache=tt_kv_cache, + prompt_lens=decoding_pos, + vision_model=vision_model, + processed_inputs=processed_inputs, + ) + + # Get first token (exactly like test_end2end.py) + prefilled_token = torch.argmax(logits, dim=-1) + logger.info(f"Prefilled token: {prefilled_token}") + + # Initialize generation (exactly like test_end2end.py) + all_outputs = [encoded_prompts[0][: prefill_lens[0]]] + all_outputs[0].append(int(prefilled_token[0].item())) + + current_pos = torch.tensor([decoding_pos[0]]) + out_tok = prefilled_token + generation_length = 150 + + results = [] + + # Decode loop (exactly like test_end2end.py) + logger.info("Starting decode loop...") + for iteration in range(generation_length): + logger.info(f"[Text] Decoding token {iteration}, current_pos: {current_pos.item()}") + + # Run decode (exactly like test_end2end.py) + logits = generator.decode_forward_text( + out_tok, + current_pos, + enable_trace=False, + page_table=page_table, + kv_cache=tt_kv_cache, + ) + + # Sample next token (exactly like test_end2end.py) + _, out_tok = sample_host( + logits, + temperature=0, + top_p=0.9, + ) + + token_id = out_tok[0].item() + decoded_token = model_args.tokenizer.decode([token_id]) + logger.info(f"Generated token {iteration}: ID={token_id}, text='{decoded_token}'") + + # Create result object + result = type("TokenResult", (), {"token": token_id, "text": decoded_token})() + + results.append(result) + + all_outputs[0].append(token_id) + current_pos += 1 + + # Early stopping (exactly like test_end2end.py) + if len(all_outputs[0]) >= 5 and all(t == all_outputs[0][-1] for t in all_outputs[0][-5:]): + logger.warning(f"Detected exact repetition of token {all_outputs[0][-1]} five times in a row. Stopping.") + break + + # Final response (exactly like test_end2end.py) + response = model_args.tokenizer.decode(all_outputs[0], skip_special_tokens=True) + logger.info(f"📝 Final Generated Response:\n{response}") + logger.info(f"📝 Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") + chat = parse_chat_output(response) + display_chat(logger, chat) + + logger.info(f"Generated {len(results)} tokens successfully") + return results + + +# Legacy function removed - vision processing now handled in multimodal model + + +def validate_e2e_outputs(results, expected_min_tokens=1): + """Validate end-to-end pipeline outputs.""" + if not results: + logger.error("No results generated from E2E pipeline") + return False + + if len(results) < expected_min_tokens: + logger.warning(f"Generated only {len(results)} tokens, expected at least {expected_min_tokens}") + return False + + # Check if tokens are valid + for result in results: + if not hasattr(result, "token") or not hasattr(result, "text"): + logger.error("Invalid result format") + return False + + logger.info("E2E pipeline validation passed") + return True + + +# ============================================================================= +# EXISTING FUNCTIONS (Unchanged for backward compatibility) +# ============================================================================= + + +def create_position_tensor(current_pos, model_args, mesh_device): + """Create position tensor for the model.""" + return ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, 0) if (model_args.is_galaxy and model_args.max_batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + +def convert_tt_output_to_torch(tt_out, model_args, mesh_device): + """Convert TTNN tensor to PyTorch tensor.""" + mesh_composer = ttnn.ConcatMesh2dToTensor( + mesh_device, dims=(3, 1) if model_args.is_galaxy else (1, -1), mesh_shape=model_args.cluster_shape + ) + tt_output_torch = ( + ttnn.to_torch(tt_out, mesh_composer=mesh_composer) + .permute(2, 1, 0, 3) + .squeeze(2)[: model_args.max_batch_size, 0:1, : model_args.vocab_size] + ) + ttnn.deallocate(tt_out) + return tt_output_torch + + +def process_token_generation( + i, + encoded_prompts, + encoded_prompts_tensor, + embd, + batch, + seqlen, + all_outputs, + all_outputs_ref, + run_ref_pt, + ref_output, + tt_output_torch, +): + """Process token generation for both prefill and decode phases.""" + if i in range(len(encoded_prompts)): + # While in "prefill" mode, use the prompt tokens as the output + all_outputs.append(encoded_prompts[i]) + if run_ref_pt: + all_outputs_ref.append(encoded_prompts[i]) + + tt_decode_input = embd(encoded_prompts_tensor[:, i]).view(batch, seqlen, -1) + if run_ref_pt: + pt_decode_input = embd(encoded_prompts_tensor[:, i]).view(batch, seqlen, -1) + else: + pt_decode_input = None + else: + # Greedy decode (temperature = 0) the generated token and save it to print out later + # Exact copy of original logic (including commented sections) + # if run_ref_pt: + # # Sample from reference model first + _, pt_out_tok = sample_host(ref_output, temperature=0, top_p=0.8) + pt_decode_input = embd(pt_out_tok) + all_outputs_ref.append(pt_out_tok.squeeze(1).tolist()[0]) + + # Use the same token for TT model (teacher forcing) + tt_decode_input = pt_decode_input + # all_outputs.append(pt_out_tok.squeeze(1).tolist()[0]) + # else: + # If not running reference model, sample from TT model directly + _, tt_out_tok = sample_host(tt_output_torch, temperature=0, top_p=0.8) + tt_decode_input = embd(tt_out_tok) + all_outputs.append(tt_out_tok.squeeze(1).tolist()[0]) + + return tt_decode_input, pt_decode_input + + +def validate_outputs(run_ref_pt, ref_output, tt_output_torch, pcc, all_outputs, all_outputs_ref, tokenizer, logger): + """Validate model outputs and compute PCC.""" + if run_ref_pt: + passing, pcc_message = comp_pcc(ref_output, tt_output_torch, pcc) + + # Decode the output tokens back to text + decoded_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in all_outputs] + logger.info(f'TTNN Decoded Outputs: {"".join(decoded_texts)}') + decoded_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in all_outputs_ref] + logger.info(f'Torch Decoded Outputs: {"".join(decoded_texts)}') + + logger.info(comp_allclose(ref_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("Model Passed!") + else: + logger.warning("Model Failed!") + + return passing + return True + + +def run_generation_loop( + tt_model, + model_args, + mesh_device, + reference_model, + embd, + encoded_prompts_tensor, + generation_length, + generation_start_pos, + batch, + seqlen, + page_table_tt, + run_ref_pt, + pcc, + tokenizer, + logger, + parse_chat, + encoded_prompts, +): + """Run the main token generation loop.""" + all_outputs = [] + all_outputs_ref = [] if run_ref_pt else [] + if run_ref_pt: + all_tests_pass = True + + # Initial setup + current_pos = torch.tensor([generation_start_pos for _ in range(batch)]) + current_pos_tensor = create_position_tensor(current_pos, model_args, mesh_device) + + # Select the first token from the prompts for initial decoding + pt_decode_input = embd(encoded_prompts_tensor[:, 0]).view(batch, seqlen, -1) + tt_decode_input = pt_decode_input + + for i in range(generation_length): + logger.info(f"[Model] Generating token {i}") + + # Prepare input + decode_input = model_args.prepare_residual_tensor_decode( + tt_decode_input, + model_args.model_config["DECODE_RESIDUAL_MEMCFG"], + ) + + # Get rotation matrices + rot_mats_global = tt_model.rope_setup.get_rot_mats(current_pos) + rot_mats_local = tt_model.rope_setup_local.get_rot_mats(current_pos) + rot_mats = [rot_mats_global, rot_mats_local] + + # Run TT model + tt_out = tt_model( + decode_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table_tt, + ) + + # Convert output + tt_output_torch = convert_tt_output_to_torch(tt_out, model_args, mesh_device) + + # Run reference model if needed + ref_output = None + if run_ref_pt: + ref_output = reference_model(pt_decode_input, current_pos[0]) + + # Update position + current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch)]) + current_pos_tensor = create_position_tensor(current_pos, model_args, mesh_device) + + # Process token generation + tt_decode_input, pt_decode_input = process_token_generation( + i, + encoded_prompts, + encoded_prompts_tensor, + embd, + batch, + seqlen, + all_outputs, + all_outputs_ref, + run_ref_pt, + ref_output, + tt_output_torch, + ) + + # Validate outputs + passing = validate_outputs( + run_ref_pt, ref_output, tt_output_torch, pcc, all_outputs, all_outputs_ref, tokenizer, logger + ) + + # Note: Individual PCC failures don't affect overall test result (matching original behavior) + # if not passing: + # all_tests_pass = False + + # Display chat if enabled + if parse_chat: + conversation = parse_chat_output(tokenizer.decode(all_outputs).replace("\n", "\\n")) + display_chat(logger, conversation) + + if run_ref_pt: + return all_tests_pass + else: + return True # If not running reference model, always pass + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@skip_for_blackhole("Failing on DRAM harvested P100a, see #21419") +@pytest.mark.timeout(1800) +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize( + "weights, layers", + [ + ("instruct", None), + ], + ids=["full"], +) +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False, + ), + ids=( + "paged_attention", + # "default_attention", + ), +) +@pytest.mark.parametrize( + "device_params", + [ + { + "fabric_config": ttnn.FabricConfig.FABRIC_1D, + "trace_region_size": 30000000, + "num_command_queues": 1, + "l1_small_size": 10 * 1024, + } + ], + indirect=True, +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (1024 * 8,), # Use smaller seq_len like test_end2end.py to avoid memory issues +) +@pytest.mark.parametrize( + "optimizations", + [ + lambda model_args: DecodersPrecision.accuracy(model_args.n_layers, model_args.model_name), + ], + ids=["accuracy"], +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "N150x4": (1, 4), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_e2e_vision_text_pipeline( + weights, + layers, + max_seq_len, + batch_size, + paged_attention, + page_params, + optimizations, + mesh_device, + reset_seeds, + request, + device_params, +): + """Test end-to-end vision-text pipeline using proper Generator methods.""" + logger.info("Starting E2E vision-text pipeline test") + + # Use bfloat8_b like test_end2end.py for better memory efficiency + dtype = ttnn.bfloat8_b + # Setup vision-enabled model configuration + model_args, instruct = setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations) + + if layers is not None: + model_args.n_layers = layers + + # Setup vision prompts and tokenizer + messages, tokenizer = setup_vision_prompts_and_tokenizer(model_args, instruct) + + # Process real vision inputs from images + processed_inputs = process_real_vision_inputs(messages, model_args) + + # Load separate models following test_end2end.py pattern + logger.info("Loading separate vision and text models like test_end2end.py...") + vision_model, text_model = load_separate_models_like_test_end2end( + model_args, mesh_device, dtype, paged_attention, page_params + ) + + # Setup page table for paged attention (exactly like test_end2end.py) + page_table_tt = None + paged_attention_config = None + + # Prepare page table for paged attention (exactly like test_end2end.py) + page_table = None + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, -2) if batch_size > 1 else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + # Run generation following EXACT test_end2end.py pattern + logger.info("Running generation following EXACT test_end2end.py pattern...") + results = run_generation_exactly_like_test_end2end( + vision_model, text_model, processed_inputs, model_args, page_table, paged_attention_config, max_gen_len=10 + ) + + # Validate results + validation_passed = validate_e2e_outputs(results, expected_min_tokens=1) + + # Final validation + if validation_passed and len(results) > 0: + logger.info("✅ E2E vision-text pipeline test PASSED!") + logger.info(f"Successfully generated {len(results)} tokens") + + # Log generated tokens for debugging + for i, result in enumerate(results[:5]): + logger.info(f"Token {i}: {result.token} -> '{result.text}'") + else: + logger.error("❌ E2E pipeline test failed") + assert False, f"E2E pipeline failed - generated {len(results)} tokens, validation: {validation_passed}" diff --git a/models/experimental/gemma3/tests/vision_tests/test_mmp.py b/models/experimental/gemma3/tests/vision_tests/test_mmp.py new file mode 100644 index 000000000000..ed947aa9d899 --- /dev/null +++ b/models/experimental/gemma3/tests/vision_tests/test_mmp.py @@ -0,0 +1,100 @@ +"""Gemma3 Test for multi-modal-projector""" + + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.experimental.gemma3.tt.mmp import TtGemma3MultiModalProjector + +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (1152,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +def test_multi_modal_inference(seq_len, batch_size, reset_seeds, mesh_device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + mesh_device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_vision_multi_modal() + + # create input tensor for multi_modal_projector layer + patches_per_image = 64 + num_patches = patches_per_image * patches_per_image + input = torch.randn((batch_size, num_patches, seq_len)) + reference_output = reference_model(input) + + # DistributedNorm inputs are fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode) + tt_input = ttnn.from_torch( + input, + device=mesh_device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + tt_model = TtGemma3MultiModalProjector( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix="multi_modal_projector", + image_size=tt_model_args.image_size, + patch_size=tt_model_args.vision_patch_size, + hidden_size=tt_model_args.vision_hidden_dim, + mm_tokens_per_image=tt_model_args.mm_tokens_per_image, + weight_cache_path=tt_model_args.weight_cache_path(dtype), + layer_norm_eps=1e-06, # layer_norm_eps + dtype=dtype, + configuration=tt_model_args, + ) + tt_output = tt_model(tt_input) + + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ + :, :, :, : tt_output.shape[-1] + ] + tt_output_torch = tt_output_torch.view(reference_output.shape) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + pcc_required = 0.9999 + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3/tests/vision_tests/test_patch_embedding.py b/models/experimental/gemma3/tests/vision_tests/test_patch_embedding.py new file mode 100644 index 000000000000..72cf892842e2 --- /dev/null +++ b/models/experimental/gemma3/tests/vision_tests/test_patch_embedding.py @@ -0,0 +1,111 @@ +"""Gemma3 test for Vision Patch Embedding""" + + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +from loguru import logger + +import ttnn +import torch +from models.tt_transformers.tt.model_config import ModelArgs + +from models.experimental.gemma3.tt.gemma_conv2d_patch import TtGemmaConv2dPatch +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from ttnn import ConcatMeshToTensor + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_conv2d_inference( + mesh_device, + reset_seeds, +): + pcc_required = 0.9999 + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + tt_layer_prefix = "visual.embeddings.patch_embedding." + first_layer_prefix = "visual.embeddings.patch_embedding._linear." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + num_devices = model_args.num_devices + + B, NCH, H, W = (1, 3, model_args.image_size, model_args.image_size) + in_channels, out_channels, kernel_size, stride, bias = ( + 3, + model_args.vision_dim, + model_args.vision_patch_size, + model_args.vision_patch_size, + True, + ) + + assert NCH == in_channels, "Number of channels in input tensor should match in_channels for the Conv2d patch." + assert type(kernel_size) == int, "Only symmetric kernel_size is currently supported." + assert kernel_size == stride, "Only same kernel_size and stride are currently supported." + + assert H % kernel_size == 0, "Height should be divisible by kernel_size." + assert W % kernel_size == 0, "Width should be divisible by kernel_size." + + input_tensor = torch.randn((B, NCH, H, W)) + logger.info(f"Input tensor shape: {input_tensor.shape}") + + ##### Perform the torch ops ##### + reference_model = model_args.reference_siglip_patch_embed() + # reference_model.load_state_dict(partial_state_dict) + reference_output = reference_model(input_tensor) + del reference_model + + tt_model = TtGemmaConv2dPatch( + mesh_device, + state_dict, + tt_layer_prefix, + dtype, + in_channels, + out_channels, + kernel_size, + stride, + bias, + ) + tt_output = tt_model(input_tensor) + + logger.info("Checking outputs") + out = ttnn.from_device(tt_output) + tt_output_torch = ttnn.to_torch(out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=3)) + + tt_output_torch = tt_output_torch[0, ..., :out_channels] + + logger.info(f"Reference output shape: {reference_output.shape}") + logger.info(f"TT output shape: {tt_output_torch.shape}") + + # TT output: [B, HW, C] + B, HW, C = tt_output_torch.shape + H = W = int(HW**0.5) + assert H * W == HW, "HW is not a perfect square — can't reshape" + tt_output_torch = tt_output_torch.permute(0, 2, 1) + tt_output_torch = tt_output_torch.reshape(B, C, H, W) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_attention.py b/models/experimental/gemma3/tests/vision_tests/test_vision_attention.py new file mode 100644 index 000000000000..42daa76f3bd8 --- /dev/null +++ b/models/experimental/gemma3/tests/vision_tests/test_vision_attention.py @@ -0,0 +1,102 @@ +"""Gemma3 Test for Vision Attention""" + + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.load_checkpoints import ( # convert_vision_hf_to_meta, + convert_hf_qkv_to_meta_format, + convert_vision_hf_to_meta, +) +from models.tt_transformers.tt.model_config import ModelArgs + +from models.tt_transformers.tt.ccl import TT_CCL +from models.experimental.gemma3.tt.gemma_image_attention import TtGemmaImageAttention +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_attention_inference(batch, num_chunks, mesh_device, reset_seeds, device_params): + dtype = ttnn.bfloat16 + pcc_required = 0.99 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "visual.encoder.layers.0.attn." + # partial_state_dict = { + # k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + # } + + dim = model_args.vision_dim + + reference_model = model_args.reference_vision_attention() + # reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + hidden_size = model_args.vision_dim + n_heads = model_args.vision_attn_n_heads + head_dim = hidden_size // n_heads + seq_len = model_args.vision_chunk_ntok + tt_ccl = TT_CCL(mesh_device) + tt_model = TtGemmaImageAttention( + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + ) + + pt_attention_input = torch.randn(batch, seq_len, dim) + + attention_input = model_args.prepare_residual_tensor_prefill( + pt_attention_input, + force_replicated=True, + ) + + tt_out = tt_model(attention_input) + + # Doing contract in tt is correct!! + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[0, :, :, :] + # tt_output_torch = ttnn.to_torch(tt_out, device=mesh_device) + + reference_output = reference_model(pt_attention_input)[0] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_cross_attention_transformer.py b/models/experimental/gemma3/tests/vision_tests/test_vision_cross_attention_transformer.py new file mode 100644 index 000000000000..751047cd2487 --- /dev/null +++ b/models/experimental/gemma3/tests/vision_tests/test_vision_cross_attention_transformer.py @@ -0,0 +1,117 @@ +"""Gemma3 Test for Vision Transformer""" + + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs + +from models.experimental.gemma3.tt.gemma_vision_crossattention import TtGemmaTransformerVision +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +# @pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize("bsz", [1]) +@pytest.mark.parametrize( + "device_params", + [ + { + "fabric_config": ttnn.FabricConfig.FABRIC_1D, + "trace_region_size": 30000000, + "num_command_queues": 1, + "l1_small_size": 24576, + } + ], + indirect=True, +) +def test_gemma_vision( + mesh_device, + reset_seeds, + bsz, + device_params, +): + pcc_required = 0.90 + dtype = ttnn.bfloat16 + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + vision_first_layer_prefix = "vision_tower.vision_model." + vision_partial_state_dict = { + k[len(vision_first_layer_prefix) :]: v + for k, v in state_dict.items() + if (k.startswith(vision_first_layer_prefix)) + } + + reference_vision_model = model_args.reference_vision_model() + # reference_vision_model.load_state_dict(vision_partial_state_dict) + + mmp_first_layer_prefix = "multi_modal_projector." + + image_size = model_args.image_size + in_channels = model_args.vision_in_channels + + input_tensor = torch.rand((bsz, in_channels, image_size, image_size)) + + reference_mmp = model_args.reference_vision_multi_modal() + + reference_output = get_image_features( + reference_vision_model, + reference_mmp, + input_tensor, + ) + + test_gemma_vision = TtGemmaTransformerVision( + mesh_device, + state_dict, + state_dict_prefix="vision_tower.vision_model.", + dtype=dtype, + configuration=model_args, + return_intermediate=False, + ) + + test_output = test_gemma_vision(input_tensor) + + logger.info("Checking outputs") + out = ttnn.from_device(test_output) + tt_output_torch = ttnn.to_torch(out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)).squeeze(0)[ + ..., : model_args.dim + ] + + # tt_output_torch = tt_output_torch.view(1, 256, 2560) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" + + +def get_image_features(vision_tower, projector, input_tensor): + """ + Get image features from the vision tower and projector. + """ + vision_token = vision_tower(input_tensor).last_hidden_state + image_features = projector(vision_token) + return image_features diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_embedding.py b/models/experimental/gemma3/tests/vision_tests/test_vision_embedding.py new file mode 100644 index 000000000000..3e6fb98642c7 --- /dev/null +++ b/models/experimental/gemma3/tests/vision_tests/test_vision_embedding.py @@ -0,0 +1,89 @@ +"""Gemma3 Test for Vision Embedding""" + + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs + +from models.experimental.gemma3.tt.siglip_vision_embedding import TtSiglipVisionEmbeddings +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from ttnn import ConcatMeshToTensor + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize("bsz", [1]) +def test_vision_embedding_integration( + mesh_device, + reset_seeds, + bsz, +): + pcc_required = 0.9999 + dtype = ttnn.bfloat16 + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "visual.embeddings." + # partial_state_dict = { + # k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + # } + + image_size = model_args.image_size + patch_size = model_args.vision_patch_size + hidden_dim = model_args.vision_dim + dim = model_args.vision_dim + in_channels = 3 + + input_tensor = torch.randn((bsz, in_channels, image_size, image_size)) + + reference_model = model_args.reference_vision_embedding() + # reference_model.load_state_dict(partial_state_dict) + reference_output = reference_model(input_tensor) + + vision_embed = TtSiglipVisionEmbeddings( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + dtype=dtype, + image_size=image_size, + patch_size=patch_size, + num_channels=in_channels, + hidden_dim=hidden_dim, + bias=True, + ) + + embeddings = vision_embed(input_tensor) + ##### Check the outputs ##### + logger.info("Checking outputs") + out = ttnn.from_device(embeddings) + tt_output_torch = ttnn.to_torch(out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=-1)) + + # Only select output from one device + tt_output_torch = tt_output_torch[..., :dim] + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + # To get RTOL values + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_layernorm.py b/models/experimental/gemma3/tests/vision_tests/test_vision_layernorm.py new file mode 100644 index 000000000000..c2b2d66891a3 --- /dev/null +++ b/models/experimental/gemma3/tests/vision_tests/test_vision_layernorm.py @@ -0,0 +1,100 @@ +"""Gemma3 Test for Vision Layernorm""" + + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.llama_layernorm import TtLayerNorm # Updated import for LayerNorm +from models.utility_functions import comp_allclose, comp_pcc, nearest_32, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize("layer_name", [("layer_norm1"), ("layer_norm2")]) +def test_layernorm_inference(mesh_device, reset_seeds, layer_name): + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device) + width = model_args.vision_dim + num_chunks = 4 + seq_len = nearest_32(model_args.vision_chunk_ntok) * num_chunks + + # Load full state dict + state_dict = model_args.load_state_dict() + + # Prefix for vision MLP weights — consistent with HF checkpoint + if layer_name == "layer_norm1": + first_layer_prefix = "visual.encoder.layers.0.ln_1." + else: + first_layer_prefix = "visual.encoder.layers.0.ln_2." + + model_args.WEIGHTS_DTYPE = dtype + # Reference HF MLP (from Gemma3 vision tower) + reference_model = model_args.reference_vision_layernorm(layer_name) + # reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + # Initialize the custom LayerNorm model + tt_model = TtLayerNorm( + device=mesh_device, + dim=width, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + weight_dtype=dtype, + eps=model_args.norm_eps, + ) + + # Generate random input + torch_input = torch.rand(1, seq_len, width) # Adjusted dimensions for LayerNorm + + # Reference output using PyTorch's LayerNorm + reference_output = reference_model(torch_input) + + # Convert input to ttnn tensor + tt_input = ttnn.from_torch( + torch_input, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + ) + + logger.info("Compilation pass for LayerNorm") + tt_output = tt_model(tt_input) + + tt_output_torch = ttnn.to_torch( + tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1) + ) # Adjusted dim for LayerNorm + tt_outputs = torch.chunk(tt_output_torch, model_args.num_devices, dim=-1) + + # Compare outputs + pcc_required = 0.99 + for idx, tt_output_torch in enumerate(tt_outputs): + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + reference_output_comp = reference_output.clone() + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output_comp = reference_output_comp[non_zero_indices] + + logger.info(comp_allclose(reference_output_comp, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_mlp.py b/models/experimental/gemma3/tests/vision_tests/test_vision_mlp.py new file mode 100644 index 000000000000..880e436b747f --- /dev/null +++ b/models/experimental/gemma3/tests/vision_tests/test_vision_mlp.py @@ -0,0 +1,101 @@ +"""Gemma3 Test for Vision MLP""" + + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.experimental.gemma3.tt.gemma_image_mlp import TtGemmaImageFeedForward +from models.utility_functions import comp_allclose, comp_pcc, nearest_32, skip_for_grayskull +from models.tt_transformers.tt.ccl import TT_CCL + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "device_params", + [ + { + "fabric_config": ttnn.FabricConfig.FABRIC_1D, + "trace_region_size": 30000000, + "num_command_queues": 1, + "l1_small_size": 24576, + } + ], + indirect=True, +) +def test_mlp_inference(batch, num_chunks, mesh_device, reset_seeds, device_params): + dtype = ttnn.bfloat16 + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + # first_layer_prefix = model_args.get_state_dict_prefix("MLP", 0, is_vision=True) + # partial_state_dict = { + # k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + # } + first_layer_prefix = "visual.encoder.layers.0.mlp." + model_args.WEIGHTS_DTYPE = dtype + + dim = model_args.vision_dim + seq_len = nearest_32(model_args.vision_chunk_ntok) * num_chunks + reference_model = model_args.reference_vision_mlp() + # reference_model.load_state_dict(partial_state_dict) + tt_ccl = TT_CCL(mesh_device) + tt_model = TtGemmaImageFeedForward( + mesh_device=mesh_device, + tt_ccl=tt_ccl, + args=model_args, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + ) + torch_input = torch.randn(1, batch, seq_len, dim) + reference_output = reference_model(torch_input).squeeze() + tt_input = ttnn.from_torch( + torch_input, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + ) + + tt_output = tt_model(tt_input) + + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ + :, :1, :, : + ].squeeze() + + pcc_required = 0.99 + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_pipeline.py b/models/experimental/gemma3/tests/vision_tests/test_vision_pipeline.py new file mode 100644 index 000000000000..48e062c4bb7a --- /dev/null +++ b/models/experimental/gemma3/tests/vision_tests/test_vision_pipeline.py @@ -0,0 +1,94 @@ +"""Gemma3 Test for Vision Model""" + + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs + +from models.experimental.gemma3.tt.gemma_vision_model import TtSiglipGemmaVisionModel +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize("bsz", [1]) +@pytest.mark.parametrize( + "device_params", + [ + { + "fabric_config": ttnn.FabricConfig.FABRIC_1D, + "trace_region_size": 30000000, + "num_command_queues": 1, + "l1_small_size": 24576, + } + ], + indirect=True, +) +def test_gemma_vision( + mesh_device, + reset_seeds, + bsz, + device_params, +): + pcc_required = 0.94 + dtype = ttnn.bfloat16 + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "visual." + # partial_state_dict = { + # k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + # } + + image_size = model_args.image_size + in_channels = model_args.vision_in_channels + + input_tensor = torch.rand((bsz, in_channels, image_size, image_size)) + + reference_model = model_args.reference_vision_model() + # reference_model.load_state_dict(partial_state_dict) + reference_output = reference_model(input_tensor).last_hidden_state + + test_gemma_vision = TtSiglipGemmaVisionModel( + mesh_device, + state_dict, + state_dict_prefix=first_layer_prefix, + dtype=dtype, + configuration=model_args, + return_intermediate=False, + ) + + test_output = test_gemma_vision(input_tensor) + + logger.info("Checking outputs") + out = ttnn.from_device(test_output) + tt_output_torch = ttnn.to_torch(out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)).squeeze(0)[ + ..., :1152 + ] + + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_rmsnorm.py b/models/experimental/gemma3/tests/vision_tests/test_vision_rmsnorm.py new file mode 100644 index 000000000000..a920f04980ad --- /dev/null +++ b/models/experimental/gemma3/tests/vision_tests/test_vision_rmsnorm.py @@ -0,0 +1,131 @@ +"""Gemma3 test for Vision RMSNorm""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +from loguru import logger + +import torch +import pytest +import os + +import ttnn +from models.experimental.gemma3.tt.gemma_vision_rmsnorm import RMSNorm + +from models.tt_transformers.tt.ccl import TT_CCL + + +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from models.tt_transformers.tt.model_config import ModelArgs + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "device_params", + [ + { + "fabric_config": ttnn.FabricConfig.FABRIC_1D, + "trace_region_size": 30000000, + "num_command_queues": 1, + "l1_small_size": 24576, + } + ], + indirect=True, +) +def test_rmsnorm_inference(seq_len, batch_size, reset_seeds, mesh_device, device_params): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + mesh_device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_vision_rms_norm() # Gemma3 RMSNorm + first_layer_prefix = "multi_modal_projector.mm_soft_emb_norm." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + # reference_model.load_state_dict(partial_state_dict) + + tt_model = RMSNorm( + device=mesh_device, + dim=1152, + state_dict=state_dict, + state_dict_prefix="", + weight_key="multi_modal_projector.mm_soft_emb_norm", + weight_dtype=dtype, + is_distributed=False, + sharded_program_config=tt_model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=tt_model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], + ) + + tt_ccl = TT_CCL(mesh_device) + # Wrap it in DistributedNorm + # tt_model = DistributedNorm(tt_inner_norm, tt_model_args, TG=tt_model_args.is_galaxy, tt_ccl = tt_ccl) + + input = torch.rand(1, 1, 1152) + + reference_output = reference_model(input) + + # DistributedNorm inputs are fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode) + tt_input = ttnn.from_torch( + input, + device=mesh_device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + memory_config=( + tt_model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG + ), + ) + + tt_output = tt_model(tt_input, mode=mode) + + # DistributedNorm outputs are replicated across devices + tt_output_torch = ttnn.to_torch( + tt_output, + mesh_composer=ttnn.ConcatMesh2dToTensor( + mesh_device, dims=(0, 2) if tt_model_args.is_galaxy else (2, 0), mesh_shape=tt_model_args.cluster_shape + ), + )[:1, :, :] + tt_output_torch = tt_output_torch[..., :1152].squeeze(0) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("rms_norm Passed!") + else: + logger.warning("rms_norm Failed!") + + assert passing, f"rms_norm output does not meet PCC requirement {0.99}." diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_transformer.py b/models/experimental/gemma3/tests/vision_tests/test_vision_transformer.py new file mode 100644 index 000000000000..2f7ef9521c93 --- /dev/null +++ b/models/experimental/gemma3/tests/vision_tests/test_vision_transformer.py @@ -0,0 +1,116 @@ +"""Gemma3 test for Vision Transformer submodule""" + + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + +from models.experimental.gemma3.tt.gemma_image_transformer import TtGemmaImageTransformer + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_image_transformer_inference(batch, num_chunks, mesh_device, device_params): + pcc_required = 0.95 + + model_args = ModelArgs(mesh_device) + dtype = ttnn.bfloat16 + + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + n_layers = model_args.vision_n_layers + first_layer_prefix = "visual.encoder." + + # gated = True + + # partial_state_dict = { + # k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + # } + + dim = model_args.vision_dim + seq_len = model_args.vision_chunk_ntok - 1 + + reference_model = model_args.reference_vision_encoder() + # reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + all_tests_pass = True + + tt_model = TtGemmaImageTransformer( + mesh_device, + state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + layers=n_layers, + block_key="layers", + ) + + # Create PT input + pt_attention_input = torch.randn(batch, seq_len, dim) + attention_mask = torch.zeros(batch, 1, seq_len, seq_len) + + attention_input = model_args.prepare_residual_tensor_prefill( + pt_attention_input, + force_replicated=True, + ) + tt_mask = ttnn.from_torch( + attention_mask, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + with torch.no_grad(): + tt_out = tt_model(attention_input, mask=tt_mask) + + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[0, :, :, :] + + reference_output = reference_model(pt_attention_input, attention_mask=attention_mask)[0] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + if not passing: + logger.warning(f"PCC value -- {pcc_message} -- is lower than {pcc_required} for the output.") + else: + logger.info(f"PCC: {pcc_message}") + + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + logger.info(comp_allclose(reference_output, tt_output_torch)) + + all_tests_pass = all_tests_pass and passing + + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_transformer_block.py b/models/experimental/gemma3/tests/vision_tests/test_vision_transformer_block.py new file mode 100644 index 000000000000..a9938f99d2e2 --- /dev/null +++ b/models/experimental/gemma3/tests/vision_tests/test_vision_transformer_block.py @@ -0,0 +1,103 @@ +"""Gemma3 Test for Vision Transformer block""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.experimental.gemma3.tt.gemma_image_block import TtGemmaImageTransformerBlock +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "gated", + (True, False), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_block_inference(batch, num_chunks, mesh_device, reset_seeds, gated, device_params): + dtype = ttnn.bfloat16 + pcc_required = 0.99 + gated = False + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + if gated: + first_layer_prefix = "visual.encoder.layers.0." + else: + first_layer_prefix = "visual.encoder.layers.0." + + dim = model_args.vision_dim + heads = model_args.vision_attn_n_heads + seq_len = model_args.vision_chunk_ntok - 1 + + reference_model = model_args.reference_vision_encoder_block() + # reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + tt_model = TtGemmaImageTransformerBlock( + mesh_device, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + gated=gated, + ) + + pt_attention_input = torch.randn(batch, seq_len, dim) + attention_mask = torch.zeros(batch, 1, seq_len, seq_len) + + attention_input = model_args.prepare_residual_tensor_prefill( + pt_attention_input, + force_replicated=True, + ) + tt_mask = ttnn.from_torch( + attention_mask, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + tt_out = tt_model(attention_input, mask=tt_mask) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[0, :, :, :] + + reference_output = reference_model(pt_attention_input, attention_mask=attention_mask)[0] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3_4b/tt/attention.py b/models/experimental/gemma3/tt/attention.py similarity index 89% rename from models/experimental/gemma3_4b/tt/attention.py rename to models/experimental/gemma3/tt/attention.py index 81208d36aff2..f070e2d8224e 100644 --- a/models/experimental/gemma3_4b/tt/attention.py +++ b/models/experimental/gemma3/tt/attention.py @@ -1,7 +1,7 @@ """ source: models/tt_transformers/tt/attention.py -This is the attention implementation of the Gemma-3-4b-it +This is the attention implementation of the Gemma3 We have re-used the Attention implementation of the TT-Transformers with few modifications. This implementation has Changes in Datatype (Bfloat16) that supports the RMSNorm, @@ -20,7 +20,7 @@ import ttnn from models.common.lightweightmodule import LightweightModule -from models.experimental.gemma3_4b.tt.rmsnorm import RMSNorm +from models.experimental.gemma3.tt.rmsnorm import RMSNorm from models.tt_transformers.tt.ccl import tt_all_gather, tt_all_reduce from models.tt_transformers.tt.model_config import OpGroup, TensorGroup @@ -41,7 +41,8 @@ def __init__( ): super().__init__() - self.state_dict = state_dict + self.layer_idx = layer_num + self.configuration = configuration self.mesh_device = mesh_device self.tt_ccl = tt_ccl self.num_devices = configuration.num_devices @@ -107,6 +108,9 @@ def __init__( self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 self.transformation_mats = transformation_mats + self.is_sliding = ( + configuration.layer_types[layer_num] == "sliding_attention" if configuration.layer_types else False + ) self.model_config = configuration.get_model_config() self.ccl_topology = configuration.ccl_topology() @@ -124,22 +128,22 @@ def __init__( decoder_id=layer_num, tensor=TensorGroup.KV_CACHE ) self.li_qkv_decode_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration + decoder_id=layer_num, op=OpGroup.LI_QKV_DECODE, configuration=configuration ) self.sdpa_decode_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration + decoder_id=layer_num, op=OpGroup.SDPA_DECODE, configuration=configuration ) self.li_o_decode_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration + decoder_id=layer_num, op=OpGroup.LI_O_DECODE, configuration=configuration ) self.sdpa_prefill_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration + decoder_id=layer_num, op=OpGroup.SDPA_PREFILL, configuration=configuration ) self.li_qkv_prefill_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration + decoder_id=layer_num, op=OpGroup.LI_QKV_PREFILL, configuration=configuration ) self.li_o_prefill_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration + decoder_id=layer_num, op=OpGroup.LI_O_PREFILL, configuration=configuration ) layer_name = configuration.get_state_dict_prefix(self.__class__.__name__, layer_num) @@ -160,14 +164,14 @@ def __init__( self.wqkv_bias_prefill = None # Create combined QKV bias if present in state dict - if f"{wq_str}.bias" in self.state_dict: + if f"{wq_str}.bias" in state_dict: qkv_bias = torch.concat( [ torch.concat( [ - torch.chunk(self.state_dict[f"{wq_str}.bias"], configuration.num_devices)[i], - torch.chunk(self.state_dict[f"{wk_str}.bias"], configuration.num_devices)[i], - torch.chunk(self.state_dict[f"{wv_str}.bias"], configuration.num_devices)[i], + torch.chunk(state_dict[f"{wq_str}.bias"], configuration.num_devices)[i], + torch.chunk(state_dict[f"{wk_str}.bias"], configuration.num_devices)[i], + torch.chunk(state_dict[f"{wv_str}.bias"], configuration.num_devices)[i], ], dim=-1, ) @@ -226,9 +230,9 @@ def __init__( qkv_list = [] for i in range(self.num_devices_per_group): # Chunk weights - wq_selected = torch.chunk(self.state_dict[f"{wq_str}.weight"], self.num_devices_per_group, dim=0)[i] - wk_selected = torch.chunk(self.state_dict[f"{wk_str}.weight"], self.num_devices_per_group, dim=0)[i] - wv_selected = torch.chunk(self.state_dict[f"{wv_str}.weight"], self.num_devices_per_group, dim=0)[i] + wq_selected = torch.chunk(state_dict[f"{wq_str}.weight"], self.num_devices_per_group, dim=0)[i] + wk_selected = torch.chunk(state_dict[f"{wk_str}.weight"], self.num_devices_per_group, dim=0)[i] + wv_selected = torch.chunk(state_dict[f"{wv_str}.weight"], self.num_devices_per_group, dim=0)[i] # Transpose the selected chunks wq = torch.transpose(wq_selected, -2, -1) @@ -262,12 +266,12 @@ def norm_reshard(x, norm, mode): x = ttnn.to_memory_config(x, mem_cfg, dtype=x.dtype) return x - if f"{q_norm_str}.weight" in self.state_dict: + if f"{q_norm_str}.weight" in state_dict: fn_q_norm = RMSNorm( device=self.mesh_device, dim=self.head_dim, eps=configuration.norm_eps, - state_dict=self.state_dict, + state_dict=state_dict, state_dict_prefix=None, # we already prefix q_norm_str weight_cache_path=None if configuration.dummy_weights else weight_cache_path, weight_dtype=ttnn.bfloat16, @@ -282,12 +286,12 @@ def norm_reshard(x, norm, mode): else: self.q_norm = lambda x, mode: x - if f"{k_norm_str}.weight" in self.state_dict: + if f"{k_norm_str}.weight" in state_dict: fn_k_norm = RMSNorm( device=self.mesh_device, dim=self.head_dim, eps=configuration.norm_eps, - state_dict=self.state_dict, + state_dict=state_dict, state_dict_prefix=None, # we already prefix k_norm_str weight_cache_path=None if configuration.dummy_weights else weight_cache_path, weight_dtype=ttnn.bfloat16, @@ -304,7 +308,7 @@ def norm_reshard(x, norm, mode): # For ring topology we can use all gather matmul for wo self.use_fused_all_gather_matmul = self.model_config["USE_FUSED_ALL_GATHER_MATMUL"] - pt_wo = self.state_dict[f"{wo_str}.weight"].transpose(-1, -2).unsqueeze(0).unsqueeze(0) + pt_wo = state_dict[f"{wo_str}.weight"].transpose(-1, -2).unsqueeze(0).unsqueeze(0) wo_mem_config = configuration.create_dram_sharded_mem_config( (configuration.n_heads * configuration.head_dim) // configuration.num_devices, configuration.dim @@ -515,6 +519,7 @@ def forward_decode( # For example, a prompt w/ 1 user vs, the same prompt repeated N times for N users, will produce different outputs # This is because the SDPA op in decode mode has different number of reductions depending on batch size # Which leads to slightly different outputs from attention (due to accumulated errors) + if page_table: attn_output_1G4D = ttnn.transformer.paged_scaled_dot_product_attention_decode( q_heads_1BQD, @@ -526,6 +531,7 @@ def forward_decode( program_config=self.model_config["SDPA_DECODE_PROGCFG"], compute_kernel_config=self.sdpa_decode_compute_kernel_cfg, memory_config=ttnn.DRAM_MEMORY_CONFIG, + sliding_window=self.configuration.sliding_window if self.is_sliding else 0, ) else: attn_output_1G4D = ttnn.transformer.scaled_dot_product_attention_decode( @@ -537,6 +543,7 @@ def forward_decode( program_config=self.model_config["SDPA_DECODE_PROGCFG"], compute_kernel_config=self.sdpa_decode_compute_kernel_cfg, memory_config=ttnn.DRAM_MEMORY_CONFIG, # FIXME: why not L1 height sharded e.g. SCORES_BATCHED_MM_OUTPUT_MEMCFG? + sliding_window=self.configuration.sliding_window if self.is_sliding else 0, ) ttnn.deallocate(q_heads_1BQD) @@ -557,32 +564,49 @@ def forward_decode( attn_output_cat, self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"] ) - # TODO: #26349 - # Fused AGMM currently has a PCC bug on small shapes - # Using the non-fused version is a temporary workaround - - all_gather_output = ttnn.experimental.all_gather_async( - attn_output_cat, - persistent_output_buffer=None, - dim=3, - multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(), - num_links=1, - memory_config=self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"], - barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(), - chunks_per_sync=10, - num_workers_per_link=2, - num_buffers_per_channel=2, - ) + # Fused AGMM only valid for ring topology + if self.ccl_topology == ttnn.Topology.Ring: + _, dense_out_sharded = ttnn.experimental.all_gather_matmul_async( + attn_output_cat, + self.wo, + persistent_output_buffer=None, + dim=3, + multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(), + all_gather_core_grid_offset=(0, 4), + barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(), + num_links=1, + memory_config_ag=self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"], + memory_config_mm=self.model_config["DECODE_RESIDUAL_MEMCFG"], + program_config=self.model_config["ATTN_ALL_GATHER_MATMUL_PROGCFG"], + compute_kernel_config=self.compute_kernel_config_hifi2, + chunks_per_sync=10, + num_workers_per_link=2, + num_buffers_per_channel=2, + ) + else: + all_gather_output = ttnn.experimental.all_gather_async( + attn_output_cat, + persistent_output_buffer=None, + dim=3, + multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(), + num_links=1, + topology=self.ccl_topology, + memory_config=self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"], + barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(), + chunks_per_sync=10, + num_workers_per_link=2, + num_buffers_per_channel=2, + ) - dense_out_sharded = ttnn.linear( - all_gather_output, - self.wo, - memory_config=self.model_config["DECODE_RESIDUAL_MEMCFG"], - program_config=self.model_config["ATTN_ALL_GATHER_MATMUL_PROGCFG"], - compute_kernel_config=self.li_o_decode_compute_kernel_cfg, - ) + dense_out_sharded = ttnn.linear( + all_gather_output, + self.wo, + memory_config=self.model_config["DECODE_RESIDUAL_MEMCFG"], + program_config=self.model_config["ATTN_ALL_GATHER_MATMUL_PROGCFG"], + compute_kernel_config=self.li_o_decode_compute_kernel_cfg, + ) - ttnn.deallocate(all_gather_output) + ttnn.deallocate(all_gather_output) ttnn.deallocate(attn_output_cat) dense_out_sharded = ttnn.to_memory_config(dense_out_sharded, self.model_config["DECODE_RESIDUAL_MEMCFG"]) return dense_out_sharded @@ -618,7 +642,7 @@ def forward_decode( core_grid=ttnn.CoreGrid(y=4, x=8) if self.TG else None, program_config=self.model_config["ATTN_OUTPUT_PROGCFG"] if not self.TG else None, memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - dtype=ttnn.bfloat8_b if self.TG else ttnn.bfloat16, + dtype=ttnn.bfloat8_b if self.TG else None, compute_kernel_config=self.li_o_decode_compute_kernel_cfg, ) @@ -808,7 +832,7 @@ def forward_prefill( ttnn.deallocate(v_fill) # SDPA - q_heads_1QSD_8b = ttnn.typecast(q_heads_1QSD, dtype=self.activation_dtype or ttnn.bfloat16) + q_heads_1QSD_8b = ttnn.typecast(q_heads_1QSD, dtype=self.activation_dtype or ttnn.bfloat8_b) ttnn.deallocate(q_heads_1QSD) if chunk_start_idx is not None: @@ -818,6 +842,8 @@ def forward_prefill( values_BKSD, page_table, chunk_start_idx, + attn_mask=attn_mask, + is_causal=True, compute_kernel_config=self.sdpa_prefill_compute_kernel_cfg, program_config=self.model_config["SDPA_PROGCFG"](seq_len), ) @@ -826,7 +852,6 @@ def forward_prefill( q_heads_1QSD_8b, k_heads_1KSD_8b, v_heads_1VSD_8b, - is_causal=True, scale=self.scale, compute_kernel_config=self.sdpa_prefill_compute_kernel_cfg, program_config=self.model_config["SDPA_PROGCFG"](seq_len), @@ -871,7 +896,7 @@ def forward_prefill( attn_output_11SH, self.wo, compute_kernel_config=self.li_o_prefill_compute_kernel_cfg, - dtype=self.activation_dtype or ttnn.bfloat16, + dtype=self.activation_dtype or ttnn.bfloat8_b, memory_config=ttnn.DRAM_MEMORY_CONFIG, program_config=self.model_config["WO_PREFILL_PROGCFG"](seq_len), ) diff --git a/models/experimental/gemma3_4b/tt/decoder.py b/models/experimental/gemma3/tt/decoder.py similarity index 79% rename from models/experimental/gemma3_4b/tt/decoder.py rename to models/experimental/gemma3/tt/decoder.py index f96d9ed914dd..259007843641 100644 --- a/models/experimental/gemma3_4b/tt/decoder.py +++ b/models/experimental/gemma3/tt/decoder.py @@ -1,27 +1,27 @@ """ source: models/tt_transformers/tt/decoder.py -This is the Decoder block for the gemma 3-4b-it model +This is the Decoder block for the Gemma3 model We couldn't use the existing implementation in TT-Transformers because the usage of submodules is different -In Gemma-3-4b-it, The decoder Block has Additional pre_feedforward_layernorm and post_feedforward_layernorm, +In Gemma3, The decoder Block has Additional pre_feedforward_layernorm and post_feedforward_layernorm, And the logic of implementation is different from the existing implementation in TT-Transformers. """ # SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC # SPDX-License-Identifier: Apache-2.0 - import ttnn from models.common.lightweightmodule import LightweightModule from models.tt_transformers.tt.distributed_norm import DistributedNorm -from models.experimental.gemma3_4b.tt.rmsnorm import RMSNorm +from models.experimental.gemma3.tt.rmsnorm import RMSNorm -from models.experimental.gemma3_4b.tt.attention import Attention +from models.experimental.gemma3.tt.attention import Attention as DefaultAttention -from models.experimental.gemma3_4b.tt.mlp import MLP +from models.experimental.gemma3.tt.mlp import MLP from models.tt_transformers.tt.model_config import TensorGroup +from models.tt_transformers.tt.ccl import tt_all_reduce class TransformerBlock(LightweightModule): @@ -35,7 +35,6 @@ def __init__( layer_num, weight_cache_path, transformation_mats, - transformation_mats_local=None, paged_attention_config=None, use_paged_kv_cache=False, attention_class=None, @@ -44,8 +43,8 @@ def __init__( self.state_dict = state_dict self.mesh_device = mesh_device - self.tt_ccl = tt_ccl + self.tt_ccl = tt_ccl self.args = args self.hidden_size = args.dim self.n_heads = args.n_heads @@ -58,26 +57,25 @@ def __init__( self.model_config = args.get_model_config() self.layer_num = layer_num + self.num_devices = args.num_devices - self.is_attention_sliding = ( - self.args.layer_types[layer_num] == "sliding_attention" if self.args.layer_types else False - ) + ActualAttentionClass = attention_class if attention_class is not None else DefaultAttention - self.attention = Attention( + self.attention = ActualAttentionClass( mesh_device=mesh_device, - tt_ccl=self.tt_ccl, + tt_ccl=tt_ccl, state_dict=state_dict, weight_cache_path=weight_cache_path, layer_num=layer_num, dtype=dtype, - transformation_mats=transformation_mats_local if self.is_attention_sliding else transformation_mats, + transformation_mats=transformation_mats, configuration=args, paged_attention_config=paged_attention_config, use_paged_kv_cache=use_paged_kv_cache, ) self.feed_forward = MLP( mesh_device=mesh_device, - tt_ccl=self.tt_ccl, + tt_ccl=tt_ccl, args=args, state_dict=state_dict, weight_cache_path=weight_cache_path, @@ -86,7 +84,7 @@ def __init__( model_config=self.model_config, ) - self.attention_norm = DistributedNorm( # input_layernorm + self.attention_norm = DistributedNorm( RMSNorm( device=mesh_device, dim=args.dim, @@ -188,19 +186,19 @@ def forward( kv_cache=None, ) -> ttnn.Tensor: TG = self.args.is_galaxy - # x is fractured across devices and interleaved in DRAM (for prefill) and sharded in L1 (for decode) skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG assert ( hidden_states.memory_config() == skip_mem_cfg ), f"decoder input memcfg mismatch: {hidden_states.memory_config()} != {skip_mem_cfg}" - - # Choose the correct rotation matrices based on the mode - rot_mats = rot_mats_local if self.is_attention_sliding else rot_mats_global residual = hidden_states attn_in = self.attention_norm(hidden_states, mode) + rot_mats = ( + rot_mats_local if (hasattr(self.attention, "is_sliding") and self.attention.is_sliding) else rot_mats_global + ) + attn_out = self.attention.forward( attn_in, current_pos, @@ -218,6 +216,22 @@ def forward( ttnn.deallocate(attn_out) ttnn.deallocate(attn_in) + if self.num_devices > 1: + hidden_states = tt_all_reduce( + hidden_states, + self.mesh_device, + self.tt_ccl, + cluster_axis=0, + dim=3, + num_reduce_scatter_links=self.args.num_reduce_scatter_links, + num_all_gather_links=self.args.num_all_gather_links, + topology=ttnn.Topology.Ring, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=self.args.ccl_dtype, + ) + + hidden_states = ttnn.div(hidden_states, self.num_devices) + hidden_states = ttnn.add(hidden_states, residual, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16) residual = hidden_states @@ -235,6 +249,22 @@ def forward( hidden_states = self.post_ff_norm(hidden_states, mode) + if self.num_devices > 1: + hidden_states = tt_all_reduce( + hidden_states, + self.mesh_device, + self.tt_ccl, + cluster_axis=0, + dim=3, + num_reduce_scatter_links=self.args.num_reduce_scatter_links, + num_all_gather_links=self.args.num_all_gather_links, + topology=ttnn.Topology.Ring, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=self.args.ccl_dtype, + ) + + hidden_states = ttnn.div(hidden_states, self.num_devices) + hidden_states = ttnn.add( hidden_states, residual, diff --git a/models/experimental/gemma3/tt/gemma3_generator.py b/models/experimental/gemma3/tt/gemma3_generator.py new file mode 100644 index 000000000000..a8ea740d4685 --- /dev/null +++ b/models/experimental/gemma3/tt/gemma3_generator.py @@ -0,0 +1,1302 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass + +import torch +from llama_models.llama3.api.datatypes import InterleavedTextMedia, StopReason +from llama_models.llama3.reference_impl.generation import ( + ChatPrediction, + CompletionPrediction, + TokenResult, + sample_top_p, +) +from loguru import logger + +import ttnn +from models.tt_transformers.tt.common import ( + copy_host_to_device, + get_block_size, + get_max_prefill_chunk_size, + get_padded_prefill_len, + num_blocks_in_seq, +) + + +@dataclass(frozen=True) +class SamplingParams: + """ + Used in Generator decode forward functions for greedy decoding / sampling on device. + The same data class exists in vLLM at vllm/worker/tt_model_runner.py. + """ + + temperature: float + top_k: int + top_p: float + + +class Gemma3Generator: + def __init__(self, model, model_args, mesh_device, tokenizer=None, formatter=None): + """ + Creating a LlamaVision wrapper requires only a mesh_device and model_args. + With model_args you have the checkpoint location, can specify max batch size + and max seqlen, and other model specific parameters. + + LlamaVision is general to text and chat. + + For bringup, make this class general to any backend implementation, as long as it takes torch tensors and returns torch tensors. + + """ + self.model = model + self.model_args = model_args + self.mesh_device = mesh_device + self.tokenizer = tokenizer + self.formatter = formatter + self.data_parallel = len(self.model) + self.prev_page_table = None + + # Note: This function is called by vLLM + def prefill_forward_text( + self, tokens: torch.Tensor, page_table=None, kv_cache=None, prompt_lens=None, empty_slots=None, **kwargs + ): + if page_table is not None: + assert isinstance(page_table, torch.Tensor), "page_table mush be torch.Tensor" + + batch_size, batch_seq_len = tokens.shape + max_batch_size_per_model = self.model_args[0].max_batch_size + + # Each model expected to run the same model, safe to use 1st vocab size + output_logits = torch.zeros(batch_size, 1, self.model_args[0].vocab_size) + prompt_lens = prompt_lens if prompt_lens is not None else torch.tensor([batch_seq_len] * batch_size) + + if empty_slots is None: + empty_slots = list(range(batch_size)) + + out_list = [] + for idx, user_id in enumerate(empty_slots): + model_id = user_id // max_batch_size_per_model + group_user_id = user_id % max_batch_size_per_model if page_table is None else 0 + seq_len = int(prompt_lens[idx]) + last_token_idx = seq_len - 1 + prefill_seq_len = get_padded_prefill_len(seq_len) + + logger.info(f"Prefilling User {user_id + 1} up to {seq_len} tokens") + + # Extracting data for the current user + # If page_table is not provided, we keep track of the relative/model user_id through group_user_id + prefill_ids = torch.cat( + [tokens[idx : idx + 1, :seq_len], torch.zeros(1, prefill_seq_len - seq_len).long()], dim=-1 + ) + page_table_user = ( + self._get_prefill_user_page_table(page_table[idx : idx + 1], kv_cache[model_id], seq_len) + if page_table is not None + else None + ) + model_kv_cache = kv_cache[model_id] if kv_cache is not None else None + + logits = self.prefill_forward_single_user_text( + prefill_ids, + page_table=page_table_user, + user_id=group_user_id, + last_token_idx=last_token_idx, + kv_cache=model_kv_cache, + model_id=model_id, + **kwargs, + ) + out_list.append(logits) + + for idx, out in enumerate(out_list): + seq_len = int(prompt_lens[idx]) + last_token_idx = seq_len - 1 + user_id = empty_slots[idx] + model_id = user_id // max_batch_size_per_model + + # Since we give unpadded_seq_len, only the tile containing the last token is returned + output_logits[idx] = self.model[model_id].process_output_prefill(out, last_token_idx=(last_token_idx % 32)) + + logger.info(f"Finished prefill for all users up to {batch_seq_len} tokens, Starting decode...") + return output_logits + + def prefill_forward_single_user_text( + self, tokens, page_table, user_id, last_token_idx, kv_cache=None, model_id=-1, **kwargs + ): + seq_len = tokens.shape[-1] + use_chunked_prefill = seq_len > self.model_args[model_id].max_prefill_chunk_size + if use_chunked_prefill: + """ + Chunked prefill requires paged attention. There are some strange constraints which we must meet: + - page_table, which is used in SDPA, must match batch size of inputs, which is 1. This is because SDPA + checks that page table batch dim matches input batch dim. Therefore we must slice the page table for the current user. + - page_table must also have enough entries in each chunk, so it will be padded with zeros if necessary. + - chunked_page_table is the slice of the page table for the current chunk. This is used by paged_fill_cache + to keep it otherwise unaware that it is operating on a chunk. + - due to the above point, we must always set user_id to 0 for chunked prefill. + """ + assert page_table is not None, "page_table must be provided for chunked prefill" + assert kv_cache is not None, "kv_cache must be provided for chunked prefill" + assert ( + last_token_idx is not None and last_token_idx < seq_len + ), "last_token_idx must be provided and less than seq_len" + chunk_size = get_max_prefill_chunk_size(seq_len, self.model_args[model_id].max_prefill_chunk_size) + block_size = get_block_size(kv_cache) + last_token_idx_in_chunk = last_token_idx % chunk_size + # Calculate which chunk contains the last_token_idx + last_chunk_start = (last_token_idx // chunk_size) * chunk_size + page_table_user = page_table[user_id : user_id + 1, :] + # Pad page table to match number of blocks in seq_len + num_padding_blocks = num_blocks_in_seq(seq_len, block_size) - page_table_user.shape[1] + page_table_user_padded = torch.cat( + [page_table_user, torch.zeros(1, num_padding_blocks, dtype=torch.int32)], dim=-1 + ) + CHUNK_USER_ID = 0 + + for chunk_start in range(0, seq_len, chunk_size): + chunk_end = chunk_start + chunk_size + assert ( + chunk_end <= seq_len + ), f"Chunk end should be less than seq_len, got chunk_end={chunk_end} and seq_len={seq_len}" + chunk_tokens = tokens[:, chunk_start:chunk_end] + chunk_page_table = page_table_user[:, chunk_start // block_size : chunk_end // block_size] + + ( + chunk_prefill_input, + chunk_rot_mats_global_prefill, + chunk_rot_mats_local_prefill, + page_table_tt, + chunk_page_table_tt, + ) = self.model[model_id].prepare_inputs_prefill( + chunk_tokens, + start_pos=chunk_start, + page_table=page_table_user_padded, + chunk_page_table=chunk_page_table, + **kwargs, + ) + tt_logits = self.model[model_id].ttnn_prefill_forward( + chunk_prefill_input, + rot_mats_global=chunk_rot_mats_global_prefill, + rot_mats_local=chunk_rot_mats_local_prefill, + user_id=CHUNK_USER_ID, + page_table=page_table_tt, + chunk_page_table=chunk_page_table_tt, + chunk_start_idx=chunk_start, + get_last_token=(last_token_idx_in_chunk // 32) * 32, + kv_cache=kv_cache, + **kwargs, + ) + + if chunk_start == last_chunk_start: + return tt_logits + else: + del tt_logits + else: + ( + prefill_input, + rot_mats_global_prefill, + rot_mats_local_prefill, + page_table_tt, + _, + ) = self.model[model_id].prepare_inputs_prefill( + tokens, + page_table=page_table, + **kwargs, + ) + + tt_logits = self.model[model_id].ttnn_prefill_forward( + prefill_input, + rot_mats_global=rot_mats_global_prefill, + rot_mats_local=rot_mats_local_prefill, + user_id=user_id, + page_table=page_table_tt, + get_last_token=(last_token_idx // 32) * 32, + kv_cache=kv_cache, + ) + return tt_logits + + # Note: This function is called by vLLM + def decode_forward_text( + self, + tokens, + start_pos, + page_table=None, + kv_cache=None, + enable_trace=True, + read_from_device=True, + sampling_params: SamplingParams = None, # Should be None if not greedy decoding / sampling on device. + ): + assert ( + sampling_params is None or sampling_params.temperature == 0 + ), "Currently only supporting greedy decoding (temperature=0) on device" + argmax_on_device = sampling_params is not None and sampling_params.temperature == 0 + + B = tokens.shape[0] + tokens = torch.chunk(tokens, self.data_parallel, 0) + start_pos = torch.chunk(start_pos, self.data_parallel, 0) + page_table = torch.chunk(page_table, self.data_parallel, 0) if page_table is not None else None + + decode_kwargs = { + "current_pos": start_pos, + "tokens": tokens, + "page_table": page_table, + "kv_cache": kv_cache, + "argmax_on_device": argmax_on_device, + } + if enable_trace: + tt_decode_output = self._easy_trace_text(**decode_kwargs) + else: + tt_decode_output = self._decode_forward_no_trace_text(**decode_kwargs) + + if read_from_device: + to_host = self.read_decode_output(tt_decode_output) + return self.process_decode_output_host(to_host, is_tokens=(sampling_params is not None)) + + return tt_decode_output + + def _decode_forward_no_trace_text( + self, + tokens, + current_pos, + page_table=None, + kv_cache=None, + argmax_on_device=False, + ): + """ + Performs text decode step. + Returns tt_logits on device + """ + tt_logits = [] + + tt_tokens = [] + tt_current_pos = [] + tt_rot_mat_idxs_global = [] + tt_rot_mat_idxs_local = [] + tt_page_table = [] + for i in range(self.data_parallel): + user_page_table = page_table[i] if page_table is not None else None + model_i = self.model[i] + ( + tt_tokens_i, + tt_current_pos_i, + tt_rot_mat_idxs_global_i, + tt_rot_mat_idxs_local_i, + tt_page_table_i, + ) = model_i.prepare_inputs_decode(tokens[i], current_pos[i], user_page_table) + tt_tokens.append(tt_tokens_i) + tt_current_pos.append(tt_current_pos_i) + tt_rot_mat_idxs_global.append(tt_rot_mat_idxs_global_i) + tt_rot_mat_idxs_local.append(tt_rot_mat_idxs_local_i) + tt_page_table.append(tt_page_table_i) + + if ( + hasattr(self.model[i], "device_decode_sliding_mask") + and self.model[i].device_decode_sliding_mask is not None + ): + self.model[i].update_attention_masks(current_pos[i]) + + for i in range(self.data_parallel): + user_kv_cache = kv_cache[i] if kv_cache is not None else None + tt_logits_i = self.model[i].ttnn_decode_forward( + tt_tokens[i], + tt_current_pos[i], + rot_mat_idxs_global=tt_rot_mat_idxs_global[i], + rot_mat_idxs_local=tt_rot_mat_idxs_local[i], + page_table=tt_page_table[i], + kv_cache=user_kv_cache, + argmax_on_device=argmax_on_device, + ) + tt_logits.append(tt_logits_i) + + return tt_logits + + def _capture_trace_text( + self, + tokens, + current_pos, + page_table=None, + kv_cache=None, + argmax_on_device=False, + ): + """ + Captures a trace for the decode_forward method. + """ + + # Compile run + self._decode_forward_no_trace_text( + tokens, current_pos, page_table=page_table, kv_cache=kv_cache, argmax_on_device=argmax_on_device + ) + logger.info("Done Compiling Model") + + # Get inputs ready for trace run + device_inputs = [] + tt_out_trace = [] + trace_ids = {} + for i in range(self.data_parallel): + user_page_table = page_table[i] if page_table is not None else None + host_inputs = self.model[i].prepare_decode_inputs_host( + tokens[i], current_pos[i], page_table=user_page_table + ) + + device_inputs_i = copy_host_to_device(host_inputs, mesh_device=self.model_args[i].mesh_device) + device_inputs.append(device_inputs_i) + + for i in range(self.data_parallel): + trace_id = ttnn.begin_trace_capture(self.model_args[i].mesh_device, cq_id=0) + trace_ids[i] = trace_id + user_kv_cache = kv_cache[i] if kv_cache is not None else None + tt_out_trace.append( + self.model[i].ttnn_decode_forward( + *device_inputs[i], kv_cache=user_kv_cache, argmax_on_device=argmax_on_device + ) + ) + ttnn.end_trace_capture(self.model_args[i].mesh_device, trace_id, cq_id=0) + logger.info("Done Capturing Decode Trace") + return trace_ids, tt_out_trace, *device_inputs + + def _easy_trace_text( + self, + tokens, + current_pos, + page_table=None, + kv_cache=None, + argmax_on_device=False, + ): + """ + Tracing is easy! Just call this method and we'll handle tracing for you. + """ + if not hasattr(self, "trace_ids_text"): + trace_ids, tt_out_trace, *device_inputs = self._capture_trace_text( + tokens, current_pos, page_table=page_table, kv_cache=kv_cache, argmax_on_device=argmax_on_device + ) + self.trace_ids_text = trace_ids + self.trace_inputs_text = device_inputs + self.trace_output_text = tt_out_trace + + reset_inputs = not argmax_on_device + if self.prev_page_table is None or any( + not torch.equal(prev, curr) for prev, curr in zip(self.prev_page_table, page_table) + ): + reset_inputs = True + self.prev_page_table = page_table + + if reset_inputs: + for i in range(self.data_parallel): + user_page_table = page_table[i] if page_table is not None else None + host_inputs_i = self.model[i].prepare_decode_inputs_host(tokens[i], current_pos[i], user_page_table) + + copy_host_to_device( + host_tensors=host_inputs_i, + device_tensors=self.trace_inputs_text[i], + ) + for i in range(self.data_parallel): + if ( + hasattr(self.model[i], "device_decode_sliding_mask") + and self.model[i].device_decode_sliding_mask is not None + ): + self.model[i].update_attention_masks(current_pos[i]) + + for i, trace_id in self.trace_ids_text.items(): + ttnn.execute_trace(self.model_args[i].mesh_device, trace_id, cq_id=0, blocking=False) + + return self.trace_output_text + + def _prefill_forward_single_user( + self, + vision_images, + vision_mask, + tokens, + xattn_caches, + user_id, + total_len, + prefill_len, + page_table=None, + kv_cache=None, + cross_page_table=None, + model_id=-1, + ): + """ + Performs vision encode step then text prefill. + Returns (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits) + """ + B = tokens.shape[0] + last_token_idx = prefill_len - 1 + + text_only_inference = vision_images is None + if not text_only_inference: + ( + vision_tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + ) = self.model[model_id].compute_vision_tokens_masks( + batch_images=[vision_images], + batch_masks=[vision_mask], + total_len=total_len, + prefill_len=prefill_len, + ) + + if cross_page_table is not None: + num_vision_tokens = vision_tokens.shape[2] + cross_page_table = self._get_prefill_user_page_table(cross_page_table, kv_cache, num_vision_tokens) + else: + ( + vision_tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + ) = (None, None, None, None, None) + + if page_table is not None: + page_table = self._get_prefill_user_page_table(page_table, kv_cache, prefill_len) + + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_full_text_mask_expand_11SD, + rot_mats, + tt_page_table, + tt_cross_page_table, + ) = self.model[model_id].prepare_inputs_prefill( + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + prefill_len=prefill_len, + page_table=page_table, + cross_page_table=cross_page_table, + text_only_inference=text_only_inference, + ) + + tt_logits = self.model[model_id].ttnn_prefill_forward( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_full_text_mask_expand_11SD, + xattn_caches, + rot_mats, + user_id, + vision_tokens, + page_table=tt_page_table, + kv_cache=kv_cache, + get_last_token=(last_token_idx // 32) * 32, + cross_page_table=tt_cross_page_table, + text_only_inference=text_only_inference, + ) + + del tt_page_table + del tt_cross_page_table + + return ( + xattn_caches, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + tt_logits, + ) + + # Note: This function is called by vLLM + def prefill_forward( + self, + vision_images, + vision_masks, + tokens: torch.Tensor, + xattn_caches, + total_lens, + prompt_lens, + page_table=None, + kv_cache=None, + cross_page_table=None, + empty_slots=None, + ): + """ + Batched version of _prefill_forward_single_user for vision model. + """ + if page_table is not None: + assert isinstance(page_table, torch.Tensor), "page_table mush be torch.Tensor" + if cross_page_table is not None: + assert isinstance(cross_page_table, torch.Tensor), "cross_page_table mush be torch.Tensor" + + batch_size, batch_seq_len = tokens.shape + max_batch_size_per_model = self.model_args[0].max_batch_size + + output_logits = torch.zeros(batch_size, 1, self.model_args[0].vocab_size) + + out_list = [] + prefill_output_xattn_masks = [] + prefill_output_full_text_row_masked_out_masks = [] + decode_output_xattn_masks = [] + decode_output_full_text_row_masked_out_masks = [] + + if empty_slots is None: + empty_slots = list(range(batch_size)) + + for idx, user_id in enumerate(empty_slots): + model_id = user_id // max_batch_size_per_model + group_user_id = user_id % max_batch_size_per_model if page_table is None else 0 + seq_len = int(prompt_lens[idx]) + + logger.info(f"Prefilling User {user_id + 1} up to {seq_len} tokens") + + user_page_table = page_table[idx : idx + 1] if page_table is not None else None + user_cross_page_table = cross_page_table[idx : idx + 1] if kv_cache is not None else None + model_kv_cache = kv_cache[model_id] if kv_cache is not None else None + model_xattn_cache = xattn_caches[model_id] if xattn_caches is not None else None + + ( + model_xattn_cache, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + logits, + ) = self._prefill_forward_single_user( + vision_images=vision_images[idx], + vision_mask=vision_masks[idx], + tokens=tokens[idx : idx + 1, :seq_len], # Keep batch dimension + xattn_caches=model_xattn_cache, + user_id=group_user_id, + total_len=total_lens[idx], + prefill_len=seq_len, + page_table=user_page_table, + kv_cache=model_kv_cache, + cross_page_table=user_cross_page_table, + model_id=model_id, + ) + + if xattn_caches is not None: + xattn_caches[model_id] = model_xattn_cache + + out_list.append(logits) + prefill_output_xattn_masks.append(prefill_cross_attention_masks) + prefill_output_full_text_row_masked_out_masks.append(prefill_full_text_row_masked_out_mask) + decode_output_xattn_masks.append(decode_cross_attention_masks) + decode_output_full_text_row_masked_out_masks.append(decode_full_text_row_masked_out_mask) + + # We gather prefill output at the end of prefill to reduce unnecessary device sync + for idx, user_id in enumerate(empty_slots): + model_id = user_id // max_batch_size_per_model + + last_token_idx = prompt_lens[idx] - 1 + output_logits[idx] = self.model[model_id].process_output_prefill( + out_list[idx], 1, last_token_idx=(last_token_idx % 32) + ) + + logger.info(f"Finished prefill for all users up to {batch_seq_len} tokens, Starting decode...") + + return ( + output_logits, + prefill_output_xattn_masks, + prefill_output_full_text_row_masked_out_masks, + decode_output_xattn_masks, + decode_output_full_text_row_masked_out_masks, + ) + + # Note: This function is called by vLLM + def decode_forward( + self, + start_pos, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + xattn_caches=None, + page_table=None, + kv_cache=None, + cross_page_table=None, + enable_trace=True, + read_from_device=True, + ): + B = tokens.shape[0] + data_parallel = min(B, self.data_parallel) + batch_per_device = B // data_parallel + tokens = torch.chunk(tokens, self.data_parallel, 0) + start_pos = torch.chunk(start_pos, self.data_parallel, 0) + prefill_cross_attention_masks = [ + prefill_cross_attention_masks[i * batch_per_device : (i + 1) * batch_per_device] + for i in range(data_parallel) + ] + prefill_full_text_row_masked_out_mask = [ + prefill_full_text_row_masked_out_mask[i * batch_per_device : (i + 1) * batch_per_device] + for i in range(data_parallel) + ] + decode_cross_attention_masks = [ + decode_cross_attention_masks[i * batch_per_device : (i + 1) * batch_per_device] + for i in range(data_parallel) + ] + decode_full_text_row_masked_out_mask = [ + decode_full_text_row_masked_out_mask[i * batch_per_device : (i + 1) * batch_per_device] + for i in range(data_parallel) + ] + page_table = torch.chunk(page_table, self.data_parallel, 0) if page_table is not None else None + cross_page_table = ( + torch.chunk(cross_page_table, self.data_parallel, 0) if cross_page_table is not None else None + ) + + decode_kwargs = { + "position_id": start_pos, + "tokens": tokens, + "prefill_cross_attention_masks": prefill_cross_attention_masks, + "prefill_full_text_row_masked_out_mask": prefill_full_text_row_masked_out_mask, + "decode_cross_attention_masks": decode_cross_attention_masks, + "decode_full_text_row_masked_out_mask": decode_full_text_row_masked_out_mask, + "xattn_caches": xattn_caches, + "page_table": page_table, + "kv_cache": kv_cache, + "cross_page_table": cross_page_table, + } + if enable_trace: + tt_logits = self._easy_trace(**decode_kwargs) + else: + tt_logits = self._decode_forward_no_trace(**decode_kwargs) + + if read_from_device: + to_host = self.read_decode_output(tt_logits) + return self.process_decode_output_host(to_host) + else: + return tt_logits + + # Note: This function is called by vLLM + def read_decode_output(self, tt_out, async_read=False): + """ + Input tt_out is a list of ttnn device tensors + """ + if not async_read: + return [out.cpu() for out in tt_out] + + host_outputs = [] + read_events = [] + for i in range(self.data_parallel): + host_outputs.append(tt_out[i].cpu(blocking=False)) + read_events.append(ttnn.record_event(self.model[i].mesh_device, 0)) + + return host_outputs, read_events + + # Note: This function is called by vLLM + def process_decode_output_host(self, tt_out, is_tokens=False): + """ + Converts the input ttnn host tensors to a torch tensor. + The input can be logits (if is_tokens=False) or tokens (if is_tokens=True). + """ + max_batch_size_per_model = self.model_args[0].max_batch_size + + logits = [] + for i in range(self.data_parallel): + logits_i = self.model[i].process_output_decode( + tt_out[i], max_batch_size_per_model, S=1, is_tokens=is_tokens + ) + logits.append(logits_i) + + return torch.cat(logits, 0) + + def _decode_forward_no_trace( + self, + position_id, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + xattn_caches=None, + page_table=None, + kv_cache=None, + cross_page_table=None, + ): + """ + Performs text decode step. + Returns tt_logits on device + """ + + # forward_decode should be traced callable + # decorator does compilation, capture, execute + tt_h = [] + tt_xattn_mask = [] + tt_full_text_mask_expand_1NSH = [] + tt_full_text_mask_expand_11SD = [] + tt_position_id = [] + tt_rot_mats = [] + tt_page_table = [] + tt_cross_page_table = [] + + for i in range(self.data_parallel): + B, S = tokens[i].shape + assert S == 1 + + user_page_table = page_table[i] if page_table is not None else None + user_cross_page_table = cross_page_table[i] if cross_page_table is not None else None + ( + tt_h_i, + tt_xattn_mask_i, + tt_full_text_mask_expand_1NSH_i, + tt_full_text_mask_expand_11SD_i, + tt_position_id_i, + tt_rot_mats_i, + tt_page_table_i, + tt_cross_page_table_i, + ) = self.model[i].prepare_inputs_decode( + tokens[i], + prefill_cross_attention_masks[i], + prefill_full_text_row_masked_out_mask[i], + decode_cross_attention_masks[i], + decode_full_text_row_masked_out_mask[i], + position_id=position_id[i], + page_table=user_page_table, + cross_page_table=user_cross_page_table, + ) + + tt_h.append(tt_h_i) + tt_xattn_mask.append(tt_xattn_mask_i) + tt_full_text_mask_expand_1NSH.append(tt_full_text_mask_expand_1NSH_i) + tt_full_text_mask_expand_11SD.append(tt_full_text_mask_expand_11SD_i) + tt_position_id.append(tt_position_id_i) + tt_rot_mats.append(tt_rot_mats_i) + tt_page_table.append(tt_page_table_i) + tt_cross_page_table.append(tt_cross_page_table_i) + + tt_logits = [] + for i in range(self.data_parallel): + user_kv_cache = kv_cache[i] if kv_cache is not None else None + xattn_cache = xattn_caches[i] if xattn_caches is not None else None + tt_logits_i = self.model[i].ttnn_decode_forward( + tt_h[i], + tt_xattn_mask[i], + tt_full_text_mask_expand_1NSH[i], + tt_full_text_mask_expand_11SD[i], + xattn_cache, + tt_position_id[i], + tt_rot_mats[i], + page_table=tt_page_table[i], + kv_cache=user_kv_cache, + cross_page_table=tt_cross_page_table[i], + ) + tt_logits.append(tt_logits_i) + + return tt_logits + + def _capture_trace( + self, + position_id, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + xattn_caches, + page_table=None, + kv_cache=None, + cross_page_table=None, + ): + """ + Captures a trace for the decode_forward method. + """ + tt_h = [] + tt_xattn_mask = [] + tt_full_text_mask_expand_1NSH = [] + tt_full_text_mask_expand_11SD = [] + tt_position_id = [] + tt_rot_mats = [] + tt_page_table = [] + tt_cross_page_table = [] + for i in range(self.data_parallel): + user_page_table = page_table[i] if page_table is not None else None + user_cross_page_table = cross_page_table[i] if cross_page_table is not None else None + ( + tt_h_i, + tt_xattn_mask_i, + tt_full_text_mask_expand_1NSH_i, + tt_full_text_mask_expand_11SD_i, + tt_position_id_i, + tt_rot_mats_i, + tt_page_table_i, + tt_cross_page_table_i, + ) = self.model[i].prepare_inputs_decode( + tokens[i], + prefill_cross_attention_masks[i], + prefill_full_text_row_masked_out_mask[i], + decode_cross_attention_masks[i], + decode_full_text_row_masked_out_mask[i], + position_id=position_id[i], + page_table=user_page_table, + cross_page_table=user_cross_page_table, + ) + + tt_h.append(tt_h_i) + tt_xattn_mask.append(tt_xattn_mask_i) + tt_full_text_mask_expand_1NSH.append(tt_full_text_mask_expand_1NSH_i) + tt_full_text_mask_expand_11SD.append(tt_full_text_mask_expand_11SD_i) + tt_position_id.append(tt_position_id_i) + tt_rot_mats.append(tt_rot_mats_i) + tt_page_table.append(tt_page_table_i) + tt_cross_page_table.append(tt_cross_page_table_i) + + # Compile run + for i in range(self.data_parallel): + user_kv_cache = kv_cache[i] if kv_cache is not None else None + xattn_cache = xattn_caches[i] if xattn_caches is not None else None + # tt_logits_rm unused later, no need to make a list + tt_logits_rm = self.model[i].ttnn_decode_forward( + tt_h[i], + tt_xattn_mask[i], + tt_full_text_mask_expand_1NSH[i], + tt_full_text_mask_expand_11SD[i], + xattn_cache, + tt_position_id[i], + tt_rot_mats[i], + page_table=tt_page_table[i], + kv_cache=user_kv_cache, + cross_page_table=tt_cross_page_table[i], + ) + logger.info("Done Compiling Model") + + # Get inputs ready for trace run + tt_h = [] + tt_xattn_mask = [] + tt_full_text_mask_expand_1NSH = [] + tt_full_text_mask_expand_11SD = [] + tt_position_id = [] + tt_rope_id = [] + tt_page_table = [] + tt_cross_page_table = [] + for i in range(self.data_parallel): + user_page_table = page_table[i] if page_table is not None else None + user_cross_page_table = cross_page_table[i] if cross_page_table is not None else None + ( + tt_h_i, + tt_xattn_mask_i, + tt_full_text_mask_expand_1NSH_i, + tt_full_text_mask_expand_11SD_i, + tt_position_id_i, + tt_rope_id_i, + tt_page_table_i, + tt_cross_page_table_i, + ) = self.model[i].prepare_decode_inputs_host( + tokens[i], + prefill_cross_attention_masks[i], + prefill_full_text_row_masked_out_mask[i], + decode_cross_attention_masks[i], + decode_full_text_row_masked_out_mask[i], + position_id[i], + page_table=user_page_table, + cross_page_table=user_cross_page_table, + ) + + ( + tt_h_i, + tt_xattn_mask_i, + tt_full_text_mask_expand_1NSH_i, + tt_full_text_mask_expand_11SD_i, + tt_position_id_i, + tt_rope_id_i, + tt_page_table_i, + tt_cross_page_table_i, + ) = copy_host_to_device( + ( + tt_h_i, + tt_xattn_mask_i, + tt_full_text_mask_expand_1NSH_i, + tt_full_text_mask_expand_11SD_i, + tt_position_id_i, + tt_rope_id_i, + tt_page_table_i, + tt_cross_page_table_i, + ), + mesh_device=self.model_args[i].mesh_device, + ) + + tt_h.append(tt_h_i) + tt_xattn_mask.append(tt_xattn_mask_i) + tt_full_text_mask_expand_1NSH.append(tt_full_text_mask_expand_1NSH_i) + tt_full_text_mask_expand_11SD.append(tt_full_text_mask_expand_11SD_i) + tt_position_id.append(tt_position_id_i) + tt_rope_id.append(tt_rope_id_i) + tt_page_table.append(tt_page_table_i) + tt_cross_page_table.append(tt_cross_page_table_i) + + tt_h_trace_input = tt_h + + tt_logits_rm = [] + trace_ids = {} + # Do on-device transformations of inputs before forward + for i in range(self.data_parallel): + trace_id = ttnn.begin_trace_capture(self.model_args[i].mesh_device, cq_id=0) + trace_ids[i] = trace_id + B = tokens[i].shape[0] + user_kv_cache = kv_cache[i] if kv_cache is not None else None + xattn_cache = xattn_caches[i] if xattn_caches is not None else None + ( + tt_h_transform, + tt_rot_mats, + tt_xattn_mask_transform, + tt_full_text_mask_expand_1NSH_transform, + tt_full_text_mask_expand_11SD_transform, + ) = self.model[i].transform_decode_inputs_device( + tt_h[i], + tt_rope_id[i], + tt_xattn_mask[i], + tt_full_text_mask_expand_1NSH[i], + tt_full_text_mask_expand_11SD[i], + B=B, + ) + + tt_logits_rm_i = self.model[i].ttnn_decode_forward( + tt_h_transform, + tt_xattn_mask_transform, + tt_full_text_mask_expand_1NSH_transform, + tt_full_text_mask_expand_11SD_transform, + xattn_cache, + tt_position_id[i], + tt_rot_mats, + page_table=tt_page_table[i], + kv_cache=user_kv_cache, + cross_page_table=tt_cross_page_table[i], + ) + tt_logits_rm.append(tt_logits_rm_i) + ttnn.end_trace_capture(self.model_args[i].mesh_device, trace_id, cq_id=0) + logger.info("Done Capturing Decode Trace") + + return ( + trace_ids, + tt_logits_rm, + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_full_text_mask_expand_11SD, + tt_position_id, + tt_rope_id, + tt_page_table, + tt_cross_page_table, + ) + + def _decode_forward_trace( + self, + position_id, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + page_table, + cross_page_table, + trace_ids, + trace_logits_rm, + trace_h, + trace_xattn_mask, + trace_full_text_mask_expand_1NSH, + trace_full_text_mask_expand_11SD, + trace_position_id, + trace_rope_id, + trace_page_table, + trace_cross_page_table, + ): + """ + Executes the trace for the decode_forward method but does not read back outputs. + """ + for i in range(self.data_parallel): + user_page_table = page_table[i] if page_table is not None else None + user_cross_page_table = cross_page_table[i] if cross_page_table is not None else None + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_full_text_mask_expand_11SD, + tt_position_id, + tt_rope_id, + tt_page_table, + tt_cross_page_table, + ) = self.model[i].prepare_decode_inputs_host( + tokens[i], + prefill_cross_attention_masks[i], + prefill_full_text_row_masked_out_mask[i], + decode_cross_attention_masks[i], + decode_full_text_row_masked_out_mask[i], + position_id=position_id[i], + page_table=user_page_table, + cross_page_table=user_cross_page_table, + ) + + copy_host_to_device( + host_tensors=( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_full_text_mask_expand_11SD, + tt_position_id, + tt_rope_id, + tt_page_table, + tt_cross_page_table, + ), + device_tensors=( + trace_h[i], + trace_xattn_mask[i], + trace_full_text_mask_expand_1NSH[i], + trace_full_text_mask_expand_11SD[i], + trace_position_id[i], + trace_rope_id[i], + trace_page_table[i], + trace_cross_page_table[i], + ), + ) + for i, trace_id in trace_ids.items(): + ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False) + + return trace_logits_rm + + def _easy_trace( + self, + position_id, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + xattn_caches=None, + page_table=None, + kv_cache=None, + cross_page_table=None, + ): + """ + Tracing is easy! Just call this method and we'll handle tracing for you. + """ + if not hasattr(self, "trace_ids"): + ( + trace_ids, + tt_logits_rm, + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_full_text_mask_expand_11SD, + tt_position_id, + tt_rope_id, + tt_page_table, + tt_cross_page_table, + ) = self._capture_trace( + position_id, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + xattn_caches, + page_table=page_table, + kv_cache=kv_cache, + cross_page_table=cross_page_table, + ) + self.trace_ids = trace_ids + self.trace_inputs = { + "tt_h": tt_h, + "tt_xattn_mask": tt_xattn_mask, + "tt_full_text_mask_expand_1NSH": tt_full_text_mask_expand_1NSH, + "tt_full_text_mask_expand_11SD": tt_full_text_mask_expand_11SD, + "tt_position_id": tt_position_id, + "tt_rope_id": tt_rope_id, + "tt_page_table": tt_page_table, + "tt_cross_page_table": tt_cross_page_table, + } + self.trace_outputs = { + "tt_logits_rm": tt_logits_rm, + } + + trace_logits_rm = self._decode_forward_trace( + position_id, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + page_table, + cross_page_table, + self.trace_ids, + self.trace_outputs["tt_logits_rm"], + self.trace_inputs["tt_h"], + self.trace_inputs["tt_xattn_mask"], + self.trace_inputs["tt_full_text_mask_expand_1NSH"], + self.trace_inputs["tt_full_text_mask_expand_11SD"], + self.trace_inputs["tt_position_id"], + self.trace_inputs["tt_rope_id"], + self.trace_inputs["tt_page_table"], + self.trace_inputs["tt_cross_page_table"], + ) + + return trace_logits_rm + + def generate( + self, + model_input, + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + ): + # Do initial prefill + vision_images = model_input.vision.images + vision_mask = model_input.vision.mask + prompt_tokens = model_input.tokens + prefill_len = len(prompt_tokens) + total_len = prefill_len + max_gen_len # Prepares mask for full length of output + + prompt_tokens_tensor = torch.tensor(prompt_tokens, dtype=torch.long).reshape(1, -1) # B, S + # Suboptimal to allocate caches every time + model_id = 0 + xattn_caches = self.model[model_id].setup_cache(self.model_args[model_id].max_batch_size) + ( + xattn_caches, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + logits, + ) = self._prefill_forward_single_user( + vision_images, + vision_mask, + prompt_tokens_tensor, + xattn_caches, + user_id=0, + total_len=total_len, + prefill_len=prefill_len, + model_id=model_id, + ) + + last_token_idx = prefill_len - 1 + logits = self.model[model_id].process_output_prefill(logits, 1, last_token_idx=(last_token_idx % 32)) + logits = logits.view(1, 1, self.model_args[model_id].vocab_size) + + prefill_output_xattn_masks = [[] for _ in range(self.data_parallel)] + prefill_output_full_text_row_masked_out_masks = [[] for _ in range(self.data_parallel)] + decode_output_xattn_masks = [[] for _ in range(self.data_parallel)] + decode_output_full_text_row_masked_out_masks = [[] for _ in range(self.data_parallel)] + + prefill_output_xattn_masks[model_id].append(prefill_cross_attention_masks) + prefill_output_full_text_row_masked_out_masks[model_id].append(prefill_full_text_row_masked_out_mask) + decode_output_xattn_masks[model_id].append(decode_cross_attention_masks) + decode_output_full_text_row_masked_out_masks[model_id].append(decode_full_text_row_masked_out_mask) + + def sample(logits): + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + next_token = next_token.reshape(-1) + return next_token, self.tokenizer.decode(next_token.tolist()) + + next_token, text = sample(logits) + + yield TokenResult( + token=next_token[0].item(), + text=text, + ) + + for gen_idx in range(max_gen_len - 1): + position_id = torch.tensor([prefill_len + gen_idx]) + next_token_tensor = next_token.reshape(1, 1) # B, S + + logits = self.decode_forward( + position_id, + next_token_tensor, + prefill_output_xattn_masks, + prefill_output_full_text_row_masked_out_masks, + decode_output_xattn_masks, + decode_output_full_text_row_masked_out_masks, + [xattn_caches], + enable_trace=False, + ) + next_token, text = sample(logits) + yield TokenResult( + token=next_token[0].item(), + text=text, + ) + + def chat_completion( + self, + messages, + temperature=0.6, + top_p: float = 0.9, + max_gen_len=None, + ): + model_id = 0 + if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model[model_id].configuration.max_seq_len: + max_gen_len = self.model[model_id].configuration.max_seq_len - 1 + + tokens = [] + + stop_reason = None + for result in self.generate( + model_input=self.formatter.encode_dialog_prompt(messages, tool_prompt_format=False), + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ): + tokens.append(result.token) + if result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + elif result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + + if stop_reason is None: + stop_reason = StopReason.out_of_tokens + + message = self.formatter.decode_assistant_message(tokens, stop_reason) + + return ChatPrediction(generation=message) + + def text_completion( + self, + content: InterleavedTextMedia, + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len=None, + ): + model_id = 0 + if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model[model_id].configuration.max_seq_len: + max_gen_len = self.model[model_id].configuration.max_seq_len - 1 + + model_input = self.formatter.encode_content(content) + + tokens = [] + + for result in self.generate( + model_input=model_input, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ): + tokens.append(result.token) + + generation = self.tokenizer.decode(tokens) + + return CompletionPrediction(generation=generation) + + def _get_prefill_user_page_table(self, page_table, kv_cache, prefill_len): + # Ensure page_table is not padded with extra blocks for paged_fill_cache to work properly + block_size = get_block_size(kv_cache) + num_blocks = num_blocks_in_seq(prefill_len, block_size) + return page_table[:, :num_blocks] + + ## Destructor + + def __del__(self): + # Workaround for issue #19052 + if self.data_parallel > 1: + for m in self.model: + ttnn.close_mesh_device(m.mesh_device) + + if hasattr(super(Gemma3Generator, self), "__del__"): + super().__del__() + + +def create_submeshes(mesh_device, data_parallel): + if not isinstance(mesh_device, ttnn.MeshDevice) or data_parallel == 1: + return [mesh_device] + + num_rows, num_cols = mesh_device.shape + num_devices = num_rows * num_cols + assert num_devices % data_parallel == 0, f"Unsupported device split: {num_devices} devices, {data_parallel} groups" + + if num_rows == 8 and num_cols == 4 and num_cols % data_parallel == 0: + submeshes = mesh_device.create_submeshes(ttnn.MeshShape(num_rows, num_cols // data_parallel)) + for submesh in submeshes: + submesh.reshape(ttnn.MeshShape(1, num_devices // data_parallel)) + return submeshes + + return mesh_device.create_submeshes(ttnn.MeshShape(1, num_devices // data_parallel)) diff --git a/models/experimental/gemma3/tt/gemma_conv2d_patch.py b/models/experimental/gemma3/tt/gemma_conv2d_patch.py new file mode 100644 index 000000000000..a27810dbe09d --- /dev/null +++ b/models/experimental/gemma3/tt/gemma_conv2d_patch.py @@ -0,0 +1,123 @@ +""" +source: models/tt_transformers/tt/multimodal/llama_conv2d_patch.py +This is the Conv2dPath of Gemma3 +We have reused the exisiting Conv2dPath of TtLlamaConv2dPath with few modifications. +We have added a check for weight to convert 4D to 2D +""" +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.utility_functions import nearest_32 + + +class TtGemmaConv2dPatch(LightweightModule): + """Conv2D Patching layer. + Column parallel over unfolded input. + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + Input: (bsz, in_channels, width, height) + Output: (bsz, num_tokens, out_channels) + """ + + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + dtype, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + bias, + ): + super().__init__() + + self.mesh_device = mesh_device + self.num_devices = self.mesh_device.get_num_devices() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + + self.bias = ( + ttnn.as_tensor( + torch.reshape(state_dict[f"{state_dict_prefix}_linear.bias"], (1, -1)), + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + if bias + else None + ) + + self._unfold = torch.nn.Unfold(kernel_size=self.kernel_size, stride=self.stride) + + weight = state_dict[f"{state_dict_prefix}_linear.weight"] + if weight.ndim == 4: + weight = weight.view(out_channels, -1) + pad_len = nearest_32(weight.shape[-1]) - weight.shape[-1] + padding = torch.zeros(self.out_channels, pad_len, dtype=weight.dtype) + padded_weight = torch.cat([weight, padding], dim=-1) + padded_weight = padded_weight.permute(1, 0).reshape(1, 1, -1, self.out_channels) + + self._linear_weight = ttnn.as_tensor( + padded_weight, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + self.compute_kernel_config = ttnn.init_device_compute_kernel_config( + mesh_device.arch(), + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + def forward(self, x: torch.Tensor): + x = self._unfold(x) + x = x.permute(0, 2, 1) + + # Need to pad the last dimension of x to be a multiple of a tile + pad_len = nearest_32(x.shape[-1]) - x.shape[-1] + padding = torch.zeros((x.shape[0], x.shape[1], pad_len), dtype=x.dtype, device=x.device) + x = torch.cat([x, padding], dim=-1) + + x = ttnn.as_tensor( + x, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + out = ttnn.linear( + x, + self._linear_weight, + bias=self.bias, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + core_grid=ttnn.CoreGrid(y=8, x=8), + ) + + ttnn.deallocate(x) + + return out diff --git a/models/experimental/gemma3/tt/gemma_image_attention.py b/models/experimental/gemma3/tt/gemma_image_attention.py new file mode 100644 index 000000000000..40c05e439552 --- /dev/null +++ b/models/experimental/gemma3/tt/gemma_image_attention.py @@ -0,0 +1,401 @@ +""" +source: models/tt_transformers/tt/multimodal/llama_image_attention.py + +This is the ImageAttention block for Gemma3 +We have reused the TTLlamaImageAttention with some modification. +We have made the linears (Q,K,V) to be executed separately and added bias support for O_projection, along with few +configuration changes. +""" + + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.utility_functions import nearest_32 + + +class TtGemmaImageAttention(LightweightModule): + def __init__( + self, + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.num_devices = configuration.num_devices + self.tt_ccl = tt_ccl + self.hidden_size = configuration.vision_dim + self.n_heads = configuration.vision_attn_n_heads + self.head_dim = self.hidden_size // self.n_heads + self.n_kv_heads = self.n_heads + + self.n_local_heads = self.n_heads // configuration.num_devices + self.n_local_kv_heads = self.n_kv_heads // configuration.num_devices + + self.dtype = dtype + + self.grid_size = configuration.max_grid_size + + self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 + self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 + self.compute_kernel_config_sdpa = configuration.compute_kernel_config_sdpa + self.configuration = configuration + + self.model_config = configuration.get_model_config() + + if configuration.dummy_weights or (weight_cache_path is None): + cache_name = lambda _: None + else: + cache_name = lambda name: weight_cache_path / (f"{state_dict_prefix}{name}") + + wq_str = f"{state_dict_prefix}wq.weight" + wk_str = f"{state_dict_prefix}wk.weight" + wv_str = f"{state_dict_prefix}wv.weight" + wo_str = f"{state_dict_prefix}wo.weight" + + # when splitting the devices, we need to make sure that the number of heads is divisible by the number of devices + assert self.n_heads % configuration.num_devices == 0 + assert self.n_kv_heads % configuration.num_devices == 0 + + # Pad head_dim to multiple of 32 + def pad_head_dim(weight, heads_out=True): + # Pad head dim to multiple of 32 + # heads_out means that the output dim of this weight contains heads. + dim = weight.shape[1] + assert weight.shape[0] == dim + padded_head_dim = nearest_32(self.head_dim) + padding_size = padded_head_dim - self.head_dim + if padding_size > 0: + if heads_out: + weight = weight.transpose(-1, -2) + weight = weight.reshape(dim, self.n_heads, self.head_dim) + padding = torch.zeros(dim, self.n_heads, padding_size, dtype=weight.dtype) + weight = torch.cat([weight, padding], dim=-1) + weight = weight.reshape(dim, self.n_heads * padded_head_dim) + if heads_out: + weight = weight.transpose(-1, -2) + return weight + + wq_padded = pad_head_dim(self.state_dict[wq_str]) + wk_padded = pad_head_dim(self.state_dict[wk_str]) + wv_padded = pad_head_dim(self.state_dict[wv_str]) + wo_padded = pad_head_dim(self.state_dict[wo_str], heads_out=False) + wq_chunked, wk_chunked, wv_chunked = ( + torch.chunk(w, configuration.num_devices) for w in [wq_padded, wk_padded, wv_padded] + ) + + # for Gemma + self.wq = ttnn.as_tensor( + tensor=wq_padded, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + cache_file_name=cache_name("wq_sharded"), + preprocess=lambda x: x.transpose(-2, -1), + ) + + self.wk = ttnn.as_tensor( + tensor=wk_padded, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + cache_file_name=cache_name("wk_sharded"), + preprocess=lambda x: x.transpose(-2, -1), + ) + + self.wv = ttnn.as_tensor( + tensor=wv_padded, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + cache_file_name=cache_name("wv_sharded"), + preprocess=lambda x: x.transpose(-2, -1), + ) + + self.wqkv = ttnn.as_tensor( + torch.concat( + [ + torch.concat( + [ + torch.transpose( + wq_chunked[i], + -2, + -1, + ), + torch.transpose( + wk_chunked[i], + -2, + -1, + ), + torch.transpose( + wv_chunked[i], + -2, + -1, + ), + ], + dim=-1, + ) + for i in range(configuration.num_devices) + ], + dim=-1, + ), + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + dtype=self.dtype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + cache_file_name=cache_name("wqkv_sharded"), + ) + + bq_str = f"{state_dict_prefix}wq.bias" + bk_str = f"{state_dict_prefix}wk.bias" + bv_str = f"{state_dict_prefix}wv.bias" + bo_str = f"{state_dict_prefix}wo.bias" + + if bq_str in self.state_dict: + + def pad_head_dim_bias(bias): + # Pad 1D bias to match padded head dim + dim = bias.shape[0] + assert ( + dim == self.n_heads * self.head_dim + ), f"Expected bias of shape ({self.n_heads} * {self.head_dim}) = {self.n_heads * self.head_dim}, but got {dim}" + + padded_head_dim = nearest_32(self.head_dim) + padding_size = padded_head_dim - self.head_dim + + if padding_size > 0: + bias = bias.view(self.n_heads, self.head_dim) + padding = torch.zeros(self.n_heads, padding_size, dtype=bias.dtype) + bias = torch.cat([bias, padding], dim=-1) + bias = bias.view(self.n_heads * padded_head_dim) + + return bias + + bq_padded = pad_head_dim_bias(self.state_dict[bq_str]) + bk_padded = pad_head_dim_bias(self.state_dict[bk_str]) + bv_padded = pad_head_dim_bias(self.state_dict[bv_str]) + + bq_chunked, bk_chunked, bv_chunked = ( + torch.chunk(b, configuration.num_devices) for b in [bq_padded, bk_padded, bv_padded] + ) + + self.bqkv = ttnn.as_tensor( + torch.concat( + [ + torch.concat( + [ + bq_chunked[i], + bk_chunked[i], + bv_chunked[i], + ], + dim=-1, + ) + for i in range(configuration.num_devices) + ], + dim=-1, + ), + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + dtype=self.dtype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + cache_file_name=cache_name("bqkv_sharded"), + ) + + # for Gemma + self.bq = ttnn.as_tensor( + tensor=bq_padded, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + cache_file_name=cache_name("bq_sharded"), + ) + + self.bk = ttnn.as_tensor( + tensor=bk_padded, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + cache_file_name=cache_name("bk_sharded"), + ) + + self.bv = ttnn.as_tensor( + tensor=bv_padded, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + cache_file_name=cache_name("bv_sharded"), + ) + + else: + self.bqkv = None + + self.wo = ttnn.as_tensor( + torch.transpose( + wo_padded, + -2, + -1, + ), + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=self.dtype, + layout=ttnn.TILE_LAYOUT, + # cache_file_name=cache_name("wo_sharded"), + ) + + if bo_str in self.state_dict: + self.bo = ttnn.as_tensor( + self.state_dict[bo_str], + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=self.dtype, + layout=ttnn.TILE_LAYOUT, + cache_file_name=cache_name("bo_sharded"), + ) + else: + self.bo = None + + self.scale = self.head_dim**-0.5 + + def forward(self, x_11SH, mask=None): + seq_len = x_11SH.shape[-2] + + MAX_MM_SEQ_LEN = ( + seq_len if "gemma-3" in self.configuration.base_model_name else self.configuration.VISION_MAX_MM_SEQ + ) + + if seq_len > MAX_MM_SEQ_LEN: + x_11SH = ttnn.reshape(x_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) + + q_heads_1QSD = ttnn.linear( + x_11SH, + self.wq, + bias=self.bq, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config_hifi4, + program_config=None + if "gemma-3" in self.configuration.base_model_name + else self.model_config["IMAGE_ATTN_QKV_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + + q_heads_1QSD = ttnn.transpose(ttnn.reshape(q_heads_1QSD, (1, seq_len, self.n_local_heads, -1)), 1, 2) + + k_heads_1KSD = ttnn.linear( + x_11SH, + self.wk, + bias=self.bk, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config_hifi4, + program_config=None + if "gemma-3" in self.configuration.base_model_name + else self.model_config["IMAGE_ATTN_QKV_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + + k_heads_1KSD = ttnn.transpose(ttnn.reshape(k_heads_1KSD, (1, seq_len, self.n_local_heads, -1)), 1, 2) + + v_heads_1VSD = ttnn.linear( + x_11SH, + self.wv, + bias=self.bv, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config_hifi4, + program_config=None + if "gemma-3" in self.configuration.base_model_name + else self.model_config["IMAGE_ATTN_QKV_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + v_heads_1VSD = ttnn.transpose(ttnn.reshape(v_heads_1VSD, (1, seq_len, self.n_local_heads, -1)), 1, 2) + + # TODO: get this from model_config + sdpa_cfg = ttnn.SDPAProgramConfig( + compute_with_storage_grid_size=(8, 8), q_chunk_size=128, k_chunk_size=128, exp_approx_mode=False + ) + attn_output_1QSD = ttnn.transformer.scaled_dot_product_attention( + q_heads_1QSD, + k_heads_1KSD, + v_heads_1VSD, + is_causal=False, + scale=self.scale, + attn_mask=mask, + program_config=sdpa_cfg, + compute_kernel_config=self.compute_kernel_config_sdpa, + ) + # deallocate keys and values + ttnn.deallocate(q_heads_1QSD) + ttnn.deallocate(k_heads_1KSD) + ttnn.deallocate(v_heads_1VSD) + + ### + # Output matmul + ### + attn_output_11SH = ttnn.experimental.nlp_concat_heads( + attn_output_1QSD, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + ttnn.deallocate(attn_output_1QSD) + + # reshaping long sequence to matmul fit on device + if seq_len > MAX_MM_SEQ_LEN: + attn_output_11SH = ttnn.reshape(attn_output_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) + + # if self.num_devices > 1: + # # self.bo = ttnn.all_gather(self.bo, dim=3, num_links=1) + # attn_output_11SH = ttnn.all_gather(attn_output_11SH, dim=3, num_links=1) + if self.num_devices > 1: # replace with reduce_scatter and all_gather + attn_output_11SH = ttnn.experimental.all_gather_async( + attn_output_11SH, + persistent_output_buffer=None, + dim=3, + multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(), + num_links=1, + topology=ttnn.Topology.Linear, + barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(), + chunks_per_sync=10, + num_workers_per_link=2, + num_buffers_per_channel=2, + ) + output_11SH = ttnn.linear( + attn_output_11SH, + self.wo, + bias=self.bo, + compute_kernel_config=self.compute_kernel_config_hifi4, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + program_config=None + if "gemma-3" in self.configuration.base_model_name + else self.model_config["IMAGE_ATTN_QKV_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + if seq_len > MAX_MM_SEQ_LEN: + output_11SH = ttnn.reshape(output_11SH, [1, 1, seq_len, -1]) + ttnn.deallocate(attn_output_11SH) + + return output_11SH diff --git a/models/experimental/gemma3_4b/tt/gemma_image_block.py b/models/experimental/gemma3/tt/gemma_image_block.py similarity index 77% rename from models/experimental/gemma3_4b/tt/gemma_image_block.py rename to models/experimental/gemma3/tt/gemma_image_block.py index 18cf86935792..e0eb0b88017c 100644 --- a/models/experimental/gemma3_4b/tt/gemma_image_block.py +++ b/models/experimental/gemma3/tt/gemma_image_block.py @@ -1,7 +1,7 @@ """ source: models/tt_transformers/tt/multimodal/llama_image_block.py -This is the ImageTransformer block for Gemma-3-4b-it. +This is the ImageTransformer block for Gemma3. We have reused the TtLlamaImageTransformerBlock with incorporating the TtGemmaImageAttention and TtGemmaImageFeedForward """ @@ -13,16 +13,16 @@ import ttnn from models.common.lightweightmodule import LightweightModule -from models.experimental.gemma3_4b.tt.gemma_image_attention import TtGemmaImageAttention -from models.experimental.gemma3_4b.tt.gemma_image_mlp import TtGemmaImageFeedForward +from models.experimental.gemma3.tt.gemma_image_attention import TtGemmaImageAttention +from models.experimental.gemma3.tt.gemma_image_mlp import TtGemmaImageFeedForward from models.tt_transformers.tt.multimodal.llama_layernorm import TtLayerNorm +from models.tt_transformers.tt.ccl import TT_CCL class TtGemmaImageTransformerBlock(LightweightModule): def __init__( self, mesh_device, - tt_ccl, state_dict, state_dict_prefix, weight_cache_path, @@ -34,10 +34,10 @@ def __init__( self.state_dict = state_dict self.mesh_device = mesh_device - self.tt_ccl = tt_ccl self.num_devices = configuration.num_devices self.hidden_size = configuration.vision_dim self.gated = gated + self.tt_ccl = TT_CCL(mesh_device) self.ln_1 = TtLayerNorm( device=mesh_device, @@ -51,7 +51,7 @@ def __init__( self.attn = TtGemmaImageAttention( mesh_device, - tt_ccl, + self.tt_ccl, state_dict, state_dict_prefix=f"{state_dict_prefix}attn.", weight_cache_path=weight_cache_path, @@ -71,7 +71,7 @@ def __init__( self.mlp = TtGemmaImageFeedForward( mesh_device=mesh_device, - tt_ccl=tt_ccl, + tt_ccl=self.tt_ccl, args=configuration, state_dict=state_dict, state_dict_prefix=f"{state_dict_prefix}mlp.", @@ -105,12 +105,29 @@ def forward(self, x_11SH, mask=None): attn_out = self.attn(self.ln_1(x_11SH), mask=mask) if self.gated: attn_out = ttnn.mul(attn_out, ttnn.tanh(self.gate_attn)) + if self.num_devices > 1: # replace with reduce_scatter and all_gather + attn_out = ttnn.experimental.all_gather_async( + attn_out, + persistent_output_buffer=None, + dim=3, + multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(), + num_links=1, + topology=ttnn.Topology.Linear, + barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(), + chunks_per_sync=10, + num_workers_per_link=2, + num_buffers_per_channel=2, + ) + # if self.num_devices > 1: + # # attn_out = ttnn.all_gather(attn_out, dim=3, num_links=1) res = ttnn.add(x_11SH, attn_out) + mlp_out = self.mlp(self.ln_2(res)) if self.gated: mlp_out = ttnn.mul(mlp_out, ttnn.tanh(self.gate_ffn)) out = ttnn.add(res, mlp_out) + ttnn.deallocate(mlp_out) ttnn.deallocate(attn_out) ttnn.deallocate(res) diff --git a/models/experimental/gemma3/tt/gemma_image_mlp.py b/models/experimental/gemma3/tt/gemma_image_mlp.py new file mode 100644 index 000000000000..ed256073c442 --- /dev/null +++ b/models/experimental/gemma3/tt/gemma_image_mlp.py @@ -0,0 +1,135 @@ +""" +source: models/tt_transformers/tt/multimodal/llama_image_mlp.py +This is the FeedForward submodule for vision block in Gemma3 +We have reused the TtLlamaImageFeedForward with few changes in CoreGrid and program_config configurations +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +class TtGemmaImageFeedForward(LightweightModule): + def __init__( + self, + mesh_device, + tt_ccl, + args, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.tt_ccl = tt_ccl + self.args = args + self.model_config = args.get_model_config() + torch_weight = lambda name, suffix: torch.transpose( + self.state_dict[f"{state_dict_prefix}{name}.{suffix}"], -2, -1 + ) + torch_bias = lambda name, suffix: self.state_dict[f"{state_dict_prefix}{name}.{suffix}"] + + if args.dummy_weights: + cache_name = lambda *_: None + else: + cache_name = lambda name, suffix: weight_cache_path / (state_dict_prefix + f"{name}.{suffix}") + + as_interleaved_tensor = lambda name, suffix, type, dim: ttnn.as_tensor( + ( + torch_weight(name, suffix) if suffix == "weight" else torch_bias(name, suffix) + ), # Grab only the wX part of the name + dtype=type, + device=self.mesh_device, + mesh_mapper=( + ttnn.ShardTensorToMesh(self.mesh_device, dim=dim) + if dim is not None + else ttnn.ReplicateTensorToMesh(self.mesh_device) + ), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + cache_file_name=cache_name(name, suffix), + ) + + # Sharded weights + self.c_fc_weight = as_interleaved_tensor("c_fc", "weight", dtype, dim=-1) + self.c_fc_bias = as_interleaved_tensor("c_fc", "bias", ttnn.bfloat16, dim=-1) + self.c_fc_bias = ttnn.reshape(self.c_fc_bias, [1, -1]) + self.c_proj_weight = as_interleaved_tensor("c_proj", "weight", dtype, dim=-2) + self.c_proj_bias = as_interleaved_tensor("c_proj", "bias", ttnn.bfloat16, dim=None) + + def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: + """ + w1 -> gate_proj + w2 -> down_proj + w3 -> up_proj + HF reference: self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + """ + seq_len = x.shape[-2] + + # Depends on whether we are padding or not + MAX_MM_SEQ_LEN = seq_len if "gemma-3" in self.args.base_model_name else self.args.VISION_MAX_MM_SEQ + + x_in = x + if seq_len >= MAX_MM_SEQ_LEN: # Too big to compute. Set different program configs based on seqlen + # Reshape input to to fit on device and parallelize computation + x_in = ttnn.reshape(x_in, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) + pc_1 = self.model_config["IMAGE_MLP_FC_PROGCFG"](seq_len, MAX_MM_SEQ_LEN) + pc_2 = self.model_config["IMAGE_MLP_PROJ_PROGCFG"](seq_len, MAX_MM_SEQ_LEN) + + # These use HiFi2; this drops 1 bit of the activations but would be FLOP-bound on 12 cores with HiFi4 + c_fc_out = ttnn.linear( + x_in, + self.c_fc_weight, + bias=self.c_fc_bias, + compute_kernel_config=self.args.compute_kernel_config_hifi4, + # core_grid=ttnn.CoreGrid(y=8, x=8) if not pc_1 else None, + dtype=ttnn.bfloat16, + # program_config=pc_1, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + activation="gelu", # NOTE: activation must be passed to linear here, not in program config! Bad output otherwise + ) + + c_proj_out = ttnn.linear( + c_fc_out, + self.c_proj_weight, + compute_kernel_config=self.args.compute_kernel_config_hifi4, + # core_grid=ttnn.CoreGrid(y=8, x=8) if not pc_2 else None, + dtype=ttnn.bfloat16, + # program_config=pc_2, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + # NOTE: Need to reshape to 4D so that fast_reduce_nc hsa a dim1 to work on + c_proj_out = ttnn.reshape(c_proj_out, [1, 1, seq_len, -1]) + + # All reduce + if self.args.num_devices > 1: # replace with reduce_scatter and all_gather + w2_out_gathered = ttnn.experimental.all_gather_async( + c_proj_out, + persistent_output_buffer=None, + dim=1, + multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(), + num_links=1, + topology=ttnn.Topology.Linear, + barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(), + chunks_per_sync=10, + num_workers_per_link=2, + num_buffers_per_channel=2, + ) + pre_bias_output = ttnn.experimental.fast_reduce_nc( + w2_out_gathered, dims=[1], output=None, compute_kernel_config=None + ) + else: + pre_bias_output = c_proj_out + + output = ttnn.add(pre_bias_output, self.c_proj_bias) + return output diff --git a/models/experimental/gemma3_4b/tt/gemma_image_transformer.py b/models/experimental/gemma3/tt/gemma_image_transformer.py similarity index 90% rename from models/experimental/gemma3_4b/tt/gemma_image_transformer.py rename to models/experimental/gemma3/tt/gemma_image_transformer.py index e2e379be45b6..4e0d4101ee96 100644 --- a/models/experimental/gemma3_4b/tt/gemma_image_transformer.py +++ b/models/experimental/gemma3/tt/gemma_image_transformer.py @@ -1,7 +1,7 @@ """ source: models/tt_transformers/tt/multimodal/llama_image_transformer.py -This is the Entire ImageTransformer for Gemma-3-4b-it. +This is the Entire ImageTransformer for Gemma3. We have adapted the TtGemmaImageTransformerBlock from TtLlamaImageTransformerBlock with changes incorporating the GemmaImageAttention and GemmaImageFeedForward """ @@ -12,14 +12,13 @@ from tqdm import tqdm from models.common.lightweightmodule import LightweightModule -from models.experimental.gemma3_4b.tt.gemma_image_block import TtGemmaImageTransformerBlock +from models.experimental.gemma3.tt.gemma_image_block import TtGemmaImageTransformerBlock class TtGemmaImageTransformer(LightweightModule): def __init__( self, mesh_device, - tt_ccl, state_dict, state_dict_prefix, weight_cache_path, @@ -33,13 +32,11 @@ def __init__( self.state_dict = state_dict self.mesh_device = mesh_device - self.tt_ccl = tt_ccl self.gated = gated self.resblocks = [ TtGemmaImageTransformerBlock( mesh_device=mesh_device, - tt_ccl=self.tt_ccl, state_dict=state_dict, state_dict_prefix=f"{state_dict_prefix}{block_key}.{i}.", weight_cache_path=weight_cache_path, diff --git a/models/experimental/gemma3/tt/gemma_vision_crossattention.py b/models/experimental/gemma3/tt/gemma_vision_crossattention.py new file mode 100644 index 000000000000..b6e7f95785ad --- /dev/null +++ b/models/experimental/gemma3/tt/gemma_vision_crossattention.py @@ -0,0 +1,66 @@ +""" +This is the Vision Transformer Block for Gemma3. +This involves vision followed by MultiModalProjector processing +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +from models.common.lightweightmodule import LightweightModule +from models.experimental.gemma3.tt.gemma_vision_model import TtSiglipGemmaVisionModel +from models.experimental.gemma3.tt.mmp import TtGemma3MultiModalProjector + + +class TtGemmaTransformerVision(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + dtype, + configuration, + weight_cache_path=None, + return_intermediate=None, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.model_config = configuration.get_model_config() + + self.dim = configuration.dim + self.vision_dim = configuration.vision_dim + self.image_res = configuration.image_size + self.patch_size = configuration.vision_patch_size + self.configuration = configuration + + self.vision_encoder = TtSiglipGemmaVisionModel( + mesh_device, + state_dict, + state_dict_prefix=configuration.state_dict_vision_prefix, + weight_cache_path=configuration.weight_cache_path(dtype), + dtype=dtype, + configuration=configuration, + return_intermediate=return_intermediate, + ) + + self.mmp = TtGemma3MultiModalProjector( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix="multi_modal_projector", + image_size=self.image_res, + patch_size=self.patch_size, + hidden_size=configuration.vision_hidden_dim, + mm_tokens_per_image=configuration.mm_tokens_per_image, + weight_cache_path=configuration.weight_cache_path(dtype), + layer_norm_eps=1e-06, # layer_norm_eps + dtype=dtype, + configuration=configuration, + ) + + def forward(self, images): + vision_tokens = self.vision_encoder(images)[0, :, :, :] + + vision_tokens = self.mmp(vision_tokens) + return vision_tokens diff --git a/models/experimental/gemma3_4b/tt/gemma_vision_model.py b/models/experimental/gemma3/tt/gemma_vision_model.py similarity index 85% rename from models/experimental/gemma3_4b/tt/gemma_vision_model.py rename to models/experimental/gemma3/tt/gemma_vision_model.py index bd50330d0675..4524426e9ae5 100644 --- a/models/experimental/gemma3_4b/tt/gemma_vision_model.py +++ b/models/experimental/gemma3/tt/gemma_vision_model.py @@ -1,5 +1,5 @@ """ -This is the Vision Tower Model for Gemma-3-4b-it. +This is the Vision Tower Model for Gemma3. """ # SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC @@ -9,8 +9,8 @@ import torch import ttnn from models.common.lightweightmodule import LightweightModule -from models.experimental.gemma3_4b.tt.siglip_vision_embedding import TtSiglipVisionEmbeddings -from models.experimental.gemma3_4b.tt.gemma_image_transformer import TtGemmaImageTransformer +from models.experimental.gemma3.tt.siglip_vision_embedding import TtSiglipVisionEmbeddings +from models.experimental.gemma3.tt.gemma_image_transformer import TtGemmaImageTransformer from models.tt_transformers.tt.multimodal.llama_layernorm import TtLayerNorm @@ -18,7 +18,6 @@ class TtSiglipGemmaVisionModel(LightweightModule): def __init__( self, mesh_device, - tt_ccl, state_dict, state_dict_prefix, dtype, @@ -29,7 +28,6 @@ def __init__( super().__init__() self.state_dict = state_dict self.mesh_device = mesh_device - self.tt_ccl = tt_ccl self.image_size = configuration.image_size self.patch_size = configuration.vision_patch_size @@ -43,8 +41,6 @@ def __init__( self.n_global_layers = configuration.vision_n_global_layers self.return_intermediate = return_intermediate - self.prepare_residual_tensor_prefill = configuration.prepare_residual_tensor_prefill - self.embeddings = TtSiglipVisionEmbeddings( mesh_device=mesh_device, state_dict=state_dict, @@ -60,7 +56,6 @@ def __init__( # transformer self.encoder = TtGemmaImageTransformer( mesh_device=mesh_device, - tt_ccl=self.tt_ccl, state_dict=state_dict, state_dict_prefix=f"{state_dict_prefix}encoder.", weight_cache_path=configuration.weight_cache_path(dtype), @@ -70,6 +65,8 @@ def __init__( block_key="layers", ) + self.prepare_residual_tensor_prefill = configuration.prepare_residual_tensor_prefill + self.ln_post = TtLayerNorm( device=mesh_device, dim=self.width, @@ -88,24 +85,18 @@ def forward(self, images): bsz, in_channel, h, w = images.shape x = self.embeddings(images) - - x = ttnn.to_torch(x) attention_mask = torch.zeros(bsz, 1, x.shape[1], x.shape[1]) - attention_input = self.prepare_residual_tensor_prefill( - x, - force_replicated=True, - ) tt_mask = ttnn.from_torch( attention_mask, device=self.mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) x = self.encoder( - attention_input, + x, mask=tt_mask, ) diff --git a/models/experimental/gemma3/tt/gemma_vision_rmsnorm.py b/models/experimental/gemma3/tt/gemma_vision_rmsnorm.py new file mode 100644 index 000000000000..83836adc2c07 --- /dev/null +++ b/models/experimental/gemma3/tt/gemma_vision_rmsnorm.py @@ -0,0 +1,172 @@ +""" +This is the modified version of the RMSNorm for Gemma3 model. + +We have modified the RMSNorm implementation equivalent to RMSNorm in Gemma3 Models. +We have handled the unit offset addition in the RMSNorm implementation directly into the TTNN Weights +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule + +TILE = 32 +SHARD_HEIGHT = TILE # Current ttnn.rms_norm implementation requires shard height to be a single tile + + +class RMSNorm(LightweightModule): + """ + RMSNorm supporting replication over a MeshDevice and sharding within devices. + + This class implements a Root Mean Square Normalization (RMSNorm) that can be + distributed across multiple devices and cores. If the `device` parameter is a + MeshDevice, the weights and computations are replicated across all devices in + the mesh. Expects an interleaved input tensor, can optionally output a sharded tensor. + + Args: + device: The device or MeshDevice on which to perform the computations. + state_dict: The state dictionary containing the model parameters. + dim: Input dimension (e.g. model hidden dimension size). + layer_num: The layer number to determine the weight key in the state dictionary. + weight_key: The key for retrieving the weight from the state dictionary. + weight_cache_path: Optional path for caching the tilized weights. + weight_memory_config: Configuration for the weight memory, default is DRAM_MEMORY_CONFIG. + weight_dtype: The data type for the tensors, bfp8_b hits >0.999 PCC in the models we tested. + model_config: Optional configuration dictionary for the model. + eps (float): Small value to avoid division by zero in normalization, default is 1e-05. + + If model_config is provided, it must specify SHARDED_NORM_INPUT_MEMCFG, SHARDED_NORM_PRGM_CFG + and SHARDED_NORM_OUTPUT_MEMCFG. If not provided, default configurations will be generated. + """ + + def __init__( + self, + device, + dim, + state_dict, + weight_key, + layer_num=None, + state_dict_prefix=None, + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps: float = 1e-06, + add_unit_offset=True, + sharded_program_config=None, + sharded_output_config=None, + output_mem_config=None, + ccl_topology=ttnn.Topology.Ring, + ): + super().__init__() + self.eps = eps + self.is_distributed = is_distributed + self.ccl_topology = ccl_topology + + if state_dict_prefix: + weight_name = f"{state_dict_prefix}{weight_key}.weight" + else: + if layer_num is None: + weight_name = f"{weight_key}.weight" + else: + weight_name = f"layers.{layer_num}.{weight_key}.weight" + + torch_weight = ( + state_dict[weight_name].unsqueeze(0).view(1, 1, dim).reshape([1, 1, dim // SHARD_HEIGHT, SHARD_HEIGHT]) + ) + if add_unit_offset: + torch_weight = torch_weight + 1.0 + + # # Add offset before caching + cache_name = None if weight_cache_path is None else weight_cache_path / weight_name + + # Compatibility with models that don't use mesh devices (e.g. single-chip Mistral-7b) + is_mesh_device = device.__class__.__name__ == "MeshDevice" + + self.weight = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=cache_name, + mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + ) + + if self.is_distributed: + self.weight_distributed = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=cache_name, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape)) + if is_mesh_device + else None, + ) + + self.sharded_output_config = sharded_output_config + self.sharded_program_config = sharded_program_config + self.output_mem_config = output_mem_config + + self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=False, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) -> ttnn.Tensor: + # If input is sharded do sharded RMSNorm and optionally return sharded output + program_config = self.sharded_program_config if in_sharded else None + memory_config = self.sharded_output_config if out_sharded else None + distributed = self.is_distributed and self.is_distributed(mode) + norm = self._distributed_rmsnorm + weight = self.weight_distributed if distributed else self.weight + + if in_sharded: + assert not distributed, "Distributed RMSNorm does not support sharded inputs" + else: + assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor" + + x = norm( + x, + epsilon=self.eps, + weight=weight, + program_config=program_config, + memory_config=memory_config, + compute_kernel_config=self.compute_kernel_config_hifi2, + ) + + if in_sharded and not out_sharded: + return ttnn.sharded_to_interleaved(x) + else: + return x + + def _distributed_rmsnorm( + self, inp, epsilon=1e-6, weight=None, program_config=None, memory_config=None, compute_kernel_config=None + ): + inp = ttnn.sharded_to_interleaved(inp) + + xnorm = ttnn.pow(inp, 2) + + xnorm = ttnn.mean(xnorm, dim=-1, keepdim=True) + + xnorm = ttnn.rsqrt(xnorm + epsilon) + + xnorm = ttnn.multiply(inp, xnorm) + + weight = ttnn.reshape(weight, [1, 1, 1, -1]) + + output = ttnn.multiply(xnorm, weight) + + if memory_config is not None: + output = ttnn.to_memory_config(output, memory_config) + + ttnn.deallocate(xnorm) + ttnn.deallocate(inp) + + return output diff --git a/models/experimental/gemma3_4b/tt/lm_head.py b/models/experimental/gemma3/tt/lm_head.py similarity index 97% rename from models/experimental/gemma3_4b/tt/lm_head.py rename to models/experimental/gemma3/tt/lm_head.py index 57f5cf36211a..5169245137fa 100644 --- a/models/experimental/gemma3_4b/tt/lm_head.py +++ b/models/experimental/gemma3/tt/lm_head.py @@ -1,7 +1,7 @@ """ source: models/tt_transformers/tt/lm_head.py -This is the LMHead module for the Gemma-3-4B-it model. +This is the LMHead module for the Gemma3 model. """ # SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC @@ -31,11 +31,12 @@ def __init__( super().__init__() self.args = args self.mesh_device = mesh_device + self.tt_ccl = tt_ccl self.dtype = dtype self.vocab_size = args.vocab_size self.padded_vocab_size = args.padded_vocab_size self.num_devices = args.num_devices - self.tt_ccl = tt_ccl + size_per_device = self.vocab_size // self.num_devices if args.is_galaxy: @@ -146,15 +147,15 @@ def forward(self, x: ttnn.Tensor): memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, dtype=ttnn.bfloat8_b, ) - outputs.append(ttnn.sharded_to_interleaved(output, memory_config=ttnn.L1_MEMORY_CONFIG)) + outputs.append(ttnn.sharded_to_interleaved(output, memory_config=ttnn.DRAM_MEMORY_CONFIG)) # Concatenate the outputs - output = ttnn.concat(outputs, dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG) + output = ttnn.concat(outputs, dim=-1, memory_config=ttnn.DRAM_MEMORY_CONFIG) output = tt_all_reduce( output, - mesh_device=self.mesh_device, - tt_ccl=self.tt_ccl, + self.mesh_device, + self.tt_ccl, cluster_axis=1, dim=3 if self.args.is_galaxy else 0, num_reduce_scatter_links=self.args.num_reduce_scatter_links, diff --git a/models/experimental/gemma3_4b/tt/mlp.py b/models/experimental/gemma3/tt/mlp.py similarity index 96% rename from models/experimental/gemma3_4b/tt/mlp.py rename to models/experimental/gemma3/tt/mlp.py index 2c55572bdfa2..440b1ad1b7f1 100644 --- a/models/experimental/gemma3_4b/tt/mlp.py +++ b/models/experimental/gemma3/tt/mlp.py @@ -1,13 +1,12 @@ """ source: models/tt_transformers/tt/mlp.py -This is the implementation of MLP (feed-forward) submodule of Gemma-3-4b-it. +This is the implementation of MLP (feed-forward) submodule of Gemma3. We have re-used the MLP implementation of the TT-Transformers library with few modifications. This implementation has changes in Data Type (bfloat16). """ - # SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC # SPDX-License-Identifier: Apache-2.0 @@ -36,15 +35,15 @@ def __init__( ): super().__init__() - self.state_dict = state_dict self.mesh_device = mesh_device self.tt_ccl = tt_ccl self.args = args + self.num_devices = args.num_devices self.dim = args.dim self.model_config = model_config self.layer_num = layer_num state_dict_prefix = state_dict_prefix or args.get_state_dict_prefix(self.__class__.__name__, layer_num) - torch_weight = lambda name: torch.transpose(self.state_dict[f"{state_dict_prefix}.{name}.weight"], -2, -1) + torch_weight = lambda name: torch.transpose(state_dict[f"{state_dict_prefix}.{name}.weight"], -2, -1) pad_hidden_dim = lambda tensor, dim: pad_to_size(tensor, dim=dim, size=args.hidden_dim) # If pading was applied (e.g. via env var), add the unpadded hidden dim to the cache name to avoid loading incorrect weights hidden_dim_string = f".hidden_dim_{args.hidden_dim}" if args.hidden_dim != args.unpadded_hidden_dim else "" @@ -108,7 +107,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: decoder_id=layer_num, tensor=TensorGroup.ACTIVATION ) li_ff1_3_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=self.args + decoder_id=layer_num, op=OpGroup.LI_FF1_FF3, configuration=self.args ) if mode == "decode": # Sharded config @@ -217,7 +216,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: w1_out, w3_out, input_tensor_a_activations=[self.activation_type], - dtype=activation_dtype or ttnn.bfloat16, + dtype=activation_dtype or ttnn.bfloat8_b, memory_config=w1_out.memory_config(), ) @@ -249,7 +248,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: w2_in = ttnn.to_memory_config(w2_in, ttnn.L1_MEMORY_CONFIG) li_ff2_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=self.args + decoder_id=layer_num, op=OpGroup.LI_FF2, configuration=self.args ) w2_out = ttnn.linear( w2_in, @@ -261,6 +260,9 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_2 else None, ) ttnn.deallocate(w2_in) + + w2_out = ttnn.multiply(w2_out, self.num_devices) + # if mode == "decode" and not TG: # w2_out = ttnn.sharded_to_interleaved(w2_out, ttnn.DRAM_MEMORY_CONFIG) w2_out_reduced = tt_all_reduce( @@ -281,6 +283,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: use_composite=True if self.dim == 8192 else False, topology=self.args.ccl_topology(), ) + w2_out_reduced = ttnn.div(w2_out_reduced, self.num_devices) # Ensure dim 0 and 1 are 1 original_shape = w2_out_reduced.shape diff --git a/models/experimental/gemma3_4b/tt/mmp.py b/models/experimental/gemma3/tt/mmp.py similarity index 95% rename from models/experimental/gemma3_4b/tt/mmp.py rename to models/experimental/gemma3/tt/mmp.py index d1b4a600c563..3e445e99606f 100644 --- a/models/experimental/gemma3_4b/tt/mmp.py +++ b/models/experimental/gemma3/tt/mmp.py @@ -1,5 +1,5 @@ """ -This is the implmentation of MultiModalprojector for Gemma-3-4b-it model. +This is the implmentation of MultiModalprojector for Gemma3 model. There is no Independent MultiModalprojector support in TT-Transformers. """ @@ -11,8 +11,7 @@ import ttnn from models.common.lightweightmodule import LightweightModule - -from models.experimental.gemma3_4b.tt.rmsnorm import RMSNorm +from models.experimental.gemma3.tt.gemma_vision_rmsnorm import RMSNorm class TtGemma3MultiModalProjector(LightweightModule): @@ -126,4 +125,7 @@ def forward(self, vision_outputs: ttnn.Tensor) -> ttnn.Tensor: self.mm_input_projection_weight = ttnn.to_layout(self.mm_input_projection_weight, ttnn.TILE_LAYOUT) projected_vision_outputs = ttnn.matmul(normed_vision_outputs, self.mm_input_projection_weight) + ttnn.deallocate(pooled_vision_outputs) + ttnn.deallocate(normed_vision_outputs) + return projected_vision_outputs diff --git a/models/experimental/gemma3/tt/rmsnorm.py b/models/experimental/gemma3/tt/rmsnorm.py new file mode 100644 index 000000000000..a61f13836f2d --- /dev/null +++ b/models/experimental/gemma3/tt/rmsnorm.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +import ttnn +from models.common.lightweightmodule import LightweightModule + +TILE = 32 +SHARD_HEIGHT = TILE # Current ttnn.rms_norm implementation requires shard height to be a single tile + + +class RMSNorm(LightweightModule): + """ + RMSNorm supporting replication over a MeshDevice and sharding within devices. + + This class implements a Root Mean Square Normalization (RMSNorm) that can be + distributed across multiple devices and cores. If the `device` parameter is a + MeshDevice, the weights and computations are replicated across all devices in + the mesh. Expects an interleaved input tensor, can optionally output a sharded tensor. + + Args: + device: The device or MeshDevice on which to perform the computations. + state_dict: The state dictionary containing the model parameters. + dim: Input dimension (e.g. model hidden dimension size). + layer_num: The layer number to determine the weight key in the state dictionary. + weight_key: The key for retrieving the weight from the state dictionary. + weight_cache_path: Optional path for caching the tilized weights. + weight_memory_config: Configuration for the weight memory, default is DRAM_MEMORY_CONFIG. + weight_dtype: The data type for the tensors, bfp8_b hits >0.999 PCC in the models we tested. + model_config: Optional configuration dictionary for the model. + eps (float): Small value to avoid division by zero in normalization, default is 1e-05. + + If model_config is provided, it must specify SHARDED_NORM_INPUT_MEMCFG, SHARDED_NORM_PRGM_CFG + and SHARDED_NORM_OUTPUT_MEMCFG. If not provided, default configurations will be generated. + """ + + def __init__( + self, + device, + dim, + state_dict, + weight_key, + layer_num=None, + state_dict_prefix=None, + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps: float = 1e-06, + add_unit_offset=False, + sharded_program_config=None, + sharded_output_config=None, + output_mem_config=None, + ccl_topology=ttnn.Topology.Ring, + tt_ccl=None, + ): + super().__init__() + self.device = device + self.eps = eps + self.is_distributed = is_distributed + self.ccl_topology = ccl_topology + self.tt_ccl = tt_ccl + + if state_dict_prefix: + weight_name = f"{state_dict_prefix}{weight_key}.weight" + else: + if layer_num is None: + weight_name = f"{weight_key}.weight" + else: + weight_name = f"layers.{layer_num}.{weight_key}.weight" + + torch_weight = ( + state_dict[weight_name].unsqueeze(0).view(1, 1, dim).reshape([1, 1, dim // SHARD_HEIGHT, SHARD_HEIGHT]) + ) + + # Add offset before caching + if add_unit_offset: + torch_weight = torch_weight + 1.0 + + cache_name = None if weight_cache_path is None else weight_cache_path / weight_name + + # Compatibility with models that don't use mesh devices (e.g. single-chip Mistral-7b) + is_mesh_device = device.__class__.__name__ == "MeshDevice" + + self.weight = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=cache_name, + mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + ) + + if self.is_distributed: + self.weight_distributed = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=( + None if weight_cache_path is None else weight_cache_path / (weight_name + "_distributed") + ), + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape)) + if is_mesh_device + else None, + ) + + self.sharded_output_config = sharded_output_config + self.sharded_program_config = sharded_program_config + self.output_mem_config = output_mem_config + + self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=False, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) -> ttnn.Tensor: + # If input is sharded do sharded RMSNorm and optionally return sharded output + program_config = self.sharded_program_config if in_sharded else None + memory_config = self.sharded_output_config if out_sharded else None + distributed = self.is_distributed and self.is_distributed(mode) + norm = self._distributed_rmsnorm if distributed else ttnn.rms_norm + weight = self.weight_distributed if distributed else self.weight + + if in_sharded: + assert not distributed, "Distributed RMSNorm does not support sharded inputs" + else: + assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor" + + x = norm( + x, + epsilon=self.eps, + weight=weight, + program_config=program_config, + memory_config=memory_config, + compute_kernel_config=self.compute_kernel_config_hifi2, + ) + + if in_sharded and not out_sharded: + return ttnn.sharded_to_interleaved(x) + else: + return x + + def _distributed_rmsnorm( + self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None + ): + assert program_config is None, "Distributed RMSNorm does not support sharded inputs" + assert memory_config is None, "Distributed RMSNorm does not support sharded outputs" + + # Run distributed rmsnorm part 1 + tt_stats = ttnn.rms_norm_pre_all_gather(inp, compute_kernel_config=compute_kernel_config, dtype=ttnn.bfloat16) + # AllGather stats + tt_stats = ttnn.experimental.all_gather_async( + tt_stats, + persistent_output_buffer=None, + dim=3, + multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(), + num_links=1, + topology=self.ccl_topology, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(), + chunks_per_sync=10, + num_workers_per_link=2, + num_buffers_per_channel=2, + ) + # Run distributed rmsnorm part 2 + tt_out = ttnn.rms_norm_post_all_gather( + inp, + tt_stats, + epsilon=epsilon, + weight=weight, + compute_kernel_config=compute_kernel_config, + ) + tt_stats.deallocate(True) + + return tt_out diff --git a/models/experimental/gemma3_4b/tt/siglip_vision_embedding.py b/models/experimental/gemma3/tt/siglip_vision_embedding.py similarity index 94% rename from models/experimental/gemma3_4b/tt/siglip_vision_embedding.py rename to models/experimental/gemma3/tt/siglip_vision_embedding.py index 9483951a78f9..b4522bea810b 100644 --- a/models/experimental/gemma3_4b/tt/siglip_vision_embedding.py +++ b/models/experimental/gemma3/tt/siglip_vision_embedding.py @@ -1,5 +1,5 @@ """ -This is the VisionEmbedding implementation for the Gemma-3-4b-it +This is the VisionEmbedding implementation for the Gemma3 This implementation combines patch_conv followed by Embeddings as a submodule. """ @@ -9,10 +9,10 @@ import torch + import ttnn from models.common.lightweightmodule import LightweightModule - -from models.experimental.gemma3_4b.tt.gemma_conv2d_patch import TtGemmaConv2dPatch +from models.experimental.gemma3.tt.gemma_conv2d_patch import TtGemmaConv2dPatch class TtSiglipVisionEmbeddings(LightweightModule): diff --git a/models/experimental/gemma3_4b/tt/gemma_text_model.py b/models/experimental/gemma3/tt/text_model.py similarity index 65% rename from models/experimental/gemma3_4b/tt/gemma_text_model.py rename to models/experimental/gemma3/tt/text_model.py index 9a433e488507..c0b033b15419 100644 --- a/models/experimental/gemma3_4b/tt/gemma_text_model.py +++ b/models/experimental/gemma3/tt/text_model.py @@ -1,6 +1,6 @@ """ -This is the end-to-end implementation of the Gemma-3-4b-it model. +This is the end-to-end implementation of the Gemma3 model. """ @@ -9,20 +9,21 @@ # SPDX-License-Identifier: Apache-2.0 import ttnn -from models.experimental.gemma3_4b.tt.rmsnorm import RMSNorm +from models.experimental.gemma3.tt.rmsnorm import RMSNorm -# from models.tt_transformers.tt.model import Transformer from models.common.lightweightmodule import LightweightModule -from models.tt_transformers.tt.embedding import Embedding +from models.tt_transformers.tt.embedding import Embedding, ScaledEmbedding from models.tt_transformers.tt.rope import RotarySetup -from models.experimental.gemma3_4b.tt.decoder import TransformerBlock +from models.experimental.gemma3.tt.decoder import TransformerBlock from models.tt_transformers.tt.distributed_norm import DistributedNorm from tqdm import tqdm import torch -from models.experimental.gemma3_4b.tt.lm_head import LMHead +from models.experimental.gemma3.tt.lm_head import LMHead from models.tt_transformers.tt.model_config import TensorGroup from models.tt_transformers.tt.common import copy_host_to_device +from models.utility_functions import nearest_32 +from models.tt_transformers.tt.ccl import TT_CCL class Gemma3Transformer(LightweightModule): @@ -31,55 +32,63 @@ def __init__( args, dtype, mesh_device, - tt_ccl, state_dict, weight_cache_path, paged_attention_config=None, use_paged_kv_cache=False, + attention_class=None, + rope_setup_class=None, + attn_mask=None, ): super().__init__() self.args = args + self.paged_attention_config = paged_attention_config self.vocab_size = args.vocab_size + self.tt_ccl = TT_CCL(mesh_device) assert self.vocab_size > 0 self.n_layers = args.n_layers self.mesh_device = mesh_device - self.tt_ccl = tt_ccl self.dtype = dtype self.model_config = args.get_model_config() self.grid_size = self.args.max_grid_size state_dict_prefix = args.get_state_dict_prefix("", None) - - self.embd = Embedding( - mesh_device=mesh_device, - args=args, - weight_cache_path=args.weight_cache_path(dtype), - state_dict=state_dict, - dtype=ttnn.bfloat16, # Row major layout requires bfloat16 - ) - - self.rope_setup = RotarySetup( - mesh_device, - args.max_batch_size, - args.head_dim, - args.max_seq_len, - args.rope_theta, - args.rope_scaling, + self.tt_ccl = TT_CCL(self.mesh_device) + + embd_kwargs = { + "mesh_device": mesh_device, + "args": args, + "weight_cache_path": args.weight_cache_path(dtype), + "state_dict": state_dict, + "dtype": ttnn.bfloat16, # Row major layout requires bfloat16 + } + if self.args.embed_scale is not None: + embd_cls = ScaledEmbedding + embd_kwargs["embed_scale"] = self.args.embed_scale + else: + embd_cls = Embedding + self.embd = embd_cls(**embd_kwargs) + + ActualRopeSetupClass = rope_setup_class if rope_setup_class is not None else RotarySetup + self.rope_setup = ActualRopeSetupClass( + device=mesh_device, + batch_size=args.max_batch_size, + head_dim=args.head_dim, + max_seq_len=args.max_seq_len, + rope_theta=args.rope_theta, + rope_scaling=args.rope_scaling, ) if args.rope_theta_local: - self.rope_setup_local = RotarySetup( + self.rope_local_setup = RotarySetup( mesh_device, args.max_batch_size, args.head_dim, args.max_seq_len, args.rope_theta_local, - None, + rope_scaling=None, ) - else: - self.rope_setup_local = None - trans_mats_dict = self.rope_setup.get_both_trans_mats() - trans_mats_dict_local = self.rope_setup_local.get_both_trans_mats() + self.trans_mats_dict = self.rope_setup.get_both_trans_mats() self.layers = [ TransformerBlock( @@ -90,14 +99,13 @@ def __init__( state_dict=state_dict, weight_cache_path=weight_cache_path, layer_num=i, - transformation_mats=trans_mats_dict, - transformation_mats_local=trans_mats_dict_local, + transformation_mats=self.trans_mats_dict, paged_attention_config=paged_attention_config, use_paged_kv_cache=use_paged_kv_cache, + attention_class=attention_class, ) for i in tqdm(range(self.n_layers)) ] - self.cross_attention_layers = self.layers self.norm = DistributedNorm( RMSNorm( device=mesh_device, @@ -131,7 +139,30 @@ def __init__( max_columns_per_device=self.args.max_columns_per_device_lm_head, ) - self.embed_scale = args.dim**0.5 + self.host_embed = self.args.reference_embedding() + + def setup_cache(self, max_batch_size): + self.cache_is_setup = True + + # Prepare xattn_caches + chunk_length = nearest_32(self.args.vision_chunk_ntok) + vision_seq_len = self.args.vision_max_num_chunks * chunk_length + xattn_cache = [ + [ + ttnn.from_torch( + torch.zeros(max_batch_size, self.args.n_kv_heads, vision_seq_len, self.args.head_dim), + device=self.mesh_device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=1), + ) + for _ in range(2) + ] + for l in range(len(self.cross_attention_layers)) + ] + + return xattn_cache def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): """ @@ -139,21 +170,67 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag tensors on device. TODO: Debate whether this function is responsible for padding """ + if not kwargs.get("processed_inputs", None): + tokens = tokens.reshape(1, 1, 1, -1) + S = tokens.shape[-1] + tokens = ttnn.from_torch( + tokens, + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + tokens_embd = self.embd(tokens) + else: + S = tokens.shape[-1] - assert tokens.dim() == 2, "tokens must be a 2D tensor" - tokens = tokens.reshape(1, 1, 1, -1) - S = tokens.shape[-1] - tokens = ttnn.from_torch( - tokens, - device=self.mesh_device, - dtype=ttnn.uint32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - tokens_embd = self.embd(tokens) - tokens_embd = ttnn.multiply(tokens_embd, self.embed_scale) - tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) + tokens_embd = ttnn.from_torch( + tokens.reshape(1, 1, 1, -1), + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + tokens_embd = self.embd(tokens_embd) + + pixel_values = kwargs["processed_inputs"]["pixel_values"] + input_ids = kwargs["processed_inputs"]["input_ids"] + if pixel_values is not None: + vision_model = kwargs["vision_model"] + vision_output = vision_model(pixel_values) + + tokens_embd = ttnn.to_torch( + tokens_embd, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1) + ) + + comp_vision_output = ttnn.to_torch( + vision_output, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1) + )[:, :, :, : vision_output.shape[-1]] + + comp_vision_output = torch.nn.functional.pad( + comp_vision_output, (0, 0, 0, tokens_embd.shape[1] - comp_vision_output.shape[1]), "constant", 0 + ) + input_ids = torch.nn.functional.pad( + input_ids, (0, tokens_embd.shape[1] - input_ids.shape[1]), "constant", 0 + ) + image_features = comp_vision_output.squeeze(0) + special_image_mask = (input_ids == self.args.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(tokens_embd) + image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) + tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) + + tokens_embd = ttnn.from_torch( + tokens_embd, + dtype=ttnn.bfloat16, + device=self.mesh_device, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh( + self.mesh_device, dims=(None, -1), mesh_shape=list(self.mesh_device.shape) + ), + ) + tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) # Slice the rot mats to the prefill seqlen assert ( self.rope_setup.cos_matrix.shape[2] >= start_pos + S @@ -164,10 +241,10 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], ] - if self.rope_setup_local is not None: + if hasattr(self, "rope_local_setup"): tt_rot_mats_prefill_local = [ - self.rope_setup_local.cos_matrix[:, :, start_pos : start_pos + S, :], - self.rope_setup_local.sin_matrix[:, :, start_pos : start_pos + S, :], + self.rope_local_setup.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_local_setup.sin_matrix[:, :, start_pos : start_pos + S, :], ] else: tt_rot_mats_prefill_local = None @@ -194,7 +271,13 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag else: tt_chunk_page_table = None - return tokens_embd, tt_rot_mats_prefill_global, tt_rot_mats_prefill_local, tt_page_table, tt_chunk_page_table + return ( + tokens_embd, + tt_rot_mats_prefill_global, + tt_rot_mats_prefill_local, + tt_page_table, + tt_chunk_page_table, + ) def prepare_inputs_decode(self, *inputs): """ @@ -230,8 +313,8 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None): current_pos, torch.tensor(0, dtype=torch.int64) ) # Ensure position indices are non-negative rope_idxs_global = self.rope_setup.get_rot_idxs(rot_current_pos, on_host=True) - if self.rope_setup_local is not None: - rope_idxs_local = self.rope_setup_local.get_rot_idxs(rot_current_pos, on_host=True) + if hasattr(self, "rope_local_setup"): + rope_idxs_local = self.rope_local_setup.get_rot_idxs(rot_current_pos, on_host=True) else: rope_idxs_local = None @@ -257,6 +340,7 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None): mesh_shape=self.args.cluster_shape, ), ) + return tokens, current_pos_tt, rope_idxs_global, rope_idxs_local, page_table def _transform_decode_inputs_device(self, tokens): @@ -269,7 +353,6 @@ def _transform_decode_inputs_device(self, tokens): Embed tokens """ tt_tokens = self.embd(tokens) - tt_tokens = ttnn.multiply(tt_tokens, self.embed_scale) tt_tokens = ttnn.unsqueeze_to_4D(tt_tokens) tt_tokens = ttnn.to_memory_config( tt_tokens, @@ -277,34 +360,33 @@ def _transform_decode_inputs_device(self, tokens): ) return tt_tokens - def concat_device_output(self, tt_out): - """ - Concatenate the output of the devices into a single tensor. - """ - torch_out_tensors = [ttnn.to_torch(x) for x in ttnn.get_device_tensors(tt_out.cpu())] - if self.args.is_galaxy: - row_dim, col_dim = (3, 1) - else: - row_dim, col_dim = (1, -1) - - rows, cols = self.args.cluster_shape - mesh_shape = [torch_out_tensors[i : i + cols] for i in range(0, len(torch_out_tensors), cols)] - row_concatenated = [torch.cat(row, dim=col_dim) for row in mesh_shape] - return torch.cat(row_concatenated, dim=row_dim) - def process_output_prefill(self, tt_out, last_token_idx): """ Input is ttnn device tensor of logits. Output is torch logits tensor. NOTE: In this model, prefill always uses get_last_token """ - return self.concat_device_output(tt_out)[0, 0, last_token_idx, : self.vocab_size] + logits = ttnn.to_torch( + tt_out, + mesh_composer=ttnn.ConcatMesh2dToTensor( + self.mesh_device, dims=(3, 1) if self.args.is_galaxy else (1, -1), mesh_shape=self.args.cluster_shape + ), + )[0, 0, last_token_idx, : self.vocab_size] + return logits def process_output_decode(self, tt_out, B, S=1, is_tokens=False): """ Input is ttnn device tensor of logits if is_tokens=False, otherwise tokens. Output is the corresponding torch tensor. """ if is_tokens: - return self.concat_device_output(tt_out)[0, 0, :B, 0] + tt_out = ttnn.to_torch( + tt_out, # tt_out.cpu(blocking=True, cq_id=1), + mesh_composer=ttnn.ConcatMesh2dToTensor( + self.mesh_device, + dims=(3, 1) if self.args.is_galaxy else (1, -1), + mesh_shape=self.args.cluster_shape, + ), + )[0, 0, :B, 0] + return tt_out if self.args.num_devices > 1: tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() @@ -329,6 +411,18 @@ def ttnn_prefill_forward( This method will take device tensors and any other args to run forward. It returns ttnn device tensors. """ + if hasattr(self.args, "sliding_window") and self.args.sliding_window is not None: + mask = torch.triu(torch.full((1, 1, x.shape[-2], x.shape[-2]), -float("inf")), diagonal=1) + sliding_mask = mask + torch.tril( + torch.full((1, 1, x.shape[-2], x.shape[-2]), -float("inf")), + diagonal=-self.args.sliding_window, + ) + sliding_attn_mask = ttnn.from_torch( + sliding_mask, device=self.mesh_device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16 + ) + else: + sliding_attn_mask = None + return self.forward( x, current_pos=None, @@ -359,6 +453,24 @@ def _increment_decode_positions_device(self, current_pos, rot_mat_idxs_global, r if rot_mat_idxs_local is not None: ttnn.plus_one(rot_mat_idxs_local) + def update_attention_masks(self, current_pos): + torch_mask = torch.concat( + [ + self.decode_sliding_mask_mat[i, :, current_pos[i].item() : current_pos[i].item() + 1, :].unsqueeze(0) + for i in range(self.decode_sliding_mask_mat.shape[0]) + ], + axis=0, + ).transpose(1, 2) + sliding_window_causal_mask = ttnn.as_tensor( + torch_mask, + dtype=ttnn.bfloat4_b, + layout=ttnn.TILE_LAYOUT, + device=None, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + ttnn.copy_host_to_device_tensor(sliding_window_causal_mask, self.device_decode_sliding_mask) + def ttnn_decode_forward( self, x, @@ -374,11 +486,9 @@ def ttnn_decode_forward( It returns ttnn device tensors. """ rot_mats_global = self.rope_setup.get_rot_mats(rot_mat_idxs_global) - if self.rope_setup_local is not None: - rot_mats_local = self.rope_setup_local.get_rot_mats(rot_mat_idxs_local) - else: - rot_mats_local = None - + rot_mats_local = ( + self.rope_local_setup.get_rot_mats(rot_mat_idxs_local) if rot_mat_idxs_local is not None else None + ) x_embed = self._transform_decode_inputs_device(x) tt_logits = self.forward( @@ -422,9 +532,7 @@ def ttnn_decode_forward( ttnn.copy(tt_logits.reshape(x.shape), x) elif not self.args.is_galaxy: # Send output logits to DRAM so L1 is not reserved for ttnn tracing and can be used by subsequent operations - # TODO Investigate why moving to DRAM fails, it never should! - # tt_logits = ttnn.to_memory_config(tt_logits, ttnn.DRAM_MEMORY_CONFIG) - pass + tt_logits = ttnn.to_memory_config(tt_logits, ttnn.DRAM_MEMORY_CONFIG) return tt_logits diff --git a/models/experimental/gemma3_4b/conftest.py b/models/experimental/gemma3_4b/conftest.py deleted file mode 100644 index 21430b096255..000000000000 --- a/models/experimental/gemma3_4b/conftest.py +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 - -# Import the device_params fixture from tt_transformers -from models.tt_transformers.conftest import device_params # noqa: F401 diff --git a/models/experimental/gemma3_4b/tt/gemma_model.py b/models/experimental/gemma3_4b/tt/gemma_model.py deleted file mode 100644 index 8b33502ccdc9..000000000000 --- a/models/experimental/gemma3_4b/tt/gemma_model.py +++ /dev/null @@ -1,108 +0,0 @@ -""" -This is the Gemma3 end-to-end model. -""" - -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 - -import ttnn -import torch -from models.experimental.gemma3_4b.tt.gemma_text_model import Gemma3Transformer -from models.experimental.gemma3_4b.tt.gemma_vision_model import TtSiglipGemmaVisionModel -from models.experimental.gemma3_4b.tt.mmp import TtGemma3MultiModalProjector -from models.tt_transformers.tt.ccl import TT_CCL - - -class TtGemma3Model(Gemma3Transformer): - def __init__( - self, - args, - dtype, - mesh_device, - state_dict, - weight_cache_path, - paged_attention_config=None, - use_paged_kv_cache=False, - ): - self.tt_ccl = TT_CCL(mesh_device) - - super().__init__( - args=args, - dtype=dtype, - mesh_device=mesh_device, - tt_ccl=self.tt_ccl, - state_dict=state_dict, - weight_cache_path=weight_cache_path, - paged_attention_config=paged_attention_config, - use_paged_kv_cache=use_paged_kv_cache, - ) - - self.vision_encoder = TtSiglipGemmaVisionModel( - mesh_device, - state_dict=state_dict, - tt_ccl=self.tt_ccl, - state_dict_prefix=args.state_dict_vision_prefix, - weight_cache_path=args.weight_cache_path(dtype), - dtype=dtype, - configuration=args, - ) - - self.mmp = TtGemma3MultiModalProjector( - mesh_device=mesh_device, - state_dict=state_dict, - state_dict_prefix="multi_modal_projector", - image_size=args.image_size, - patch_size=args.vision_patch_size, - hidden_size=args.vision_hidden_dim, - mm_tokens_per_image=args.mm_tokens_per_image, - weight_cache_path=args.weight_cache_path(dtype), - layer_norm_eps=1e-06, # layer_norm_eps - dtype=dtype, - configuration=args, - ) - - def prepare_inputs_prefill(self, pt_tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): - """ - Inputs are torch tensors or python types. This function returns ttnn - tensors on device. - TODO: Debate whether this function is responsible for padding - """ - - tokens_embd, *kwargs_out = super().prepare_inputs_prefill( - pt_tokens, start_pos, page_table, chunk_page_table, **kwargs - ) - - if kwargs.get("pixel_values") is not None: - vision_output = self.compute_vision_token(kwargs["pixel_values"]) - - # TODO: Move tokens merging to device - - tokens_embd = ttnn.to_torch(tokens_embd) - comp_vision_output = ttnn.to_torch(ttnn.from_device(vision_output)) - comp_vision_output = torch.nn.functional.pad( - comp_vision_output, (0, 0, 0, tokens_embd.shape[1] - comp_vision_output.shape[1]), "constant", 0 - ) - - input_ids = torch.nn.functional.pad( - pt_tokens, (0, tokens_embd.shape[1] - pt_tokens.shape[1]), "constant", 0 - ) - image_features = comp_vision_output.squeeze(0) - special_image_mask = (input_ids == self.args.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(tokens_embd) - image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) - tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) - tokens_embd = ttnn.from_torch( - tokens_embd, - dtype=ttnn.bfloat16, - device=self.mesh_device, - layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - - return tokens_embd, *kwargs_out - - def compute_vision_token(self, pixel_values): - vision_tokens = self.vision_encoder(pixel_values)[0, :, :, :] - vision_output = self.mmp(vision_tokens) - return vision_output diff --git a/models/tt_transformers/tt/common.py b/models/tt_transformers/tt/common.py index 3d3cf1ecd177..66922dc0dbb1 100644 --- a/models/tt_transformers/tt/common.py +++ b/models/tt_transformers/tt/common.py @@ -3,10 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 import math +import os import re from enum import Enum from types import SimpleNamespace -from typing import Optional +from typing import Callable, Optional import torch from llama_models.llama3.api.datatypes import ImageMedia @@ -248,7 +249,10 @@ def encode_prompt_hf(tokenizer, prompt_text, system_prompt_text=None): chat.append({"role": "user", "content": prompt_text}) return tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=True) else: - return tokenizer.apply_chat_template(prompt_text, add_generation_prompt=True, tokenize=True) + output = tokenizer.apply_chat_template([prompt_text], add_generation_prompt=True, tokenize=True) + if len(output) == 1: + output = output[0] + return output def compute_llama3_parameters(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): @@ -367,6 +371,42 @@ def get_prefill_rot_mat(head_dim, mesh_device, seq_len, theta, scale_factor, ori rot_mats = [cos_gathereds, sin_gathereds] return rot_mats +def compute_linear_parameters(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): + """Linear scaling for rotary embeddings.""" + freqs /= scale_factor + return freqs + + +def apply_scaling(freqs: torch.Tensor, scale_factor: float, orig_context_len: int, rope_type="llama3"): + # FIXME: Llama-3.x specific scaling - we need to support yarn for Qwen2.5 models + + if rope_type == "linear": + freqs = compute_linear_parameters(freqs, scale_factor, orig_context_len) + elif rope_type == "llama3": + freqs = compute_llama3_parameters(freqs, scale_factor, orig_context_len) + + return freqs + + +def precompute_freqs(dim: int, end: int, theta, scale_factor, orig_context_len, rope_type="llama3"): + """ + Precompute the frequency tensor for sine and cosine values with given dimensions. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 500000.0. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tensors containing cosine and sine values. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end) + if scale_factor is not None: + freqs = apply_scaling(freqs, scale_factor, orig_context_len, rope_type=rope_type) + freqs = torch.outer(t, freqs).float() + return torch.cos(freqs), torch.sin(freqs) + # Add-Multiply method of rotary embeddings for prefill def get_rot_transformation_mat(dhead): @@ -673,7 +713,11 @@ def create_tt_model( state_dict=None, num_layers=None, ): - from models.tt_transformers.tt.model import Transformer + if "HF_MODEL" in os.environ and "gemma-3" in os.environ["HF_MODEL"].lower(): + from models.experimental.gemma3.tt.text_model import Gemma3Transformer as Transformer + else: + from models.tt_transformers.tt.model import Transformer + from models.tt_transformers.tt.model_config import ModelArgs tt_model_args = ModelArgs( @@ -741,7 +785,7 @@ def hf_multimodal_encode(messages, processor): **encoded, tokens=encoded["input_ids"].squeeze(0), vision=SimpleNamespace( - images=encoded.get("pixel_values", None), + images=encoded["pixel_values"], mask=None, ), ) diff --git a/models/tt_transformers/tt/generator.py b/models/tt_transformers/tt/generator.py index d1f4524a728d..534e7a9260cf 100644 --- a/models/tt_transformers/tt/generator.py +++ b/models/tt_transformers/tt/generator.py @@ -210,7 +210,6 @@ def prefill_forward_single_user_text( chunk_start_idx=chunk_start, get_last_token=(last_token_idx_in_chunk // 32) * 32, kv_cache=kv_cache, - **kwargs, ) if chunk_start == last_chunk_start: @@ -218,13 +217,9 @@ def prefill_forward_single_user_text( else: del tt_logits else: - ( - prefill_input, - rot_mats_global_prefill, - rot_mats_local_prefill, - page_table_tt, - _, - ) = self.model[model_id].prepare_inputs_prefill( + (prefill_input, rot_mats_global_prefill, rot_mats_local_prefill, page_table_tt, _) = self.model[ + model_id + ].prepare_inputs_prefill( tokens, page_table=page_table, **kwargs, @@ -299,7 +294,6 @@ def _decode_forward_no_trace_text( tt_rot_mat_idxs_global = [] tt_rot_mat_idxs_local = [] tt_page_table = [] - for i in range(self.data_parallel): user_page_table = page_table[i] if page_table is not None else None model_i = self.model[i] @@ -410,6 +404,12 @@ def _easy_trace_text( host_tensors=host_inputs_i, device_tensors=self.trace_inputs_text[i], ) + for i in range(self.data_parallel): + if ( + hasattr(self.model[i], "device_decode_sliding_mask") + and self.model[i].device_decode_sliding_mask is not None + ): + self.model[i].update_attention_masks(current_pos[i]) for i, trace_id in self.trace_ids_text.items(): ttnn.execute_trace(self.model_args[i].mesh_device, trace_id, cq_id=0, blocking=False) diff --git a/models/tt_transformers/tt/model.py b/models/tt_transformers/tt/model.py index f9bbbf2020b1..cd689e051f46 100644 --- a/models/tt_transformers/tt/model.py +++ b/models/tt_transformers/tt/model.py @@ -188,7 +188,13 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag else: tt_chunk_page_table = None - return tokens_embd, tt_rot_mats_prefill_global, tt_rot_mats_prefill_local, tt_page_table, tt_chunk_page_table + return ( + tokens_embd, + tt_rot_mats_prefill_global, + tt_rot_mats_prefill_local, + tt_page_table, + tt_chunk_page_table, + ) def prepare_inputs_decode(self, *inputs): """ diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index 682895c91853..1eea6bab32cb 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -574,7 +574,9 @@ def __init__( if max_prefill_chunk_size_div1024 is None: # TODO Improve this to be more general to more devices and models MAX_PREFILL_CHUNK_SIZES_DIV1024 = { - "gemma-3-4b": {"N150": 8, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, + "gemma-3-1b": {"N150": 128, "N300": None, "T3K": None, "TG": None, "P150x4": None}, + "gemma-3-4b": {"N150": 128, "N300": 128, "T3K": None, "TG": None, "P150x4": 128}, + "gemma-3-27b": {"N150": None, "N300": None, "T3K": 128, "TG": 128, "P150x4": 128}, "Llama-3.2-1B": {"N150": 128, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, "Llama-3.2-3B": {"N150": 8, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, "Llama-3.1-8B": {"N150": 4, "N300": 64, "T3K": 128, "TG": 128, "P150x4": 128}, @@ -610,9 +612,9 @@ def __init__( self.max_prefill_chunk_size = max_prefill_chunk_size_div1024 * 1024 if ( - self.base_model_name in ["Llama-3.1-8B", "Llama-3.2-11B", "Mistral-7B", "gemma-3-4b"] + self.base_model_name in ["Llama-3.1-8B", "Llama-3.2-11B", "Mistral-7B", "gemma-3-4b", "gemma-3-1b"] and self.device_name == "N150" - ) or (self.base_model_name in ["Qwen2.5-7B"] and self.device_name == "N300"): + ) or (self.base_model_name in ["Qwen2.5-7B", "gemma-3-4b"] and self.device_name == "N300"): logger.info(f"Reducing prefill_len_cutoff to 512 for {self.model_name} on {self.device_name}") self.prefill_len_cutoff = 512 elif self.base_model_name in ["Mixtral-8x7B"] and self.device_name == "T3K": @@ -1451,6 +1453,7 @@ def prepare_residual_tensor_prefill(self, x_bsh, force_replicated=False): xs_1BSH = ttnn.from_torch( x_1BSH, device=self.mesh_device, + # dtype=ttnn.bfloat8_b, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, @@ -1502,8 +1505,6 @@ def _set_params_from_dict(self, config, is_hf=False): # Try to get text_config, if it doesn't exist everything is text config text_config = config.get("text_config", config) self.eos_token_id = None if isinstance(eos_token_id, int) else eos_token_id - layer_types = text_config["layer_types"] if "layer_types" in text_config else None - # Common params with different names between Meta and HF self.dim = text_config.get("dim", text_config.get("hidden_size")) self.n_heads = text_config.get("n_heads", text_config.get("num_attention_heads")) @@ -1513,10 +1514,7 @@ def _set_params_from_dict(self, config, is_hf=False): # they are calculated in HF but not calculated in Meta self.n_layers -= len(text_config.get("cross_attention_layers", ())) - self.sliding_window_pattern = ( - [lt == "sliding_attention" for lt in layer_types] if layer_types is not None else [False] * self.n_layers - ) - + self.sliding_window = text_config.get("sliding_window", 0) self.full_model_n_layers = self.n_layers self.norm_eps = text_config.get("norm_eps", text_config.get("rms_norm_eps")) self.vocab_size = text_config["vocab_size"] @@ -1609,7 +1607,7 @@ def _set_params_from_dict(self, config, is_hf=False): ) self.query_pre_attn_scalar = text_config.get("query_pre_attn_scalar", None) - + self.sliding_window = text_config.get("sliding_window", None) # Configurable MLP activation type self.mlp_activation_type = self._get_hidden_activation_type(text_config) @@ -1634,7 +1632,9 @@ def vision_chunk_ntok(self): """ Returns the number of tokens per chunk, accounting for the extra class token """ - return (self.vision_chunk_size // self.vision_patch_size) ** 2 + 1 + if self.is_llama_vision(): + return (self.vision_chunk_size // self.vision_patch_size) ** 2 + 1 + return (self.image_size // self.vision_patch_size) ** 2 + 1 def _set_model_params(self, checkpoint_dir): if self.checkpoint_type == CheckpointType.Meta: @@ -2533,6 +2533,10 @@ def reference_vision_transformer(self, wrap=True, load_checkpoint=False): if self.cached_hf_model is None: model = model_cls.from_pretrained(self.CKPT_DIR, local_files_only=os.getenv("CI") == "true") self.cached_hf_model = model + if "gemma-3" in self.model_name: + from transformers import Gemma3ForConditionalGeneration + + model = Gemma3ForConditionalGeneration.from_pretrained(self.CKPT_DIR) else: model = self.cached_hf_model model.model.layers = model.model.layers[: self.n_layers] @@ -2567,7 +2571,7 @@ def reference_embedding(self, reference_model=None): model = self.reference_transformer(wrap=False) layer = model.model.embed_tokens else: - layer = reference_model.model.model.embed_tokens + layer = reference_model.model.embed_tokens layer._load_state_dict = layer.load_state_dict layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) @@ -2581,15 +2585,9 @@ def reference_decoder(self): else: model = self.reference_transformer(wrap=False) layer = model.model.layers[0] - use_position_embeddings = layer.__class__.__name__ != "Phi3DecoderLayer" - model_name_env = os.getenv("HF_MODEL") - if hasattr(model.model, "rotary_emb_local"): - rotary_emb_local = model.model.rotary_emb_local - else: - rotary_emb_local = None - wrapper = HfDecoderWrapper( - layer, self.head_dim, model.model.rotary_emb if use_position_embeddings else None, rotary_emb_local - ) + rotary_emb_local = getattr(model.model, "rotary_emb_local", None) + wrapper = HfDecoderWrapper(layer, self.head_dim, model.model.rotary_emb, rotary_emb_local=rotary_emb_local) + return wrapper def reference_attention(self): @@ -2797,7 +2795,6 @@ def forward(self, x, start_pos, freqs_cis_i, mask=None): position_ids=position_ids, attention_mask=mask, ) - output = result[0] return output diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp index 712def37d7cc..4a11a8680b4d 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp @@ -62,6 +62,7 @@ void MAIN { constexpr uint32_t q_heads_parallel_factor = get_compile_time_arg_val(26); constexpr bool use_half_tile = get_compile_time_arg_val(27); constexpr uint32_t scale_fp32 = get_compile_time_arg_val(28); + constexpr uint32_t sliding_window = get_compile_time_arg_val(29); constexpr uint32_t q_chunk_tiles = Sq_chunk_t * DHt; constexpr uint32_t out_chunk_tiles = Sq_chunk_t * vDHt; @@ -73,6 +74,7 @@ void MAIN { constexpr uint32_t cb_k_in = tt::CBIndex::c_1; constexpr uint32_t cb_v_in = tt::CBIndex::c_2; constexpr uint32_t cb_mask_in = tt::CBIndex::c_3; + constexpr uint32_t cb_sliding_window_mask_in = tt::CBIndex::c_13; // Separate buffer for sliding window mask constexpr uint32_t cb_attention_sink = tt::CBIndex::c_4; constexpr uint32_t cb_identity_scale_in = tt::CBIndex::c_5; constexpr uint32_t cb_m_in = tt::CBIndex::c_6; @@ -144,8 +146,13 @@ void MAIN { auto k_chunk_size_dynamic = Sk_chunk_t_dynamic * tt::constants::TILE_HEIGHT; // Get the sequence length assignment - auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end] = - get_runtime_args(cur_pos, cur_batch, core_num_in_reduce, num_cores_per_head, k_chunk_size_dynamic); + auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end, window_start_unaligned, window_start_chunk] = get_runtime_args( + cur_pos, + cur_batch, + core_num_in_reduce, + num_cores_per_head, + k_chunk_size_dynamic, + sliding_window > 0 ? std::optional(sliding_window) : std::nullopt); if (k_chunk_start == k_chunk_end) { return; // early exit because no computes needs to be done } @@ -270,13 +277,22 @@ void MAIN { // OPTIMIZATION: Add the attention mask directly on top of DST if chunk sizes are dynamic #ifdef DYNAMIC_CHUNK_SIZE - bool add_mask_fusion = - is_causal && k_chunk == k_chunk_end - 1 && apply_mask_at_last_chunk || use_attention_mask; + bool add_causal_mask_fusion = is_causal && k_chunk == k_chunk_end - 1 && apply_mask_at_last_chunk; + bool add_sliding_window_mask_fusion = k_chunk == window_start_chunk; + bool add_mask_fusion = add_causal_mask_fusion || use_attention_mask || add_sliding_window_mask_fusion; #else bool add_mask_fusion = false; + bool add_causal_mask_fusion = false; + bool add_sliding_window_mask_fusion = false; #endif /* QK = Q_CHUNK @ K_CHUNK */ + // Determine which mask buffer to use for fusion + uint32_t mask_cb_to_use = cb_mask_in; // Default to causal mask buffer + if (add_sliding_window_mask_fusion) { + mask_cb_to_use = cb_sliding_window_mask_in; // Use sliding window mask buffer + } + cb_matmul_blocks( cb_q_in, cb_k_in, @@ -292,7 +308,7 @@ void MAIN { qk_subblock_w_dynamic, true, add_mask_fusion, - cb_mask_in, + mask_cb_to_use, cb_zero_in); /* QK += MASK */ @@ -309,6 +325,12 @@ void MAIN { add_block_inplace(cb_qk_im, cb_mask_in, qk_chunk_tiles_dynamic); } } + + // Apply sliding window mask to the first chunk (only on the core that processes it) + if (k_chunk == window_start_chunk && window_start_unaligned > 0) { + reconfig_data_format(cb_qk_im, cb_sliding_window_mask_in); + add_block_inplace(cb_qk_im, cb_sliding_window_mask_in, qk_chunk_tiles_dynamic); + } } /** diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/dataflow_common.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/dataflow_common.hpp index dded84fd5f02..37e50aec5015 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/dataflow_common.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/dataflow_common.hpp @@ -136,6 +136,76 @@ void fill_tile_partial(uint32_t cb_id, uint32_t tile_id, uint32_t cur_pos_in_til } } +template +void fill_tile_partial_sliding_window( + uint32_t cb_id, uint32_t tile_id, uint32_t window_start_pos_in_tile, uint32_t partial_val) { + /* + For sliding window mask: fill positions 0 to window_start_pos_in_tile - 1 with partial_val (-inf) + This is the inverse of fill_tile_partial which fills from cur_pos_in_tile + 1 to end + + Example: if window_start_pos_in_tile = 5, then positions 0,1,2,3,4 are filled with -inf + and positions 5,6,7,...,31 remain as 0 (allowed) + */ + constexpr int num_faces = (tile_bytes == 1024) ? 2 : 4; + + fill_tile(cb_id, tile_id, 0); + if (window_start_pos_in_tile == 0 || partial_val == 0) { + return; // No masking needed if window starts at position 0 or no mask value + } + + const uint16_t datum_val = partial_val >> 16; + volatile tt_l1_ptr uint16_t* uint16_ptr = + reinterpret_cast(get_write_ptr(cb_id) + tile_id * tile_bytes); + volatile tt_l1_ptr uint32_t* uint32_ptr = + reinterpret_cast(get_write_ptr(cb_id) + tile_id * tile_bytes); + + // Determine which faces to fill completely (before the window_start_pos_in_tile) + int face_start = (window_start_pos_in_tile < 15) ? 0 : 1; // Last face to fill completely + + // Fill complete faces (faces 0, 2, 4, 6... for faces before face_start) + if (face_start == 1) { + constexpr int num_uint32_datums_tile_face = (16 * 16) / 2; + for (int k = 0; k < num_faces; k += 2) { + uint32_t uint32_face_idx = k << 7; + for (int j = 0; j < num_uint32_datums_tile_face; j++) { + uint32_ptr[uint32_face_idx + j] = partial_val; + } + } + } + + // Fill partial face (the face containing window_start_pos_in_tile) + uint32_t fill_end_pos_in_face = window_start_pos_in_tile % 16; // Position to stop filling (exclusive) + + // Optimize performance by filling 2 uint16 datums in each write + bool is_odd_end_pos = fill_end_pos_in_face % 2 == 1; + uint32_t fill_end_pos_in_uint32_face = fill_end_pos_in_face >> 1; + constexpr uint32_t num_cols_in_face = 16; + constexpr uint32_t num_rows_in_face = 16; + constexpr uint32_t num_cols_in_uint32_face = num_cols_in_face >> 1; + + // Fill the face containing window_start_pos_in_tile + int target_face = (window_start_pos_in_tile < 16) ? 0 : 1; + for (int k = target_face; k < num_faces; k += 2) { + uint32_t uint16_face_idx = k << 8; + uint32_t uint32_face_idx = k << 7; + + for (uint32_t face_row_idx = 0; face_row_idx < num_rows_in_face; face_row_idx++) { + // Fill uint32 pairs from start to fill_end_pos_in_uint32_face + for (uint32_t uint32_face_col_idx = 0; uint32_face_col_idx < fill_end_pos_in_uint32_face; + uint32_face_col_idx++) { + uint32_ptr[uint32_face_idx + (uint32_face_col_idx + num_cols_in_uint32_face * face_row_idx)] = + partial_val; + } + + // Handle the odd position if fill_end_pos_in_face is odd + if (is_odd_end_pos && fill_end_pos_in_face > 0) { + uint16_ptr[uint16_face_idx + ((fill_end_pos_in_face - 1) + num_cols_in_face * face_row_idx)] = + datum_val; + } + } + } +} + /****************************************************************************** * Attention Mask Functions * ******************************************************************************/ @@ -265,6 +335,69 @@ void generate_mask(uint32_t k_num_chunks, uint32_t Sk_chunk_t, uint32_t cur_pos) cb_push_back(cb_mask_in, total_read_tiles); } +template +void generate_sliding_window_mask(uint32_t k_num_chunks, uint32_t Sk_chunk_t, uint32_t window_start) { + /* + Generate sliding window mask for the first chunk: + - Mask positions < window_start with -inf (sliding window start) + - Allow positions >= window_start + + This mask is applied only to the first chunk to enforce sliding window constraint. + */ + + // the cb_mask in is of size PNHt * Sk_chunk_t + uint32_t total_read_tiles = PNHt * Sk_chunk_t; + uint32_t window_start_in_chunk = window_start % (Sk_chunk_t * 32); + uint32_t window_start_in_chunk_t = window_start_in_chunk / 32; + uint32_t window_start_in_tile = window_start_in_chunk % 32; + constexpr uint32_t NEG_INF = 0xFF80FF80; // TODO: Make sure this is -inf + + cb_reserve_back(cb_mask_in, total_read_tiles); + + uint64_t noc_read_addr_base = get_noc_addr(get_read_ptr(cb_mask_in)); + uint32_t q_write_ptr_base = get_read_ptr(cb_mask_in); + constexpr uint32_t tile_bytes = get_tile_size(cb_mask_in); + + for (uint32_t i = 0; i < Sk_chunk_t; ++i) { + if (i < window_start_in_chunk_t) { + // Tile is completely before sliding window - fill with -inf + if (i == 0) { + fill_tile(cb_mask_in, i, NEG_INF); + } else { + copy_tile(noc_read_addr_base, q_write_ptr_base, 0, i); + } + } else if (i == window_start_in_chunk_t) { + // Tile contains sliding window start - partial mask at beginning + fill_tile_partial_sliding_window(cb_mask_in, i, window_start_in_tile, NEG_INF); + } else { + // Tile is within sliding window - fill with zeros (allow) + if (i == window_start_in_chunk_t + 1) { + fill_tile(cb_mask_in, i, 0); + } else { + // Copy from the first allowed tile + copy_tile( + noc_read_addr_base, + q_write_ptr_base, + window_start_in_chunk_t + 1, + i); // copy from cb_mask_in[cur_pos_in_chunk_t+1] to cb_mask_in[i] + if (i == Sk_chunk_t - 1) { + noc_async_read_barrier(); + } + } + } + + // Copy to all heads + for (uint32_t j = 1; j < PNHt; ++j) { + copy_tile(noc_read_addr_base, q_write_ptr_base, i, j * Sk_chunk_t + i); + if (j == PNHt - 1) { + noc_async_read_barrier(); + } + } + } + + cb_push_back(cb_mask_in, total_read_tiles); +} + /****************************************************************************** * Writer Kernel Specific Functions * ******************************************************************************/ diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp index 166659edbcfb..4e82c183e88f 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp @@ -8,7 +8,6 @@ #include "ttnn/operations/transformer/sdpa_decode/device/kernels/rt_args_common.hpp" #include "dataflow_common.hpp" -// #include "debug/dprint.h" void kernel_main() { /* @@ -46,8 +45,9 @@ void kernel_main() { constexpr bool is_cur_pos_tensor_sharded = get_compile_time_arg_val(27); constexpr bool is_page_table_sharded = get_compile_time_arg_val(28); constexpr uint32_t q_page_size_bytes = get_compile_time_arg_val(29); + constexpr uint32_t sliding_window = get_compile_time_arg_val(30); - constexpr auto k_args = TensorAccessorArgs<30>(); + constexpr auto k_args = TensorAccessorArgs<31>(); constexpr auto q_args = TensorAccessorArgs(); constexpr auto v_args = TensorAccessorArgs(); constexpr auto mask_args = TensorAccessorArgs(); @@ -113,8 +113,13 @@ void kernel_main() { auto k_chunk_size_dynamic = Sk_chunk_t_dynamic * tt::constants::TILE_HEIGHT; // Sequence length assignment - auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end] = - get_runtime_args(cur_pos, cur_batch, core_num_in_reduce, num_cores_per_head, k_chunk_size_dynamic); + auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end, window_start_unaligned, window_start_chunk] = get_runtime_args( + cur_pos, + cur_batch, + core_num_in_reduce, + num_cores_per_head, + k_chunk_size_dynamic, + sliding_window > 0 ? std::optional(sliding_window) : std::nullopt); if (k_chunk_start == k_chunk_end) { return; // early exit because no computes needs to be done diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp index 922395c2a349..9493e73bfa9a 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp @@ -36,8 +36,9 @@ void kernel_main() { constexpr bool is_causal = get_compile_time_arg_val(22) == 1; constexpr uint32_t max_dynamic_chunk_size = get_compile_time_arg_val(23); constexpr uint32_t q_heads_parallel_factor = get_compile_time_arg_val(24); + constexpr uint32_t sliding_window = get_compile_time_arg_val(25); - constexpr auto out_args = TensorAccessorArgs<25>(); + constexpr auto out_args = TensorAccessorArgs<26>(); uint32_t arg_idx = 0; const uint32_t out_addr = get_arg_val(arg_idx++); @@ -81,8 +82,13 @@ void kernel_main() { auto k_chunk_size_dynamic = Sk_chunk_t_dynamic * tt::constants::TILE_HEIGHT; // Sequence length assignment - auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end] = - get_runtime_args(cur_pos, cur_batch, core_num_in_reduce, num_cores_per_head, k_chunk_size_dynamic); + auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end, window_start_unaligned, window_start_chunk] = get_runtime_args( + cur_pos, + cur_batch, + core_num_in_reduce, + num_cores_per_head, + k_chunk_size_dynamic, + sliding_window > 0 ? std::optional(sliding_window) : std::nullopt); if (k_chunk_start == k_chunk_end) { return; // early exit because no computes needs to be done @@ -118,6 +124,7 @@ void kernel_main() { constexpr uint32_t cb_l_in = tt::CBIndex::c_7; constexpr uint32_t cb_mask_in = tt::CBIndex::c_3; + constexpr uint32_t cb_sliding_window_mask_in = tt::CBIndex::c_13; // Separate buffer for sliding window mask constexpr uint32_t cb_identity_scale_in = tt::CBIndex::c_5; constexpr uint32_t cb_col_identity = tt::CBIndex::c_11; constexpr uint32_t cb_zero_in = tt::CBIndex::c_12; @@ -132,6 +139,12 @@ void kernel_main() { generate_reduce_scaler(cb_zero_in, zero_scalar_packed); generate_bcast_col_scalar(cb_col_identity, identity_scalar_packed); + if (k_chunk_start == window_start_chunk && window_start_unaligned > 0) { + // If this core processes the first chunk and we need to apply sliding window mask, generate it here + generate_sliding_window_mask( + k_num_chunks, Sk_chunk_t_dynamic, window_start_unaligned); + } + if (is_worker) { ASSERT(num_heads_per_core == 1); // if there are workers, then head must be split across workers so there // should not be more than one head per core diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/rt_args_common.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/rt_args_common.hpp index ec69a4e56b4a..95d1624cd031 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/rt_args_common.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/rt_args_common.hpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include +#include #include inline uint32_t nearest_n(uint32_t x, uint32_t n) { return ((x + n - 1) / n) * n; } @@ -29,33 +30,63 @@ inline uint8_t nearest_pow_of_2_up_to_8(uint32_t x) { return (result > max) ? max : result; } -inline std::tuple get_runtime_args( - int cur_pos, int cur_batch, int core_num, int num_cores_per_batch, uint32_t k_chunk_size) { - uint32_t valid_seq_len = nearest_n(cur_pos + 1, k_chunk_size); +inline std::tuple get_runtime_args( + int cur_pos, + int cur_batch, + int core_num, + int num_cores_per_batch, + uint32_t k_chunk_size, + std::optional sliding_window = std::nullopt) { + uint32_t window_start = 0; + uint32_t window_start_unaligned = 0; // Keep track of the actual window start for masking + uint32_t valid_seq_len; + + if (sliding_window.has_value() && sliding_window.value() > 0) { + // Calculate actual window bounds + uint32_t window_end = cur_pos + 1; // exclusive end + window_start_unaligned = (window_end > sliding_window.value()) ? (window_end - sliding_window.value()) : 0; + + // Round window_start down to chunk boundary to ensure we capture the full window + uint32_t window_start_aligned = (window_start_unaligned / k_chunk_size) * k_chunk_size; + + // Round window_end up to chunk boundary to ensure we capture the full window + uint32_t window_end_aligned = nearest_n(window_end, k_chunk_size); + + // Calculate valid_seq_len based on the sliding window range + valid_seq_len = window_end_aligned - window_start_aligned; + window_start = window_start_aligned; // Use aligned start for chunk calculations + } else { + // Standard behavior: process from beginning up to cur_pos + valid_seq_len = nearest_n(cur_pos + 1, k_chunk_size); + window_start = 0; + window_start_unaligned = 0; + } + uint32_t pst_value = valid_seq_len / tt::constants::TILE_HEIGHT; + uint32_t window_start_chunk = window_start / k_chunk_size; uint32_t num_chunks_value = valid_seq_len / k_chunk_size; - uint32_t k_chunk_start = 0; - uint32_t k_chunk_end = 0; + uint32_t k_chunk_start = window_start_chunk; + uint32_t k_chunk_end = window_start_chunk; + // Distribute active chunks among cores if (num_cores_per_batch > int(num_chunks_value)) { - int chunks_per_core = 1; - if (core_num >= int(num_chunks_value)) { - chunks_per_core = 0; - } - k_chunk_start = (num_chunks_value - core_num - 1) * chunks_per_core; - k_chunk_end = (num_chunks_value - core_num) * chunks_per_core; + int chunks_per_core = (core_num < int(num_chunks_value)) ? 1 : 0; + k_chunk_start = window_start_chunk + (num_chunks_value - core_num - 1) * chunks_per_core; + k_chunk_end = window_start_chunk + (num_chunks_value - core_num) * chunks_per_core; } else { int chunks_per_core = num_chunks_value / num_cores_per_batch; int residuals = num_chunks_value % num_cores_per_batch; int reversed_core_num = num_cores_per_batch - core_num - 1; - k_chunk_start = reversed_core_num * chunks_per_core + std::min(residuals, reversed_core_num); + k_chunk_start = + window_start_chunk + reversed_core_num * chunks_per_core + std::min(residuals, reversed_core_num); k_chunk_end = k_chunk_start + chunks_per_core; if (reversed_core_num < residuals) { k_chunk_end += 1; } } - return {pst_value, num_chunks_value, k_chunk_start, k_chunk_end}; + + return {pst_value, num_chunks_value, k_chunk_start, k_chunk_end, window_start_unaligned, window_start_chunk}; } template diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp index 95a50c07af8a..90c95bb4cbeb 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp @@ -1,3 +1,4 @@ + // SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 @@ -320,6 +321,10 @@ operation::ProgramWithCallbacks ScaledDotProductAttentionDecode::create_program( if (not scale.has_value()) { scale = 1.0f / std::sqrt(static_cast(input_tensor_q.padded_shape()[-1])); } + auto sliding_window = this->sliding_window; + if (not sliding_window.has_value()) { + sliding_window = 0; + } return detail::sdpa_decode_multi_core( input_tensor_q, @@ -338,7 +343,8 @@ operation::ProgramWithCallbacks ScaledDotProductAttentionDecode::create_program( this->k_chunk_size, this->share_cache, this->use_mla.value_or(false), - this->head_dim_v.value_or(0)); + this->head_dim_v.value_or(0), + sliding_window); } operation::Hash ScaledDotProductAttentionDecode::compute_program_hash( @@ -356,6 +362,7 @@ operation::Hash ScaledDotProductAttentionDecode::compute_program_hash( this->is_causal, this->use_mla, this->head_dim_v, + this->sliding_window, has_attn_mask, has_cur_pos, input_tensors, diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp index 1ad7e59d89e4..8957ef894f2a 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp @@ -17,6 +17,7 @@ struct ScaledDotProductAttentionDecode { const bool is_causal; std::vector cur_pos; const std::optional scale; + const std::optional sliding_window; const tt::tt_metal::MemoryConfig output_mem_config; const std::optional program_config; const DeviceComputeKernelConfig compute_kernel_config; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp index acf291323942..724864697849 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp @@ -11,7 +11,6 @@ #include "sdpa_decode_op.hpp" #include #include -#include #include #include "ttnn/operation.hpp" #include @@ -40,7 +39,8 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( const uint32_t k_chunk_size, std::optional share_cache, bool use_mla, - uint32_t head_dim_v) { + uint32_t head_dim_v, + std::optional sliding_window) { /* Q: 1 x B x PNH x DH K: B x NKV x S x DH @@ -421,7 +421,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( if (use_cur_pos_tensor) { auto pos_buffer = cur_pos_tensor.value().buffer(); tt::DataFormat pos_df = tt_metal::datatype_to_dataformat_converter(cur_pos_tensor.value().dtype()); - pos_tensor_tile_size = tt_metal::detail::TileSize(pos_df); + pos_tensor_tile_size = tt::tile_size(pos_df); index_stick_size = pos_buffer->aligned_page_size(); // cb pos @@ -533,6 +533,14 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( .set_tile_dims(CBIndex::c_12, scalar_tile); CreateCircularBuffer(program, core_grid, c_zero_config); + // sliding window mask input (conditionally created based on sliding_window) + if (sliding_window.has_value() && sliding_window.value() > 0) { + auto c_sliding_window_mask_config = CircularBufferConfig(qk_tiles * mask_tile_size, {{CBIndex::c_13, mask_df}}) + .set_page_size(CBIndex::c_13, mask_tile_size) + .set_tile_dims(CBIndex::c_13, mask_tile); + CreateCircularBuffer(program, core_grid, c_sliding_window_mask_config); + } + // cb_qk_im auto c_intermed0_config = CircularBufferConfig(qk_tiles * im_tile_size, {{CBIndex::c_24, im_df}}) .set_page_size(CBIndex::c_24, im_tile_size) @@ -660,7 +668,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( for (uint32_t i = 0; i < num_active_cores; ++i) { CoreCoord core = core_group[i]; - uint32_t worker_id_for_reduce = i % num_cores_per_head - 1; + uint32_t worker_id_for_reduce = (i % num_cores_per_head) - 1; bool do_reduce = (worker_id_for_reduce == -1); if (do_reduce) { reduce_core_noc_x = core.x; @@ -686,7 +694,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( for (uint32_t i = 0; i < num_active_cores; ++i) { CoreCoord core = core_group[i]; - uint32_t worker_id_for_output = i % num_cores_per_batch - 1; + uint32_t worker_id_for_output = (i % num_cores_per_batch) - 1; bool do_output = (worker_id_for_output == -1); if (do_output) { output_core_noc_x = core.x; @@ -742,6 +750,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( is_cur_pos_tensor_sharded, is_page_table_sharded, full_tile.get_tile_size(q_df), + sliding_window.value_or(0), // Add sliding_window to compile-time args }; tt_metal::TensorAccessorArgs(input_tensor_k.buffer()).append_to(reader_compile_time_args_common); tt_metal::TensorAccessorArgs(input_tensor_q.buffer()).append_to(reader_compile_time_args_common); @@ -784,6 +793,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( is_causal, max_dynamic_chunk_size, q_heads_parallel_factor, + sliding_window.value_or(0), // Add sliding_window to writer compile-time args }; tt_metal::TensorAccessorArgs(output_tensor.buffer()).append_to(writer_compile_time_args_common); @@ -817,6 +827,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( q_heads_parallel_factor, use_half_tile, scale_union.u, + sliding_window.value_or(0), // Add sliding_window to compute compile-time args }; // Determine granularity for compute loops @@ -900,8 +911,8 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( // Set rt args for (uint32_t i = 0; i < num_active_cores; ++i) { CoreCoord core = core_group[i]; - uint32_t worker_id_for_reduce = i % num_cores_per_head - 1; - uint32_t worker_id_for_output = i % num_cores_per_batch - 1; + uint32_t worker_id_for_reduce = (i % num_cores_per_head) - 1; + uint32_t worker_id_for_output = (i % num_cores_per_batch) - 1; bool do_reduce = (worker_id_for_reduce == -1); bool do_output = (worker_id_for_output == -1); @@ -1046,8 +1057,8 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( // Set rt args for (uint32_t i = 0; i < num_active_cores; ++i) { CoreCoord core = core_group[i]; - uint32_t worker_id_for_reduce = (num_cores_per_head == 0) ? -1 : i % num_cores_per_head - 1; - uint32_t worker_id_for_output = i % num_cores_per_batch - 1; + uint32_t worker_id_for_reduce = (num_cores_per_head == 0) ? -1 : (i % num_cores_per_head) - 1; + uint32_t worker_id_for_output = (i % num_cores_per_batch) - 1; bool do_reduce = (worker_id_for_reduce == -1); bool do_output = (worker_id_for_output == -1); uint32_t cur_head = (num_cores_per_head == 0) ? 0 : (i % num_cores_per_batch) / num_cores_per_head; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.hpp index 64d31123ed16..d44758c77699 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.hpp @@ -27,6 +27,7 @@ tt::tt_metal::operation::ProgramWithCallbacks sdpa_decode_multi_core( uint32_t k_chunk_size, std::optional share_cache, bool mla = false, - uint32_t head_dim_v = 0); + uint32_t head_dim_v = 0, + std::optional sliding_window = std::nullopt); } // namespace ttnn::operations::transformer::detail diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp index 5d8f00d36357..c440025b8dbf 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp @@ -41,6 +41,7 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke( const std::optional& cur_pos_tensor, const std::optional& attention_sink, std::optional scale, + std::optional sliding_window, const std::optional& memory_config, std::optional program_config, std::optional compute_kernel_config) { @@ -74,6 +75,7 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke( .is_causal = is_causal, .cur_pos = cur_pos, .scale = scale, + .sliding_window = sliding_window, .output_mem_config = memory_config.value_or(operation::DEFAULT_OUTPUT_MEMORY_CONFIG), .program_config = program_config, .compute_kernel_config = kernel_config_val, @@ -95,6 +97,7 @@ ttnn::Tensor ExecutePagedScaledDotProductAttentionDecode::invoke( const std::optional& cur_pos_tensor, const std::optional& attention_sink, std::optional scale, + std::optional sliding_window, const std::optional& memory_config, std::optional program_config, std::optional compute_kernel_config) { @@ -125,6 +128,7 @@ ttnn::Tensor ExecutePagedScaledDotProductAttentionDecode::invoke( .is_causal = is_causal, .cur_pos = std::vector(), .scale = scale, + .sliding_window = sliding_window, .output_mem_config = memory_config.value_or(operation::DEFAULT_OUTPUT_MEMORY_CONFIG), .program_config = program_config, .compute_kernel_config = kernel_config_val, @@ -146,6 +150,7 @@ ttnn::Tensor ExecuteFlashMultiLatentAttentionDecode::invoke( const std::optional& cur_pos_tensor, const std::optional& attention_sink, std::optional scale, + std::optional sliding_window, const std::optional& memory_config, std::optional program_config, std::optional compute_kernel_config) { @@ -179,6 +184,7 @@ ttnn::Tensor ExecuteFlashMultiLatentAttentionDecode::invoke( .is_causal = is_causal, .cur_pos = cur_pos, .scale = scale, + .sliding_window = sliding_window, .output_mem_config = memory_config.value_or(operation::DEFAULT_OUTPUT_MEMORY_CONFIG), .program_config = program_config, .compute_kernel_config = kernel_config_val, @@ -202,6 +208,7 @@ ttnn::Tensor ExecutePagedFlashMultiLatentAttentionDecode::invoke( const std::optional& cur_pos_tensor, const std::optional& attention_sink, std::optional scale, + std::optional sliding_window, const std::optional& memory_config, std::optional program_config, std::optional compute_kernel_config) { @@ -232,6 +239,7 @@ ttnn::Tensor ExecutePagedFlashMultiLatentAttentionDecode::invoke( .is_causal = is_causal, .cur_pos = std::vector(), .scale = scale, + .sliding_window = sliding_window, .output_mem_config = memory_config.value_or(operation::DEFAULT_OUTPUT_MEMORY_CONFIG), .program_config = program_config, .compute_kernel_config = kernel_config_val, diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.hpp index f62899179b64..4423599261ed 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.hpp @@ -22,6 +22,7 @@ struct ExecuteScaledDotProductAttentionDecode { const std::optional& cur_pos_tensor = std::nullopt, const std::optional& attention_sink = std::nullopt, std::optional scale = std::nullopt, + std::optional sliding_window = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional program_config = std::nullopt, std::optional compute_kernel_config = std::nullopt); @@ -38,6 +39,7 @@ struct ExecutePagedScaledDotProductAttentionDecode { const std::optional& cur_pos_tensor = std::nullopt, const std::optional& attention_sink = std::nullopt, std::optional scale = std::nullopt, + std::optional sliding_window = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional program_config = std::nullopt, std::optional compute_kernel_config = std::nullopt); @@ -54,6 +56,7 @@ struct ExecuteFlashMultiLatentAttentionDecode { const std::optional& cur_pos_tensor = std::nullopt, const std::optional& attention_sink = std::nullopt, std::optional scale = std::nullopt, + std::optional sliding_window = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional program_config = std::nullopt, std::optional compute_kernel_config = std::nullopt); @@ -70,6 +73,7 @@ struct ExecutePagedFlashMultiLatentAttentionDecode { const std::optional& cur_pos_tensor = std::nullopt, const std::optional& attention_sink = std::nullopt, std::optional scale = std::nullopt, + std::optional sliding_window = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional program_config = std::nullopt, std::optional compute_kernel_config = std::nullopt); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode_pybind.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode_pybind.cpp index 6f3189923f61..32356a463394 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode_pybind.cpp @@ -63,6 +63,7 @@ void py_bind_sdpa_decode(py::module& module) { const std::optional& cur_pos_tensor, const std::optional& attention_sink, std::optional scale, + std::optional sliding_window, const std::optional& memory_config, std::optional program_config, std::optional compute_kernel_config) { @@ -76,6 +77,7 @@ void py_bind_sdpa_decode(py::module& module) { cur_pos_tensor, attention_sink, scale, + sliding_window, memory_config, program_config, compute_kernel_config); @@ -90,6 +92,7 @@ void py_bind_sdpa_decode(py::module& module) { py::arg("cur_pos_tensor").noconvert() = std::nullopt, py::arg("attention_sink").noconvert() = std::nullopt, py::arg("scale").noconvert() = std::nullopt, + py::arg("sliding_window").noconvert() = std::nullopt, py::arg("memory_config").noconvert() = std::nullopt, py::arg("program_config").noconvert() = std::nullopt, py::arg("compute_kernel_config").noconvert() = std::nullopt}); @@ -110,6 +113,7 @@ void py_bind_sdpa_decode(py::module& module) { const std::optional& cur_pos_tensor, const std::optional& attention_sink, std::optional scale, + std::optional sliding_window, const std::optional& memory_config, std::optional program_config, std::optional compute_kernel_config) { @@ -123,6 +127,7 @@ void py_bind_sdpa_decode(py::module& module) { cur_pos_tensor, attention_sink, scale, + sliding_window, memory_config, program_config, compute_kernel_config); @@ -137,6 +142,7 @@ void py_bind_sdpa_decode(py::module& module) { py::arg("cur_pos_tensor").noconvert() = std::nullopt, py::arg("attention_sink").noconvert() = std::nullopt, py::arg("scale").noconvert() = std::nullopt, + py::arg("sliding_window").noconvert() = std::nullopt, py::arg("memory_config").noconvert() = std::nullopt, py::arg("program_config").noconvert() = std::nullopt, py::arg("compute_kernel_config").noconvert() = std::nullopt}); @@ -157,6 +163,7 @@ void py_bind_sdpa_decode(py::module& module) { const std::optional& cur_pos_tensor, const std::optional& attention_sink, std::optional scale, + std::optional sliding_window, const std::optional& memory_config, std::optional program_config, std::optional compute_kernel_config) { @@ -170,6 +177,7 @@ void py_bind_sdpa_decode(py::module& module) { cur_pos_tensor, attention_sink, scale, + sliding_window, memory_config, program_config, compute_kernel_config); @@ -184,6 +192,7 @@ void py_bind_sdpa_decode(py::module& module) { py::arg("cur_pos_tensor").noconvert() = std::nullopt, py::arg("attention_sink").noconvert() = std::nullopt, py::arg("scale").noconvert() = std::nullopt, + py::arg("sliding_window").noconvert() = std::nullopt, py::arg("memory_config").noconvert() = std::nullopt, py::arg("program_config").noconvert() = std::nullopt, py::arg("compute_kernel_config").noconvert() = std::nullopt}); @@ -204,6 +213,7 @@ void py_bind_sdpa_decode(py::module& module) { const std::optional& cur_pos_tensor, const std::optional& attention_sink, std::optional scale, + std::optional sliding_window, const std::optional& memory_config, std::optional program_config, std::optional compute_kernel_config) { @@ -217,6 +227,7 @@ void py_bind_sdpa_decode(py::module& module) { cur_pos_tensor, attention_sink, scale, + sliding_window, memory_config, program_config, compute_kernel_config); @@ -231,6 +242,7 @@ void py_bind_sdpa_decode(py::module& module) { py::arg("cur_pos_tensor").noconvert() = std::nullopt, py::arg("attention_sink").noconvert() = std::nullopt, py::arg("scale").noconvert() = std::nullopt, + py::arg("sliding_window").noconvert() = std::nullopt, py::arg("memory_config").noconvert() = std::nullopt, py::arg("program_config").noconvert() = std::nullopt, py::arg("compute_kernel_config").noconvert() = std::nullopt});