diff --git a/models/common/rmsnorm.py b/models/common/rmsnorm.py index 35d7ec55121e..4e24cf725d6a 100644 --- a/models/common/rmsnorm.py +++ b/models/common/rmsnorm.py @@ -85,7 +85,7 @@ def __init__( torch_weight, device=device, dtype=weight_dtype, - layout=ttnn.ROW_MAJOR_LAYOUT, + layout=ttnn.TILE_LAYOUT, memory_config=weight_memory_config, cache_file_name=cache_name, mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, @@ -96,7 +96,7 @@ def __init__( torch_weight, device=device, dtype=weight_dtype, - layout=ttnn.ROW_MAJOR_LAYOUT, + layout=ttnn.TILE_LAYOUT, memory_config=weight_memory_config, cache_file_name=cache_name, mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape)) @@ -128,6 +128,11 @@ def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) -> else: assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor" + if x.shape[-1] % weight.shape[-1] == 0: + # Reshape weight only if x's last dimension is divisible by weight's last dimension, + # to avoid padding errors in RMSNorm when dimensions are not aligned + weight = ttnn.reshape(weight, [1, 1, 1, -1]) + x = norm( x, epsilon=self.eps, diff --git a/models/tt_transformers/demo/simple_vision_demo.py b/models/tt_transformers/demo/simple_vision_demo.py index 7d21da9ca274..bcc27ce9c474 100644 --- a/models/tt_transformers/demo/simple_vision_demo.py +++ b/models/tt_transformers/demo/simple_vision_demo.py @@ -27,7 +27,9 @@ import ttnn from models.demos.utils.llm_demo_utils import create_benchmark_data, verify_perf from models.perf.benchmarking_utils import BenchmarkProfiler +from models.tt_transformers.tt.common import hf_multimodal_encode from models.tt_transformers.tt.generator import Generator +from models.tt_transformers.tt.model_config import CheckpointType def get_batch_sampler(temperature, top_p, tokenizer): @@ -61,6 +63,7 @@ def create_multimodal_model( checkpoint=None, ): from models.tt_transformers.tt.model_config import ModelArgs + from models.tt_transformers.tt.multimodal.gemma.gemma_e2e_model import TtGemmaModel from models.tt_transformers.tt.multimodal.llama_vision_model import CrossAttentionTransformer tt_model_args = ModelArgs(mesh_device, max_batch_size=max_batch_size) @@ -76,14 +79,26 @@ def create_multimodal_model( if checkpoint is None: checkpoint = tt_model_args.load_state_dict() - model = CrossAttentionTransformer( - mesh_device, - state_dict=checkpoint, - weight_cache_path=tt_model_args.weight_cache_path(dtype), - dtype=dtype, - configuration=tt_model_args, - use_paged_kv_cache=use_paged_kv_cache, - ) + print(f"Loaded checkpoint for {tt_model_args.base_model_name} with {checkpoint.keys()} keys") + + if tt_model_args.base_model_name == "gemma-3-4b": + model = TtGemmaModel( + mesh_device=mesh_device, + state_dict=checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(ttnn.bfloat8_b), + dtype=ttnn.bfloat8_b, + args=tt_model_args, + use_paged_kv_cache=use_paged_kv_cache, + ) + else: + model = CrossAttentionTransformer( + mesh_device, + state_dict=checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=tt_model_args, + use_paged_kv_cache=use_paged_kv_cache, + ) return tt_model_args, model, checkpoint @@ -128,7 +143,7 @@ def prepare_generator_args( ) @pytest.mark.parametrize( "test_type,max_seq_len", - (("normal", 512),), + (("normal", 2048),), ids=["normal"], ) @pytest.mark.parametrize( @@ -148,7 +163,9 @@ def prepare_generator_args( # 4, ], ) -@pytest.mark.parametrize("device_params", [{"trace_region_size": 14951424, "num_command_queues": 2}], indirect=True) +@pytest.mark.parametrize( + "device_params", [{"trace_region_size": 14951424, "num_command_queues": 2, "l1_small_size": 24576}], indirect=True +) def test_multimodal_demo_text( mesh_device, warmup_iters, @@ -172,9 +189,6 @@ def test_multimodal_demo_text( profiler = BenchmarkProfiler() profiler.start("run") - ckpt_dir = os.environ["LLAMA_DIR"] - tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") - num_devices = mesh_device.get_num_devices() if isinstance(mesh_device, ttnn.MeshDevice) else 1 max_batch_size *= data_parallel # input batch_size is interpreted as size per DP group @@ -185,11 +199,26 @@ def test_multimodal_demo_text( max_batch_size=max_batch_size, max_seq_len=max_seq_len, ) + + HF_MODEL = model_args[0].checkpoint_type == CheckpointType.HuggingFace + + if not HF_MODEL: + ckpt_dir = os.environ["LLAMA_DIR"] + tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") + + tokenizer = Tokenizer(model_path=tokenizer_path) + formatter = ChatFormat(tokenizer) + else: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_args[0].CKPT_DIR) + generator = Generator(model, model_args, mesh_device) - tokenizer = Tokenizer(model_path=tokenizer_path) - formatter = ChatFormat(tokenizer) - xattn_caches = [model.setup_cache(model_args[i].max_batch_size) for i, model in enumerate(generator.model)] + xattn_caches = [ + model.setup_cache(model_args[i].max_batch_size) if not HF_MODEL else None + for i, model in enumerate(generator.model) + ] # Create random images for trace capture with specific dimensions trace_img_560x560 = create_random_image(560, 560) @@ -250,10 +279,12 @@ def test_multimodal_demo_text( total_users = len(dialogs) num_batches = total_users // max_batch_size - sampler = get_batch_sampler(temperature, top_p, tokenizer) + sampler = get_batch_sampler(temperature, top_p, model_args[0].tokenizer) _num_prefill_tokens = 0 _num_decode_tokens = 0 + prompt_encoder = hf_multimodal_encode if HF_MODEL else formatter.encode_dialog_prompt + for iter_num in range(warmup_iters + 1): logger.info(f"Iteration {iter_num}") current_dialogs = trace_dialogs + dialogs @@ -263,9 +294,14 @@ def test_multimodal_demo_text( for msg in dialog: print(f"{msg.role.capitalize()}: {msg.content}\n") batch_model_input = [ - formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) for dialog in batch_dialogs + prompt_encoder(dialog, processor) if HF_MODEL else prompt_encoder(dialog, tool_prompt_format=False) + for dialog in batch_dialogs ] + if HF_MODEL: + # Use the processor's tokenizer instead of model_args tokenizer to ensure consistency + tokenizer = processor.tokenizer + # Do initial prefill vision_images = [ model_input.vision.images if model_input.vision else None for model_input in batch_model_input @@ -278,7 +314,8 @@ def test_multimodal_demo_text( total_lens = prefill_lens + max_gen_len # Create padded tokens tensor for batch - pad_id = tokenizer.pad_id + stop_tokens = model_args[0].tokenizer.stop_tokens + pad_id = tokenizer.pad_token_id if HF_MODEL else tokenizer.pad_id bsz = len(prompt_tokens) tokens = torch.full((bsz, max(total_lens)), pad_id, dtype=torch.long) @@ -358,19 +395,29 @@ def test_multimodal_demo_text( profiler.end(f"compile_decode", iteration=batch_idx) # Disable checking for eot until I have more robust code for batch > 1 - # if text in ["<|eot_id|>", "<|eom_id|>"]: - # break + if HF_MODEL: + if next_tokens in stop_tokens: + break + else: + # Disable checking for eot until I have more robust code for batch > 1 + pass + # if text in ["<|eot_id|>", "<|eom_id|>"]: + # break _num_decode_tokens += ( gen_idx * max_batch_size ) # gen_idx is (num_tokens - 1) to avoid counting compile iter # Log full text output for each user in batch - vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256] + if HF_MODEL: + # For HF models, get vision tokens from the processor if they exist + vision_tokens = [] + else: + vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256] for user_id in range(max_batch_size): # Remove <|image|> tokens since they break the tokenizer tokens_out = [ - t if t not in vision_tokens else tokenizer.pad_id + t if t not in vision_tokens else pad_id for t in tokens[user_id].tolist()[: position_id[user_id] + 2] ] text = tokenizer.decode(tokens_out) diff --git a/models/tt_transformers/tests/multimodal/gemma/test_mmp.py b/models/tt_transformers/tests/multimodal/gemma/test_mmp.py new file mode 100644 index 000000000000..8cc699cc51d8 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/gemma/test_mmp.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.gemma.multi_modal_projector import TtGemma3MultiModalProjector +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (1152,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +def test_multi_modal_inference(seq_len, batch_size, reset_seeds, device): + print("device:", device) + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_vision_multi_modal() + # first_layer_prefix = "multi_modal_projector." + + # partial_state_dict = { + # k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + # } + + # reference_model.load_state_dict(partial_state_dict) + + # create input tensor for multi_modal_projector layer + patches_per_image = 64 + num_patches = patches_per_image * patches_per_image + input = torch.randn((batch_size, num_patches, seq_len)) + reference_output = reference_model(input) + + # DistributedNorm inputs are fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode) + tt_input = ttnn.from_torch( + input, + device=device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, -1), mesh_shape=tt_model_args.cluster_shape), + # memory_config=( + # tt_model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG + # ), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + tt_model = TtGemma3MultiModalProjector( + mesh_device=device, + state_dict=state_dict, + state_dict_prefix="model.multi_modal_projector", + image_size=tt_model_args.vision_chunk_size, + patch_size=tt_model_args.vision_patch_size, + hidden_size=tt_model_args.vision_hidden_dim, + mm_tokens_per_image=tt_model_args.mm_tokens_per_image, + weight_cache_path=tt_model_args.weight_cache_path(dtype), + layer_norm_eps=1e-06, # layer_norm_eps + dtype=dtype, + configuration=tt_model_args, + ) + tt_output = tt_model(tt_input) + + tt_output_torch = ttnn.to_torch(tt_output).squeeze(0) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + pcc_required = 0.9999 + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/gemma/test_patch_embedding.py b/models/tt_transformers/tests/multimodal/gemma/test_patch_embedding.py new file mode 100644 index 000000000000..ad5a9b40a5a0 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/gemma/test_patch_embedding.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.gemma.gemma_conv2d_patch import TtGemmaConv2dPatch +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from ttnn import ConcatMeshToTensor + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_conv2d_inference( + mesh_device, + reset_seeds, +): + pcc_required = 0.9999 + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + tt_layer_prefix = "model.vision_tower.vision_model.embeddings.patch_embedding." + first_layer_prefix = "model.vision_tower.vision_model.embeddings.patch_embedding._linear." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + num_devices = model_args.num_devices + + B, NCH, H, W = (1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size) + in_channels, out_channels, kernel_size, stride, bias = ( + 3, + model_args.vision_dim, + model_args.vision_patch_size, + model_args.vision_patch_size, + True, + ) + + assert NCH == in_channels, "Number of channels in input tensor should match in_channels for the Conv2d patch." + assert type(kernel_size) == int, "Only symmetric kernel_size is currently supported." + assert kernel_size == stride, "Only same kernel_size and stride are currently supported." + + assert H % kernel_size == 0, "Height should be divisible by kernel_size." + assert W % kernel_size == 0, "Width should be divisible by kernel_size." + + input_tensor = torch.randn((B, NCH, H, W)) + + ##### Perform the torch ops ##### + reference_model = model_args.reference_siglip_patch_embed() + # reference_model.load_state_dict(partial_state_dict) + reference_output = reference_model(input_tensor) + del reference_model + + tt_model = TtGemmaConv2dPatch( + mesh_device, + state_dict, + tt_layer_prefix, + dtype, + in_channels, + out_channels, + kernel_size, + stride, + bias, + ) + tt_output = tt_model(input_tensor) + + logger.info("Checking outputs") + out = ttnn.from_device(tt_output) + tt_output_torch = ttnn.to_torch(out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=3)) + + tt_output_torch = tt_output_torch[0, ..., :out_channels] + + logger.info(f"Reference output shape: {reference_output.shape}") + logger.info(f"TT output shape: {tt_output_torch.shape}") + + # TT output: [B, HW, C] + B, HW, C = tt_output_torch.shape + H = W = int(HW**0.5) + assert H * W == HW, "HW is not a perfect square — can't reshape" + tt_output_torch = tt_output_torch.permute(0, 2, 1) + tt_output_torch = tt_output_torch.reshape(B, C, H, W) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/gemma/test_vision_attention.py b/models/tt_transformers/tests/multimodal/gemma/test_vision_attention.py new file mode 100644 index 000000000000..4e7dc66d3b75 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/gemma/test_vision_attention.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.load_checkpoints import ( # convert_vision_hf_to_meta, + convert_hf_qkv_to_meta_format, + convert_vision_hf_to_meta, +) +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.gemma.gemma_image_attention import TtGemmaImageAttention +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_attention_inference(batch, num_chunks, mesh_device, reset_seeds): + dtype = ttnn.bfloat16 + pcc_required = 0.99 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "model.vision_tower.vision_model.encoder.layers.0.attn." + # partial_state_dict = { + # k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + # } + + dim = model_args.vision_dim + + reference_model = model_args.reference_vision_attention() + # reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + hidden_size = model_args.vision_dim + n_heads = model_args.vision_attn_n_heads + head_dim = hidden_size // n_heads + seq_len = model_args.vision_chunk_ntok + + tt_model = TtGemmaImageAttention( + mesh_device, + state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + ) + + pt_attention_input = torch.randn(batch, seq_len, dim) + + # attention_input = model_args.prepare_residual_tensor_prefill( + # pt_attention_input, + # force_replicated=True, + # ) + attention_input = ttnn.from_torch( + pt_attention_input.unsqueeze(0), + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + ) + + tt_out = tt_model(attention_input) + + print("TT output :", tt_out) + # Doing contract in tt is correct!! + tt_output_torch = ttnn.to_torch( + tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1), device=mesh_device + )[0, :, :, :] + + reference_output = reference_model(pt_attention_input)[0] + print("Reference output shape:", reference_output.shape) + tt_output_torch = tt_output_torch[:, :4097, :] + print("TT output shape:", tt_output_torch.shape) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/gemma/test_vision_cross_attention_transformer.py b/models/tt_transformers/tests/multimodal/gemma/test_vision_cross_attention_transformer.py new file mode 100644 index 000000000000..06ff83801943 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/gemma/test_vision_cross_attention_transformer.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.gemma.gemma_vision_model import TtGemmaTransformerVision +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize("bsz", [1]) +def test_gemma_vision( + mesh_device, + reset_seeds, + bsz, +): + pcc_required = 0.90 + dtype = ttnn.bfloat16 + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + vision_first_layer_prefix = "model.vision_tower.vision_model." + vision_partial_state_dict = { + k[len(vision_first_layer_prefix) :]: v + for k, v in state_dict.items() + if (k.startswith(vision_first_layer_prefix)) + } + + reference_vision_model = model_args.reference_vision_model() + + image_size = model_args.vision_chunk_size + in_channels = model_args.vision_in_channels + + input_tensor = torch.rand((bsz, in_channels, image_size, image_size)) + + reference_mmp = model_args.reference_vision_multi_modal() + + reference_output = get_image_features( + reference_vision_model, + reference_mmp, + input_tensor, + ) + + test_gemma_vision = TtGemmaTransformerVision( + mesh_device, + state_dict, + state_dict_prefix="model.vision_tower.vision_model.", + dtype=dtype, + configuration=model_args, + return_intermediate=False, + ) + + test_output = test_gemma_vision(input_tensor) + + logger.info("Checking outputs") + out = ttnn.from_device(test_output) + + print("out shape: ", out) + print("reference_output ", reference_output) + + tt_output_torch = ttnn.to_torch( + out, + mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0), + )[0, :, :, :] + + print("reference_output ", reference_output.shape) + print(f"TT output shape: {tt_output_torch.shape}") + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" + + +def get_image_features(vision_tower, projector, input_tensor): + """ + Get image features from the vision tower and projector. + """ + vision_token = vision_tower(input_tensor).last_hidden_state + image_features = projector(vision_token) + return image_features diff --git a/models/tt_transformers/tests/multimodal/gemma/test_vision_embedding.py b/models/tt_transformers/tests/multimodal/gemma/test_vision_embedding.py new file mode 100644 index 000000000000..6bd362126f06 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/gemma/test_vision_embedding.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.gemma.siglip_vision_embedding import TtSiglipVisionEmbeddings +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from ttnn import ConcatMeshToTensor + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize("bsz", [1]) +def test_vision_embedding_integration( + mesh_device, + reset_seeds, + bsz, +): + pcc_required = 0.9999 + dtype = ttnn.bfloat16 + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "model.vision_tower.vision_model.embeddings." + # partial_state_dict = { + # k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + # } + + image_size = model_args.vision_chunk_size + patch_size = model_args.vision_patch_size + hidden_dim = model_args.vision_dim + dim = model_args.vision_dim + in_channels = 3 + + input_tensor = torch.randn((bsz, in_channels, image_size, image_size)) + + reference_model = model_args.reference_vision_embedding() + # reference_model.load_state_dict(partial_state_dict) + reference_output = reference_model(input_tensor) + + vision_embed = TtSiglipVisionEmbeddings( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + dtype=dtype, + image_size=image_size, + patch_size=patch_size, + num_channels=in_channels, + hidden_dim=hidden_dim, + bias=True, + ) + + embeddings = vision_embed(input_tensor) + ##### Check the outputs ##### + logger.info("Checking outputs") + out = ttnn.from_device(embeddings) + tt_output_torch = ttnn.to_torch(out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=-1)) + + # Only select output from one device + tt_output_torch = tt_output_torch[..., :dim] + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + # To get RTOL values + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/gemma/test_vision_layernorm.py b/models/tt_transformers/tests/multimodal/gemma/test_vision_layernorm.py new file mode 100644 index 000000000000..681f1def5203 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/gemma/test_vision_layernorm.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: © 2025vTenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.load_checkpoints import convert_vision_hf_to_meta # convert_vision_hf_to_meta, +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.llama_layernorm import TtLayerNorm # Updated import for LayerNorm +from models.utility_functions import comp_allclose, comp_pcc, nearest_32, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize("layer_name", [("layer_norm1"), ("layer_norm2")]) +def test_layernorm_inference(mesh_device, reset_seeds, layer_name): + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device) + width = model_args.vision_dim + num_chunks = 4 + seq_len = nearest_32(model_args.vision_chunk_ntok) * num_chunks + + # Load full state dict + state_dict = model_args.load_state_dict() + + # Prefix for vision MLP weights — consistent with HF checkpoint + if layer_name == "layer_norm1": + first_layer_prefix = "model.vision_tower.vision_model.encoder.layers.0.ln_1." + else: + first_layer_prefix = "model.vision_tower.vision_model.encoder.layers.0.ln_2." + + model_args.WEIGHTS_DTYPE = dtype + # Reference HF MLP (from Gemma3 vision tower) + reference_model = model_args.reference_vision_layernorm(layer_name) + # reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + # Initialize the custom LayerNorm model + tt_model = TtLayerNorm( + device=mesh_device, + dim=width, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + weight_dtype=dtype, + eps=model_args.norm_eps, + ) + + # Generate random input + torch_input = torch.rand(1, seq_len, width) # Adjusted dimensions for LayerNorm + + # Reference output using PyTorch's LayerNorm + reference_output = reference_model(torch_input) + + # Convert input to ttnn tensor + tt_input = ttnn.from_torch( + torch_input, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + ) + + logger.info("Compilation pass for LayerNorm") + tt_output = tt_model(tt_input) + print("tt_outputs ", tt_output) + + tt_output_torch = ttnn.to_torch( + tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1) + ) # Adjusted dim for LayerNorm + print("tt_output_torch shape ", tt_output_torch, tt_output_torch.shape) + tt_outputs = torch.chunk(tt_output_torch, model_args.num_devices, dim=-1) + # print("tt_outputs shape ", tt_outputs) + print("reference_output ", reference_output.shape) + # Compare outputs + pcc_required = 0.99 + for idx, tt_output_torch in enumerate(tt_outputs): + print("tt_output_torch ", tt_output_torch, tt_output_torch.shape) + print("reference_output ", reference_output.shape) + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + # non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + # tt_output_torch = tt_output_torch[non_zero_indices] + # reference_output = reference_output[non_zero_indices] + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/gemma/test_vision_mlp.py b/models/tt_transformers/tests/multimodal/gemma/test_vision_mlp.py new file mode 100644 index 000000000000..9cc743e3392b --- /dev/null +++ b/models/tt_transformers/tests/multimodal/gemma/test_vision_mlp.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.gemma.gemma_image_mlp import TtGemmaImageFeedForward +from models.utility_functions import comp_allclose, comp_pcc, nearest_32, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_mlp_inference(batch, num_chunks, mesh_device, reset_seeds): + dtype = ttnn.bfloat16 + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = model_args.get_state_dict_prefix("MLP", 0, is_vision=True) + # partial_state_dict = { + # k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + # } + model_args.WEIGHTS_DTYPE = dtype + + dim = model_args.vision_dim + seq_len = nearest_32(model_args.vision_chunk_ntok) * num_chunks + reference_model = model_args.reference_vision_mlp() + # reference_model.load_state_dict(partial_state_dict) + + tt_model = TtGemmaImageFeedForward( + mesh_device=mesh_device, + args=model_args, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + ) + torch_input = torch.randn(1, batch, seq_len, dim) + reference_output = reference_model(torch_input).squeeze() + tt_input = ttnn.from_torch( + torch_input, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + ) + + tt_output = tt_model(tt_input) + + print("TT output shape:", tt_output) + + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ + :, :1, :, : + ].squeeze() + + pcc_required = 0.99 + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/gemma/test_vision_pipeline.py b/models/tt_transformers/tests/multimodal/gemma/test_vision_pipeline.py new file mode 100644 index 000000000000..29e7d05648ce --- /dev/null +++ b/models/tt_transformers/tests/multimodal/gemma/test_vision_pipeline.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.gemma.gemma_vision_model import TtSiglipGemmaVisionModel +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize("bsz", [1]) +def test_gemma_vision( + mesh_device, + reset_seeds, + bsz, +): + pcc_required = 0.94 + dtype = ttnn.bfloat16 + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "model.vision_tower.vision_model." + # partial_state_dict = { + # k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + # } + + image_size = model_args.vision_chunk_size + in_channels = model_args.vision_in_channels + + input_tensor = torch.rand((bsz, in_channels, image_size, image_size)) + + reference_model = model_args.reference_vision_model() + # reference_model.load_state_dict(partial_state_dict) + reference_output = reference_model(input_tensor).last_hidden_state + + test_gemma_vision = TtSiglipGemmaVisionModel( + mesh_device, + state_dict, + state_dict_prefix=first_layer_prefix, + dtype=dtype, + configuration=model_args, + return_intermediate=False, + ) + + test_output = test_gemma_vision(input_tensor) + + logger.info("Checking outputs") + out = ttnn.from_device(test_output) + tt_output_torch = ttnn.to_torch(out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)).squeeze(0) + print("reference_output ", reference_output.shape) + print("tt_output_torch ", tt_output_torch.shape) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/gemma/test_vision_rmsnorm.py b/models/tt_transformers/tests/multimodal/gemma/test_vision_rmsnorm.py new file mode 100644 index 000000000000..40f9d697157a --- /dev/null +++ b/models/tt_transformers/tests/multimodal/gemma/test_vision_rmsnorm.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0from loguru import logger + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.distributed_norm import DistributedNorm +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.gemma.gemma_vision_rmsnorm import RMSNorm +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_rmsnorm_inference(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_vision_rms_norm() # Gemma3 RMSNorm + first_layer_prefix = "model.multi_modal_projector.mm_soft_emb_norm." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + # reference_model.load_state_dict(partial_state_dict) + + tt_inner_norm = RMSNorm( + device=device, + dim=1152, + state_dict=state_dict, + state_dict_prefix="", + weight_key="model.multi_modal_projector.mm_soft_emb_norm", + weight_dtype=dtype, + is_distributed=False, + sharded_program_config=tt_model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=tt_model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], + ) + + # Wrap it in DistributedNorm + tt_model = DistributedNorm(tt_inner_norm, tt_model_args, TG=tt_model_args.is_galaxy) + + input = torch.rand(1, 1, 1152) + + reference_output = reference_model(input) + + # DistributedNorm inputs are fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode) + tt_input = ttnn.from_torch( + input, + device=device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, -1), mesh_shape=tt_model_args.cluster_shape), + memory_config=( + tt_model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG + ), + ) + + tt_output = tt_model(tt_input, mode=mode) + + # DistributedNorm outputs are replicated across devices + tt_output_torch = ttnn.to_torch( + tt_output, + mesh_composer=ttnn.ConcatMesh2dToTensor( + device, dims=(0, 2) if tt_model_args.is_galaxy else (2, 0), mesh_shape=tt_model_args.cluster_shape + ), + )[:1, :, :].squeeze(0) + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("rms_norm Passed!") + else: + logger.warning("rms_norm Failed!") + + assert passing, f"rms_norm output does not meet PCC requirement {0.99}." diff --git a/models/tt_transformers/tests/multimodal/gemma/test_vision_transformer.py b/models/tt_transformers/tests/multimodal/gemma/test_vision_transformer.py new file mode 100644 index 000000000000..22074d2c1027 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/gemma/test_vision_transformer.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.gemma.gemma_image_transformer import TtGemmaImageTransformer +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_image_transformer_inference(batch, num_chunks, mesh_device): + pcc_required = 0.99 + + model_args = ModelArgs(mesh_device) + dtype = ttnn.bfloat16 + + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + n_layers = model_args.vision_n_layers + first_layer_prefix = "model.vision_tower.vision_model.encoder." + + # gated = True + + # partial_state_dict = { + # k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + # } + + dim = model_args.vision_dim + seq_len = model_args.vision_chunk_ntok - 1 + + reference_model = model_args.reference_vision_encoder() + # reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + all_tests_pass = True + + tt_model = TtGemmaImageTransformer( + mesh_device, + state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + layers=n_layers, + block_key="layers", + ) + + # Create PT input + pt_attention_input = torch.randn(batch, seq_len, dim) + attention_mask = torch.zeros(batch, 1, seq_len, seq_len) + + attention_input = model_args.prepare_residual_tensor_prefill( + pt_attention_input, + force_replicated=True, + ) + tt_mask = ttnn.from_torch( + attention_mask, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + with torch.no_grad(): + tt_out = tt_model(attention_input, mask=tt_mask) + + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[0, :, :, :] + + reference_output = reference_model(pt_attention_input, attention_mask=attention_mask)[0] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + if not passing: + logger.warning(f"PCC value -- {pcc_message} -- is lower than {pcc_required} for the output.") + else: + logger.info(f"PCC: {pcc_message}") + + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + logger.info(comp_allclose(reference_output, tt_output_torch)) + + all_tests_pass = all_tests_pass and passing + + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/gemma/test_vision_transformer_block.py b/models/tt_transformers/tests/multimodal/gemma/test_vision_transformer_block.py new file mode 100644 index 000000000000..680617d847fb --- /dev/null +++ b/models/tt_transformers/tests/multimodal/gemma/test_vision_transformer_block.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.load_checkpoints import convert_vision_hf_to_meta # convert_vision_hf_to_meta, +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.gemma.gemma_image_block import TtGemmaImageTransformerBlock +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "gated", + (True, False), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_block_inference(batch, num_chunks, mesh_device, reset_seeds, gated): + dtype = ttnn.bfloat16 + pcc_required = 0.99 + gated = False + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + if gated: + first_layer_prefix = "model.vision_tower.vision_model.encoder.layers.0." + else: + first_layer_prefix = "model.vision_tower.vision_model.encoder.layers.0." + # partial_state_dict = { + # k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + # } + + dim = model_args.vision_dim + heads = model_args.vision_attn_n_heads + seq_len = model_args.vision_chunk_ntok - 1 + + reference_model = model_args.reference_vision_encoder_block() + # reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + tt_model = TtGemmaImageTransformerBlock( + mesh_device, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + gated=gated, + ) + + pt_attention_input = torch.randn(batch, seq_len, dim) + attention_mask = torch.zeros(batch, 1, seq_len, seq_len) + + attention_input = model_args.prepare_residual_tensor_prefill( + pt_attention_input, + force_replicated=True, + ) + tt_mask = ttnn.from_torch( + attention_mask, + device=mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + tt_out = tt_model(attention_input, mask=tt_mask) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[0, :, :, :] + + reference_output = reference_model(pt_attention_input, attention_mask=attention_mask)[0] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/test_decoder.py b/models/tt_transformers/tests/test_decoder.py index bb61c937f89f..cce49ac87600 100644 --- a/models/tt_transformers/tests/test_decoder.py +++ b/models/tt_transformers/tests/test_decoder.py @@ -87,6 +87,19 @@ def test_decoder_inference( model_args.rope_theta, model_args.rope_scaling, ) + + if model_args.rope_local_theta is not None: + rope_setup_local = RotarySetup( + mesh_device, + model_args.max_batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_local_theta, + None, + ) + else: + rope_setup_local = None + transformation_mats = rope_setup.get_both_trans_mats() # Prepare page table for paged attention @@ -172,12 +185,12 @@ def test_decoder_inference( # Get cos/sin matrices for the current position of each user rot_mats = rope_setup.get_rot_mats(current_pos) - + rot_mats_local = None if rope_setup_local is None else rope_setup_local.get_rot_mats(current_pos) # Run TT model tt_out = tt_model( decode_input, current_pos_tensor, - rot_mats=rot_mats, + rot_mats=[rot_mats, rot_mats_local], mode="decode", page_table=page_table_tt, ) diff --git a/models/tt_transformers/tests/test_decoder_prefill.py b/models/tt_transformers/tests/test_decoder_prefill.py index ca63f294b2d2..96409e438202 100644 --- a/models/tt_transformers/tests/test_decoder_prefill.py +++ b/models/tt_transformers/tests/test_decoder_prefill.py @@ -93,6 +93,16 @@ def test_decoder_inference( theta=model_args.rope_theta, rope_scaling=model_args.rope_scaling, ) + if model_args.rope_local_theta is not None: + rot_mats_local = get_rot_mats( + head_dim=model_args.head_dim, + device=mesh_device, + seq_len=max_seq_len, + theta=model_args.rope_local_theta, + rope_scaling=None, + ) + else: + rot_mats_local = None transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) transformation_mats_prefill = ttnn.as_tensor( transformation_mat_torch, @@ -168,7 +178,9 @@ def test_decoder_inference( attn_mask_torch = torch.triu(attn_mask, diagonal=1) ref_output = reference_model(pt_decode_input, positions[0], freqs_cis_i, mask=attn_mask_torch) # Run TT model - tt_out = tt_model(decode_input, None, rot_mats, user_id=0, mode="prefill", page_table=page_table_tt) + tt_out = tt_model( + decode_input, None, [rot_mats, rot_mats_local], user_id=0, mode="prefill", page_table=page_table_tt + ) tt_out = ttnn.to_torch( tt_out, mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape), diff --git a/models/tt_transformers/tests/test_embedding.py b/models/tt_transformers/tests/test_embedding.py index f6408a397bcd..b5af233ede88 100644 --- a/models/tt_transformers/tests/test_embedding.py +++ b/models/tt_transformers/tests/test_embedding.py @@ -42,7 +42,7 @@ def test_embedding(max_seq_len, batch_size, mesh_device, reset_seeds, ensure_gc) tokenizer = model_args.tokenizer reference_emb = model_args.reference_embedding() - if model_args.is_vision(): + if model_args.is_vision() and not model_args.base_model_name.startswith("gemma-3"): layer_name = "text_model.tok_embeddings.weight" else: layer_name = "tok_embeddings.weight" @@ -68,7 +68,8 @@ def test_embedding(max_seq_len, batch_size, mesh_device, reset_seeds, ensure_gc) dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, ) - tt_output = tt_emb(tt_input) + embed_scale = model_args.embed_scale + tt_output = tt_emb(tt_input, embed_scale) tt_output_torch = ttnn.to_torch( tt_output, mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(0, -1), mesh_shape=model_args.cluster_shape), diff --git a/models/tt_transformers/tt/attention.py b/models/tt_transformers/tt/attention.py index 47ba6a7d95fd..87d7907af88c 100644 --- a/models/tt_transformers/tt/attention.py +++ b/models/tt_transformers/tt/attention.py @@ -27,6 +27,7 @@ def __init__( use_paged_kv_cache=False, ): super().__init__() + self.is_sliding = configuration.sliding_window_pattern[layer_num] self.state_dict = state_dict self.mesh_device = mesh_device diff --git a/models/tt_transformers/tt/common.py b/models/tt_transformers/tt/common.py index 5eebf47ce735..d4d9fa773eae 100644 --- a/models/tt_transformers/tt/common.py +++ b/models/tt_transformers/tt/common.py @@ -5,9 +5,11 @@ import math import re from enum import Enum +from types import SimpleNamespace from typing import Optional import torch +from llama_models.llama3.api.datatypes import ImageMedia from loguru import logger from pydantic import BaseModel, Field @@ -33,8 +35,8 @@ def __init__(self, block_size=32, max_num_blocks=1024): class RopeScalingType(str, Enum): """Types of RoPE scaling.""" - # LINEAR = "linear" # DYNAMIC = "dynamic" + LINEAR = "linear" YARN = "yarn" LLAMA3 = "llama3" DEFAULT = "default" @@ -56,6 +58,14 @@ class RopeScalingLlama3(RopeScaling): high_freq_factor: Optional[float] = 4.0 +class RopeScalingLinear(RopeScaling): + """RoPE scaling configuration for Linear.""" + + # Linear-specific parameters + factor: float = 8.0 + original_max_position_embeddings: int = 2048 + + class RopeScalingYarn(RopeScaling): """RoPE scaling configuration for Yarn.""" @@ -72,6 +82,8 @@ def rope_scaling_model_factory(rope_scaling_params: dict) -> RopeScaling: return RopeScalingLlama3(**rope_scaling_params) elif rope_scaling_type == RopeScalingType.YARN: return RopeScalingYarn(**rope_scaling_params) + elif rope_scaling_type == RopeScalingType.LINEAR: + return RopeScalingLinear(**rope_scaling_params) elif rope_scaling_type in ["default", "mrope"]: logger.warning( f"Rope scaling type was set to {rope_scaling_type}, defaulting to no rope scaling as this rope type is not supported yet by TTT" @@ -216,9 +228,8 @@ def encode_prompt_hf(tokenizer, prompt_text, system_prompt_text=None): return tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True) -def apply_scaling(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): - # FIXME: Llama-3.x specific scaling - we need to support yarn for Qwen2.5 models - # Values obtained from grid search +def compute_llama3_parameters(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): + """Llama-3.x specific scaling for rotary embeddings.""" low_freq_factor = 1 high_freq_factor = 4 @@ -238,7 +249,31 @@ def apply_scaling(freqs: torch.Tensor, scale_factor: float, orig_context_len: in return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) -def precompute_freqs(dim: int, end: int, theta, scale_factor, orig_context_len): +def compute_linear_parameters(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): + """Linear scaling for rotary embeddings.""" + freqs /= scale_factor + return freqs + + +def compute_default_parameters(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): + """Default scaling for rotary embeddings.""" + return freqs + + +def apply_scaling(freqs: torch.Tensor, scale_factor: float, orig_context_len: int, rope_type="llama3"): + # FIXME: Llama-3.x specific scaling - we need to support yarn for Qwen2.5 models + + if rope_type == "default": + freqs = compute_default_parameters(freqs, scale_factor, orig_context_len) + elif rope_type == "linear": + freqs = compute_linear_parameters(freqs, scale_factor, orig_context_len) + elif rope_type == "llama3": + freqs = compute_llama3_parameters(freqs, scale_factor, orig_context_len) + + return freqs + + +def precompute_freqs(dim: int, end: int, theta, scale_factor, orig_context_len, rope_type="llama3"): """ Precompute the frequency tensor for sine and cosine values with given dimensions. @@ -253,7 +288,7 @@ def precompute_freqs(dim: int, end: int, theta, scale_factor, orig_context_len): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end) if scale_factor is not None: - freqs = apply_scaling(freqs, scale_factor, orig_context_len) + freqs = apply_scaling(freqs, scale_factor, orig_context_len, rope_type=rope_type) freqs = torch.outer(t, freqs).float() return torch.cos(freqs), torch.sin(freqs) @@ -614,3 +649,46 @@ def create_tt_model( tt_kv_cache = [l.attention.layer_past for l in model.layers] if paged_attention_config else None return tt_model_args, model, tt_kv_cache, state_dict + + +def hf_multimodal_encode(messages, processor): + hf_messages = [] + + for msg in messages: + hf_content = [] + + for item in msg.content: + if isinstance(item, ImageMedia): + hf_content.append( + { + "type": "image", + "image": item.image, + } + ) + elif isinstance(item, str): + hf_content.append( + { + "type": "text", + "text": item, + } + ) + + hf_messages.append( + { + "role": msg.role, + "content": hf_content, + } + ) + + encoded = processor.apply_chat_template( + hf_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to("cpu", dtype=torch.bfloat16) + + return SimpleNamespace( + **encoded, + tokens=encoded["input_ids"].squeeze(0), + vision=SimpleNamespace( + images=encoded["pixel_values"], + mask=None, + ), + ) diff --git a/models/tt_transformers/tt/decoder.py b/models/tt_transformers/tt/decoder.py index 24e95a709b8a..7a97b55e9b58 100644 --- a/models/tt_transformers/tt/decoder.py +++ b/models/tt_transformers/tt/decoder.py @@ -5,6 +5,7 @@ from models.common.lightweightmodule import LightweightModule from models.common.rmsnorm import RMSNorm from models.tt_transformers.tt.attention import Attention as DefaultAttention +from models.tt_transformers.tt.ccl import tt_all_reduce from models.tt_transformers.tt.distributed_norm import DistributedNorm from models.tt_transformers.tt.mlp import MLP from models.tt_transformers.tt.model_config import TensorGroup @@ -29,6 +30,8 @@ def __init__( self.state_dict = state_dict self.mesh_device = mesh_device + self.num_devices = args.num_devices + self.TG = self.num_devices == 32 self.args = args self.hidden_size = args.dim self.n_heads = args.n_heads @@ -102,6 +105,53 @@ def __init__( args, TG=args.is_galaxy, ) + if f"layers.{layer_num}.pre_feedforward_layernorm.weight" in self.state_dict: + self.pre_ff_norm = DistributedNorm( # pre_feedforward_layernorm + RMSNorm( + device=mesh_device, + dim=args.dim, + eps=args.norm_eps, + state_dict=state_dict, + add_unit_offset=self.args.rms_norm_add_unit_offset, + state_dict_prefix=args.get_state_dict_prefix("", layer_num), + weight_cache_path=None if args.dummy_weights else weight_cache_path, + weight_dtype=ttnn.bfloat16, + weight_key="pre_feedforward_layernorm", + is_distributed=self.args.is_distributed_norm, + sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], + sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], + ccl_topology=self.args.ccl_topology(), + ), + args, + TG=args.is_galaxy, + ) + else: + # If pre_feedforward_layernorm is not in state_dict, we do not use it + self.pre_ff_norm = None + + if f"layers.{layer_num}.post_feedforward_layernorm.weight" in self.state_dict: + self.post_ff_norm = DistributedNorm( # post_feedforward_layernorm + RMSNorm( + device=mesh_device, + dim=args.dim, + eps=args.norm_eps, + add_unit_offset=self.args.rms_norm_add_unit_offset, + state_dict=state_dict, + state_dict_prefix=args.get_state_dict_prefix("", layer_num), + weight_cache_path=None if args.dummy_weights else weight_cache_path, + weight_dtype=ttnn.bfloat16, + weight_key="post_feedforward_layernorm", + is_distributed=self.args.is_distributed_norm, + sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], + sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], + ccl_topology=self.args.ccl_topology(), + ), + args, + TG=args.is_galaxy, + ) + else: + # If post_feedforward_layernorm is not in state_dict, we do not use it + self.post_ff_norm = None def forward( self, @@ -116,6 +166,7 @@ def forward( kv_cache=None, ) -> ttnn.Tensor: TG = self.args.is_galaxy + residual = x # x is fractured across devices and interleaved in DRAM (for prefill) and sharded in L1 (for decode) skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG assert ( @@ -124,10 +175,15 @@ def forward( # Norms take fractured inputs and output replicated across devices attn_in = self.attention_norm(x, mode) # Attention takes replicated inputs and produces fractured outputs + if self.attention.is_sliding: + position_embeddings = rot_mats[1] + else: + position_embeddings = rot_mats[0] + attn_out = self.attention.forward( attn_in, current_pos, - rot_mats, + position_embeddings, user_id, mode, page_table=page_table, @@ -135,25 +191,60 @@ def forward( chunk_start_idx=chunk_start_idx, kv_cache=kv_cache, ) - # Here x and attn_out are both fractured across devices - h = ttnn.add(x, attn_out, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16 if TG else None) - ttnn.deallocate(attn_out) + if self.pre_ff_norm == None: + attn_out = ttnn.add(x, attn_out, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16 if TG else None) + + residual = attn_out + + hidden_states = self.ff_norm(attn_out, mode) + if self.pre_ff_norm is not None: + hidden_states = tt_all_reduce( + hidden_states, + self.mesh_device, + cluster_axis=0, + dim=3, + num_reduce_scatter_links=self.args.num_reduce_scatter_links, + num_all_gather_links=self.args.num_all_gather_links, + topology=self.args.ccl_topology(), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=self.args.ccl_dtype, + ) + hidden_states = ttnn.add(residual, hidden_states, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16) + + residual = hidden_states + + hidden_states = self.pre_ff_norm(hidden_states, mode) + if mode == "prefill": x.deallocate(True) - # Norms take fractured inputs and output replicated across devices - ff_in = self.ff_norm(h, mode) + # ttnn.deallocate(attn_out) + if TG and mode == "decode": - ff_in = ttnn.to_memory_config(ff_in, memory_config=self.model_config["MLP_ACT_MEMCFG"]) + hidden_states = ttnn.to_memory_config(hidden_states, memory_config=self.model_config["MLP_ACT_MEMCFG"]) # MLP takes replicated inputs and produces fractured outputs - ff_out = self.feed_forward.forward(ff_in, mode) - # ff_out and h are both fractured across devices + hidden_states = self.feed_forward.forward(hidden_states, mode) + activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( decoder_id=self.layer_num, tensor=TensorGroup.ACTIVATION ) + + if self.post_ff_norm is not None: + hidden_states = self.post_ff_norm(hidden_states, mode) # Gathered + hidden_states = tt_all_reduce( + hidden_states, + self.mesh_device, + cluster_axis=0, + dim=3, + num_reduce_scatter_links=self.args.num_reduce_scatter_links, + num_all_gather_links=self.args.num_all_gather_links, + topology=self.args.ccl_topology(), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=self.args.ccl_dtype, + ) out = ttnn.add( - h, - ff_out, + residual, + hidden_states, memory_config=skip_mem_cfg, dtype=self.args.ccl_dtype if TG and not self.args.is_distributed_norm(mode) diff --git a/models/tt_transformers/tt/embedding.py b/models/tt_transformers/tt/embedding.py index c1420ad22f68..344392d8237e 100644 --- a/models/tt_transformers/tt/embedding.py +++ b/models/tt_transformers/tt/embedding.py @@ -33,6 +33,7 @@ def __init__( cache_file_name=cache_name, ) - def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: + def forward(self, x: ttnn.Tensor, embed_scale: int = 1.0) -> ttnn.Tensor: x = ttnn.embedding(x, self.weights, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) + x = ttnn.multiply(x, embed_scale) return x diff --git a/models/tt_transformers/tt/generator.py b/models/tt_transformers/tt/generator.py index c9413fe6f44a..ca92f4d53a25 100644 --- a/models/tt_transformers/tt/generator.py +++ b/models/tt_transformers/tt/generator.py @@ -22,6 +22,7 @@ get_padded_prefill_len, num_blocks_in_seq, ) +from models.tt_transformers.tt.model_config import CheckpointType @dataclass(frozen=True) @@ -57,7 +58,7 @@ def __init__(self, model, model_args, mesh_device, tokenizer=None, formatter=Non # Note: This function is called by vLLM def prefill_forward_text( - self, tokens: torch.Tensor, page_table=None, kv_cache=None, prompt_lens=None, empty_slots=None + self, tokens: torch.Tensor, page_table=None, kv_cache=None, prompt_lens=None, empty_slots=None, **kwargs ): if page_table is not None: assert isinstance(page_table, torch.Tensor), "page_table mush be torch.Tensor" @@ -79,6 +80,7 @@ def prefill_forward_text( seq_len = int(prompt_lens[idx]) last_token_idx = seq_len - 1 prefill_seq_len = get_padded_prefill_len(seq_len) + local_kwargs = kwargs.copy() # Avoid modifying original kwargs logger.info(f"Prefilling User {user_id + 1} up to {seq_len} tokens") @@ -94,6 +96,12 @@ def prefill_forward_text( ) model_kv_cache = kv_cache[model_id] if kv_cache is not None else None + # Check if 'pixel_values' exists and index it safely + if "pixel_values" in local_kwargs: + local_kwargs["pixel_values"] = local_kwargs["pixel_values"][idx] + if "image_grid_thw" in local_kwargs: + local_kwargs["image_grid_thw"] = local_kwargs["image_grid_thw"][idx] + logits = self.prefill_forward_single_user_text( prefill_ids, page_table=page_table_user, @@ -101,6 +109,7 @@ def prefill_forward_text( last_token_idx=last_token_idx, kv_cache=model_kv_cache, model_id=model_id, + **local_kwargs, ) out_list.append(logits) @@ -116,7 +125,9 @@ def prefill_forward_text( logger.info(f"Finished prefill for all users up to {batch_seq_len} tokens, Starting decode...") return output_logits - def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_token_idx, kv_cache=None, model_id=-1): + def prefill_forward_single_user_text( + self, tokens, page_table, user_id, last_token_idx, kv_cache=None, model_id=-1, **kwargs + ): seq_len = tokens.shape[-1] use_chunked_prefill = seq_len > self.model_args[model_id].max_prefill_chunk_size if use_chunked_prefill: @@ -165,6 +176,7 @@ def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_tok start_pos=chunk_start, page_table=page_table_user_padded, chunk_page_table=chunk_page_table, + **kwargs, ) tt_logits = self.model[model_id].ttnn_prefill_forward( chunk_prefill_input, @@ -175,6 +187,7 @@ def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_tok chunk_start_idx=chunk_start, get_last_token=(last_token_idx_in_chunk // 32) * 32, kv_cache=kv_cache, + **kwargs, ) if chunk_start == last_chunk_start: @@ -185,6 +198,7 @@ def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_tok prefill_input, rot_mats_prefill, page_table_tt, _ = self.model[model_id].prepare_inputs_prefill( tokens, page_table=page_table, + **kwargs, ) tt_logits = self.model[model_id].ttnn_prefill_forward( @@ -485,6 +499,61 @@ def _prefill_forward_single_user( # Note: This function is called by vLLM def prefill_forward( + self, + vision_images, + vision_masks, + tokens, + xattn_caches, + total_lens, + prompt_lens, + page_table=None, + kv_cache=None, + cross_page_table=None, + empty_slots=None, + **kwargs, + ): + if self.model_args[0].checkpoint_type == CheckpointType.HuggingFace: + logits = self.prefill_forward_text( + tokens, + page_table=page_table, + kv_cache=kv_cache, + prompt_lens=prompt_lens, + pixel_values=vision_images, + **kwargs, + ) + + return logits, None, None, None, None + + else: + ( + output_logits, + prefill_output_xattn_masks, + prefill_output_full_text_row_masked_out_masks, + decode_output_xattn_masks, + decode_output_full_text_row_masked_out_masks, + ) = self.prefill_forward_llama_vision( + vision_images, + vision_masks, + tokens, + xattn_caches, + total_lens, + prompt_lens, + page_table=page_table, + kv_cache=kv_cache, + cross_page_table=cross_page_table, + empty_slots=empty_slots, + ) + + return ( + output_logits, + prefill_output_xattn_masks, + prefill_output_full_text_row_masked_out_masks, + decode_output_xattn_masks, + decode_output_full_text_row_masked_out_masks, + ) + + # Note: This function is called by vLLM + def prefill_forward_llama_vision( self, vision_images, vision_masks, @@ -581,7 +650,7 @@ def prefill_forward( ) # Note: This function is called by vLLM - def decode_forward( + def decode_forward_llama_vision( self, start_pos, tokens, @@ -645,6 +714,45 @@ def decode_forward( else: return tt_logits + def decode_forward( + self, + start_pos, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + xattn_caches=None, + page_table=None, + kv_cache=None, + cross_page_table=None, + enable_trace=True, + read_from_device=True, + ): + if self.model_args[0].checkpoint_type == CheckpointType.HuggingFace: + return self.decode_forward_text( + tokens, + start_pos, + enable_trace=enable_trace, + page_table=page_table, + kv_cache=kv_cache, + ) + else: + return self.decode_forward_llama_vision( + start_pos, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + xattn_caches, + page_table, + kv_cache, + cross_page_table, + enable_trace, + read_from_device, + ) + # Note: This function is called by vLLM def read_decode_output(self, tt_out, unpadded_batch, is_tokens=False): """ diff --git a/models/tt_transformers/tt/lm_head.py b/models/tt_transformers/tt/lm_head.py index 3be020957904..c540343a4a2c 100644 --- a/models/tt_transformers/tt/lm_head.py +++ b/models/tt_transformers/tt/lm_head.py @@ -31,6 +31,7 @@ def __init__( self.num_devices = args.num_devices size_per_device = self.vocab_size // self.num_devices + self.model_config = args.get_model_config() if args.is_galaxy: size_per_device = self.padded_vocab_size // self.num_devices @@ -138,12 +139,14 @@ def forward(self, x: ttnn.Tensor): compute_kernel_config=self.compute_kernel_config, program_config=pc, memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - dtype=ttnn.bfloat8_b, + dtype=self.args.lm_head_dtype or ttnn.bfloat8_b, + ) + outputs.append( + ttnn.sharded_to_interleaved(output, memory_config=self.model_config["LM_HEAD_OUTPUT_MEMCFG"]) ) - outputs.append(ttnn.sharded_to_interleaved(output, memory_config=ttnn.L1_MEMORY_CONFIG)) # Concatenate the outputs - output = ttnn.concat(outputs, dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG) + output = ttnn.concat(outputs, dim=-1, memory_config=self.model_config["LM_HEAD_OUTPUT_MEMCFG"]) output = tt_all_reduce( output, diff --git a/models/tt_transformers/tt/load_checkpoints.py b/models/tt_transformers/tt/load_checkpoints.py index 6b28e2b4e5ce..b81d808f8635 100644 --- a/models/tt_transformers/tt/load_checkpoints.py +++ b/models/tt_transformers/tt/load_checkpoints.py @@ -85,6 +85,347 @@ def convert_hf_to_meta(state_dict, head_dim): return state_dict +def convert_vision_hf_to_meta(state_dict, head_dim): + state_dict = split_hf_keys(state_dict) + # state_dict = convert_hf_qkv_to_meta_format(state_dict, head_dim) + state_dict = map_vision_hf_to_meta_keys(state_dict, head_dim) + return state_dict + + +def map_hf_to_meta_keys(loaded_weights): + hf_to_meta = { + # Top level mappings + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + # Layer level mappings + "input_layernorm.weight": "attention_norm.weight", + "post_attention_layernorm.weight": "ffn_norm.weight", + # Attention module mappings + "self_attn.q_proj.weight": "attention.wq.weight", + "self_attn.k_proj.weight": "attention.wk.weight", + "self_attn.v_proj.weight": "attention.wv.weight", + "self_attn.o_proj.weight": "attention.wo.weight", + "self_attn.q_proj.bias": "attention.wq.bias", + "self_attn.k_proj.bias": "attention.wk.bias", + "self_attn.v_proj.bias": "attention.wv.bias", + "self_attn.q_norm.weight": "attention.q_norm.weight", + "self_attn.k_norm.weight": "attention.k_norm.weight", + "self_attn.o_proj.bias": "attention.wo.bias", + # Feed forward module mappings + "mlp.gate_proj.weight": "feed_forward.w1.weight", + "mlp.up_proj.weight": "feed_forward.w3.weight", + "mlp.down_proj.weight": "feed_forward.w2.weight", + # MLP bias mappings + "mlp.gate_proj.bias": "feed_forward.w1.bias", + "mlp.up_proj.bias": "feed_forward.w3.bias", + "mlp.down_proj.bias": "feed_forward.w2.bias", + # === Additional FFN layernorms (Gemma3 specific) === + "pre_feedforward_layernorm.weight": "pre_feedforward_layernorm.weight", + "post_feedforward_layernorm.weight": "post_feedforward_layernorm.weight", + # Direct module mappings + "gate_proj.weight": "w1.weight", + "down_proj.weight": "w2.weight", + "up_proj.weight": "w3.weight", + "q_proj.weight": "wq.weight", + "k_proj.weight": "wk.weight", + "v_proj.weight": "wv.weight", + "o_proj.weight": "wo.weight", + "q_proj.bias": "wq.bias", + "k_proj.bias": "wk.bias", + "v_proj.bias": "wv.bias", + "q_norm.weight": "q_norm.weight", + "k_norm.weight": "k_norm.weight", + "o_proj.bias": "wo.bias", + # Direct MLP bias mappings + "gate_proj.bias": "w1.bias", + "up_proj.bias": "w3.bias", + "down_proj.bias": "w2.bias", + "weight": "emb.weight", # For host embeddings + # Full path layer mappings + "model.layers.{layer}.input_layernorm.weight": "layers.{layer}.attention_norm.weight", + "model.layers.{layer}.post_attention_layernorm.weight": "layers.{layer}.ffn_norm.weight", + "model.layers.{layer}.self_attn.q_proj.weight": "layers.{layer}.attention.wq.weight", + "model.layers.{layer}.self_attn.k_proj.weight": "layers.{layer}.attention.wk.weight", + "model.layers.{layer}.self_attn.v_proj.weight": "layers.{layer}.attention.wv.weight", + "model.layers.{layer}.self_attn.o_proj.weight": "layers.{layer}.attention.wo.weight", + "model.layers.{layer}.self_attn.q_proj.bias": "layers.{layer}.attention.wq.bias", + "model.layers.{layer}.self_attn.k_proj.bias": "layers.{layer}.attention.wk.bias", + "model.layers.{layer}.self_attn.v_proj.bias": "layers.{layer}.attention.wv.bias", + "model.layers.{layer}.self_attn.q_norm.weight": "layers.{layer}.attention.q_norm.weight", + "model.layers.{layer}.self_attn.k_norm.weight": "layers.{layer}.attention.k_norm.weight", + "model.layers.{layer}.self_attn.o_proj.bias": "layers.{layer}.attention.wo.bias", + "model.layers.{layer}.mlp.gate_proj.weight": "layers.{layer}.feed_forward.w1.weight", + "model.layers.{layer}.mlp.up_proj.weight": "layers.{layer}.feed_forward.w3.weight", + "model.layers.{layer}.mlp.down_proj.weight": "layers.{layer}.feed_forward.w2.weight", + # Full path MLP bias mappings + "model.layers.{layer}.mlp.gate_proj.bias": "layers.{layer}.feed_forward.w1.bias", + "model.layers.{layer}.mlp.up_proj.bias": "layers.{layer}.feed_forward.w3.bias", + "model.layers.{layer}.mlp.down_proj.bias": "layers.{layer}.feed_forward.w2.bias", + "model.layers.{layer}.pre_feedforward_layernorm.weight": "layers.{layer}.pre_feedforward_layernorm.weight", + "model.layers.{layer}.post_feedforward_layernorm.weight": "layers.{layer}.post_feedforward_layernorm.weight", + } + + meta_state_dict = {} + for key, tensor in loaded_weights.items(): + # Remove known prefix if present + prefix = next((p for p in _get_known_prefixes_mapping().keys() if key.startswith(p)), "") + key = key.replace(prefix, _get_known_prefixes_mapping().get(prefix, ""), 1) + + new_key = key + if key in hf_to_meta: + # Direct match for top-level keys + new_key = hf_to_meta[key] + elif key.startswith("model.layers."): + # Extract layer number and form a template key + parts = key.split(".") + layer_num = parts[2] # e.g. "0" in "model.layers.0.input_layernorm.weight" + template_key = "model.layers.{layer}." + ".".join(parts[3:]) + if template_key in hf_to_meta: + new_key = hf_to_meta[template_key].format(layer=layer_num) + else: + new_key = key[len("model.") :] # Remove "model." prefix + + meta_state_dict[new_key] = tensor + + return meta_state_dict + + +def map_vision_meta_to_hf_keys(loaded_weights): + language_weights = { + key[len("language_model.") :]: tensor + for key, tensor in loaded_weights.items() + if key.startswith("language_model.") + } + mapped_language_weights = map_meta_to_hf_keys(language_weights, language_prefix="language_model.") + other_weights = {key: tensor for key, tensor in loaded_weights.items() if not key.startswith("language_model.")} + hf_state_dict = {**mapped_language_weights} + loaded_weights = {**other_weights} + meta_to_hf_mappings = { + # vision MLP + "c_fc.weight": "fc1.weight", + "c_fc.bias": "fc1.bias", + "c_proj.weight": "fc2.weight", + "c_proj.bias": "fc2.bias", + # vision attention + # "wq.weight": "q_proj.weight", + # "wk.weight": "k_proj.weight", + # "wv.weight": "v_proj.weight", + # "wo.weight": "out_proj.weight", + # "wq.bias": "q_proj.bias", + # "wk.bias": "k_proj.bias", + # "wv.bias": "v_proj.bias", + # "wo.bias": "out_proj.bias", + # vision encoder block + "attn.wq.weight": "self_attn.q_proj.weight", + "attn.wk.weight": "self_attn.k_proj.weight", + "attn.wv.weight": "self_attn.v_proj.weight", + "attn.wo.weight": "self_attn.out_proj.weight", + "attn.wq.bias": "self_attn.q_proj.bias", + "attn.wk.bias": "self_attn.k_proj.bias", + "attn.wv.bias": "self_attn.v_proj.bias", + "attn.wo.bias": "self_attn.out_proj.bias", + "ln_1.weight": "layer_norm1.weight", + "ln_1.bias": "layer_norm1.bias", + "ln_2.weight": "layer_norm2.weight", + "ln_2.bias": "layer_norm2.bias", + "mlp.c_fc.weight": "mlp.fc1.weight", + "mlp.c_fc.bias": "mlp.fc1.bias", + "mlp.c_proj.weight": "mlp.fc2.weight", + "mlp.c_proj.bias": "mlp.fc2.bias", + # vision encoder + "layers.{layer}.attn.wq.weight": "layers.{layer}.self_attn.q_proj.weight", + "layers.{layer}.attn.wk.weight": "layers.{layer}.self_attn.k_proj.weight", + "layers.{layer}.attn.wv.weight": "layers.{layer}.self_attn.v_proj.weight", + "layers.{layer}.attn.wo.weight": "layers.{layer}.self_attn.out_proj.weight", + "layers.{layer}.attn.wq.bias": "layers.{layer}.self_attn.q_proj.bias", + "layers.{layer}.attn.wk.bias": "layers.{layer}.self_attn.k_proj.bias", + "layers.{layer}.attn.wv.bias": "layers.{layer}.self_attn.v_proj.bias", + "layers.{layer}.attn.wo.bias": "layers.{layer}.self_attn.out_proj.bias", + "layers.{layer}.ln_1.weight": "layers.{layer}.layer_norm1.weight", + "layers.{layer}.ln_1.bias": "layers.{layer}.layer_norm1.bias", + "layers.{layer}.ln_2.weight": "layers.{layer}.layer_norm2.weight", + "layers.{layer}.ln_2.bias": "layers.{layer}.layer_norm2.bias", + "layers.{layer}.mlp.c_fc.weight": "layers.{layer}.mlp.fc1.weight", + "layers.{layer}.mlp.c_fc.bias": "layers.{layer}.mlp.fc1.bias", + "layers.{layer}.mlp.c_proj.weight": "layers.{layer}.mlp.fc2.weight", + "layers.{layer}.mlp.c_proj.bias": "layers.{layer}.mlp.fc2.bias", + # vision transformer + "encoder.layers.{layer}.attn.wq.weight": "encoder.layers.{layer}.self_attn.q_proj.weight", + "encoder.layers.{layer}.attn.wk.weight": "encoder.layers.{layer}.self_attn.k_proj.weight", + "encoder.layers.{layer}.attn.wv.weight": "encoder.layers.{layer}.self_attn.v_proj.weight", + "encoder.layers.{layer}.attn.wo.weight": "encoder.layers.{layer}.self_attn.out_proj.weight", + "encoder.layers.{layer}.attn.wq.bias": "encoder.layers.{layer}.self_attn.q_proj.bias", + "encoder.layers.{layer}.attn.wk.bias": "encoder.layers.{layer}.self_attn.k_proj.bias", + "encoder.layers.{layer}.attn.wv.bias": "encoder.layers.{layer}.self_attn.v_proj.bias", + "encoder.layers.{layer}.attn.wo.bias": "encoder.layers.{layer}.self_attn.out_proj.bias", + "ln_post.weight": "post_layernorm.weight", + "ln_post.bias": "post_layernorm.bias", + # Top level + "_linear.weight": "weight", # patch_embedding + "_linear.bias": "bias", # patch_embedding + "positional_embedding": "weight", # pos_emb + "model.vision_tower.vision_model.embeddings.patch_embedding._linear.weight": "vision_tower.vision_model.embeddings.patch_embedding.weight", + "model.vision_tower.vision_model.embeddings.patch_embedding._linear.bias": "vision_tower.vision_model.embeddings.patch_embedding._linear.bias", + "model.vision_tower.vision_model.embeddings.position_embedding.positional_embedding": "vision_tower.vision_model.embeddings.position_embedding.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wq.weight": "vision_tower.vision_model.encoder.layers.{layer}.self_attn.q_proj.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wk.weight": "vision_tower.vision_model.encoder.layers.{layer}.self_attn.k_proj.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wv.weight": "vision_tower.vision_model.encoder.layers.{layer}.self_attn.v_proj.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wo.weight": "vision_tower.vision_model.encoder.layers.{layer}.self_attn.out_proj.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wq.bias": "vision_tower.vision_model.encoder.layers.{layer}.self_attn.q_proj.bias", + "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wk.bias": "vision_tower.vision_model.encoder.layers.{layer}.self_attn.k_proj.bias", + "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wv.bias": "vision_tower.vision_model.encoder.layers.{layer}.self_attn.v_proj.bias", + "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wo.bias": "vision_tower.vision_model.encoder.layers.{layer}.self_attn.out_proj.bias", + "model.vision_tower.vision_model.encoder.layers.{layer}.ln_1.weight": "vision_tower.vision_model.encoder.layers.{layer}.layer_norm1.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.ln_1.bias": "vision_tower.vision_model.encoder.layers.{layer}.layer_norm1.bias", + "model.vision_tower.vision_model.encoder.layers.{layer}.ln_2.weight": "vision_tower.vision_model.encoder.layers.{layer}.layer_norm2.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.ln_2.bias": "vision_tower.vision_model.encoder.layers.{layer}.layer_norm2.bias", + "model.vision_tower.vision_model.encoder.layers.{layer}.mlp.c_fc.weight": "vision_tower.vision_model.encoder.layers.{layer}.mlp.fc1.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.mlp.c_fc.bias": "vision_tower.vision_model.encoder.layers.{layer}.mlp.fc1.bias", + "model.vision_tower.vision_model.encoder.layers.{layer}.mlp.c_proj.weight": "vision_tower.vision_model.encoder.layers.{layer}.mlp.fc2.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.mlp.c_proj.bias": "vision_tower.vision_model.encoder.layers.{layer}.mlp.fc2.bias", + "model.vision_tower.vision_model.ln_post.weight": "vision_tower.vision_model.post_layernorm.weight", + "model.vision_tower.vision_model.ln_post.bias": "vision_tower.vision_model.post_layernorm.bias", + } + + for key, tensor in loaded_weights.items(): + # Handle full model paths with layer numbers + if "model.vision_tower.vision_model.encoder.layers." in key: + parts = key.split(".") + layer_num = parts[5] + remainder = ".".join(parts[6:]) + if remainder in meta_to_hf_mappings: + new_key = f"model.vision_tower.vision_model.encoder.layers.{layer_num}.{meta_to_hf_mappings[remainder]}" + hf_state_dict[new_key] = tensor + continue + + # Handle full vision encoder paths with layer numbers + if "layers." in key: + parts = key.split(".") + layer_num = parts[1] # e.g. "0" in "model.layers.0.input_layernorm.weight" + template_key = "layers.{layer}." + ".".join(parts[2:]) + if template_key in meta_to_hf_mappings: + hf_state_dict[meta_to_hf_mappings[template_key].format(layer=layer_num)] = tensor + continue + + # Try exact matches first + if key in meta_to_hf_mappings: + hf_state_dict[meta_to_hf_mappings[key]] = tensor + continue + + # For submodule state dicts, try matching the end of the key + matched = False + for meta_pattern, hf_pattern in meta_to_hf_mappings.items(): + if key.endswith("." + meta_pattern): + # Replace only the matching part at the end + prefix = key[: -len(meta_pattern)] + new_key = prefix + hf_pattern + hf_state_dict[new_key] = tensor + matched = True + break + + # If no mapping found, keep the original key + if not matched: + hf_state_dict[key] = tensor + + return hf_state_dict + + +def map_vision_hf_to_meta_keys(loaded_weights, head_dim): + hf_to_meta = { + # vision MLP + "fc1.weight": "c_fc.weight", + "fc1.bias": "c_fc.bias", + "fc2.weight": "c_proj.weight", + "fc2.bias": "c_proj.bias", + # vision attention + # "q_proj.weight": "wq.weight", + # "k_proj.weight": "wk.weight", + # "v_proj.weight": "wv.weight", + # "out_proj.weight": "wo.weight", + # "q_proj.bias": "wq.bias", + # "k_proj.bias": "wk.bias", + # "v_proj.bias": "wv.bias", + # "out_proj.bias": "wo.bias", + # vision encoder + "self_attn.q_proj.weight": "attn.wq.weight", + "self_attn.k_proj.weight": "attn.wk.weight", + "self_attn.v_proj.weight": "attn.wv.weight", + "self_attn.out_proj.weight": "attn.wo.weight", + "self_attn.q_proj.bias": "attn.wq.bias", + "self_attn.k_proj.bias": "attn.wk.bias", + "self_attn.v_proj.bias": "attn.wv.bias", + "self_attn.out_proj.bias": "attn.wo.bias", + "layer_norm1.weight": "ln_1.weight", + "layer_norm1.bias": "ln_1.bias", + "layer_norm2.weight": "ln_2.weight", + "layer_norm2.bias": "ln_2.bias", + "mlp.fc1.weight": "mlp.c_fc.weight", + "mlp.fc1.bias": "mlp.c_fc.bias", + "mlp.fc2.weight": "mlp.c_proj.weight", + "mlp.fc2.bias": "mlp.c_proj.bias", + # Top level + # vision transformer + "encoder.layers.{layer}.self_attn.q_proj.weight": "encoder.layers.{layer}.attn.wq.weight", + "encoder.layers.{layer}.self_attn.k_proj.weight": "encoder.layers.{layer}.attn.wk.weight", + "encoder.layers.{layer}.self_attn.v_proj.weight": "encoder.layers.{layer}.attn.wv.weight", + "encoder.layers.{layer}.self_attn.out_proj.weight": "encoder.layers.{layer}.attn.wo.weight", + "encoder.layers.{layer}.self_attn.q_proj.bias": "encoder.layers.{layer}.attn.wq.bias", + "encoder.layers.{layer}.self_attn.k_proj.bias": "encoder.layers.{layer}.attn.wk.bias", + "encoder.layers.{layer}.self_attn.v_proj.bias": "encoder.layers.{layer}.attn.wv.bias", + "encoder.layers.{layer}.self_attn.out_proj.bias": "encoder.layers.{layer}.attn.wo.bias", + "post_layernorm.weight": "ln_post.weight", + "post_layernorm.bias": "ln_post.bias", + "weight": "_linear.weight", + "bias": "_linear.bias", + "weight": "positional_embedding", # pos_emb + "model.vision_tower.vision_model.embeddings.patch_embedding.weight": "model.vision_tower.vision_model.embeddings.patch_embedding._linear.weight", + "model.vision_tower.vision_model.embeddings.patch_embedding.bias": "model.vision_tower.vision_model.embeddings.patch_embedding._linear.bias", + "model.vision_tower.vision_model.embeddings.position_embedding.weight": "model.vision_tower.vision_model.embeddings.position_embedding.positional_embedding", + "model.vision_tower.vision_model.encoder.layers.{layer}.self_attn.q_proj.weight": "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wq.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.self_attn.k_proj.weight": "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wk.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.self_attn.v_proj.weight": "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wv.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.self_attn.out_proj.weight": "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wo.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.self_attn.q_proj.bias": "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wq.bias", + "model.vision_tower.vision_model.encoder.layers.{layer}.self_attn.k_proj.bias": "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wk.bias", + "model.vision_tower.vision_model.encoder.layers.{layer}.self_attn.v_proj.bias": "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wv.bias", + "model.vision_tower.vision_model.encoder.layers.{layer}.self_attn.out_proj.bias": "model.vision_tower.vision_model.encoder.layers.{layer}.attn.wo.bias", + "model.vision_tower.vision_model.encoder.layers.{layer}.layer_norm1.weight": "model.vision_tower.vision_model.encoder.layers.{layer}.ln_1.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.layer_norm1.bias": "model.vision_tower.vision_model.encoder.layers.{layer}.ln_1.bias", + "model.vision_tower.vision_model.encoder.layers.{layer}.layer_norm2.weight": "model.vision_tower.vision_model.encoder.layers.{layer}.ln_2.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.layer_norm2.bias": "model.vision_tower.vision_model.encoder.layers.{layer}.ln_2.bias", + "model.vision_tower.vision_model.encoder.layers.{layer}.mlp.fc1.weight": "model.vision_tower.vision_model.encoder.layers.{layer}.mlp.c_fc.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.mlp.fc1.bias": "model.vision_tower.vision_model.encoder.layers.{layer}.mlp.c_fc.bias", + "model.vision_tower.vision_model.encoder.layers.{layer}.mlp.fc2.weight": "model.vision_tower.vision_model.encoder.layers.{layer}.mlp.c_proj.weight", + "model.vision_tower.vision_model.encoder.layers.{layer}.mlp.fc2.bias": "model.vision_tower.vision_model.encoder.layers.{layer}.mlp.c_proj.bias", + "model.vision_tower.vision_model.post_layernorm.weight": "model.vision_tower.vision_model.ln_post.weight", + "model.vision_tower.vision_model.post_layernorm.bias": "model.vision_tower.vision_model.ln_post.bias", + } + + remapped = {} + for key, tensor in loaded_weights.items(): + if key in hf_to_meta: + remapped[hf_to_meta[key]] = tensor + elif "model.vision_tower.vision_model.encoder.layers." in key: + parts = key.split(".") + layer_num = parts[5] # e.g. "0" in "model.layers.0.input_layernorm.weight" + template_key = "model.vision_tower.vision_model.encoder.layers.{layer}." + ".".join(parts[6:]) + if template_key in hf_to_meta: + remapped[hf_to_meta[template_key].format(layer=layer_num)] = tensor + else: + remapped[key] = tensor + + # Remove language_model keys + non_text_weights = {k: v for k, v in remapped.items() if not k.startswith("model.language_model.")} + text_weights = { + k: v for k, v in loaded_weights.items() if k.startswith("model.language_model.") or k.startswith("lm_head.") + } + text_weights = convert_hf_qkv_to_meta_format(text_weights, head_dim) + # remapped_text = map_hf_to_meta_keys(text_weights, prefix="model.language_model.") + remapped_text = map_hf_to_meta_keys(text_weights) + return {**non_text_weights, **remapped_text} + + def load_meta_state_dict(ckpt_dir, n_layers=None, start_layer_idx=0): checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" @@ -238,6 +579,7 @@ def map_hf_to_meta_keys(loaded_weights): """ replacements = [ ("^emb.weight", "weight"), + ("model.language_model.", ""), ("model.", ""), ("embed_tokens", "tok_embeddings"), ("lm_head", "output"), @@ -252,11 +594,19 @@ def map_hf_to_meta_keys(loaded_weights): ("k_proj", "wk"), ("v_proj", "wv"), ("o_proj", "wo"), + ("q_norm", "q_norm"), + ("k_norm", "k_norm"), ] return replace_keys(loaded_weights, replacements) -def map_meta_to_hf_keys(loaded_weights): +def convert_vision_meta_to_hf(state_dict, head_dim): + # state_dict = convert_meta_qkv_to_hf_format(state_dict, head_dim) + state_dict = map_vision_meta_to_hf_keys(state_dict) + return state_dict + + +def map_meta_to_hf_keys(loaded_weights, language_prefix=""): # Define mappings at each level of the hierarchy meta_to_hf_mappings = { # Top level @@ -266,6 +616,8 @@ def map_meta_to_hf_keys(loaded_weights): # Layer level "attention_norm.weight": "input_layernorm.weight", "ffn_norm.weight": "post_attention_layernorm.weight", + "pre_feedforward_layernorm.weight": "pre_feedforward_layernorm.weight", + "post_feedforward_layernorm.weight": "post_feedforward_layernorm.weight", # Attention module "attention.wq.weight": "self_attn.q_proj.weight", "attention.wk.weight": "self_attn.k_proj.weight", diff --git a/models/tt_transformers/tt/mlp.py b/models/tt_transformers/tt/mlp.py index 9893ec2440e4..ec9fe66d7506 100644 --- a/models/tt_transformers/tt/mlp.py +++ b/models/tt_transformers/tt/mlp.py @@ -72,7 +72,9 @@ def __init__( self.w3 = as_sharded_tensor("w3_sharded", ff1_3_dtype, dims=w1_dims) # Default activation is SILU - self.activation_type = self.args.mlp_activation_type + self.activation_type = ( + args.mlp_activation_type if hasattr(args, "mlp_activation_type") else ttnn.UnaryOpType.SILU + ) def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: """ diff --git a/models/tt_transformers/tt/model.py b/models/tt_transformers/tt/model.py index 591c915085e6..d9801db59594 100644 --- a/models/tt_transformers/tt/model.py +++ b/models/tt_transformers/tt/model.py @@ -58,6 +58,19 @@ def __init__( rope_theta=args.rope_theta, rope_scaling=args.rope_scaling, ) + + if args.rope_local_theta is not None: + self.rope_setup_local = ActualRopeSetupClass( + device=mesh_device, + batch_size=args.max_batch_size, + head_dim=args.head_dim, + max_seq_len=args.max_seq_len, + rope_theta=args.rope_local_theta, + rope_scaling=None, + ) + else: + self.rope_setup_local = None + self.trans_mats_dict = self.rope_setup.get_both_trans_mats() self.layers = [ @@ -105,6 +118,8 @@ def __init__( max_columns_per_device=self.args.max_columns_per_device_lm_head, ) + self.embed_scale = args.embed_scale + def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_page_table=None): """ Inputs are torch tensors or python types. This function returns ttnn @@ -122,7 +137,8 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag layout=ttnn.ROW_MAJOR_LAYOUT, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) - tokens_embd = self.embd(tokens) + tokens_embd = self.embd(tokens, self.embed_scale) + tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) # Slice the rot mats to the prefill seqlen @@ -133,6 +149,13 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag self.rope_setup.cos_matrix[:, :, start_pos : start_pos + S, :], self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], ] + if self.rope_setup_local is not None: + tt_rot_mats_prefill_local = [ + self.rope_setup_local.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_setup_local.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + else: + tt_rot_mats_prefill_local = None if page_table is not None: tt_page_table = ttnn.from_torch( @@ -156,7 +179,7 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag else: tt_chunk_page_table = None - return tokens_embd, tt_rot_mats_prefill, tt_page_table, tt_chunk_page_table + return tokens_embd, [tt_rot_mats_prefill, tt_rot_mats_prefill_local], tt_page_table, tt_chunk_page_table def prepare_inputs_decode(self, *inputs): """ @@ -228,13 +251,18 @@ def transform_decode_inputs_device(self, tokens, current_pos, rope_idxs, page_ta Embed tokens """ tt_rot_mats = self.rope_setup.get_rot_mats(rope_idxs) - tt_tokens = self.embd(tokens) + if self.rope_setup_local is not None: + tt_rot_mats_local = self.rope_setup_local.get_rot_mats(rope_idxs) + else: + tt_rot_mats_local = None + tt_tokens = self.embd(tokens, self.embed_scale) + tt_tokens = ttnn.unsqueeze_to_4D(tt_tokens) tt_tokens = ttnn.to_memory_config( tt_tokens, self.args.model_config["DECODE_RESIDUAL_MEMCFG"], ) - return tt_tokens, current_pos, tt_rot_mats, page_table + return tt_tokens, current_pos, [tt_rot_mats, tt_rot_mats_local], page_table def concat_device_output(self, tt_out): """ diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index 59a955a568a5..ee9bb0c3335e 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -27,6 +27,8 @@ from models.tt_transformers.tt.load_checkpoints import ( convert_hf_to_meta, convert_meta_to_hf, + convert_vision_hf_to_meta, + convert_vision_meta_to_hf, load_hf_state_dict, load_meta_state_dict, reverse_permute, @@ -69,6 +71,7 @@ class OpGroup(Enum): LI_QKV_PREFILL = "li_qkv_prefill" LI_O_PREFILL = "li_o_prefill" SDPA_PREFILL = "sdpa_prefill" + ACCURACY = "accuracy" # This is a special group for accuracy mode, not an actual operator group class MathFidelitySetting(Enum): @@ -77,6 +80,7 @@ class MathFidelitySetting(Enum): HIFI2_NA = "hifi2na" # na specified `packer_l1_acc=False` and `fp32_dest_acc_en=False` in compute kernel config HIFI2_FP16 = "hifi2fp16" # fp16 specified `fp32_dest_acc_en=False` in compute kernel config HIFI4 = "hifi4" + HIFI4_FP32 = "hifi4fp32" class ModelOptimizations: @@ -141,7 +145,7 @@ def performance(cls, model_name): All models use bfp4 in FF1 and FF3 MLPs in this configuration """ base_model_name = get_base_model_name(model_name) - if base_model_name == "Qwen2.5-7B": + if base_model_name in ["Qwen2.5-7B", "gemma-3-4b"]: logger.info( f"Model {model_name} is degraded under standard high-performance settings, using BF16 attention and BFP8 MLP" ) @@ -235,7 +239,7 @@ def _default_settings(self): TensorGroup.WO: PrecisionSetting.BFP8, TensorGroup.KV_CACHE: PrecisionSetting.BFP8, # Activation across whole model - TensorGroup.ACTIVATION: None, # this signals that original dtype should be used + TensorGroup.ACTIVATION: None, }, "OpFidelity": { # MLP linear operators - BFP8 with FP16 accumulation to save L1 @@ -248,6 +252,7 @@ def _default_settings(self): OpGroup.LI_QKV_PREFILL: MathFidelitySetting.HIFI2, OpGroup.SDPA_PREFILL: MathFidelitySetting.HIFI4, OpGroup.LI_O_PREFILL: MathFidelitySetting.HIFI2, # FP32 accumulate is important here + OpGroup.ACCURACY: MathFidelitySetting.HIFI4_FP32, }, } @@ -549,7 +554,7 @@ def __init__( if max_prefill_chunk_size_div1024 is None: # TODO Improve this to be more general to more devices and models MAX_PREFILL_CHUNK_SIZES_DIV1024 = { - "gemma-3-4b": {"N150": 8, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, + "gemma-3-4b": {"N150": 128, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, "Llama-3.2-1B": {"N150": 128, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, "Llama-3.2-3B": {"N150": 8, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, "Llama-3.1-8B": {"N150": 4, "N300": 64, "T3K": 128, "TG": 128, "P150x4": 128}, @@ -583,9 +588,10 @@ def __init__( max_prefill_chunk_size_div1024 = int(max_prefill_chunk_size_div1024) self.max_prefill_chunk_size = max_prefill_chunk_size_div1024 * 1024 - if (self.base_model_name in ["Llama-3.1-8B", "Llama-3.2-11B", "Mistral-7B"] and self.device_name == "N150") or ( - self.base_model_name in ["Qwen2.5-7B"] and self.device_name == "N300" - ): + if ( + self.base_model_name in ["Llama-3.1-8B", "Llama-3.2-11B", "Mistral-7B", "gemma-3-4b"] + and self.device_name == "N150" + ) or (self.base_model_name in ["Qwen2.5-7B"] and self.device_name == "N300"): logger.info(f"Reducing prefill_len_cutoff to 512 for {self.model_name} on {self.device_name}") self.prefill_len_cutoff = 512 @@ -660,6 +666,12 @@ def __init__( fp32_dest_acc_en=True, packer_l1_acc=True, ) + self.compute_kernel_config_hifi4_fp32 = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=True, + packer_l1_acc=True, + dst_full_sync_en=False, + ) self.compute_kernel_config_hifi2_na = ttnn.WormholeComputeKernelConfig( math_fidelity=ttnn.MathFidelity.HiFi2, math_approx_mode=False, @@ -1230,7 +1242,10 @@ def _get_xattn_kv_prefill_mem_cfg(seq_len): self.model_config["XATTN_KV_PREFILL_MEM_CFG"] = _get_xattn_kv_prefill_mem_cfg - self.VISION_MAX_MM_SEQ = nearest_32(self.vision_chunk_ntok) + if self.is_vision(): + self.VISION_MAX_MM_SEQ = ( + self.vision_chunk_ntok if "gemma-3" in self.base_model_name else nearest_32(self.vision_chunk_ntok) + ) # RMS NORM self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"] = self.create_sharded_norm_config(attn_input_grid) @@ -1252,6 +1267,10 @@ def _get_xattn_kv_prefill_mem_cfg(seq_len): ), ) + self.model_config["LM_HEAD_OUTPUT_MEMCFG"] = ( + ttnn.DRAM_MEMORY_CONFIG if self.model_name == "gemma-3-4b-it" else ttnn.L1_MEMORY_CONFIG + ) + self.lm_head_dtype = ttnn.bfloat16 if self.model_name == "gemma-3-4b-it" else None self.set_tg_attention_config() self.is_multichip = self.num_devices > 1 @@ -1394,24 +1413,34 @@ def _get_hidden_activation_type(self, config): def _set_model_specific_params(self): # Gemma3 specific params - is_gemma3 = "gemma-3" in self.base_model_name.lower() - if is_gemma3: - self.rms_norm_add_unit_offset = True + self.rms_norm_add_unit_offset = "gemma-3" in self.base_model_name.lower() + self.embed_scale = 1.0 if not "gemma-3" in self.base_model_name.lower() else self.dim ** 0.5 def _set_params_from_dict(self, config, is_hf=False): + eos_token_id = config.get("eos_token_id", None) + self.image_token_index = config.get("image_token_index", 262144) + # Try to get text_config, if it doesn't exist everything is text config text_config = config.get("text_config", config) + self.eos_token_id = None if isinstance(eos_token_id, int) else eos_token_id + layer_types = text_config["layer_types"] if "layer_types" in text_config else None # Common params with different names between Meta and HF self.dim = text_config.get("dim", text_config.get("hidden_size")) self.n_heads = text_config.get("n_heads", text_config.get("num_attention_heads")) self.n_kv_heads = text_config.get("n_kv_heads", text_config.get("num_key_value_heads")) self.n_layers = text_config.get("n_layers", text_config.get("num_hidden_layers")) + + self.sliding_window_pattern = ( + [lt == "sliding_attention" for lt in layer_types] if layer_types is not None else [False] * self.n_layers + ) + self.full_model_n_layers = self.n_layers self.norm_eps = text_config.get("norm_eps", text_config.get("rms_norm_eps")) self.vocab_size = text_config["vocab_size"] self.padded_vocab_size = 128 * 1024 if self.is_galaxy else None self.head_dim = text_config.get("head_dim", self.dim // self.n_heads) or self.dim // self.n_heads + self.rope_local_theta = text_config.get("rope_local_base_freq", None) if is_hf: self.max_context_len = text_config.get("max_position_embeddings") else: @@ -1489,18 +1518,18 @@ def _set_params_from_dict(self, config, is_hf=False): self.vision_num_cross_attention_layers = config.get("vision_num_cross_attention_layers", -1) # Vision constants - self.vision_dim = 1280 - self.vision_mlp_ratio = 4 - self.vision_hidden_dim = int(self.vision_dim * self.vision_mlp_ratio) - self.vision_act_layer = ttnn.UnaryOpType.GELU - self.vision_dropout = 0.0 - self.vision_attn_n_heads = 16 - self.vision_head_dim = self.vision_dim // self.vision_attn_n_heads - self.vision_n_layers = 32 - self.vision_n_global_layers = 8 - self.vision_max_num_tiles = 4 - self.vision_patch_size = 14 - self.vision_in_channels = 3 + # self.vision_dim = 1280 + # self.vision_mlp_ratio = 4 + # self.vision_hidden_dim = int(self.vision_dim * self.vision_mlp_ratio) + # self.vision_act_layer = ttnn.UnaryOpType.GELU + # self.vision_dropout = 0.0 + # self.vision_attn_n_heads = 16 + # self.vision_head_dim = self.vision_dim // self.vision_attn_n_heads + # self.vision_n_layers = 32 + # self.vision_n_global_layers = 8 + # self.vision_max_num_tiles = 4 + # self.vision_patch_size = 14 + # self.vision_in_channels = 3 self.state_dict_text_prefix = self._get_text_prefix() self.is_multimodal = "vision_config" in config or self.is_vision() @@ -1576,7 +1605,68 @@ def _set_params(self, checkpoint_dir): else None ) + # def _set_vision_params(self, vision_config): + # self.vision_dim = vision_config.get("hidden_size", 1280) + # self.vision_mlp_ratio = vision_config.get("intermediate_size", self.vision_dim * 4) // self.vision_dim + # self.vision_hidden_dim = vision_config.get("intermediate_size", self.vision_dim * self.vision_mlp_ratio) + # self.vision_attn_n_heads = vision_config.get("num_attention_heads", 16) + # self.vision_head_dim = self.vision_dim // self.vision_attn_n_heads + # self.vision_n_layers = vision_config.get("num_hidden_layers", 32) + # self.vision_patch_size = vision_config.get("patch_size", 14) + # self.vision_in_channels = vision_config.get("num_channels", 3) + # self.vision_act_layer = ttnn.UnaryOpType.GELU # or read from config if variable + # self.vision_dropout = vision_config.get("attention_dropout", 0.0) + # self.vision_max_num_tiles = 4 + # self.vision_n_global_layers = 8 + + def _set_vision_params(self, vision_config): + self.vision_chunk_size = vision_config.get("vision_chunk_size", 896) + self.vision_max_num_chunks = vision_config.get("vision_max_num_chunks", 4) + self.vision_num_cross_attention_layers = vision_config.get("vision_num_cross_attention_layers", 8) + self.vision_dim = vision_config.get("hidden_size", 1152) + + intermediate_size = vision_config.get("intermediate_size", self.vision_dim * 4) + self.vision_mlp_ratio = intermediate_size // self.vision_dim + self.vision_hidden_dim = int(self.vision_dim * self.vision_mlp_ratio) + self.vision_attn_n_heads = vision_config.get("num_attention_heads", 16) + self.vision_head_dim = self.vision_dim // self.vision_attn_n_heads + + self.vision_n_layers = vision_config.get("num_hidden_layers", 27) + self.vision_patch_size = vision_config.get("patch_size", 14) + self.vision_in_channels = vision_config.get("num_channels", 3) + + self.vision_dropout = vision_config.get("attention_dropout", 0.0) + self.mm_tokens_per_image = vision_config.get("mm_tokens_per_image", 256) + + # Optional vision activation layer, defaults to GELU + act_layer = vision_config.get("act_layer", "gelu").lower() + self.vision_act_layer = { + "gelu": ttnn.UnaryOpType.GELU, + "relu": ttnn.UnaryOpType.RELU, + "silu": ttnn.UnaryOpType.SILU, + }.get(act_layer, ttnn.UnaryOpType.GELU) + + # Optional tuning knobs + # self.vision_max_num_tiles = vision_config.get("max_num_tiles", 4) + self.vision_n_global_layers = vision_config.get("n_global_layers", 8) + + # # Optional Meta-specific knobs + # self.vision_max_num_chunks = vision_config.get("max_num_chunks", 4) + # self.vision_num_cross_attention_layers = vision_config.get("num_cross_attention_layers", -1) + def _set_hf_params(self, checkpoint_dir): + def merge_text_config(base_config): + text_config = base_config.get("text_config", {}) + # Merge non-nested keys into text_config + text_config.update({k: v for k, v in base_config.items() if k not in ["text_config", "vision_config"]}) + return text_config + + def merge_vision_config(base_config): + vision_config = base_config.get("vision_config", {}) + # Merge non-nested keys into vision_config + vision_config.update({k: v for k, v in base_config.items() if k not in ["text_config", "vision_config"]}) + return vision_config + if self.from_hf_url: # Special case Qwen2.5-VL models until they are fully integrated into a HF release if "Qwen/Qwen2.5-VL" in self.model_name: @@ -1588,17 +1678,30 @@ def _set_hf_params(self, checkpoint_dir): logger.info( f"Loading state param for dummy {self.model_name} from {self.LOCAL_HF_PARAMS[self.model_name]}" ) - self.hf_config = AutoConfig.from_pretrained(self.LOCAL_HF_PARAMS[self.model_name]) + self.hf_config = AutoConfig.from_pretrained(self.LOCAL_HF_PARAMS[self.model_name]).to_dict() else: - self.hf_config = AutoConfig.from_pretrained(self.CKPT_DIR) - - config = self.hf_config.to_dict() + self.hf_config = AutoConfig.from_pretrained(self.CKPT_DIR).to_dict() + + if "text_config" in self.hf_config or "vision_config" in self.hf_config: + if "gemma-3-4b" in self.base_model_name: + self._set_params_from_dict(self.hf_config, is_hf=True) + if "vision_config" in self.hf_config: + merged_vision_config = merge_vision_config(self.hf_config) + self._set_vision_params(merged_vision_config) + else: + merged_text_config = merge_text_config(self.hf_config) + self._set_params_from_dict(merged_text_config, is_hf=True) + if "vision_config" in self.hf_config: + merged_vision_config = merge_vision_config(self.hf_config) + self._set_vision_params(merged_vision_config) + else: + self._set_params_from_dict(self.hf_config, is_hf=True) else: config_file = os.path.join(checkpoint_dir, "config.json") assert os.path.exists(config_file), f"config.json file not found at {config_file}" with open(config_file, "r") as f: config = json.load(f) - self._set_params_from_dict(config, is_hf=True) + self._set_params_from_dict(config, is_hf=True) def __repr__(self): return f"""ModelArgs( @@ -1622,15 +1725,37 @@ def __repr__(self): def is_vision(self): return self.vision_chunk_size > 0 - def get_state_dict_prefix(self, module_name, layer_num): - text_prefix = self.state_dict_text_prefix + def get_state_dict_prefix(self, module_name, layer_num, is_vision=False): + if "gemma-3-4b" in self.model_name: + if is_vision: + text_prefix = "model.vision_tower.vision_model.encoder." + + else: + # text_prefix = "model.language_model." + + text_prefix = "" + + else: + text_prefix = self.state_dict_text_prefix + layer_prefix = f"layers.{layer_num}." if layer_num is not None else "" + module_map = { "MLP": "feed_forward", "Attention": "attention", "TransformerBlock": "", "": "", # If no module is given, just get layer prefix } + + vision_module_map = { + "MLP": "mlp.", + "Attention": "self_attn.", + "TransformerBlock": "", + "": "", + } + + module_map = vision_module_map if is_vision else module_map + return text_prefix + layer_prefix + module_map[module_name] def weight_cache_path(self, dtype): @@ -1694,10 +1819,13 @@ def load_state_dict(self): if self.checkpoint_type == CheckpointType.HuggingFace: if self.is_multimodal: - state_dict = standardize_hf_keys_multimodal(state_dict) + if "gemma-3-4b" in self.model_name: + state_dict = convert_vision_hf_to_meta(state_dict, self.head_dim) + else: + state_dict = standardize_hf_keys_multimodal(state_dict) else: state_dict = standardize_hf_keys(state_dict) - state_dict = convert_hf_to_meta(state_dict, self.head_dim) + state_dict = convert_hf_to_meta(state_dict, self.head_dim) keys_dict = list(state_dict.keys())[:] remv = [f"layers.{i}." for i in list(range(self.n_layers, self.full_model_n_layers))] @@ -2108,7 +2236,7 @@ def create_tokenizer(self): # Add meta-compatible stop token list to the HF tokenizer if not "stop_tokens" in tokenizer.__dict__: - tokenizer.stop_tokens = [tokenizer.eos_token_id] + tokenizer.stop_tokens = self.eos_token_id if self.eos_token_id is not None else [tokenizer.eos_token_id] return tokenizer def encode_prompt(self, prompt_text, system_prompt_text=None, instruct=True): @@ -2151,11 +2279,8 @@ def reference_transformer(self, wrap=True, load_checkpoint=False): # Special case Qwen2.5-VL models until they are fully integrated into a HF release if "Qwen/Qwen2.5-VL" in self.model_name: from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig as AutoConfig - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLForConditionalGeneration as AutoModelForCausalLM, - ) else: - from transformers import AutoConfig, AutoModelForCausalLM + from transformers import AutoConfig, AutoModel # HF is much faster at loading from a checkpoint than generating from config # so use that by preference unless we don't have a checkpoint @@ -2163,16 +2288,16 @@ def reference_transformer(self, wrap=True, load_checkpoint=False): config = AutoConfig.from_pretrained(self.LOCAL_HF_PARAMS[self.model_name]) config.num_layers = self.n_layers config.num_hidden_layers = self.n_layers - model = AutoModelForCausalLM.from_config(config) + model = AutoModel.from_config(config) else: if self.cache_hf_flag and self.cached_hf_model is None: - model = AutoModelForCausalLM.from_pretrained(self.CKPT_DIR) + model = AutoModel.from_pretrained(self.CKPT_DIR) self.cached_hf_model = model elif self.cache_hf_flag and self.cached_hf_model is not None: model = self.cached_hf_model else: # No caching - load fresh each time - model = AutoModelForCausalLM.from_pretrained(self.CKPT_DIR) + model = AutoModel.from_pretrained(self.CKPT_DIR) # HACK: Assume that we want the language model layers only if hasattr(model, "language_model"): model.model = model.language_model @@ -2184,6 +2309,20 @@ def reference_transformer(self, wrap=True, load_checkpoint=False): else: return model + def reference_vision_multi_modal(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.multi_modal_projector + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rms_norm(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.multi_modal_projector.mm_soft_emb_norm + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + return layer + def reference_rms_norm(self): if self.checkpoint_type == CheckpointType.Meta: from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import RMSNorm @@ -2196,6 +2335,109 @@ def reference_rms_norm(self): layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) return layer + def reference_vision_transformer(self, wrap=True, load_checkpoint=False): + if self.checkpoint_type == CheckpointType.HuggingFace: + from transformers import AutoConfig, AutoModelForCausalLM + + if self.dummy_weights and not load_checkpoint: + config = AutoConfig.from_pretrained(self.LOCAL_HF_PARAMS[self.model_name]) + config.num_layers = self.n_layers + config.num_hidden_layers = self.n_layers + model = AutoModelForCausalLM.from_config(config) + else: + if "gemma-3" in self.model_name: + from transformers import Gemma3ForConditionalGeneration + + model = Gemma3ForConditionalGeneration.from_pretrained(self.CKPT_DIR) + model = model + else: + if self.cached_hf_model is None: + model = AutoModelForCausalLM.from_pretrained(self.CKPT_DIR) + self.cached_hf_model = model + else: + model = self.cached_hf_model + model.model.layers = model.model.layers[: self.n_layers] + if wrap: + wrapper = HfModelWrapper(model, self.head_dim) + return wrapper + else: + return model + + def reference_gemma_model(self): + model = self.reference_vision_transformer(wrap=False) + layer = model + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_model(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_mlp(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.encoder.layers[0].mlp + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_siglip_patch_embed(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.embeddings.patch_embedding + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_pos_embedding(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.embeddings.position_embedding + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_embedding(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.embeddings + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_layernorm(self, layer_name="layer_norm1"): + model = self.reference_vision_transformer(wrap=False) + if layer_name == "layer_norm1": + layer = model.vision_tower.vision_model.encoder.layers[0].layer_norm1 + elif layer_name == "layer_norm2": + layer = model.vision_tower.vision_model.encoder.layers[0].layer_norm2 + else: + layer = model.vision_tower.vision_model.post_layernorm + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_attention(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.encoder.layers[0].self_attn # Common naming + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_encoder_block(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.encoder.layers[0] + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_encoder(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.encoder + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + def reference_mlp(self): if self.checkpoint_type == CheckpointType.Meta: from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import FeedForward @@ -2232,7 +2474,14 @@ def reference_decoder(self): else: model = self.reference_transformer(wrap=False) layer = model.model.layers[0] - wrapper = HfDecoderWrapper(layer, self.head_dim, model.model.rotary_emb) + rotary_emb = model.model.rotary_emb + + if "gemma-3" in self.model_name: + rotary_emb_local = model.model.rotary_emb_local + wrapper = HfGemmaDecoderWrapper(layer, self.head_dim, rotary_emb, rotary_emb_local) + else: + wrapper = HfDecoderWrapper(layer, self.head_dim, rotary_emb) + return wrapper def reference_attention(self): @@ -2243,7 +2492,11 @@ def reference_attention(self): else: model = self.reference_transformer(wrap=False) layer = model.model.layers[0].self_attn - use_position_embeddings = layer.__class__.__name__ in ("Qwen3Attention", "MistralAttention") + use_position_embeddings = layer.__class__.__name__ in ( + "Qwen3Attention", + "MistralAttention", + "Gemma3Attention", + ) wrapper = HfAttentionWrapper( layer, self.head_dim, model.model.rotary_emb if use_position_embeddings else None ) @@ -2397,24 +2650,79 @@ def cache_v(self): class HfDecoderWrapper: - def __init__(self, decoder, head_dim, rotary_emb): + def __init__(self, decoder, head_dim, rotary_emb, rotary_emb_local=None): + from transformers import DynamicCache + + self.decoder = decoder + self.head_dim = head_dim + self.rotary_emb = rotary_emb + self.rotary_emb_local = rotary_emb_local + self.past_key_values = DynamicCache() + + def forward(self, x, start_pos, freqs_cis_i, mask=None): + position_ids = torch.tensor([list(range(start_pos, start_pos + x.shape[1]))] * x.shape[0]) + model_name_env = os.getenv("HF_MODEL") + if "gemma-3-4b" in model_name_env.lower(): + position_embeddings = self.rotary_emb(x, position_ids) + position_embeddings_local = self.rotary_emb_local(x, position_ids) + else: + position_embeddings = self.rotary_emb(x, position_ids) + + if mask is not None: + while len(mask.shape) < 4: + mask = mask.unsqueeze(0) + if self.rotary_emb_local is not None: + result = self.decoder.forward( + x, + position_embeddings_global=position_embeddings, + position_embeddings_local=position_embeddings_local, + past_key_value=self.past_key_values, + use_cache=True, + position_ids=position_ids, + attention_mask=mask, + ) + else: + result = self.decoder.forward( + x, + position_embeddings=position_embeddings, + past_key_value=self.past_key_values, + use_cache=True, + position_ids=position_ids, + attention_mask=mask, + ) + output = result[0] + return output + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def load_state_dict(self, state_dict): + return self.decoder.load_state_dict(convert_meta_to_hf(state_dict, self.head_dim)) + + +class HfGemmaDecoderWrapper: + def __init__(self, decoder, head_dim, rotary_emb, rotary_emb_local): from transformers import DynamicCache self.decoder = decoder self.head_dim = head_dim self.rotary_emb = rotary_emb + self.rotary_emb_local = rotary_emb_local self.past_key_values = DynamicCache() def forward(self, x, start_pos, freqs_cis_i, mask=None): position_ids = torch.tensor([list(range(start_pos, start_pos + x.shape[1]))] * x.shape[0]) - position_embeddings = self.rotary_emb(x, position_ids) + # TODO: Generalize for other HF models + position_embeddings_global = self.rotary_emb(x, position_ids) + position_embeddings_local = self.rotary_emb_local(x, position_ids) if mask is not None: while len(mask.shape) < 4: mask = mask.unsqueeze(0) result = self.decoder.forward( x, - position_embeddings=position_embeddings, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, past_key_value=self.past_key_values, use_cache=True, position_ids=position_ids, @@ -2525,6 +2833,7 @@ def get_math_fidelity(self, decoder_id, op: OpGroup, configuration: ModelArgs): MathFidelitySetting.HIFI2_NA: configuration.compute_kernel_config_hifi2_na, MathFidelitySetting.HIFI2_FP16: configuration.compute_kernel_config_hifi2_fp16, MathFidelitySetting.HIFI4: configuration.compute_kernel_config_hifi4, + MathFidelitySetting.HIFI4_FP32: configuration.compute_kernel_config_hifi4_fp32, } return math_fidelity_setting_lookup[self.decoder_optimizations[decoder_id].op_fidelity_settings[op]] diff --git a/models/tt_transformers/tt/multimodal/gemma/gemma_conv2d_patch.py b/models/tt_transformers/tt/multimodal/gemma/gemma_conv2d_patch.py new file mode 100644 index 000000000000..850f0610d793 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/gemma/gemma_conv2d_patch.py @@ -0,0 +1,123 @@ +""" +This is the Conv2dPath of Gemma-3-4b-it +We have reused the exisiting Conv2dPath of TtLlamaConv2dPath with few modifications. +We have added a check for weight to convert 4D to 2D +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.utility_functions import nearest_32 + + +class TtGemmaConv2dPatch(LightweightModule): + """Conv2D Patching layer. + Column parallel over unfolded input. + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + Input: (bsz, in_channels, width, height) + Output: (bsz, num_tokens, out_channels) + """ + + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + dtype, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + bias, + ): + super().__init__() + + self.mesh_device = mesh_device + self.num_devices = self.mesh_device.get_num_devices() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + + self.bias = ( + ttnn.as_tensor( + torch.reshape(state_dict[f"{state_dict_prefix}_linear.bias"], (1, -1)), + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + if bias + else None + ) + + self._unfold = torch.nn.Unfold(kernel_size=self.kernel_size, stride=self.stride) + + weight = state_dict[f"{state_dict_prefix}_linear.weight"] + if weight.ndim == 4: + weight = weight.view(out_channels, -1) + pad_len = nearest_32(weight.shape[-1]) - weight.shape[-1] + padding = torch.zeros(self.out_channels, pad_len, dtype=weight.dtype) + padded_weight = torch.cat([weight, padding], dim=-1) + padded_weight = padded_weight.permute(1, 0).reshape(1, 1, -1, self.out_channels) + + self._linear_weight = ttnn.as_tensor( + padded_weight, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + self.compute_kernel_config = ttnn.init_device_compute_kernel_config( + mesh_device.arch(), + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + def forward(self, x: torch.Tensor): + x = self._unfold(x) + x = x.permute(0, 2, 1) + + # Need to pad the last dimension of x to be a multiple of a tile + pad_len = nearest_32(x.shape[-1]) - x.shape[-1] + padding = torch.zeros((x.shape[0], x.shape[1], pad_len), dtype=x.dtype, device=x.device) + x = torch.cat([x, padding], dim=-1) + + x = ttnn.as_tensor( + x, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + out = ttnn.linear( + x, + self._linear_weight, + bias=self.bias, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + core_grid=ttnn.CoreGrid(y=8, x=8), + ) + + ttnn.deallocate(x) + + return out diff --git a/models/tt_transformers/tt/multimodal/gemma/gemma_e2e_model.py b/models/tt_transformers/tt/multimodal/gemma/gemma_e2e_model.py new file mode 100644 index 000000000000..4c427a6c1cb4 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/gemma/gemma_e2e_model.py @@ -0,0 +1,132 @@ +from typing import List + +import torch + +import ttnn +from models.tt_transformers.tt.model import Transformer +from models.tt_transformers.tt.multimodal.gemma.gemma_vision_model import TtGemmaTransformerVision +from models.tt_transformers.tt.multimodal.llama_vision_model import _stack_images + + +def _stack_images( + images: List[List[torch.Tensor]], # batch of samples, each with list of image embeddings +) -> List[torch.Tensor]: + """ + Concatenate image embeddings per sample into a single 2D tensor. + + Args: + images: List of samples, each being a list of [num_patches, hidden_dim] tensors + + Returns: + List of [total_patches, hidden_dim] tensors, one per sample + """ + return [torch.cat(image_list, dim=0) for image_list in images] + + +class TtGemmaModel(Transformer): + def __init__( + self, + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=None, + use_paged_kv_cache=False, + ): + super().__init__( + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=paged_attention_config, + use_paged_kv_cache=use_paged_kv_cache, + ) + + self.vision_model = TtGemmaTransformerVision( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix="model.vision_tower.vision_model.", + dtype=dtype, + configuration=args, + weight_cache_path=weight_cache_path, + ) + + def prepare_inputs_prefill(self, pt_tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): + """ + Inputs are torch tensors or python types. This function returns ttnn + tensors on device. + TODO: Debate whether this function is responsible for padding + """ + + S = pt_tokens.shape[-1] + tokens = ttnn.from_torch( + pt_tokens.reshape(1, 1, 1, -1), + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + self.embed_scale = self.args.dim**0.5 + tokens_embd = self.embd(tokens, self.embed_scale) + + vision_output = self.compute_vision_token(**kwargs) + tokens_embd = ttnn.to_torch(tokens_embd, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1)) + comp_vision_output = ttnn.to_torch( + vision_output, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0) + )[: vision_output.shape[0], :] + + image_features = comp_vision_output.squeeze(0) + special_image_mask = (pt_tokens == self.args.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(tokens_embd) + image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) + tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) + + tokens_embd = self.args.prepare_residual_tensor_prefill( + tokens_embd, + ) + + tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) + # Slice the rot mats to the prefill seqlen + assert ( + self.rope_setup.cos_matrix.shape[2] >= start_pos + S + ), f"Padded prefill end idx {start_pos + S} exceeds max seq len {self.rope_setup.cos_matrix.shape[2]}" + + tt_rot_mats_prefill_global = [ + self.rope_setup.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + + tt_rot_mats_prefill_local = [ + self.rope_setup_local.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_setup_local.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + + if page_table is not None: + tt_page_table = ttnn.from_torch( + page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_page_table = None + + if chunk_page_table is not None: + tt_chunk_page_table = ttnn.from_torch( + chunk_page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_chunk_page_table = None + + return tokens_embd, [tt_rot_mats_prefill_global, tt_rot_mats_prefill_local], tt_page_table, tt_chunk_page_table + + def compute_vision_token(self, pixel_values): + vision_output = self.vision_model(pixel_values) + return vision_output diff --git a/models/tt_transformers/tt/multimodal/gemma/gemma_image_attention.py b/models/tt_transformers/tt/multimodal/gemma/gemma_image_attention.py new file mode 100644 index 000000000000..66c9adad571e --- /dev/null +++ b/models/tt_transformers/tt/multimodal/gemma/gemma_image_attention.py @@ -0,0 +1,385 @@ +""" +This is the ImageAttention block for Gemma-3-4b-it +We have reused the TTLlamaImageAttention with some modification. +We have made the linears (Q,K,V) to be executed separately and added bias support for O_projection, along with few +configuration changes. +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.utility_functions import nearest_32 + + +class TtGemmaImageAttention(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.num_devices = configuration.num_devices + + self.hidden_size = configuration.vision_dim + self.n_heads = configuration.vision_attn_n_heads + self.head_dim = self.hidden_size // self.n_heads + self.n_kv_heads = self.n_heads + + self.n_local_heads = self.n_heads // configuration.num_devices + self.n_local_kv_heads = self.n_kv_heads // configuration.num_devices + + self.dtype = dtype + + self.grid_size = configuration.max_grid_size + + self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 + self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 + self.compute_kernel_config_sdpa = configuration.compute_kernel_config_sdpa + self.configuration = configuration + + self.model_config = configuration.get_model_config() + + if configuration.dummy_weights or (weight_cache_path is None): + cache_name = lambda _: None + else: + cache_name = lambda name: weight_cache_path / (f"{state_dict_prefix}{name}") + + wq_str = f"{state_dict_prefix}wq.weight" + wk_str = f"{state_dict_prefix}wk.weight" + wv_str = f"{state_dict_prefix}wv.weight" + wo_str = f"{state_dict_prefix}wo.weight" + + # when splitting the devices, we need to make sure that the number of heads is divisible by the number of devices + assert self.n_heads % configuration.num_devices == 0 + assert self.n_kv_heads % configuration.num_devices == 0 + + # Pad head_dim to multiple of 32 + def pad_head_dim(weight, heads_out=True): + # Pad head dim to multiple of 32 + # heads_out means that the output dim of this weight contains heads. + dim = weight.shape[1] + assert weight.shape[0] == dim + padded_head_dim = nearest_32(self.head_dim) + padding_size = padded_head_dim - self.head_dim + if padding_size > 0: + if heads_out: + weight = weight.transpose(-1, -2) + weight = weight.reshape(dim, self.n_heads, self.head_dim) + padding = torch.zeros(dim, self.n_heads, padding_size, dtype=weight.dtype) + weight = torch.cat([weight, padding], dim=-1) + weight = weight.reshape(dim, self.n_heads * padded_head_dim) + if heads_out: + weight = weight.transpose(-1, -2) + return weight + + wq_padded = pad_head_dim(self.state_dict[wq_str]) + wk_padded = pad_head_dim(self.state_dict[wk_str]) + wv_padded = pad_head_dim(self.state_dict[wv_str]) + wo_padded = pad_head_dim(self.state_dict[wo_str], heads_out=False) + wq_chunked, wk_chunked, wv_chunked = ( + torch.chunk(w, configuration.num_devices) for w in [wq_padded, wk_padded, wv_padded] + ) + + # for Gemma + self.wq = ttnn.as_tensor( + tensor=wq_padded, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + cache_file_name=cache_name("wq_sharded"), + preprocess=lambda x: x.transpose(-2, -1), + ) + + self.wk = ttnn.as_tensor( + tensor=wk_padded, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + cache_file_name=cache_name("wk_sharded"), + preprocess=lambda x: x.transpose(-2, -1), + ) + + self.wv = ttnn.as_tensor( + tensor=wv_padded, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + cache_file_name=cache_name("wv_sharded"), + preprocess=lambda x: x.transpose(-2, -1), + ) + + self.wqkv = ttnn.as_tensor( + torch.concat( + [ + torch.concat( + [ + torch.transpose( + wq_chunked[i], + -2, + -1, + ), + torch.transpose( + wk_chunked[i], + -2, + -1, + ), + torch.transpose( + wv_chunked[i], + -2, + -1, + ), + ], + dim=-1, + ) + for i in range(configuration.num_devices) + ], + dim=-1, + ), + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + dtype=self.dtype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + cache_file_name=cache_name("wqkv_sharded"), + ) + + bq_str = f"{state_dict_prefix}wq.bias" + bk_str = f"{state_dict_prefix}wk.bias" + bv_str = f"{state_dict_prefix}wv.bias" + bo_str = f"{state_dict_prefix}wo.bias" + + if bq_str in self.state_dict: + + def pad_head_dim_bias(bias): + # Pad 1D bias to match padded head dim + dim = bias.shape[0] + assert ( + dim == self.n_heads * self.head_dim + ), f"Expected bias of shape ({self.n_heads} * {self.head_dim}) = {self.n_heads * self.head_dim}, but got {dim}" + + padded_head_dim = nearest_32(self.head_dim) + padding_size = padded_head_dim - self.head_dim + + if padding_size > 0: + bias = bias.view(self.n_heads, self.head_dim) + padding = torch.zeros(self.n_heads, padding_size, dtype=bias.dtype) + bias = torch.cat([bias, padding], dim=-1) + bias = bias.view(self.n_heads * padded_head_dim) + + return bias + + bq_padded = pad_head_dim_bias(self.state_dict[bq_str]) + bk_padded = pad_head_dim_bias(self.state_dict[bk_str]) + bv_padded = pad_head_dim_bias(self.state_dict[bv_str]) + + bq_chunked, bk_chunked, bv_chunked = ( + torch.chunk(b, configuration.num_devices) for b in [bq_padded, bk_padded, bv_padded] + ) + + self.bqkv = ttnn.as_tensor( + torch.concat( + [ + torch.concat( + [ + bq_chunked[i], + bk_chunked[i], + bv_chunked[i], + ], + dim=-1, + ) + for i in range(configuration.num_devices) + ], + dim=-1, + ), + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + dtype=self.dtype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + cache_file_name=cache_name("bqkv_sharded"), + ) + + # for Gemma + self.bq = ttnn.as_tensor( + tensor=bq_padded, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + cache_file_name=cache_name("bq_sharded"), + ) + + self.bk = ttnn.as_tensor( + tensor=bk_padded, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + cache_file_name=cache_name("bk_sharded"), + ) + + self.bv = ttnn.as_tensor( + tensor=bv_padded, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + cache_file_name=cache_name("bv_sharded"), + ) + + else: + self.bqkv = None + + self.wo = ttnn.as_tensor( + torch.transpose( + wo_padded, + -2, + -1, + ), + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=self.dtype, + layout=ttnn.TILE_LAYOUT, + # cache_file_name=cache_name("wo_sharded"), + ) + + if bo_str in self.state_dict: + self.bo = ttnn.as_tensor( + self.state_dict[bo_str], + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=self.dtype, + layout=ttnn.TILE_LAYOUT, + cache_file_name=cache_name("bo_sharded"), + ) + else: + self.bo = None + + self.scale = self.head_dim**-0.5 + + def forward(self, x_11SH, mask=None): + seq_len = x_11SH.shape[-2] + + MAX_MM_SEQ_LEN = ( + seq_len if "gemma-3" in self.configuration.base_model_name else self.configuration.VISION_MAX_MM_SEQ + ) + + if seq_len > MAX_MM_SEQ_LEN: + x_11SH = ttnn.reshape(x_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) + + q_heads_1QSD = ttnn.linear( + x_11SH, + self.wq, + bias=self.bq, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config_hifi4, + program_config=None + if "gemma-3" in self.configuration.base_model_name + else self.model_config["IMAGE_ATTN_QKV_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + + q_heads_1QSD = ttnn.transpose(ttnn.reshape(q_heads_1QSD, (1, seq_len, self.n_local_heads, -1)), 1, 2) + + k_heads_1KSD = ttnn.linear( + x_11SH, + self.wk, + bias=self.bk, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config_hifi4, + program_config=None + if "gemma-3" in self.configuration.base_model_name + else self.model_config["IMAGE_ATTN_QKV_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + + k_heads_1KSD = ttnn.transpose(ttnn.reshape(k_heads_1KSD, (1, seq_len, self.n_local_heads, -1)), 1, 2) + + v_heads_1VSD = ttnn.linear( + x_11SH, + self.wv, + bias=self.bv, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config_hifi4, + program_config=None + if "gemma-3" in self.configuration.base_model_name + else self.model_config["IMAGE_ATTN_QKV_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + v_heads_1VSD = ttnn.transpose(ttnn.reshape(v_heads_1VSD, (1, seq_len, self.n_local_heads, -1)), 1, 2) + + # TODO: get this from model_config + sdpa_cfg = ttnn.SDPAProgramConfig( + compute_with_storage_grid_size=(8, 8), q_chunk_size=128, k_chunk_size=128, exp_approx_mode=False + ) + attn_output_1QSD = ttnn.transformer.scaled_dot_product_attention( + q_heads_1QSD, + k_heads_1KSD, + v_heads_1VSD, + is_causal=False, + scale=self.scale, + attn_mask=mask, + program_config=sdpa_cfg, + compute_kernel_config=self.compute_kernel_config_sdpa, + ) + # deallocate keys and values + ttnn.deallocate(q_heads_1QSD) + ttnn.deallocate(k_heads_1KSD) + ttnn.deallocate(v_heads_1VSD) + + ### + # Output matmul + ### + attn_output_11SH = ttnn.experimental.nlp_concat_heads( + attn_output_1QSD, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + ttnn.deallocate(attn_output_1QSD) + + # reshaping long sequence to matmul fit on device + if seq_len > MAX_MM_SEQ_LEN: + attn_output_11SH = ttnn.reshape(attn_output_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) + + if self.num_devices > 1: + # self.bo = ttnn.all_gather(self.bo, dim=3, num_links=1) + attn_output_11SH = ttnn.all_gather(attn_output_11SH, dim=3, num_links=1) + + output_11SH = ttnn.linear( + attn_output_11SH, + self.wo, + bias=self.bo, + compute_kernel_config=self.compute_kernel_config_hifi4, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + program_config=None + if "gemma-3" in self.configuration.base_model_name + else self.model_config["IMAGE_ATTN_QKV_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + if seq_len > MAX_MM_SEQ_LEN: + output_11SH = ttnn.reshape(output_11SH, [1, 1, seq_len, -1]) + ttnn.deallocate(attn_output_11SH) + + return output_11SH diff --git a/models/tt_transformers/tt/multimodal/gemma/gemma_image_block.py b/models/tt_transformers/tt/multimodal/gemma/gemma_image_block.py new file mode 100644 index 000000000000..4e2998dd107b --- /dev/null +++ b/models/tt_transformers/tt/multimodal/gemma/gemma_image_block.py @@ -0,0 +1,114 @@ +""" +This is the ImageTransformer block for Gemma-3-4b-it. +We have reused the TtLlamaImageTransformerBlock with incorporating the +TtGemmaImageAttention and TtGemmaImageFeedForward +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.gemma.gemma_image_attention import TtGemmaImageAttention +from models.tt_transformers.tt.multimodal.gemma.gemma_image_mlp import TtGemmaImageFeedForward +from models.tt_transformers.tt.multimodal.llama_layernorm import TtLayerNorm + + +class TtGemmaImageTransformerBlock(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + gated=False, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.num_devices = configuration.num_devices + self.hidden_size = configuration.vision_dim + self.gated = gated + + self.ln_1 = TtLayerNorm( + device=mesh_device, + dim=configuration.vision_dim, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}ln_1.", + weight_cache_path=weight_cache_path, + weight_dtype=dtype, + eps=configuration.norm_eps, + ) + + self.attn = TtGemmaImageAttention( + mesh_device, + state_dict, + state_dict_prefix=f"{state_dict_prefix}attn.", + weight_cache_path=weight_cache_path, + dtype=dtype, + configuration=configuration, + ) + + self.ln_2 = TtLayerNorm( + device=mesh_device, + dim=configuration.vision_dim, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}ln_2.", + weight_cache_path=weight_cache_path, + weight_dtype=dtype, + eps=configuration.norm_eps, + ) + + self.mlp = TtGemmaImageFeedForward( + mesh_device=mesh_device, + args=configuration, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}mlp.", + weight_cache_path=weight_cache_path, + dtype=dtype, + ) + + if gated: + # Gate tensors must be expanded to hidden dim or we get a PCC error + self.gate_attn = ttnn.as_tensor( + state_dict[f"{state_dict_prefix}gate_attn"].unsqueeze(0).expand(1, self.hidden_size), + dtype=ttnn.bfloat16, + device=self.mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + self.gate_ffn = ttnn.as_tensor( + state_dict[f"{state_dict_prefix}gate_ffn"].unsqueeze(0).expand(1, self.hidden_size), + dtype=ttnn.bfloat16, + device=self.mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + def forward(self, x_11SH, mask=None): + seq_len = x_11SH.shape[-2] + assert seq_len % 32 == 0 and seq_len > 0, "Seqlen must be divisible by 32" + + attn_out = self.attn(self.ln_1(x_11SH), mask=mask) + if self.gated: + attn_out = ttnn.mul(attn_out, ttnn.tanh(self.gate_attn)) + + if self.num_devices > 1: + attn_out = ttnn.all_gather(attn_out, dim=3, num_links=1) + res = ttnn.add(x_11SH, attn_out) + + mlp_out = self.mlp(self.ln_2(res)) + if self.gated: + mlp_out = ttnn.mul(mlp_out, ttnn.tanh(self.gate_ffn)) + out = ttnn.add(res, mlp_out) + + ttnn.deallocate(mlp_out) + ttnn.deallocate(attn_out) + ttnn.deallocate(res) + return out diff --git a/models/tt_transformers/tt/multimodal/gemma/gemma_image_mlp.py b/models/tt_transformers/tt/multimodal/gemma/gemma_image_mlp.py new file mode 100644 index 000000000000..dd307cfe66cc --- /dev/null +++ b/models/tt_transformers/tt/multimodal/gemma/gemma_image_mlp.py @@ -0,0 +1,120 @@ +""" +This is the FeedForward submodule for vision block in Gemma-3-4b-it +We have reused the TtLlamaImageFeedForward with few changes in CoreGrid and program_config configurations +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +class TtGemmaImageFeedForward(LightweightModule): + def __init__( + self, + mesh_device, + args, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.args = args + self.model_config = args.get_model_config() + torch_weight = lambda name, suffix: torch.transpose( + self.state_dict[f"{state_dict_prefix}{name}.{suffix}"], -2, -1 + ) + torch_bias = lambda name, suffix: self.state_dict[f"{state_dict_prefix}{name}.{suffix}"] + + if args.dummy_weights: + cache_name = lambda *_: None + else: + cache_name = lambda name, suffix: weight_cache_path / (state_dict_prefix + f"{name}.{suffix}") + + as_interleaved_tensor = lambda name, suffix, type, dim: ttnn.as_tensor( + ( + torch_weight(name, suffix) if suffix == "weight" else torch_bias(name, suffix) + ), # Grab only the wX part of the name + dtype=type, + device=self.mesh_device, + mesh_mapper=( + ttnn.ShardTensorToMesh(self.mesh_device, dim=dim) + if dim is not None + else ttnn.ReplicateTensorToMesh(self.mesh_device) + ), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + cache_file_name=cache_name(name, suffix), + ) + + # Sharded weights + self.c_fc_weight = as_interleaved_tensor("c_fc", "weight", dtype, dim=-1) + self.c_fc_bias = as_interleaved_tensor("c_fc", "bias", ttnn.bfloat16, dim=-1) + self.c_fc_bias = ttnn.reshape(self.c_fc_bias, [1, -1]) + self.c_proj_weight = as_interleaved_tensor("c_proj", "weight", dtype, dim=-2) + self.c_proj_bias = as_interleaved_tensor("c_proj", "bias", ttnn.bfloat16, dim=None) + + def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: + """ + w1 -> gate_proj + w2 -> down_proj + w3 -> up_proj + HF reference: self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + """ + seq_len = x.shape[-2] + + # Depends on whether we are padding or not + MAX_MM_SEQ_LEN = seq_len if "gemma-3" in self.args.base_model_name else self.args.VISION_MAX_MM_SEQ + + x_in = x + if seq_len >= MAX_MM_SEQ_LEN: # Too big to compute. Set different program configs based on seqlen + # Reshape input to to fit on device and parallelize computation + x_in = ttnn.reshape(x_in, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) + pc_1 = self.model_config["IMAGE_MLP_FC_PROGCFG"](seq_len, MAX_MM_SEQ_LEN) + pc_2 = self.model_config["IMAGE_MLP_PROJ_PROGCFG"](seq_len, MAX_MM_SEQ_LEN) + + # These use HiFi2; this drops 1 bit of the activations but would be FLOP-bound on 12 cores with HiFi4 + c_fc_out = ttnn.linear( + x_in, + self.c_fc_weight, + bias=self.c_fc_bias, + compute_kernel_config=self.args.compute_kernel_config_hifi4, + # core_grid=ttnn.CoreGrid(y=8, x=8) if not pc_1 else None, + dtype=ttnn.bfloat16, + # program_config=pc_1, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + activation="gelu", # NOTE: activation must be passed to linear here, not in program config! Bad output otherwise + ) + + c_proj_out = ttnn.linear( + c_fc_out, + self.c_proj_weight, + compute_kernel_config=self.args.compute_kernel_config_hifi4, + # core_grid=ttnn.CoreGrid(y=8, x=8) if not pc_2 else None, + dtype=ttnn.bfloat16, + # program_config=pc_2, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + # NOTE: Need to reshape to 4D so that fast_reduce_nc hsa a dim1 to work on + c_proj_out = ttnn.reshape(c_proj_out, [1, 1, seq_len, -1]) + + # All reduce + if self.args.num_devices > 1: # replace with reduce_scatter and all_gather + w2_out_gathered = ttnn.all_gather(c_proj_out, dim=1, num_links=1, topology=ttnn.Topology.Linear) + pre_bias_output = ttnn.experimental.fast_reduce_nc( + w2_out_gathered, dims=[1], output=None, compute_kernel_config=None + ) + else: + pre_bias_output = c_proj_out + + output = ttnn.add(pre_bias_output, self.c_proj_bias) + return output diff --git a/models/tt_transformers/tt/multimodal/gemma/gemma_image_transformer.py b/models/tt_transformers/tt/multimodal/gemma/gemma_image_transformer.py new file mode 100644 index 000000000000..2d1bc924d7c9 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/gemma/gemma_image_transformer.py @@ -0,0 +1,64 @@ +""" +This is the Entire ImageTransformer for Gemma-3-4b-it. +We have adapted the TtGemmaImageTransformerBlock from TtLlamaImageTransformerBlock +with changes incorporating the GemmaImageAttention and GemmaImageFeedForward +""" +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from tqdm import tqdm + +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.gemma.gemma_image_block import TtGemmaImageTransformerBlock + + +class TtGemmaImageTransformer(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + layers, + block_key="resblocks", + gated=False, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.gated = gated + + self.resblocks = [ + TtGemmaImageTransformerBlock( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}{block_key}.{i}.", + weight_cache_path=weight_cache_path, + dtype=dtype, + configuration=configuration, + gated=gated, + ) + for i in tqdm(range(layers), desc=f"Loading vision transformer layers") + ] + + def forward(self, x, return_intermediate=None, mask=None): + """ + Different from reference impl in that if return_intermediates, it returns + a list of intermediate tensors rather than a stack of intermediates. + Outer code will have to be aware and handle this correctly. + """ + seq_len = x.shape[-2] + assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128" + + out = [] + for idx, r in enumerate(self.resblocks): + if return_intermediate is not None and idx in return_intermediate: + out.append(x) + x = r(x, mask=mask) + if return_intermediate is not None: + return x, out + return x diff --git a/models/tt_transformers/tt/multimodal/gemma/gemma_vision_block.py b/models/tt_transformers/tt/multimodal/gemma/gemma_vision_block.py new file mode 100644 index 000000000000..208b84616546 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/gemma/gemma_vision_block.py @@ -0,0 +1,106 @@ +""" +This is the Vision Tower Model for Gemma-3-4b-it. +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.gemma.gemma_image_transformer import TtGemmaImageTransformer +from models.tt_transformers.tt.multimodal.gemma.siglip_vision_embedding import TtSiglipVisionEmbeddings +from models.tt_transformers.tt.multimodal.llama_layernorm import TtLayerNorm + + +class TtSiglipGemmaVisionModel(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + dtype, + configuration, + weight_cache_path=None, + return_intermediate=None, + ): + super().__init__() + self.state_dict = state_dict + self.mesh_device = mesh_device + + self.image_size = configuration.vision_chunk_size + self.patch_size = configuration.vision_patch_size + + self.width = configuration.vision_dim + self.layers = configuration.vision_n_layers + self.heads = configuration.vision_attn_n_heads + self.mlp_ratio = configuration.vision_mlp_ratio + self.act_layer = configuration.vision_act_layer + self.in_channels = configuration.vision_in_channels + self.n_global_layers = configuration.vision_n_global_layers + self.return_intermediate = return_intermediate + + self.embeddings = TtSiglipVisionEmbeddings( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}embeddings.", + dtype=dtype, + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.in_channels, + hidden_dim=self.width, + bias=True, + ) + + # transformer + self.encoder = TtGemmaImageTransformer( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}encoder.", + weight_cache_path=configuration.weight_cache_path(dtype), + dtype=dtype, + configuration=configuration, + layers=self.layers, + block_key="layers", + ) + + self.prepare_residual_tensor_prefill = configuration.prepare_residual_tensor_prefill + + self.ln_post = TtLayerNorm( + device=mesh_device, + dim=self.width, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}ln_post.", + weight_cache_path=configuration.weight_cache_path(dtype), + weight_dtype=dtype, + eps=configuration.norm_eps, + ) + + def forward(self, images): + assert isinstance( + images, torch.Tensor + ), "VisionEncoder input must be a torch tensor because of unfold in self.conv1" + + bsz, in_channel, h, w = images.shape + + x = self.embeddings(images) + attention_mask = torch.zeros(bsz, 1, x.shape[1], x.shape[1]) + + tt_mask = ttnn.from_torch( + attention_mask, + device=self.mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + x = self.encoder( + x, + mask=tt_mask, + ) + + x = self.ln_post(x) + + return x diff --git a/models/tt_transformers/tt/multimodal/gemma/gemma_vision_model.py b/models/tt_transformers/tt/multimodal/gemma/gemma_vision_model.py new file mode 100644 index 000000000000..42b6f58c5462 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/gemma/gemma_vision_model.py @@ -0,0 +1,67 @@ +""" +This is the Vision Transformer Block for Gemma-3-4b-it. +This involves vision followed by MultiModalProjector processing +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.gemma.gemma_vision_block import TtSiglipGemmaVisionModel +from models.tt_transformers.tt.multimodal.gemma.multi_modal_projector import TtGemma3MultiModalProjector + + +class TtGemmaTransformerVision(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + dtype, + configuration, + weight_cache_path=None, + return_intermediate=None, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.model_config = configuration.get_model_config() + + self.dim = configuration.dim + self.vision_dim = configuration.vision_dim + self.image_res = configuration.vision_chunk_size + self.patch_size = configuration.vision_patch_size + self.configuration = configuration + + self.vision_encoder = TtSiglipGemmaVisionModel( + mesh_device, + state_dict, + f"{state_dict_prefix}", + weight_cache_path=configuration.weight_cache_path(dtype), + dtype=dtype, + configuration=configuration, + return_intermediate=return_intermediate, + ) + + self.mmp = TtGemma3MultiModalProjector( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix="model.multi_modal_projector", + image_size=configuration.vision_chunk_size, + patch_size=configuration.vision_patch_size, + hidden_size=configuration.vision_hidden_dim, + mm_tokens_per_image=configuration.mm_tokens_per_image, + weight_cache_path=configuration.weight_cache_path(dtype), + layer_norm_eps=1e-06, # layer_norm_eps + dtype=dtype, + configuration=configuration, + ) + + def forward(self, images): + vision_tokens = self.vision_encoder(images)[0, :, :, :] + + vision_tokens = self.mmp(vision_tokens) + return vision_tokens diff --git a/models/tt_transformers/tt/multimodal/gemma/gemma_vision_rmsnorm.py b/models/tt_transformers/tt/multimodal/gemma/gemma_vision_rmsnorm.py new file mode 100644 index 000000000000..f3f3d801ac37 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/gemma/gemma_vision_rmsnorm.py @@ -0,0 +1,172 @@ +""" +This is the modified version of the RMSNorm for Gemma-3-4b-it model. + +We have modified the RMSNorm implementation equivalent to RMSNorm in Gemma-3-4b-it. +We have handled the unit offset addition in the RMSNorm implementation directly into the TTNN Weights +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule + +TILE = 32 +SHARD_HEIGHT = TILE # Current ttnn.rms_norm implementation requires shard height to be a single tile + + +class RMSNorm(LightweightModule): + """ + RMSNorm supporting replication over a MeshDevice and sharding within devices. + + This class implements a Root Mean Square Normalization (RMSNorm) that can be + distributed across multiple devices and cores. If the `device` parameter is a + MeshDevice, the weights and computations are replicated across all devices in + the mesh. Expects an interleaved input tensor, can optionally output a sharded tensor. + + Args: + device: The device or MeshDevice on which to perform the computations. + state_dict: The state dictionary containing the model parameters. + dim: Input dimension (e.g. model hidden dimension size). + layer_num: The layer number to determine the weight key in the state dictionary. + weight_key: The key for retrieving the weight from the state dictionary. + weight_cache_path: Optional path for caching the tilized weights. + weight_memory_config: Configuration for the weight memory, default is DRAM_MEMORY_CONFIG. + weight_dtype: The data type for the tensors, bfp8_b hits >0.999 PCC in the models we tested. + model_config: Optional configuration dictionary for the model. + eps (float): Small value to avoid division by zero in normalization, default is 1e-05. + + If model_config is provided, it must specify SHARDED_NORM_INPUT_MEMCFG, SHARDED_NORM_PRGM_CFG + and SHARDED_NORM_OUTPUT_MEMCFG. If not provided, default configurations will be generated. + """ + + def __init__( + self, + device, + dim, + state_dict, + weight_key, + layer_num=None, + state_dict_prefix=None, + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps: float = 1e-06, + add_unit_offset=True, + sharded_program_config=None, + sharded_output_config=None, + output_mem_config=None, + ccl_topology=ttnn.Topology.Ring, + ): + super().__init__() + self.eps = eps + self.is_distributed = is_distributed + self.ccl_topology = ccl_topology + + if state_dict_prefix: + weight_name = f"{state_dict_prefix}{weight_key}.weight" + else: + if layer_num is None: + weight_name = f"{weight_key}.weight" + else: + weight_name = f"layers.{layer_num}.{weight_key}.weight" + + torch_weight = ( + state_dict[weight_name].unsqueeze(0).view(1, 1, dim).reshape([1, 1, dim // SHARD_HEIGHT, SHARD_HEIGHT]) + ) + if add_unit_offset: + torch_weight = torch_weight + 1.0 + + # # Add offset before caching + cache_name = None if weight_cache_path is None else weight_cache_path / weight_name + + # Compatibility with models that don't use mesh devices (e.g. single-chip Mistral-7b) + is_mesh_device = device.__class__.__name__ == "MeshDevice" + + self.weight = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=cache_name, + mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + ) + + if self.is_distributed: + self.weight_distributed = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=cache_name, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape)) + if is_mesh_device + else None, + ) + + self.sharded_output_config = sharded_output_config + self.sharded_program_config = sharded_program_config + self.output_mem_config = output_mem_config + + self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=False, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) -> ttnn.Tensor: + # If input is sharded do sharded RMSNorm and optionally return sharded output + program_config = self.sharded_program_config if in_sharded else None + memory_config = self.sharded_output_config if out_sharded else None + distributed = self.is_distributed and self.is_distributed(mode) + norm = self._distributed_rmsnorm + weight = self.weight_distributed if distributed else self.weight + + if in_sharded: + assert not distributed, "Distributed RMSNorm does not support sharded inputs" + else: + assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor" + + x = norm( + x, + epsilon=self.eps, + weight=weight, + program_config=program_config, + memory_config=memory_config, + compute_kernel_config=self.compute_kernel_config_hifi2, + ) + + if in_sharded and not out_sharded: + return ttnn.sharded_to_interleaved(x) + else: + return x + + def _distributed_rmsnorm( + self, inp, epsilon=1e-6, weight=None, program_config=None, memory_config=None, compute_kernel_config=None + ): + inp = ttnn.sharded_to_interleaved(inp) + + xnorm = ttnn.pow(inp, 2) + + xnorm = ttnn.mean(xnorm, dim=-1, keepdim=True) + + xnorm = ttnn.rsqrt(xnorm + epsilon) + + xnorm = ttnn.multiply(inp, xnorm) + + weight = ttnn.reshape(weight, [1, 1, 1, -1]) + + output = ttnn.multiply(xnorm, weight) + + if memory_config is not None: + output = ttnn.to_memory_config(output, memory_config) + + ttnn.deallocate(xnorm) + ttnn.deallocate(inp) + + return output diff --git a/models/tt_transformers/tt/multimodal/gemma/multi_modal_projector.py b/models/tt_transformers/tt/multimodal/gemma/multi_modal_projector.py new file mode 100644 index 000000000000..4ba2280c9120 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/gemma/multi_modal_projector.py @@ -0,0 +1,131 @@ +""" +This is the implmentation of MultiModalprojector for Gemma-3-4b-it model. +There is no Independent MultiModalprojector support in TT-Transformers. +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.gemma.gemma_vision_rmsnorm import RMSNorm + + +class TtGemma3MultiModalProjector(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + image_size, + patch_size, + hidden_size, + mm_tokens_per_image, + weight_cache_path, + layer_norm_eps, + dtype, + configuration, + ): + super().__init__() + self.mesh_device = mesh_device + self.dtype = dtype + + self.patches_per_image = int(image_size // patch_size) + self.tokens_per_side = int(mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.hidden_size = hidden_size + + weight_key = state_dict_prefix + ".mm_input_projection_weight" + weight = state_dict[weight_key] + + if configuration.dummy_weights or (weight_cache_path is None): + cache_name = lambda _: None + else: + cache_name = lambda name: weight_cache_path / (f"{state_dict_prefix}{name}") + + # Pad dimensions to multiples of 32 + padded_vision_size = ((hidden_size + 31) // 32) * 32 + + if padded_vision_size != hidden_size: + padding = torch.zeros(hidden_size, padded_vision_size - hidden_size, dtype=weight.dtype) + weight = torch.cat([weight, padding], dim=-1) + + self.mm_input_projection_weight = ttnn.as_tensor( + weight, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + # cache_file_name=cache_name("mm_input_projection_weight"), # pcc drop fix later + ) + + # # Create RMSNorm layer + weight_key = state_dict_prefix + ".mm_soft_emb_norm" + self.mm_soft_emb_norm = RMSNorm( + device=mesh_device, + dim=1152, + state_dict=state_dict, + state_dict_prefix="", + weight_key=weight_key, + weight_dtype=dtype, + is_distributed=False, + # sharded_program_config=tt_model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + # sharded_output_config=tt_model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], + ) + + def forward(self, vision_outputs: ttnn.Tensor) -> ttnn.Tensor: + batch_size, _, seq_length = vision_outputs.shape + mode = "decode" if seq_length <= 32 else "prefill" + + # Reshape: [batch, seq, hidden] -> [batch, hidden, seq] + reshaped_vision_outputs = ttnn.transpose(vision_outputs, 1, 2) + + ttnn.deallocate(vision_outputs) + + reshaped_vision_outputs = ttnn.reshape( + reshaped_vision_outputs, (batch_size, seq_length, self.patches_per_image, self.patches_per_image) + ) + + in_n, in_c, in_h, in_w = reshaped_vision_outputs.shape + reshaped_vision_outputs = ttnn.to_layout(reshaped_vision_outputs, ttnn.ROW_MAJOR_LAYOUT) + reshaped_vision_outputs = ttnn.permute(reshaped_vision_outputs, (0, 2, 3, 1)) + reshaped_vision_outputs = ttnn.reshape(reshaped_vision_outputs, (1, 1, in_n * in_h * in_w, in_c)) + pooled_vision_outputs = ttnn.avg_pool2d( + reshaped_vision_outputs, + batch_size=in_n, + input_h=in_h, + input_w=in_w, + channels=in_c, + kernel_size=(self.kernel_size, self.kernel_size), + stride=(self.kernel_size, self.kernel_size), + padding=(0, 0), + ceil_mode=False, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + applied_shard_scheme=ttnn.TensorMemoryLayout.BLOCK_SHARDED, + ) + # transpose + HOUT = ((in_h - self.kernel_size) // self.kernel_size) + 1 + WOUT = ((in_w - self.kernel_size) // self.kernel_size) + 1 + pooled_vision_outputs = ttnn.reshape(pooled_vision_outputs, (in_n, HOUT, WOUT, in_c)) + + pooled_vision_outputs = ttnn.permute(pooled_vision_outputs, (0, 3, 1, 2)) + pooled_vision_outputs = ttnn.to_layout(pooled_vision_outputs, ttnn.TILE_LAYOUT) + + pooled_vision_outputs = ttnn.reshape( + pooled_vision_outputs, (pooled_vision_outputs.shape[0], pooled_vision_outputs.shape[1], -1) + ) + + # # Flatten(2) + pooled_vision_outputs = ttnn.transpose(pooled_vision_outputs, 1, 2) + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs, mode=mode) + self.mm_input_projection_weight = ttnn.to_layout(self.mm_input_projection_weight, ttnn.TILE_LAYOUT) + projected_vision_outputs = ttnn.matmul(normed_vision_outputs, self.mm_input_projection_weight) + + ttnn.deallocate(pooled_vision_outputs) + ttnn.deallocate(normed_vision_outputs) + + return projected_vision_outputs diff --git a/models/tt_transformers/tt/multimodal/gemma/siglip_vision_embedding.py b/models/tt_transformers/tt/multimodal/gemma/siglip_vision_embedding.py new file mode 100644 index 000000000000..c4f5cad74ff0 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/gemma/siglip_vision_embedding.py @@ -0,0 +1,79 @@ +""" +This is the VisionEmbedding implementation for the Gemma-3-4b-it +This implementation combines patch_conv followed by Embeddings as a submodule. +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.gemma.gemma_conv2d_patch import TtGemmaConv2dPatch + + +class TtSiglipVisionEmbeddings(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + dtype, + image_size, + patch_size, + num_channels, + hidden_dim, + bias=True, + ): + super().__init__() + + self.image_size = image_size + self.patch_size = patch_size + self.hidden_dim = hidden_dim + self.num_channels = num_channels + self.mesh_device = mesh_device + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_ids = ttnn.arange(0, self.num_positions, 1, dtype=ttnn.uint32, device=self.mesh_device) + self.position_ids = ttnn.reshape(self.position_ids, (1, -1)) + + self.patch_embed = TtGemmaConv2dPatch( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}patch_embedding.", + dtype=dtype, + in_channels=num_channels, + out_channels=hidden_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + ) + + # Positional embedding + positional_embedding = state_dict[f"{state_dict_prefix}position_embedding.positional_embedding"] + + self.pos_emb_weights = ttnn.as_tensor( + positional_embedding, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + def forward(self, pixel_values: torch.Tensor) -> ttnn.Tensor: + """ + Args: + pixel_values: torch.Tensor of shape (B, C, H, W) + Returns: + embeddings: ttnn.Tensor of shape (B, num_patches, hidden_dim) + """ + patch_embeddings = self.patch_embed(pixel_values) # [B, num_patches, hidden_dim] + patch_embeddings = ttnn.reshape(patch_embeddings, (1, -1, self.hidden_dim)) + positional_embeddings = ttnn.embedding(self.position_ids, self.pos_emb_weights, layout=ttnn.TILE_LAYOUT) + embeddings = ttnn.add(patch_embeddings, positional_embeddings) + return embeddings diff --git a/models/tt_transformers/tt/rope.py b/models/tt_transformers/tt/rope.py index e5e96c148fb2..f8b4bc10ed8d 100644 --- a/models/tt_transformers/tt/rope.py +++ b/models/tt_transformers/tt/rope.py @@ -218,13 +218,46 @@ def _set_cos_sin_cache(self, seq_len: int, device: Any, dtype: torch.dtype) -> N self.register_buffer("sin_cached", sin.to(dtype), persistent=False) +class LinearRotaryEmbedding(RotaryEmbedding): + def __init__( + self, + dim: int, + max_position_embeddings: int, + base: float, + factor: float, + original_max_position_embeddings: int, + device: Optional[Any] = None, + ) -> None: + self.base = base + self.orig_context_len = original_max_position_embeddings + self.scaling_factor = factor + super().__init__(dim, max_position_embeddings, base, device) + + def apply_scaling(self, freqs: torch.Tensor) -> torch.Tensor: + freqs = freqs / self.scaling_factor + return freqs + + def _set_cos_sin_cache(self, seq_len: int, device: Any, dtype: torch.dtype) -> None: + self.max_seq_len_cached = seq_len + freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim)) + t = torch.arange(seq_len * 2.0) + freqs = self.apply_scaling(freqs) + freqs = torch.outer(t, freqs).float() + cos = torch.cos(freqs) + sin = torch.sin(freqs) + cos, sin = gather_cos_sin(torch.arange(seq_len), cos, sin) + + self.register_buffer("cos_cached", cos.to(dtype), persistent=False) + self.register_buffer("sin_cached", sin.to(dtype), persistent=False) + + def rotary_embedding_factory( dim: int, max_position_embeddings: int, base: float, rope_scaling: Optional[RopeScaling] = None, device: Optional[Any] = None, -) -> Union[RotaryEmbedding, YarnRotaryEmbedding, LlamaRotaryEmbedding]: +) -> Union[RotaryEmbedding, YarnRotaryEmbedding, LlamaRotaryEmbedding, LinearRotaryEmbedding]: if rope_scaling is None: return RotaryEmbedding(dim, max_position_embeddings, base, device) else: @@ -232,6 +265,8 @@ def rotary_embedding_factory( rotary_embedding = LlamaRotaryEmbedding elif rope_scaling.rope_type.value == "yarn": rotary_embedding = YarnRotaryEmbedding + elif rope_scaling.rope_type.value == "linear": + rotary_embedding = LinearRotaryEmbedding else: raise ValueError(f"Invalid rope_scaling: {rope_scaling}") return rotary_embedding(