diff --git a/models/tt_transformers/demo/simple_vision_demo.py b/models/tt_transformers/demo/simple_vision_demo.py index d51340a4429f..951f0b7e13c9 100644 --- a/models/tt_transformers/demo/simple_vision_demo.py +++ b/models/tt_transformers/demo/simple_vision_demo.py @@ -27,7 +27,9 @@ import ttnn from models.demos.utils.llm_demo_utils import create_benchmark_data, verify_perf from models.perf.benchmarking_utils import BenchmarkProfiler +from models.tt_transformers.tt.common import hf_multimodal_encode from models.tt_transformers.tt.generator import Generator +from models.tt_transformers.tt.model_config import CheckpointType def get_batch_sampler(temperature, top_p, tokenizer): @@ -62,6 +64,7 @@ def create_multimodal_model( ): from models.tt_transformers.tt.model_config import ModelArgs from models.tt_transformers.tt.multimodal.llama_vision_model import CrossAttentionTransformer + from models.tt_transformers.tt.multimodal.mistral_24b.mistral_e2e_model import MistralTransformer tt_model_args = ModelArgs(mesh_device, max_batch_size=max_batch_size) assert tt_model_args.is_vision(), "This model is multimodal" @@ -76,14 +79,25 @@ def create_multimodal_model( if checkpoint is None: checkpoint = tt_model_args.load_state_dict() - model = CrossAttentionTransformer( - mesh_device, - state_dict=checkpoint, - weight_cache_path=tt_model_args.weight_cache_path(dtype), - dtype=dtype, - configuration=tt_model_args, - use_paged_kv_cache=use_paged_kv_cache, - ) + + if tt_model_args.base_model_name == "Mistral-Small-3.1-24B": + model = MistralTransformer( + mesh_device=mesh_device, + state_dict=checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(ttnn.bfloat8_b), + dtype=ttnn.bfloat8_b, + args=tt_model_args, + use_paged_kv_cache=use_paged_kv_cache, + ) + else: + model = CrossAttentionTransformer( + mesh_device, + state_dict=checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=tt_model_args, + use_paged_kv_cache=use_paged_kv_cache, + ) return tt_model_args, model, checkpoint @@ -136,7 +150,7 @@ def prepare_generator_args( ) @pytest.mark.parametrize( "test_type,max_seq_len", - (("normal", 512),), + (("normal", 2048),), ids=["normal"], ) @pytest.mark.parametrize( @@ -182,9 +196,6 @@ def test_multimodal_demo_text( profiler = BenchmarkProfiler() profiler.start("run") - ckpt_dir = os.environ["LLAMA_DIR"] - tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") - num_devices = mesh_device.get_num_devices() if isinstance(mesh_device, ttnn.MeshDevice) else 1 max_batch_size *= data_parallel # input batch_size is interpreted as size per DP group @@ -195,11 +206,27 @@ def test_multimodal_demo_text( max_batch_size=max_batch_size, max_seq_len=max_seq_len, ) + + HF_MODEL = model_args[0].checkpoint_type == CheckpointType.HuggingFace + + if not HF_MODEL: + ckpt_dir = os.environ["LLAMA_DIR"] + tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") + + tokenizer = Tokenizer(model_path=tokenizer_path) + formatter = ChatFormat(tokenizer) + else: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_args[0].CKPT_DIR) + tokenizer = model_args[0].tokenizer + generator = Generator(model, model_args, mesh_device) - tokenizer = Tokenizer(model_path=tokenizer_path) - formatter = ChatFormat(tokenizer) - xattn_caches = [model.setup_cache(model_args[i].max_batch_size) for i, model in enumerate(generator.model)] + xattn_caches = [ + model.setup_cache(model_args[i].max_batch_size) if not HF_MODEL else None + for i, model in enumerate(generator.model) + ] # Create random images for trace capture with specific dimensions trace_img_560x560 = create_random_image(560, 560) @@ -264,6 +291,8 @@ def test_multimodal_demo_text( _num_prefill_tokens = 0 _num_decode_tokens = 0 + prompt_encoder = hf_multimodal_encode if HF_MODEL else formatter.encode_dialog_prompt + for iter_num in range(warmup_iters + 1): logger.info(f"Iteration {iter_num}") current_dialogs = trace_dialogs + dialogs @@ -273,9 +302,15 @@ def test_multimodal_demo_text( for msg in dialog: print(f"{msg.role.capitalize()}: {msg.content}\n") batch_model_input = [ - formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) for dialog in batch_dialogs + prompt_encoder(dialog, processor) if HF_MODEL else prompt_encoder(dialog, tool_prompt_format=False) + for dialog in batch_dialogs ] + if HF_MODEL: + image_sizes = [model_input.image_sizes for model_input in batch_model_input] + else: + image_sizes = None + # Do initial prefill vision_images = [ model_input.vision.images if model_input.vision else None for model_input in batch_model_input @@ -288,7 +323,7 @@ def test_multimodal_demo_text( total_lens = prefill_lens + max_gen_len # Create padded tokens tensor for batch - pad_id = tokenizer.pad_id + pad_id = tokenizer.pad_token_id if HF_MODEL else tokenizer.pad_id bsz = len(prompt_tokens) tokens = torch.full((bsz, max(total_lens)), pad_id, dtype=torch.long) @@ -312,6 +347,7 @@ def test_multimodal_demo_text( xattn_caches, total_lens, prefill_lens, + image_sizes=image_sizes, ) # Get cached prefill time @@ -323,12 +359,7 @@ def test_multimodal_demo_text( decode_batch_xattn_masks, decode_batch_text_masks, ) = generator.prefill_forward( - vision_images, - vision_mask, - tokens, - xattn_caches, - total_lens, - prefill_lens, + vision_images, vision_mask, tokens, xattn_caches, total_lens, prefill_lens, image_sizes=image_sizes ) prefill_end = time.perf_counter() @@ -375,12 +406,16 @@ def test_multimodal_demo_text( ) # gen_idx is (num_tokens - 1) to avoid counting compile iter # Log full text output for each user in batch - vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256] + if HF_MODEL: + # For HF models, get vision tokens from the processor if they exist + vision_tokens = [] + else: + vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256] for user_id in range(max_batch_size): # Remove <|image|> tokens since they break the tokenizer tokens_out = [ - t if t not in vision_tokens else tokenizer.pad_id + t if t not in vision_tokens else pad_id for t in tokens[user_id].tolist()[: position_id[user_id] + 2] ] text = tokenizer.decode(tokens_out) diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_conv2d.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_conv2d.py new file mode 100644 index 000000000000..f754b220bfc2 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_conv2d.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +""" +This file is a unit test for validating the Mistral-24B conv2d. +""" +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.mistral_24b.vision_conv2d import TtMistralConv2dPatch +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from ttnn import ConcatMeshToTensor + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_conv2d_inference( + mesh_device, + reset_seeds, +): + pcc_required = 0.9999 + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "vision_tower.patch_conv." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + num_devices = model_args.num_devices + + ##### Create input tensor for the all gather ##### + B, NCH, H, W = (1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size) + in_channels, out_channels, kernel_size, stride, bias = ( + 3, + model_args.vision_dim, + model_args.vision_patch_size, + model_args.vision_patch_size, + False, + ) + + assert NCH == in_channels, "Number of channels in input tensor should match in_channels for the Conv2d patch." + assert type(kernel_size) == int, "Only symmetric kernel_size is currently supported." + assert kernel_size == stride, "Only same kernel_size and stride are currently supported." + + assert H % kernel_size == 0, "Height should be divisible by kernel_size." + assert W % kernel_size == 0, "Width should be divisible by kernel_size." + + ##### Prepare inputs ##### + input_tensor = torch.randn((B, NCH, H, W)).to(dtype=torch.bfloat16) + logger.info(f"Input tensor shape: {input_tensor.shape}") + + reference_model = model_args.reference_conv2d_patch() + reference_model.load_state_dict(partial_state_dict) + reference_output = reference_model(input_tensor) + + tt_model = TtMistralConv2dPatch( + mesh_device, + state_dict, + first_layer_prefix, + dtype, + in_channels, + out_channels, + kernel_size, + stride, + bias, + ) + tt_output = tt_model(input_tensor) + + ##### Check the outputs ##### + out = ttnn.from_device(tt_output) + tt_output_torch = ttnn.to_torch(out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=2)) + + # Only select output from one device + tt_output_torch = tt_output_torch[0, ..., :out_channels] + + # 1. Restore batch dim + tt_output_torch = tt_output_torch.unsqueeze(0) + # 1 1024 4096 + # 2. Permute to match Conv2D output: (N, C_out, H_out, W_out) + tt_output_torch = tt_output_torch.permute(0, 2, 1).reshape(1, 1024, 64, 64) + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py new file mode 100644 index 000000000000..6ccc36813ae9 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.mistral_24b.vision_rope import VisionRotarySetup as RotarySetup +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_rot_emb(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + partial_state_dict = {} + + reference_model = tt_model_args.reference_vision_rot_emb() + reference_model.load_state_dict(partial_state_dict) + + image_size = tt_model_args.vision_image_size + patch_size = tt_model_args.vision_patch_size + dim = tt_model_args.vision_head_dim + num_patches_per_dim = image_size // patch_size + num_patches = num_patches_per_dim * num_patches_per_dim + position_ids = torch.arange(4096, dtype=torch.long) + + x = torch.randn(batch_size, 4096, 1024) + + cos, sin = reference_model(x, position_ids) + tt_model = RotarySetup( + device, + batch_size, + dim, + image_size, + patch_size, + num_patches, + tt_model_args.vision_rope_theta, + scale_factor=None, + orig_context_len=num_patches, + datatype=dtype, + ) + + cos2, sin2 = tt_model.get_rot_mats(position_ids) + cos2 = ttnn.from_device(cos2) + cos2 = ttnn.to_torch(cos2) + cos2 = cos2.squeeze(0) + + sin2 = ttnn.from_device(sin2) + sin2 = ttnn.to_torch(sin2) + sin2 = sin2.squeeze(0) + + passing, pcc_message = comp_pcc(cos, cos2) + + logger.info(comp_allclose(cos, cos2)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"COS PCC value is lower than {0.99} for some of the outputs. Check Warnings!" + + passing, pcc_message = comp_pcc(sin, sin2) + + logger.info(comp_allclose(sin, sin2)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"SIN PCC value is lower than {0.99} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_transformer.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_transformer.py new file mode 100644 index 000000000000..618e7122e2f0 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_transformer.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.ccl import TT_CCL +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.mistral_24b.vision_pixtral_transformer import TtPixtralTransformer +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 1),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_image_transformer_inference(batch, num_chunks, mesh_device): + pcc_required = 0.99 + + model_args = ModelArgs(mesh_device) + dtype = ttnn.bfloat16 + + state_dict = model_args.load_state_dict() + n_layers = model_args.vision_n_layers + first_layer_prefix = "vision_tower.transformer." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + dim = model_args.vision_dim + heads = model_args.vision_attn_n_heads + seq_len = model_args.vision_chunk_ntok - 1 + head_dim = dim // heads + + reference_model = model_args.reference_vision_encoder() + reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + all_tests_pass = True + + tt_ccl = TT_CCL(mesh_device) + tt_model = TtPixtralTransformer( + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=None, + dtype=dtype, + configuration=model_args, + layers=n_layers, + ) + + # Create PT input + pt_attention_input = torch.rand(batch, seq_len, dim).to(torch.bfloat16) + attention_mask = torch.zeros(batch, 1, seq_len, seq_len).to(torch.bfloat16) + + B, T, D = pt_attention_input.shape + cos = torch.ones((1, T, head_dim)).to(torch.bfloat16) + sin = torch.zeros((1, T, head_dim)).to(torch.bfloat16) + + cos_t = ttnn.from_torch( + cos, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + sin_t = ttnn.from_torch( + sin, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + attention_input = model_args.prepare_residual_tensor_prefill( + pt_attention_input, + force_replicated=True, + ) + + with torch.no_grad(): + tt_out = tt_model(attention_input, position_embeddings=(cos_t, sin_t)) + reference_output = reference_model( + pt_attention_input, attention_mask=attention_mask, position_embeddings=(cos, sin) + )[0] + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[ + : tt_out.shape[0] + ] + tt_output_torch = tt_output_torch.squeeze(0) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + if not passing: + logger.warning(f"PCC value -- {pcc_message} -- is lower than {pcc_required} for the output.") + else: + logger.info(f"PCC: {pcc_message}") + logger.info(comp_allclose(reference_output, tt_output_torch)) + all_tests_pass = all_tests_pass and passing + + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_attention.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_attention.py new file mode 100644 index 000000000000..4f4994704f64 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_attention.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.ccl import TT_CCL +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.mistral_24b.vision_attention import ( + TtMistralImageAttention as TtLlamaImageAttention, +) +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from ttnn import ConcatMeshToTensor + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_vision_attention(mesh_device, seq_len, batch_size): + logger.info(f"seq_len: {seq_len}, batch_size: {batch_size}") + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=256) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower.transformer.layers.0.attention." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model = model_args.reference_vision_attention() + reference_model.load_state_dict(partial_state_dict) + + hidden_size = model_args.vision_dim + n_heads = model_args.vision_attn_n_heads + head_dim = hidden_size // n_heads + + tt_ccl = TT_CCL(mesh_device) + tt_model = TtLlamaImageAttention( + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + ) + + dim = model_args.vision_dim + pt_attention_input = torch.randn(batch_size, seq_len, dim).to(torch.bfloat16) + attention_mask = torch.zeros(batch_size, 1, seq_len, seq_len).to(torch.bfloat16) + + B, T, D = pt_attention_input.shape + cos = torch.ones((1, T, head_dim)).to(torch.bfloat16) + sin = torch.zeros((1, T, head_dim)).to(torch.bfloat16) + + attention_input = ttnn.from_torch( + pt_attention_input.unsqueeze(0), + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + cos_t = ttnn.from_torch( + cos, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + sin_t = ttnn.from_torch( + sin, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + tt_out = tt_model(attention_input, position_embeddings=(cos_t, sin_t)) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=-1))[ + :, :, :, : tt_out.shape[-1] + ] + tt_output_torch = tt_output_torch.squeeze(0) + reference_output = reference_model(pt_attention_input, attention_mask, position_embeddings=(cos, sin))[0] + pcc_required = 0.99 + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_mlp.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_mlp.py new file mode 100644 index 000000000000..ac17185e02a0 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_mlp.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs + +from models.tt_transformers.tt.multimodal.mistral_24b.vision_mlp import MistralTTVisionMLP as MLP +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (64 * 1024,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_mlp_inference(seq_len, batch_size, mesh_device, reset_seeds): + dtype = ttnn.bfloat8_b + mode = "decode" if seq_len <= 32 else "prefill" + + model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=128) + model_args.n_layers = 1 + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower.transformer.layers.0.feed_forward." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model = model_args.reference_vision_mlp() + reference_model.load_state_dict(partial_state_dict) + + tt_model = MLP( + mesh_device=mesh_device, + args=model_args, + state_dict=state_dict, + weight_cache_path=model_args.weight_cache_path(dtype), + state_dict_prefix="vision_tower.transformer.layers.0.feed_forward.", + dtype=dtype, + ) + torch_input = torch.randn(1, 1, seq_len, 1024).to(torch.bfloat16) + + reference_output = reference_model(torch_input) + tt_input = ttnn.from_torch( + torch_input, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + ) + + logger.info("Run MLP") + tt_output = tt_model(tt_input) + + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ + :, :, :, :1024 + ] + pcc_required = 0.99 + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + if passing: + logger.info("MLP Passed!") + else: + logger.warning("MLP Failed!") + + assert passing, f"MLP output does not meet PCC requirement {pcc_required}: {pcc_message}." diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py new file mode 100644 index 000000000000..e9f1ec9a91ed --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +""" +This file is a unit test for validating the Mistral-24B Vision Model pipeline. +""" + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.ccl import TT_CCL +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.mistral_24b.vision_model import TtMistralVisionTransformer +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +def get_image_features(vision_tower, projector, input_tensor, image_sizes): + """ + Get image features from the vision tower and projector. + """ + vision_token = vision_tower(input_tensor, image_sizes).last_hidden_state + image_features = projector(vision_token.squeeze(0), image_sizes) + return image_features + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_mistral_vision_model(mesh_device, reset_seeds): + pcc_required = 0.97 + dtype = ttnn.bfloat8_b + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix) + } + + ##### Reference model output (Torch) ##### + reference_model = model_args.reference_vision_model() + reference_model.load_state_dict(partial_state_dict) + + mmp_first_layer_prefix = "multi_modal_projector." + + mmp_partial_state_dict = { + k[len(mmp_first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(mmp_first_layer_prefix)) + } + + reference_mmp = model_args.reference_vision_multi_modal() + reference_mmp.load_state_dict(mmp_partial_state_dict) + + B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size + input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16) + + reference_output = get_image_features(reference_model, reference_mmp, input_tensor, image_sizes=[(H, W)]) + + # ##### TT Model: TtMistralVisionTransformer ##### + tt_ccl = TT_CCL(mesh_device=mesh_device) + vision_model = TtMistralVisionTransformer( + mesh_device=mesh_device, + tt_ccl=tt_ccl, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + dtype=dtype, + model_args=model_args, + ) + + tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) # [0] + tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ + :, : tt_output.shape[-1] + ] + + non_zero_indices = tt_output.ne(0).nonzero(as_tuple=True) + tt_output = tt_output[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC below {pcc_required}. {pcc_message}" diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py new file mode 100644 index 000000000000..868cfec3c806 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_rmsnorm_inference(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_vision_rms() + + first_layer_prefix = "vision_tower.transformer.layers.0.ffn_norm." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model.load_state_dict(partial_state_dict) + + tt_model = RMSNorm( + device=device, + dim=1024, + state_dict=state_dict, + state_dict_prefix="vision_tower.transformer.layers.0.", + weight_key="ffn_norm", + weight_dtype=dtype, + is_distributed=False, + simplified_rms=True, + ) + input = torch.rand(batch_size, seq_len, 1024) + + reference_output = reference_model(input) + + tt_input = ttnn.from_torch( + input, + device=device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device=device), + ) + + tt_output = tt_model(tt_input, mode=mode) + + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(device, dim=-1))[ + :, : tt_output.shape[-1] + ] + + logger.info(f"tt_output_torch: {tt_output_torch.shape}") + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("rms_norm Passed!") + else: + logger.warning("rms_norm Failed!") + + assert passing, f"rms_norm output does not meet PCC requirement {0.99}." diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py new file mode 100644 index 000000000000..fdfe2f1dcb5b --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +""" +This file is a unit test for validating the Mistral-24B Vision Tower model. +""" + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.ccl import TT_CCL +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.mistral_24b.mistral_vision_tower import MistralVisionTower +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_mistral_vision_tower(mesh_device, reset_seeds): + pcc_required = 0.99 + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix) + } + + B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size + input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16) + + ##### Reference model output (Torch) ##### + reference_model = model_args.reference_vision_model() + reference_model.load_state_dict(partial_state_dict) + reference_output = reference_model(input_tensor, image_sizes=[(H, W)]) + + reference_output = reference_output.last_hidden_state + tt_ccl = TT_CCL(mesh_device) + ##### TT Model: MistralVisionTower ##### + vision_model = MistralVisionTower( + mesh_device=mesh_device, + tt_ccl=tt_ccl, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + dtype=dtype, + configuration=model_args, + ) + + tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) + tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ + :, :, :, : tt_output.shape[-1] + ] + tt_output = tt_output.squeeze(0) + passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC below {pcc_required}. {pcc_message}" diff --git a/models/tt_transformers/tt/common.py b/models/tt_transformers/tt/common.py index 28612764db38..b7d8023a4b51 100644 --- a/models/tt_transformers/tt/common.py +++ b/models/tt_transformers/tt/common.py @@ -5,9 +5,11 @@ import math import re from enum import Enum +from types import SimpleNamespace from typing import Optional import torch +from llama_models.llama3.api.datatypes import ImageMedia from loguru import logger from pydantic import AliasChoices, BaseModel, Field @@ -98,6 +100,26 @@ def rope_scaling_model_factory(rope_scaling_params: dict) -> RopeScaling: raise ValueError(f"Unexpected RoPE scaling type: {rope_scaling_type}") +def position_ids_in_meshgrid_tt(tt_patch_embeds_list, max_width, device): + position_ids_tt = [] + for tt_patch in tt_patch_embeds_list: + shape = tt_patch.shape + height, width = shape[-2], shape[-1] + mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_width + v_grid + + tt_ids = ttnn.from_torch( + ids, + device=device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + position_ids_tt.append(tt_ids[:, 0]) + return ttnn.concat(position_ids_tt, dim=0) + + def encode_prompt_instruct(tokenizer, prompt_text, system_prompt_text=None): """<|begin_of_text|><|start_header_id|>system<|end_header_id|> {{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|> @@ -283,6 +305,43 @@ def apply_llama3_scaling(freqs: torch.Tensor, scale_factor: float, orig_context_ return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) +def apply_scaling_vision(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): + return freqs / scale_factor + + +def precompute_mistral_vision_freqs( + dim: int, max_patches_per_side: int, theta: float, scale_factor=None, orig_context_len=None +): + # Compute base frequencies + base_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + if scale_factor is not None: + base_freqs = apply_scaling_vision(base_freqs, scale_factor, orig_context_len) + + # Get height and width indices + h_idx = torch.arange(max_patches_per_side) + w_idx = torch.arange(max_patches_per_side) + + # Compute 2D frequency matrices + freqs_h = torch.outer(h_idx, base_freqs[::2]) + freqs_w = torch.outer(w_idx, base_freqs[1::2]) + + # Broadcast + merge + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape( + -1, dim // 2 + ) # Shape: [H*W, dim//2] + + full_freqs = torch.cat([inv_freq, inv_freq], dim=-1) + cos = full_freqs.cos() + sin = full_freqs.sin() + return cos, sin # Shape: [H*W, dim] + + def precompute_freqs(dim: int, end: int, theta, scale_factor, orig_context_len): """ Precompute the frequency tensor for sine and cosine values with given dimensions. @@ -615,3 +674,46 @@ def create_tt_model( tt_kv_cache = [l.attention.layer_past for l in model.layers] if paged_attention_config else None return tt_model_args, model, tt_kv_cache, state_dict + + +def hf_multimodal_encode(messages, processor): + hf_messages = [] + + for msg in messages: + hf_content = [] + + for item in msg.content: + if isinstance(item, ImageMedia): + hf_content.append( + { + "type": "image", + "image": item.image, + } + ) + elif isinstance(item, str): + hf_content.append( + { + "type": "text", + "text": item, + } + ) + + hf_messages.append( + { + "role": msg.role, + "content": hf_content, + } + ) + + encoded = processor.apply_chat_template( + hf_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to("cpu", dtype=torch.bfloat16) + + return SimpleNamespace( + **encoded, + tokens=encoded["input_ids"].squeeze(0), + vision=SimpleNamespace( + images=encoded["pixel_values"], + mask=None, + ), + ) diff --git a/models/tt_transformers/tt/generator.py b/models/tt_transformers/tt/generator.py index 2eb8863c24c9..115c23ee3b9d 100644 --- a/models/tt_transformers/tt/generator.py +++ b/models/tt_transformers/tt/generator.py @@ -22,6 +22,7 @@ get_padded_prefill_len, num_blocks_in_seq, ) +from models.tt_transformers.tt.model_config import CheckpointType @dataclass(frozen=True) @@ -86,6 +87,7 @@ def prefill_forward_text( seq_len = int(prompt_lens[idx]) last_token_idx = seq_len - 1 prefill_seq_len = get_padded_prefill_len(seq_len) + local_kwargs = kwargs.copy() logger.info(f"Prefilling User {user_id + 1} up to {seq_len} tokens") @@ -101,6 +103,12 @@ def prefill_forward_text( ) model_kv_cache = kv_cache[model_id] if kv_cache is not None else None + # Check if 'pixel_values' exists and index it safely + if "pixel_values" in local_kwargs: + local_kwargs["pixel_values"] = local_kwargs["pixel_values"][idx] + if "image_sizes" in local_kwargs: + local_kwargs["image_sizes"] = local_kwargs["image_sizes"][idx] + logits = self.prefill_forward_single_user_text( prefill_ids, page_table=page_table_user, @@ -108,7 +116,7 @@ def prefill_forward_text( last_token_idx=last_token_idx, kv_cache=model_kv_cache, model_id=model_id, - **kwargs, + **local_kwargs, ) out_list.append(logits) @@ -188,6 +196,7 @@ def prefill_forward_single_user_text( chunk_start_idx=chunk_start, get_last_token=(last_token_idx_in_chunk // 32) * 32, kv_cache=kv_cache, + **kwargs, ) if chunk_start == last_chunk_start: @@ -492,6 +501,61 @@ def _prefill_forward_single_user( # Note: This function is called by vLLM def prefill_forward( + self, + vision_images, + vision_masks, + tokens, + xattn_caches, + total_lens, + prompt_lens, + page_table=None, + kv_cache=None, + cross_page_table=None, + empty_slots=None, + **kwargs, + ): + if self.model_args[0].checkpoint_type == CheckpointType.HuggingFace: + logits = self.prefill_forward_text( + tokens, + page_table=page_table, + kv_cache=kv_cache, + prompt_lens=prompt_lens, + pixel_values=vision_images, + **kwargs, + ) + + return logits, None, None, None, None + + else: + ( + output_logits, + prefill_output_xattn_masks, + prefill_output_full_text_row_masked_out_masks, + decode_output_xattn_masks, + decode_output_full_text_row_masked_out_masks, + ) = self.prefill_forward_llama_vision( + vision_images, + vision_masks, + tokens, + xattn_caches, + total_lens, + prompt_lens, + page_table=page_table, + kv_cache=kv_cache, + cross_page_table=cross_page_table, + empty_slots=empty_slots, + ) + + return ( + output_logits, + prefill_output_xattn_masks, + prefill_output_full_text_row_masked_out_masks, + decode_output_xattn_masks, + decode_output_full_text_row_masked_out_masks, + ) + + # Note: This function is called by vLLM + def prefill_forward_llama_vision( self, vision_images, vision_masks, @@ -588,7 +652,7 @@ def prefill_forward( ) # Note: This function is called by vLLM - def decode_forward( + def decode_forward_llama_vision( self, start_pos, tokens, @@ -652,6 +716,47 @@ def decode_forward( else: return tt_logits + def decode_forward( + self, + start_pos, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + xattn_caches=None, + page_table=None, + kv_cache=None, + cross_page_table=None, + enable_trace=True, + read_from_device=True, + ): + pass + + if self.model_args[0].checkpoint_type == CheckpointType.HuggingFace: + return self.decode_forward_text( + tokens, + start_pos, + enable_trace=enable_trace, + page_table=page_table, + kv_cache=kv_cache, + ) + else: + return self.decode_forward_llama_vision( + start_pos, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + xattn_caches, + page_table, + kv_cache, + cross_page_table, + enable_trace, + read_from_device, + ) + # Note: This function is called by vLLM def read_decode_output(self, tt_out, async_read=False): """ diff --git a/models/tt_transformers/tt/load_checkpoints.py b/models/tt_transformers/tt/load_checkpoints.py index fa00a6b882cf..ddbcccf41557 100644 --- a/models/tt_transformers/tt/load_checkpoints.py +++ b/models/tt_transformers/tt/load_checkpoints.py @@ -528,7 +528,10 @@ def convert_hf_qkv_to_meta_format(loaded_weights, head_dim): """Convert HuggingFace QKV weights to Meta format for RoPE compatibility.""" converted_weights = {} for key, tensor in loaded_weights.items(): - if "q_proj.weight" in key or "k_proj.weight" in key: + if "vision_tower" in key: + # Skip conversion for vision tower weights + converted_weights[key] = tensor + elif "q_proj.weight" in key or "k_proj.weight" in key: # For weights: n_heads = tensor.shape[0] // head_dim n_heads = tensor.shape[0] // head_dim converted_weights[key] = reverse_permute(tensor, n_heads, tensor.shape[0], tensor.shape[1]) @@ -592,10 +595,32 @@ def map_hf_to_meta_keys(loaded_weights): ("o_proj", "wo"), ("q_norm", "q_norm"), ("k_norm", "k_norm"), + ("patch_conv.weight", "patch_conv._linear.weight"), ] return replace_keys(loaded_weights, replacements) +def map_vision_meta_to_hf_keys(loaded_weights): + """ + Map Hugging Face checkpoint keys to Meta checkpoint keys. + You can use this to support other models by adding more mappings. + See replace_keys for more details on the format of replacements. + """ + base_mapping = [ + ("w1", "gate_proj"), + ("w2", "down_proj"), + ("w3", "up_proj"), + ("wq", "q_proj"), + ("wk", "k_proj"), + ("wv", "v_proj"), + ("wo", "o_proj"), + ("_linear.weight", "weight"), + ] + mapping = base_mapping + + return replace_keys(loaded_weights, mapping) + + def convert_vision_meta_to_hf(state_dict, head_dim): # state_dict = convert_meta_qkv_to_hf_format(state_dict, head_dim) state_dict = map_vision_meta_to_hf_keys(state_dict) diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index 65c339b68c9a..bb94d3e7c251 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -571,6 +571,7 @@ def __init__( "Phi-3.5-mini-instruct": {"N150": 128, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, "QwQ-32B": {"N150": None, "N300": None, "T3K": 64, "TG": 128, "P150x4": 128}, "Qwen3-32B": {"N150": None, "N300": None, "T3K": 64, "TG": 128, "P150x4": 128}, + "Mistral-Small-3.1-24B": {"N150": 8, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, } try: max_prefill_chunk_size_div1024 = MAX_PREFILL_CHUNK_SIZES_DIV1024[self.base_model_name][self.device_name] @@ -1530,6 +1531,24 @@ def _set_params_from_dict(self, config, is_hf=False): self._set_vision_params(config) self.is_multimodal = "vision_config" in config or self.is_vision() + # Vision params (Meta-specific) + self.vision_chunk_size = config.get("vision_chunk_size", 896) + self.vision_max_num_chunks = config.get("vision_max_num_chunks", 4) + self.vision_num_cross_attention_layers = config.get("vision_num_cross_attention_layers", -1) + + # Vision constants + self.vision_dim = 1280 + self.vision_mlp_ratio = 4 + self.vision_hidden_dim = int(self.vision_dim * self.vision_mlp_ratio) + self.vision_act_layer = ttnn.UnaryOpType.GELU + self.vision_dropout = 0.0 + self.vision_attn_n_heads = 16 + self.vision_head_dim = self.vision_dim // self.vision_attn_n_heads + self.vision_n_layers = 32 + self.vision_n_global_layers = 8 + self.vision_max_num_tiles = 4 + self.vision_patch_size = 14 + self.vision_in_channels = 3 self.state_dict_text_prefix = self._get_text_prefix() self.state_dict_vision_prefix = self._get_vision_prefix() @@ -1605,28 +1624,32 @@ def _set_params(self, checkpoint_dir): else None ) - def _set_vision_params(self, config): - vision_config = config.get("vision_config", config) - - self.vision_chunk_size = vision_config.get("vision_chunk_size", -1) + def _set_vision_params(self, vision_config): + self.vision_chunk_size = vision_config.get("vision_chunk_size", 896) self.image_size = vision_config.get("image_size", 896) self.vision_max_num_chunks = vision_config.get("vision_max_num_chunks", 4) - self.vision_num_cross_attention_layers = vision_config.get("vision_num_cross_attention_layers", -1) - self.vision_dim = vision_config.get("hidden_size", 1280) - + self.vision_num_cross_attention_layers = vision_config.get("vision_num_cross_attention_layers", 8) + self.vision_dim = vision_config.get("hidden_size", 1152) intermediate_size = vision_config.get("intermediate_size", self.vision_dim * 4) + self.vision_image_size = vision_config.get("image_size", 1540) + self.vision_rope_theta = vision_config.get("rope_theta", 10000.0) + self.image_token_index = vision_config.get("image_token_index", 10) + self.vision_mlp_ratio = intermediate_size // self.vision_dim self.vision_hidden_dim = int(self.vision_dim * self.vision_mlp_ratio) - self.vision_attn_n_heads = vision_config.get("num_attention_heads", 16) + self.vision_attn_n_heads = vision_config.get("num_attention_heads") or vision_config.get("num_heads") or 16 self.vision_head_dim = self.vision_dim // self.vision_attn_n_heads - self.vision_n_layers = vision_config.get("num_hidden_layers", 32) + self.vision_n_layers = vision_config.get("num_hidden_layers") or vision_config.get("depth") or 27 self.vision_patch_size = vision_config.get("patch_size", 14) self.vision_in_channels = vision_config.get("num_channels", 3) self.vision_dropout = vision_config.get("attention_dropout", 0.0) self.mm_tokens_per_image = vision_config.get("mm_tokens_per_image", 256) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + self.vision_head_dim = vision_config.get("head_dim", 64) + # Optional vision activation layer, defaults to GELU act_layer = vision_config.get("act_layer", "gelu").lower() self.vision_act_layer = { @@ -1640,6 +1663,18 @@ def _set_vision_params(self, config): self.vision_n_global_layers = vision_config.get("n_global_layers", 8) def _set_hf_params(self, checkpoint_dir): + def merge_text_config(base_config): + text_config = base_config.get("text_config", {}) + # Merge non-nested keys into text_config + text_config.update({k: v for k, v in base_config.items() if k not in ["text_config", "vision_config"]}) + return text_config + + def merge_vision_config(base_config): + vision_config = base_config.get("vision_config", {}) + # Merge non-nested keys into vision_config + vision_config.update({k: v for k, v in base_config.items() if k not in ["text_config", "vision_config"]}) + return vision_config + if self.from_hf_url: # Special case Qwen2.5-VL models until they are fully integrated into a HF release if "Qwen/Qwen2.5-VL" in self.model_name: @@ -1655,12 +1690,25 @@ def _set_hf_params(self, checkpoint_dir): else: self.hf_config = AutoConfig.from_pretrained(self.CKPT_DIR) config = self.hf_config.to_dict() + if "text_config" in config or "vision_config" in config: + merged_text_config = merge_text_config(config) + self._set_params_from_dict(merged_text_config, is_hf=True) + + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + self._set_vision_params(config["vision_config"]) + else: + if "vision_config" in config: + merged_vision_config = merge_vision_config(config) + self._set_vision_params(merged_vision_config) + else: + self._set_params_from_dict(config, is_hf=True) + else: config_file = os.path.join(checkpoint_dir, "config.json") assert os.path.exists(config_file), f"config.json file not found at {config_file}" with open(config_file, "r") as f: config = json.load(f) - self._set_params_from_dict(config, is_hf=True) + self._set_params_from_dict(config) def __repr__(self): return f"""ModelArgs( @@ -1686,8 +1734,12 @@ def is_vision(self): return self.vision_chunk_size > 0 def get_state_dict_prefix(self, module_name, layer_num, is_vision=False): - text_prefix = self.state_dict_text_prefix - vision_prefix = self.state_dict_vision_prefix + if self.is_vision() and self.model_name.startswith("Mistral") and "Small-3.1-24B" not in self.model_name: + text_prefix = self.state_dict_text_prefix + else: + text_prefix = "" if not is_vision else self.state_dict_text_prefix + + vision_prefix = self.state_dict_vision_prefix if is_vision else "" layer_prefix = f"layers.{layer_num}." if layer_num is not None else "" @@ -1753,6 +1805,8 @@ def load_state_dict(self): ) print("Loading Qwen2.5-VL model: ", AutoModelForCausalLM) + elif "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + from transformers import Mistral3ForConditionalGeneration as AutoModelForCausalLM else: from transformers import AutoModelForCausalLM @@ -2135,55 +2189,76 @@ def create_tokenizer(self): logger.info(f"Model name: {self.model_name}") logger.info(f"Base model name: {self.base_model_name}") - try: - # Try to load tokenizer from the original model path - tokenizer = AutoTokenizer.from_pretrained(self.TOKENIZER_PATH) - logger.info(f"Successfully loaded tokenizer from {self.TOKENIZER_PATH}") - except Exception as e: - logger.warning(f"Failed to load tokenizer from {self.TOKENIZER_PATH}: {e}") - - # Try to use base model tokenizer as fallback - fallback_tokenizer_path = base_model_tokenizer_mapping.get(self.base_model_name) - - # If no direct match, try to infer from model name patterns - if not fallback_tokenizer_path: - model_name_lower = self.model_name.lower() - if "qwen2.5" in model_name_lower and "0.5b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-Coder-0.5B-Instruct" - elif "qwen2.5" in model_name_lower and "1.5b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-1.5B-Instruct" - elif "qwen2.5" in model_name_lower and "3b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-3B-Instruct" - elif "qwen2.5" in model_name_lower and "7b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-7B-Instruct" - elif "qwen2.5" in model_name_lower and "14b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-14B-Instruct" - elif "qwen2.5" in model_name_lower and "32b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-32B-Instruct" - elif "qwen2.5" in model_name_lower and "72b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-72B-Instruct" - elif "llama" in model_name_lower and "3.1" in model_name_lower and "8b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.1-8B-Instruct" - elif "llama" in model_name_lower and "3.1" in model_name_lower and "70b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.1-70B-Instruct" - elif "llama" in model_name_lower and "3.2" in model_name_lower and "1b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.2-1B-Instruct" - elif "llama" in model_name_lower and "3.2" in model_name_lower and "3b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.2-3B-Instruct" - elif "mistral" in model_name_lower and "7b" in model_name_lower: - fallback_tokenizer_path = "mistralai/Mistral-7B-Instruct-v0.3" - - if fallback_tokenizer_path: - logger.info(f"Attempting to use fallback tokenizer: {fallback_tokenizer_path}") - try: - tokenizer = AutoTokenizer.from_pretrained(fallback_tokenizer_path) - logger.info(f"Successfully loaded fallback tokenizer from {fallback_tokenizer_path}") - except Exception as fallback_e: - logger.error(f"Failed to load fallback tokenizer from {fallback_tokenizer_path}: {fallback_e}") - raise fallback_e - else: - logger.error(f"No fallback tokenizer found for base model: {self.base_model_name}") - raise e + # Special handling for Mistral-Small-3.1-24B-Instruct-2503 + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + tokenizer = AutoTokenizer.from_pretrained( + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", trust_remote_code=True + ) + logger.info("Manually setting Mistral instruct-style chat template on the tokenizer.") + + mistral_template = """{% for message in messages %} + {% if message['role'] == 'system' %} + <|system|> + {{ message['content'] }} + {% elif message['role'] == 'user' %} + [INST] {{ message['content'] }} [/INST] + {% elif message['role'] == 'assistant' %} + {{ message['content'] }}{{ eos_token }} + {% endif %} + {% endfor %}""" + tokenizer.chat_template = mistral_template + else: + try: + # Try to load tokenizer from the original model path + tokenizer = AutoTokenizer.from_pretrained(self.TOKENIZER_PATH) + logger.info(f"Successfully loaded tokenizer from {self.TOKENIZER_PATH}") + except Exception as e: + logger.warning(f"Failed to load tokenizer from {self.TOKENIZER_PATH}: {e}") + + # Try to use base model tokenizer as fallback + fallback_tokenizer_path = base_model_tokenizer_mapping.get(self.base_model_name) + + # If no direct match, try to infer from model name patterns + if not fallback_tokenizer_path: + model_name_lower = self.model_name.lower() + if "qwen2.5" in model_name_lower and "0.5b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-Coder-0.5B-Instruct" + elif "qwen2.5" in model_name_lower and "1.5b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-1.5B-Instruct" + elif "qwen2.5" in model_name_lower and "3b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-3B-Instruct" + elif "qwen2.5" in model_name_lower and "7b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-7B-Instruct" + elif "qwen2.5" in model_name_lower and "14b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-14B-Instruct" + elif "qwen2.5" in model_name_lower and "32b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-32B-Instruct" + elif "qwen2.5" in model_name_lower and "72b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-72B-Instruct" + elif "llama" in model_name_lower and "3.1" in model_name_lower and "8b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.1-8B-Instruct" + elif "llama" in model_name_lower and "3.1" in model_name_lower and "70b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.1-70B-Instruct" + elif "llama" in model_name_lower and "3.2" in model_name_lower and "1b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.2-1B-Instruct" + elif "llama" in model_name_lower and "3.2" in model_name_lower and "3b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.2-3B-Instruct" + elif "mistral" in model_name_lower and "7b" in model_name_lower: + fallback_tokenizer_path = "mistralai/Mistral-7B-Instruct-v0.3" + + if fallback_tokenizer_path: + logger.info(f"Attempting to use fallback tokenizer: {fallback_tokenizer_path}") + try: + tokenizer = AutoTokenizer.from_pretrained(fallback_tokenizer_path) + logger.info(f"Successfully loaded fallback tokenizer from {fallback_tokenizer_path}") + except Exception as fallback_e: + logger.error( + f"Failed to load fallback tokenizer from {fallback_tokenizer_path}: {fallback_e}" + ) + raise fallback_e + else: + logger.error(f"No fallback tokenizer found for base model: {self.base_model_name}") + raise e # Add meta-compatible stop token list to the HF tokenizer if not "stop_tokens" in tokenizer.__dict__: @@ -2246,6 +2321,8 @@ def reference_transformer(self, wrap=True, load_checkpoint=False): from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLForConditionalGeneration as AutoModelForCausalLM, ) + elif "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + from transformers import Mistral3ForConditionalGeneration as AutoModelForCausalLM else: from transformers import AutoConfig, AutoModelForCausalLM @@ -2287,6 +2364,8 @@ def reference_vision_multi_modal(self): layer = model.multi_modal_projector # layer._load_state_dict = layer.load_state_dict # layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer def reference_vision_rms_norm(self): @@ -2294,6 +2373,8 @@ def reference_vision_rms_norm(self): layer = model.multi_modal_projector.mm_soft_emb_norm # layer._load_state_dict = layer.load_state_dict # layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer def reference_rms_norm(self): @@ -2303,7 +2384,8 @@ def reference_rms_norm(self): return RMSNorm(self.dim, self.norm_eps) else: model = self.reference_transformer(wrap=False) - layer = model.model.norm + layers = getattr(model, "layers", getattr(model, "model", {}).layers) + layer = layers[0].input_layernorm layer._load_state_dict = layer.load_state_dict layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) return layer @@ -2323,6 +2405,12 @@ def reference_vision_transformer(self, wrap=True, load_checkpoint=False): model = Gemma3ForConditionalGeneration.from_pretrained(self.CKPT_DIR) model = model + elif "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + from transformers import Mistral3ForConditionalGeneration + + model = Mistral3ForConditionalGeneration.from_pretrained(self.CKPT_DIR, torch_dtype=torch.bfloat16) + model = model + else: if self.cached_hf_model is None: model = AutoModelForCausalLM.from_pretrained(self.CKPT_DIR) @@ -2345,16 +2433,44 @@ def reference_gemma_model(self): def reference_vision_model(self): model = self.reference_vision_transformer(wrap=False) - layer = model.vision_tower.vision_model - # layer._load_state_dict = layer.load_state_dict - # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + # Mistral-Small-3.1-24B-Instruct-2503 has a different structure + layer = model.vision_tower + else: + layer = model.vision_tower.vision_model + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer - def reference_vision_mlp(self): + def reference_vision_mlp(self, layer_idx=0): model = self.reference_vision_transformer(wrap=False) - layer = model.vision_tower.vision_model.encoder.layers[0].mlp - # layer._load_state_dict = layer.load_state_dict - # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = model.vision_tower.transformer.layers[layer_idx].feed_forward + else: + layer = model.vision_tower.vision_model.encoder.layers[0].mlp + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_pixtral_image_block(self, layer_num=0): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.transformer.layers[layer_num] + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rms(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.transformer.layers[0].ffn_norm + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_conv2d_patch(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.patch_conv + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer def reference_siglip_patch_embed(self): @@ -2390,11 +2506,22 @@ def reference_vision_layernorm(self, layer_name="layer_norm1"): # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer - def reference_vision_attention(self): + def reference_vision_attention(self, layer_idx=0): model = self.reference_vision_transformer(wrap=False) - layer = model.vision_tower.vision_model.encoder.layers[0].self_attn # Common naming - # layer._load_state_dict = layer.load_state_dict - # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = model.vision_tower.transformer.layers[layer_idx].attention + else: + layer = model.vision_tower.vision_model.encoder.layers[0].self_attn # Common naming + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rot_emb(self): + model = self.reference_vision_transformer(wrap=False) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = model.vision_tower.patch_positional_embedding + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer def reference_vision_encoder_block(self): @@ -2406,9 +2533,12 @@ def reference_vision_encoder_block(self): def reference_vision_encoder(self): model = self.reference_vision_transformer(wrap=False) - layer = model.vision_tower.vision_model.encoder - # layer._load_state_dict = layer.load_state_dict - # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = model.vision_tower.transformer + else: + layer = model.vision_tower.vision_model.encoder + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer def reference_mlp(self): diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/mistral_e2e_model.py b/models/tt_transformers/tt/multimodal/mistral_24b/mistral_e2e_model.py new file mode 100644 index 000000000000..a4257bfe1bc7 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/mistral_e2e_model.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This is the end-to-end pipeline for the Mistral-Small-3.1-24B-Instruct-2503 model. + +The `MistralTransformer` class inherits from the `Transformer` class in tt_transformers. +It overrides `prepare_inputs_prefill` to run inference on the vision model and +pass the resulting visual tokens to the text model along with text tokens. +""" + + +import ttnn +from models.tt_transformers.tt.model import Transformer +from models.tt_transformers.tt.multimodal.mistral_24b.vision_model import TtMistralVisionTransformer + + +class MistralTransformer(Transformer): + def __init__( + self, + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=None, + use_paged_kv_cache=False, + ): + super().__init__( + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=paged_attention_config, + use_paged_kv_cache=use_paged_kv_cache, + ) + + self.vision_model = TtMistralVisionTransformer( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix="vision_tower.", + dtype=dtype, + model_args=args, + tt_ccl=self.tt_ccl, + ) + + def prepare_inputs_prefill(self, pt_tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): + """ + Inputs are torch tensors or python types. This function returns ttnn + tensors on device. + TODO: Debate whether this function is responsible for padding + """ + + S = pt_tokens.shape[-1] + tokens = ttnn.from_torch( + pt_tokens.reshape(1, 1, 1, -1), + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + tokens_embd = self.embd(tokens) + + vision_output = self.compute_vision_token(**kwargs) + + if vision_output is not None: + tokens_embd = ttnn.to_torch(tokens_embd, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1)) + comp_vision_output = ttnn.to_torch( + vision_output, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0) + )[: vision_output.shape[0], :] + + image_features = comp_vision_output.squeeze(0) + special_image_mask = (pt_tokens == 10).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(tokens_embd) + image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) + tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) + + tokens_embd = self.args.prepare_residual_tensor_prefill( + tokens_embd, + ) + + tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) + # Slice the rot mats to the prefill seqlen + assert ( + self.rope_setup.cos_matrix.shape[2] >= start_pos + S + ), f"Padded prefill end idx {start_pos + S} exceeds max seq len {self.rope_setup.cos_matrix.shape[2]}" + + tt_rot_mats_prefill_global = [ + self.rope_setup.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + + if hasattr(self, "rope_local_setup"): + tt_rot_mats_prefill_local = [ + self.rope_local_setup.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_local_setup.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + else: + tt_rot_mats_prefill_local = None + + if page_table is not None: + tt_page_table = ttnn.from_torch( + page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_page_table = None + + if chunk_page_table is not None: + tt_chunk_page_table = ttnn.from_torch( + chunk_page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_chunk_page_table = None + + return tokens_embd, tt_rot_mats_prefill_global, tt_rot_mats_prefill_local, tt_page_table, tt_chunk_page_table + + def compute_vision_token(self, pixel_values, image_sizes): + if pixel_values is not None: + vision_output = self.vision_model(pixel_values, image_sizes) + return vision_output + return None diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/mistral_vision_tower.py b/models/tt_transformers/tt/multimodal/mistral_24b/mistral_vision_tower.py new file mode 100644 index 000000000000..d875011d2712 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/mistral_vision_tower.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This file implements the Vision Tower submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +This pipeline constructs the vision tower from vision model architecture. +""" + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.common import position_ids_in_meshgrid_tt +from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm +from models.tt_transformers.tt.multimodal.mistral_24b.vision_conv2d import TtMistralConv2dPatch +from models.tt_transformers.tt.multimodal.mistral_24b.vision_pixtral_transformer import TtPixtralTransformer +from models.tt_transformers.tt.multimodal.mistral_24b.vision_rope import VisionRotarySetup as RotarySetup +from ttnn import ConcatMeshToTensor + + +class MistralVisionTower(LightweightModule): + def __init__( + self, + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix, + dtype, + configuration, + return_intermediate=None, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.tt_ccl = tt_ccl + self.dtype = dtype + self.config = configuration + + self.image_size = configuration.vision_chunk_size + self.patch_size = configuration.vision_patch_size + self.width = configuration.vision_dim + self.layers = configuration.vision_n_layers + self.heads = configuration.vision_attn_n_heads + self.vision_head_dim = configuration.vision_head_dim + self.mlp_ratio = configuration.vision_mlp_ratio + self.act_layer = configuration.vision_act_layer + self.in_channels = configuration.vision_in_channels + self.n_global_layers = configuration.vision_n_global_layers + self.max_seq_len = configuration.max_seq_len + self.return_intermediate = return_intermediate + self.n_layers = configuration.vision_n_layers + + in_channels, out_channels, kernel_size, stride, bias = ( + 3, + configuration.vision_dim, + configuration.vision_patch_size, + configuration.vision_patch_size, + False, + ) + + self.patch_conv = TtMistralConv2dPatch( + mesh_device=self.mesh_device, + state_dict=self.state_dict, + state_dict_prefix=f"{state_dict_prefix}patch_conv.", + dtype=self.dtype, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + bias=bias, + ) + + self.ln_pre = RMSNorm( + device=mesh_device, + dim=self.width, + state_dict=self.state_dict, + state_dict_prefix=state_dict_prefix, + weight_dtype=dtype, + weight_key="ln_pre", + is_distributed=False, + simplified_rms=True, + ) + + image_size = configuration.vision_image_size + patch_size = configuration.vision_patch_size + dim = configuration.vision_head_dim + num_patches_per_dim = image_size // patch_size + num_patches = num_patches_per_dim * num_patches_per_dim + self.num_patches = num_patches + + self.patch_positional_embedding = RotarySetup( + self.mesh_device, + 1, + dim, + image_size, + patch_size, + num_patches, + configuration.vision_rope_theta, + scale_factor=None, + orig_context_len=num_patches, + datatype=dtype, + ) + + self.transformer = TtPixtralTransformer( + mesh_device=self.mesh_device, + tt_ccl=tt_ccl, + state_dict=self.state_dict, + state_dict_prefix=f"{state_dict_prefix}transformer.", + weight_cache_path=configuration.weight_cache_path(dtype), + dtype=self.dtype, + configuration=configuration, + layers=self.n_layers, + ) + + def forward(self, input_tensor, image_sizes=None): + """ + input_tensor shape: (B, C, H, W) + """ + patch_embeds = self.patch_conv(input_tensor) + patch_embeds = ttnn.transpose(patch_embeds, 1, 2) + height, width = image_sizes[0] + patch_embeds = ttnn.reshape( + patch_embeds, + [patch_embeds.shape[0], self.width, height // self.patch_size, width // self.patch_size], + ) + + patch_embeds_list = [ + ttnn.slice( + patch_embeds, + [0, 0, 0, 0], + [1, self.width, size[0] // self.patch_size, size[1] // self.patch_size], + ) + for size in image_sizes + ] + + reshaped_patches = [] + for p in patch_embeds_list: + p = ttnn.reshape(p, (1, self.width, -1)) + p = ttnn.transpose(p, 1, 2) + reshaped_patches.append(p) + + patch_embeds = ttnn.concat(reshaped_patches, dim=0) + + # ln_pre RMS Norm + mode = "prefill" + patch_embeds = self.ln_pre(patch_embeds, mode=mode) + + # # positional embeddings + position_ids = position_ids_in_meshgrid_tt( + patch_embeds_list, + max_width=self.config.vision_image_size // self.config.vision_patch_size, + device=self.mesh_device, + ) + + torch_position_ids = ttnn.to_torch(position_ids, mesh_composer=ConcatMeshToTensor(self.mesh_device, dim=0))[ + : position_ids.shape[-1] + ] + + position_embeddings = self.patch_positional_embedding.get_rot_mats(torch_position_ids) + + patch_embeds = ttnn.unsqueeze(patch_embeds, 0) + out = self.transformer(patch_embeds, position_embeddings=position_embeddings) + # deallocate position_embeddings + ttnn.deallocate(position_embeddings[0]) + ttnn.deallocate(position_embeddings[1]) + + return out diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/rmsnorm.py b/models/tt_transformers/tt/multimodal/mistral_24b/rmsnorm.py new file mode 100644 index 000000000000..0aa7cec84448 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/rmsnorm.py @@ -0,0 +1,202 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This is the modified version of the rmsnorm for the Mistral-Small-3.1-24B-Instruct-2503 model. +We introduced the `simplified_rms_norm` function to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + +import ttnn +from models.common.lightweightmodule import LightweightModule + +TILE = 32 +SHARD_HEIGHT = TILE # Current ttnn.rms_norm implementation requires shard height to be a single tile + + +class RMSNorm(LightweightModule): + """ + RMSNorm supporting replication over a MeshDevice and sharding within devices. + + This class implements a Root Mean Square Normalization (RMSNorm) that can be + distributed across multiple devices and cores. If the `device` parameter is a + MeshDevice, the weights and computations are replicated across all devices in + the mesh. Expects an interleaved input tensor, can optionally output a sharded tensor. + + Args: + device: The device or MeshDevice on which to perform the computations. + state_dict: The state dictionary containing the model parameters. + dim: Input dimension (e.g. model hidden dimension size). + layer_num: The layer number to determine the weight key in the state dictionary. + weight_key: The key for retrieving the weight from the state dictionary. + weight_cache_path: Optional path for caching the tilized weights. + weight_memory_config: Configuration for the weight memory, default is DRAM_MEMORY_CONFIG. + weight_dtype: The data type for the tensors, bfp8_b hits >0.999 PCC in the models we tested. + model_config: Optional configuration dictionary for the model. + eps (float): Small value to avoid division by zero in normalization, default is 1e-05. + + If model_config is provided, it must specify SHARDED_NORM_INPUT_MEMCFG, SHARDED_NORM_PRGM_CFG + and SHARDED_NORM_OUTPUT_MEMCFG. If not provided, default configurations will be generated. + """ + + def __init__( + self, + device, + dim, + state_dict, + weight_key, + layer_num=None, + state_dict_prefix=None, + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps: float = 1e-05, + add_unit_offset=False, + sharded_program_config=None, + sharded_output_config=None, + output_mem_config=None, + ccl_topology=ttnn.Topology.Ring, + simplified_rms=False, + ): + super().__init__() + self.eps = eps + self.is_distributed = is_distributed + self.ccl_topology = ccl_topology + + if state_dict_prefix: + weight_name = f"{state_dict_prefix}{weight_key}.weight" + else: + if layer_num is None: + weight_name = f"{weight_key}.weight" + else: + weight_name = f"layers.{layer_num}.{weight_key}.weight" + + torch_weight = ( + state_dict[weight_name].unsqueeze(0).view(1, 1, dim).reshape([1, 1, dim // SHARD_HEIGHT, SHARD_HEIGHT]) + ) + + # Add offset before caching + if add_unit_offset: + torch_weight = torch_weight + 1.0 + + cache_name = None if weight_cache_path is None else weight_cache_path / weight_name + + # Compatibility with models that don't use mesh devices (e.g. single-chip Mistral-7b) + is_mesh_device = device.__class__.__name__ == "MeshDevice" + + self.weight = ttnn.as_tensor( + state_dict[weight_name].unsqueeze(0).view(1, 1, dim), + device=device, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT if weight_dtype == ttnn.bfloat8_b else ttnn.ROW_MAJOR_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=cache_name, + mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + ) + + if self.is_distributed: + self.weight_distributed = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=cache_name, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape)) + if is_mesh_device + else None, + ) + + self.sharded_output_config = sharded_output_config + self.sharded_program_config = sharded_program_config + self.output_mem_config = output_mem_config + + self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=False, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + self.simplified_rms = simplified_rms + + def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) -> ttnn.Tensor: + # If input is sharded do sharded RMSNorm and optionally return sharded output + program_config = self.sharded_program_config if in_sharded else None + memory_config = self.sharded_output_config if out_sharded else None + distributed = self.is_distributed and self.is_distributed(mode) + norm = ( + self._simplified_rmsnorm + if self.simplified_rms + else self._distributed_rmsnorm + if distributed + else ttnn.rms_norm + ) + + weight = self.weight_distributed if distributed else self.weight + + if in_sharded: + assert not distributed, "Distributed RMSNorm does not support sharded inputs" + else: + assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor" + + x = norm( + x, + epsilon=self.eps, + weight=weight, + program_config=program_config, + memory_config=memory_config, + compute_kernel_config=self.compute_kernel_config_hifi2, + ) + + if in_sharded and not out_sharded: + return ttnn.sharded_to_interleaved(x) + else: + return x + + def _simplified_rmsnorm( + self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None + ): + inp = ttnn.sharded_to_interleaved(inp, ttnn.DRAM_MEMORY_CONFIG) + xnorm = ttnn.pow(inp, 2) + xnorm = ttnn.mean(xnorm, dim=-1, keepdim=True) + xnorm = ttnn.rsqrt(xnorm + epsilon) + xnorm = ttnn.multiply(inp, xnorm) + weight = ttnn.reshape(weight, [1, 1, -1]) + output = ttnn.multiply(xnorm, (weight), use_legacy=False) + + if memory_config is not None: + output = ttnn.to_memory_config(output, memory_config) + + ttnn.deallocate(xnorm) + ttnn.deallocate(weight) + + return output + + def _distributed_rmsnorm( + self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None + ): + assert program_config is None, "Distributed RMSNorm does not support sharded inputs" + assert memory_config is None, "Distributed RMSNorm does not support sharded outputs" + + # Run distributed rmsnorm part 1 + tt_stats = ttnn.rms_norm_pre_all_gather(inp, compute_kernel_config=compute_kernel_config, dtype=ttnn.bfloat16) + # AllGather stats + tt_stats = ttnn.all_gather( + tt_stats, + dim=3, + num_links=1, + topology=self.ccl_topology, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + # Run distributed rmsnorm part 2 + tt_out = ttnn.rms_norm_post_all_gather( + inp, + tt_stats, + epsilon=epsilon, + weight=weight, + compute_kernel_config=compute_kernel_config, + ) + tt_stats.deallocate(True) + + return tt_out diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_attention.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_attention.py new file mode 100644 index 000000000000..03c6f755dfc9 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_attention.py @@ -0,0 +1,269 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 +""" +This is the modified version of the vision_attention for the Mistral-Small-3.1-24B-Instruct-2503 model. +We introduced the `apply_rotary_pos_emb_vision_tt` function to llama_image_attention to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model. + +""" +import torch +import ttnn + +from models.common.lightweightmodule import LightweightModule +from models.utility_functions import is_blackhole, nearest_32 + + +def rotate_half(x): + last_dim = x.shape[-1] + half = last_dim // 2 + + x1 = ttnn.slice(x, (0, 0, 0, 0), (x.shape[0], x.shape[1], x.shape[2], half)) + x2 = ttnn.slice(x, (0, 0, 0, half), (x.shape[0], x.shape[1], x.shape[2], last_dim)) + + neg_x2 = ttnn.mul(x2, -1, use_legacy=False) + return ttnn.concat([neg_x2, x1], dim=-1) + + +def apply_rotary_pos_emb_vision_tt(q, k, cos, sin): + cos = ttnn.unsqueeze(cos, 0) + sin = ttnn.unsqueeze(sin, 0) + + q_embed = ttnn.add(ttnn.mul(q, cos, use_legacy=True), ttnn.mul(rotate_half(q), sin, use_legacy=True)) + k_embed = ttnn.add(ttnn.mul(k, cos), ttnn.mul(rotate_half(k), sin)) + return q_embed, k_embed + + +class TtMistralImageAttention(LightweightModule): + def __init__( + self, + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.tt_ccl = tt_ccl + self.num_devices = configuration.num_devices + + self.hidden_size = configuration.vision_dim + self.n_heads = configuration.vision_attn_n_heads + self.head_dim = self.hidden_size // self.n_heads + self.n_kv_heads = self.n_heads + + self.n_local_heads = self.n_heads // configuration.num_devices + self.n_local_kv_heads = self.n_kv_heads // configuration.num_devices + + self.dtype = dtype + + self.grid_size = configuration.max_grid_size + + self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 + self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 + self.compute_kernel_config_sdpa = configuration.compute_kernel_config_sdpa + self.configuration = configuration + + self.model_config = configuration.get_model_config() + + if configuration.dummy_weights or (weight_cache_path is None): + cache_name = lambda _: None + else: + cache_name = lambda name: weight_cache_path / (f"{state_dict_prefix}{name}") + + wq_str = f"{state_dict_prefix}wq.weight" + wk_str = f"{state_dict_prefix}wk.weight" + wv_str = f"{state_dict_prefix}wv.weight" + wo_str = f"{state_dict_prefix}wo.weight" + + # when splitting the devices, we need to make sure that the number of heads is divisible by the number of devices + assert self.n_heads % configuration.num_devices == 0 + assert self.n_kv_heads % configuration.num_devices == 0 + + # Pad head_dim to multiple of 32 + def pad_head_dim(weight, heads_out=True): + # Pad head dim to multiple of 32 + # heads_out means that the output dim of this weight contains heads. + dim = weight.shape[1] + assert weight.shape[0] == dim + padded_head_dim = nearest_32(self.head_dim) + padding_size = padded_head_dim - self.head_dim + if padding_size > 0: + if heads_out: + weight = weight.transpose(-1, -2) + weight = weight.reshape(dim, self.n_heads, self.head_dim) + padding = torch.zeros(dim, self.n_heads, padding_size, dtype=weight.dtype) + weight = torch.cat([weight, padding], dim=-1) + weight = weight.reshape(dim, self.n_heads * padded_head_dim) + if heads_out: + weight = weight.transpose(-1, -2) + return weight + + wq_padded = pad_head_dim(self.state_dict[wq_str]) + wk_padded = pad_head_dim(self.state_dict[wk_str]) + wv_padded = pad_head_dim(self.state_dict[wv_str]) + wo_padded = pad_head_dim(self.state_dict[wo_str], heads_out=False) + wq_chunked, wk_chunked, wv_chunked = ( + torch.chunk(w, configuration.num_devices) for w in [wq_padded, wk_padded, wv_padded] + ) + + self.wqkv = ttnn.as_tensor( + torch.concat( + [ + torch.concat( + [ + torch.transpose( + wq_chunked[i], + -2, + -1, + ), + torch.transpose( + wk_chunked[i], + -2, + -1, + ), + torch.transpose( + wv_chunked[i], + -2, + -1, + ), + ], + dim=-1, + ) + for i in range(configuration.num_devices) + ], + dim=-1, + ), + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + dtype=self.dtype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + ) + + self.wo = ttnn.as_tensor( + torch.transpose( + wo_padded, + -2, + -1, + ), + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-2), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=self.dtype, + layout=ttnn.TILE_LAYOUT, + ) + + self.scale = self.head_dim**-0.5 + + def forward(self, x_11SH, position_embeddings=None): + seq_len = x_11SH.shape[-2] + + MAX_MM_SEQ_LEN = seq_len + + if seq_len > MAX_MM_SEQ_LEN: + x_11SH = ttnn.reshape(x_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) + + xqkv_fused = ttnn.linear( + x_11SH, + self.wqkv, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config_hifi4, + program_config=self.model_config["IMAGE_ATTN_QKV_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + if seq_len > MAX_MM_SEQ_LEN: + xqkv_fused = ttnn.reshape(xqkv_fused, [1, 1, seq_len, -1]) + + # split qkv into heads + ( + q_heads_1QSD, + k_heads_1KSD, + v_heads_1VSD, + ) = ttnn.experimental.nlp_create_qkv_heads( + xqkv_fused, + num_heads=self.n_local_heads, + num_kv_heads=self.n_local_kv_heads, + transpose_k_heads=False, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + if position_embeddings is not None: + cos, sin = position_embeddings + q_heads_1QSD, k_heads_1KSD = apply_rotary_pos_emb_vision_tt(q_heads_1QSD, k_heads_1KSD, cos, sin) + ttnn.deallocate(xqkv_fused) + # TODO: get this from model_config + sdpa_cfg = ttnn.SDPAProgramConfig( + compute_with_storage_grid_size=(8, 8), q_chunk_size=128, k_chunk_size=128, exp_approx_mode=False + ) + attn_output_1QSD = ttnn.transformer.scaled_dot_product_attention( + q_heads_1QSD, + k_heads_1KSD, + v_heads_1VSD, + is_causal=False, + scale=self.scale, + program_config=sdpa_cfg, + compute_kernel_config=self.compute_kernel_config_sdpa, + ) + # deallocate keys and values + ttnn.deallocate(q_heads_1QSD) + ttnn.deallocate(k_heads_1KSD) + ttnn.deallocate(v_heads_1VSD) + + ### + # Output matmul + ### + attn_output_11SH = ttnn.experimental.nlp_concat_heads( + attn_output_1QSD, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + ttnn.deallocate(attn_output_1QSD) + + # reshaping long sequence to matmul fit on device + if seq_len > MAX_MM_SEQ_LEN: + attn_output_11SH = ttnn.reshape(attn_output_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) + + output_11SH = ttnn.linear( + attn_output_11SH, + self.wo, + compute_kernel_config=self.compute_kernel_config_hifi4, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + program_config=self.model_config["IMAGE_ATTN_OUT_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + if seq_len > MAX_MM_SEQ_LEN: + output_11SH = ttnn.reshape(output_11SH, [1, 1, seq_len, -1]) + ttnn.deallocate(attn_output_11SH) + + # All reduce + if self.num_devices > 1: # replace with reduce_scatter and all_gather + # TODO: 26411 + # Remove this blackhole condition once fabric CCLs are working on blackhole + if is_blackhole(): + dense_out_gathered = ttnn.all_gather(output_11SH, dim=1, num_links=1, topology=ttnn.Topology.Linear) + else: + dense_out_gathered = ttnn.experimental.all_gather_async( + output_11SH, + persistent_output_buffer=None, + dim=1, + multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(), + num_links=1, + topology=ttnn.Topology.Linear, + barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(), + chunks_per_sync=10, + num_workers_per_link=2, + num_buffers_per_channel=2, + ) + output_11SH.deallocate(True) + dense_out_reduced = ttnn.experimental.fast_reduce_nc( + dense_out_gathered, dims=[1], output=None, compute_kernel_config=None + ) + # slicing the required sequence length + dense_out_reduced = dense_out_reduced[:, :, : dense_out_gathered.shape[-2], :] + return dense_out_reduced + else: + return output_11SH diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_conv2d.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_conv2d.py new file mode 100644 index 000000000000..4dc76f9f5ada --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_conv2d.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This is the modified version of the vision_patch_conv2d for the Mistral-Small-3.1-24B-Instruct-2503 model. +We have modified the llama_patch_conv2d to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + +import torch +import ttnn + +from models.common.lightweightmodule import LightweightModule + + +class TtMistralConv2dPatch(LightweightModule): + """Conv2D Patching layer. + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + Input: (bsz, in_channels, width, height) + Output: (bsz, num_tokens, out_channels) + """ + + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + dtype, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + bias, + ): + super().__init__() + + self.mesh_device = mesh_device + self.num_devices = self.mesh_device.get_num_devices() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + + self.bias = ( + ttnn.as_tensor( + torch.reshape(state_dict[f"{state_dict_prefix}_linear.bias"], (1, -1)), + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + if bias + else None + ) + + self._unfold = torch.nn.Unfold(kernel_size=self.kernel_size, stride=self.stride) + + weight = state_dict[f"{state_dict_prefix}_linear.weight"] + if weight.ndim == 4: + weight = weight.reshape(out_channels, -1).T + + self._linear_weight = ttnn.as_tensor( + weight, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + self.compute_kernel_config = ttnn.init_device_compute_kernel_config( + mesh_device.arch(), + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + def forward(self, x: torch.Tensor): + x = self._unfold(x) + x = x.permute(0, 2, 1) + + x = ttnn.as_tensor( + x, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + out = ttnn.linear( + x, + self._linear_weight, + bias=self.bias, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + core_grid=ttnn.CoreGrid(y=8, x=8), + ) + + return out diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_mlp.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_mlp.py new file mode 100644 index 000000000000..30c84ea94f03 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_mlp.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This is the modified version of the FeedForward for the Mistral-Small-3.1-24B-Instruct-2503 model. +This file implements the Vision FeedForward submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + +import torch +import ttnn + +from models.common.lightweightmodule import LightweightModule + + +class MistralTTVisionMLP(LightweightModule): + def __init__( + self, + mesh_device, + args, + state_dict, + weight_cache_path, + dtype, + state_dict_prefix=None, + ): + super().__init__() + + self.mesh_device = mesh_device + self.args = args + self.state_dict = state_dict + self.dim = args.dim + + def get_weight(name): + return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) + + def get_bias(name): + return state_dict[f"{state_dict_prefix}{name}.bias"] + + def cache_name(name): + if args.dummy_weights: + return None + return weight_cache_path / f"{state_dict_prefix}.{name}" + + def as_tensor(name, dtype, is_bias=False): + tensor_data = get_bias(name) if is_bias else get_weight(name) + return ttnn.as_tensor( + tensor_data, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + # Weights and Biases + self.w1 = as_tensor("w1", dtype) + self.b1 = as_tensor("w1", ttnn.bfloat16, is_bias=False) + + self.w3 = as_tensor("w3", dtype) + self.b3 = as_tensor("w3", ttnn.bfloat16, is_bias=False) + + self.w2 = as_tensor("w2", dtype) + self.b2 = as_tensor("w2", ttnn.bfloat16, is_bias=False) + + self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=True, + packer_l1_acc=True, + dst_full_sync_en=False, + ) + + def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: + """ + Qwen HF MLP reference: + output = down_proj(act_fn(gate_proj(x)) * up_proj(x)) + Mapping: + w1 -> gate_proj + w3 -> up_proj + w2 -> down_proj + """ + + # Linear with SILU activation + w1_out = ttnn.linear( + x, + self.w1, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + activation="silu", + compute_kernel_config=self.compute_kernel_config, + ) + + w3_out = ttnn.linear( + x, + self.w3, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + + # Element-wise multiply + w2_in = ttnn.mul(w1_out, w3_out, dtype=ttnn.bfloat16) + + # Final projection + w2_out = ttnn.linear( + w2_in, + self.w2, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + + ttnn.deallocate(w1_out) + ttnn.deallocate(w3_out) + ttnn.deallocate(w2_in) + return w2_out diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_mmp.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_mmp.py new file mode 100644 index 000000000000..2dd21e0d0177 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_mmp.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +""" +This file implements the Vision MultiModalProjector submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm +from ttnn import ConcatMeshToTensor + + +class TTMistral3PatchMerger(LightweightModule): + def __init__( + self, + mesh_device, + args, + state_dict, + state_dict_prefix, + weight_cache_path=None, + dtype=ttnn.bfloat16, + ): + super().__init__() + self.device = mesh_device + hidden_size = args.vision_dim + self.spatial_merge_size = 2 + self.patch_size = args.vision_patch_size + self.args = args + + def get_weight(name): + return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) + + def get_bias(name): + return state_dict[f"{state_dict_prefix}{name}.bias"] + + def cache_name(name): + if args.dummy_weights: + return None + return weight_cache_path / f"{state_dict_prefix}.{name}" + + def as_tensor(name, dtype, is_bias=False): + tensor_data = get_bias(name) if is_bias else get_weight(name) + return ttnn.as_tensor( + tensor_data, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + self.merging_weights = as_tensor("merging_layer", dtype) + self.merging_bias = as_tensor("merging_layer", ttnn.bfloat16, is_bias=False) + + def forward(self, image_features: ttnn.Tensor, image_sizes) -> ttnn.Tensor: + image_sizes = [ + (image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes + ] + + tokens_per_image = [h * w for h, w in image_sizes] + d = image_features.shape[-1] + + permuted_tensor = [] + for image_index, image_tokens in enumerate(ttnn.split(image_features, tokens_per_image, dim=0)): + # Reshape image_tokens into a 2D grid + h, w = image_sizes[image_index] + + image_tokens = ttnn.to_layout(image_tokens, ttnn.ROW_MAJOR_LAYOUT) + + image_grid = ttnn.view(image_tokens, (h, w, d)) + # Permute the grid to have channels last + image_grid = ttnn.permute(image_grid, (2, 0, 1)) # Channels first + image_grid = ttnn.unsqueeze(image_grid, dim=0) # Add batch dimension + # Reshape the grid to merge patches + if self.args.num_devices > 1: + image_grid_torch = ttnn.to_torch(image_grid, mesh_composer=ConcatMeshToTensor(self.device, dim=0)) + image_grid_torch = image_grid_torch[0].unsqueeze(0) # shape: [1, 1024, 30, 44] + image_grid_torch = image_grid_torch.to(dtype=torch.bfloat16) + else: + image_grid_torch = ttnn.to_torch(image_grid).to(dtype=torch.bfloat16) + + grid = torch.nn.functional.unfold( + image_grid_torch, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size + ) + + grid = ttnn.from_torch(grid, device=self.device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + + grid = ttnn.view(grid, (d * self.spatial_merge_size**2, -1)) + grid = ttnn.transpose(grid, 0, 1) # Transpose to have features first + + permuted_tensor.append(grid) + + image_features = ttnn.concat(permuted_tensor, dim=0) + # Apply merging layer + image_features = ttnn.linear( + image_features, self.merging_weights, dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + + return image_features + + +class TTMistral3MultiModalProjector(LightweightModule): + def __init__(self, mesh_device, args, state_dict, state_dict_prefix, dtype, eps, weight_cache_path=None): + super().__init__() + + self.norm = RMSNorm( + device=mesh_device, + dim=args.vision_dim, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="norm", + weight_dtype=dtype, + eps=eps, + is_distributed=False, + simplified_rms=True, + ) + + self.patch_merger = TTMistral3PatchMerger( + mesh_device=mesh_device, + args=args, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}patch_merger.", + ) + + def get_weight(name): + return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) + + def get_bias(name): + return state_dict[f"{state_dict_prefix}{name}.bias"] + + def cache_name(name): + if args.dummy_weights: + return None + return weight_cache_path / f"{state_dict_prefix}.{name}" + + def as_tensor(name, dtype, is_bias=False): + tensor_data = get_bias(name) if is_bias else get_weight(name) + return ttnn.as_tensor( + tensor_data, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + self.linear_1_weight = as_tensor("linear_1", dtype) + self.linear_1_bias = as_tensor("linear_1", ttnn.bfloat16, is_bias=False) + + self.linear_2_weight = as_tensor("linear_2", dtype) + self.linear_2_bias = as_tensor("linear_2", ttnn.bfloat16, is_bias=False) + + def forward(self, image_features: ttnn.Tensor, image_sizes): + image_features = self.norm(image_features, mode="decode") + image_features = self.patch_merger(image_features, image_sizes) + + hidden_states = ttnn.linear( + image_features, + self.linear_1_weight, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + activation="gelu", # Using GELU activation as per Mistral 3 model + ) + + hidden_states = ttnn.linear( + hidden_states, self.linear_2_weight, dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + + return hidden_states diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_model.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_model.py new file mode 100644 index 000000000000..f85cbf2e7ecc --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_model.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This is the end-to-end architecture of the Mistral-24B vision model. + +It brings together all components related to visual and MultiModalProjector together. +""" + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.mistral_24b.mistral_vision_tower import MistralVisionTower +from models.tt_transformers.tt.multimodal.mistral_24b.vision_mmp import TTMistral3MultiModalProjector + + +class TtMistralVisionTransformer(LightweightModule): + def __init__(self, mesh_device, tt_ccl, state_dict, state_dict_prefix, dtype, model_args): + super().__init__() + self.state_dict = state_dict + self.mesh_device = mesh_device + self.tt_ccl = tt_ccl + + self.vision_tower = MistralVisionTower( + mesh_device=mesh_device, + tt_ccl=self.tt_ccl, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + dtype=dtype, + configuration=model_args, + ) + + self.mmp = TTMistral3MultiModalProjector( + mesh_device=mesh_device, + args=model_args, + state_dict=state_dict, + state_dict_prefix="multi_modal_projector.", + dtype=dtype, + eps=1e-05, # layer_norm_eps + ) + + def forward(self, input_tensor, image_sizes=None): + """ + input_tensor shape: (B, C, H, W) + """ + + x = self.vision_tower(input_tensor, image_sizes=image_sizes) + x = ttnn.squeeze(ttnn.squeeze(x, 0), 0) + x = self.mmp(x, image_sizes) + return x diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_image_block.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_image_block.py new file mode 100644 index 000000000000..a564dc282ba6 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_image_block.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm +from models.tt_transformers.tt.multimodal.mistral_24b.vision_attention import ( + TtMistralImageAttention as TtLlamaImageAttention, +) +from models.tt_transformers.tt.multimodal.mistral_24b.vision_mlp import MistralTTVisionMLP as MLP + +""" +This file implements the pixtral image block specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + + +class TtPixtralImageTransformerBlock(LightweightModule): + def __init__( + self, + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + ): + super().__init__() + self.state_dict = state_dict + self.mesh_device = mesh_device + self.tt_ccl = tt_ccl + self.configuration = configuration + self.num_devices = configuration.num_devices + self.hidden_size = configuration.vision_dim + + self.attention_norm = RMSNorm( + device=mesh_device, + dim=configuration.vision_dim, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="attention_norm", + weight_dtype=dtype, + is_distributed=False, + simplified_rms=True, + ) + + self.attention = TtLlamaImageAttention( + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix=f"{state_dict_prefix}attention.", + weight_cache_path=weight_cache_path, + dtype=dtype, + configuration=configuration, + ) + + self.ffn_norm = RMSNorm( + device=mesh_device, + dim=configuration.vision_dim, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="ffn_norm", + weight_dtype=dtype, + is_distributed=False, + simplified_rms=True, + ) + + self.mlp = MLP( + mesh_device=mesh_device, + args=configuration, + state_dict=state_dict, + weight_cache_path=weight_cache_path, + state_dict_prefix=f"{state_dict_prefix}feed_forward.", + dtype=dtype, + ) + + def forward(self, x_input, position_embeddings=None): + mode = "prefill" + # attention norm Input and result replicated + attn_norm_res = self.attention_norm(x_input, mode=mode) + # attention Input and results replicated + attn_out = self.attention(attn_norm_res, position_embeddings=position_embeddings) + res = ttnn.add(x_input, attn_out, use_legacy=True) + ffn_norm_res = self.ffn_norm(res, mode=mode) + mlp_out = self.mlp(ffn_norm_res) + out = ttnn.add(res, mlp_out) + return out diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_transformer.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_transformer.py new file mode 100644 index 000000000000..d21e417875f0 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_transformer.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This file implements the Vision Transformer submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +This pipeline iterates over the pixtral image blocks to generate the image embeddings. +""" + +from tqdm import tqdm + +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.mistral_24b.vision_pixtral_image_block import TtPixtralImageTransformerBlock + + +class TtPixtralTransformer(LightweightModule): + def __init__( + self, + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + layers, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.tt_ccl = tt_ccl + + block_key = "layers" + self.resblocks = [ + TtPixtralImageTransformerBlock( + mesh_device=mesh_device, + tt_ccl=self.tt_ccl, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}{block_key}.{i}.", + weight_cache_path=weight_cache_path, + dtype=dtype, + configuration=configuration, + ) + for i in tqdm(range(layers), desc=f"Loading vision transformer layers") + ] + + def forward(self, x, return_intermediate=None, position_embeddings=None): + """ + Different from reference impl in that if return_intermediates, it returns + a list of intermediate tensors rather than a stack of intermediates. + Outer code will have to be aware and handle this correctly. + """ + out = [] + for idx, r in enumerate(self.resblocks): + if return_intermediate is not None and idx in return_intermediate: + out.append(x) + x = r(x, position_embeddings=position_embeddings) + if return_intermediate is not None: + return x, out + return x diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_rope.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_rope.py new file mode 100644 index 000000000000..bb299dc4ca07 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_rope.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This is the modified version of the RoPE for the Mistral-Small-3.1-24B-Instruct-2503 model. +We have modified the compute_gather_cos_sin function of RMSNorm to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + +import torch +import ttnn + +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.common import precompute_mistral_vision_freqs +from ttnn import ReplicateTensorToMesh + + +def compute_gather_cos_sin(dhead, max_patches_per_side, theta, scale_factor, orig_context_len, position_ids): + cos, sin = precompute_mistral_vision_freqs(dhead, max_patches_per_side, theta, scale_factor, orig_context_len) + return cos, sin + + +class VisionRotarySetup(LightweightModule): + def __init__( + self, + device, + batch_size: int, + head_dim: int, + image_size: int, + patch_size: int, + max_seq_len: int, + rope_theta: float, + scale_factor: float, # use None to disable rope scaling + orig_context_len: int, # only used if scaling enabled + datatype=ttnn.bfloat16, + ): + super().__init__() + + self.batch_size = batch_size + self.head_dim = head_dim + self.device = device + self.is_mesh_device = isinstance(device, ttnn._ttnn.multi_device.MeshDevice) + self.num_devices = device.get_num_devices() if self.is_mesh_device else 1 + if self.num_devices == 32: + self.batch_size_per_device_group = max(self.batch_size // list(device.shape)[1], 1) + else: + self.batch_size_per_device_group = self.batch_size + self.core_grid = device.compute_with_storage_grid_size() + + max_patches_per_side = image_size // patch_size + + # Generate the cos/sin matrices needed for ttnn.embedding op + cos_matrix, sin_matrix = compute_gather_cos_sin( + dhead=head_dim, + max_patches_per_side=max_patches_per_side, + theta=rope_theta, + scale_factor=scale_factor, + orig_context_len=orig_context_len, + position_ids=torch.arange(max_seq_len), + ) + self.cos_matrix = ttnn.from_torch( + cos_matrix, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=datatype, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + self.sin_matrix = ttnn.from_torch( + sin_matrix, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=datatype, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + + def get_rot_mats(self, position_idxs, return_rot_idxs=False): + device = self.device + + # If position_idxs is a torch tensor, get the TTNN version of it + if isinstance(position_idxs, torch.Tensor): + rot_idxs = position_idxs.unsqueeze(0) + else: + rot_idxs = position_idxs + assert len(rot_idxs.shape) == 2 and rot_idxs.shape[0] == 1, "rot_idxs must be a [1, batch] tensor" + + rot_idxs = ttnn.from_torch( + rot_idxs, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=ttnn.uint32, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + # Send the idxs to device + if rot_idxs.device != device: + rot_idxs = ttnn.to_device(rot_idxs, device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + + embedding_layout = ttnn.TILE_LAYOUT + cos = ttnn.embedding(rot_idxs, self.cos_matrix, layout=embedding_layout) # [1, batch, head_dim] + sin = ttnn.embedding(rot_idxs, self.sin_matrix, layout=embedding_layout) # [1, batch, head_dim] + + if return_rot_idxs: + return [cos, sin], rot_idxs + ttnn.deallocate(rot_idxs) + return [cos, sin]