diff --git a/models/experimental/gemma3/tests/test_attention.py b/models/experimental/gemma3/tests/test_attention.py deleted file mode 100644 index 5e1c1a905cde..000000000000 --- a/models/experimental/gemma3/tests/test_attention.py +++ /dev/null @@ -1,289 +0,0 @@ -"""Gemma-3 Test for Text Attention""" - - -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 -import os - -import pytest -import torch -from loguru import logger - -import ttnn -from models.experimental.gemma3.tt.attention import Attention -from models.tt_transformers.tt.common import PagedAttentionConfig, precompute_freqs -from models.tt_transformers.tt.rope import RotarySetup -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull - -from models.tt_transformers.tt.model_config import ModelArgs -from models.tt_transformers.tt.ccl import TT_CCL - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -@pytest.mark.parametrize( - "paged_attention", - ( - True, - # False, - ), - ids=( - "paged_attention", - # "default_attention", - ), -) -@pytest.mark.parametrize( - "page_params", - [{"page_block_size": 32, "page_max_num_blocks": 1024}], -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -@pytest.mark.parametrize( - "max_seq_len", - (1,), # For decode-only unit test, there's no need to run with large sequence lengths -) -@pytest.mark.parametrize( - "device_params", - [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], - indirect=True, -) -def test_attention_inference( - max_seq_len, - batch_size, - paged_attention, - page_params, - mesh_device, - reset_seeds, - device_params, - # ensure_gc, -): - dtype = ttnn.bfloat16 - pcc = 0.99 - - model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=128) - model_args.n_layers = 6 # For the unit test, just run a single layer - - state_dict = model_args.load_state_dict() - - first_layer_prefix = model_args.get_state_dict_prefix("Attention", 0) + "." - # Ref model needs partial state dict, but our models use full state dict keys as cached weight names - # partial_state_dict = { - # k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) - # } - - reference_model = model_args.reference_attention() - # reference_model.load_state_dict(partial_state_dict) - - seq_len = 1 - - generation_start_pos = 0 - generation_length = 10 - all_tests_pass = True - - # Setup RoPE transformation matrices - rope_setup = RotarySetup( - mesh_device, - batch_size, - model_args.head_dim, - model_args.max_seq_len, - model_args.rope_theta, - model_args.rope_scaling, - ) - - transformation_mats = rope_setup.get_both_trans_mats() - - page_table_tt = None - paged_attention_config = None - - if paged_attention: - paged_attention_config = PagedAttentionConfig( - block_size=page_params["page_block_size"], - max_num_blocks=page_params["page_max_num_blocks"], - ) - - # Implied shuffling of blocks - permutation = torch.randperm(paged_attention_config.max_num_blocks) - # Page table which maps virtual blocks to physical - reverse_permutation = torch.argsort(permutation) - page_table = reverse_permutation.reshape( - model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size - ) - page_table_tt = ttnn.from_torch( - page_table, - device=mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, -2) if (model_args.is_galaxy and batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - tt_ccl = TT_CCL(mesh_device) - tt_model = Attention( - mesh_device, - tt_ccl, - state_dict, - weight_cache_path=model_args.weight_cache_path(dtype), - layer_num=0, - dtype=dtype, - transformation_mats=transformation_mats, - configuration=model_args, - paged_attention_config=paged_attention_config, - ) - - cos, sin = precompute_freqs( - model_args.head_dim, - model_args.max_seq_len * 2, - model_args.rope_theta, - model_args.rope_scaling.factor if model_args.rope_scaling else None, - model_args.rope_scaling.original_max_position_embeddings if model_args.rope_scaling else None, - rope_type="linear", - ) - freqs_cis = torch.complex(cos, sin) - - # Initial positions - current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)]) - current_pos_tensor = ttnn.from_torch( - current_pos, - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - for i in range(generation_length): - # 70B attention block typically sees tensors with mean 0 and std 0.03 - 0.05 in layer 1 - pt_attention_input = torch.randn(batch_size, seq_len, model_args.dim) # Qwen2.5 0.5B sees 0.1 to 2.1 - - tt_attention_input = pt_attention_input.clone() - - attention_input = model_args.prepare_residual_tensor_decode( - tt_attention_input, - model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"], - force_replicated=False if model_args.is_galaxy else True, - ) - - # Get cos/sin matrices for the current position of each user - rot_mats = rope_setup.get_rot_mats(current_pos) - - tt_out = tt_model( - attention_input, - current_pos_tensor, - rot_mats=rot_mats, - mode="decode", - page_table=page_table_tt, - ) - # multi-device attention module returns replicated output - tt_out = ttnn.to_torch( - tt_out, - mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape), - ) - tt_output_torch = tt_out[:, 0:1, : model_args.max_batch_size, : model_args.dim].view(-1, 1, model_args.dim) - - # In this test all users have the same position (if using batch > 1) - freqs_cis_i = freqs_cis[current_pos[0], :].unsqueeze(0) - - reference_output = reference_model(pt_attention_input, current_pos[0], freqs_cis_i, mask=None) - - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) - - logger.info(comp_allclose(reference_output, tt_output_torch)) - logger.info(f"PCC: {pcc_message}") - if passing: - logger.info(f"[pos={current_pos[0]}] Attention Passed!") - else: - logger.warning(f"[pos={current_pos[0]}] Attention Failed!") - all_tests_pass = False - - # Increment position - current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch_size)]) - current_pos_tensor = ttnn.from_torch( - current_pos, - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - check_kv_cache = True - if check_kv_cache: - # PyTorch output -------------------------------------------------------------------- - pytorch_layer_present = [ - reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] - reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] - ] - # TT hardware execution ------------------------------------------------------------- - if paged_attention: - tt_layer_present = [ - ( - ttnn.to_torch( - cache, - mesh_composer=ttnn.ConcatMesh2dToTensor( - mesh_device, - dims=(1, 3) if model_args.is_galaxy else (0, 1), - mesh_shape=model_args.cluster_shape, - ), - )[reverse_permutation][:, : model_args.n_kv_heads, :, : model_args.head_dim] - .reshape( - model_args.max_batch_size, - paged_attention_config.max_num_blocks // model_args.max_batch_size, - model_args.n_kv_heads, - paged_attention_config.block_size, - model_args.head_dim, - ) - .transpose(1, 2) - .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ - :batch_size, ... - ] - ) - for cache in tt_model.layer_past - ] - else: - tt_layer_present = [ - ttnn.to_torch( - cache, - mesh_composer=ttnn.ConcatMesh2dToTensor( - mesh_device, - dims=(1, 0) if model_args.is_galaxy else (0, 1), - mesh_shape=model_args.cluster_shape, - ), - )[:batch_size, :, :, :] - for cache in tt_model.layer_past - ] - for label, cache_pt, cache_tt in zip(["K", "V"], pytorch_layer_present, tt_layer_present): - cache_length_to_check = min(model_args.max_seq_len, generation_start_pos + i + 1) - cache_pt = cache_pt[:, :, generation_start_pos:cache_length_to_check, :] - cache_tt = cache_tt[:, :, generation_start_pos:cache_length_to_check, :] - does_pass, output_pcc = comp_pcc(cache_pt, cache_tt, pcc) - logger.info(f"{label} cache output: {output_pcc}") - if does_pass: - logger.info(f"{label} cache Passed!") - else: - logger.warning(f"{label} Cache Failed! PCC value is lower than {pcc}") - all_tests_pass = False - - if all_tests_pass: - logger.info("Attention output Passed!") - else: - logger.warning("Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3/tests/test_decoder.py b/models/experimental/gemma3/tests/test_decoder.py deleted file mode 100644 index 0a40ff780bb7..000000000000 --- a/models/experimental/gemma3/tests/test_decoder.py +++ /dev/null @@ -1,229 +0,0 @@ -"""Gemma3 Test for Text Decoder""" - -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 - -import torch -import pytest -from loguru import logger -import os -import ttnn -from models.tt_transformers.tt.model_config import ModelArgs -from models.experimental.gemma3.tt.decoder import TransformerBlock -from models.utility_functions import ( - comp_pcc, - comp_allclose, -) -from models.utility_functions import skip_for_grayskull -from models.tt_transformers.tt.common import PagedAttentionConfig -from models.tt_transformers.tt.rope import RotarySetup -from models.tt_transformers.tt.ccl import TT_CCL -from models.tt_transformers.tt.common import PagedAttentionConfig, precompute_freqs - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -@pytest.mark.parametrize( - "paged_attention", - ( - True, - # False - ), - ids=( - "paged_attention", - # "default_attention" - ), -) -@pytest.mark.parametrize( - "page_params", - [{"page_block_size": 32, "page_max_num_blocks": 1024}], -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -@pytest.mark.parametrize( - "max_seq_len", - (256,), # For decode-only unit test, there's no need to run with large sequence lengths -) -@pytest.mark.parametrize( - "device_params", - [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], - indirect=True, -) -def test_decoder_inference( - max_seq_len, - batch_size, - paged_attention, - page_params, - mesh_device, - device_params, - reset_seeds, -): - dtype = ttnn.bfloat16 - - pcc_required = 0.85 - model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) - model_args.n_layers = 1 - - state_dict = model_args.load_state_dict() - - reference_model = model_args.reference_decoder() - - generation_start_pos = 0 - generation_length = 3 - all_tests_pass = True - - rope_setup = RotarySetup( - mesh_device, - batch_size, - model_args.head_dim, - model_args.max_seq_len, - model_args.rope_theta, - model_args.rope_scaling, - ) - transformation_mats = rope_setup.get_both_trans_mats() - - # Prepare page table for paged attention - page_table_tt = None - paged_attention_config = None - - if paged_attention: - paged_attention_config = PagedAttentionConfig( - block_size=page_params["page_block_size"], - max_num_blocks=page_params["page_max_num_blocks"], - ) - # Implied shuffling of blocks - permutation = torch.randperm(paged_attention_config.max_num_blocks) - # Page table which maps virtual blocks to physical - reverse_permutation = torch.argsort(permutation) - page_table = reverse_permutation.reshape( - model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size - ) - page_table_tt = ttnn.from_torch( - page_table, - device=mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, -2) if (model_args.is_galaxy and batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - # Initialize TT model - tt_ccl = TT_CCL(mesh_device) - tt_model = TransformerBlock( - args=model_args, - mesh_device=mesh_device, - tt_ccl=tt_ccl, - dtype=dtype, - state_dict=state_dict, - layer_num=0, - weight_cache_path=model_args.weight_cache_path(dtype), - transformation_mats=transformation_mats, - paged_attention_config=paged_attention_config, - ) - - seqlen = 1 - - cos, sin = precompute_freqs( - model_args.head_dim, - model_args.max_seq_len * 2, - model_args.rope_theta, - model_args.rope_scaling.factor if model_args.rope_scaling else None, - model_args.rope_scaling.original_max_position_embeddings if model_args.rope_scaling else None, - rope_type="linear", - ) - freqs_cis = torch.complex(cos, sin) - - # Initial positions - current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)]) - current_pos_tensor = ttnn.from_torch( - current_pos, - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - for i in range(generation_length): - pt_decode_input = (torch.rand(batch_size, seqlen, model_args.dim) * 2) - 1 - logger.info(f"[Decoder] Generating token {i}") - - tt_decode_input = pt_decode_input.clone() - - decode_input = model_args.prepare_residual_tensor_decode( - tt_decode_input, - # ttnn.DRAM_MEMORY_CONFIG, - model_args.model_config["DECODE_RESIDUAL_MEMCFG"], - ) - - # Get cos/sin matrices for the current position of each user - rot_mat_global = rope_setup.get_rot_mats(current_pos) - rot_mat_local = rope_setup.get_rot_mats(current_pos) - - # Run TT model - tt_out = tt_model( - decode_input, - current_pos_tensor, - rot_mats_global=rot_mat_global, - rot_mats_local=rot_mat_local, - mode="decode", - page_table=page_table_tt, - ) - tt_out = ttnn.to_torch( - tt_out, - mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape), - ) - - tt_output_torch = tt_out[:, 0:1, : model_args.max_batch_size, : model_args.dim].view(-1, 1, model_args.dim) - - # Reference model - ref_output = reference_model(pt_decode_input, current_pos[0], None, mask=None) - non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) - tt_output_torch = tt_output_torch[non_zero_indices] - ref_output = ref_output[non_zero_indices] - - passing, pcc_message = comp_pcc(ref_output, tt_output_torch, pcc_required) - - logger.info(comp_allclose(ref_output, tt_output_torch)) - logger.info(f"PCC: {pcc_message}") - - if passing: - logger.info("Decoder Block Passed!") - else: - logger.warning("Decoder Block Failed!") - all_tests_pass = False - - # Increment position - current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch_size)]) - current_pos_tensor = ttnn.from_torch( - current_pos, - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - if all_tests_pass: - logger.info(f"All {generation_length} decode iterations Passed!") - else: - logger.warning("One or more iterations of decode Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3/tests/test_embedding.py b/models/experimental/gemma3/tests/test_embedding.py deleted file mode 100644 index 751fbbbd824d..000000000000 --- a/models/experimental/gemma3/tests/test_embedding.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Gemma3 test for Text Embedding""" - - -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 -import os - -import pytest -import torch -from loguru import logger - -import ttnn -from models.tt_transformers.tt.embedding import Embedding, ScaledEmbedding -from models.tt_transformers.tt.model_config import ModelArgs -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize("use_scaled_embedding", (False, True)) -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -@pytest.mark.parametrize( - "max_seq_len", - (128,), # For decode-only unit test, there's no need to run with large sequence lengths -) -def test_embedding(max_seq_len, batch_size, mesh_device, reset_seeds, ensure_gc, use_scaled_embedding): - dtype = ttnn.bfloat16 - - model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len, cache_hf=True) - model_args.n_layers = 1 - - state_dict = model_args.load_state_dict() - tokenizer = model_args.tokenizer - - if use_scaled_embedding: - model_args.embed_scale = model_args.dim**0.5 - logger.info(f"Using scaled embedding with scale {model_args.embed_scale}") - - reference_emb = model_args.reference_embedding() - layer_name = "tok_embeddings.weight" - reference_emb.load_state_dict({"emb.weight": state_dict[layer_name]}) - - emb_kwargs = { - "mesh_device": mesh_device, - "args": model_args, - "weight_cache_path": model_args.weight_cache_path(dtype), - "state_dict": state_dict, - "dtype": dtype, - } - if use_scaled_embedding: - emb_kwargs["embed_scale"] = model_args.embed_scale - emb_cls = ScaledEmbedding - else: - emb_cls = Embedding - - tt_emb = emb_cls(**emb_kwargs) - - prompts = ["Joy"] * 32 - pt_input = torch.tensor([model_args.encode_prompt(prompt, instruct=False) for prompt in prompts]) - reference_output = reference_emb(pt_input) - logger.info(f"reference_output: {reference_output.shape}") - - tt_input = ttnn.from_torch( - pt_input.squeeze(1), - device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - dtype=ttnn.uint32, - layout=ttnn.ROW_MAJOR_LAYOUT, - ) - tt_output = tt_emb(tt_input) - tt_output = ttnn.multiply(tt_output, model_args.embed_scale) - tt_output_torch = ttnn.to_torch( - tt_output, - mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(0, -1), mesh_shape=model_args.cluster_shape), - )[:32].view(reference_output.shape) - logger.info(f"tt_output_torch: {tt_output_torch.shape}") - - passing, pcc_message = comp_pcc(reference_output, tt_output_torch) - - logger.info(comp_allclose(reference_output, tt_output_torch)) - logger.info(f"PCC: {pcc_message}") - if passing: - logger.info("embedding Passed!") - else: - logger.warning("embedding Failed!") - - assert passing, f"embedding output does not meet PCC requirement {0.99}." diff --git a/models/experimental/gemma3/tests/test_lm_head.py b/models/experimental/gemma3/tests/test_lm_head.py deleted file mode 100644 index fdecdec31ebe..000000000000 --- a/models/experimental/gemma3/tests/test_lm_head.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Gemma3 Test for lm_head""" - - -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 - -import os - -import pytest -import torch -from loguru import logger - -import ttnn -from models.experimental.gemma3.tt.lm_head import LMHead -from models.tt_transformers.tt.model_config import ModelArgs -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull -from models.tt_transformers.tt.ccl import TT_CCL - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "seq_len", - (32,), -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -def test_lm_head_inference(seq_len, batch_size, mesh_device, reset_seeds): - dtype = ttnn.bfloat8_b - - model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=seq_len) - model_args.n_layers = 1 - state_dict = model_args.load_state_dict() - - state_dict_prefix = model_args.get_state_dict_prefix("", None) - # Ref model needs partial state dict, but our models use full state dict keys as cached weight names - partial_state_dict = { - "weight": state_dict[f"{state_dict_prefix}output.weight"], - } - - model_args.WEIGHTS_DTYPE = dtype - reference_model = model_args.reference_lm_head() - reference_model.load_state_dict(partial_state_dict) - - tt_ccl = TT_CCL(mesh_device) - tt_model = LMHead( - args=model_args, - mesh_device=mesh_device, - tt_ccl=tt_ccl, - dtype=dtype, - state_dict=state_dict, - state_dict_prefix=state_dict_prefix, - weight_cache_path=model_args.weight_cache_path(dtype), - max_columns_per_device=model_args.max_columns_per_device_lm_head, - ) - - torch_input = torch.randn(1, 1, seq_len, model_args.dim) - reference_output = reference_model(torch_input) - tt_input = ttnn.from_torch( - torch_input, - device=mesh_device, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, dims=(None, 3) if model_args.is_galaxy else (None, None), mesh_shape=model_args.cluster_shape - ), - dtype=ttnn.bfloat16, - memory_config=model_args.model_config["LM_HEAD_INPUT_MEMCFG"], - layout=ttnn.TILE_LAYOUT, - ) - - logger.info("Run LM_Head") - tt_output = tt_model(tt_input) - tt_output_torch = ttnn.to_torch( - tt_output, - mesh_composer=ttnn.ConcatMesh2dToTensor( - mesh_device, model_args.cluster_shape, dims=(3, 1) if model_args.is_galaxy else (1, 3) - ), - ) - tt_output_torch = tt_output_torch[:, 0:1, :, : model_args.vocab_size] - - pcc_required = 0.99 - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) - - non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) - tt_output_torch = tt_output_torch[non_zero_indices] - reference_output = reference_output[non_zero_indices] - - logger.info(comp_allclose(reference_output, tt_output_torch)) - logger.info(f"PCC: {pcc_message}") - if passing: - logger.info("LM_Head Passed!") - else: - logger.warning("LM_Head Failed!") - - assert passing, f"LM_Head output does not meet PCC requirement {pcc_required}: {pcc_message}." diff --git a/models/experimental/gemma3/tests/test_mlp.py b/models/experimental/gemma3/tests/test_mlp.py deleted file mode 100644 index 02cec5c19101..000000000000 --- a/models/experimental/gemma3/tests/test_mlp.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Gemma3 Test for Text MLP""" - -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 - -from loguru import logger - -import torch -import pytest -import os -import ttnn - -from models.tt_transformers.tests.test_utils import get_ref_model_dype -from models.experimental.gemma3.tt.mlp import MLP -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull -from models.tt_transformers.tt.ccl import TT_CCL -from models.tt_transformers.tt.model_config import ModelArgs - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -@pytest.mark.parametrize( - "seq_len", - (128,), -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -@pytest.mark.parametrize( - "device_params", - [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], - indirect=True, -) -def test_mlp_inference(seq_len, batch_size, reset_seeds, mesh_device, device_params): - dtype = ttnn.bfloat16 - mode = "decode" if seq_len <= 32 else "prefill" - - tt_model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=128) - - tt_model_args.n_layers = 1 - state_dict = tt_model_args.load_state_dict() - - # # Ref model needs partial state dict, but our models use full state dict keys as cached weight names - # first_layer_prefix = "layers.0.feed_forward" - first_layer_prefix = tt_model_args.get_state_dict_prefix("MLP", 0) - - partial_state_dict = { - k[len(first_layer_prefix) + 1 :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) - } - - reference_model = tt_model_args.reference_mlp() # Gemma3 MLP - reference_model.load_state_dict(partial_state_dict) - - tt_ccl = TT_CCL(mesh_device) - tt_model = MLP( - mesh_device=mesh_device, - tt_ccl=tt_ccl, - args=tt_model_args, - state_dict=state_dict, - weight_cache_path=tt_model_args.weight_cache_path(dtype), - layer_num=0, - dtype=dtype, - model_config=tt_model_args.get_model_config(), - ) - - torch_input = torch.randn( - 1, 1, seq_len, tt_model_args.dim, dtype=get_ref_model_dype(reference_model, tt_model_args.model_name) - ) - reference_output = reference_model(torch_input) - - tt_input = ttnn.from_torch( - torch_input, - device=mesh_device, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, 3) if tt_model_args.is_galaxy else (None, None), - mesh_shape=tt_model_args.cluster_shape, - ), # When both dims are None, the mapper used is `ReplicateTensorToMesh` - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - layout=ttnn.TILE_LAYOUT, - ) - - logger.info("Run MLP") - tt_output = tt_model(tt_input, mode) - - tt_output_torch = ttnn.to_torch( - tt_output, - mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=tt_model_args.cluster_shape), - ) - - # tt_output_torch = tt_output_torch[:, :1, :, :] - - pcc_required = 0.99 - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) - - logger.info(comp_allclose(reference_output, tt_output_torch[0])) - logger.info(f"PCC: {pcc_message}") - if passing: - logger.info("MLP Passed!") - else: - logger.warning("MLP Failed!") - - assert passing, f"MLP output does not meet PCC requirement {pcc_required}: {pcc_message}." diff --git a/models/experimental/gemma3/tests/test_rmsnorm.py b/models/experimental/gemma3/tests/test_rmsnorm.py deleted file mode 100644 index 2e734273505b..000000000000 --- a/models/experimental/gemma3/tests/test_rmsnorm.py +++ /dev/null @@ -1,169 +0,0 @@ -"""Gemma3 Test for Text RMSNorm""" - -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 - -from loguru import logger - -import torch -import pytest -import os - -import ttnn -from models.experimental.gemma3.tt.rmsnorm import RMSNorm -from models.tt_transformers.tt.distributed_norm import DistributedNorm - -from models.tt_transformers.tt.ccl import TT_CCL -from models.utility_functions import comp_allclose, skip_for_grayskull -from models.tt_transformers.tt.model_config import ModelArgs - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -@pytest.mark.parametrize( - "tt_layer_name, torch_layer_name, dim", - ( - ("norm", "norm", 1152), - ("layers.0.attention_norm", "layers.0.input_layernorm", 1152), - ("layers.0.ffn_norm", "layers.0.post_attention_layernorm", 1152), - ("layers.0.pre_feedforward_layernorm", "layers.0.pre_feedforward_layernorm", 1152), - ("layers.0.post_feedforward_layernorm", "layers.0.post_feedforward_layernorm", 1152), - ("layers.0.attention.q_norm", "layers.0.self_attn.q_norm", 256), - ("layers.0.attention.k_norm", "layers.0.self_attn.k_norm", 256), - ), -) -@pytest.mark.parametrize( - "seq_len", - (128,), -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -@pytest.mark.parametrize( - "device_params", - [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], - indirect=True, -) -def test_rmsnorm_inference( - seq_len, batch_size, reset_seeds, mesh_device, tt_layer_name, torch_layer_name, device_params, dim -): - dtype = ttnn.bfloat16 - mode = "decode" if seq_len <= 32 else "prefill" - - tt_model_args = ModelArgs( - mesh_device, - max_batch_size=batch_size, - max_seq_len=128, - ) - - tt_model_args.n_layers = 1 - state_dict = tt_model_args.load_state_dict() - reference_model = tt_model_args.reference_transformer(wrap=False) # Gemma3 Entire Model - reference_model = reference_model.model.get_submodule(torch_layer_name) - - state_dict_prefix = "" - first_layer_prefix = state_dict_prefix + tt_layer_name + "." - partial_state_dict = { - k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) - } - - reference_model.load_state_dict(partial_state_dict) - tt_ccl = TT_CCL(mesh_device) - if "q_norm" in tt_layer_name or "k_norm" in tt_layer_name: - tt_model = RMSNorm( - device=mesh_device, - dim=dim, - state_dict=state_dict, - state_dict_prefix=state_dict_prefix, - weight_key=tt_layer_name, - weight_dtype=dtype, - is_distributed=False, - sharded_program_config=None, - sharded_output_config=None, - tt_ccl=tt_ccl, - ) - else: - tt_inner_norm = RMSNorm( - device=mesh_device, - dim=tt_model_args.dim, - state_dict=state_dict, - state_dict_prefix=state_dict_prefix, - weight_key=tt_layer_name, - weight_dtype=dtype, - is_distributed=tt_model_args.is_distributed_norm, - sharded_program_config=tt_model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], - sharded_output_config=tt_model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], - tt_ccl=tt_ccl, - ) - - # Wrap it in DistributedNorm - tt_model = DistributedNorm(tt_inner_norm, tt_model_args, tt_ccl, TG=tt_model_args.is_galaxy) - if "q_norm" in tt_layer_name or "k_norm" in tt_layer_name: - input = torch.rand(1, 1, dim) - else: - input = torch.rand(1, 1, 32, dim) - - reference_output = reference_model(input) - - # DistributedNorm inputs are fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode) - if "q_norm" in tt_layer_name or "k_norm" in tt_layer_name: - tt_input = ttnn.from_torch( - input, - device=mesh_device, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - memory_config=( - tt_model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] - if mode == "decode" - else ttnn.DRAM_MEMORY_CONFIG - ), - ) - else: - tt_input = ttnn.from_torch( - input, - device=mesh_device, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=(None, -1), mesh_shape=tt_model_args.cluster_shape), - memory_config=( - tt_model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] - if mode == "decode" - else ttnn.DRAM_MEMORY_CONFIG - ), - ) - - tt_output = tt_model(tt_input, mode=mode) - - # DistributedNorm outputs are replicated across devices - tt_output_torch = ttnn.to_torch( - tt_output, - mesh_composer=ttnn.ConcatMesh2dToTensor( - mesh_device, dims=(0, 2) if tt_model_args.is_galaxy else (2, 0), mesh_shape=tt_model_args.cluster_shape - ), - )[:1, :, :] - - # tt_output_torch = tt_output_torch.view(1, 1, tt_model_args.dim) - - logger.info(comp_allclose(reference_output, tt_output_torch)) - pcc_message = "RMSNORM" - logger.info(f"PCC: {torch_layer_name} , {pcc_message}") - - passing = 0.99 - if passing: - logger.info("rms_norm Passed!") - else: - logger.warning("rms_norm Failed!") - - assert passing, f"rms_norm output does not meet PCC requirement {0.99}." diff --git a/models/experimental/gemma3/tests/vision_tests/test_end2end.py b/models/experimental/gemma3/tests/vision_tests/test_end2end.py deleted file mode 100644 index 802b253f38ac..000000000000 --- a/models/experimental/gemma3/tests/vision_tests/test_end2end.py +++ /dev/null @@ -1,756 +0,0 @@ -""" End-to-end test for Gemma3 vision-text pipeline.""" -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 -import torch -import pytest -from loguru import logger -import os -import ttnn -from models.tt_transformers.tt.common import ( - encode_prompt_hf, - sample_host, - PagedAttentionConfig, - preprocess_inputs_prefill, -) -from models.tt_transformers.tt.model_config import DecodersPrecision - -from models.experimental.gemma3.tt.text_model import Gemma3Transformer -from models.experimental.gemma3.tt.gemma_vision_crossattention import TtGemmaTransformerVision -from models.experimental.gemma3.tt.gemma3_generator import Gemma3Generator -from models.utility_functions import ( - comp_pcc, - comp_allclose, -) -from models.utility_functions import skip_for_grayskull, skip_for_blackhole -from models.tt_transformers.tt.model_config import HfModelWrapper - -from models.tt_transformers.tt.model_config import ModelArgs -from transformers import AutoProcessor, AutoTokenizer - -import re - - -def parse_chat_output(text): - """Parse chat output format from generated text.""" - pattern = r"<\|(?Puser|assistant)\|>\s*(?P.*?)(?=<\|(?:user|assistant|end)\|>|$)" - matches = re.finditer(pattern, text, re.DOTALL) - return [(match.group("role"), match.group("message").strip()) for match in matches] - - -def display_chat(logger, conversation): - """Display chat conversation in formatted output.""" - for role, message in conversation: - if role == "user": - logger.info(f"👤 User: {message}") - elif role == "assistant": - logger.info(f"🤖 Assistant: {message}") - - -def setup_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations): - """Setup model arguments and configuration.""" - instruct = True if weights == "instruct" else False - - model_args = ModelArgs( - mesh_device=mesh_device, - instruct=instruct, - optimizations=optimizations, - max_seq_len=max_seq_len, - max_batch_size=batch_size, - ) - - return model_args, instruct - - -def setup_prompts_and_tokenizer(model_args, instruct): - """Setup prompts and tokenizer for the test.""" - prompts = ["Write a essay about Lion"] * model_args.max_batch_size - tokenizer = model_args.tokenizer - - if instruct: - encoded_prompts = encode_prompt_hf(tokenizer=tokenizer, prompt_text=prompts[0]) - else: - encoded_prompts = [model_args.encode_prompt(prompt, instruct=False) for prompt in prompts] - - return prompts, tokenizer, encoded_prompts - - -def setup_reference_model(model_args, run_ref_pt): - """Setup reference PyTorch model and embedding.""" - if run_ref_pt: - reference_transformer_model = model_args.reference_transformer(wrap=False) - reference_model = HfModelWrapper(reference_transformer_model, model_args.head_dim) - logger.info("Finished loading reference model.") - embd = model_args.reference_embedding(reference_transformer_model) - else: - reference_model = None - embd = model_args.reference_embedding() - - return reference_model, embd - - -def setup_paged_attention(paged_attention, page_params, model_args, mesh_device): - """Setup paged attention configuration and page table.""" - page_table_tt = None - paged_attention_config = None - - if paged_attention: - paged_attention_config = PagedAttentionConfig( - block_size=page_params["page_block_size"], - max_num_blocks=page_params["page_max_num_blocks"], - ) - # Implied shuffling of blocks - permutation = torch.randperm(paged_attention_config.max_num_blocks) - # Page table which maps virtual blocks to physical - reverse_permutation = torch.argsort(permutation) - page_table = reverse_permutation.reshape( - model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size - ) - page_table_tt = ttnn.from_torch( - page_table, - device=mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, -2) if model_args.max_batch_size > 1 else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - return paged_attention_config, page_table_tt - - -# ============================================================================= -# NEW E2E PIPELINE COMPONENTS - Following SOLID Principles -# ============================================================================= - - -def setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations): - """Setup model arguments for vision-enabled model (Single Responsibility).""" - instruct = True if weights == "instruct" else False - - model_args = ModelArgs( - mesh_device=mesh_device, - instruct=instruct, - optimizations=optimizations, - max_seq_len=max_seq_len, - max_batch_size=batch_size, - ) - return model_args, instruct - - -def setup_vision_prompts_and_tokenizer(model_args, instruct): - """Setup multimodal prompts and tokenizer for vision-enabled model.""" - # Create multimodal messages similar to test_end2end.py - messages = [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What is your favorite condiment? There are so many condiments to choose from, each bringing its unique flavor and texture to enhance different dishes. Do you prefer the classic taste of ketchup, the creamy richness of mayonnaise, the spicy kick of mustard, or perhaps something more exotic like sriracha or hoisin sauce? Maybe you enjoy the tangy zest of salsa or the smooth and savory taste of aioli. Share what your favorite condiment is and why you love it. Does it remind you of a specific dish or meal?", - }, - ], - } - ] - - messages = [ - { - "role": "user", - "content": [ - { - "type": "image", - "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", - }, - {"type": "text", "text": "Describe this image in detail."}, - ], - } - ] - - tokenizer = model_args.tokenizer - return messages, tokenizer - - -def setup_vision_reference_model(model_args, run_ref_pt): - """Setup reference vision-enabled model (Open/Closed Principle).""" - if run_ref_pt: - reference_transformer_model = model_args.reference_vision_transformer(wrap=False) - reference_model = HfModelWrapper(reference_transformer_model, model_args.head_dim) - logger.info("Finished loading reference vision model.") - embd = model_args.reference_embedding(reference_transformer_model) - else: - reference_model = None - embd = model_args.reference_embedding() - - return reference_model, embd - - -def process_real_vision_inputs(messages, model_args): - """Process real image inputs using AutoProcessor (Interface Segregation).""" - model_id = model_args.CKPT_DIR - - try: - # Try loading processor (works for models that has preprocessor_config.json) - processor = AutoProcessor.from_pretrained(model_id) - except OSError: - # Fallback to tokenizer - processor = AutoTokenizer.from_pretrained(model_id) - - # Process the multimodal messages similar to test_end2end.py - encoded = processor.apply_chat_template( - messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" - ).to(torch.bfloat16) - - input_ids = encoded["input_ids"] - pixel_values = encoded["pixel_values"] if "pixel_values" in encoded else None - attention_mask = encoded["attention_mask"] - - # logger.info(f"Processed vision inputs - input_ids: {input_ids.shape}, pixel_values: {pixel_values.shape}") - - return { - "input_ids": input_ids, - "pixel_values": pixel_values, - "attention_mask": attention_mask, - "processor": processor, - "input_prompts": messages, - } - - -# Legacy function removed - vision model now part of multimodal model - - -def load_separate_models_like_test_end2end(model_args, mesh_device, dtype, paged_attention, page_params): - """Load separate vision and text models following test_end2end.py pattern.""" - state_dict = model_args.load_state_dict() - vision_prefix = "vision_tower.vision_model." - - # Setup paged attention config (exactly like test_end2end.py) - paged_attention_config = None - if paged_attention: - paged_attention_config = PagedAttentionConfig( - block_size=page_params["page_block_size"], - max_num_blocks=page_params["page_max_num_blocks"], - ) - - # Load vision model (exactly like test_end2end.py) - if model_args.is_multimodal: - vision_model = TtGemmaTransformerVision( - mesh_device=mesh_device, - state_dict=state_dict, - state_dict_prefix=vision_prefix, - dtype=dtype, - configuration=model_args, - weight_cache_path=model_args.weight_cache_path(dtype), - ) - else: - vision_model = None - # Load text model (exactly like test_end2end.py) - text_model = Gemma3Transformer( - args=model_args, - mesh_device=mesh_device, - dtype=dtype, - state_dict=state_dict, - weight_cache_path=model_args.weight_cache_path(dtype), - paged_attention_config=paged_attention_config, - ) - - logger.info("Separate vision and text models loaded like test_end2end.py") - return vision_model, text_model - - -def run_generation_exactly_like_test_end2end( - vision_model, text_model, processed_inputs, model_args, page_table=None, paged_attention_config=None, max_gen_len=20 -): - """Run generation following the EXACT pattern from test_end2end.py.""" - input_ids = processed_inputs["input_ids"] - pixel_values = processed_inputs["pixel_values"] - input_prompts = processed_inputs["input_prompts"] - processor = processed_inputs["processor"] - - logger.info("Running generation exactly like test_end2end.py...") - - # Process vision (exactly like test_end2end.py) - logger.info("Running Vision Model...") - - # Create Generator (exactly like test_end2end.py) - generator = Gemma3Generator([text_model], [model_args], text_model.mesh_device, tokenizer=model_args.tokenizer) - - # Setup KV cache (exactly like test_end2end.py) - tt_kv_cache = [[l.attention.layer_past for l in text_model.layers]] if paged_attention_config else None - - # Get embeddings and combine with vision (exactly like test_end2end.py) - # host_embedding = model_args.reference_embedding() - - # # Text generation setup (exactly like test_end2end.py) - input_tokens_prefill = input_ids - batch_size = input_tokens_prefill.shape[0] - # seq_len = input_tokens_prefill.shape[1] - model_args.tokenizer = processor - - ( - input_tokens_prefill_pt, - encoded_prompts, - decoding_pos, - prefill_lens, - ) = preprocess_inputs_prefill( - input_prompts, - model_args.tokenizer, - [model_args], - instruct=True, - max_generated_tokens=max_gen_len, - max_prefill_len=8192, - ) - - input_tokens_prefill_pt = torch.stack(input_tokens_prefill_pt).view(batch_size, -1) - - logger.info("Running prefill...") - logits = generator.prefill_forward_text( - input_tokens_prefill_pt, - page_table=page_table, - kv_cache=tt_kv_cache, - prompt_lens=decoding_pos, - vision_model=vision_model, - processed_inputs=processed_inputs, - ) - - # Get first token (exactly like test_end2end.py) - prefilled_token = torch.argmax(logits, dim=-1) - logger.info(f"Prefilled token: {prefilled_token}") - - # Initialize generation (exactly like test_end2end.py) - all_outputs = [encoded_prompts[0][: prefill_lens[0]]] - all_outputs[0].append(int(prefilled_token[0].item())) - - current_pos = torch.tensor([decoding_pos[0]]) - out_tok = prefilled_token - generation_length = 150 - - results = [] - - # Decode loop (exactly like test_end2end.py) - logger.info("Starting decode loop...") - for iteration in range(generation_length): - logger.info(f"[Text] Decoding token {iteration}, current_pos: {current_pos.item()}") - - # Run decode (exactly like test_end2end.py) - logits = generator.decode_forward_text( - out_tok, - current_pos, - enable_trace=False, - page_table=page_table, - kv_cache=tt_kv_cache, - ) - - # Sample next token (exactly like test_end2end.py) - _, out_tok = sample_host( - logits, - temperature=0, - top_p=0.9, - ) - - token_id = out_tok[0].item() - decoded_token = model_args.tokenizer.decode([token_id]) - logger.info(f"Generated token {iteration}: ID={token_id}, text='{decoded_token}'") - - # Create result object - result = type("TokenResult", (), {"token": token_id, "text": decoded_token})() - - results.append(result) - - all_outputs[0].append(token_id) - current_pos += 1 - - # Early stopping (exactly like test_end2end.py) - if len(all_outputs[0]) >= 5 and all(t == all_outputs[0][-1] for t in all_outputs[0][-5:]): - logger.warning(f"Detected exact repetition of token {all_outputs[0][-1]} five times in a row. Stopping.") - break - - # Final response (exactly like test_end2end.py) - response = model_args.tokenizer.decode(all_outputs[0], skip_special_tokens=True) - logger.info(f"📝 Final Generated Response:\n{response}") - logger.info(f"📝 Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") - chat = parse_chat_output(response) - display_chat(logger, chat) - - logger.info(f"Generated {len(results)} tokens successfully") - return results - - -# Legacy function removed - vision processing now handled in multimodal model - - -def validate_e2e_outputs(results, expected_min_tokens=1): - """Validate end-to-end pipeline outputs.""" - if not results: - logger.error("No results generated from E2E pipeline") - return False - - if len(results) < expected_min_tokens: - logger.warning(f"Generated only {len(results)} tokens, expected at least {expected_min_tokens}") - return False - - # Check if tokens are valid - for result in results: - if not hasattr(result, "token") or not hasattr(result, "text"): - logger.error("Invalid result format") - return False - - logger.info("E2E pipeline validation passed") - return True - - -# ============================================================================= -# EXISTING FUNCTIONS (Unchanged for backward compatibility) -# ============================================================================= - - -def create_position_tensor(current_pos, model_args, mesh_device): - """Create position tensor for the model.""" - return ttnn.from_torch( - current_pos, - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, 0) if (model_args.is_galaxy and model_args.max_batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - -def convert_tt_output_to_torch(tt_out, model_args, mesh_device): - """Convert TTNN tensor to PyTorch tensor.""" - mesh_composer = ttnn.ConcatMesh2dToTensor( - mesh_device, dims=(3, 1) if model_args.is_galaxy else (1, -1), mesh_shape=model_args.cluster_shape - ) - tt_output_torch = ( - ttnn.to_torch(tt_out, mesh_composer=mesh_composer) - .permute(2, 1, 0, 3) - .squeeze(2)[: model_args.max_batch_size, 0:1, : model_args.vocab_size] - ) - ttnn.deallocate(tt_out) - return tt_output_torch - - -def process_token_generation( - i, - encoded_prompts, - encoded_prompts_tensor, - embd, - batch, - seqlen, - all_outputs, - all_outputs_ref, - run_ref_pt, - ref_output, - tt_output_torch, -): - """Process token generation for both prefill and decode phases.""" - if i in range(len(encoded_prompts)): - # While in "prefill" mode, use the prompt tokens as the output - all_outputs.append(encoded_prompts[i]) - if run_ref_pt: - all_outputs_ref.append(encoded_prompts[i]) - - tt_decode_input = embd(encoded_prompts_tensor[:, i]).view(batch, seqlen, -1) - if run_ref_pt: - pt_decode_input = embd(encoded_prompts_tensor[:, i]).view(batch, seqlen, -1) - else: - pt_decode_input = None - else: - # Greedy decode (temperature = 0) the generated token and save it to print out later - # Exact copy of original logic (including commented sections) - # if run_ref_pt: - # # Sample from reference model first - _, pt_out_tok = sample_host(ref_output, temperature=0, top_p=0.8) - pt_decode_input = embd(pt_out_tok) - all_outputs_ref.append(pt_out_tok.squeeze(1).tolist()[0]) - - # Use the same token for TT model (teacher forcing) - tt_decode_input = pt_decode_input - # all_outputs.append(pt_out_tok.squeeze(1).tolist()[0]) - # else: - # If not running reference model, sample from TT model directly - _, tt_out_tok = sample_host(tt_output_torch, temperature=0, top_p=0.8) - tt_decode_input = embd(tt_out_tok) - all_outputs.append(tt_out_tok.squeeze(1).tolist()[0]) - - return tt_decode_input, pt_decode_input - - -def validate_outputs(run_ref_pt, ref_output, tt_output_torch, pcc, all_outputs, all_outputs_ref, tokenizer, logger): - """Validate model outputs and compute PCC.""" - if run_ref_pt: - passing, pcc_message = comp_pcc(ref_output, tt_output_torch, pcc) - - # Decode the output tokens back to text - decoded_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in all_outputs] - logger.info(f'TTNN Decoded Outputs: {"".join(decoded_texts)}') - decoded_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in all_outputs_ref] - logger.info(f'Torch Decoded Outputs: {"".join(decoded_texts)}') - - logger.info(comp_allclose(ref_output, tt_output_torch)) - logger.info(f"PCC: {pcc_message}") - - if passing: - logger.info("Model Passed!") - else: - logger.warning("Model Failed!") - - return passing - return True - - -def run_generation_loop( - tt_model, - model_args, - mesh_device, - reference_model, - embd, - encoded_prompts_tensor, - generation_length, - generation_start_pos, - batch, - seqlen, - page_table_tt, - run_ref_pt, - pcc, - tokenizer, - logger, - parse_chat, - encoded_prompts, -): - """Run the main token generation loop.""" - all_outputs = [] - all_outputs_ref = [] if run_ref_pt else [] - if run_ref_pt: - all_tests_pass = True - - # Initial setup - current_pos = torch.tensor([generation_start_pos for _ in range(batch)]) - current_pos_tensor = create_position_tensor(current_pos, model_args, mesh_device) - - # Select the first token from the prompts for initial decoding - pt_decode_input = embd(encoded_prompts_tensor[:, 0]).view(batch, seqlen, -1) - tt_decode_input = pt_decode_input - - for i in range(generation_length): - logger.info(f"[Model] Generating token {i}") - - # Prepare input - decode_input = model_args.prepare_residual_tensor_decode( - tt_decode_input, - model_args.model_config["DECODE_RESIDUAL_MEMCFG"], - ) - - # Get rotation matrices - rot_mats_global = tt_model.rope_setup.get_rot_mats(current_pos) - rot_mats_local = tt_model.rope_setup_local.get_rot_mats(current_pos) - rot_mats = [rot_mats_global, rot_mats_local] - - # Run TT model - tt_out = tt_model( - decode_input, - current_pos_tensor, - rot_mats=rot_mats, - mode="decode", - page_table=page_table_tt, - ) - - # Convert output - tt_output_torch = convert_tt_output_to_torch(tt_out, model_args, mesh_device) - - # Run reference model if needed - ref_output = None - if run_ref_pt: - ref_output = reference_model(pt_decode_input, current_pos[0]) - - # Update position - current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch)]) - current_pos_tensor = create_position_tensor(current_pos, model_args, mesh_device) - - # Process token generation - tt_decode_input, pt_decode_input = process_token_generation( - i, - encoded_prompts, - encoded_prompts_tensor, - embd, - batch, - seqlen, - all_outputs, - all_outputs_ref, - run_ref_pt, - ref_output, - tt_output_torch, - ) - - # Validate outputs - passing = validate_outputs( - run_ref_pt, ref_output, tt_output_torch, pcc, all_outputs, all_outputs_ref, tokenizer, logger - ) - - # Note: Individual PCC failures don't affect overall test result (matching original behavior) - # if not passing: - # all_tests_pass = False - - # Display chat if enabled - if parse_chat: - conversation = parse_chat_output(tokenizer.decode(all_outputs).replace("\n", "\\n")) - display_chat(logger, conversation) - - if run_ref_pt: - return all_tests_pass - else: - return True # If not running reference model, always pass - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@skip_for_blackhole("Failing on DRAM harvested P100a, see #21419") -@pytest.mark.timeout(1800) -@pytest.mark.models_performance_bare_metal -@pytest.mark.parametrize( - "weights, layers", - [ - ("instruct", None), - ], - ids=["full"], -) -@pytest.mark.parametrize( - "paged_attention", - ( - True, - # False, - ), - ids=( - "paged_attention", - # "default_attention", - ), -) -@pytest.mark.parametrize( - "device_params", - [ - { - "fabric_config": ttnn.FabricConfig.FABRIC_1D, - "trace_region_size": 30000000, - "num_command_queues": 1, - "l1_small_size": 10 * 1024, - } - ], - indirect=True, -) -@pytest.mark.parametrize( - "page_params", - [{"page_block_size": 32, "page_max_num_blocks": 1024}], -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -@pytest.mark.parametrize( - "max_seq_len", - (1024 * 8,), # Use smaller seq_len like test_end2end.py to avoid memory issues -) -@pytest.mark.parametrize( - "optimizations", - [ - lambda model_args: DecodersPrecision.accuracy(model_args.n_layers, model_args.model_name), - ], - ids=["accuracy"], -) -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "N150x4": (1, 4), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -def test_e2e_vision_text_pipeline( - weights, - layers, - max_seq_len, - batch_size, - paged_attention, - page_params, - optimizations, - mesh_device, - reset_seeds, - request, - device_params, -): - """Test end-to-end vision-text pipeline using proper Generator methods.""" - logger.info("Starting E2E vision-text pipeline test") - - # Use bfloat8_b like test_end2end.py for better memory efficiency - dtype = ttnn.bfloat8_b - # Setup vision-enabled model configuration - model_args, instruct = setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations) - - if layers is not None: - model_args.n_layers = layers - - # Setup vision prompts and tokenizer - messages, tokenizer = setup_vision_prompts_and_tokenizer(model_args, instruct) - - # Process real vision inputs from images - processed_inputs = process_real_vision_inputs(messages, model_args) - - # Load separate models following test_end2end.py pattern - logger.info("Loading separate vision and text models like test_end2end.py...") - vision_model, text_model = load_separate_models_like_test_end2end( - model_args, mesh_device, dtype, paged_attention, page_params - ) - - # Setup page table for paged attention (exactly like test_end2end.py) - page_table_tt = None - paged_attention_config = None - - # Prepare page table for paged attention (exactly like test_end2end.py) - page_table = None - if paged_attention: - paged_attention_config = PagedAttentionConfig( - block_size=page_params["page_block_size"], - max_num_blocks=page_params["page_max_num_blocks"], - ) - # Implied shuffling of blocks - permutation = torch.randperm(paged_attention_config.max_num_blocks) - # Page table which maps virtual blocks to physical - reverse_permutation = torch.argsort(permutation) - page_table = reverse_permutation.reshape( - model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size - ) - page_table_tt = ttnn.from_torch( - page_table, - device=mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, -2) if batch_size > 1 else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - # Run generation following EXACT test_end2end.py pattern - logger.info("Running generation following EXACT test_end2end.py pattern...") - results = run_generation_exactly_like_test_end2end( - vision_model, text_model, processed_inputs, model_args, page_table, paged_attention_config, max_gen_len=10 - ) - - # Validate results - validation_passed = validate_e2e_outputs(results, expected_min_tokens=1) - - # Final validation - if validation_passed and len(results) > 0: - logger.info("✅ E2E vision-text pipeline test PASSED!") - logger.info(f"Successfully generated {len(results)} tokens") - - # Log generated tokens for debugging - for i, result in enumerate(results[:5]): - logger.info(f"Token {i}: {result.token} -> '{result.text}'") - else: - logger.error("❌ E2E pipeline test failed") - assert False, f"E2E pipeline failed - generated {len(results)} tokens, validation: {validation_passed}" diff --git a/models/experimental/gemma3/tt/attention.py b/models/experimental/gemma3/tt/attention.py deleted file mode 100644 index f070e2d8224e..000000000000 --- a/models/experimental/gemma3/tt/attention.py +++ /dev/null @@ -1,961 +0,0 @@ -""" -source: models/tt_transformers/tt/attention.py - -This is the attention implementation of the Gemma3 - -We have re-used the Attention implementation of the TT-Transformers with few modifications. -This implementation has Changes in Datatype (Bfloat16) that supports the RMSNorm, -Sliding Window support. - -""" - -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 - -import math - -import torch - -import ttnn -from models.common.lightweightmodule import LightweightModule - -from models.experimental.gemma3.tt.rmsnorm import RMSNorm -from models.tt_transformers.tt.ccl import tt_all_gather, tt_all_reduce -from models.tt_transformers.tt.model_config import OpGroup, TensorGroup - - -class Attention(LightweightModule): - def __init__( - self, - mesh_device, - tt_ccl, - state_dict, - weight_cache_path, - layer_num, - dtype, - transformation_mats, - configuration, - paged_attention_config=None, - use_paged_kv_cache=False, - ): - super().__init__() - - self.layer_idx = layer_num - self.configuration = configuration - self.mesh_device = mesh_device - self.tt_ccl = tt_ccl - self.num_devices = configuration.num_devices - self.TG = self.num_devices == 32 - self.hidden_size = configuration.dim - self.n_heads = configuration.n_heads - self.head_dim = configuration.head_dim - self.max_seq_len = configuration.max_seq_len - self.max_batch_size = configuration.max_batch_size - self.n_kv_heads = configuration.n_kv_heads - self.paged_attention_config = paged_attention_config - self.min_kv_prefill_shard_seqlen = configuration.min_kv_prefill_shard_seqlen - self.ccl_dtype = configuration.ccl_dtype - self.num_reduce_scatter_links = configuration.num_reduce_scatter_links - self.num_all_gather_links = configuration.num_all_gather_links - self.MAX_QKV_MM_SEQ_LEN = configuration.MAX_QKV_MM_SEQ_LEN - self.tile_size = configuration.tile_size - self.rms_norm_add_unit_offset = configuration.rms_norm_add_unit_offset - self.num_device_groups = self.num_devices // self.n_kv_heads - self.num_devices_per_group = self.n_kv_heads if self.TG else self.num_devices - self.batch_size_per_device_group = ( - max(self.max_batch_size // self.num_device_groups, 1) if self.TG else self.max_batch_size - ) - - self.n_local_heads = self.n_heads // self.num_devices_per_group - self.n_local_kv_heads = self.n_kv_heads // self.num_devices_per_group - - self.arch_name = configuration.arch_name - # TODO: Fix this once all-gather supports < tile_size - if self.TG: - weight = torch.zeros(1, 32, 8, 32) - for i in range(32): - col = i % 4 # This determines which group of 8 to select - weight[:, i, :, col * 8 : (col + 1) * 8] = torch.eye(8) - - self.slice_mat = ttnn.from_torch( - weight, - dtype=ttnn.bfloat4_b, - layout=ttnn.TILE_LAYOUT, - device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=1), - ) - user_selection_matrix = torch.eye(8, 8) - user_selection_matrix = torch.nn.functional.pad(user_selection_matrix, (0, 24), "constant", 0) # (8, 32) - user_selection_matrix = [user_selection_matrix] * 4 - user_selection_matrix = torch.block_diag(*user_selection_matrix) # (32, 128) - self.user_selection_matrix = ttnn.from_torch( - user_selection_matrix, - dtype=ttnn.bfloat4_b, - layout=ttnn.TILE_LAYOUT, - device=self.mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - - self.dtype = dtype - - self.max_seq_len = configuration.max_seq_len - self.grid_size = configuration.max_grid_size - - self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 - self.compute_kernel_config_hifi2_fp16 = configuration.compute_kernel_config_hifi2_fp16 - - self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 - - self.transformation_mats = transformation_mats - self.is_sliding = ( - configuration.layer_types[layer_num] == "sliding_attention" if configuration.layer_types else False - ) - - self.model_config = configuration.get_model_config() - self.ccl_topology = configuration.ccl_topology() - self.is_multichip = configuration.is_multichip - self.activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=layer_num, tensor=TensorGroup.ACTIVATION - ) - self.wqkv_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=layer_num, tensor=TensorGroup.WQKV - ) - self.wo_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=layer_num, tensor=TensorGroup.WO - ) - self.kv_cache_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=layer_num, tensor=TensorGroup.KV_CACHE - ) - self.li_qkv_decode_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.LI_QKV_DECODE, configuration=configuration - ) - self.sdpa_decode_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.SDPA_DECODE, configuration=configuration - ) - self.li_o_decode_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.LI_O_DECODE, configuration=configuration - ) - self.sdpa_prefill_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.SDPA_PREFILL, configuration=configuration - ) - self.li_qkv_prefill_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.LI_QKV_PREFILL, configuration=configuration - ) - self.li_o_prefill_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.LI_O_PREFILL, configuration=configuration - ) - - layer_name = configuration.get_state_dict_prefix(self.__class__.__name__, layer_num) - if configuration.dummy_weights or (weight_cache_path is None): - cache_name = lambda _: None - else: - cache_name = lambda name: weight_cache_path / (f"{layer_name}.{name}") - - wq_str = f"{layer_name}.wq" - wk_str = f"{layer_name}.wk" - wv_str = f"{layer_name}.wv" - wo_str = f"{layer_name}.wo" - q_norm_str = f"{layer_name}.q_norm" - k_norm_str = f"{layer_name}.k_norm" - - # Initialize bias tensors as None - self.wqkv_bias_decode = None - self.wqkv_bias_prefill = None - - # Create combined QKV bias if present in state dict - if f"{wq_str}.bias" in state_dict: - qkv_bias = torch.concat( - [ - torch.concat( - [ - torch.chunk(state_dict[f"{wq_str}.bias"], configuration.num_devices)[i], - torch.chunk(state_dict[f"{wk_str}.bias"], configuration.num_devices)[i], - torch.chunk(state_dict[f"{wv_str}.bias"], configuration.num_devices)[i], - ], - dim=-1, - ) - for i in range(configuration.num_devices) - ], - dim=-1, - ) - # Prefill can use broadcasting on the bias add so wants a 1d tensor - self.wqkv_bias_prefill = ttnn.as_tensor( - qkv_bias, - device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - layout=ttnn.TILE_LAYOUT, - cache_file_name=cache_name("wqkv_bias_prefill_sharded"), - ) - # as_tensor returns (32, dim) which is incorrect, this reshape updates the padded size to the correct size - self.wqkv_bias_prefill = ttnn.reshape( - self.wqkv_bias_prefill, - (1, 1, 1, self.wqkv_bias_prefill.shape[-1]), - (1, 1, self.wqkv_bias_prefill.shape[-2], self.wqkv_bias_prefill.shape[-1]), - ) - - # Broadcasting does not seem to be supported inside execute_trace so expand to the whole batch size - # Create a list of bias tensors for each multiple of tile_size up to max_batch_size - self.wqkv_bias_decode = [] - for batch_size in range( - configuration.tile_size, - configuration.tile_padded_batch_rows + configuration.tile_size, - configuration.tile_size, - ): - qkv_bias_decode = qkv_bias.unsqueeze(0).expand(batch_size, -1) - bias_tensor = ttnn.as_tensor( - qkv_bias_decode, - device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - layout=ttnn.TILE_LAYOUT, - cache_file_name=cache_name(f"wqkv_bias_decode_sharded_{batch_size}"), - ) - self.wqkv_bias_decode.append(bias_tensor) - - # when splitting the devices, we need to make sure that the number of heads is divisible by the number of devices - assert self.n_heads % self.num_devices_per_group == 0 - assert self.n_kv_heads % self.num_devices_per_group == 0 - assert configuration.qkv_size % self.num_devices_per_group == 0 - assert configuration.dim % self.num_devices_per_group == 0 - - # wqkv: 4096 x 3072 (2 devices): width-sharded on 12 banks, 3072 over 12 banks. - wqkv_mem_config = configuration.create_dram_sharded_mem_config( - configuration.dim, configuration.qkv_size // configuration.num_devices - ) - - qkv_list = [] - for i in range(self.num_devices_per_group): - # Chunk weights - wq_selected = torch.chunk(state_dict[f"{wq_str}.weight"], self.num_devices_per_group, dim=0)[i] - wk_selected = torch.chunk(state_dict[f"{wk_str}.weight"], self.num_devices_per_group, dim=0)[i] - wv_selected = torch.chunk(state_dict[f"{wv_str}.weight"], self.num_devices_per_group, dim=0)[i] - - # Transpose the selected chunks - wq = torch.transpose(wq_selected, -2, -1) - wk = torch.transpose(wk_selected, -2, -1) - wv = torch.transpose(wv_selected, -2, -1) - - qkv = torch.cat([wq, wk, wv], dim=-1) - qkv_list.append(qkv) - - qkv_cat = torch.cat(qkv_list, dim=-1).unsqueeze(0).unsqueeze(0) - - self.wqkv = ttnn.as_tensor( - qkv_cat, - dtype=self.wqkv_dtype, - layout=ttnn.TILE_LAYOUT, - device=self.mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG if self.TG else wqkv_mem_config, - mesh_mapper=ttnn.ShardTensor2dMesh( - self.mesh_device, dims=(3, 2) if self.TG else (2, 3), mesh_shape=configuration.cluster_shape - ), - cache_file_name=cache_name("wqkv_sharded_2d"), - ) - - def norm_reshard(x, norm, mode): - """Hack until RMSNorm supports height-sharded output config""" - if mode == "decode": - mem_cfg = x.memory_config() - x = ttnn.to_memory_config(x, ttnn.L1_MEMORY_CONFIG, dtype=x.dtype) - x = norm(x, mode) - if mode == "decode": - x = ttnn.to_memory_config(x, mem_cfg, dtype=x.dtype) - return x - - if f"{q_norm_str}.weight" in state_dict: - fn_q_norm = RMSNorm( - device=self.mesh_device, - dim=self.head_dim, - eps=configuration.norm_eps, - state_dict=state_dict, - state_dict_prefix=None, # we already prefix q_norm_str - weight_cache_path=None if configuration.dummy_weights else weight_cache_path, - weight_dtype=ttnn.bfloat16, - weight_key=q_norm_str, - add_unit_offset=self.rms_norm_add_unit_offset, - is_distributed=False, - sharded_program_config=None, # FIXME: add height-sharded support. self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"], - sharded_output_config=None, # FIXME: add height-sharded support. self.model_config["CREATE_QKV_DECODE_SHARD"] - tt_ccl=self.tt_ccl, - ) - self.q_norm = lambda x, mode: norm_reshard(x, fn_q_norm, mode) - else: - self.q_norm = lambda x, mode: x - - if f"{k_norm_str}.weight" in state_dict: - fn_k_norm = RMSNorm( - device=self.mesh_device, - dim=self.head_dim, - eps=configuration.norm_eps, - state_dict=state_dict, - state_dict_prefix=None, # we already prefix k_norm_str - weight_cache_path=None if configuration.dummy_weights else weight_cache_path, - weight_dtype=ttnn.bfloat16, - weight_key=k_norm_str, - add_unit_offset=self.rms_norm_add_unit_offset, - is_distributed=False, - sharded_program_config=None, # FIXME: add height-sharded support. self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"], - sharded_output_config=None, # FIXME: add height-sharded support. self.model_config["CREATE_QKV_DECODE_SHARD"], - tt_ccl=self.tt_ccl, - ) - self.k_norm = lambda x, mode: norm_reshard(x, fn_k_norm, mode) - else: - self.k_norm = lambda x, mode: x - - # For ring topology we can use all gather matmul for wo - self.use_fused_all_gather_matmul = self.model_config["USE_FUSED_ALL_GATHER_MATMUL"] - pt_wo = state_dict[f"{wo_str}.weight"].transpose(-1, -2).unsqueeze(0).unsqueeze(0) - - wo_mem_config = configuration.create_dram_sharded_mem_config( - (configuration.n_heads * configuration.head_dim) // configuration.num_devices, configuration.dim - ) - - self.wo = ttnn.as_tensor( - pt_wo, - dtype=self.wo_dtype, - layout=ttnn.TILE_LAYOUT, - device=self.mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG if (self.use_fused_all_gather_matmul or self.TG) else wo_mem_config, - mesh_mapper=ttnn.ShardTensor2dMesh( - self.mesh_device, - dims=(2, 3) if (self.use_fused_all_gather_matmul or self.TG) else (3, 2), - mesh_shape=configuration.cluster_shape, - ), - cache_file_name=( - cache_name("wo_width_sharded_2d") if (self.use_fused_all_gather_matmul or self.TG) else cache_name("wo") - ), - ) - if not use_paged_kv_cache: - # vLLM provides its own kv cache - self.init_kv_cache(configuration, weight_cache_path) - - if configuration.query_pre_attn_scalar is not None: - self.scale = configuration.query_pre_attn_scalar**-0.5 - else: - self.scale = self.head_dim**-0.5 - - def init_kv_cache(self, configuration, weight_cache_path): - """ - Generates empty KV cache and pushed to device memory - """ - - if self.paged_attention_config: - cache_k = torch.zeros( - ( - self.paged_attention_config.max_num_blocks, - self.n_local_kv_heads, - self.paged_attention_config.block_size, - self.head_dim, - ) - ) - cache_v = torch.zeros( - ( - self.paged_attention_config.max_num_blocks, - self.n_local_kv_heads, - self.paged_attention_config.block_size, - self.head_dim, - ) - ) - else: - cache_k = torch.zeros( - ( - self.batch_size_per_device_group, - self.n_local_kv_heads, - self.max_seq_len, - self.head_dim, - ) - ) - cache_v = torch.zeros( - ( - self.batch_size_per_device_group, - self.n_local_kv_heads, - self.max_seq_len, - self.head_dim, - ) - ) - - self.layer_past = [ - ttnn.as_tensor( - k_or_v, - dtype=self.kv_cache_dtype, - layout=self.model_config["ATTN_W_LAYOUT_TILE"], - device=self.mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - cache_file_name=( - f"{weight_cache_path}/kvcache_{k_or_v.shape}" - if weight_cache_path and not configuration.dummy_weights - else None - ), - ) - for k_or_v in [cache_k, cache_v] - ] - - def forward_decode( - self, - x: ttnn.Tensor, - current_pos, - rot_mats=None, - page_table=None, - kv_cache=None, - ) -> ttnn.Tensor: - """ - x: (seq_len, 1, batch, dim) - current_pos: (batch_size), current token position in the sequence for each user - """ - - ### - # QKV matmuls - # Use HiFi2 for DRAM-sharded matmuls as they are otherwise flop-bound. Loses 1 bit of activation precision. - ### - - xqkv_fused_sharded = ttnn.linear( - x, - self.wqkv, - # bias=self.wqkv_bias, - memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - program_config=self.model_config["XQKV_DECODE_PROGCFG"], - compute_kernel_config=self.li_qkv_decode_compute_kernel_cfg, - dtype=self.ccl_dtype if self.TG else self.activation_dtype or ttnn.bfloat16, - ) - # FIXME: File bug against dram-sharded matmuls with bias - if self.wqkv_bias_decode: - # select the bias tensor based on the number of tiles in the rows - # WARNING: must not change the batch size between compiling and executing a trace - num_tiles = int(math.ceil(xqkv_fused_sharded.shape[-2] / self.tile_size)) - xqkv_fused_sharded = xqkv_fused_sharded + self.wqkv_bias_decode[num_tiles - 1] - - ttnn.deallocate(x) - xqkv_fused = tt_all_reduce( - xqkv_fused_sharded, - self.mesh_device, - self.tt_ccl, - cluster_axis=1, - num_reduce_scatter_links=self.num_reduce_scatter_links, - num_all_gather_links=self.num_all_gather_links, - memory_config=self.model_config["QKV_OUT_GATHERED_MEMCFG"](list(self.mesh_device.shape)[1]), - sharded=True, - dtype=self.ccl_dtype, - topology=self.ccl_topology, - ) - - if self.TG: - # TODO: Slice the fused_query_key_value tensor get batch=8 - xqkv_fused = ttnn.matmul( - self.slice_mat, - xqkv_fused, - dtype=ttnn.bfloat16, - memory_config=self.model_config["CREATE_HEAD_INPUT_MEMCFG"], - ) - else: - # bfloat16 is required by nlp_create_qkv_heads_decode - xqkv_fused = ttnn.sharded_to_interleaved(xqkv_fused_sharded, ttnn.L1_MEMORY_CONFIG, ttnn.bfloat16) - - ttnn.deallocate(xqkv_fused_sharded) - - # Reshape such that true unpadded batch is tracked in shape - fqkv_shape = xqkv_fused.shape - xqkv_fused = ttnn.reshape( - xqkv_fused, (1, 1, self.batch_size_per_device_group, fqkv_shape[3]), (1, 1, 32, fqkv_shape[3]) - ) - - ### - # Reshape and rotary embeddings - ### - ( - q_heads_pre_rot_1BQD, - k_heads_pre_rot_1BKD, - v_heads_1BKD, - ) = ttnn.experimental.nlp_create_qkv_heads_decode( - xqkv_fused, - num_heads=self.n_local_heads, - num_kv_heads=self.n_local_kv_heads, - memory_config=self.model_config["CREATE_QKV_DECODE_SHARD"], - ) - - q_heads_pre_rot_1BQD = self.q_norm(q_heads_pre_rot_1BQD, mode="decode") - k_heads_pre_rot_1BKD = self.k_norm(k_heads_pre_rot_1BKD, mode="decode") - - ttnn.deallocate(xqkv_fused) - - # Q Rotary Embeddings - q_heads_1BQD = ttnn.experimental.rotary_embedding_llama( - q_heads_pre_rot_1BQD, rot_mats[0], rot_mats[1], self.transformation_mats["decode"], is_decode_mode=True - ) - - # K Rotary Embeddings - k_heads_1BKD = ttnn.experimental.rotary_embedding_llama( - k_heads_pre_rot_1BKD, rot_mats[0], rot_mats[1], self.transformation_mats["decode"], is_decode_mode=True - ) - - ttnn.deallocate(q_heads_pre_rot_1BQD) - ttnn.deallocate(k_heads_pre_rot_1BKD) - - ### - # KV update - ### - if kv_cache: - keys = kv_cache[0] - values = kv_cache[1] - else: - keys = self.layer_past[0] - values = self.layer_past[1] - # k_heads, [seqlen, n_kv_heads, bsz, head_dim] - # v_heads [seqlen, n_kv_heads, bsz, head_dim] - # keys, [max_batch_size, n_kv_heads // configuration.num_devices, max_seq_len, head_dim] - ttnn.experimental.paged_update_cache(keys, k_heads_1BKD, update_idxs_tensor=current_pos, page_table=page_table) - ttnn.experimental.paged_update_cache( - values, v_heads_1BKD, update_idxs_tensor=current_pos, page_table=page_table - ) - - ttnn.deallocate(k_heads_1BKD) - ttnn.deallocate(v_heads_1BKD) - - # NOTE: Varying the batch size will result in slightly different outputs. - # For example, a prompt w/ 1 user vs, the same prompt repeated N times for N users, will produce different outputs - # This is because the SDPA op in decode mode has different number of reductions depending on batch size - # Which leads to slightly different outputs from attention (due to accumulated errors) - - if page_table: - attn_output_1G4D = ttnn.transformer.paged_scaled_dot_product_attention_decode( - q_heads_1BQD, - keys, - values, - cur_pos_tensor=current_pos, - page_table_tensor=page_table, - scale=self.scale, - program_config=self.model_config["SDPA_DECODE_PROGCFG"], - compute_kernel_config=self.sdpa_decode_compute_kernel_cfg, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - sliding_window=self.configuration.sliding_window if self.is_sliding else 0, - ) - else: - attn_output_1G4D = ttnn.transformer.scaled_dot_product_attention_decode( - q_heads_1BQD, - keys, - values, - cur_pos_tensor=current_pos, - scale=self.scale, - program_config=self.model_config["SDPA_DECODE_PROGCFG"], - compute_kernel_config=self.sdpa_decode_compute_kernel_cfg, - memory_config=ttnn.DRAM_MEMORY_CONFIG, # FIXME: why not L1 height sharded e.g. SCORES_BATCHED_MM_OUTPUT_MEMCFG? - sliding_window=self.configuration.sliding_window if self.is_sliding else 0, - ) - - ttnn.deallocate(q_heads_1BQD) - - attn_output_11BH = ttnn.to_memory_config( - attn_output_1G4D, - memory_config=self.model_config["SCORES_BATCHED_MM_OUTPUT_MEMCFG"](self.batch_size_per_device_group), - ) - attn_output_cat = ttnn.experimental.nlp_concat_heads_decode( - attn_output_11BH, - num_heads=self.n_local_heads, - ) - ttnn.deallocate(attn_output_11BH) - ttnn.deallocate(attn_output_1G4D) - - if self.use_fused_all_gather_matmul: - attn_output_cat = ttnn.to_memory_config( - attn_output_cat, self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"] - ) - - # Fused AGMM only valid for ring topology - if self.ccl_topology == ttnn.Topology.Ring: - _, dense_out_sharded = ttnn.experimental.all_gather_matmul_async( - attn_output_cat, - self.wo, - persistent_output_buffer=None, - dim=3, - multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(), - all_gather_core_grid_offset=(0, 4), - barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(), - num_links=1, - memory_config_ag=self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"], - memory_config_mm=self.model_config["DECODE_RESIDUAL_MEMCFG"], - program_config=self.model_config["ATTN_ALL_GATHER_MATMUL_PROGCFG"], - compute_kernel_config=self.compute_kernel_config_hifi2, - chunks_per_sync=10, - num_workers_per_link=2, - num_buffers_per_channel=2, - ) - else: - all_gather_output = ttnn.experimental.all_gather_async( - attn_output_cat, - persistent_output_buffer=None, - dim=3, - multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(), - num_links=1, - topology=self.ccl_topology, - memory_config=self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"], - barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(), - chunks_per_sync=10, - num_workers_per_link=2, - num_buffers_per_channel=2, - ) - - dense_out_sharded = ttnn.linear( - all_gather_output, - self.wo, - memory_config=self.model_config["DECODE_RESIDUAL_MEMCFG"], - program_config=self.model_config["ATTN_ALL_GATHER_MATMUL_PROGCFG"], - compute_kernel_config=self.li_o_decode_compute_kernel_cfg, - ) - - ttnn.deallocate(all_gather_output) - ttnn.deallocate(attn_output_cat) - dense_out_sharded = ttnn.to_memory_config(dense_out_sharded, self.model_config["DECODE_RESIDUAL_MEMCFG"]) - return dense_out_sharded - - else: - attn_output = tt_all_gather( - attn_output_cat, - self.mesh_device, - self.tt_ccl, - dim=2, - cluster_axis=1, - num_links=2, - memory_config=self.model_config["GATHER_USERS_MEMCFG"](list(self.mesh_device.shape)[1]), - sharded=True, - # dtype=self.ccl_dtype, # Running bf16 until we have SDPA output bfp8 df; otherwise we have two sharded to interleaved/interleaved to sharded conversions - ) - if self.TG: - attn_output = ttnn.to_memory_config(attn_output, ttnn.L1_MEMORY_CONFIG) - # user_selection_matrix = [1, 1, 32, 128] - # user_selection_matrix @ activation -> [1, 1, 32, 128] * [1, 1, 128, 2048] -> [1, 1, 32, 2048] - attn_output = ttnn.matmul( - self.user_selection_matrix, - attn_output, - core_grid=ttnn.CoreGrid(y=4, x=8), - dtype=ttnn.bfloat16, - memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - ) - - # TODO: Fix this once self.TG supports dram-sharded matmuls - dense_out_sharded = ttnn.matmul( - attn_output, - self.wo, - core_grid=ttnn.CoreGrid(y=4, x=8) if self.TG else None, - program_config=self.model_config["ATTN_OUTPUT_PROGCFG"] if not self.TG else None, - memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - dtype=ttnn.bfloat8_b if self.TG else None, - compute_kernel_config=self.li_o_decode_compute_kernel_cfg, - ) - - ttnn.deallocate(attn_output_cat) - - # All reduce - dense_out_reduced = tt_all_reduce( - dense_out_sharded, - self.mesh_device, - self.tt_ccl, - cluster_axis=0, - num_reduce_scatter_links=self.num_reduce_scatter_links, - num_all_gather_links=self.num_all_gather_links, - dim=0 if (self.TG and self.hidden_size < 8192) else 3, - topology=self.ccl_topology, - memory_config=( - ( - self.model_config["SELF_OUT_REDUCE_SCATTER_MEMCFG"] - if self.hidden_size == 8192 - else self.model_config["SELF_OUT_GATHERED_MEMCFG"](list(self.mesh_device.shape)[0]) - ) - if self.TG - else self.model_config["DECODE_RESIDUAL_MEMCFG"] - ), - sharded=True, - dtype=self.ccl_dtype, - use_composite=True if self.hidden_size == 8192 else False, - ) - - if not self.TG: - dense_out_reduced = ttnn.to_memory_config( - dense_out_reduced, self.model_config["DECODE_RESIDUAL_MEMCFG"] - ) - - return dense_out_reduced - - def forward_prefill( - self, - x_11SH, - rot_mats, - user_id: int = 0, - page_table=None, - chunk_page_table=None, - chunk_start_idx=None, - kv_cache=None, - ): - seq_len = x_11SH.shape[-2] - assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128" - ### - # QKV matmuls - ### - - # reshaping long sequence to matmul fit on device - if seq_len > self.MAX_QKV_MM_SEQ_LEN: - if seq_len % self.MAX_QKV_MM_SEQ_LEN != 0: - raise ValueError(f"seq_len {seq_len} must be divisible by {self.MAX_QKV_MM_SEQ_LEN}") - x_11SH = ttnn.reshape(x_11SH, [1, seq_len // self.MAX_QKV_MM_SEQ_LEN, self.MAX_QKV_MM_SEQ_LEN, -1]) - - xqkv_fused = ttnn.linear( - x_11SH, - self.wqkv, - dtype=self.ccl_dtype if self.TG else self.activation_dtype or ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.li_qkv_prefill_compute_kernel_cfg, - program_config=self.model_config["XQKV_PREFILL_PROGCFG"](seq_len), - ) - - # FIXME: surely ttnn.linear bias should work? - if self.wqkv_bias_prefill is not None: - xqkv_fused = xqkv_fused + self.wqkv_bias_prefill - - xqkv_fused = tt_all_reduce( - xqkv_fused, - self.mesh_device, - self.tt_ccl, - cluster_axis=1, - num_reduce_scatter_links=self.num_reduce_scatter_links, - num_all_gather_links=self.num_all_gather_links, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - dtype=self.ccl_dtype, - ) - - if seq_len > self.MAX_QKV_MM_SEQ_LEN: - xqkv_fused = ttnn.reshape(xqkv_fused, [1, 1, seq_len, -1]) - - ttnn.deallocate(x_11SH) - - # split qkv into heads - ( - q_heads_1QSD_pre_rot, - k_heads_1KSD_pre_rot, - v_heads_1VSD, - ) = ttnn.experimental.nlp_create_qkv_heads( - xqkv_fused, - num_heads=self.n_local_heads, - num_kv_heads=self.n_local_kv_heads, - transpose_k_heads=False, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) - - q_heads_1QSD_pre_rot = self.q_norm(q_heads_1QSD_pre_rot, mode="prefill") - k_heads_1KSD_pre_rot = self.k_norm(k_heads_1KSD_pre_rot, mode="prefill") - - ttnn.deallocate(xqkv_fused) - - ### - # Rotary embeddings - ### - - if q_heads_1QSD_pre_rot.dtype != ttnn.bfloat16: # Rotary embeddings require bfloat16 inputs - q_heads_1QSD_pre_rot = ttnn.typecast(q_heads_1QSD_pre_rot, dtype=ttnn.bfloat16) - - q_heads_1QSD = ttnn.experimental.rotary_embedding_llama( - q_heads_1QSD_pre_rot, - rot_mats[0], - rot_mats[1], - self.transformation_mats["prefill"], - is_decode_mode=False, - ) - ttnn.deallocate(q_heads_1QSD_pre_rot) - - if k_heads_1KSD_pre_rot.dtype != ttnn.bfloat16: # Rotary embeddings require bfloat16 inputs - k_heads_1KSD_pre_rot = ttnn.typecast(k_heads_1KSD_pre_rot, dtype=ttnn.bfloat16) - - k_heads_1KSD = ttnn.experimental.rotary_embedding_llama( - k_heads_1KSD_pre_rot, - rot_mats[0], - rot_mats[1], - self.transformation_mats["prefill"], - is_decode_mode=False, - ) - ttnn.deallocate(k_heads_1KSD_pre_rot) - - # Fill KV-Cache - if kv_cache: - keys_BKSD, values_BKSD = kv_cache[0], kv_cache[1] - else: - keys_BKSD, values_BKSD = self.layer_past[0], self.layer_past[1] - k_heads_1KSD_8b = ttnn.typecast(k_heads_1KSD, dtype=keys_BKSD.dtype) - ttnn.deallocate(k_heads_1KSD) - - # sharding k_fill to deal with update_cache memory limitation - if seq_len >= self.min_kv_prefill_shard_seqlen and not self.TG and not page_table: - k_fill = ttnn.interleaved_to_sharded(k_heads_1KSD_8b, self.model_config["KV_PREFILL_MEM_CFG"](seq_len)) - else: - k_fill = k_heads_1KSD_8b - - v_heads_1VSD_8b = ttnn.typecast(v_heads_1VSD, dtype=values_BKSD.dtype) - - ttnn.deallocate(v_heads_1VSD) - - # sharding v_fill to deal with update_cache memory limitation - if seq_len >= self.min_kv_prefill_shard_seqlen and not self.TG and not page_table: - v_fill = ttnn.interleaved_to_sharded(v_heads_1VSD_8b, self.model_config["KV_PREFILL_MEM_CFG"](seq_len)) - else: - v_fill = v_heads_1VSD_8b - - if self.TG: - k_fill = self.prefill_prepare_tensor_for_kv_cache(k_fill, user_id) - v_fill = self.prefill_prepare_tensor_for_kv_cache(v_fill, user_id) - if page_table: - # In the case that the tokens have been padded along the seq len dimension, we need to fill the cache with the unpadded k/v values. - # Assume that the page table does not have padding, so we can use it to get the unpadded page len. - block_size = keys_BKSD.shape[2] - # If chunked prefill, use chunk_page_table if given, otherwise use page_table. - fill_page_table = chunk_page_table if chunk_page_table is not None else page_table - - page_len = fill_page_table.shape[1] * block_size - k_fill_sliced = k_fill[:, :, :page_len, :] if page_len < k_fill.shape[2] else k_fill - v_fill_sliced = v_fill[:, :, :page_len, :] if page_len < v_fill.shape[2] else v_fill - ttnn.experimental.paged_fill_cache(keys_BKSD, k_fill_sliced, fill_page_table, batch_idx=user_id) - ttnn.experimental.paged_fill_cache(values_BKSD, v_fill_sliced, fill_page_table, batch_idx=user_id) - else: - ttnn.fill_cache( - keys_BKSD, - k_fill, - user_id % self.batch_size_per_device_group, - ) - ttnn.fill_cache( - values_BKSD, - v_fill, - user_id % self.batch_size_per_device_group, - ) - - if seq_len >= self.min_kv_prefill_shard_seqlen and not self.TG and not page_table: - ttnn.deallocate(k_fill) - ttnn.deallocate(v_fill) - - # SDPA - q_heads_1QSD_8b = ttnn.typecast(q_heads_1QSD, dtype=self.activation_dtype or ttnn.bfloat8_b) - ttnn.deallocate(q_heads_1QSD) - - if chunk_start_idx is not None: - attn_output_84SD = ttnn.transformer.chunked_scaled_dot_product_attention( - q_heads_1QSD_8b, - keys_BKSD, - values_BKSD, - page_table, - chunk_start_idx, - attn_mask=attn_mask, - is_causal=True, - compute_kernel_config=self.sdpa_prefill_compute_kernel_cfg, - program_config=self.model_config["SDPA_PROGCFG"](seq_len), - ) - else: - attn_output_84SD = ttnn.transformer.scaled_dot_product_attention( - q_heads_1QSD_8b, - k_heads_1KSD_8b, - v_heads_1VSD_8b, - scale=self.scale, - compute_kernel_config=self.sdpa_prefill_compute_kernel_cfg, - program_config=self.model_config["SDPA_PROGCFG"](seq_len), - ) - - # deallocate keys and values - ttnn.deallocate(q_heads_1QSD_8b) - ttnn.deallocate(k_heads_1KSD_8b) - ttnn.deallocate(v_heads_1VSD_8b) - - attn_output_1QSD = ttnn.reshape(attn_output_84SD, [1, self.n_local_heads, -1, self.head_dim]) - - ### - # Output matmul - ### - attn_output_11SH = ttnn.experimental.nlp_concat_heads( - attn_output_1QSD, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) - ttnn.deallocate(attn_output_1QSD) - # reshaping long sequence to matmul fit on device - if seq_len > 1024: - attn_output_11SH = ttnn.reshape(attn_output_11SH, [1, seq_len // 1024, 1024, -1]) - - # Non fused All Gather Matmul - if self.use_fused_all_gather_matmul: # is true for Ring topology - attn_output_11SH = ttnn.experimental.all_gather_async( - attn_output_11SH, - persistent_output_buffer=None, - dim=3, - multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(), - num_links=1, - topology=self.ccl_topology, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(), - chunks_per_sync=10, - num_workers_per_link=2, - num_buffers_per_channel=2, - ) - - output_11SH = ttnn.linear( - attn_output_11SH, - self.wo, - compute_kernel_config=self.li_o_prefill_compute_kernel_cfg, - dtype=self.activation_dtype or ttnn.bfloat8_b, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - program_config=self.model_config["WO_PREFILL_PROGCFG"](seq_len), - ) - - if seq_len > 1024: - output_11SH = ttnn.reshape(output_11SH, [1, 1, seq_len, -1]) - ttnn.deallocate(attn_output_11SH) - - # Reduce-scatter - if not self.use_fused_all_gather_matmul: - output_11SH = tt_all_reduce( - output_11SH, - self.mesh_device, - self.tt_ccl, - cluster_axis=0, - dim=0 if self.TG else 3, - num_reduce_scatter_links=self.num_reduce_scatter_links, - num_all_gather_links=self.num_all_gather_links, - topology=self.ccl_topology, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - dtype=self.ccl_dtype, - ) - - return output_11SH - - def forward( - self, - x, - current_pos, - rot_mats=None, - user_id=0, - mode="decode", - page_table=None, - chunk_page_table=None, - chunk_start_idx=None, - kv_cache=None, - ): - if mode == "prefill": - return self.forward_prefill( - x, - rot_mats, - user_id, - page_table=page_table, - chunk_page_table=chunk_page_table, - chunk_start_idx=chunk_start_idx, - kv_cache=kv_cache, - ) - else: - return self.forward_decode(x, current_pos, rot_mats, page_table=page_table, kv_cache=kv_cache) - - def prefill_prepare_tensor_for_kv_cache(self, key_or_value_layer, user_id): - tensor_copy = ttnn.clone(key_or_value_layer) - # key_or_value_layer.deallocate(True) - # Get all tensors from multi-device tensor - tensors = ttnn.get_device_tensors(tensor_copy) - # Get only tensors from specific column chips - # Get every 4th tensor starting from user_id // 8 - single_column_tensors = tensors[user_id // self.batch_size_per_device_group :: 4] - # Create multi-device tensor - multi_device_tensor = ttnn.combine_device_tensors(single_column_tensors) - - return multi_device_tensor diff --git a/models/experimental/gemma3/tt/decoder.py b/models/experimental/gemma3/tt/decoder.py deleted file mode 100644 index 259007843641..000000000000 --- a/models/experimental/gemma3/tt/decoder.py +++ /dev/null @@ -1,277 +0,0 @@ -""" -source: models/tt_transformers/tt/decoder.py - -This is the Decoder block for the Gemma3 model -We couldn't use the existing implementation in TT-Transformers because the usage of submodules is different - -In Gemma3, The decoder Block has Additional pre_feedforward_layernorm and post_feedforward_layernorm, -And the logic of implementation is different from the existing implementation in TT-Transformers. -""" - -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 -import ttnn - -from models.common.lightweightmodule import LightweightModule -from models.tt_transformers.tt.distributed_norm import DistributedNorm -from models.experimental.gemma3.tt.rmsnorm import RMSNorm - -from models.experimental.gemma3.tt.attention import Attention as DefaultAttention - -from models.experimental.gemma3.tt.mlp import MLP -from models.tt_transformers.tt.model_config import TensorGroup -from models.tt_transformers.tt.ccl import tt_all_reduce - - -class TransformerBlock(LightweightModule): - def __init__( - self, - args, - mesh_device, - tt_ccl, - dtype, - state_dict, - layer_num, - weight_cache_path, - transformation_mats, - paged_attention_config=None, - use_paged_kv_cache=False, - attention_class=None, - ): - super().__init__() - - self.state_dict = state_dict - self.mesh_device = mesh_device - - self.tt_ccl = tt_ccl - self.args = args - self.hidden_size = args.dim - self.n_heads = args.n_heads - self.head_dim = self.hidden_size // self.n_heads - self.max_seq_len = args.max_seq_len - self.dim = args.dim - self.max_batch_size = args.max_batch_size - self.n_kv_heads = args.n_kv_heads - self.current = 0 - self.model_config = args.get_model_config() - - self.layer_num = layer_num - self.num_devices = args.num_devices - - ActualAttentionClass = attention_class if attention_class is not None else DefaultAttention - - self.attention = ActualAttentionClass( - mesh_device=mesh_device, - tt_ccl=tt_ccl, - state_dict=state_dict, - weight_cache_path=weight_cache_path, - layer_num=layer_num, - dtype=dtype, - transformation_mats=transformation_mats, - configuration=args, - paged_attention_config=paged_attention_config, - use_paged_kv_cache=use_paged_kv_cache, - ) - self.feed_forward = MLP( - mesh_device=mesh_device, - tt_ccl=tt_ccl, - args=args, - state_dict=state_dict, - weight_cache_path=weight_cache_path, - layer_num=layer_num, - dtype=dtype, - model_config=self.model_config, - ) - - self.attention_norm = DistributedNorm( - RMSNorm( - device=mesh_device, - dim=args.dim, - eps=args.norm_eps, - state_dict=state_dict, - state_dict_prefix=args.get_state_dict_prefix("", layer_num), - weight_cache_path=None if args.dummy_weights else weight_cache_path, - weight_dtype=ttnn.bfloat16, - weight_key="attention_norm", - is_distributed=self.args.is_distributed_norm, - add_unit_offset=self.args.rms_norm_add_unit_offset, - sharded_program_config=self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"], - sharded_output_config=self.model_config["SHARDED_ATTN_INPUT_MEMCFG"], - ccl_topology=self.args.ccl_topology(), - tt_ccl=self.tt_ccl, - ), - args, - tt_ccl=self.tt_ccl, - TG=args.is_galaxy, - ) - - self.ff_norm = DistributedNorm( # post_attention_layernorm - RMSNorm( - device=mesh_device, - dim=args.dim, - eps=args.norm_eps, - state_dict=state_dict, - state_dict_prefix=args.get_state_dict_prefix("", layer_num), - weight_cache_path=None if args.dummy_weights else weight_cache_path, - weight_dtype=ttnn.bfloat16, - weight_key="ffn_norm", - is_distributed=self.args.is_distributed_norm, - add_unit_offset=self.args.rms_norm_add_unit_offset, - sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], - sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], - ccl_topology=self.args.ccl_topology(), - tt_ccl=self.tt_ccl, - ), - args, - tt_ccl=self.tt_ccl, - TG=args.is_galaxy, - ) - - self.pre_ff_norm = DistributedNorm( # pre_feedforward_layernorm - RMSNorm( - device=mesh_device, - dim=args.dim, - eps=args.norm_eps, - state_dict=state_dict, - state_dict_prefix=args.get_state_dict_prefix("", layer_num), - weight_cache_path=None if args.dummy_weights else weight_cache_path, - weight_dtype=ttnn.bfloat16, - weight_key="pre_feedforward_layernorm", - is_distributed=self.args.is_distributed_norm, - add_unit_offset=self.args.rms_norm_add_unit_offset, - sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], - sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], - ccl_topology=self.args.ccl_topology(), - tt_ccl=self.tt_ccl, - ), - args, - tt_ccl=self.tt_ccl, - TG=args.is_galaxy, - ) - - self.post_ff_norm = DistributedNorm( # post_feedforward_layernorm - RMSNorm( - device=mesh_device, - dim=args.dim, - eps=args.norm_eps, - state_dict=state_dict, - state_dict_prefix=args.get_state_dict_prefix("", layer_num), - weight_cache_path=None if args.dummy_weights else weight_cache_path, - weight_dtype=ttnn.bfloat16, - weight_key="post_feedforward_layernorm", - is_distributed=self.args.is_distributed_norm, - add_unit_offset=self.args.rms_norm_add_unit_offset, - sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], - sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], - ccl_topology=self.args.ccl_topology(), - tt_ccl=self.tt_ccl, - ), - args, - tt_ccl=self.tt_ccl, - TG=args.is_galaxy, - ) - - def forward( - self, - hidden_states: ttnn.Tensor, - current_pos, - rot_mats_global=None, - rot_mats_local=None, - user_id=0, - mode="decode", - page_table=None, - chunk_page_table=None, - chunk_start_idx=None, - kv_cache=None, - ) -> ttnn.Tensor: - TG = self.args.is_galaxy - skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG - - assert ( - hidden_states.memory_config() == skip_mem_cfg - ), f"decoder input memcfg mismatch: {hidden_states.memory_config()} != {skip_mem_cfg}" - residual = hidden_states - - attn_in = self.attention_norm(hidden_states, mode) - - rot_mats = ( - rot_mats_local if (hasattr(self.attention, "is_sliding") and self.attention.is_sliding) else rot_mats_global - ) - - attn_out = self.attention.forward( - attn_in, - current_pos, - rot_mats, - user_id, - mode, - page_table=page_table, - chunk_page_table=chunk_page_table, - chunk_start_idx=chunk_start_idx, - kv_cache=kv_cache, - ) - - hidden_states = self.ff_norm(attn_out, mode) - - ttnn.deallocate(attn_out) - ttnn.deallocate(attn_in) - - if self.num_devices > 1: - hidden_states = tt_all_reduce( - hidden_states, - self.mesh_device, - self.tt_ccl, - cluster_axis=0, - dim=3, - num_reduce_scatter_links=self.args.num_reduce_scatter_links, - num_all_gather_links=self.args.num_all_gather_links, - topology=ttnn.Topology.Ring, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - dtype=self.args.ccl_dtype, - ) - - hidden_states = ttnn.div(hidden_states, self.num_devices) - - hidden_states = ttnn.add(hidden_states, residual, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16) - - residual = hidden_states - - hidden_states = self.pre_ff_norm(hidden_states, mode) - - if TG and mode == "decode": - hidden_states = ttnn.to_memory_config(hidden_states, memory_config=self.model_config["MLP_ACT_MEMCFG"]) - - hidden_states = self.feed_forward.forward(hidden_states, mode) - - activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=self.layer_num, tensor=TensorGroup.ACTIVATION - ) - - hidden_states = self.post_ff_norm(hidden_states, mode) - - if self.num_devices > 1: - hidden_states = tt_all_reduce( - hidden_states, - self.mesh_device, - self.tt_ccl, - cluster_axis=0, - dim=3, - num_reduce_scatter_links=self.args.num_reduce_scatter_links, - num_all_gather_links=self.args.num_all_gather_links, - topology=ttnn.Topology.Ring, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - dtype=self.args.ccl_dtype, - ) - - hidden_states = ttnn.div(hidden_states, self.num_devices) - - hidden_states = ttnn.add( - hidden_states, - residual, - memory_config=skip_mem_cfg, - dtype=self.args.ccl_dtype - if TG and not self.args.is_distributed_norm(mode) - else activation_dtype or ttnn.bfloat16, - ) - - return hidden_states diff --git a/models/experimental/gemma3/tt/gemma3_generator.py b/models/experimental/gemma3/tt/gemma3_generator.py deleted file mode 100644 index a8ea740d4685..000000000000 --- a/models/experimental/gemma3/tt/gemma3_generator.py +++ /dev/null @@ -1,1302 +0,0 @@ -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 - -from dataclasses import dataclass - -import torch -from llama_models.llama3.api.datatypes import InterleavedTextMedia, StopReason -from llama_models.llama3.reference_impl.generation import ( - ChatPrediction, - CompletionPrediction, - TokenResult, - sample_top_p, -) -from loguru import logger - -import ttnn -from models.tt_transformers.tt.common import ( - copy_host_to_device, - get_block_size, - get_max_prefill_chunk_size, - get_padded_prefill_len, - num_blocks_in_seq, -) - - -@dataclass(frozen=True) -class SamplingParams: - """ - Used in Generator decode forward functions for greedy decoding / sampling on device. - The same data class exists in vLLM at vllm/worker/tt_model_runner.py. - """ - - temperature: float - top_k: int - top_p: float - - -class Gemma3Generator: - def __init__(self, model, model_args, mesh_device, tokenizer=None, formatter=None): - """ - Creating a LlamaVision wrapper requires only a mesh_device and model_args. - With model_args you have the checkpoint location, can specify max batch size - and max seqlen, and other model specific parameters. - - LlamaVision is general to text and chat. - - For bringup, make this class general to any backend implementation, as long as it takes torch tensors and returns torch tensors. - - """ - self.model = model - self.model_args = model_args - self.mesh_device = mesh_device - self.tokenizer = tokenizer - self.formatter = formatter - self.data_parallel = len(self.model) - self.prev_page_table = None - - # Note: This function is called by vLLM - def prefill_forward_text( - self, tokens: torch.Tensor, page_table=None, kv_cache=None, prompt_lens=None, empty_slots=None, **kwargs - ): - if page_table is not None: - assert isinstance(page_table, torch.Tensor), "page_table mush be torch.Tensor" - - batch_size, batch_seq_len = tokens.shape - max_batch_size_per_model = self.model_args[0].max_batch_size - - # Each model expected to run the same model, safe to use 1st vocab size - output_logits = torch.zeros(batch_size, 1, self.model_args[0].vocab_size) - prompt_lens = prompt_lens if prompt_lens is not None else torch.tensor([batch_seq_len] * batch_size) - - if empty_slots is None: - empty_slots = list(range(batch_size)) - - out_list = [] - for idx, user_id in enumerate(empty_slots): - model_id = user_id // max_batch_size_per_model - group_user_id = user_id % max_batch_size_per_model if page_table is None else 0 - seq_len = int(prompt_lens[idx]) - last_token_idx = seq_len - 1 - prefill_seq_len = get_padded_prefill_len(seq_len) - - logger.info(f"Prefilling User {user_id + 1} up to {seq_len} tokens") - - # Extracting data for the current user - # If page_table is not provided, we keep track of the relative/model user_id through group_user_id - prefill_ids = torch.cat( - [tokens[idx : idx + 1, :seq_len], torch.zeros(1, prefill_seq_len - seq_len).long()], dim=-1 - ) - page_table_user = ( - self._get_prefill_user_page_table(page_table[idx : idx + 1], kv_cache[model_id], seq_len) - if page_table is not None - else None - ) - model_kv_cache = kv_cache[model_id] if kv_cache is not None else None - - logits = self.prefill_forward_single_user_text( - prefill_ids, - page_table=page_table_user, - user_id=group_user_id, - last_token_idx=last_token_idx, - kv_cache=model_kv_cache, - model_id=model_id, - **kwargs, - ) - out_list.append(logits) - - for idx, out in enumerate(out_list): - seq_len = int(prompt_lens[idx]) - last_token_idx = seq_len - 1 - user_id = empty_slots[idx] - model_id = user_id // max_batch_size_per_model - - # Since we give unpadded_seq_len, only the tile containing the last token is returned - output_logits[idx] = self.model[model_id].process_output_prefill(out, last_token_idx=(last_token_idx % 32)) - - logger.info(f"Finished prefill for all users up to {batch_seq_len} tokens, Starting decode...") - return output_logits - - def prefill_forward_single_user_text( - self, tokens, page_table, user_id, last_token_idx, kv_cache=None, model_id=-1, **kwargs - ): - seq_len = tokens.shape[-1] - use_chunked_prefill = seq_len > self.model_args[model_id].max_prefill_chunk_size - if use_chunked_prefill: - """ - Chunked prefill requires paged attention. There are some strange constraints which we must meet: - - page_table, which is used in SDPA, must match batch size of inputs, which is 1. This is because SDPA - checks that page table batch dim matches input batch dim. Therefore we must slice the page table for the current user. - - page_table must also have enough entries in each chunk, so it will be padded with zeros if necessary. - - chunked_page_table is the slice of the page table for the current chunk. This is used by paged_fill_cache - to keep it otherwise unaware that it is operating on a chunk. - - due to the above point, we must always set user_id to 0 for chunked prefill. - """ - assert page_table is not None, "page_table must be provided for chunked prefill" - assert kv_cache is not None, "kv_cache must be provided for chunked prefill" - assert ( - last_token_idx is not None and last_token_idx < seq_len - ), "last_token_idx must be provided and less than seq_len" - chunk_size = get_max_prefill_chunk_size(seq_len, self.model_args[model_id].max_prefill_chunk_size) - block_size = get_block_size(kv_cache) - last_token_idx_in_chunk = last_token_idx % chunk_size - # Calculate which chunk contains the last_token_idx - last_chunk_start = (last_token_idx // chunk_size) * chunk_size - page_table_user = page_table[user_id : user_id + 1, :] - # Pad page table to match number of blocks in seq_len - num_padding_blocks = num_blocks_in_seq(seq_len, block_size) - page_table_user.shape[1] - page_table_user_padded = torch.cat( - [page_table_user, torch.zeros(1, num_padding_blocks, dtype=torch.int32)], dim=-1 - ) - CHUNK_USER_ID = 0 - - for chunk_start in range(0, seq_len, chunk_size): - chunk_end = chunk_start + chunk_size - assert ( - chunk_end <= seq_len - ), f"Chunk end should be less than seq_len, got chunk_end={chunk_end} and seq_len={seq_len}" - chunk_tokens = tokens[:, chunk_start:chunk_end] - chunk_page_table = page_table_user[:, chunk_start // block_size : chunk_end // block_size] - - ( - chunk_prefill_input, - chunk_rot_mats_global_prefill, - chunk_rot_mats_local_prefill, - page_table_tt, - chunk_page_table_tt, - ) = self.model[model_id].prepare_inputs_prefill( - chunk_tokens, - start_pos=chunk_start, - page_table=page_table_user_padded, - chunk_page_table=chunk_page_table, - **kwargs, - ) - tt_logits = self.model[model_id].ttnn_prefill_forward( - chunk_prefill_input, - rot_mats_global=chunk_rot_mats_global_prefill, - rot_mats_local=chunk_rot_mats_local_prefill, - user_id=CHUNK_USER_ID, - page_table=page_table_tt, - chunk_page_table=chunk_page_table_tt, - chunk_start_idx=chunk_start, - get_last_token=(last_token_idx_in_chunk // 32) * 32, - kv_cache=kv_cache, - **kwargs, - ) - - if chunk_start == last_chunk_start: - return tt_logits - else: - del tt_logits - else: - ( - prefill_input, - rot_mats_global_prefill, - rot_mats_local_prefill, - page_table_tt, - _, - ) = self.model[model_id].prepare_inputs_prefill( - tokens, - page_table=page_table, - **kwargs, - ) - - tt_logits = self.model[model_id].ttnn_prefill_forward( - prefill_input, - rot_mats_global=rot_mats_global_prefill, - rot_mats_local=rot_mats_local_prefill, - user_id=user_id, - page_table=page_table_tt, - get_last_token=(last_token_idx // 32) * 32, - kv_cache=kv_cache, - ) - return tt_logits - - # Note: This function is called by vLLM - def decode_forward_text( - self, - tokens, - start_pos, - page_table=None, - kv_cache=None, - enable_trace=True, - read_from_device=True, - sampling_params: SamplingParams = None, # Should be None if not greedy decoding / sampling on device. - ): - assert ( - sampling_params is None or sampling_params.temperature == 0 - ), "Currently only supporting greedy decoding (temperature=0) on device" - argmax_on_device = sampling_params is not None and sampling_params.temperature == 0 - - B = tokens.shape[0] - tokens = torch.chunk(tokens, self.data_parallel, 0) - start_pos = torch.chunk(start_pos, self.data_parallel, 0) - page_table = torch.chunk(page_table, self.data_parallel, 0) if page_table is not None else None - - decode_kwargs = { - "current_pos": start_pos, - "tokens": tokens, - "page_table": page_table, - "kv_cache": kv_cache, - "argmax_on_device": argmax_on_device, - } - if enable_trace: - tt_decode_output = self._easy_trace_text(**decode_kwargs) - else: - tt_decode_output = self._decode_forward_no_trace_text(**decode_kwargs) - - if read_from_device: - to_host = self.read_decode_output(tt_decode_output) - return self.process_decode_output_host(to_host, is_tokens=(sampling_params is not None)) - - return tt_decode_output - - def _decode_forward_no_trace_text( - self, - tokens, - current_pos, - page_table=None, - kv_cache=None, - argmax_on_device=False, - ): - """ - Performs text decode step. - Returns tt_logits on device - """ - tt_logits = [] - - tt_tokens = [] - tt_current_pos = [] - tt_rot_mat_idxs_global = [] - tt_rot_mat_idxs_local = [] - tt_page_table = [] - for i in range(self.data_parallel): - user_page_table = page_table[i] if page_table is not None else None - model_i = self.model[i] - ( - tt_tokens_i, - tt_current_pos_i, - tt_rot_mat_idxs_global_i, - tt_rot_mat_idxs_local_i, - tt_page_table_i, - ) = model_i.prepare_inputs_decode(tokens[i], current_pos[i], user_page_table) - tt_tokens.append(tt_tokens_i) - tt_current_pos.append(tt_current_pos_i) - tt_rot_mat_idxs_global.append(tt_rot_mat_idxs_global_i) - tt_rot_mat_idxs_local.append(tt_rot_mat_idxs_local_i) - tt_page_table.append(tt_page_table_i) - - if ( - hasattr(self.model[i], "device_decode_sliding_mask") - and self.model[i].device_decode_sliding_mask is not None - ): - self.model[i].update_attention_masks(current_pos[i]) - - for i in range(self.data_parallel): - user_kv_cache = kv_cache[i] if kv_cache is not None else None - tt_logits_i = self.model[i].ttnn_decode_forward( - tt_tokens[i], - tt_current_pos[i], - rot_mat_idxs_global=tt_rot_mat_idxs_global[i], - rot_mat_idxs_local=tt_rot_mat_idxs_local[i], - page_table=tt_page_table[i], - kv_cache=user_kv_cache, - argmax_on_device=argmax_on_device, - ) - tt_logits.append(tt_logits_i) - - return tt_logits - - def _capture_trace_text( - self, - tokens, - current_pos, - page_table=None, - kv_cache=None, - argmax_on_device=False, - ): - """ - Captures a trace for the decode_forward method. - """ - - # Compile run - self._decode_forward_no_trace_text( - tokens, current_pos, page_table=page_table, kv_cache=kv_cache, argmax_on_device=argmax_on_device - ) - logger.info("Done Compiling Model") - - # Get inputs ready for trace run - device_inputs = [] - tt_out_trace = [] - trace_ids = {} - for i in range(self.data_parallel): - user_page_table = page_table[i] if page_table is not None else None - host_inputs = self.model[i].prepare_decode_inputs_host( - tokens[i], current_pos[i], page_table=user_page_table - ) - - device_inputs_i = copy_host_to_device(host_inputs, mesh_device=self.model_args[i].mesh_device) - device_inputs.append(device_inputs_i) - - for i in range(self.data_parallel): - trace_id = ttnn.begin_trace_capture(self.model_args[i].mesh_device, cq_id=0) - trace_ids[i] = trace_id - user_kv_cache = kv_cache[i] if kv_cache is not None else None - tt_out_trace.append( - self.model[i].ttnn_decode_forward( - *device_inputs[i], kv_cache=user_kv_cache, argmax_on_device=argmax_on_device - ) - ) - ttnn.end_trace_capture(self.model_args[i].mesh_device, trace_id, cq_id=0) - logger.info("Done Capturing Decode Trace") - return trace_ids, tt_out_trace, *device_inputs - - def _easy_trace_text( - self, - tokens, - current_pos, - page_table=None, - kv_cache=None, - argmax_on_device=False, - ): - """ - Tracing is easy! Just call this method and we'll handle tracing for you. - """ - if not hasattr(self, "trace_ids_text"): - trace_ids, tt_out_trace, *device_inputs = self._capture_trace_text( - tokens, current_pos, page_table=page_table, kv_cache=kv_cache, argmax_on_device=argmax_on_device - ) - self.trace_ids_text = trace_ids - self.trace_inputs_text = device_inputs - self.trace_output_text = tt_out_trace - - reset_inputs = not argmax_on_device - if self.prev_page_table is None or any( - not torch.equal(prev, curr) for prev, curr in zip(self.prev_page_table, page_table) - ): - reset_inputs = True - self.prev_page_table = page_table - - if reset_inputs: - for i in range(self.data_parallel): - user_page_table = page_table[i] if page_table is not None else None - host_inputs_i = self.model[i].prepare_decode_inputs_host(tokens[i], current_pos[i], user_page_table) - - copy_host_to_device( - host_tensors=host_inputs_i, - device_tensors=self.trace_inputs_text[i], - ) - for i in range(self.data_parallel): - if ( - hasattr(self.model[i], "device_decode_sliding_mask") - and self.model[i].device_decode_sliding_mask is not None - ): - self.model[i].update_attention_masks(current_pos[i]) - - for i, trace_id in self.trace_ids_text.items(): - ttnn.execute_trace(self.model_args[i].mesh_device, trace_id, cq_id=0, blocking=False) - - return self.trace_output_text - - def _prefill_forward_single_user( - self, - vision_images, - vision_mask, - tokens, - xattn_caches, - user_id, - total_len, - prefill_len, - page_table=None, - kv_cache=None, - cross_page_table=None, - model_id=-1, - ): - """ - Performs vision encode step then text prefill. - Returns (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits) - """ - B = tokens.shape[0] - last_token_idx = prefill_len - 1 - - text_only_inference = vision_images is None - if not text_only_inference: - ( - vision_tokens, - prefill_cross_attention_masks, - prefill_full_text_row_masked_out_mask, - decode_cross_attention_masks, - decode_full_text_row_masked_out_mask, - ) = self.model[model_id].compute_vision_tokens_masks( - batch_images=[vision_images], - batch_masks=[vision_mask], - total_len=total_len, - prefill_len=prefill_len, - ) - - if cross_page_table is not None: - num_vision_tokens = vision_tokens.shape[2] - cross_page_table = self._get_prefill_user_page_table(cross_page_table, kv_cache, num_vision_tokens) - else: - ( - vision_tokens, - prefill_cross_attention_masks, - prefill_full_text_row_masked_out_mask, - decode_cross_attention_masks, - decode_full_text_row_masked_out_mask, - ) = (None, None, None, None, None) - - if page_table is not None: - page_table = self._get_prefill_user_page_table(page_table, kv_cache, prefill_len) - - ( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - tt_full_text_mask_expand_11SD, - rot_mats, - tt_page_table, - tt_cross_page_table, - ) = self.model[model_id].prepare_inputs_prefill( - tokens, - prefill_cross_attention_masks, - prefill_full_text_row_masked_out_mask, - prefill_len=prefill_len, - page_table=page_table, - cross_page_table=cross_page_table, - text_only_inference=text_only_inference, - ) - - tt_logits = self.model[model_id].ttnn_prefill_forward( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - tt_full_text_mask_expand_11SD, - xattn_caches, - rot_mats, - user_id, - vision_tokens, - page_table=tt_page_table, - kv_cache=kv_cache, - get_last_token=(last_token_idx // 32) * 32, - cross_page_table=tt_cross_page_table, - text_only_inference=text_only_inference, - ) - - del tt_page_table - del tt_cross_page_table - - return ( - xattn_caches, - prefill_cross_attention_masks, - prefill_full_text_row_masked_out_mask, - decode_cross_attention_masks, - decode_full_text_row_masked_out_mask, - tt_logits, - ) - - # Note: This function is called by vLLM - def prefill_forward( - self, - vision_images, - vision_masks, - tokens: torch.Tensor, - xattn_caches, - total_lens, - prompt_lens, - page_table=None, - kv_cache=None, - cross_page_table=None, - empty_slots=None, - ): - """ - Batched version of _prefill_forward_single_user for vision model. - """ - if page_table is not None: - assert isinstance(page_table, torch.Tensor), "page_table mush be torch.Tensor" - if cross_page_table is not None: - assert isinstance(cross_page_table, torch.Tensor), "cross_page_table mush be torch.Tensor" - - batch_size, batch_seq_len = tokens.shape - max_batch_size_per_model = self.model_args[0].max_batch_size - - output_logits = torch.zeros(batch_size, 1, self.model_args[0].vocab_size) - - out_list = [] - prefill_output_xattn_masks = [] - prefill_output_full_text_row_masked_out_masks = [] - decode_output_xattn_masks = [] - decode_output_full_text_row_masked_out_masks = [] - - if empty_slots is None: - empty_slots = list(range(batch_size)) - - for idx, user_id in enumerate(empty_slots): - model_id = user_id // max_batch_size_per_model - group_user_id = user_id % max_batch_size_per_model if page_table is None else 0 - seq_len = int(prompt_lens[idx]) - - logger.info(f"Prefilling User {user_id + 1} up to {seq_len} tokens") - - user_page_table = page_table[idx : idx + 1] if page_table is not None else None - user_cross_page_table = cross_page_table[idx : idx + 1] if kv_cache is not None else None - model_kv_cache = kv_cache[model_id] if kv_cache is not None else None - model_xattn_cache = xattn_caches[model_id] if xattn_caches is not None else None - - ( - model_xattn_cache, - prefill_cross_attention_masks, - prefill_full_text_row_masked_out_mask, - decode_cross_attention_masks, - decode_full_text_row_masked_out_mask, - logits, - ) = self._prefill_forward_single_user( - vision_images=vision_images[idx], - vision_mask=vision_masks[idx], - tokens=tokens[idx : idx + 1, :seq_len], # Keep batch dimension - xattn_caches=model_xattn_cache, - user_id=group_user_id, - total_len=total_lens[idx], - prefill_len=seq_len, - page_table=user_page_table, - kv_cache=model_kv_cache, - cross_page_table=user_cross_page_table, - model_id=model_id, - ) - - if xattn_caches is not None: - xattn_caches[model_id] = model_xattn_cache - - out_list.append(logits) - prefill_output_xattn_masks.append(prefill_cross_attention_masks) - prefill_output_full_text_row_masked_out_masks.append(prefill_full_text_row_masked_out_mask) - decode_output_xattn_masks.append(decode_cross_attention_masks) - decode_output_full_text_row_masked_out_masks.append(decode_full_text_row_masked_out_mask) - - # We gather prefill output at the end of prefill to reduce unnecessary device sync - for idx, user_id in enumerate(empty_slots): - model_id = user_id // max_batch_size_per_model - - last_token_idx = prompt_lens[idx] - 1 - output_logits[idx] = self.model[model_id].process_output_prefill( - out_list[idx], 1, last_token_idx=(last_token_idx % 32) - ) - - logger.info(f"Finished prefill for all users up to {batch_seq_len} tokens, Starting decode...") - - return ( - output_logits, - prefill_output_xattn_masks, - prefill_output_full_text_row_masked_out_masks, - decode_output_xattn_masks, - decode_output_full_text_row_masked_out_masks, - ) - - # Note: This function is called by vLLM - def decode_forward( - self, - start_pos, - tokens, - prefill_cross_attention_masks, - prefill_full_text_row_masked_out_mask, - decode_cross_attention_masks, - decode_full_text_row_masked_out_mask, - xattn_caches=None, - page_table=None, - kv_cache=None, - cross_page_table=None, - enable_trace=True, - read_from_device=True, - ): - B = tokens.shape[0] - data_parallel = min(B, self.data_parallel) - batch_per_device = B // data_parallel - tokens = torch.chunk(tokens, self.data_parallel, 0) - start_pos = torch.chunk(start_pos, self.data_parallel, 0) - prefill_cross_attention_masks = [ - prefill_cross_attention_masks[i * batch_per_device : (i + 1) * batch_per_device] - for i in range(data_parallel) - ] - prefill_full_text_row_masked_out_mask = [ - prefill_full_text_row_masked_out_mask[i * batch_per_device : (i + 1) * batch_per_device] - for i in range(data_parallel) - ] - decode_cross_attention_masks = [ - decode_cross_attention_masks[i * batch_per_device : (i + 1) * batch_per_device] - for i in range(data_parallel) - ] - decode_full_text_row_masked_out_mask = [ - decode_full_text_row_masked_out_mask[i * batch_per_device : (i + 1) * batch_per_device] - for i in range(data_parallel) - ] - page_table = torch.chunk(page_table, self.data_parallel, 0) if page_table is not None else None - cross_page_table = ( - torch.chunk(cross_page_table, self.data_parallel, 0) if cross_page_table is not None else None - ) - - decode_kwargs = { - "position_id": start_pos, - "tokens": tokens, - "prefill_cross_attention_masks": prefill_cross_attention_masks, - "prefill_full_text_row_masked_out_mask": prefill_full_text_row_masked_out_mask, - "decode_cross_attention_masks": decode_cross_attention_masks, - "decode_full_text_row_masked_out_mask": decode_full_text_row_masked_out_mask, - "xattn_caches": xattn_caches, - "page_table": page_table, - "kv_cache": kv_cache, - "cross_page_table": cross_page_table, - } - if enable_trace: - tt_logits = self._easy_trace(**decode_kwargs) - else: - tt_logits = self._decode_forward_no_trace(**decode_kwargs) - - if read_from_device: - to_host = self.read_decode_output(tt_logits) - return self.process_decode_output_host(to_host) - else: - return tt_logits - - # Note: This function is called by vLLM - def read_decode_output(self, tt_out, async_read=False): - """ - Input tt_out is a list of ttnn device tensors - """ - if not async_read: - return [out.cpu() for out in tt_out] - - host_outputs = [] - read_events = [] - for i in range(self.data_parallel): - host_outputs.append(tt_out[i].cpu(blocking=False)) - read_events.append(ttnn.record_event(self.model[i].mesh_device, 0)) - - return host_outputs, read_events - - # Note: This function is called by vLLM - def process_decode_output_host(self, tt_out, is_tokens=False): - """ - Converts the input ttnn host tensors to a torch tensor. - The input can be logits (if is_tokens=False) or tokens (if is_tokens=True). - """ - max_batch_size_per_model = self.model_args[0].max_batch_size - - logits = [] - for i in range(self.data_parallel): - logits_i = self.model[i].process_output_decode( - tt_out[i], max_batch_size_per_model, S=1, is_tokens=is_tokens - ) - logits.append(logits_i) - - return torch.cat(logits, 0) - - def _decode_forward_no_trace( - self, - position_id, - tokens, - prefill_cross_attention_masks, - prefill_full_text_row_masked_out_mask, - decode_cross_attention_masks, - decode_full_text_row_masked_out_mask, - xattn_caches=None, - page_table=None, - kv_cache=None, - cross_page_table=None, - ): - """ - Performs text decode step. - Returns tt_logits on device - """ - - # forward_decode should be traced callable - # decorator does compilation, capture, execute - tt_h = [] - tt_xattn_mask = [] - tt_full_text_mask_expand_1NSH = [] - tt_full_text_mask_expand_11SD = [] - tt_position_id = [] - tt_rot_mats = [] - tt_page_table = [] - tt_cross_page_table = [] - - for i in range(self.data_parallel): - B, S = tokens[i].shape - assert S == 1 - - user_page_table = page_table[i] if page_table is not None else None - user_cross_page_table = cross_page_table[i] if cross_page_table is not None else None - ( - tt_h_i, - tt_xattn_mask_i, - tt_full_text_mask_expand_1NSH_i, - tt_full_text_mask_expand_11SD_i, - tt_position_id_i, - tt_rot_mats_i, - tt_page_table_i, - tt_cross_page_table_i, - ) = self.model[i].prepare_inputs_decode( - tokens[i], - prefill_cross_attention_masks[i], - prefill_full_text_row_masked_out_mask[i], - decode_cross_attention_masks[i], - decode_full_text_row_masked_out_mask[i], - position_id=position_id[i], - page_table=user_page_table, - cross_page_table=user_cross_page_table, - ) - - tt_h.append(tt_h_i) - tt_xattn_mask.append(tt_xattn_mask_i) - tt_full_text_mask_expand_1NSH.append(tt_full_text_mask_expand_1NSH_i) - tt_full_text_mask_expand_11SD.append(tt_full_text_mask_expand_11SD_i) - tt_position_id.append(tt_position_id_i) - tt_rot_mats.append(tt_rot_mats_i) - tt_page_table.append(tt_page_table_i) - tt_cross_page_table.append(tt_cross_page_table_i) - - tt_logits = [] - for i in range(self.data_parallel): - user_kv_cache = kv_cache[i] if kv_cache is not None else None - xattn_cache = xattn_caches[i] if xattn_caches is not None else None - tt_logits_i = self.model[i].ttnn_decode_forward( - tt_h[i], - tt_xattn_mask[i], - tt_full_text_mask_expand_1NSH[i], - tt_full_text_mask_expand_11SD[i], - xattn_cache, - tt_position_id[i], - tt_rot_mats[i], - page_table=tt_page_table[i], - kv_cache=user_kv_cache, - cross_page_table=tt_cross_page_table[i], - ) - tt_logits.append(tt_logits_i) - - return tt_logits - - def _capture_trace( - self, - position_id, - tokens, - prefill_cross_attention_masks, - prefill_full_text_row_masked_out_mask, - decode_cross_attention_masks, - decode_full_text_row_masked_out_mask, - xattn_caches, - page_table=None, - kv_cache=None, - cross_page_table=None, - ): - """ - Captures a trace for the decode_forward method. - """ - tt_h = [] - tt_xattn_mask = [] - tt_full_text_mask_expand_1NSH = [] - tt_full_text_mask_expand_11SD = [] - tt_position_id = [] - tt_rot_mats = [] - tt_page_table = [] - tt_cross_page_table = [] - for i in range(self.data_parallel): - user_page_table = page_table[i] if page_table is not None else None - user_cross_page_table = cross_page_table[i] if cross_page_table is not None else None - ( - tt_h_i, - tt_xattn_mask_i, - tt_full_text_mask_expand_1NSH_i, - tt_full_text_mask_expand_11SD_i, - tt_position_id_i, - tt_rot_mats_i, - tt_page_table_i, - tt_cross_page_table_i, - ) = self.model[i].prepare_inputs_decode( - tokens[i], - prefill_cross_attention_masks[i], - prefill_full_text_row_masked_out_mask[i], - decode_cross_attention_masks[i], - decode_full_text_row_masked_out_mask[i], - position_id=position_id[i], - page_table=user_page_table, - cross_page_table=user_cross_page_table, - ) - - tt_h.append(tt_h_i) - tt_xattn_mask.append(tt_xattn_mask_i) - tt_full_text_mask_expand_1NSH.append(tt_full_text_mask_expand_1NSH_i) - tt_full_text_mask_expand_11SD.append(tt_full_text_mask_expand_11SD_i) - tt_position_id.append(tt_position_id_i) - tt_rot_mats.append(tt_rot_mats_i) - tt_page_table.append(tt_page_table_i) - tt_cross_page_table.append(tt_cross_page_table_i) - - # Compile run - for i in range(self.data_parallel): - user_kv_cache = kv_cache[i] if kv_cache is not None else None - xattn_cache = xattn_caches[i] if xattn_caches is not None else None - # tt_logits_rm unused later, no need to make a list - tt_logits_rm = self.model[i].ttnn_decode_forward( - tt_h[i], - tt_xattn_mask[i], - tt_full_text_mask_expand_1NSH[i], - tt_full_text_mask_expand_11SD[i], - xattn_cache, - tt_position_id[i], - tt_rot_mats[i], - page_table=tt_page_table[i], - kv_cache=user_kv_cache, - cross_page_table=tt_cross_page_table[i], - ) - logger.info("Done Compiling Model") - - # Get inputs ready for trace run - tt_h = [] - tt_xattn_mask = [] - tt_full_text_mask_expand_1NSH = [] - tt_full_text_mask_expand_11SD = [] - tt_position_id = [] - tt_rope_id = [] - tt_page_table = [] - tt_cross_page_table = [] - for i in range(self.data_parallel): - user_page_table = page_table[i] if page_table is not None else None - user_cross_page_table = cross_page_table[i] if cross_page_table is not None else None - ( - tt_h_i, - tt_xattn_mask_i, - tt_full_text_mask_expand_1NSH_i, - tt_full_text_mask_expand_11SD_i, - tt_position_id_i, - tt_rope_id_i, - tt_page_table_i, - tt_cross_page_table_i, - ) = self.model[i].prepare_decode_inputs_host( - tokens[i], - prefill_cross_attention_masks[i], - prefill_full_text_row_masked_out_mask[i], - decode_cross_attention_masks[i], - decode_full_text_row_masked_out_mask[i], - position_id[i], - page_table=user_page_table, - cross_page_table=user_cross_page_table, - ) - - ( - tt_h_i, - tt_xattn_mask_i, - tt_full_text_mask_expand_1NSH_i, - tt_full_text_mask_expand_11SD_i, - tt_position_id_i, - tt_rope_id_i, - tt_page_table_i, - tt_cross_page_table_i, - ) = copy_host_to_device( - ( - tt_h_i, - tt_xattn_mask_i, - tt_full_text_mask_expand_1NSH_i, - tt_full_text_mask_expand_11SD_i, - tt_position_id_i, - tt_rope_id_i, - tt_page_table_i, - tt_cross_page_table_i, - ), - mesh_device=self.model_args[i].mesh_device, - ) - - tt_h.append(tt_h_i) - tt_xattn_mask.append(tt_xattn_mask_i) - tt_full_text_mask_expand_1NSH.append(tt_full_text_mask_expand_1NSH_i) - tt_full_text_mask_expand_11SD.append(tt_full_text_mask_expand_11SD_i) - tt_position_id.append(tt_position_id_i) - tt_rope_id.append(tt_rope_id_i) - tt_page_table.append(tt_page_table_i) - tt_cross_page_table.append(tt_cross_page_table_i) - - tt_h_trace_input = tt_h - - tt_logits_rm = [] - trace_ids = {} - # Do on-device transformations of inputs before forward - for i in range(self.data_parallel): - trace_id = ttnn.begin_trace_capture(self.model_args[i].mesh_device, cq_id=0) - trace_ids[i] = trace_id - B = tokens[i].shape[0] - user_kv_cache = kv_cache[i] if kv_cache is not None else None - xattn_cache = xattn_caches[i] if xattn_caches is not None else None - ( - tt_h_transform, - tt_rot_mats, - tt_xattn_mask_transform, - tt_full_text_mask_expand_1NSH_transform, - tt_full_text_mask_expand_11SD_transform, - ) = self.model[i].transform_decode_inputs_device( - tt_h[i], - tt_rope_id[i], - tt_xattn_mask[i], - tt_full_text_mask_expand_1NSH[i], - tt_full_text_mask_expand_11SD[i], - B=B, - ) - - tt_logits_rm_i = self.model[i].ttnn_decode_forward( - tt_h_transform, - tt_xattn_mask_transform, - tt_full_text_mask_expand_1NSH_transform, - tt_full_text_mask_expand_11SD_transform, - xattn_cache, - tt_position_id[i], - tt_rot_mats, - page_table=tt_page_table[i], - kv_cache=user_kv_cache, - cross_page_table=tt_cross_page_table[i], - ) - tt_logits_rm.append(tt_logits_rm_i) - ttnn.end_trace_capture(self.model_args[i].mesh_device, trace_id, cq_id=0) - logger.info("Done Capturing Decode Trace") - - return ( - trace_ids, - tt_logits_rm, - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - tt_full_text_mask_expand_11SD, - tt_position_id, - tt_rope_id, - tt_page_table, - tt_cross_page_table, - ) - - def _decode_forward_trace( - self, - position_id, - tokens, - prefill_cross_attention_masks, - prefill_full_text_row_masked_out_mask, - decode_cross_attention_masks, - decode_full_text_row_masked_out_mask, - page_table, - cross_page_table, - trace_ids, - trace_logits_rm, - trace_h, - trace_xattn_mask, - trace_full_text_mask_expand_1NSH, - trace_full_text_mask_expand_11SD, - trace_position_id, - trace_rope_id, - trace_page_table, - trace_cross_page_table, - ): - """ - Executes the trace for the decode_forward method but does not read back outputs. - """ - for i in range(self.data_parallel): - user_page_table = page_table[i] if page_table is not None else None - user_cross_page_table = cross_page_table[i] if cross_page_table is not None else None - ( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - tt_full_text_mask_expand_11SD, - tt_position_id, - tt_rope_id, - tt_page_table, - tt_cross_page_table, - ) = self.model[i].prepare_decode_inputs_host( - tokens[i], - prefill_cross_attention_masks[i], - prefill_full_text_row_masked_out_mask[i], - decode_cross_attention_masks[i], - decode_full_text_row_masked_out_mask[i], - position_id=position_id[i], - page_table=user_page_table, - cross_page_table=user_cross_page_table, - ) - - copy_host_to_device( - host_tensors=( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - tt_full_text_mask_expand_11SD, - tt_position_id, - tt_rope_id, - tt_page_table, - tt_cross_page_table, - ), - device_tensors=( - trace_h[i], - trace_xattn_mask[i], - trace_full_text_mask_expand_1NSH[i], - trace_full_text_mask_expand_11SD[i], - trace_position_id[i], - trace_rope_id[i], - trace_page_table[i], - trace_cross_page_table[i], - ), - ) - for i, trace_id in trace_ids.items(): - ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False) - - return trace_logits_rm - - def _easy_trace( - self, - position_id, - tokens, - prefill_cross_attention_masks, - prefill_full_text_row_masked_out_mask, - decode_cross_attention_masks, - decode_full_text_row_masked_out_mask, - xattn_caches=None, - page_table=None, - kv_cache=None, - cross_page_table=None, - ): - """ - Tracing is easy! Just call this method and we'll handle tracing for you. - """ - if not hasattr(self, "trace_ids"): - ( - trace_ids, - tt_logits_rm, - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - tt_full_text_mask_expand_11SD, - tt_position_id, - tt_rope_id, - tt_page_table, - tt_cross_page_table, - ) = self._capture_trace( - position_id, - tokens, - prefill_cross_attention_masks, - prefill_full_text_row_masked_out_mask, - decode_cross_attention_masks, - decode_full_text_row_masked_out_mask, - xattn_caches, - page_table=page_table, - kv_cache=kv_cache, - cross_page_table=cross_page_table, - ) - self.trace_ids = trace_ids - self.trace_inputs = { - "tt_h": tt_h, - "tt_xattn_mask": tt_xattn_mask, - "tt_full_text_mask_expand_1NSH": tt_full_text_mask_expand_1NSH, - "tt_full_text_mask_expand_11SD": tt_full_text_mask_expand_11SD, - "tt_position_id": tt_position_id, - "tt_rope_id": tt_rope_id, - "tt_page_table": tt_page_table, - "tt_cross_page_table": tt_cross_page_table, - } - self.trace_outputs = { - "tt_logits_rm": tt_logits_rm, - } - - trace_logits_rm = self._decode_forward_trace( - position_id, - tokens, - prefill_cross_attention_masks, - prefill_full_text_row_masked_out_mask, - decode_cross_attention_masks, - decode_full_text_row_masked_out_mask, - page_table, - cross_page_table, - self.trace_ids, - self.trace_outputs["tt_logits_rm"], - self.trace_inputs["tt_h"], - self.trace_inputs["tt_xattn_mask"], - self.trace_inputs["tt_full_text_mask_expand_1NSH"], - self.trace_inputs["tt_full_text_mask_expand_11SD"], - self.trace_inputs["tt_position_id"], - self.trace_inputs["tt_rope_id"], - self.trace_inputs["tt_page_table"], - self.trace_inputs["tt_cross_page_table"], - ) - - return trace_logits_rm - - def generate( - self, - model_input, - max_gen_len: int, - temperature: float = 0.6, - top_p: float = 0.9, - ): - # Do initial prefill - vision_images = model_input.vision.images - vision_mask = model_input.vision.mask - prompt_tokens = model_input.tokens - prefill_len = len(prompt_tokens) - total_len = prefill_len + max_gen_len # Prepares mask for full length of output - - prompt_tokens_tensor = torch.tensor(prompt_tokens, dtype=torch.long).reshape(1, -1) # B, S - # Suboptimal to allocate caches every time - model_id = 0 - xattn_caches = self.model[model_id].setup_cache(self.model_args[model_id].max_batch_size) - ( - xattn_caches, - prefill_cross_attention_masks, - prefill_full_text_row_masked_out_mask, - decode_cross_attention_masks, - decode_full_text_row_masked_out_mask, - logits, - ) = self._prefill_forward_single_user( - vision_images, - vision_mask, - prompt_tokens_tensor, - xattn_caches, - user_id=0, - total_len=total_len, - prefill_len=prefill_len, - model_id=model_id, - ) - - last_token_idx = prefill_len - 1 - logits = self.model[model_id].process_output_prefill(logits, 1, last_token_idx=(last_token_idx % 32)) - logits = logits.view(1, 1, self.model_args[model_id].vocab_size) - - prefill_output_xattn_masks = [[] for _ in range(self.data_parallel)] - prefill_output_full_text_row_masked_out_masks = [[] for _ in range(self.data_parallel)] - decode_output_xattn_masks = [[] for _ in range(self.data_parallel)] - decode_output_full_text_row_masked_out_masks = [[] for _ in range(self.data_parallel)] - - prefill_output_xattn_masks[model_id].append(prefill_cross_attention_masks) - prefill_output_full_text_row_masked_out_masks[model_id].append(prefill_full_text_row_masked_out_mask) - decode_output_xattn_masks[model_id].append(decode_cross_attention_masks) - decode_output_full_text_row_masked_out_masks[model_id].append(decode_full_text_row_masked_out_mask) - - def sample(logits): - if temperature > 0: - probs = torch.softmax(logits[:, -1] / temperature, dim=-1) - next_token = sample_top_p(probs, top_p) - else: - next_token = torch.argmax(logits[:, -1], dim=-1) - next_token = next_token.reshape(-1) - return next_token, self.tokenizer.decode(next_token.tolist()) - - next_token, text = sample(logits) - - yield TokenResult( - token=next_token[0].item(), - text=text, - ) - - for gen_idx in range(max_gen_len - 1): - position_id = torch.tensor([prefill_len + gen_idx]) - next_token_tensor = next_token.reshape(1, 1) # B, S - - logits = self.decode_forward( - position_id, - next_token_tensor, - prefill_output_xattn_masks, - prefill_output_full_text_row_masked_out_masks, - decode_output_xattn_masks, - decode_output_full_text_row_masked_out_masks, - [xattn_caches], - enable_trace=False, - ) - next_token, text = sample(logits) - yield TokenResult( - token=next_token[0].item(), - text=text, - ) - - def chat_completion( - self, - messages, - temperature=0.6, - top_p: float = 0.9, - max_gen_len=None, - ): - model_id = 0 - if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model[model_id].configuration.max_seq_len: - max_gen_len = self.model[model_id].configuration.max_seq_len - 1 - - tokens = [] - - stop_reason = None - for result in self.generate( - model_input=self.formatter.encode_dialog_prompt(messages, tool_prompt_format=False), - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - ): - tokens.append(result.token) - if result.text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - elif result.text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - - if stop_reason is None: - stop_reason = StopReason.out_of_tokens - - message = self.formatter.decode_assistant_message(tokens, stop_reason) - - return ChatPrediction(generation=message) - - def text_completion( - self, - content: InterleavedTextMedia, - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len=None, - ): - model_id = 0 - if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model[model_id].configuration.max_seq_len: - max_gen_len = self.model[model_id].configuration.max_seq_len - 1 - - model_input = self.formatter.encode_content(content) - - tokens = [] - - for result in self.generate( - model_input=model_input, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - ): - tokens.append(result.token) - - generation = self.tokenizer.decode(tokens) - - return CompletionPrediction(generation=generation) - - def _get_prefill_user_page_table(self, page_table, kv_cache, prefill_len): - # Ensure page_table is not padded with extra blocks for paged_fill_cache to work properly - block_size = get_block_size(kv_cache) - num_blocks = num_blocks_in_seq(prefill_len, block_size) - return page_table[:, :num_blocks] - - ## Destructor - - def __del__(self): - # Workaround for issue #19052 - if self.data_parallel > 1: - for m in self.model: - ttnn.close_mesh_device(m.mesh_device) - - if hasattr(super(Gemma3Generator, self), "__del__"): - super().__del__() - - -def create_submeshes(mesh_device, data_parallel): - if not isinstance(mesh_device, ttnn.MeshDevice) or data_parallel == 1: - return [mesh_device] - - num_rows, num_cols = mesh_device.shape - num_devices = num_rows * num_cols - assert num_devices % data_parallel == 0, f"Unsupported device split: {num_devices} devices, {data_parallel} groups" - - if num_rows == 8 and num_cols == 4 and num_cols % data_parallel == 0: - submeshes = mesh_device.create_submeshes(ttnn.MeshShape(num_rows, num_cols // data_parallel)) - for submesh in submeshes: - submesh.reshape(ttnn.MeshShape(1, num_devices // data_parallel)) - return submeshes - - return mesh_device.create_submeshes(ttnn.MeshShape(1, num_devices // data_parallel)) diff --git a/models/experimental/gemma3/tt/lm_head.py b/models/experimental/gemma3/tt/lm_head.py deleted file mode 100644 index 5169245137fa..000000000000 --- a/models/experimental/gemma3/tt/lm_head.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -source: models/tt_transformers/tt/lm_head.py - -This is the LMHead module for the Gemma3 model. -""" -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 - -import math - -import torch - -import ttnn -from models.common.lightweightmodule import LightweightModule -from models.tt_transformers.tt.ccl import tt_all_reduce - - -class LMHead(LightweightModule): - def __init__( - self, - args, - mesh_device, - tt_ccl, - dtype, - state_dict, - state_dict_prefix, - weight_cache_path, - max_columns_per_device, # too many columns per device lead to L1 OOM - ): - super().__init__() - self.args = args - self.mesh_device = mesh_device - self.tt_ccl = tt_ccl - self.dtype = dtype - self.vocab_size = args.vocab_size - self.padded_vocab_size = args.padded_vocab_size - self.num_devices = args.num_devices - - size_per_device = self.vocab_size // self.num_devices - - if args.is_galaxy: - size_per_device = self.padded_vocab_size // self.num_devices - num_splits = math.ceil(size_per_device / max_columns_per_device) - - split_sizes = [min(size_per_device, max_columns_per_device)] * (num_splits - 1) - split_sizes.append(size_per_device - sum(split_sizes)) # remaining columns - - # Split the output weights - torch_output_weights = state_dict[f"{state_dict_prefix}output.weight"].permute(1, 0) - - self.output_weights = [] - if args.is_galaxy: - cache_file_name = ( - None if args.dummy_weights else weight_cache_path / f"output_lm_head_{num_splits}_split_shard_0" - ) - padded_lm_head = torch.zeros(1, 1, args.dim, self.padded_vocab_size) - padded_lm_head[:, :, :, : self.vocab_size] = torch_output_weights - - memory_config = ( - ttnn.DRAM_MEMORY_CONFIG - if args.dim == 2048 - else args.create_dram_sharded_mem_config(k=args.dim // 4, n=self.padded_vocab_size // 8) - ) - self.output_weights.append( # (2k, 16k) 128* 1024 - ttnn.as_tensor( - padded_lm_head, - device=mesh_device, - mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=(3, 2), mesh_shape=args.cluster_shape), - layout=ttnn.TILE_LAYOUT, - dtype=dtype, - memory_config=memory_config, - cache_file_name=cache_file_name, - ) - ) - else: - for i, split_size in enumerate(split_sizes): - # Create a list to store the split tensors for each device - device_splits = [] - for device in range(self.num_devices): - start = device * size_per_device + sum(split_sizes[:i]) - end = start + split_size - device_splits.append(torch_output_weights[:, start:end]) - - # Concatenate the splits from all devices - combined_split = torch.cat(device_splits, dim=-1) - - cache_file_name = ( - None - if args.dummy_weights - else weight_cache_path / f"output_lm_head_{num_splits}_split_shard_{i}_{combined_split.shape[-1]}" - ) - memory_config = args.create_dram_sharded_mem_config( - k=args.dim, n=math.ceil(combined_split.shape[-1] / self.num_devices) - ) - self.output_weights.append( - ttnn.as_tensor( - combined_split, - device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), - layout=ttnn.TILE_LAYOUT, - dtype=dtype, - memory_config=memory_config, - cache_file_name=cache_file_name, - ) - ) - - self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi2, - math_approx_mode=False, - fp32_dest_acc_en=False, - packer_l1_acc=True, - ) - if args.is_galaxy: - self.program_configs = [ - ( - None - if args.dim == 2048 - else args.dram_matmul_config( - args.tile_padded_batch_rows, # (8k, 128k) -> (2k, 16k) - args.dim // 4, - 16 * 1024, - args.lm_head_core_grid.num_cores, - ) - ) - ] - - else: - self.program_configs = [ - args.dram_matmul_config( - args.tile_padded_batch_rows, - args.dim, - split_size, - args.lm_head_core_grid.num_cores, - ) - for split_size in split_sizes - ] - - def forward(self, x: ttnn.Tensor): - outputs = [] - for weight, pc in zip(self.output_weights, self.program_configs): - output = ttnn.linear( - x, - weight, - compute_kernel_config=self.compute_kernel_config, - program_config=pc, - memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - dtype=ttnn.bfloat8_b, - ) - outputs.append(ttnn.sharded_to_interleaved(output, memory_config=ttnn.DRAM_MEMORY_CONFIG)) - - # Concatenate the outputs - output = ttnn.concat(outputs, dim=-1, memory_config=ttnn.DRAM_MEMORY_CONFIG) - - output = tt_all_reduce( - output, - self.mesh_device, - self.tt_ccl, - cluster_axis=1, - dim=3 if self.args.is_galaxy else 0, - num_reduce_scatter_links=self.args.num_reduce_scatter_links, - num_all_gather_links=self.args.num_all_gather_links, - memory_config=ttnn.L1_MEMORY_CONFIG, - dtype=self.args.ccl_dtype, - sharded=False, - use_composite=True, - ) - - return output diff --git a/models/experimental/gemma3/tt/mlp.py b/models/experimental/gemma3/tt/mlp.py deleted file mode 100644 index 440b1ad1b7f1..000000000000 --- a/models/experimental/gemma3/tt/mlp.py +++ /dev/null @@ -1,300 +0,0 @@ -""" -source: models/tt_transformers/tt/mlp.py - -This is the implementation of MLP (feed-forward) submodule of Gemma3. - -We have re-used the MLP implementation of the TT-Transformers library with few modifications. -This implementation has changes in Data Type (bfloat16). -""" - -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 - -import torch - -import ttnn -from models.common.lightweightmodule import LightweightModule -from models.tt_transformers.tt.ccl import tt_all_reduce -from models.tt_transformers.tt.common import pad_to_size -from models.tt_transformers.tt.model_config import OpGroup, TensorGroup - - -class MLP(LightweightModule): - def __init__( - self, - mesh_device, - tt_ccl, - args, - state_dict, - weight_cache_path, - layer_num, - dtype, - model_config, - state_dict_prefix=None, - ): - super().__init__() - - self.mesh_device = mesh_device - self.tt_ccl = tt_ccl - self.args = args - self.num_devices = args.num_devices - self.dim = args.dim - self.model_config = model_config - self.layer_num = layer_num - state_dict_prefix = state_dict_prefix or args.get_state_dict_prefix(self.__class__.__name__, layer_num) - torch_weight = lambda name: torch.transpose(state_dict[f"{state_dict_prefix}.{name}.weight"], -2, -1) - pad_hidden_dim = lambda tensor, dim: pad_to_size(tensor, dim=dim, size=args.hidden_dim) - # If pading was applied (e.g. via env var), add the unpadded hidden dim to the cache name to avoid loading incorrect weights - hidden_dim_string = f".hidden_dim_{args.hidden_dim}" if args.hidden_dim != args.unpadded_hidden_dim else "" - - if args.dummy_weights: - cache_name = lambda _: None - else: - cache_name = lambda name: weight_cache_path / f"{state_dict_prefix}.{name}{hidden_dim_string}" - - w1_w3_mem_config = args.create_dram_sharded_mem_config(args.dim, args.hidden_dim // args.num_devices) - w2_mem_config = args.create_dram_sharded_mem_config(args.hidden_dim // args.num_devices, args.dim) - - # TODO Clean up this code. With sharding, we load the normal weights and then shard them - as_sharded_tensor = lambda name, type, dims: ttnn.as_tensor( - pad_hidden_dim( - torch_weight(name[:2]), dims[0] if args.is_galaxy else dims[-1] - ), # Grab only the wX part of the name - dtype=type, - device=self.mesh_device, - mesh_mapper=ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=args.cluster_shape), - layout=ttnn.TILE_LAYOUT, - memory_config=( - ttnn.DRAM_MEMORY_CONFIG if args.is_galaxy else w2_mem_config if "w2" in name else w1_w3_mem_config - ), - cache_file_name=cache_name(name), - ) - - # Sharded weights - w1_dims = (-1, -2) if args.is_galaxy else (-2, -1) - w2_dims = (-2, -1) if args.is_galaxy else (-1, -2) - - layer_num = max(layer_num, 0) # cross_block uses the configutation of the first decoder - - ff1_3_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=layer_num, tensor=TensorGroup.FF1_FF3 - ) - ff2_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=layer_num, tensor=TensorGroup.FF2 - ) - - self.w1 = as_sharded_tensor( - "w1_sharded", ff1_3_dtype, dims=w1_dims - ) # bfp4 normally ok here but sub .99 pcc for llama 3.1 weights - self.w2 = as_sharded_tensor("w2_sharded", ff2_dtype, dims=w2_dims) - self.w3 = as_sharded_tensor("w3_sharded", ff1_3_dtype, dims=w1_dims) - - # Default activation is SILU - self.activation_type = self.args.mlp_activation_type - - def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: - """ - w1 -> gate_proj - w2 -> down_proj - w3 -> up_proj - HF reference: self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - """ - seq_len = x.shape[-2] - TG = self.args.is_galaxy - layer_num = max(self.layer_num, 0) # cross_block uses the configutation of the first decoder - activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=layer_num, tensor=TensorGroup.ACTIVATION - ) - li_ff1_3_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.LI_FF1_FF3, configuration=self.args - ) - - if mode == "decode": # Sharded config - if TG: # TODO: Fix this when TG supports DRAM sharded matmuls - pc_1 = self.model_config["FF1_3_TG_PROGCFG"] if self.dim >= 4096 else None - pc_2 = self.model_config["FF2_TG_PROGCFG"] if self.dim >= 4096 else None - pc_3 = self.model_config["FF1_3_TG_PROGCFG"] if self.dim >= 4096 else None - else: - pc_1 = self.model_config["DECODE_MLP_W1_W3_PRG_CONFIG"] - pc_2 = self.model_config["DECODE_MLP_W2_PRG_CONFIG"] - pc_3 = self.model_config["DECODE_MLP_W1_W3_PRG_CONFIG"] - else: # Update the program configs based for prefill - if seq_len >= self.args.prefill_len_cutoff: # 512 if Blackhole, 1024 if Wormhole - # Reshape input to to fit on device and parallelize computation - x = ttnn.reshape(x, [1, seq_len // self.args.prefill_len_cutoff, self.args.prefill_len_cutoff, -1]) - pc_1 = self.model_config["PREFILL_MLP_W1_W3_PRG_CONFIG"](seq_len) - pc_2 = self.model_config["PREFILL_MLP_W2_PRG_CONFIG"](seq_len) - pc_3 = self.model_config["PREFILL_MLP_W1_W3_PRG_CONFIG"](seq_len) - - # In decode mode (seqlen <= 32) do DRAM sharded matmuls - # These use HiFi2; this drops 1 bit of the activations but would be FLOP-bound on 12 cores with HiFi4 - memory_config = ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG - w1_out = ttnn.linear( - x, - self.w1, - dtype=ttnn.bfloat8_b if TG else activation_dtype or ttnn.bfloat16, - core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_1 else None, - compute_kernel_config=li_ff1_3_compute_kernel_cfg, - program_config=pc_1, - memory_config=memory_config, - ) - - w3_out = ttnn.linear( - x, - self.w3, - dtype=ttnn.bfloat8_b if TG else activation_dtype or ttnn.bfloat16, - core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_3 else None, - compute_kernel_config=li_ff1_3_compute_kernel_cfg, - program_config=pc_3, - memory_config=memory_config, - ) - ttnn.deallocate(x) - - if TG: - # if mode == "decode" and self.dim!=8192: - # w1_out = ttnn.to_memory_config(w1_out, ttnn.DRAM_MEMORY_CONFIG) - # w3_out = ttnn.to_memory_config(w3_out, ttnn.DRAM_MEMORY_CONFIG) - if self.dim == 8192 or mode == "prefill": - input_mem_cfg = w1_out.memory_config() - - cluster_axis = 1 - w1_out = ttnn.experimental.reduce_scatter_minimal_async( - w1_out, - persistent_output_buffers=None, - dim=3, - multi_device_global_semaphore=self.tt_ccl.get_and_cycle_rs_semaphore_handles(cluster_axis), - barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(cluster_axis), - num_links=self.args.num_reduce_scatter_links, - cluster_axis=cluster_axis, - memory_config=self.model_config["FF1_OUT_REDUCE_SCATTER_MEMCFG"] if mode == "decode" else None, - intermediate_memory_config=ttnn.DRAM_MEMORY_CONFIG, - topology=ttnn.Topology.Linear, - chunks_per_sync=10, - num_workers_per_link=2, - num_buffers_per_channel=2, - ) - - w3_out = ttnn.experimental.reduce_scatter_minimal_async( - w3_out, - persistent_output_buffers=None, - dim=3, - multi_device_global_semaphore=self.tt_ccl.get_and_cycle_rs_semaphore_handles(cluster_axis), - barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(cluster_axis), - num_links=1, - cluster_axis=cluster_axis, - memory_config=self.model_config["FF1_OUT_REDUCE_SCATTER_MEMCFG"] if mode == "decode" else None, - intermediate_memory_config=ttnn.DRAM_MEMORY_CONFIG, - topology=ttnn.Topology.Linear, - chunks_per_sync=10, - num_workers_per_link=2, - num_buffers_per_channel=2, - ) - else: - w1_out = tt_all_reduce( - w1_out, - self.mesh_device, - self.tt_ccl, - cluster_axis=1, - num_all_gather_links=2, - sharded=True if mode == "decode" else False, - topology=self.args.ccl_topology(), - memory_config=self.model_config["FF1_OUT_GATHERED_MEMCFG"] if mode == "decode" else None, - ) - w3_out = tt_all_reduce( - w3_out, - self.mesh_device, - self.tt_ccl, - cluster_axis=1, - num_all_gather_links=2, - sharded=True if mode == "decode" else False, - topology=self.args.ccl_topology(), - memory_config=self.model_config["FF1_OUT_GATHERED_MEMCFG"] if mode == "decode" else None, - ) - - w2_in = ttnn.mul( - w1_out, - w3_out, - input_tensor_a_activations=[self.activation_type], - dtype=activation_dtype or ttnn.bfloat8_b, - memory_config=w1_out.memory_config(), - ) - - if mode == "decode" and not TG: - # w2 may use a different core grid, this is a no-op if they already match - w2_in = ttnn.to_memory_config(w2_in, self.model_config["SHARDED_MLP2_INPUT_MEMCFG"]) - - ttnn.deallocate(w3_out) - ttnn.deallocate(w1_out) - - if TG and (self.dim == 8192 or mode == "prefill"): - cluster_axis = 1 - w2_in = ttnn.experimental.all_gather_async( - w2_in, - persistent_output_buffer=None, - dim=3, - multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(cluster_axis), - num_links=2, - cluster_axis=1, - topology=ttnn.Topology.Linear, - memory_config=input_mem_cfg, - barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(cluster_axis), - chunks_per_sync=10, - num_workers_per_link=2, - num_buffers_per_channel=2, - ) - - if mode == "decode": - w2_in = ttnn.to_memory_config(w2_in, ttnn.L1_MEMORY_CONFIG) - - li_ff2_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.LI_FF2, configuration=self.args - ) - w2_out = ttnn.linear( - w2_in, - self.w2, - compute_kernel_config=li_ff2_compute_kernel_cfg, - dtype=self.args.ccl_dtype if TG else activation_dtype or ttnn.bfloat16, - program_config=pc_2, - memory_config=memory_config, - core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_2 else None, - ) - ttnn.deallocate(w2_in) - - w2_out = ttnn.multiply(w2_out, self.num_devices) - - # if mode == "decode" and not TG: - # w2_out = ttnn.sharded_to_interleaved(w2_out, ttnn.DRAM_MEMORY_CONFIG) - w2_out_reduced = tt_all_reduce( - w2_out, - self.mesh_device, - self.tt_ccl, - cluster_axis=0, - dim=0 if (TG and self.dim < 8192) else 3, - num_reduce_scatter_links=self.args.num_reduce_scatter_links, - num_all_gather_links=self.args.num_all_gather_links, - sharded=(mode == "decode"), - memory_config=( - (self.model_config["FF2_OUT_REDUCE_SCATTER_MEMCFG"] if TG else w2_out.memory_config()) - if mode == "decode" - else ttnn.DRAM_MEMORY_CONFIG - ), - dtype=self.args.ccl_dtype, - use_composite=True if self.dim == 8192 else False, - topology=self.args.ccl_topology(), - ) - w2_out_reduced = ttnn.div(w2_out_reduced, self.num_devices) - - # Ensure dim 0 and 1 are 1 - original_shape = w2_out_reduced.shape - w2_out_reduced = ttnn.reshape( - w2_out_reduced, (1, 1, original_shape[-4] * original_shape[-3] * original_shape[-2], original_shape[-1]) - ) - if mode == "decode": - w2_out_reduced = ttnn.to_memory_config( - w2_out_reduced, - self.model_config["SHARDED_ATTN_INPUT_MEMCFG"] if TG else self.model_config["DECODE_RESIDUAL_MEMCFG"], - ) - - # ttnn.deallocate(w2_out) - return w2_out_reduced diff --git a/models/experimental/gemma3/tt/rmsnorm.py b/models/experimental/gemma3/tt/rmsnorm.py deleted file mode 100644 index a61f13836f2d..000000000000 --- a/models/experimental/gemma3/tt/rmsnorm.py +++ /dev/null @@ -1,179 +0,0 @@ -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 -import ttnn -from models.common.lightweightmodule import LightweightModule - -TILE = 32 -SHARD_HEIGHT = TILE # Current ttnn.rms_norm implementation requires shard height to be a single tile - - -class RMSNorm(LightweightModule): - """ - RMSNorm supporting replication over a MeshDevice and sharding within devices. - - This class implements a Root Mean Square Normalization (RMSNorm) that can be - distributed across multiple devices and cores. If the `device` parameter is a - MeshDevice, the weights and computations are replicated across all devices in - the mesh. Expects an interleaved input tensor, can optionally output a sharded tensor. - - Args: - device: The device or MeshDevice on which to perform the computations. - state_dict: The state dictionary containing the model parameters. - dim: Input dimension (e.g. model hidden dimension size). - layer_num: The layer number to determine the weight key in the state dictionary. - weight_key: The key for retrieving the weight from the state dictionary. - weight_cache_path: Optional path for caching the tilized weights. - weight_memory_config: Configuration for the weight memory, default is DRAM_MEMORY_CONFIG. - weight_dtype: The data type for the tensors, bfp8_b hits >0.999 PCC in the models we tested. - model_config: Optional configuration dictionary for the model. - eps (float): Small value to avoid division by zero in normalization, default is 1e-05. - - If model_config is provided, it must specify SHARDED_NORM_INPUT_MEMCFG, SHARDED_NORM_PRGM_CFG - and SHARDED_NORM_OUTPUT_MEMCFG. If not provided, default configurations will be generated. - """ - - def __init__( - self, - device, - dim, - state_dict, - weight_key, - layer_num=None, - state_dict_prefix=None, - weight_cache_path=None, - weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, - weight_dtype=ttnn.bfloat16, - is_distributed=None, - eps: float = 1e-06, - add_unit_offset=False, - sharded_program_config=None, - sharded_output_config=None, - output_mem_config=None, - ccl_topology=ttnn.Topology.Ring, - tt_ccl=None, - ): - super().__init__() - self.device = device - self.eps = eps - self.is_distributed = is_distributed - self.ccl_topology = ccl_topology - self.tt_ccl = tt_ccl - - if state_dict_prefix: - weight_name = f"{state_dict_prefix}{weight_key}.weight" - else: - if layer_num is None: - weight_name = f"{weight_key}.weight" - else: - weight_name = f"layers.{layer_num}.{weight_key}.weight" - - torch_weight = ( - state_dict[weight_name].unsqueeze(0).view(1, 1, dim).reshape([1, 1, dim // SHARD_HEIGHT, SHARD_HEIGHT]) - ) - - # Add offset before caching - if add_unit_offset: - torch_weight = torch_weight + 1.0 - - cache_name = None if weight_cache_path is None else weight_cache_path / weight_name - - # Compatibility with models that don't use mesh devices (e.g. single-chip Mistral-7b) - is_mesh_device = device.__class__.__name__ == "MeshDevice" - - self.weight = ttnn.as_tensor( - torch_weight, - device=device, - dtype=weight_dtype, - layout=ttnn.ROW_MAJOR_LAYOUT, - memory_config=weight_memory_config, - cache_file_name=cache_name, - mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, - ) - - if self.is_distributed: - self.weight_distributed = ttnn.as_tensor( - torch_weight, - device=device, - dtype=weight_dtype, - layout=ttnn.ROW_MAJOR_LAYOUT, - memory_config=weight_memory_config, - cache_file_name=( - None if weight_cache_path is None else weight_cache_path / (weight_name + "_distributed") - ), - mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape)) - if is_mesh_device - else None, - ) - - self.sharded_output_config = sharded_output_config - self.sharded_program_config = sharded_program_config - self.output_mem_config = output_mem_config - - self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi2, - math_approx_mode=False, - fp32_dest_acc_en=True, - packer_l1_acc=True, - ) - - def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) -> ttnn.Tensor: - # If input is sharded do sharded RMSNorm and optionally return sharded output - program_config = self.sharded_program_config if in_sharded else None - memory_config = self.sharded_output_config if out_sharded else None - distributed = self.is_distributed and self.is_distributed(mode) - norm = self._distributed_rmsnorm if distributed else ttnn.rms_norm - weight = self.weight_distributed if distributed else self.weight - - if in_sharded: - assert not distributed, "Distributed RMSNorm does not support sharded inputs" - else: - assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor" - - x = norm( - x, - epsilon=self.eps, - weight=weight, - program_config=program_config, - memory_config=memory_config, - compute_kernel_config=self.compute_kernel_config_hifi2, - ) - - if in_sharded and not out_sharded: - return ttnn.sharded_to_interleaved(x) - else: - return x - - def _distributed_rmsnorm( - self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None - ): - assert program_config is None, "Distributed RMSNorm does not support sharded inputs" - assert memory_config is None, "Distributed RMSNorm does not support sharded outputs" - - # Run distributed rmsnorm part 1 - tt_stats = ttnn.rms_norm_pre_all_gather(inp, compute_kernel_config=compute_kernel_config, dtype=ttnn.bfloat16) - # AllGather stats - tt_stats = ttnn.experimental.all_gather_async( - tt_stats, - persistent_output_buffer=None, - dim=3, - multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(), - num_links=1, - topology=self.ccl_topology, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(), - chunks_per_sync=10, - num_workers_per_link=2, - num_buffers_per_channel=2, - ) - # Run distributed rmsnorm part 2 - tt_out = ttnn.rms_norm_post_all_gather( - inp, - tt_stats, - epsilon=epsilon, - weight=weight, - compute_kernel_config=compute_kernel_config, - ) - tt_stats.deallocate(True) - - return tt_out diff --git a/models/experimental/gemma3/tt/text_model.py b/models/experimental/gemma3/tt/text_model.py deleted file mode 100644 index c0b033b15419..000000000000 --- a/models/experimental/gemma3/tt/text_model.py +++ /dev/null @@ -1,594 +0,0 @@ -""" - -This is the end-to-end implementation of the Gemma3 model. - -""" - -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 - -import ttnn -from models.experimental.gemma3.tt.rmsnorm import RMSNorm - -from models.common.lightweightmodule import LightweightModule -from models.tt_transformers.tt.embedding import Embedding, ScaledEmbedding -from models.tt_transformers.tt.rope import RotarySetup - -from models.experimental.gemma3.tt.decoder import TransformerBlock -from models.tt_transformers.tt.distributed_norm import DistributedNorm -from tqdm import tqdm -import torch -from models.experimental.gemma3.tt.lm_head import LMHead -from models.tt_transformers.tt.model_config import TensorGroup -from models.tt_transformers.tt.common import copy_host_to_device -from models.utility_functions import nearest_32 -from models.tt_transformers.tt.ccl import TT_CCL - - -class Gemma3Transformer(LightweightModule): - def __init__( - self, - args, - dtype, - mesh_device, - state_dict, - weight_cache_path, - paged_attention_config=None, - use_paged_kv_cache=False, - attention_class=None, - rope_setup_class=None, - attn_mask=None, - ): - super().__init__() - self.args = args - self.paged_attention_config = paged_attention_config - self.vocab_size = args.vocab_size - self.tt_ccl = TT_CCL(mesh_device) - assert self.vocab_size > 0 - self.n_layers = args.n_layers - self.mesh_device = mesh_device - self.dtype = dtype - self.model_config = args.get_model_config() - self.grid_size = self.args.max_grid_size - state_dict_prefix = args.get_state_dict_prefix("", None) - self.tt_ccl = TT_CCL(self.mesh_device) - - embd_kwargs = { - "mesh_device": mesh_device, - "args": args, - "weight_cache_path": args.weight_cache_path(dtype), - "state_dict": state_dict, - "dtype": ttnn.bfloat16, # Row major layout requires bfloat16 - } - if self.args.embed_scale is not None: - embd_cls = ScaledEmbedding - embd_kwargs["embed_scale"] = self.args.embed_scale - else: - embd_cls = Embedding - self.embd = embd_cls(**embd_kwargs) - - ActualRopeSetupClass = rope_setup_class if rope_setup_class is not None else RotarySetup - self.rope_setup = ActualRopeSetupClass( - device=mesh_device, - batch_size=args.max_batch_size, - head_dim=args.head_dim, - max_seq_len=args.max_seq_len, - rope_theta=args.rope_theta, - rope_scaling=args.rope_scaling, - ) - - if args.rope_theta_local: - self.rope_local_setup = RotarySetup( - mesh_device, - args.max_batch_size, - args.head_dim, - args.max_seq_len, - args.rope_theta_local, - rope_scaling=None, - ) - - self.trans_mats_dict = self.rope_setup.get_both_trans_mats() - - self.layers = [ - TransformerBlock( - args=args, - mesh_device=mesh_device, - tt_ccl=self.tt_ccl, - dtype=dtype, - state_dict=state_dict, - weight_cache_path=weight_cache_path, - layer_num=i, - transformation_mats=self.trans_mats_dict, - paged_attention_config=paged_attention_config, - use_paged_kv_cache=use_paged_kv_cache, - attention_class=attention_class, - ) - for i in tqdm(range(self.n_layers)) - ] - self.norm = DistributedNorm( - RMSNorm( - device=mesh_device, - dim=args.dim, - eps=args.norm_eps, - state_dict=state_dict, - state_dict_prefix=args.get_state_dict_prefix("", None), - weight_cache_path=None if args.dummy_weights else weight_cache_path, - weight_dtype=ttnn.bfloat16, - weight_key="norm", - add_unit_offset=self.args.rms_norm_add_unit_offset, - is_distributed=self.args.is_distributed_norm, - sharded_program_config=self.model_config["SHARDED_NORM_LM_HEAD_PRGM_CFG"], - sharded_output_config=self.model_config["LM_HEAD_INPUT_MEMCFG"], - ccl_topology=self.args.ccl_topology(), - tt_ccl=self.tt_ccl, - ), - args, - self.tt_ccl, - args.is_galaxy, - ) - - self.lm_head = LMHead( - args=args, - mesh_device=mesh_device, - tt_ccl=self.tt_ccl, - dtype=dtype, - state_dict=state_dict, - state_dict_prefix=state_dict_prefix, - weight_cache_path=weight_cache_path, - max_columns_per_device=self.args.max_columns_per_device_lm_head, - ) - - self.host_embed = self.args.reference_embedding() - - def setup_cache(self, max_batch_size): - self.cache_is_setup = True - - # Prepare xattn_caches - chunk_length = nearest_32(self.args.vision_chunk_ntok) - vision_seq_len = self.args.vision_max_num_chunks * chunk_length - xattn_cache = [ - [ - ttnn.from_torch( - torch.zeros(max_batch_size, self.args.n_kv_heads, vision_seq_len, self.args.head_dim), - device=self.mesh_device, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - dtype=ttnn.bfloat16, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=1), - ) - for _ in range(2) - ] - for l in range(len(self.cross_attention_layers)) - ] - - return xattn_cache - - def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): - """ - Inputs are torch tensors or python types. This function returns ttnn - tensors on device. - TODO: Debate whether this function is responsible for padding - """ - if not kwargs.get("processed_inputs", None): - tokens = tokens.reshape(1, 1, 1, -1) - S = tokens.shape[-1] - tokens = ttnn.from_torch( - tokens, - device=self.mesh_device, - dtype=ttnn.uint32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - tokens_embd = self.embd(tokens) - else: - S = tokens.shape[-1] - - tokens_embd = ttnn.from_torch( - tokens.reshape(1, 1, 1, -1), - device=self.mesh_device, - dtype=ttnn.uint32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - - tokens_embd = self.embd(tokens_embd) - - pixel_values = kwargs["processed_inputs"]["pixel_values"] - input_ids = kwargs["processed_inputs"]["input_ids"] - if pixel_values is not None: - vision_model = kwargs["vision_model"] - vision_output = vision_model(pixel_values) - - tokens_embd = ttnn.to_torch( - tokens_embd, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1) - ) - - comp_vision_output = ttnn.to_torch( - vision_output, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1) - )[:, :, :, : vision_output.shape[-1]] - - comp_vision_output = torch.nn.functional.pad( - comp_vision_output, (0, 0, 0, tokens_embd.shape[1] - comp_vision_output.shape[1]), "constant", 0 - ) - input_ids = torch.nn.functional.pad( - input_ids, (0, tokens_embd.shape[1] - input_ids.shape[1]), "constant", 0 - ) - image_features = comp_vision_output.squeeze(0) - special_image_mask = (input_ids == self.args.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(tokens_embd) - image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) - tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) - - tokens_embd = ttnn.from_torch( - tokens_embd, - dtype=ttnn.bfloat16, - device=self.mesh_device, - layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh( - self.mesh_device, dims=(None, -1), mesh_shape=list(self.mesh_device.shape) - ), - ) - - tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) - # Slice the rot mats to the prefill seqlen - assert ( - self.rope_setup.cos_matrix.shape[2] >= start_pos + S - ), f"Padded prefill end idx {start_pos + S} exceeds max seq len {self.rope_setup.cos_matrix.shape[2]}" - - tt_rot_mats_prefill_global = [ - self.rope_setup.cos_matrix[:, :, start_pos : start_pos + S, :], - self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], - ] - - if hasattr(self, "rope_local_setup"): - tt_rot_mats_prefill_local = [ - self.rope_local_setup.cos_matrix[:, :, start_pos : start_pos + S, :], - self.rope_local_setup.sin_matrix[:, :, start_pos : start_pos + S, :], - ] - else: - tt_rot_mats_prefill_local = None - - if page_table is not None: - tt_page_table = ttnn.from_torch( - page_table, - device=self.mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - else: - tt_page_table = None - - if chunk_page_table is not None: - tt_chunk_page_table = ttnn.from_torch( - chunk_page_table, - device=self.mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - else: - tt_chunk_page_table = None - - return ( - tokens_embd, - tt_rot_mats_prefill_global, - tt_rot_mats_prefill_local, - tt_page_table, - tt_chunk_page_table, - ) - - def prepare_inputs_decode(self, *inputs): - """ - Inputs are torch tensors or python types. This function returns ttnn - tensors on device. - Its implementation can take advantage of a few other functions which the - model must implement. - """ - host_inputs = self.prepare_decode_inputs_host(*inputs) - device_inputs = copy_host_to_device(host_inputs, mesh_device=self.mesh_device) # Helper function - return device_inputs - - def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None): - """ - Inputs are torch tensors or python types. Outputs are ttnn tensors on host. - NOTE: Tokens and current_pos are padded to batch - """ - B = tokens.shape[0] - assert current_pos.shape[0] == B, "Batch size mismatch" - assert B == self.args.max_batch_size, "Batch size must be equal to max_batch_size" - - # Necessary padding to be full tile sized when on device - tokens = torch.nn.functional.pad(tokens.view(-1), (0, 32 - len(tokens)), "constant", 0) - tokens = ttnn.from_torch( - tokens, - device=None, - dtype=ttnn.uint32, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - tokens = ttnn.unsqueeze_to_4D(tokens) - - rot_current_pos = torch.maximum( - current_pos, torch.tensor(0, dtype=torch.int64) - ) # Ensure position indices are non-negative - rope_idxs_global = self.rope_setup.get_rot_idxs(rot_current_pos, on_host=True) - if hasattr(self, "rope_local_setup"): - rope_idxs_local = self.rope_local_setup.get_rot_idxs(rot_current_pos, on_host=True) - else: - rope_idxs_local = None - - current_pos_tt = ttnn.from_torch( - current_pos, - device=None, - dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - self.mesh_device, - dims=(None, 0) if (self.args.is_galaxy and B > 1) else (None, None), - mesh_shape=self.args.cluster_shape, - ), - ) - - if page_table is not None: - page_table = ttnn.from_torch( - page_table, - device=None, - dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - self.mesh_device, - dims=(None, -2) if (self.args.is_galaxy and B > 1) else (None, None), - mesh_shape=self.args.cluster_shape, - ), - ) - - return tokens, current_pos_tt, rope_idxs_global, rope_idxs_local, page_table - - def _transform_decode_inputs_device(self, tokens): - """ - Inputs are ttnn tensors on device. This function applies any on-device - transformations which should happen before forward decode. - For example: tilize, reshape, shard. - Return transformed device tensors - - Embed tokens - """ - tt_tokens = self.embd(tokens) - tt_tokens = ttnn.unsqueeze_to_4D(tt_tokens) - tt_tokens = ttnn.to_memory_config( - tt_tokens, - self.args.model_config["DECODE_RESIDUAL_MEMCFG"], - ) - return tt_tokens - - def process_output_prefill(self, tt_out, last_token_idx): - """ - Input is ttnn device tensor of logits. Output is torch logits tensor. - NOTE: In this model, prefill always uses get_last_token - """ - logits = ttnn.to_torch( - tt_out, - mesh_composer=ttnn.ConcatMesh2dToTensor( - self.mesh_device, dims=(3, 1) if self.args.is_galaxy else (1, -1), mesh_shape=self.args.cluster_shape - ), - )[0, 0, last_token_idx, : self.vocab_size] - return logits - - def process_output_decode(self, tt_out, B, S=1, is_tokens=False): - """ - Input is ttnn device tensor of logits if is_tokens=False, otherwise tokens. Output is the corresponding torch tensor. - """ - if is_tokens: - tt_out = ttnn.to_torch( - tt_out, # tt_out.cpu(blocking=True, cq_id=1), - mesh_composer=ttnn.ConcatMesh2dToTensor( - self.mesh_device, - dims=(3, 1) if self.args.is_galaxy else (1, -1), - mesh_shape=self.args.cluster_shape, - ), - )[0, 0, :B, 0] - return tt_out - - if self.args.num_devices > 1: - tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() - else: - tt_out = ttnn.to_torch(tt_out).float() - tt_out = tt_out[:, :, :B, : self.vocab_size].view(B, S, -1) - return tt_out - - def ttnn_prefill_forward( - self, - x, - rot_mats_global=None, - rot_mats_local=None, - user_id=0, - page_table=None, - chunk_page_table=None, - chunk_start_idx=None, - get_last_token=-1, - kv_cache=None, - ): - """ - This method will take device tensors and any other args to run forward. - It returns ttnn device tensors. - """ - if hasattr(self.args, "sliding_window") and self.args.sliding_window is not None: - mask = torch.triu(torch.full((1, 1, x.shape[-2], x.shape[-2]), -float("inf")), diagonal=1) - sliding_mask = mask + torch.tril( - torch.full((1, 1, x.shape[-2], x.shape[-2]), -float("inf")), - diagonal=-self.args.sliding_window, - ) - sliding_attn_mask = ttnn.from_torch( - sliding_mask, device=self.mesh_device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16 - ) - else: - sliding_attn_mask = None - - return self.forward( - x, - current_pos=None, - rot_mats_global=rot_mats_global, - rot_mats_local=rot_mats_local, - user_id=user_id, - mode="prefill", - page_table=page_table, - chunk_page_table=chunk_page_table, - chunk_start_idx=chunk_start_idx, - get_last_token=get_last_token, - kv_cache=kv_cache, - ) - - def _increment_decode_positions_device(self, current_pos, rot_mat_idxs_global, rot_mat_idxs_local): - # ttnn.ne currently requires the input to be in TILE_LAYOUT - current_pos_tiled = ttnn.to_layout(current_pos, layout=ttnn.TILE_LAYOUT) - # Update only active positions (current_pos != -1) - predicate = ttnn.ne(current_pos_tiled, -1) - result = ttnn.where( - predicate, - ttnn.add(current_pos_tiled, 1), - current_pos_tiled, - ) - ttnn.copy(ttnn.to_layout(result, layout=ttnn.ROW_MAJOR_LAYOUT), current_pos) - - ttnn.plus_one(rot_mat_idxs_global) - if rot_mat_idxs_local is not None: - ttnn.plus_one(rot_mat_idxs_local) - - def update_attention_masks(self, current_pos): - torch_mask = torch.concat( - [ - self.decode_sliding_mask_mat[i, :, current_pos[i].item() : current_pos[i].item() + 1, :].unsqueeze(0) - for i in range(self.decode_sliding_mask_mat.shape[0]) - ], - axis=0, - ).transpose(1, 2) - sliding_window_causal_mask = ttnn.as_tensor( - torch_mask, - dtype=ttnn.bfloat4_b, - layout=ttnn.TILE_LAYOUT, - device=None, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - ttnn.copy_host_to_device_tensor(sliding_window_causal_mask, self.device_decode_sliding_mask) - - def ttnn_decode_forward( - self, - x, - current_pos, - rot_mat_idxs_global=None, - rot_mat_idxs_local=None, - page_table=None, - kv_cache=None, - argmax_on_device=False, - ): - """ - This method will take device tensors and any other args to run forward. - It returns ttnn device tensors. - """ - rot_mats_global = self.rope_setup.get_rot_mats(rot_mat_idxs_global) - rot_mats_local = ( - self.rope_local_setup.get_rot_mats(rot_mat_idxs_local) if rot_mat_idxs_local is not None else None - ) - x_embed = self._transform_decode_inputs_device(x) - - tt_logits = self.forward( - x_embed, - current_pos, - rot_mats_global=rot_mats_global, - rot_mats_local=rot_mats_local, - mode="decode", - page_table=page_table, - kv_cache=kv_cache, - ) - - # Gather the output across all devices and untilize the tensor (for argmax) - if self.args.num_devices > 1: - cluster_axis = 0 if self.args.is_galaxy else None - num_links = 2 if self.args.is_galaxy else 1 - tt_logits = ttnn.experimental.all_gather_async( - tt_logits, - persistent_output_buffer=None, - dim=3, - multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(cluster_axis), - num_links=num_links, - memory_config=tt_logits.memory_config(), - cluster_axis=cluster_axis, - topology=self.args.ccl_topology(), - barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(cluster_axis), - chunks_per_sync=10, - num_workers_per_link=2, - num_buffers_per_channel=2, - ) - - tt_logits = ttnn.untilize(tt_logits, use_multicore=True) - - if argmax_on_device: - tt_logits = ttnn.argmax(tt_logits, dim=3, keepdim=True, use_multicore=True) - - # Update device tensors for the next iteration - self._increment_decode_positions_device(current_pos, rot_mat_idxs_global, rot_mat_idxs_local) - - # Update input tokens with sampled tokens for the next iteration - ttnn.copy(tt_logits.reshape(x.shape), x) - elif not self.args.is_galaxy: - # Send output logits to DRAM so L1 is not reserved for ttnn tracing and can be used by subsequent operations - tt_logits = ttnn.to_memory_config(tt_logits, ttnn.DRAM_MEMORY_CONFIG) - - return tt_logits - - def forward( - self, - x: ttnn.Tensor, - current_pos, - rot_mats_global=None, - rot_mats_local=None, - user_id=0, - mode="decode", - page_table=None, - chunk_page_table=None, - chunk_start_idx=None, - get_last_token=-1, - kv_cache=None, - ): - for i, layer in enumerate(self.layers): - # No-op if callers already provide the right memory config - activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=i, tensor=TensorGroup.ACTIVATION - ) - if mode == "decode" and not self.args.is_galaxy: - x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"], activation_dtype) - elif activation_dtype is not None and x.dtype != activation_dtype: - x = ttnn.typecast(x, activation_dtype) - - x = layer( - x, - current_pos, - rot_mats_global=rot_mats_global, - rot_mats_local=rot_mats_local, - user_id=user_id, - mode=mode, - page_table=page_table, - chunk_page_table=chunk_page_table, - chunk_start_idx=chunk_start_idx, - kv_cache=kv_cache[i] if kv_cache is not None else None, - ) - - if mode == "prefill" and get_last_token == -1: - return x - - # Slicing the tensor to the nearest ceiling/floor multiples of 32 for the prefill_len, to get the last token - if get_last_token != -1: - x = ttnn.slice(x, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, x.shape[-1])) - - # Output norm - x = self.norm(x, mode=mode) - - if mode == "prefill" and self.model_config["LM_HEAD_INPUT_MEMCFG"].is_sharded(): - x = ttnn.interleaved_to_sharded(x, self.model_config["LM_HEAD_INPUT_MEMCFG"]) - - x = self.lm_head(x) - - if mode == "prefill": - x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) - # x = ttnn.to_memory_config(x, memory_config=ttnn.DRAM_MEMORY_CONFIG) - return x diff --git a/models/tt_transformers/demo/simple_vision_demo.py b/models/tt_transformers/demo/simple_vision_demo.py index 7623dce74865..fb32456c314d 100644 --- a/models/tt_transformers/demo/simple_vision_demo.py +++ b/models/tt_transformers/demo/simple_vision_demo.py @@ -14,6 +14,7 @@ from pkg_resources import resource_filename from models.tt_transformers.tt.generator import create_submeshes +from models.tt_transformers.tt.model_config import ModelArgs IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/")) @@ -27,7 +28,9 @@ import ttnn from models.demos.utils.llm_demo_utils import create_benchmark_data, verify_perf from models.perf.benchmarking_utils import BenchmarkProfiler +from models.tt_transformers.tt.common import hf_multimodal_encode from models.tt_transformers.tt.generator import Generator +from models.tt_transformers.tt.model_config import CheckpointType def get_batch_sampler(temperature, top_p, tokenizer): @@ -60,11 +63,11 @@ def create_multimodal_model( use_paged_kv_cache=False, checkpoint=None, ): - from models.tt_transformers.tt.model_config import ModelArgs + from models.tt_transformers.tt.multimodal.gemma3.gemma_e2e_model import TtGemmaModel from models.tt_transformers.tt.multimodal.llama_vision_model import CrossAttentionTransformer tt_model_args = ModelArgs(mesh_device, max_batch_size=max_batch_size) - assert tt_model_args.is_llama_vision(), "This model is multimodal" + assert tt_model_args.is_multimodal, "This model is multimodal" # limit length or we'll run out of space tt_model_args.max_seq_len = max_seq_len @@ -76,14 +79,25 @@ def create_multimodal_model( if checkpoint is None: checkpoint = tt_model_args.load_state_dict() - model = CrossAttentionTransformer( - mesh_device, - state_dict=checkpoint, - weight_cache_path=tt_model_args.weight_cache_path(dtype), - dtype=dtype, - configuration=tt_model_args, - use_paged_kv_cache=use_paged_kv_cache, - ) + + if tt_model_args.base_model_name.startswith("gemma-3"): + model = TtGemmaModel( + mesh_device=mesh_device, + state_dict=checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(ttnn.bfloat8_b), + dtype=ttnn.bfloat8_b, + args=tt_model_args, + use_paged_kv_cache=use_paged_kv_cache, + ) + else: + model = CrossAttentionTransformer( + mesh_device, + state_dict=checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=tt_model_args, + use_paged_kv_cache=use_paged_kv_cache, + ) return tt_model_args, model, checkpoint @@ -135,7 +149,7 @@ def prepare_generator_args( ) @pytest.mark.parametrize( "test_type,max_seq_len", - (("normal", 512),), + (("normal", 2048),), ids=["normal"], ) @pytest.mark.parametrize( @@ -157,7 +171,9 @@ def prepare_generator_args( ], ) @pytest.mark.parametrize( - "device_params", [{"fabric_config": True, "trace_region_size": 17000000, "num_command_queues": 2}], indirect=True + "device_params", + [{"fabric_config": True, "trace_region_size": 32617088, "num_command_queues": 2, "l1_small_size": 24576}], + indirect=True, ) def test_multimodal_demo_text( mesh_device, @@ -178,24 +194,27 @@ def test_multimodal_demo_text( Simple multimodal demo with limited dependence on reference code. """ num_devices = mesh_device.get_num_devices() if isinstance(mesh_device, ttnn.MeshDevice) else 1 + tt_model_args = ModelArgs(mesh_device, max_batch_size=max_batch_size) - if num_devices == 2: - if max_batch_size == 1: - pytest.skip( - "Batch size=1 on N300 mesh experiences ND hangs: https://github.com/tenstorrent/tt-metal/issues/28247" - ) - if max_batch_size not in (4, 16): - pytest.skip(f"Batch size={max_batch_size} is not tested for N300 mesh") + # llama model only support on T3K right now and will skip if ran on N300 and N150 + if tt_model_args.is_llama_vision() is True: + if num_devices == 2: + if max_batch_size == 1: + pytest.skip( + "Batch size=1 on N300 mesh experiences ND hangs: https://github.com/tenstorrent/tt-metal/issues/28247" + ) + if max_batch_size not in (4, 16): + pytest.skip(f"Batch size={max_batch_size} is not tested for N300 mesh") if num_devices == 8 and max_batch_size not in (1, 4, 32): pytest.skip(f"Batch size={max_batch_size} is not tested for T3K mesh") - logger.info("Start profiler") profiler = BenchmarkProfiler() profiler.start("run") + assert not ( + max_batch_size == 32 and os.environ.get("MESH_DEVICE") == "N150" + ), "Run models with batch size of 1 when MESH_DEVICE is N150" - ckpt_dir = os.environ["LLAMA_DIR"] - tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") - + num_devices = mesh_device.get_num_devices() if isinstance(mesh_device, ttnn.MeshDevice) else 1 max_batch_size *= data_parallel # input batch_size is interpreted as size per DP group model_args, model = prepare_generator_args( @@ -204,11 +223,26 @@ def test_multimodal_demo_text( max_batch_size=max_batch_size, max_seq_len=max_seq_len, ) + + HF_MODEL = model_args[0].checkpoint_type == CheckpointType.HuggingFace + + if not HF_MODEL: + ckpt_dir = os.environ["LLAMA_DIR"] + tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") + + tokenizer = Tokenizer(model_path=tokenizer_path) + formatter = ChatFormat(tokenizer) + else: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_args[0].CKPT_DIR) + generator = Generator(model, model_args, mesh_device) - tokenizer = Tokenizer(model_path=tokenizer_path) - formatter = ChatFormat(tokenizer) - xattn_caches = [model.setup_cache(model_args[i].max_batch_size) for i, model in enumerate(generator.model)] + xattn_caches = [ + model.setup_cache(model_args[i].max_batch_size) if not HF_MODEL else None + for i, model in enumerate(generator.model) + ] # Create random images for trace capture with specific dimensions trace_img_560x560 = create_random_image(560, 560) @@ -266,10 +300,11 @@ def test_multimodal_demo_text( total_users = len(dialogs) num_batches = total_users // max_batch_size - sampler = get_batch_sampler(temperature, top_p, tokenizer) + sampler = get_batch_sampler(temperature, top_p, model_args[0].tokenizer) _num_prefill_tokens = 0 _num_decode_tokens = 0 + prompt_encoder = hf_multimodal_encode if HF_MODEL else formatter.encode_dialog_prompt for iter_num in range(warmup_iters + 1): logger.info(f"Iteration {iter_num}") current_dialogs = trace_dialogs + dialogs @@ -279,9 +314,14 @@ def test_multimodal_demo_text( for msg in dialog: logger.info(f"{msg.role.capitalize()}: {msg.content}\n") batch_model_input = [ - formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) for dialog in batch_dialogs + prompt_encoder(dialog, processor) if HF_MODEL else prompt_encoder(dialog, tool_prompt_format=False) + for dialog in batch_dialogs ] + if HF_MODEL: + # Use the processor's tokenizer instead of model_args tokenizer to ensure consistency + tokenizer = processor.tokenizer + # Do initial prefill vision_images = [ model_input.vision.images if model_input.vision else None for model_input in batch_model_input @@ -294,7 +334,8 @@ def test_multimodal_demo_text( total_lens = prefill_lens + max_gen_len # Create padded tokens tensor for batch - pad_id = tokenizer.pad_id + stop_tokens = model_args[0].tokenizer.stop_tokens + pad_id = tokenizer.pad_token_id if HF_MODEL else tokenizer.pad_id bsz = len(prompt_tokens) tokens = torch.full((bsz, max(total_lens)), pad_id, dtype=torch.long) @@ -374,6 +415,12 @@ def test_multimodal_demo_text( profiler.end("compile_decode", iteration=batch_idx) # Disable checking for eot until I have more robust code for batch > 1 + # if HF_MODEL: + # if next_tokens in stop_tokens: + # break + # else: + # # Disable checking for eot until I have more robust code for batch > 1 + # pass # if text in ["<|eot_id|>", "<|eom_id|>"]: # break _num_decode_tokens += ( @@ -381,12 +428,16 @@ def test_multimodal_demo_text( ) # gen_idx is (num_tokens - 1) to avoid counting compile iter # Log full text output for each user in batch - vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256] + if HF_MODEL: + # For HF models, get vision tokens from the processor if they exist + vision_tokens = [] + else: + vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256] for user_id in range(max_batch_size): # Remove <|image|> tokens since they break the tokenizer tokens_out = [ - t if t not in vision_tokens else tokenizer.pad_id + t if t not in vision_tokens else pad_id for t in tokens[user_id].tolist()[: position_id[user_id] + 2] ] text = tokenizer.decode(tokens_out) diff --git a/models/experimental/gemma3/tests/vision_tests/test_mmp.py b/models/tt_transformers/tests/multimodal/vision_tests/test_mmp.py similarity index 95% rename from models/experimental/gemma3/tests/vision_tests/test_mmp.py rename to models/tt_transformers/tests/multimodal/vision_tests/test_mmp.py index ed947aa9d899..4b537c22f730 100644 --- a/models/experimental/gemma3/tests/vision_tests/test_mmp.py +++ b/models/tt_transformers/tests/multimodal/vision_tests/test_mmp.py @@ -1,7 +1,7 @@ """Gemma3 Test for multi-modal-projector""" -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -13,8 +13,7 @@ import ttnn from models.tt_transformers.tt.model_config import ModelArgs -from models.experimental.gemma3.tt.mmp import TtGemma3MultiModalProjector - +from models.tt_transformers.tt.multimodal.gemma3.multi_modal_projector import TtGemma3MultiModalProjector from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull diff --git a/models/experimental/gemma3/tests/vision_tests/test_patch_embedding.py b/models/tt_transformers/tests/multimodal/vision_tests/test_patch_embedding.py similarity index 96% rename from models/experimental/gemma3/tests/vision_tests/test_patch_embedding.py rename to models/tt_transformers/tests/multimodal/vision_tests/test_patch_embedding.py index 72cf892842e2..1d3c2d7aa9ae 100644 --- a/models/experimental/gemma3/tests/vision_tests/test_patch_embedding.py +++ b/models/tt_transformers/tests/multimodal/vision_tests/test_patch_embedding.py @@ -1,20 +1,19 @@ """Gemma3 test for Vision Patch Embedding""" -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 import os import pytest +import torch from loguru import logger import ttnn -import torch from models.tt_transformers.tt.model_config import ModelArgs - -from models.experimental.gemma3.tt.gemma_conv2d_patch import TtGemmaConv2dPatch +from models.tt_transformers.tt.multimodal.gemma3.gemma_conv2d_patch import TtGemmaConv2dPatch from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull from ttnn import ConcatMeshToTensor diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_attention.py b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_attention.py similarity index 95% rename from models/experimental/gemma3/tests/vision_tests/test_vision_attention.py rename to models/tt_transformers/tests/multimodal/vision_tests/test_vision_attention.py index 42daa76f3bd8..f4b313c00076 100644 --- a/models/experimental/gemma3/tests/vision_tests/test_vision_attention.py +++ b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_attention.py @@ -1,7 +1,7 @@ """Gemma3 Test for Vision Attention""" -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 import os @@ -11,14 +11,13 @@ from loguru import logger import ttnn +from models.tt_transformers.tt.ccl import TT_CCL from models.tt_transformers.tt.load_checkpoints import ( # convert_vision_hf_to_meta, convert_hf_qkv_to_meta_format, convert_vision_hf_to_meta, ) from models.tt_transformers.tt.model_config import ModelArgs - -from models.tt_transformers.tt.ccl import TT_CCL -from models.experimental.gemma3.tt.gemma_image_attention import TtGemmaImageAttention +from models.tt_transformers.tt.multimodal.gemma3.gemma_image_attention import TtGemmaImageAttention from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_cross_attention_transformer.py b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_cross_attention_transformer.py similarity index 95% rename from models/experimental/gemma3/tests/vision_tests/test_vision_cross_attention_transformer.py rename to models/tt_transformers/tests/multimodal/vision_tests/test_vision_cross_attention_transformer.py index 751047cd2487..dc9444602a7d 100644 --- a/models/experimental/gemma3/tests/vision_tests/test_vision_cross_attention_transformer.py +++ b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_cross_attention_transformer.py @@ -1,7 +1,7 @@ """Gemma3 Test for Vision Transformer""" -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -13,8 +13,7 @@ import ttnn from models.tt_transformers.tt.model_config import ModelArgs - -from models.experimental.gemma3.tt.gemma_vision_crossattention import TtGemmaTransformerVision +from models.tt_transformers.tt.multimodal.gemma3.gemma_vision_model import TtGemmaTransformerVision from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_embedding.py b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_embedding.py similarity index 92% rename from models/experimental/gemma3/tests/vision_tests/test_vision_embedding.py rename to models/tt_transformers/tests/multimodal/vision_tests/test_vision_embedding.py index 3e6fb98642c7..ad1d36a97a74 100644 --- a/models/experimental/gemma3/tests/vision_tests/test_vision_embedding.py +++ b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_embedding.py @@ -1,7 +1,7 @@ """Gemma3 Test for Vision Embedding""" -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -13,8 +13,7 @@ import ttnn from models.tt_transformers.tt.model_config import ModelArgs - -from models.experimental.gemma3.tt.siglip_vision_embedding import TtSiglipVisionEmbeddings +from models.tt_transformers.tt.multimodal.gemma3.gemma_vision_embedding import TtGemmaVisionEmbeddings from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull from ttnn import ConcatMeshToTensor @@ -57,7 +56,7 @@ def test_vision_embedding_integration( # reference_model.load_state_dict(partial_state_dict) reference_output = reference_model(input_tensor) - vision_embed = TtSiglipVisionEmbeddings( + vision_embed = TtGemmaVisionEmbeddings( mesh_device=mesh_device, state_dict=state_dict, state_dict_prefix=first_layer_prefix, diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_layernorm.py b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_layernorm.py similarity index 98% rename from models/experimental/gemma3/tests/vision_tests/test_vision_layernorm.py rename to models/tt_transformers/tests/multimodal/vision_tests/test_vision_layernorm.py index c2b2d66891a3..24dc75b44a6e 100644 --- a/models/experimental/gemma3/tests/vision_tests/test_vision_layernorm.py +++ b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_layernorm.py @@ -1,7 +1,7 @@ """Gemma3 Test for Vision Layernorm""" -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_mlp.py b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_mlp.py similarity index 95% rename from models/experimental/gemma3/tests/vision_tests/test_vision_mlp.py rename to models/tt_transformers/tests/multimodal/vision_tests/test_vision_mlp.py index 880e436b747f..be6fe8a648d2 100644 --- a/models/experimental/gemma3/tests/vision_tests/test_vision_mlp.py +++ b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_mlp.py @@ -1,7 +1,7 @@ """Gemma3 Test for Vision MLP""" -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -12,10 +12,10 @@ from loguru import logger import ttnn +from models.tt_transformers.tt.ccl import TT_CCL from models.tt_transformers.tt.model_config import ModelArgs -from models.experimental.gemma3.tt.gemma_image_mlp import TtGemmaImageFeedForward +from models.tt_transformers.tt.multimodal.gemma3.gemma_image_mlp import TtGemmaImageFeedForward from models.utility_functions import comp_allclose, comp_pcc, nearest_32, skip_for_grayskull -from models.tt_transformers.tt.ccl import TT_CCL @skip_for_grayskull("Requires wormhole_b0 to run") diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_pipeline.py b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_pipeline.py similarity index 93% rename from models/experimental/gemma3/tests/vision_tests/test_vision_pipeline.py rename to models/tt_transformers/tests/multimodal/vision_tests/test_vision_pipeline.py index 48e062c4bb7a..c74b101e593c 100644 --- a/models/experimental/gemma3/tests/vision_tests/test_vision_pipeline.py +++ b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_pipeline.py @@ -1,7 +1,7 @@ """Gemma3 Test for Vision Model""" -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -13,8 +13,7 @@ import ttnn from models.tt_transformers.tt.model_config import ModelArgs - -from models.experimental.gemma3.tt.gemma_vision_model import TtSiglipGemmaVisionModel +from models.tt_transformers.tt.multimodal.gemma3.gemma_vision_block import TtGemmaVisionModel from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull @@ -66,7 +65,7 @@ def test_gemma_vision( # reference_model.load_state_dict(partial_state_dict) reference_output = reference_model(input_tensor).last_hidden_state - test_gemma_vision = TtSiglipGemmaVisionModel( + test_gemma_vision = TtGemmaVisionModel( mesh_device, state_dict, state_dict_prefix=first_layer_prefix, diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_rmsnorm.py b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_rmsnorm.py similarity index 95% rename from models/experimental/gemma3/tests/vision_tests/test_vision_rmsnorm.py rename to models/tt_transformers/tests/multimodal/vision_tests/test_vision_rmsnorm.py index a920f04980ad..c42a40d5cd86 100644 --- a/models/experimental/gemma3/tests/vision_tests/test_vision_rmsnorm.py +++ b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_rmsnorm.py @@ -1,23 +1,16 @@ """Gemma3 test for Vision RMSNorm""" -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC - -# SPDX-License-Identifier: Apache-2.0 - -from loguru import logger +import os -import torch import pytest -import os +import torch +from loguru import logger import ttnn -from models.experimental.gemma3.tt.gemma_vision_rmsnorm import RMSNorm - from models.tt_transformers.tt.ccl import TT_CCL - - -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.gemma3.gemma_vision_rmsnorm import RMSNorm +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull @torch.no_grad() diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_transformer.py b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_transformer.py similarity index 95% rename from models/experimental/gemma3/tests/vision_tests/test_vision_transformer.py rename to models/tt_transformers/tests/multimodal/vision_tests/test_vision_transformer.py index 2f7ef9521c93..cf41727c0395 100644 --- a/models/experimental/gemma3/tests/vision_tests/test_vision_transformer.py +++ b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_transformer.py @@ -1,7 +1,7 @@ """Gemma3 test for Vision Transformer submodule""" -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 import os @@ -12,10 +12,9 @@ import ttnn from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.gemma3.gemma_image_transformer import TtGemmaImageTransformer from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull -from models.experimental.gemma3.tt.gemma_image_transformer import TtGemmaImageTransformer - @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( diff --git a/models/experimental/gemma3/tests/vision_tests/test_vision_transformer_block.py b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_transformer_block.py similarity index 95% rename from models/experimental/gemma3/tests/vision_tests/test_vision_transformer_block.py rename to models/tt_transformers/tests/multimodal/vision_tests/test_vision_transformer_block.py index a9938f99d2e2..76fb6e9be6d3 100644 --- a/models/experimental/gemma3/tests/vision_tests/test_vision_transformer_block.py +++ b/models/tt_transformers/tests/multimodal/vision_tests/test_vision_transformer_block.py @@ -1,6 +1,6 @@ """Gemma3 Test for Vision Transformer block""" -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -12,7 +12,7 @@ import ttnn from models.tt_transformers.tt.model_config import ModelArgs -from models.experimental.gemma3.tt.gemma_image_block import TtGemmaImageTransformerBlock +from models.tt_transformers.tt.multimodal.gemma3.gemma_image_block import TtGemmaImageTransformerBlock from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull diff --git a/models/tt_transformers/tests/test_attention.py b/models/tt_transformers/tests/test_attention.py index 364475a6e8c9..8745b72d22d2 100644 --- a/models/tt_transformers/tests/test_attention.py +++ b/models/tt_transformers/tests/test_attention.py @@ -143,6 +143,7 @@ def test_attention_inference( model_args.rope_theta, model_args.rope_scaling.factor if model_args.rope_scaling else None, model_args.rope_scaling.original_max_position_embeddings if model_args.rope_scaling else None, + rope_type=model_args.rope_scaling.rope_type.value, ) freqs_cis = torch.complex(cos, sin) @@ -164,7 +165,8 @@ def test_attention_inference( pt_attention_input = torch.randn( batch_size, seq_len, model_args.dim, dtype=get_ref_model_dype(reference_model, model_args.model_name) ) # Qwen2.5 0.5B sees 0.1 to 2.1 - + if "gemma" in os.environ.get("HF_MODEL"): + pt_attention_input = pt_attention_input.to(torch.bfloat16) tt_attention_input = pt_attention_input.clone() attention_input = model_args.prepare_residual_tensor_decode( diff --git a/models/tt_transformers/tests/test_attention_prefill.py b/models/tt_transformers/tests/test_attention_prefill.py index bfa054ad4100..8f61be47e744 100644 --- a/models/tt_transformers/tests/test_attention_prefill.py +++ b/models/tt_transformers/tests/test_attention_prefill.py @@ -145,6 +145,8 @@ def test_attention_inference( ) * 2 ) - 1 + if "gemma" in os.environ.get("HF_MODEL"): + pt_attention_input = pt_attention_input.to(torch.bfloat16) tt_attention_input = pt_attention_input.clone() attention_input = model_args.prepare_residual_tensor_prefill( tt_attention_input, diff --git a/models/tt_transformers/tests/test_decoder.py b/models/tt_transformers/tests/test_decoder.py index 8c97e45ebd9f..4507385ef050 100644 --- a/models/tt_transformers/tests/test_decoder.py +++ b/models/tt_transformers/tests/test_decoder.py @@ -9,11 +9,10 @@ import ttnn from models.common.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull -from models.tt_transformers.tests.test_utils import get_ref_model_dype from models.tt_transformers.tt.ccl import TT_CCL from models.tt_transformers.tt.common import PagedAttentionConfig, precompute_freqs from models.tt_transformers.tt.decoder import TransformerBlock -from models.tt_transformers.tt.model_config import CheckpointType, ModelArgs +from models.tt_transformers.tt.model_config import ModelArgs from models.tt_transformers.tt.rope import RotarySetup @@ -148,17 +147,15 @@ def test_decoder_inference( seqlen = 1 - if model_args.checkpoint_type == CheckpointType.Meta: - cos, sin = precompute_freqs( - model_args.head_dim, - model_args.max_seq_len * 2, - model_args.rope_theta, - model_args.rope_scaling.factor if model_args.rope_scaling else None, - model_args.rope_scaling.original_max_position_embeddings if model_args.rope_scaling else None, - ) - freqs_cis = torch.complex(cos, sin) - else: - freqs_cis = None + cos, sin = precompute_freqs( + model_args.head_dim, + model_args.max_seq_len * 2, + model_args.rope_theta, + model_args.rope_scaling.factor if model_args.rope_scaling else None, + model_args.rope_scaling.original_max_position_embeddings if model_args.rope_scaling else None, + rope_type="linear", + ) + freqs_cis = torch.complex(cos, sin) # Initial positions current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)]) @@ -178,10 +175,14 @@ def test_decoder_inference( # input = torch.randn(1, 32, 4096) pt_decode_input = ( torch.rand( - batch_size, seqlen, model_args.dim, dtype=get_ref_model_dype(reference_model, model_args.model_name) + batch_size, + seqlen, + model_args.dim, ) * 2 ) - 1 + if "gemma" in os.environ.get("HF_MODEL"): + pt_decode_input = pt_decode_input.to(torch.bfloat16) tt_decode_input = pt_decode_input.clone() decode_input = model_args.prepare_residual_tensor_decode( diff --git a/models/tt_transformers/tests/test_decoder_prefill.py b/models/tt_transformers/tests/test_decoder_prefill.py index f647b8377e54..3353aa2988d6 100644 --- a/models/tt_transformers/tests/test_decoder_prefill.py +++ b/models/tt_transformers/tests/test_decoder_prefill.py @@ -66,7 +66,7 @@ def test_decoder_inference( "Mistral-7B models do not support max_seq_len > 256. See issue: https://github.com/tenstorrent/tt-metal/issues/19806" ) - dtype = ttnn.bfloat8_b + dtype = ttnn.bfloat16 batch_size = 1 # For prefill we only support batch_size = 1 model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len, cache_hf=True) @@ -165,6 +165,8 @@ def test_decoder_inference( ) * 2 ) - 1 + if "gemma" in os.environ.get("HF_MODEL"): + pt_decode_input = pt_decode_input.to(torch.bfloat16) tt_decode_input = pt_decode_input.clone() decode_input = model_args.prepare_residual_tensor_prefill( tt_decode_input, diff --git a/models/tt_transformers/tests/test_lm_head.py b/models/tt_transformers/tests/test_lm_head.py index bb070668dc7b..c44595cd4cdf 100644 --- a/models/tt_transformers/tests/test_lm_head.py +++ b/models/tt_transformers/tests/test_lm_head.py @@ -65,6 +65,8 @@ def test_lm_head_inference(seq_len, batch_size, mesh_device, reset_seeds): ) torch_input = torch.randn(1, 1, seq_len, model_args.dim) + if "gemma" in os.environ.get("HF_MODEL"): + torch_input = torch_input.to(torch.bfloat16) reference_output = reference_model(torch_input) tt_input = ttnn.from_torch( torch_input, diff --git a/models/tt_transformers/tests/test_model.py b/models/tt_transformers/tests/test_model.py index b6738c3e576c..51432fa0f7c8 100644 --- a/models/tt_transformers/tests/test_model.py +++ b/models/tt_transformers/tests/test_model.py @@ -18,13 +18,22 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.timeout(1800) @pytest.mark.models_performance_bare_metal +# @pytest.mark.parametrize( +# "weights, layers", +# [ +# ("random", 1), +# ("instruct", None), +# ], +# ids=["quick", "full"], +# ) + + @pytest.mark.parametrize( "weights, layers", [ - ("random", 1), ("instruct", None), ], - ids=["quick", "full"], + ids=["full"], ) @pytest.mark.parametrize( "paged_attention", @@ -181,17 +190,19 @@ def test_model_inference( model_args.n_layers = layers state_dict = model_args.load_state_dict() state_dict_prefix = model_args.get_state_dict_prefix("", None) + # print(state_dict.keys()) reference_state_dict = { k[len(state_dict_prefix) :]: v for k, v in state_dict.items() if ( - any([f"{state_dict_prefix}layers.{i}." in k for i in range(model_args.n_layers)]) - or any( - [ + ( + any(f"{state_dict_prefix}layers.{i}." in k for i in range(model_args.n_layers)) + or any( f"{state_dict_prefix}{name}" in k for name in ["tok_embeddings.weight", "norm.weight", "output.weight"] - ] + ) ) + and not (k.startswith("visual.") or k.startswith("multi_modal_projector.")) ) } @@ -308,6 +319,7 @@ def test_model_inference( decode_input, current_pos_tensor, rot_mats_global=rot_mats, + rot_mats_local=rot_mats, mode="decode", page_table=page_table_tt, ) diff --git a/models/tt_transformers/tt/attention.py b/models/tt_transformers/tt/attention.py index 057f0b453be9..2cc6eb810fe2 100644 --- a/models/tt_transformers/tt/attention.py +++ b/models/tt_transformers/tt/attention.py @@ -30,6 +30,8 @@ def __init__( super().__init__() self.mesh_device = mesh_device + self.layer_idx = layer_num + self.configuration = configuration self.tt_ccl = tt_ccl self.num_devices = configuration.num_devices self.TG = self.num_devices == 32 @@ -516,6 +518,7 @@ def forward_decode( program_config=self.model_config["SDPA_DECODE_PROGCFG"], compute_kernel_config=self.sdpa_decode_compute_kernel_cfg, memory_config=ttnn.DRAM_MEMORY_CONFIG, + sliding_window=self.configuration.sliding_window if self.is_sliding else 0, ) else: attn_output_1G4D = ttnn.transformer.scaled_dot_product_attention_decode( @@ -527,6 +530,7 @@ def forward_decode( program_config=self.model_config["SDPA_DECODE_PROGCFG"], compute_kernel_config=self.sdpa_decode_compute_kernel_cfg, memory_config=ttnn.DRAM_MEMORY_CONFIG, # FIXME: why not L1 height sharded e.g. SCORES_BATCHED_MM_OUTPUT_MEMCFG? + sliding_window=self.configuration.sliding_window if self.is_sliding else 0, ) ttnn.deallocate(q_heads_1BQD) @@ -833,7 +837,6 @@ def forward_prefill( q_heads_1QSD_8b, k_heads_1KSD_8b, v_heads_1VSD_8b, - is_causal=True, scale=self.scale, compute_kernel_config=self.sdpa_prefill_compute_kernel_cfg, program_config=self.model_config["SDPA_PROGCFG"](seq_len), @@ -927,7 +930,13 @@ def forward( kv_cache=kv_cache, ) else: - return self.forward_decode(x, current_pos, rot_mats, page_table=page_table, kv_cache=kv_cache) + return self.forward_decode( + x, + current_pos, + rot_mats, + page_table=page_table, + kv_cache=kv_cache, + ) def prefill_prepare_tensor_for_kv_cache(self, key_or_value_layer, user_id): tensor_copy = ttnn.clone(key_or_value_layer) diff --git a/models/tt_transformers/tt/common.py b/models/tt_transformers/tt/common.py index 66922dc0dbb1..9dc64d2778be 100644 --- a/models/tt_transformers/tt/common.py +++ b/models/tt_transformers/tt/common.py @@ -3,11 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 import math -import os import re from enum import Enum from types import SimpleNamespace -from typing import Callable, Optional +from typing import Optional import torch from llama_models.llama3.api.datatypes import ImageMedia @@ -371,6 +370,7 @@ def get_prefill_rot_mat(head_dim, mesh_device, seq_len, theta, scale_factor, ori rot_mats = [cos_gathereds, sin_gathereds] return rot_mats + def compute_linear_parameters(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): """Linear scaling for rotary embeddings.""" freqs /= scale_factor @@ -713,11 +713,7 @@ def create_tt_model( state_dict=None, num_layers=None, ): - if "HF_MODEL" in os.environ and "gemma-3" in os.environ["HF_MODEL"].lower(): - from models.experimental.gemma3.tt.text_model import Gemma3Transformer as Transformer - else: - from models.tt_transformers.tt.model import Transformer - + from models.tt_transformers.tt.model import Transformer from models.tt_transformers.tt.model_config import ModelArgs tt_model_args = ModelArgs( @@ -785,7 +781,7 @@ def hf_multimodal_encode(messages, processor): **encoded, tokens=encoded["input_ids"].squeeze(0), vision=SimpleNamespace( - images=encoded["pixel_values"], + images=encoded.get("pixel_values", None), mask=None, ), ) diff --git a/models/tt_transformers/tt/decoder.py b/models/tt_transformers/tt/decoder.py index 722c22c27493..9069c8aace4a 100644 --- a/models/tt_transformers/tt/decoder.py +++ b/models/tt_transformers/tt/decoder.py @@ -46,6 +46,7 @@ def __init__( self.model_config = args.get_model_config() self.is_mixture_of_experts = False self.layer_num = layer_num + self.num_devices = args.num_devices ActualAttentionClass = attention_class if attention_class is not None else DefaultAttention @@ -136,6 +137,7 @@ def __init__( tt_ccl=self.tt_ccl, TG=args.is_galaxy, ) + if f"layers.{layer_num}.pre_feedforward_layernorm.weight" in state_dict: self.pre_ff_norm = DistributedNorm( # pre_feedforward_layernorm RMSNorm( @@ -143,12 +145,12 @@ def __init__( dim=args.dim, eps=args.norm_eps, state_dict=state_dict, - add_unit_offset=self.args.rms_norm_add_unit_offset, state_dict_prefix=args.get_state_dict_prefix("", layer_num), weight_cache_path=None if args.dummy_weights else weight_cache_path, weight_dtype=ttnn.bfloat16, weight_key="pre_feedforward_layernorm", is_distributed=self.args.is_distributed_norm, + add_unit_offset=self.args.rms_norm_add_unit_offset, sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], ccl_topology=self.args.ccl_topology(), @@ -168,13 +170,13 @@ def __init__( device=mesh_device, dim=args.dim, eps=args.norm_eps, - add_unit_offset=self.args.rms_norm_add_unit_offset, state_dict=state_dict, state_dict_prefix=args.get_state_dict_prefix("", layer_num), weight_cache_path=None if args.dummy_weights else weight_cache_path, weight_dtype=ttnn.bfloat16, weight_key="post_feedforward_layernorm", is_distributed=self.args.is_distributed_norm, + add_unit_offset=self.args.rms_norm_add_unit_offset, sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], ccl_topology=self.args.ccl_topology(), @@ -203,6 +205,7 @@ def forward( ) -> ttnn.Tensor: TG = self.args.is_galaxy residual = x + # x is fractured across devices and interleaved in DRAM (for prefill) and sharded in L1 (for decode) skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG assert ( @@ -228,65 +231,71 @@ def forward( chunk_start_idx=chunk_start_idx, kv_cache=kv_cache, ) + if self.pre_ff_norm == None: + # Here x and attn_out are both fractured across devices + attn_out = ttnn.add(x, attn_out, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16 if TG else None) + + residual = attn_out + + # Norms take fractured inputs and output replicated across devices + hidden_states = self.ff_norm(attn_out, mode) - if self.pre_ff_norm is None: - hidden_states = ttnn.add( - residual, attn_out, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16 if TG else None - ) - residual = hidden_states - if mode == "prefill": - x.deallocate(True) - else: - hidden_states = attn_out - hidden_states = self.ff_norm(hidden_states, mode) if self.pre_ff_norm is not None: - # The output of the ff_norm is replicated across the device - # but the residual is fractured across the devices + # NOTE: The output of ff_norm is gathered, while the input to the add is fractured. + # To align them, we use tt_all_reduce to fracture the hidden_states across devices. + # Since tt_all_reduce performs a sum across devices, we divide by num_devices + # to restore the original values. if self.num_devices > 1: hidden_states = tt_all_reduce( hidden_states, self.mesh_device, - tt_ccl=self.tt_ccl, + self.tt_ccl, cluster_axis=0, dim=3, num_reduce_scatter_links=self.args.num_reduce_scatter_links, num_all_gather_links=self.args.num_all_gather_links, - topology=ttnn.Topology.Ring, + topology=self.args.ccl_topology(), memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=self.args.ccl_dtype, ) hidden_states = ttnn.div(hidden_states, self.num_devices) - hidden_states = ttnn.add( - residual, hidden_states, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16 if TG else None - ) + + hidden_states = ttnn.add(residual, hidden_states, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16) + residual = hidden_states + hidden_states = self.pre_ff_norm(hidden_states, mode) - ttnn.deallocate(attn_out) + if mode == "prefill": + x.deallocate(True) + + # ttnn.deallocate(attn_out) if TG and mode == "decode": hidden_states = ttnn.to_memory_config(hidden_states, memory_config=self.model_config["MLP_ACT_MEMCFG"]) # MLP takes replicated inputs and produces fractured outputs - hidden_states = self.feed_forward.forward(hidden_states, mode) - + # takes and residual are both fractured across devices activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( decoder_id=self.layer_num, tensor=TensorGroup.ACTIVATION ) - if self.post_ff_norm is not None: - hidden_states = self.post_ff_norm(hidden_states, mode) # Gathered + hidden_states = self.post_ff_norm(hidden_states, mode) + # NOTE: The output of post_ff_norm is gathered, while the inputs to the add is fractured. + # To align them, we use tt_all_reduce to fracture the hidden_states across devices. + # Since tt_all_reduce performs a sum across devices, we divide by num_devices + # to restore the original values. if self.num_devices > 1: hidden_states = tt_all_reduce( hidden_states, self.mesh_device, - tt_ccl=self.tt_ccl, + self.tt_ccl, cluster_axis=0, dim=3, num_reduce_scatter_links=self.args.num_reduce_scatter_links, num_all_gather_links=self.args.num_all_gather_links, - topology=ttnn.Topology.Ring, + topology=self.args.ccl_topology(), memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=self.args.ccl_dtype, ) diff --git a/models/tt_transformers/tt/generator.py b/models/tt_transformers/tt/generator.py index 534e7a9260cf..03fc6a6453ad 100644 --- a/models/tt_transformers/tt/generator.py +++ b/models/tt_transformers/tt/generator.py @@ -61,13 +61,7 @@ def __init__(self, model, model_args, mesh_device, processor=None, tokenizer=Non # Note: This function is called by vLLM def prefill_forward_text( - self, - tokens: torch.Tensor, - page_table=None, - kv_cache=None, - prompt_lens=None, - empty_slots=None, - **kwargs, + self, tokens: torch.Tensor, page_table=None, kv_cache=None, prompt_lens=None, empty_slots=None, **kwargs ): if page_table is not None: assert isinstance(page_table, torch.Tensor), "page_table mush be torch.Tensor" @@ -411,6 +405,13 @@ def _easy_trace_text( ): self.model[i].update_attention_masks(current_pos[i]) + for i in range(self.data_parallel): + if ( + hasattr(self.model[i], "device_decode_sliding_mask") + and self.model[i].device_decode_sliding_mask is not None + ): + self.model[i].update_attention_masks(current_pos[i]) + for i, trace_id in self.trace_ids_text.items(): ttnn.execute_trace(self.model_args[i].mesh_device, trace_id, cq_id=0, blocking=False) diff --git a/models/tt_transformers/tt/mlp.py b/models/tt_transformers/tt/mlp.py index 22fb0ac7a996..16c10ffb1e96 100644 --- a/models/tt_transformers/tt/mlp.py +++ b/models/tt_transformers/tt/mlp.py @@ -29,6 +29,7 @@ def __init__( self.mesh_device = mesh_device self.tt_ccl = tt_ccl self.args = args + self.num_devices = args.num_devices self.dim = args.dim self.model_config = model_config self.layer_num = layer_num @@ -252,6 +253,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_2 else None, ) ttnn.deallocate(w2_in) + w2_out = ttnn.multiply(w2_out, self.num_devices) # if mode == "decode" and not TG: # w2_out = ttnn.sharded_to_interleaved(w2_out, ttnn.DRAM_MEMORY_CONFIG) w2_out_reduced = tt_all_reduce( @@ -272,6 +274,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: use_composite=True if self.dim == 8192 else False, topology=self.args.ccl_topology(), ) + w2_out_reduced = ttnn.div(w2_out_reduced, self.num_devices) # Ensure dim 0 and 1 are 1 original_shape = w2_out_reduced.shape diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index 1eea6bab32cb..71f1fd2cd5b4 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -1190,6 +1190,11 @@ def __init__( use_height_and_width_as_shard_shape=True, ) + self.model_config["LM_HEAD_OUTPUT_MEMCFG"] = ( + ttnn.DRAM_MEMORY_CONFIG if self.base_model_name.startswith("gemma-3") else ttnn.L1_MEMORY_CONFIG + ) + self.lm_head_dtype = ttnn.bfloat16 if self.base_model_name.startswith("gemma-3") else None + # Vision model configs self.model_config["IMAGE_MLP_FC_PROGCFG"] = lambda seq_len, max_seq: self.matmul_config( m=min(seq_len, max_seq), @@ -1492,6 +1497,7 @@ def _get_hidden_activation_type(self, config): def _set_model_specific_params(self): # Gemma3 specific params + self.attention_mask = False is_gemma3 = "gemma-3" in self.base_model_name.lower() if is_gemma3: self.rms_norm_add_unit_offset = True @@ -1792,7 +1798,10 @@ def is_llama_vision(self): return ("llama" in self.CKPT_DIR.lower()) and ("vision" in self.CKPT_DIR.lower()) def get_state_dict_prefix(self, module_name, layer_num, is_vision=False): - text_prefix = self.state_dict_text_prefix + if "gemma-3" in self.model_name: + text_prefix = "" + else: + text_prefix = self.state_dict_text_prefix vision_prefix = self.state_dict_vision_prefix layer_prefix = f"layers.{layer_num}." if layer_num is not None else "" @@ -2401,7 +2410,7 @@ def create_tokenizer(self): # Add meta-compatible stop token list to the HF tokenizer if not "stop_tokens" in tokenizer.__dict__: - tokenizer.stop_tokens = [tokenizer.eos_token_id] + tokenizer.stop_tokens = self.eos_token_id if self.eos_token_id is not None else [tokenizer.eos_token_id] # Phi-3-mini uses "<|end|>" as EOS token if "phi-3-mini" in self.base_model_name.lower(): tokenizer.stop_tokens.append(tokenizer.encode("<|end|>")[0]) @@ -2571,7 +2580,7 @@ def reference_embedding(self, reference_model=None): model = self.reference_transformer(wrap=False) layer = model.model.embed_tokens else: - layer = reference_model.model.embed_tokens + layer = reference_model.model.model.embed_tokens layer._load_state_dict = layer.load_state_dict layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) diff --git a/models/experimental/gemma3/tt/gemma_conv2d_patch.py b/models/tt_transformers/tt/multimodal/gemma3/gemma_conv2d_patch.py similarity index 100% rename from models/experimental/gemma3/tt/gemma_conv2d_patch.py rename to models/tt_transformers/tt/multimodal/gemma3/gemma_conv2d_patch.py diff --git a/models/tt_transformers/tt/multimodal/gemma3/gemma_e2e_model.py b/models/tt_transformers/tt/multimodal/gemma3/gemma_e2e_model.py new file mode 100644 index 000000000000..41047dd3f92f --- /dev/null +++ b/models/tt_transformers/tt/multimodal/gemma3/gemma_e2e_model.py @@ -0,0 +1,127 @@ +import ttnn + +# from models.tt_transformers.tt.common import create_causal_mask, create_sliding_window_causal_mask +from models.tt_transformers.tt.model import Transformer +from models.tt_transformers.tt.multimodal.gemma3.gemma_vision_model import TtGemmaTransformerVision + + +class TtGemmaModel(Transformer): + def __init__( + self, + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=None, + use_paged_kv_cache=False, + ): + super().__init__( + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=paged_attention_config, + use_paged_kv_cache=use_paged_kv_cache, + ) + + self.vision_model = TtGemmaTransformerVision( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=args.state_dict_vision_prefix, + dtype=dtype, + configuration=args, + weight_cache_path=weight_cache_path, + ) + + def prepare_inputs_prefill(self, pt_tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): + """ + Inputs are torch tensors or python types. This function returns ttnn + tensors on device. + TODO: Debate whether this function is responsible for padding + """ + + S = pt_tokens.shape[-1] + tokens = ttnn.from_torch( + pt_tokens.reshape(1, 1, 1, -1), + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + tokens_embd = self.embd(tokens) + vision_output = self.compute_vision_token(**kwargs) + + if vision_output is not None: + tokens_embd = ttnn.to_torch(tokens_embd, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1)) + + comp_vision_output = ttnn.to_torch( + vision_output, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0) + )[: vision_output.shape[0], :] + + image_features = comp_vision_output.squeeze(0) + special_image_mask = (pt_tokens == self.args.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(tokens_embd) + image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) + tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) + + tokens_embd = self.args.prepare_residual_tensor_prefill( + tokens_embd, + ) + + tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) + + # Slice the rot mats to the prefill seqlen + assert ( + self.rope_setup.cos_matrix.shape[2] >= start_pos + S + ), f"Padded prefill end idx {start_pos + S} exceeds max seq len {self.rope_setup.cos_matrix.shape[2]}" + + tt_rot_mats_prefill_global = [ + self.rope_setup.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + + if hasattr(self, "rope_local_setup"): + tt_rot_mats_prefill_local = [ + self.rope_local_setup.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_local_setup.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + else: + tt_rot_mats_prefill_local = None + + if page_table is not None: + tt_page_table = ttnn.from_torch( + page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_page_table = None + + if chunk_page_table is not None: + tt_chunk_page_table = ttnn.from_torch( + chunk_page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_chunk_page_table = None + + return ( + tokens_embd, + tt_rot_mats_prefill_global, + tt_rot_mats_prefill_local, + tt_page_table, + tt_chunk_page_table, + ) + + def compute_vision_token(self, pixel_values=None): + if pixel_values is None: + return None + vision_output = self.vision_model(pixel_values) + return vision_output diff --git a/models/experimental/gemma3/tt/gemma_image_attention.py b/models/tt_transformers/tt/multimodal/gemma3/gemma_image_attention.py similarity index 98% rename from models/experimental/gemma3/tt/gemma_image_attention.py rename to models/tt_transformers/tt/multimodal/gemma3/gemma_image_attention.py index 40c05e439552..3bbfeef4d7ba 100644 --- a/models/experimental/gemma3/tt/gemma_image_attention.py +++ b/models/tt_transformers/tt/multimodal/gemma3/gemma_image_attention.py @@ -41,6 +41,8 @@ def __init__( self.head_dim = self.hidden_size // self.n_heads self.n_kv_heads = self.n_heads + self.configuration = configuration + self.n_local_heads = self.n_heads // configuration.num_devices self.n_local_kv_heads = self.n_kv_heads // configuration.num_devices @@ -367,9 +369,6 @@ def forward(self, x_11SH, mask=None): if seq_len > MAX_MM_SEQ_LEN: attn_output_11SH = ttnn.reshape(attn_output_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) - # if self.num_devices > 1: - # # self.bo = ttnn.all_gather(self.bo, dim=3, num_links=1) - # attn_output_11SH = ttnn.all_gather(attn_output_11SH, dim=3, num_links=1) if self.num_devices > 1: # replace with reduce_scatter and all_gather attn_output_11SH = ttnn.experimental.all_gather_async( attn_output_11SH, @@ -377,7 +376,7 @@ def forward(self, x_11SH, mask=None): dim=3, multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(), num_links=1, - topology=ttnn.Topology.Linear, + topology=self.configuration.ccl_topology(), barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(), chunks_per_sync=10, num_workers_per_link=2, diff --git a/models/experimental/gemma3/tt/gemma_image_block.py b/models/tt_transformers/tt/multimodal/gemma3/gemma_image_block.py similarity index 95% rename from models/experimental/gemma3/tt/gemma_image_block.py rename to models/tt_transformers/tt/multimodal/gemma3/gemma_image_block.py index e0eb0b88017c..59b4cfa7b78b 100644 --- a/models/experimental/gemma3/tt/gemma_image_block.py +++ b/models/tt_transformers/tt/multimodal/gemma3/gemma_image_block.py @@ -12,11 +12,10 @@ import ttnn from models.common.lightweightmodule import LightweightModule - -from models.experimental.gemma3.tt.gemma_image_attention import TtGemmaImageAttention -from models.experimental.gemma3.tt.gemma_image_mlp import TtGemmaImageFeedForward -from models.tt_transformers.tt.multimodal.llama_layernorm import TtLayerNorm from models.tt_transformers.tt.ccl import TT_CCL +from models.tt_transformers.tt.multimodal.gemma3.gemma_image_attention import TtGemmaImageAttention +from models.tt_transformers.tt.multimodal.gemma3.gemma_image_mlp import TtGemmaImageFeedForward +from models.tt_transformers.tt.multimodal.llama_layernorm import TtLayerNorm class TtGemmaImageTransformerBlock(LightweightModule): diff --git a/models/experimental/gemma3/tt/gemma_image_mlp.py b/models/tt_transformers/tt/multimodal/gemma3/gemma_image_mlp.py similarity index 95% rename from models/experimental/gemma3/tt/gemma_image_mlp.py rename to models/tt_transformers/tt/multimodal/gemma3/gemma_image_mlp.py index ed256073c442..ab8f88a7ba6b 100644 --- a/models/experimental/gemma3/tt/gemma_image_mlp.py +++ b/models/tt_transformers/tt/multimodal/gemma3/gemma_image_mlp.py @@ -132,4 +132,12 @@ def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: pre_bias_output = c_proj_out output = ttnn.add(pre_bias_output, self.c_proj_bias) + + ttnn.deallocate(c_fc_out) + ttnn.deallocate(c_proj_out) + ttnn.deallocate(pre_bias_output) + # Deallocate input tensor to free memory + ttnn.deallocate(x_in) + # Reshape output back to original shape + return output diff --git a/models/experimental/gemma3/tt/gemma_image_transformer.py b/models/tt_transformers/tt/multimodal/gemma3/gemma_image_transformer.py similarity index 95% rename from models/experimental/gemma3/tt/gemma_image_transformer.py rename to models/tt_transformers/tt/multimodal/gemma3/gemma_image_transformer.py index 4e0d4101ee96..1f10b2c1d8d7 100644 --- a/models/experimental/gemma3/tt/gemma_image_transformer.py +++ b/models/tt_transformers/tt/multimodal/gemma3/gemma_image_transformer.py @@ -12,7 +12,7 @@ from tqdm import tqdm from models.common.lightweightmodule import LightweightModule -from models.experimental.gemma3.tt.gemma_image_block import TtGemmaImageTransformerBlock +from models.tt_transformers.tt.multimodal.gemma3.gemma_image_block import TtGemmaImageTransformerBlock class TtGemmaImageTransformer(LightweightModule): diff --git a/models/experimental/gemma3/tt/gemma_vision_model.py b/models/tt_transformers/tt/multimodal/gemma3/gemma_vision_block.py similarity index 91% rename from models/experimental/gemma3/tt/gemma_vision_model.py rename to models/tt_transformers/tt/multimodal/gemma3/gemma_vision_block.py index 4524426e9ae5..6bad9cee5ddd 100644 --- a/models/experimental/gemma3/tt/gemma_vision_model.py +++ b/models/tt_transformers/tt/multimodal/gemma3/gemma_vision_block.py @@ -7,14 +7,15 @@ # SPDX-License-Identifier: Apache-2.0 import torch + import ttnn from models.common.lightweightmodule import LightweightModule -from models.experimental.gemma3.tt.siglip_vision_embedding import TtSiglipVisionEmbeddings -from models.experimental.gemma3.tt.gemma_image_transformer import TtGemmaImageTransformer +from models.tt_transformers.tt.multimodal.gemma3.gemma_image_transformer import TtGemmaImageTransformer +from models.tt_transformers.tt.multimodal.gemma3.gemma_vision_embedding import TtGemmaVisionEmbeddings from models.tt_transformers.tt.multimodal.llama_layernorm import TtLayerNorm -class TtSiglipGemmaVisionModel(LightweightModule): +class TtGemmaVisionModel(LightweightModule): def __init__( self, mesh_device, @@ -41,7 +42,7 @@ def __init__( self.n_global_layers = configuration.vision_n_global_layers self.return_intermediate = return_intermediate - self.embeddings = TtSiglipVisionEmbeddings( + self.embeddings = TtGemmaVisionEmbeddings( mesh_device=mesh_device, state_dict=state_dict, state_dict_prefix=f"{state_dict_prefix}embeddings.", diff --git a/models/experimental/gemma3/tt/siglip_vision_embedding.py b/models/tt_transformers/tt/multimodal/gemma3/gemma_vision_embedding.py similarity index 94% rename from models/experimental/gemma3/tt/siglip_vision_embedding.py rename to models/tt_transformers/tt/multimodal/gemma3/gemma_vision_embedding.py index b4522bea810b..5d901cc007a2 100644 --- a/models/experimental/gemma3/tt/siglip_vision_embedding.py +++ b/models/tt_transformers/tt/multimodal/gemma3/gemma_vision_embedding.py @@ -12,10 +12,10 @@ import ttnn from models.common.lightweightmodule import LightweightModule -from models.experimental.gemma3.tt.gemma_conv2d_patch import TtGemmaConv2dPatch +from models.tt_transformers.tt.multimodal.gemma3.gemma_conv2d_patch import TtGemmaConv2dPatch -class TtSiglipVisionEmbeddings(LightweightModule): +class TtGemmaVisionEmbeddings(LightweightModule): def __init__( self, mesh_device, diff --git a/models/experimental/gemma3/tt/gemma_vision_crossattention.py b/models/tt_transformers/tt/multimodal/gemma3/gemma_vision_model.py similarity index 86% rename from models/experimental/gemma3/tt/gemma_vision_crossattention.py rename to models/tt_transformers/tt/multimodal/gemma3/gemma_vision_model.py index b6e7f95785ad..495814a9d505 100644 --- a/models/experimental/gemma3/tt/gemma_vision_crossattention.py +++ b/models/tt_transformers/tt/multimodal/gemma3/gemma_vision_model.py @@ -1,5 +1,5 @@ """ -This is the Vision Transformer Block for Gemma3. +This is the Vision Transformer Block for Gemma-3-4b-it. This involves vision followed by MultiModalProjector processing """ @@ -7,9 +7,10 @@ # SPDX-License-Identifier: Apache-2.0 + from models.common.lightweightmodule import LightweightModule -from models.experimental.gemma3.tt.gemma_vision_model import TtSiglipGemmaVisionModel -from models.experimental.gemma3.tt.mmp import TtGemma3MultiModalProjector +from models.tt_transformers.tt.multimodal.gemma3.gemma_vision_block import TtGemmaVisionModel +from models.tt_transformers.tt.multimodal.gemma3.multi_modal_projector import TtGemma3MultiModalProjector class TtGemmaTransformerVision(LightweightModule): @@ -35,7 +36,7 @@ def __init__( self.patch_size = configuration.vision_patch_size self.configuration = configuration - self.vision_encoder = TtSiglipGemmaVisionModel( + self.vision_encoder = TtGemmaVisionModel( mesh_device, state_dict, state_dict_prefix=configuration.state_dict_vision_prefix, diff --git a/models/experimental/gemma3/tt/gemma_vision_rmsnorm.py b/models/tt_transformers/tt/multimodal/gemma3/gemma_vision_rmsnorm.py similarity index 100% rename from models/experimental/gemma3/tt/gemma_vision_rmsnorm.py rename to models/tt_transformers/tt/multimodal/gemma3/gemma_vision_rmsnorm.py diff --git a/models/experimental/gemma3/tt/mmp.py b/models/tt_transformers/tt/multimodal/gemma3/multi_modal_projector.py similarity index 98% rename from models/experimental/gemma3/tt/mmp.py rename to models/tt_transformers/tt/multimodal/gemma3/multi_modal_projector.py index 3e445e99606f..9ffa22bd22e9 100644 --- a/models/experimental/gemma3/tt/mmp.py +++ b/models/tt_transformers/tt/multimodal/gemma3/multi_modal_projector.py @@ -11,7 +11,7 @@ import ttnn from models.common.lightweightmodule import LightweightModule -from models.experimental.gemma3.tt.gemma_vision_rmsnorm import RMSNorm +from models.tt_transformers.tt.multimodal.gemma3.gemma_vision_rmsnorm import RMSNorm class TtGemma3MultiModalProjector(LightweightModule):