From c216b15402912d7bb663f37df4e75837f2dd33d2 Mon Sep 17 00:00:00 2001 From: nikileshx Date: Thu, 3 Jul 2025 12:38:52 +0000 Subject: [PATCH] Add Support for mistralai/Mistral-Small-3.1-24B-Instruct-2503 model --- .../tests/pipeline_tests/test_end2end.py | 530 ++++++++++++++++++ .../tests/pipeline_tests/test_vision_model.py | 97 ++++ .../tests/pipeline_tests/test_vision_tower.py | 75 +++ .../mistral_24b/tests/test_conv2d.py | 101 ++++ .../mistral_24b/tests/test_patch_rot_emb.py | 93 +++ .../tests/test_pixtral_transformer.py | 120 ++++ .../tests/test_vision_attention.py | 122 ++++ .../mistral_24b/tests/test_vision_mlp.py | 89 +++ .../mistral_24b/tests/test_vision_rms.py | 99 ++++ models/experimental/mistral_24b/tt/model.py | 134 +++++ .../tt/pipeline/mistral_vision_tower.py | 169 ++++++ .../mistral_24b/tt/pipeline/vision_model.py | 51 ++ models/experimental/mistral_24b/tt/rmsnorm.py | 202 +++++++ .../mistral_24b/tt/vision_attention.py | 269 +++++++++ .../mistral_24b/tt/vision_conv2d.py | 109 ++++ .../experimental/mistral_24b/tt/vision_mlp.py | 115 ++++ .../experimental/mistral_24b/tt/vision_mmp.py | 171 ++++++ .../tt/vision_pixtral_image_block.py | 87 +++ .../tt/vision_pixtral_transformer.py | 61 ++ .../mistral_24b/tt/vision_rope.py | 105 ++++ models/tt_transformers/tt/common.py | 56 ++ models/tt_transformers/tt/load_checkpoints.py | 152 ++++- models/tt_transformers/tt/model_config.py | 372 +++++++++--- 23 files changed, 3307 insertions(+), 72 deletions(-) create mode 100644 models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py create mode 100644 models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py create mode 100644 models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py create mode 100644 models/experimental/mistral_24b/tests/test_conv2d.py create mode 100644 models/experimental/mistral_24b/tests/test_patch_rot_emb.py create mode 100644 models/experimental/mistral_24b/tests/test_pixtral_transformer.py create mode 100644 models/experimental/mistral_24b/tests/test_vision_attention.py create mode 100644 models/experimental/mistral_24b/tests/test_vision_mlp.py create mode 100644 models/experimental/mistral_24b/tests/test_vision_rms.py create mode 100644 models/experimental/mistral_24b/tt/model.py create mode 100644 models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py create mode 100644 models/experimental/mistral_24b/tt/pipeline/vision_model.py create mode 100644 models/experimental/mistral_24b/tt/rmsnorm.py create mode 100644 models/experimental/mistral_24b/tt/vision_attention.py create mode 100644 models/experimental/mistral_24b/tt/vision_conv2d.py create mode 100644 models/experimental/mistral_24b/tt/vision_mlp.py create mode 100644 models/experimental/mistral_24b/tt/vision_mmp.py create mode 100644 models/experimental/mistral_24b/tt/vision_pixtral_image_block.py create mode 100644 models/experimental/mistral_24b/tt/vision_pixtral_transformer.py create mode 100644 models/experimental/mistral_24b/tt/vision_rope.py diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py new file mode 100644 index 000000000000..96c1e1e00c26 --- /dev/null +++ b/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py @@ -0,0 +1,530 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +"""Test for Mistral-24B End-to-End Vision-Text Pipeline""" + +import torch +import pytest +from loguru import logger +import os +import ttnn + +from models.tt_transformers.tt.ccl import TT_CCL +from models.tt_transformers.tt.common import ( + sample_host, + PagedAttentionConfig, + preprocess_inputs_prefill, +) + +from models.tt_transformers.tt.model_config import DecodersPrecision +from models.experimental.mistral_24b.tt.model import MistralTransformer as Transformer + +from models.tt_transformers.tt.generator import Generator + +from models.experimental.mistral_24b.tt.pipeline.vision_model import TtMistralVisionTransformer +from models.tt_transformers.tt.model_config import ModelArgs +from transformers import AutoProcessor, AutoModelForVision2Seq + +import re + + +def run_reference_demo_pipeline(messages, model_id="mistralai/Mistral-Small-3.1-24B-Instruct-2503"): + """ + Run Hugging Face reference demo model (Vision-Text pipeline) using given messages. + """ + logger.info("Running reference HF vision-text model...") + + processor = AutoProcessor.from_pretrained(model_id) + model = AutoModelForVision2Seq.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.bfloat16, + ) + + model.eval() + + # Apply chat template + prompt_text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, padding=True, padding_side="left" + ) + + # Extract images (already loaded) + image_inputs = [] + for msg in messages: + for item in msg["content"]: + if item["type"] == "image": + image_inputs.append(item["image"]) + + # Tokenize and move to model device + inputs = processor( + text=[prompt_text], + images=image_inputs, + return_tensors="pt", + ).to(model.device, dtype=torch.bfloat16) + + with torch.no_grad(): + generated_ids = model.generate( + **inputs, + max_new_tokens=100, + temperature=0.0, + top_p=0.9, + do_sample=False, + pad_token_id=model.config.pad_token_id, + ) + + # Decode + output = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + logger.info(f"HF reference model output: {output}") + + chat = parse_chat_output(output) + display_chat(logger, chat) + + return output + + +def parse_chat_output(text): + """Parse chat output format from generated text.""" + pattern = r"<\|(?Puser|assistant)\|>\s*(?P.*?)(?=<\|(?:user|assistant|end)\|>|$)" + matches = re.finditer(pattern, text, re.DOTALL) + return [(match.group("role"), match.group("message").strip()) for match in matches] + + +def display_chat(logger, conversation): + """Display chat conversation in formatted output.""" + for role, message in conversation: + if role == "user": + logger.info(f"👤 User: {message}") + elif role == "assistant": + logger.info(f"🤖 Assistant: {message}") + + +def setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations): + """Setup model arguments for vision-enabled model (Single Responsibility).""" + instruct = True if weights == "instruct" else False + + model_args = ModelArgs( + mesh_device=mesh_device, + instruct=instruct, + optimizations=optimizations, + max_seq_len=max_seq_len, + max_batch_size=batch_size, + ) + + return model_args, instruct + + +def setup_vision_prompts_and_tokenizer(model_args, instruct): + """Setup multimodal prompts and tokenizer for vision-enabled model.""" + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://img.freepik.com/premium-photo/girl-hugging-dog-with-girl-hugging-her_737761-2565.jpg", + }, + { + "type": "text", + "text": "Is there a cat in this image? If not, what animal do you see in the image? Describe the image in detail in 600 words.", + }, + ], + } + ] + + tokenizer = model_args.tokenizer + return messages, tokenizer + + +def process_vision_info(messages): + """Extract images (already opened) from messages.""" + image_inputs = [] + video_inputs = None # Not used + + for msg in messages: + content = msg.get("content", []) + for item in content: + if item.get("type") == "image": + image_inputs.append(item["image"]) + + return image_inputs, video_inputs + + +def process_real_vision_inputs(messages, model_args): + """Process real image inputs using AutoProcessor (Interface Segregation).""" + processor = AutoProcessor.from_pretrained(os.getenv("HF_MODEL")) + + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, padding=True, padding_side="left" + ) + + image_inputs, video_inputs = process_vision_info(messages) + + encoded = processor( + text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt", return_dict=True + ).to("cpu", dtype=torch.bfloat16) + input_ids = encoded["input_ids"] + pixel_values = encoded["pixel_values"] if "pixel_values" in encoded else None + attention_mask = encoded["attention_mask"] if "attention_mask" in encoded else None + image_sizes = encoded["image_sizes"] if "image_sizes" in encoded else None + + return { + "input_ids": input_ids, + "pixel_values": pixel_values, + "attention_mask": attention_mask, + "image_sizes": image_sizes, + "processor": processor, + } + + +def load_separate_models_like_test_end2end(model_args, mesh_device, dtype, paged_attention, page_params): + """Load separate vision and text models following test_end2end.py pattern.""" + state_dict = model_args.load_state_dict() + + vision_prefix = "vision_tower." + # Setup paged attention config (exactly like test_end2end.py) + paged_attention_config = None + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + + tt_ccl = TT_CCL(mesh_device) + # Load vision model (exactly like test_end2end.py) + vision_model = TtMistralVisionTransformer( + mesh_device=mesh_device, + tt_ccl=tt_ccl, + state_dict=state_dict, + state_dict_prefix=vision_prefix, + dtype=dtype, + model_args=model_args, + ) + + # Load text model (exactly like test_end2end.py) + text_model = Transformer( + args=model_args, + mesh_device=mesh_device, + dtype=dtype, + state_dict=state_dict, + weight_cache_path=model_args.weight_cache_path(dtype), + paged_attention_config=paged_attention_config, + ) + logger.info("Separate vision and text models loaded like test_end2end.py") + return vision_model, text_model + + +def run_generation_exactly_like_test_end2end( + vision_model, + text_model, + processed_inputs, + model_args, + page_table=None, + paged_attention_config=None, + max_gen_len=20, + repetition_ngram_size=3, +): + """Run generation following the EXACT pattern from test_end2end.py.""" + input_ids = processed_inputs["input_ids"] + + logger.info("Running generation exactly like test_end2end.py...") + + logger.info("Running Vision Model...") + generator = Generator([text_model], [model_args], vision_model.mesh_device, tokenizer=model_args.tokenizer) + tt_kv_cache = [[l.attention.layer_past for l in text_model.layers]] if paged_attention_config else None + + input_tokens_prefill = input_ids + batch_size = input_tokens_prefill.shape[0] + + prompt_text = model_args.tokenizer.decode(input_ids[0].tolist()) + input_prompts = [prompt_text] + + ( + input_tokens_prefill_pt, + encoded_prompts, + decoding_pos, + prefill_lens, + ) = preprocess_inputs_prefill( + input_prompts, + model_args.tokenizer, + [model_args], + instruct=True, + max_generated_tokens=max_gen_len, + max_prefill_len=8192, + ) + + input_tokens_prefill_pt = torch.stack(input_tokens_prefill_pt).view(batch_size, -1) + + logger.info("Running prefill...") + logits = generator.prefill_forward_text( + input_tokens_prefill_pt, + page_table=page_table, + kv_cache=tt_kv_cache, + prompt_lens=decoding_pos, + vision_model=vision_model, + processed_inputs=processed_inputs, + ) + + prefilled_token = torch.argmax(logits, dim=-1) + prefilled_token_decoded_res = model_args.tokenizer.decode(prefilled_token[0].item()) + logger.info(f"prefilled_token_decoded_res: {prefilled_token_decoded_res}") + + logger.info(f"Prefilled token: {prefilled_token}") + + import torch.nn.functional as F + + logger.info(f"Encoded prompt: {encoded_prompts[0]}") + logger.info(f"Decoded prompt: {model_args.tokenizer.decode(encoded_prompts[0])}") + + # logits: [1, 1, vocab_size] + last_logits = logits[0, -1] # shape: [vocab_size] + probs = F.softmax(last_logits, dim=-1) + + top_k = 5 + topk_probs, topk_indices = torch.topk(probs, k=top_k) + + topk_tokens = [model_args.tokenizer.decode([idx.item()]) for idx in topk_indices] + + logger.info("Top-5 predicted tokens (with probabilities):") + for i in range(top_k): + logger.info(f"{i+1}. Token: '{topk_tokens[i]}' (ID={topk_indices[i].item()}), P={topk_probs[i].item():.4f}") + + all_outputs = [encoded_prompts[0][: prefill_lens[0]]] + all_outputs[0].append(int(prefilled_token[0].item())) + + current_pos = torch.tensor([decoding_pos[0]]) + out_tok = prefilled_token + generation_length = max_gen_len + + results = [] + + logger.info("Starting decode loop...") + for iteration in range(generation_length): + logger.info(f"[Text] Decoding token {iteration}, current_pos: {current_pos.item()}") + + logits = generator.decode_forward_text( + out_tok, + current_pos, + enable_trace=False, + page_table=page_table, + kv_cache=tt_kv_cache, + ) + + _, out_tok = sample_host( + logits, + temperature=0, + top_p=0.9, + ) + + token_id = out_tok[0].item() + decoded_token = model_args.tokenizer.decode([token_id]) + logger.info(f"Generated token {iteration}: ID={token_id}, text='{decoded_token}'") + + # Stop if EOS detected + if token_id == model_args.tokenizer.eos_token_id: + logger.info("EOS token detected, stopping generation.") + break + + # Stop if repetition detected (n-gram) + if len(all_outputs[0]) >= repetition_ngram_size * 2: + last_ngram = tuple(all_outputs[0][-repetition_ngram_size:]) + for i in range(len(all_outputs[0]) - repetition_ngram_size): + if tuple(all_outputs[0][i : i + repetition_ngram_size]) == last_ngram: + logger.info(f"Detected {repetition_ngram_size}-gram repetition, stopping.") + break + + # Create result object + result = type("TokenResult", (), {"token": token_id, "text": decoded_token})() + + results.append(result) + + all_outputs[0].append(token_id) + current_pos += 1 + + # Early stopping (exactly like test_end2end.py) + if len(all_outputs[0]) >= 5 and all(t == all_outputs[0][-1] for t in all_outputs[0][-5:]): + logger.warning(f"Detected exact repetition of token {all_outputs[0][-1]} five times in a row. Stopping.") + break + + # Final response (exactly like test_end2end.py) + response = model_args.tokenizer.decode(all_outputs[0], skip_special_tokens=True) + logger.info(f"Each iteration Generated Response:\n{response}") + logger.info(f"Each iteration Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") + chat = parse_chat_output(response) + display_chat(logger, chat) + + logger.info(f" Each iteration Generated {len(results)} tokens successfully") + + # Final response (exactly like test_end2end.py) + response = model_args.tokenizer.decode(all_outputs[0], skip_special_tokens=True) + logger.info(f"Final Generated Response:\n{response}") + logger.info(f"Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") + chat = parse_chat_output(response) + display_chat(logger, chat) + + logger.info(f"Generated {len(results)} tokens successfully") + return results + + +def validate_e2e_outputs(results, expected_min_tokens=1): + """Validate end-to-end pipeline outputs.""" + if not results: + logger.error("No results generated from E2E pipeline") + return False + + if len(results) < expected_min_tokens: + logger.warning(f"Generated only {len(results)} tokens, expected at least {expected_min_tokens}") + return False + + # Check if tokens are valid + for result in results: + if not hasattr(result, "token") or not hasattr(result, "text"): + logger.error("Invalid result format") + return False + + logger.info("E2E pipeline validation passed") + return True + + +@torch.no_grad() +@pytest.mark.timeout(1800) +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize( + "weights, layers", + [ + ("instruct", None), + ], + ids=["full"], +) +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False, + ), + ids=( + "paged_attention", + # "default_attention", + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (1024 * 8,), # Use smaller seq_len like test_end2end.py to avoid memory issues +) +@pytest.mark.parametrize( + "optimizations", + [ + lambda model_args: DecodersPrecision.accuracy(model_args.n_layers, model_args.model_name), + ], + ids=["accuracy"], +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "N150x4": (1, 4), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_e2e_vision_text_pipeline( + weights, + layers, + max_seq_len, + batch_size, + paged_attention, + page_params, + optimizations, + mesh_device, + reset_seeds, + request, + device_params, +): + """Test end-to-end vision-text pipeline using proper Generator methods.""" + logger.info("Starting E2E vision-text pipeline test") + + # Use bfloat8_b like test_end2end.py for better memory efficiency + dtype = ttnn.bfloat8_b + + # Setup vision-enabled model configuration + model_args, instruct = setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations) + + if layers is not None: + model_args.n_layers = layers + + # Setup vision prompts and tokenizer + messages, tokenizer = setup_vision_prompts_and_tokenizer(model_args, instruct) + + # Process real vision inputs from images + processed_inputs = process_real_vision_inputs(messages, model_args) + + # Load separate models following test_end2end.py pattern + logger.info("Loading separate vision and text models like test_end2end.py...") + vision_model, text_model = load_separate_models_like_test_end2end( + model_args, mesh_device, dtype, paged_attention, page_params + ) + + # Setup page table for paged attention (exactly like test_end2end.py) + page_table_tt = None + paged_attention_config = None + + # Prepare page table for paged attention (exactly like test_end2end.py) + page_table = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, -2) if batch_size > 1 else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + # Run generation following EXACT test_end2end.py pattern + logger.info("Running generation following EXACT test_end2end.py pattern...") + results = run_generation_exactly_like_test_end2end( + vision_model, text_model, processed_inputs, model_args, page_table, paged_attention_config, max_gen_len=1024 * 4 + ) + + # Validate results + validation_passed = validate_e2e_outputs(results, expected_min_tokens=1) + + # Final validation + if validation_passed and len(results) > 0: + logger.info("E2E vision-text pipeline test PASSED!") + logger.info(f"Successfully generated {len(results)} tokens") + + # Log generated tokens for debugging + for i, result in enumerate(results[:5]): + logger.info(f"Token {i}: {result.token} -> '{result.text}'") + else: + logger.error("E2E pipeline test failed") + assert False, f"E2E pipeline failed - generated {len(results)} tokens, validation: {validation_passed}" diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py new file mode 100644 index 000000000000..74c558a11dfe --- /dev/null +++ b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +""" +This file is a unit test for validating the Mistral-24B Vision Model pipeline. +""" + +import os +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.ccl import TT_CCL +from models.tt_transformers.tt.model_config import ModelArgs +from models.experimental.mistral_24b.tt.pipeline.vision_model import TtMistralVisionTransformer +from models.common.utility_functions import comp_allclose, comp_pcc + + +def get_image_features(vision_tower, projector, input_tensor, image_sizes): + """ + Get image features from the vision tower and projector. + """ + vision_token = vision_tower(input_tensor, image_sizes).last_hidden_state + image_features = projector(vision_token.squeeze(0), image_sizes) + return image_features + + +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_mistral_vision_model(mesh_device, reset_seeds): + pcc_required = 0.97 + dtype = ttnn.bfloat8_b + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix) + } + + ##### Reference model output (Torch) ##### + reference_model = model_args.reference_vision_model() + reference_model.load_state_dict(partial_state_dict) + + mmp_first_layer_prefix = "multi_modal_projector." + + mmp_partial_state_dict = { + k[len(mmp_first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(mmp_first_layer_prefix)) + } + + reference_mmp = model_args.reference_vision_multi_modal() + reference_mmp.load_state_dict(mmp_partial_state_dict) + + B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size + input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16) + + reference_output = get_image_features(reference_model, reference_mmp, input_tensor, image_sizes=[(H, W)]) + + # ##### TT Model: TtMistralVisionTransformer ##### + tt_ccl = TT_CCL(mesh_device=mesh_device) + vision_model = TtMistralVisionTransformer( + mesh_device=mesh_device, + tt_ccl=tt_ccl, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + dtype=dtype, + model_args=model_args, + ) + + tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) + tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ + :, : tt_output.shape[-1] + ] + + non_zero_indices = tt_output.ne(0).nonzero(as_tuple=True) + tt_output = tt_output[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC below {pcc_required}. {pcc_message}" diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py new file mode 100644 index 000000000000..c596c29fa98b --- /dev/null +++ b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +""" +This file is a unit test for validating the Mistral-24B Vision Tower model. +""" + +import os +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.ccl import TT_CCL +from models.tt_transformers.tt.model_config import ModelArgs +from models.experimental.mistral_24b.tt.pipeline.mistral_vision_tower import MistralVisionTower +from models.common.utility_functions import comp_allclose, comp_pcc + + +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_mistral_vision_tower(mesh_device, reset_seeds): + pcc_required = 0.99 + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix) + } + + B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size + input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16) + print("state_dict ", state_dict.keys()) + ##### Reference model output (Torch) ##### + reference_model = model_args.reference_vision_model() + reference_model.load_state_dict(partial_state_dict) + reference_output = reference_model(input_tensor, image_sizes=[(H, W)]) + + reference_output = reference_output.last_hidden_state + tt_ccl = TT_CCL(mesh_device) + ##### TT Model: MistralVisionTower ##### + vision_model = MistralVisionTower( + mesh_device=mesh_device, + tt_ccl=tt_ccl, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + dtype=dtype, + configuration=model_args, + ) + + tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) + tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ + :, :, :, : tt_output.shape[-1] + ] + tt_output = tt_output.squeeze(0) + passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC below {pcc_required}. {pcc_message}" diff --git a/models/experimental/mistral_24b/tests/test_conv2d.py b/models/experimental/mistral_24b/tests/test_conv2d.py new file mode 100644 index 000000000000..337e5ade4bfa --- /dev/null +++ b/models/experimental/mistral_24b/tests/test_conv2d.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +""" +This file is a unit test for validating the Mistral-24B conv2d. +""" +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.experimental.mistral_24b.tt.vision_conv2d import TtMistralConv2dPatch +from models.common.utility_functions import comp_allclose, comp_pcc +from ttnn import ConcatMeshToTensor + + +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_conv2d_inference( + mesh_device, + reset_seeds, +): + pcc_required = 0.9999 + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "vision_tower.patch_conv." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + num_devices = model_args.num_devices + + ##### Create input tensor for the all gather ##### + B, NCH, H, W = (1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size) + in_channels, out_channels, kernel_size, stride, bias = ( + 3, + model_args.vision_dim, + model_args.vision_patch_size, + model_args.vision_patch_size, + False, + ) + + assert NCH == in_channels, "Number of channels in input tensor should match in_channels for the Conv2d patch." + assert type(kernel_size) == int, "Only symmetric kernel_size is currently supported." + assert kernel_size == stride, "Only same kernel_size and stride are currently supported." + + assert H % kernel_size == 0, "Height should be divisible by kernel_size." + assert W % kernel_size == 0, "Width should be divisible by kernel_size." + + ##### Prepare inputs ##### + input_tensor = torch.randn((B, NCH, H, W)).to(dtype=torch.bfloat16) + logger.info(f"Input tensor shape: {input_tensor.shape}") + + reference_model = model_args.reference_conv2d_patch() + reference_model.load_state_dict(partial_state_dict) + reference_output = reference_model(input_tensor) + + tt_model = TtMistralConv2dPatch( + mesh_device, + state_dict, + first_layer_prefix, + dtype, + in_channels, + out_channels, + kernel_size, + stride, + bias, + ) + tt_output = tt_model(input_tensor) + + ##### Check the outputs ##### + out = ttnn.from_device(tt_output) + tt_output_torch = ttnn.to_torch(out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=2)) + + # Only select output from one device + tt_output_torch = tt_output_torch[0, ..., :out_channels] + + # 1. Restore batch dim + tt_output_torch = tt_output_torch.unsqueeze(0) + # 1 1024 4096 + # 2. Permute to match Conv2D output: (N, C_out, H_out, W_out) + tt_output_torch = tt_output_torch.permute(0, 2, 1).reshape(1, 1024, 110, 110) + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/mistral_24b/tests/test_patch_rot_emb.py b/models/experimental/mistral_24b/tests/test_patch_rot_emb.py new file mode 100644 index 000000000000..2d20bb9591d8 --- /dev/null +++ b/models/experimental/mistral_24b/tests/test_patch_rot_emb.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from loguru import logger + +import torch +import pytest +import os +import ttnn + +from models.experimental.mistral_24b.tt.vision_rope import VisionRotarySetup as RotarySetup + +from models.common.utility_functions import comp_allclose, comp_pcc +from models.tt_transformers.tt.model_config import ModelArgs + + +@torch.no_grad() +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_rot_emb(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + partial_state_dict = {} + + reference_model = tt_model_args.reference_vision_rot_emb() + reference_model.load_state_dict(partial_state_dict) + + image_size = tt_model_args.vision_image_size + patch_size = tt_model_args.vision_patch_size + dim = tt_model_args.vision_head_dim + num_patches_per_dim = image_size // patch_size + num_patches = num_patches_per_dim * num_patches_per_dim + position_ids = torch.arange(4096, dtype=torch.long) + + x = torch.randn(batch_size, 4096, 1024) + + cos, sin = reference_model(x, position_ids) + tt_model = RotarySetup( + device, + batch_size, + dim, + image_size, + patch_size, + num_patches, + tt_model_args.vision_rope_theta, + scale_factor=None, + orig_context_len=num_patches, + datatype=dtype, + ) + + cos2, sin2 = tt_model.get_rot_mats(position_ids) + cos2 = ttnn.from_device(cos2) + cos2 = ttnn.to_torch(cos2) + cos2 = cos2.squeeze(0) + + sin2 = ttnn.from_device(sin2) + sin2 = ttnn.to_torch(sin2) + sin2 = sin2.squeeze(0) + + passing, pcc_message = comp_pcc(cos, cos2) + + logger.info(comp_allclose(cos, cos2)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"COS PCC value is lower than {0.99} for some of the outputs. Check Warnings!" + + passing, pcc_message = comp_pcc(sin, sin2) + + logger.info(comp_allclose(sin, sin2)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"SIN PCC value is lower than {0.99} for some of the outputs. Check Warnings!" diff --git a/models/experimental/mistral_24b/tests/test_pixtral_transformer.py b/models/experimental/mistral_24b/tests/test_pixtral_transformer.py new file mode 100644 index 000000000000..a2cbbcadd825 --- /dev/null +++ b/models/experimental/mistral_24b/tests/test_pixtral_transformer.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.ccl import TT_CCL +from models.tt_transformers.tt.model_config import ModelArgs + +from models.experimental.mistral_24b.tt.vision_pixtral_transformer import TtPixtralTransformer +from models.common.utility_functions import comp_allclose, comp_pcc + + +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 1),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_image_transformer_inference(batch, num_chunks, mesh_device): + pcc_required = 0.99 + + model_args = ModelArgs(mesh_device) + dtype = ttnn.bfloat16 + + state_dict = model_args.load_state_dict() + n_layers = model_args.vision_n_layers + first_layer_prefix = "vision_tower.transformer." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + dim = model_args.vision_dim + heads = model_args.vision_attn_n_heads + seq_len = model_args.vision_chunk_ntok - 1 + head_dim = dim // heads + + reference_model = model_args.reference_vision_encoder() + reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + all_tests_pass = True + + tt_ccl = TT_CCL(mesh_device) + tt_model = TtPixtralTransformer( + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=None, + dtype=dtype, + configuration=model_args, + layers=n_layers, + ) + + # Create PT input + pt_attention_input = torch.rand(batch, seq_len, dim).to(torch.bfloat16) + attention_mask = torch.zeros(batch, 1, seq_len, seq_len).to(torch.bfloat16) + + B, T, D = pt_attention_input.shape + cos = torch.ones((1, T, head_dim)).to(torch.bfloat16) + sin = torch.zeros((1, T, head_dim)).to(torch.bfloat16) + + cos_t = ttnn.from_torch( + cos, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + sin_t = ttnn.from_torch( + sin, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + attention_input = model_args.prepare_residual_tensor_prefill( + pt_attention_input, + force_replicated=True, + ) + + with torch.no_grad(): + tt_out = tt_model(attention_input, position_embeddings=(cos_t, sin_t)) + reference_output = reference_model( + pt_attention_input, attention_mask=attention_mask, position_embeddings=(cos, sin) + )[0] + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[ + : tt_out.shape[0] + ] + tt_output_torch = tt_output_torch.squeeze(0) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + if not passing: + logger.warning(f"PCC value -- {pcc_message} -- is lower than {pcc_required} for the output.") + else: + logger.info(f"PCC: {pcc_message}") + logger.info(comp_allclose(reference_output, tt_output_torch)) + all_tests_pass = all_tests_pass and passing + + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/mistral_24b/tests/test_vision_attention.py b/models/experimental/mistral_24b/tests/test_vision_attention.py new file mode 100644 index 000000000000..6529c7b6cccb --- /dev/null +++ b/models/experimental/mistral_24b/tests/test_vision_attention.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.ccl import TT_CCL +from models.tt_transformers.tt.model_config import ModelArgs +from models.common.utility_functions import comp_allclose, comp_pcc + +from models.experimental.mistral_24b.tt.vision_attention import TtMistralImageAttention as TtLlamaImageAttention + +from ttnn import ConcatMeshToTensor + + +@torch.no_grad() +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_vision_attention(mesh_device, seq_len, batch_size): + logger.info(f"seq_len: {seq_len}, batch_size: {batch_size}") + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=256) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower.transformer.layers.0.attention." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model = model_args.reference_vision_attention() + reference_model.load_state_dict(partial_state_dict) + + hidden_size = model_args.vision_dim + n_heads = model_args.vision_attn_n_heads + head_dim = hidden_size // n_heads + + tt_ccl = TT_CCL(mesh_device) + tt_model = TtLlamaImageAttention( + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + ) + + dim = model_args.vision_dim + pt_attention_input = torch.randn(batch_size, seq_len, dim).to(torch.bfloat16) + attention_mask = torch.zeros(batch_size, 1, seq_len, seq_len).to(torch.bfloat16) + + B, T, D = pt_attention_input.shape + cos = torch.ones((1, T, head_dim)).to(torch.bfloat16) + sin = torch.zeros((1, T, head_dim)).to(torch.bfloat16) + + attention_input = ttnn.from_torch( + pt_attention_input.unsqueeze(0), + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + cos_t = ttnn.from_torch( + cos, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + sin_t = ttnn.from_torch( + sin, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + tt_out = tt_model(attention_input, position_embeddings=(cos_t, sin_t)) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=-1))[ + :, :, :, : tt_out.shape[-1] + ] + tt_output_torch = tt_output_torch.squeeze(0) + reference_output = reference_model(pt_attention_input, attention_mask, position_embeddings=(cos, sin))[0] + pcc_required = 0.99 + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/mistral_24b/tests/test_vision_mlp.py b/models/experimental/mistral_24b/tests/test_vision_mlp.py new file mode 100644 index 000000000000..03ea842f82bf --- /dev/null +++ b/models/experimental/mistral_24b/tests/test_vision_mlp.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn + +from models.experimental.mistral_24b.tt.vision_mlp import MistralTTVisionMLP as MLP +from models.tt_transformers.tt.model_config import ModelArgs +from models.common.utility_functions import comp_allclose, comp_pcc + + +@torch.no_grad() +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (64 * 1024,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_mlp_inference(seq_len, batch_size, mesh_device, reset_seeds): + dtype = ttnn.bfloat8_b + mode = "decode" if seq_len <= 32 else "prefill" + + model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=128) + model_args.n_layers = 1 + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower.transformer.layers.0.feed_forward." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model = model_args.reference_vision_mlp() + reference_model.load_state_dict(partial_state_dict) + + tt_model = MLP( + mesh_device=mesh_device, + args=model_args, + state_dict=state_dict, + weight_cache_path=model_args.weight_cache_path(dtype), + state_dict_prefix="vision_tower.transformer.layers.0.feed_forward.", + dtype=dtype, + ) + torch_input = torch.randn(1, 1, seq_len, 1024).to(torch.bfloat16) + + reference_output = reference_model(torch_input) + tt_input = ttnn.from_torch( + torch_input, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + ) + + logger.info("Run MLP") + tt_output = tt_model(tt_input) + + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ + :, :, :, :1024 + ] + pcc_required = 0.99 + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + if passing: + logger.info("MLP Passed!") + else: + logger.warning("MLP Failed!") + + assert passing, f"MLP output does not meet PCC requirement {pcc_required}: {pcc_message}." diff --git a/models/experimental/mistral_24b/tests/test_vision_rms.py b/models/experimental/mistral_24b/tests/test_vision_rms.py new file mode 100644 index 000000000000..add1d178d25c --- /dev/null +++ b/models/experimental/mistral_24b/tests/test_vision_rms.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from loguru import logger + +import torch +import pytest +import os + +import ttnn +from models.experimental.mistral_24b.tt.rmsnorm import RMSNorm + +from models.common.utility_functions import comp_allclose, comp_pcc + +from models.tt_transformers.tt.model_config import ModelArgs + + +@torch.no_grad() +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_rmsnorm_inference(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_vision_rms() + + first_layer_prefix = "vision_tower.transformer.layers.0.ffn_norm." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model.load_state_dict(partial_state_dict) + + tt_model = RMSNorm( + device=device, + dim=1024, + state_dict=state_dict, + state_dict_prefix="vision_tower.transformer.layers.0.", + weight_key="ffn_norm", + weight_dtype=dtype, + is_distributed=False, + simplified_rms=True, + ) + input = torch.rand(batch_size, seq_len, 1024) + + reference_output = reference_model(input) + + tt_input = ttnn.from_torch( + input, + device=device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device=device), + ) + + tt_output = tt_model(tt_input, mode=mode) + + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(device, dim=-1))[ + :, : tt_output.shape[-1] + ] + + logger.info(f"tt_output_torch: {tt_output_torch.shape}") + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("rms_norm Passed!") + else: + logger.warning("rms_norm Failed!") + + assert passing, f"rms_norm output does not meet PCC requirement {0.99}." diff --git a/models/experimental/mistral_24b/tt/model.py b/models/experimental/mistral_24b/tt/model.py new file mode 100644 index 000000000000..764c12bf3a1d --- /dev/null +++ b/models/experimental/mistral_24b/tt/model.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This is the end-to-end pipeline for the Mistral-Small-3.1-24B-Instruct-2503 model. + +The `MistralTransformer` class inherits from the `Transformer` class in tt_transformers. +It overrides `prepare_inputs_prefill` to run inference on the vision model and +pass the resulting visual tokens to the text model along with text tokens. +""" + + +import ttnn +import torch + +from models.tt_transformers.tt.model import Transformer +from ttnn import ConcatMeshToTensor + + +class MistralTransformer(Transformer): + def __init__( + self, + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=None, + use_paged_kv_cache=False, + ): + super().__init__( + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=paged_attention_config, + use_paged_kv_cache=use_paged_kv_cache, + ) + + def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): + """ + Inputs are torch tensors or python types. This function returns ttnn + tensors on device. + TODO: Debate whether this function is responsible for padding + """ + + tokens = tokens.reshape(1, 1, 1, -1) + S = tokens.shape[-1] + tokens = ttnn.from_torch( + tokens, + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + tokens_embd = self.embd(tokens) + + pixel_values = kwargs["processed_inputs"]["pixel_values"] + input_ids = kwargs["processed_inputs"]["input_ids"] + image_sizes = kwargs["processed_inputs"]["image_sizes"] + + if pixel_values is not None: + vision_model = kwargs["vision_model"] + vision_output = vision_model(pixel_values, image_sizes) + vision_output_torch = ttnn.to_torch( + vision_output, mesh_composer=ConcatMeshToTensor(self.mesh_device, dim=-1) + )[:, : vision_output.shape[-1]] + tokens_embd = ttnn.to_torch(tokens_embd, mesh_composer=ConcatMeshToTensor(self.mesh_device, dim=-1)) + sliced_token_embds = tokens_embd[: tokens_embd.shape[0]] + + image_features = vision_output_torch + + input_ids = torch.nn.functional.pad( + input_ids, (0, tokens_embd.shape[1] - input_ids.shape[1]), "constant", 0 + ) + special_image_mask = (input_ids == 10).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(tokens_embd) + image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) + tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) + + tokens_embd = ttnn.from_torch( + tokens_embd, + dtype=ttnn.bfloat16, + device=self.mesh_device, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh( + self.mesh_device, dims=(None, 2), mesh_shape=list(self.mesh_device.shape) + ), + ) + + tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) + # Slice the rot mats to the prefill seqlen + assert ( + self.rope_setup.cos_matrix.shape[2] >= start_pos + S + ), f"Padded prefill end idx {start_pos + S} exceeds max seq len {self.rope_setup.cos_matrix.shape[2]}" + + tt_rot_mats_prefill_global = [ + self.rope_setup.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + + if hasattr(self, "rope_local_setup"): + tt_rot_mats_prefill_local = [ + self.rope_local_setup.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_local_setup.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + else: + tt_rot_mats_prefill_local = None + + if page_table is not None: + tt_page_table = ttnn.from_torch( + page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_page_table = None + + if chunk_page_table is not None: + tt_chunk_page_table = ttnn.from_torch( + chunk_page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_chunk_page_table = None + + return tokens_embd, tt_rot_mats_prefill_global, tt_rot_mats_prefill_local, tt_page_table, tt_chunk_page_table diff --git a/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py b/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py new file mode 100644 index 000000000000..7a244d83543f --- /dev/null +++ b/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py @@ -0,0 +1,169 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This file implements the Vision Tower submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +This pipeline constructs the vision tower from vision model architecture. +""" + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.experimental.mistral_24b.tt.vision_conv2d import TtMistralConv2dPatch +from models.experimental.mistral_24b.tt.rmsnorm import RMSNorm + +from models.tt_transformers.tt.common import position_ids_in_meshgrid_tt +from models.experimental.mistral_24b.tt.vision_rope import VisionRotarySetup as RotarySetup + +from models.experimental.mistral_24b.tt.vision_pixtral_transformer import TtPixtralTransformer +from ttnn import ConcatMeshToTensor + + +class MistralVisionTower(LightweightModule): + def __init__( + self, + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix, + dtype, + configuration, + return_intermediate=None, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.tt_ccl = tt_ccl + self.dtype = dtype + self.config = configuration + + self.image_size = configuration.vision_chunk_size + self.patch_size = configuration.vision_patch_size + self.width = configuration.vision_dim + self.layers = configuration.vision_n_layers + self.heads = configuration.vision_attn_n_heads + self.vision_head_dim = configuration.vision_head_dim + self.mlp_ratio = configuration.vision_mlp_ratio + self.act_layer = configuration.vision_act_layer + self.in_channels = configuration.vision_in_channels + self.n_global_layers = configuration.vision_n_global_layers + self.max_seq_len = configuration.max_seq_len + self.return_intermediate = return_intermediate + self.n_layers = configuration.vision_n_layers + + in_channels, out_channels, kernel_size, stride, bias = ( + 3, + configuration.vision_dim, + configuration.vision_patch_size, + configuration.vision_patch_size, + False, + ) + + self.patch_conv = TtMistralConv2dPatch( + mesh_device=self.mesh_device, + state_dict=self.state_dict, + state_dict_prefix=f"{state_dict_prefix}patch_conv.", + dtype=self.dtype, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + bias=bias, + ) + + self.ln_pre = RMSNorm( + device=mesh_device, + dim=self.width, + state_dict=self.state_dict, + state_dict_prefix=state_dict_prefix, + weight_dtype=dtype, + weight_key="ln_pre", + is_distributed=False, + simplified_rms=True, + ) + + image_size = configuration.vision_image_size + patch_size = configuration.vision_patch_size + dim = configuration.vision_head_dim + num_patches_per_dim = image_size // patch_size + num_patches = num_patches_per_dim * num_patches_per_dim + self.num_patches = num_patches + + self.patch_positional_embedding = RotarySetup( + self.mesh_device, + 1, + dim, + image_size, + patch_size, + num_patches, + configuration.vision_rope_theta, + scale_factor=None, + orig_context_len=num_patches, + datatype=dtype, + ) + + self.transformer = TtPixtralTransformer( + mesh_device=self.mesh_device, + tt_ccl=tt_ccl, + state_dict=self.state_dict, + state_dict_prefix=f"{state_dict_prefix}transformer.", + weight_cache_path=configuration.weight_cache_path(dtype), + dtype=self.dtype, + configuration=configuration, + layers=self.n_layers, + ) + + def forward(self, input_tensor, image_sizes=None): + """ + input_tensor shape: (B, C, H, W) + """ + patch_embeds = self.patch_conv(input_tensor) + patch_embeds = ttnn.transpose(patch_embeds, 1, 2) + height, width = image_sizes[0] + patch_embeds = ttnn.reshape( + patch_embeds, + [patch_embeds.shape[0], self.width, height // self.patch_size, width // self.patch_size], + ) + + patch_embeds_list = [ + ttnn.slice( + patch_embeds, + [0, 0, 0, 0], + [1, self.width, size[0] // self.patch_size, size[1] // self.patch_size], + ) + for size in image_sizes + ] + + reshaped_patches = [] + for p in patch_embeds_list: + p = ttnn.reshape(p, (1, self.width, -1)) + p = ttnn.transpose(p, 1, 2) + reshaped_patches.append(p) + + patch_embeds = ttnn.concat(reshaped_patches, dim=0) + + # ln_pre RMS Norm + mode = "prefill" + patch_embeds = self.ln_pre(patch_embeds, mode=mode) + + # # positional embeddings + position_ids = position_ids_in_meshgrid_tt( + patch_embeds_list, + max_width=self.config.vision_image_size // self.config.vision_patch_size, + device=self.mesh_device, + ) + + torch_position_ids = ttnn.to_torch(position_ids, mesh_composer=ConcatMeshToTensor(self.mesh_device, dim=0))[ + : position_ids.shape[-1] + ] + + position_embeddings = self.patch_positional_embedding.get_rot_mats(torch_position_ids) + + patch_embeds = ttnn.unsqueeze(patch_embeds, 0) + out = self.transformer(patch_embeds, position_embeddings=position_embeddings) + # deallocate position_embeddings + ttnn.deallocate(position_embeddings[0]) + ttnn.deallocate(position_embeddings[1]) + + return out diff --git a/models/experimental/mistral_24b/tt/pipeline/vision_model.py b/models/experimental/mistral_24b/tt/pipeline/vision_model.py new file mode 100644 index 000000000000..ebc816a71279 --- /dev/null +++ b/models/experimental/mistral_24b/tt/pipeline/vision_model.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This is the end-to-end architecture of the Mistral-24B vision model. + +It brings together all components related to visual and MultiModalProjector together. +""" + +import ttnn +from models.common.lightweightmodule import LightweightModule + +from models.experimental.mistral_24b.tt.pipeline.mistral_vision_tower import MistralVisionTower +from models.experimental.mistral_24b.tt.vision_mmp import TTMistral3MultiModalProjector + + +class TtMistralVisionTransformer(LightweightModule): + def __init__(self, mesh_device, tt_ccl, state_dict, state_dict_prefix, dtype, model_args): + super().__init__() + self.state_dict = state_dict + self.mesh_device = mesh_device + self.tt_ccl = tt_ccl + + self.vision_tower = MistralVisionTower( + mesh_device=mesh_device, + tt_ccl=self.tt_ccl, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + dtype=dtype, + configuration=model_args, + ) + + self.mmp = TTMistral3MultiModalProjector( + mesh_device=mesh_device, + args=model_args, + state_dict=state_dict, + state_dict_prefix="multi_modal_projector.", + dtype=dtype, + eps=1e-05, # layer_norm_eps + ) + + def forward(self, input_tensor, image_sizes=None): + """ + input_tensor shape: (B, C, H, W) + """ + + x = self.vision_tower(input_tensor, image_sizes=image_sizes) + x = ttnn.squeeze(ttnn.squeeze(x, 0), 0) + x = self.mmp(x, image_sizes) + return x diff --git a/models/experimental/mistral_24b/tt/rmsnorm.py b/models/experimental/mistral_24b/tt/rmsnorm.py new file mode 100644 index 000000000000..0aa7cec84448 --- /dev/null +++ b/models/experimental/mistral_24b/tt/rmsnorm.py @@ -0,0 +1,202 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This is the modified version of the rmsnorm for the Mistral-Small-3.1-24B-Instruct-2503 model. +We introduced the `simplified_rms_norm` function to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + +import ttnn +from models.common.lightweightmodule import LightweightModule + +TILE = 32 +SHARD_HEIGHT = TILE # Current ttnn.rms_norm implementation requires shard height to be a single tile + + +class RMSNorm(LightweightModule): + """ + RMSNorm supporting replication over a MeshDevice and sharding within devices. + + This class implements a Root Mean Square Normalization (RMSNorm) that can be + distributed across multiple devices and cores. If the `device` parameter is a + MeshDevice, the weights and computations are replicated across all devices in + the mesh. Expects an interleaved input tensor, can optionally output a sharded tensor. + + Args: + device: The device or MeshDevice on which to perform the computations. + state_dict: The state dictionary containing the model parameters. + dim: Input dimension (e.g. model hidden dimension size). + layer_num: The layer number to determine the weight key in the state dictionary. + weight_key: The key for retrieving the weight from the state dictionary. + weight_cache_path: Optional path for caching the tilized weights. + weight_memory_config: Configuration for the weight memory, default is DRAM_MEMORY_CONFIG. + weight_dtype: The data type for the tensors, bfp8_b hits >0.999 PCC in the models we tested. + model_config: Optional configuration dictionary for the model. + eps (float): Small value to avoid division by zero in normalization, default is 1e-05. + + If model_config is provided, it must specify SHARDED_NORM_INPUT_MEMCFG, SHARDED_NORM_PRGM_CFG + and SHARDED_NORM_OUTPUT_MEMCFG. If not provided, default configurations will be generated. + """ + + def __init__( + self, + device, + dim, + state_dict, + weight_key, + layer_num=None, + state_dict_prefix=None, + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps: float = 1e-05, + add_unit_offset=False, + sharded_program_config=None, + sharded_output_config=None, + output_mem_config=None, + ccl_topology=ttnn.Topology.Ring, + simplified_rms=False, + ): + super().__init__() + self.eps = eps + self.is_distributed = is_distributed + self.ccl_topology = ccl_topology + + if state_dict_prefix: + weight_name = f"{state_dict_prefix}{weight_key}.weight" + else: + if layer_num is None: + weight_name = f"{weight_key}.weight" + else: + weight_name = f"layers.{layer_num}.{weight_key}.weight" + + torch_weight = ( + state_dict[weight_name].unsqueeze(0).view(1, 1, dim).reshape([1, 1, dim // SHARD_HEIGHT, SHARD_HEIGHT]) + ) + + # Add offset before caching + if add_unit_offset: + torch_weight = torch_weight + 1.0 + + cache_name = None if weight_cache_path is None else weight_cache_path / weight_name + + # Compatibility with models that don't use mesh devices (e.g. single-chip Mistral-7b) + is_mesh_device = device.__class__.__name__ == "MeshDevice" + + self.weight = ttnn.as_tensor( + state_dict[weight_name].unsqueeze(0).view(1, 1, dim), + device=device, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT if weight_dtype == ttnn.bfloat8_b else ttnn.ROW_MAJOR_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=cache_name, + mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + ) + + if self.is_distributed: + self.weight_distributed = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=cache_name, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape)) + if is_mesh_device + else None, + ) + + self.sharded_output_config = sharded_output_config + self.sharded_program_config = sharded_program_config + self.output_mem_config = output_mem_config + + self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=False, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + self.simplified_rms = simplified_rms + + def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) -> ttnn.Tensor: + # If input is sharded do sharded RMSNorm and optionally return sharded output + program_config = self.sharded_program_config if in_sharded else None + memory_config = self.sharded_output_config if out_sharded else None + distributed = self.is_distributed and self.is_distributed(mode) + norm = ( + self._simplified_rmsnorm + if self.simplified_rms + else self._distributed_rmsnorm + if distributed + else ttnn.rms_norm + ) + + weight = self.weight_distributed if distributed else self.weight + + if in_sharded: + assert not distributed, "Distributed RMSNorm does not support sharded inputs" + else: + assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor" + + x = norm( + x, + epsilon=self.eps, + weight=weight, + program_config=program_config, + memory_config=memory_config, + compute_kernel_config=self.compute_kernel_config_hifi2, + ) + + if in_sharded and not out_sharded: + return ttnn.sharded_to_interleaved(x) + else: + return x + + def _simplified_rmsnorm( + self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None + ): + inp = ttnn.sharded_to_interleaved(inp, ttnn.DRAM_MEMORY_CONFIG) + xnorm = ttnn.pow(inp, 2) + xnorm = ttnn.mean(xnorm, dim=-1, keepdim=True) + xnorm = ttnn.rsqrt(xnorm + epsilon) + xnorm = ttnn.multiply(inp, xnorm) + weight = ttnn.reshape(weight, [1, 1, -1]) + output = ttnn.multiply(xnorm, (weight), use_legacy=False) + + if memory_config is not None: + output = ttnn.to_memory_config(output, memory_config) + + ttnn.deallocate(xnorm) + ttnn.deallocate(weight) + + return output + + def _distributed_rmsnorm( + self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None + ): + assert program_config is None, "Distributed RMSNorm does not support sharded inputs" + assert memory_config is None, "Distributed RMSNorm does not support sharded outputs" + + # Run distributed rmsnorm part 1 + tt_stats = ttnn.rms_norm_pre_all_gather(inp, compute_kernel_config=compute_kernel_config, dtype=ttnn.bfloat16) + # AllGather stats + tt_stats = ttnn.all_gather( + tt_stats, + dim=3, + num_links=1, + topology=self.ccl_topology, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + # Run distributed rmsnorm part 2 + tt_out = ttnn.rms_norm_post_all_gather( + inp, + tt_stats, + epsilon=epsilon, + weight=weight, + compute_kernel_config=compute_kernel_config, + ) + tt_stats.deallocate(True) + + return tt_out diff --git a/models/experimental/mistral_24b/tt/vision_attention.py b/models/experimental/mistral_24b/tt/vision_attention.py new file mode 100644 index 000000000000..1590db960b17 --- /dev/null +++ b/models/experimental/mistral_24b/tt/vision_attention.py @@ -0,0 +1,269 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 +""" +This is the modified version of the vision_attention for the Mistral-Small-3.1-24B-Instruct-2503 model. +We introduced the `apply_rotary_pos_emb_vision_tt` function to llama_image_attention to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + +import torch +import ttnn + +from models.common.lightweightmodule import LightweightModule +from models.common.utility_functions import is_blackhole, nearest_32 + + +def rotate_half(x): + last_dim = x.shape[-1] + half = last_dim // 2 + + x1 = ttnn.slice(x, (0, 0, 0, 0), (x.shape[0], x.shape[1], x.shape[2], half)) + x2 = ttnn.slice(x, (0, 0, 0, half), (x.shape[0], x.shape[1], x.shape[2], last_dim)) + + neg_x2 = ttnn.mul(x2, -1, use_legacy=False) + return ttnn.concat([neg_x2, x1], dim=-1) + + +def apply_rotary_pos_emb_vision_tt(q, k, cos, sin): + cos = ttnn.unsqueeze(cos, 0) + sin = ttnn.unsqueeze(sin, 0) + + q_embed = ttnn.add(ttnn.mul(q, cos, use_legacy=True), ttnn.mul(rotate_half(q), sin, use_legacy=True)) + k_embed = ttnn.add(ttnn.mul(k, cos), ttnn.mul(rotate_half(k), sin)) + return q_embed, k_embed + + +class TtMistralImageAttention(LightweightModule): + def __init__( + self, + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.tt_ccl = tt_ccl + self.num_devices = configuration.num_devices + + self.hidden_size = configuration.vision_dim + self.n_heads = configuration.vision_attn_n_heads + self.head_dim = self.hidden_size // self.n_heads + self.n_kv_heads = self.n_heads + + self.n_local_heads = self.n_heads // configuration.num_devices + self.n_local_kv_heads = self.n_kv_heads // configuration.num_devices + + self.dtype = dtype + + self.grid_size = configuration.max_grid_size + + self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 + self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 + self.compute_kernel_config_sdpa = configuration.compute_kernel_config_sdpa + self.configuration = configuration + + self.model_config = configuration.get_model_config() + + if configuration.dummy_weights or (weight_cache_path is None): + cache_name = lambda _: None + else: + cache_name = lambda name: weight_cache_path / (f"{state_dict_prefix}{name}") + + wq_str = f"{state_dict_prefix}wq.weight" + wk_str = f"{state_dict_prefix}wk.weight" + wv_str = f"{state_dict_prefix}wv.weight" + wo_str = f"{state_dict_prefix}wo.weight" + + # when splitting the devices, we need to make sure that the number of heads is divisible by the number of devices + assert self.n_heads % configuration.num_devices == 0 + assert self.n_kv_heads % configuration.num_devices == 0 + + # Pad head_dim to multiple of 32 + def pad_head_dim(weight, heads_out=True): + # Pad head dim to multiple of 32 + # heads_out means that the output dim of this weight contains heads. + dim = weight.shape[1] + assert weight.shape[0] == dim + padded_head_dim = nearest_32(self.head_dim) + padding_size = padded_head_dim - self.head_dim + if padding_size > 0: + if heads_out: + weight = weight.transpose(-1, -2) + weight = weight.reshape(dim, self.n_heads, self.head_dim) + padding = torch.zeros(dim, self.n_heads, padding_size, dtype=weight.dtype) + weight = torch.cat([weight, padding], dim=-1) + weight = weight.reshape(dim, self.n_heads * padded_head_dim) + if heads_out: + weight = weight.transpose(-1, -2) + return weight + + wq_padded = pad_head_dim(self.state_dict[wq_str]) + wk_padded = pad_head_dim(self.state_dict[wk_str]) + wv_padded = pad_head_dim(self.state_dict[wv_str]) + wo_padded = pad_head_dim(self.state_dict[wo_str], heads_out=False) + wq_chunked, wk_chunked, wv_chunked = ( + torch.chunk(w, configuration.num_devices) for w in [wq_padded, wk_padded, wv_padded] + ) + + self.wqkv = ttnn.as_tensor( + torch.concat( + [ + torch.concat( + [ + torch.transpose( + wq_chunked[i], + -2, + -1, + ), + torch.transpose( + wk_chunked[i], + -2, + -1, + ), + torch.transpose( + wv_chunked[i], + -2, + -1, + ), + ], + dim=-1, + ) + for i in range(configuration.num_devices) + ], + dim=-1, + ), + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + dtype=self.dtype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + ) + + self.wo = ttnn.as_tensor( + torch.transpose( + wo_padded, + -2, + -1, + ), + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-2), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=self.dtype, + layout=ttnn.TILE_LAYOUT, + ) + + self.scale = self.head_dim**-0.5 + + def forward(self, x_11SH, position_embeddings=None): + seq_len = x_11SH.shape[-2] + + MAX_MM_SEQ_LEN = seq_len + + if seq_len > MAX_MM_SEQ_LEN: + x_11SH = ttnn.reshape(x_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) + + xqkv_fused = ttnn.linear( + x_11SH, + self.wqkv, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config_hifi4, + program_config=self.model_config["IMAGE_ATTN_QKV_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + if seq_len > MAX_MM_SEQ_LEN: + xqkv_fused = ttnn.reshape(xqkv_fused, [1, 1, seq_len, -1]) + + # split qkv into heads + ( + q_heads_1QSD, + k_heads_1KSD, + v_heads_1VSD, + ) = ttnn.experimental.nlp_create_qkv_heads( + xqkv_fused, + num_heads=self.n_local_heads, + num_kv_heads=self.n_local_kv_heads, + transpose_k_heads=False, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + if position_embeddings is not None: + cos, sin = position_embeddings + q_heads_1QSD, k_heads_1KSD = apply_rotary_pos_emb_vision_tt(q_heads_1QSD, k_heads_1KSD, cos, sin) + ttnn.deallocate(xqkv_fused) + # TODO: get this from model_config + sdpa_cfg = ttnn.SDPAProgramConfig( + compute_with_storage_grid_size=(8, 8), q_chunk_size=128, k_chunk_size=128, exp_approx_mode=False + ) + attn_output_1QSD = ttnn.transformer.scaled_dot_product_attention( + q_heads_1QSD, + k_heads_1KSD, + v_heads_1VSD, + is_causal=False, + scale=self.scale, + program_config=sdpa_cfg, + compute_kernel_config=self.compute_kernel_config_sdpa, + ) + # deallocate keys and values + ttnn.deallocate(q_heads_1QSD) + ttnn.deallocate(k_heads_1KSD) + ttnn.deallocate(v_heads_1VSD) + + ### + # Output matmul + ### + attn_output_11SH = ttnn.experimental.nlp_concat_heads( + attn_output_1QSD, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + ttnn.deallocate(attn_output_1QSD) + + # reshaping long sequence to matmul fit on device + if seq_len > MAX_MM_SEQ_LEN: + attn_output_11SH = ttnn.reshape(attn_output_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) + + output_11SH = ttnn.linear( + attn_output_11SH, + self.wo, + compute_kernel_config=self.compute_kernel_config_hifi4, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + program_config=self.model_config["IMAGE_ATTN_OUT_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + if seq_len > MAX_MM_SEQ_LEN: + output_11SH = ttnn.reshape(output_11SH, [1, 1, seq_len, -1]) + ttnn.deallocate(attn_output_11SH) + + # All reduce + if self.num_devices > 1: # replace with reduce_scatter and all_gather + # TODO: 26411 + # Remove this blackhole condition once fabric CCLs are working on blackhole + if is_blackhole(): + dense_out_gathered = ttnn.all_gather(output_11SH, dim=1, num_links=1, topology=ttnn.Topology.Linear) + else: + dense_out_gathered = ttnn.experimental.all_gather_async( + output_11SH, + persistent_output_buffer=None, + dim=1, + multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(), + num_links=1, + topology=ttnn.Topology.Linear, + barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(), + chunks_per_sync=10, + num_workers_per_link=2, + num_buffers_per_channel=2, + ) + output_11SH.deallocate(True) + dense_out_reduced = ttnn.experimental.fast_reduce_nc( + dense_out_gathered, dims=[1], output=None, compute_kernel_config=None + ) + # slicing the required sequence length + dense_out_reduced = dense_out_reduced[:, :, : dense_out_gathered.shape[-2], :] + return dense_out_reduced + else: + return output_11SH diff --git a/models/experimental/mistral_24b/tt/vision_conv2d.py b/models/experimental/mistral_24b/tt/vision_conv2d.py new file mode 100644 index 000000000000..4dc76f9f5ada --- /dev/null +++ b/models/experimental/mistral_24b/tt/vision_conv2d.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This is the modified version of the vision_patch_conv2d for the Mistral-Small-3.1-24B-Instruct-2503 model. +We have modified the llama_patch_conv2d to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + +import torch +import ttnn + +from models.common.lightweightmodule import LightweightModule + + +class TtMistralConv2dPatch(LightweightModule): + """Conv2D Patching layer. + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + Input: (bsz, in_channels, width, height) + Output: (bsz, num_tokens, out_channels) + """ + + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + dtype, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + bias, + ): + super().__init__() + + self.mesh_device = mesh_device + self.num_devices = self.mesh_device.get_num_devices() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + + self.bias = ( + ttnn.as_tensor( + torch.reshape(state_dict[f"{state_dict_prefix}_linear.bias"], (1, -1)), + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + if bias + else None + ) + + self._unfold = torch.nn.Unfold(kernel_size=self.kernel_size, stride=self.stride) + + weight = state_dict[f"{state_dict_prefix}_linear.weight"] + if weight.ndim == 4: + weight = weight.reshape(out_channels, -1).T + + self._linear_weight = ttnn.as_tensor( + weight, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + self.compute_kernel_config = ttnn.init_device_compute_kernel_config( + mesh_device.arch(), + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + def forward(self, x: torch.Tensor): + x = self._unfold(x) + x = x.permute(0, 2, 1) + + x = ttnn.as_tensor( + x, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + out = ttnn.linear( + x, + self._linear_weight, + bias=self.bias, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + core_grid=ttnn.CoreGrid(y=8, x=8), + ) + + return out diff --git a/models/experimental/mistral_24b/tt/vision_mlp.py b/models/experimental/mistral_24b/tt/vision_mlp.py new file mode 100644 index 000000000000..30c84ea94f03 --- /dev/null +++ b/models/experimental/mistral_24b/tt/vision_mlp.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This is the modified version of the FeedForward for the Mistral-Small-3.1-24B-Instruct-2503 model. +This file implements the Vision FeedForward submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + +import torch +import ttnn + +from models.common.lightweightmodule import LightweightModule + + +class MistralTTVisionMLP(LightweightModule): + def __init__( + self, + mesh_device, + args, + state_dict, + weight_cache_path, + dtype, + state_dict_prefix=None, + ): + super().__init__() + + self.mesh_device = mesh_device + self.args = args + self.state_dict = state_dict + self.dim = args.dim + + def get_weight(name): + return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) + + def get_bias(name): + return state_dict[f"{state_dict_prefix}{name}.bias"] + + def cache_name(name): + if args.dummy_weights: + return None + return weight_cache_path / f"{state_dict_prefix}.{name}" + + def as_tensor(name, dtype, is_bias=False): + tensor_data = get_bias(name) if is_bias else get_weight(name) + return ttnn.as_tensor( + tensor_data, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + # Weights and Biases + self.w1 = as_tensor("w1", dtype) + self.b1 = as_tensor("w1", ttnn.bfloat16, is_bias=False) + + self.w3 = as_tensor("w3", dtype) + self.b3 = as_tensor("w3", ttnn.bfloat16, is_bias=False) + + self.w2 = as_tensor("w2", dtype) + self.b2 = as_tensor("w2", ttnn.bfloat16, is_bias=False) + + self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=True, + packer_l1_acc=True, + dst_full_sync_en=False, + ) + + def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: + """ + Qwen HF MLP reference: + output = down_proj(act_fn(gate_proj(x)) * up_proj(x)) + Mapping: + w1 -> gate_proj + w3 -> up_proj + w2 -> down_proj + """ + + # Linear with SILU activation + w1_out = ttnn.linear( + x, + self.w1, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + activation="silu", + compute_kernel_config=self.compute_kernel_config, + ) + + w3_out = ttnn.linear( + x, + self.w3, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + + # Element-wise multiply + w2_in = ttnn.mul(w1_out, w3_out, dtype=ttnn.bfloat16) + + # Final projection + w2_out = ttnn.linear( + w2_in, + self.w2, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + + ttnn.deallocate(w1_out) + ttnn.deallocate(w3_out) + ttnn.deallocate(w2_in) + return w2_out diff --git a/models/experimental/mistral_24b/tt/vision_mmp.py b/models/experimental/mistral_24b/tt/vision_mmp.py new file mode 100644 index 000000000000..6e88dbf65680 --- /dev/null +++ b/models/experimental/mistral_24b/tt/vision_mmp.py @@ -0,0 +1,171 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +""" +This file implements the Vision MultiModalProjector submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + +import torch +from models.common.lightweightmodule import LightweightModule +from models.experimental.mistral_24b.tt.rmsnorm import RMSNorm +import ttnn +from ttnn import ConcatMeshToTensor + + +class TTMistral3PatchMerger(LightweightModule): + def __init__( + self, + mesh_device, + args, + state_dict, + state_dict_prefix, + weight_cache_path=None, + dtype=ttnn.bfloat16, + ): + super().__init__() + self.device = mesh_device + hidden_size = args.vision_dim + self.spatial_merge_size = 2 + self.patch_size = args.vision_patch_size + self.args = args + + def get_weight(name): + return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) + + def get_bias(name): + return state_dict[f"{state_dict_prefix}{name}.bias"] + + def cache_name(name): + if args.dummy_weights: + return None + return weight_cache_path / f"{state_dict_prefix}.{name}" + + def as_tensor(name, dtype, is_bias=False): + tensor_data = get_bias(name) if is_bias else get_weight(name) + return ttnn.as_tensor( + tensor_data, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + self.merging_weights = as_tensor("merging_layer", dtype) + self.merging_bias = as_tensor("merging_layer", ttnn.bfloat16, is_bias=False) + + def forward(self, image_features: ttnn.Tensor, image_sizes) -> ttnn.Tensor: + image_sizes = [ + (image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes + ] + + tokens_per_image = [h * w for h, w in image_sizes] + d = image_features.shape[-1] + + permuted_tensor = [] + for image_index, image_tokens in enumerate(ttnn.split(image_features, tokens_per_image, dim=0)): + # Reshape image_tokens into a 2D grid + h, w = image_sizes[image_index] + + image_tokens = ttnn.to_layout(image_tokens, ttnn.ROW_MAJOR_LAYOUT) + + image_grid = ttnn.view(image_tokens, (h, w, d)) + # Permute the grid to have channels last + image_grid = ttnn.permute(image_grid, (2, 0, 1)) # Channels first + image_grid = ttnn.unsqueeze(image_grid, dim=0) # Add batch dimension + # Reshape the grid to merge patches + if self.args.num_devices > 1: + image_grid_torch = ttnn.to_torch(image_grid, mesh_composer=ConcatMeshToTensor(self.device, dim=0)) + image_grid_torch = image_grid_torch[0].unsqueeze(0) # shape: [1, 1024, 30, 44] + image_grid_torch = image_grid_torch.to(dtype=torch.bfloat16) + else: + image_grid_torch = ttnn.to_torch(image_grid).to(dtype=torch.bfloat16) + + grid = torch.nn.functional.unfold( + image_grid_torch, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size + ) + + grid = ttnn.from_torch(grid, device=self.device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + + grid = ttnn.view(grid, (d * self.spatial_merge_size**2, -1)) + grid = ttnn.transpose(grid, 0, 1) # Transpose to have features first + + permuted_tensor.append(grid) + + image_features = ttnn.concat(permuted_tensor, dim=0) + # Apply merging layer + image_features = ttnn.linear( + image_features, self.merging_weights, dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + + return image_features + + +class TTMistral3MultiModalProjector(LightweightModule): + def __init__(self, mesh_device, args, state_dict, state_dict_prefix, dtype, eps, weight_cache_path=None): + super().__init__() + + self.norm = RMSNorm( + device=mesh_device, + dim=args.vision_dim, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="norm", + weight_dtype=dtype, + eps=eps, + is_distributed=False, + simplified_rms=True, + ) + + self.patch_merger = TTMistral3PatchMerger( + mesh_device=mesh_device, + args=args, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}patch_merger.", + ) + + def get_weight(name): + return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) + + def get_bias(name): + return state_dict[f"{state_dict_prefix}{name}.bias"] + + def cache_name(name): + if args.dummy_weights: + return None + return weight_cache_path / f"{state_dict_prefix}.{name}" + + def as_tensor(name, dtype, is_bias=False): + tensor_data = get_bias(name) if is_bias else get_weight(name) + return ttnn.as_tensor( + tensor_data, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + self.linear_1_weight = as_tensor("linear_1", dtype) + self.linear_1_bias = as_tensor("linear_1", ttnn.bfloat16, is_bias=False) + + self.linear_2_weight = as_tensor("linear_2", dtype) + self.linear_2_bias = as_tensor("linear_2", ttnn.bfloat16, is_bias=False) + + def forward(self, image_features: ttnn.Tensor, image_sizes): + image_features = self.norm(image_features, mode="decode") + image_features = self.patch_merger(image_features, image_sizes) + + hidden_states = ttnn.linear( + image_features, + self.linear_1_weight, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + activation="gelu", # Using GELU activation as per Mistral 3 model + ) + + hidden_states = ttnn.linear( + hidden_states, self.linear_2_weight, dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + + return hidden_states diff --git a/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py b/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py new file mode 100644 index 000000000000..66a010a35af8 --- /dev/null +++ b/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.experimental.mistral_24b.tt.rmsnorm import RMSNorm + +from models.experimental.mistral_24b.tt.vision_attention import TtMistralImageAttention as TtLlamaImageAttention +from models.experimental.mistral_24b.tt.vision_mlp import MistralTTVisionMLP as MLP + +""" +This file implements the pixtral image block specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + + +class TtPixtralImageTransformerBlock(LightweightModule): + def __init__( + self, + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + ): + super().__init__() + self.state_dict = state_dict + self.mesh_device = mesh_device + self.tt_ccl = tt_ccl + self.configuration = configuration + self.num_devices = configuration.num_devices + self.hidden_size = configuration.vision_dim + + self.attention_norm = RMSNorm( + device=mesh_device, + dim=configuration.vision_dim, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="attention_norm", + weight_dtype=dtype, + is_distributed=False, + simplified_rms=True, + ) + + self.attention = TtLlamaImageAttention( + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix=f"{state_dict_prefix}attention.", + weight_cache_path=weight_cache_path, + dtype=dtype, + configuration=configuration, + ) + + self.ffn_norm = RMSNorm( + device=mesh_device, + dim=configuration.vision_dim, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="ffn_norm", + weight_dtype=dtype, + is_distributed=False, + simplified_rms=True, + ) + + self.mlp = MLP( + mesh_device=mesh_device, + args=configuration, + state_dict=state_dict, + weight_cache_path=weight_cache_path, + state_dict_prefix=f"{state_dict_prefix}feed_forward.", + dtype=dtype, + ) + + def forward(self, x_input, position_embeddings=None): + mode = "prefill" + # attention norm Input and result replicated + attn_norm_res = self.attention_norm(x_input, mode=mode) + # attention Input and results replicated + attn_out = self.attention(attn_norm_res, position_embeddings=position_embeddings) + res = ttnn.add(x_input, attn_out, use_legacy=True) + ffn_norm_res = self.ffn_norm(res, mode=mode) + mlp_out = self.mlp(ffn_norm_res) + out = ttnn.add(res, mlp_out) + return out diff --git a/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py b/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py new file mode 100644 index 000000000000..7e45e9ff8573 --- /dev/null +++ b/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This file implements the Vision Transformer submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +This pipeline iterates over the pixtral image blocks to generate the image embeddings. +""" + +from tqdm import tqdm + +from models.common.lightweightmodule import LightweightModule +from models.experimental.mistral_24b.tt.vision_pixtral_image_block import TtPixtralImageTransformerBlock + + +class TtPixtralTransformer(LightweightModule): + def __init__( + self, + mesh_device, + tt_ccl, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + layers, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.tt_ccl = tt_ccl + + block_key = "layers" + self.resblocks = [ + TtPixtralImageTransformerBlock( + mesh_device=mesh_device, + tt_ccl=self.tt_ccl, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}{block_key}.{i}.", + weight_cache_path=weight_cache_path, + dtype=dtype, + configuration=configuration, + ) + for i in tqdm(range(layers), desc=f"Loading vision transformer layers") + ] + + def forward(self, x, return_intermediate=None, position_embeddings=None): + """ + Different from reference impl in that if return_intermediates, it returns + a list of intermediate tensors rather than a stack of intermediates. + Outer code will have to be aware and handle this correctly. + """ + out = [] + for idx, r in enumerate(self.resblocks): + if return_intermediate is not None and idx in return_intermediate: + out.append(x) + x = r(x, position_embeddings=position_embeddings) + if return_intermediate is not None: + return x, out + return x diff --git a/models/experimental/mistral_24b/tt/vision_rope.py b/models/experimental/mistral_24b/tt/vision_rope.py new file mode 100644 index 000000000000..bb299dc4ca07 --- /dev/null +++ b/models/experimental/mistral_24b/tt/vision_rope.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +""" +This is the modified version of the RoPE for the Mistral-Small-3.1-24B-Instruct-2503 model. +We have modified the compute_gather_cos_sin function of RMSNorm to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + +import torch +import ttnn + +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.common import precompute_mistral_vision_freqs +from ttnn import ReplicateTensorToMesh + + +def compute_gather_cos_sin(dhead, max_patches_per_side, theta, scale_factor, orig_context_len, position_ids): + cos, sin = precompute_mistral_vision_freqs(dhead, max_patches_per_side, theta, scale_factor, orig_context_len) + return cos, sin + + +class VisionRotarySetup(LightweightModule): + def __init__( + self, + device, + batch_size: int, + head_dim: int, + image_size: int, + patch_size: int, + max_seq_len: int, + rope_theta: float, + scale_factor: float, # use None to disable rope scaling + orig_context_len: int, # only used if scaling enabled + datatype=ttnn.bfloat16, + ): + super().__init__() + + self.batch_size = batch_size + self.head_dim = head_dim + self.device = device + self.is_mesh_device = isinstance(device, ttnn._ttnn.multi_device.MeshDevice) + self.num_devices = device.get_num_devices() if self.is_mesh_device else 1 + if self.num_devices == 32: + self.batch_size_per_device_group = max(self.batch_size // list(device.shape)[1], 1) + else: + self.batch_size_per_device_group = self.batch_size + self.core_grid = device.compute_with_storage_grid_size() + + max_patches_per_side = image_size // patch_size + + # Generate the cos/sin matrices needed for ttnn.embedding op + cos_matrix, sin_matrix = compute_gather_cos_sin( + dhead=head_dim, + max_patches_per_side=max_patches_per_side, + theta=rope_theta, + scale_factor=scale_factor, + orig_context_len=orig_context_len, + position_ids=torch.arange(max_seq_len), + ) + self.cos_matrix = ttnn.from_torch( + cos_matrix, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=datatype, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + self.sin_matrix = ttnn.from_torch( + sin_matrix, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=datatype, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + + def get_rot_mats(self, position_idxs, return_rot_idxs=False): + device = self.device + + # If position_idxs is a torch tensor, get the TTNN version of it + if isinstance(position_idxs, torch.Tensor): + rot_idxs = position_idxs.unsqueeze(0) + else: + rot_idxs = position_idxs + assert len(rot_idxs.shape) == 2 and rot_idxs.shape[0] == 1, "rot_idxs must be a [1, batch] tensor" + + rot_idxs = ttnn.from_torch( + rot_idxs, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=ttnn.uint32, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + # Send the idxs to device + if rot_idxs.device != device: + rot_idxs = ttnn.to_device(rot_idxs, device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + + embedding_layout = ttnn.TILE_LAYOUT + cos = ttnn.embedding(rot_idxs, self.cos_matrix, layout=embedding_layout) # [1, batch, head_dim] + sin = ttnn.embedding(rot_idxs, self.sin_matrix, layout=embedding_layout) # [1, batch, head_dim] + + if return_rot_idxs: + return [cos, sin], rot_idxs + ttnn.deallocate(rot_idxs) + return [cos, sin] diff --git a/models/tt_transformers/tt/common.py b/models/tt_transformers/tt/common.py index 35c3b4449b47..2c9e4ea63cbf 100644 --- a/models/tt_transformers/tt/common.py +++ b/models/tt_transformers/tt/common.py @@ -114,6 +114,26 @@ def rope_scaling_model_factory( raise ValueError(f"Unexpected RoPE scaling type: {rope_scaling_type}") +def position_ids_in_meshgrid_tt(tt_patch_embeds_list, max_width, device): + position_ids_tt = [] + for tt_patch in tt_patch_embeds_list: + shape = tt_patch.shape + height, width = shape[-2], shape[-1] + mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_width + v_grid + + tt_ids = ttnn.from_torch( + ids, + device=device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + position_ids_tt.append(tt_ids[:, 0]) + return ttnn.concat(position_ids_tt, dim=0) + + def encode_prompt_instruct(tokenizer, prompt_text, system_prompt_text=None): """<|begin_of_text|><|start_header_id|>system<|end_header_id|> {{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|> @@ -368,6 +388,42 @@ def get_prefill_rot_mat(head_dim, mesh_device, seq_len, theta, scale_factor, ori rot_mats = [cos_gathereds, sin_gathereds] return rot_mats +def apply_scaling_vision(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): + return freqs / scale_factor + + +def precompute_mistral_vision_freqs( + dim: int, max_patches_per_side: int, theta: float, scale_factor=None, orig_context_len=None +): + # Compute base frequencies + base_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + if scale_factor is not None: + base_freqs = apply_scaling_vision(base_freqs, scale_factor, orig_context_len) + + # Get height and width indices + h_idx = torch.arange(max_patches_per_side) + w_idx = torch.arange(max_patches_per_side) + + # Compute 2D frequency matrices + freqs_h = torch.outer(h_idx, base_freqs[::2]) + freqs_w = torch.outer(w_idx, base_freqs[1::2]) + + # Broadcast + merge + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape( + -1, dim // 2 + ) # Shape: [H*W, dim//2] + + full_freqs = torch.cat([inv_freq, inv_freq], dim=-1) + cos = full_freqs.cos() + sin = full_freqs.sin() + return cos, sin # Shape: [H*W, dim] + # Add-Multiply method of rotary embeddings for prefill def get_rot_transformation_mat(dhead): diff --git a/models/tt_transformers/tt/load_checkpoints.py b/models/tt_transformers/tt/load_checkpoints.py index 62a32a6c7bc4..58cf8c38c268 100644 --- a/models/tt_transformers/tt/load_checkpoints.py +++ b/models/tt_transformers/tt/load_checkpoints.py @@ -278,7 +278,10 @@ def convert_hf_qkv_to_meta_format(loaded_weights, head_dim): """Convert HuggingFace QKV weights to Meta format for RoPE compatibility.""" converted_weights = {} for key, tensor in loaded_weights.items(): - if "q_proj.weight" in key or "k_proj.weight" in key: + if "vision_tower" in key: + # Skip conversion for vision tower weights + converted_weights[key] = tensor + elif "q_proj.weight" in key or "k_proj.weight" in key: # For weights: n_heads = tensor.shape[0] // head_dim n_heads = tensor.shape[0] // head_dim converted_weights[key] = reverse_permute(tensor, n_heads, tensor.shape[0], tensor.shape[1]) @@ -600,6 +603,12 @@ def flatten_conv_linear(state_dict): return state_dict +def convert_vision_meta_to_hf(state_dict, head_dim): + # state_dict = convert_meta_qkv_to_hf_format(state_dict, head_dim) + state_dict = map_vision_meta_to_hf_keys(state_dict) + return state_dict + + def map_hf_to_meta_keys(loaded_weights): """ Map Hugging Face checkpoint keys to Meta checkpoint keys. @@ -625,6 +634,7 @@ def map_hf_to_meta_keys(loaded_weights): ("o_proj", "wo"), ("q_norm", "q_norm"), ("k_norm", "k_norm"), + ("patch_conv.weight", "patch_conv._linear.weight"), ] return replace_keys(loaded_weights, replacements) @@ -658,6 +668,146 @@ def map_meta_to_hf_keys(state_dict): return replace_keys(state_dict, replacements) +def map_vision_meta_to_hf_keys(loaded_weights): + language_weights = { + key[len("language_model.") :]: tensor + for key, tensor in loaded_weights.items() + if key.startswith("language_model.") + } + mapped_language_weights = map_meta_to_hf_keys(language_weights, language_prefix="language_model.") + other_weights = {key: tensor for key, tensor in loaded_weights.items() if not key.startswith("language_model.")} + hf_state_dict = {**mapped_language_weights} + loaded_weights = {**other_weights} + meta_to_hf_mappings = { + # vision MLP + "c_fc.weight": "fc1.weight", + "c_fc.bias": "fc1.bias", + "c_proj.weight": "fc2.weight", + "c_proj.bias": "fc2.bias", + # vision attention + # "wq.weight": "q_proj.weight", + # "wk.weight": "k_proj.weight", + # "wv.weight": "v_proj.weight", + # "wo.weight": "out_proj.weight", + # "wq.bias": "q_proj.bias", + # "wk.bias": "k_proj.bias", + # "wv.bias": "v_proj.bias", + # "wo.bias": "out_proj.bias", + # vision encoder block + "attn.wq.weight": "self_attn.q_proj.weight", + "attn.wk.weight": "self_attn.k_proj.weight", + "attn.wv.weight": "self_attn.v_proj.weight", + "attn.wo.weight": "self_attn.out_proj.weight", + "attn.wq.bias": "self_attn.q_proj.bias", + "attn.wk.bias": "self_attn.k_proj.bias", + "attn.wv.bias": "self_attn.v_proj.bias", + "attn.wo.bias": "self_attn.out_proj.bias", + "ln_1.weight": "layer_norm1.weight", + "ln_1.bias": "layer_norm1.bias", + "ln_2.weight": "layer_norm2.weight", + "ln_2.bias": "layer_norm2.bias", + "mlp.c_fc.weight": "mlp.fc1.weight", + "mlp.c_fc.bias": "mlp.fc1.bias", + "mlp.c_proj.weight": "mlp.fc2.weight", + "mlp.c_proj.bias": "mlp.fc2.bias", + # vision encoder + "layers.{layer}.attn.wq.weight": "layers.{layer}.self_attn.q_proj.weight", + "layers.{layer}.attn.wk.weight": "layers.{layer}.self_attn.k_proj.weight", + "layers.{layer}.attn.wv.weight": "layers.{layer}.self_attn.v_proj.weight", + "layers.{layer}.attn.wo.weight": "layers.{layer}.self_attn.out_proj.weight", + "layers.{layer}.attn.wq.bias": "layers.{layer}.self_attn.q_proj.bias", + "layers.{layer}.attn.wk.bias": "layers.{layer}.self_attn.k_proj.bias", + "layers.{layer}.attn.wv.bias": "layers.{layer}.self_attn.v_proj.bias", + "layers.{layer}.attn.wo.bias": "layers.{layer}.self_attn.out_proj.bias", + "layers.{layer}.ln_1.weight": "layers.{layer}.layer_norm1.weight", + "layers.{layer}.ln_1.bias": "layers.{layer}.layer_norm1.bias", + "layers.{layer}.ln_2.weight": "layers.{layer}.layer_norm2.weight", + "layers.{layer}.ln_2.bias": "layers.{layer}.layer_norm2.bias", + "layers.{layer}.mlp.c_fc.weight": "layers.{layer}.mlp.fc1.weight", + "layers.{layer}.mlp.c_fc.bias": "layers.{layer}.mlp.fc1.bias", + "layers.{layer}.mlp.c_proj.weight": "layers.{layer}.mlp.fc2.weight", + "layers.{layer}.mlp.c_proj.bias": "layers.{layer}.mlp.fc2.bias", + # vision transformer + "encoder.layers.{layer}.attn.wq.weight": "encoder.layers.{layer}.self_attn.q_proj.weight", + "encoder.layers.{layer}.attn.wk.weight": "encoder.layers.{layer}.self_attn.k_proj.weight", + "encoder.layers.{layer}.attn.wv.weight": "encoder.layers.{layer}.self_attn.v_proj.weight", + "encoder.layers.{layer}.attn.wo.weight": "encoder.layers.{layer}.self_attn.out_proj.weight", + "encoder.layers.{layer}.attn.wq.bias": "encoder.layers.{layer}.self_attn.q_proj.bias", + "encoder.layers.{layer}.attn.wk.bias": "encoder.layers.{layer}.self_attn.k_proj.bias", + "encoder.layers.{layer}.attn.wv.bias": "encoder.layers.{layer}.self_attn.v_proj.bias", + "encoder.layers.{layer}.attn.wo.bias": "encoder.layers.{layer}.self_attn.out_proj.bias", + "ln_post.weight": "post_layernorm.weight", + "ln_post.bias": "post_layernorm.bias", + # Top level + "_linear.weight": "weight", # patch_embedding + "_linear.bias": "bias", # patch_embedding + "positional_embedding": "weight", # pos_emb + "visual.embeddings.patch_embedding._linear.weight": "visual.embeddings.patch_embedding.weight", + "visual.embeddings.patch_embedding._linear.bias": "visual.embeddings.patch_embedding._linear.bias", + "visual.embeddings.position_embedding.positional_embedding": "visual.embeddings.position_embedding.weight", + "visual.encoder.layers.{layer}.attn.wq.weight": "visual.encoder.layers.{layer}.self_attn.q_proj.weight", + "visual.encoder.layers.{layer}.attn.wk.weight": "visual.encoder.layers.{layer}.self_attn.k_proj.weight", + "visual.encoder.layers.{layer}.attn.wv.weight": "visual.encoder.layers.{layer}.self_attn.v_proj.weight", + "visual.encoder.layers.{layer}.attn.wo.weight": "visual.encoder.layers.{layer}.self_attn.out_proj.weight", + "visual.encoder.layers.{layer}.attn.wq.bias": "visual.encoder.layers.{layer}.self_attn.q_proj.bias", + "visual.encoder.layers.{layer}.attn.wk.bias": "visual.encoder.layers.{layer}.self_attn.k_proj.bias", + "visual.encoder.layers.{layer}.attn.wv.bias": "visual.encoder.layers.{layer}.self_attn.v_proj.bias", + "visual.encoder.layers.{layer}.attn.wo.bias": "visual.encoder.layers.{layer}.self_attn.out_proj.bias", + "visual.encoder.layers.{layer}.ln_1.weight": "visual.encoder.layers.{layer}.layer_norm1.weight", + "visual.encoder.layers.{layer}.ln_1.bias": "visual.encoder.layers.{layer}.layer_norm1.bias", + "visual.encoder.layers.{layer}.ln_2.weight": "visual.encoder.layers.{layer}.layer_norm2.weight", + "visual.encoder.layers.{layer}.ln_2.bias": "visual.encoder.layers.{layer}.layer_norm2.bias", + "visual.encoder.layers.{layer}.mlp.c_fc.weight": "visual.encoder.layers.{layer}.mlp.fc1.weight", + "visual.encoder.layers.{layer}.mlp.c_fc.bias": "visual.encoder.layers.{layer}.mlp.fc1.bias", + "visual.encoder.layers.{layer}.mlp.c_proj.weight": "visual.encoder.layers.{layer}.mlp.fc2.weight", + "visual.encoder.layers.{layer}.mlp.c_proj.bias": "visual.encoder.layers.{layer}.mlp.fc2.bias", + "visual.ln_post.weight": "visual.post_layernorm.weight", + "visual.ln_post.bias": "visual.post_layernorm.bias", + } + + for key, tensor in loaded_weights.items(): + # Handle full model paths with layer numbers + if "model.vision_tower.vision_model.encoder.layers." in key: + parts = key.split(".") + layer_num = parts[5] + remainder = ".".join(parts[6:]) + if remainder in meta_to_hf_mappings: + new_key = f"model.vision_tower.vision_model.encoder.layers.{layer_num}.{meta_to_hf_mappings[remainder]}" + hf_state_dict[new_key] = tensor + continue + + # Handle full vision encoder paths with layer numbers + if "layers." in key: + parts = key.split(".") + layer_num = parts[1] # e.g. "0" in "model.layers.0.input_layernorm.weight" + template_key = "layers.{layer}." + ".".join(parts[2:]) + if template_key in meta_to_hf_mappings: + hf_state_dict[meta_to_hf_mappings[template_key].format(layer=layer_num)] = tensor + continue + + # Try exact matches first + if key in meta_to_hf_mappings: + hf_state_dict[meta_to_hf_mappings[key]] = tensor + continue + + # For submodule state dicts, try matching the end of the key + matched = False + for meta_pattern, hf_pattern in meta_to_hf_mappings.items(): + if key.endswith("." + meta_pattern): + # Replace only the matching part at the end + prefix = key[: -len(meta_pattern)] + new_key = prefix + hf_pattern + hf_state_dict[new_key] = tensor + matched = True + break + + # If no mapping found, keep the original key + if not matched: + hf_state_dict[key] = tensor + + return hf_state_dict + + def convert_meta_qkv_to_hf_format(loaded_weights, head_dim): """Convert Meta QKV weights back to HuggingFace format.""" converted_weights = {} diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index 848811a0aa8a..4d320008f10c 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -31,6 +31,7 @@ convert_hf_to_meta_mllama, convert_meta_to_hf, convert_vision_hf_to_meta, + convert_vision_meta_to_hf, load_hf_state_dict, load_meta_state_dict, reverse_permute, @@ -597,6 +598,7 @@ def __init__( "Phi-3-mini-128k-instruct": {"N150": 32, "N300": 64, "T3K": 128, "TG": 128, "P150x4": 128}, "QwQ-32B": {"N150": None, "N300": None, "T3K": 64, "TG": 128, "P150x4": 128}, "Qwen3-32B": {"N150": None, "N300": None, "T3K": 64, "TG": 128, "P150x4": 128}, + "Mistral-Small-3.1-24B": {"N150": 8, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, } try: max_prefill_chunk_size_div1024 = MAX_PREFILL_CHUNK_SIZES_DIV1024[self.base_model_name][self.device_name] @@ -1465,7 +1467,7 @@ def prepare_residual_tensor_prefill(self, x_bsh, force_replicated=False): return xs_1BSH def _get_text_prefix(self): - if self.is_llama_vision(): + if self.is_multimodal: return "text_model." else: return "" @@ -1700,8 +1702,8 @@ def _set_params(self, checkpoint_dir): else None ) - def _set_vision_params(self, config): - vision_config = config.get("vision_config", config) + def _set_vision_params(self, vision_config): + vision_config = vision_config.get("vision_config", vision_config) self.vision_chunk_size = vision_config.get("vision_chunk_size", vision_config.get("image_size", -1)) self.image_size = vision_config.get("image_size", -1) @@ -1712,17 +1714,25 @@ def _set_vision_params(self, config): self.vision_dim = vision_config.get("hidden_size", 1280) intermediate_size = vision_config.get("intermediate_size", self.vision_dim * 4) + self.vision_image_size = vision_config.get("image_size", 1540) + self.vision_rope_theta = vision_config.get("rope_theta", 10000.0) + self.vision_mlp_ratio = intermediate_size // self.vision_dim self.vision_hidden_dim = int(self.vision_dim * self.vision_mlp_ratio) - self.vision_attn_n_heads = vision_config.get("num_attention_heads", 16) + self.vision_attn_n_heads = vision_config.get("num_attention_heads") or vision_config.get("num_heads") or 16 self.vision_head_dim = self.vision_dim // self.vision_attn_n_heads - self.vision_n_layers = vision_config.get("num_hidden_layers", 32) + self.vision_n_layers = vision_config.get("num_hidden_layers") or vision_config.get("depth") or 27 self.vision_patch_size = vision_config.get("patch_size", 14) self.vision_in_channels = vision_config.get("num_channels", 3) self.vision_dropout = vision_config.get("attention_dropout", 0.0) - self.mm_tokens_per_image = vision_config.get("mm_tokens_per_image", config.get("mm_tokens_per_image", 256)) + self.mm_tokens_per_image = vision_config.get( + "mm_tokens_per_image", vision_config.get("mm_tokens_per_image", 256) + ) + + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + self.vision_head_dim = vision_config.get("head_dim", 64) # Optional vision activation layer, defaults to GELU act_layer = vision_config.get("act_layer", "gelu").lower() @@ -1737,6 +1747,18 @@ def _set_vision_params(self, config): self.vision_n_global_layers = vision_config.get("n_global_layers", vision_config.get("num_global_layers", 8)) def _set_hf_params(self, checkpoint_dir): + def merge_text_config(base_config): + text_config = base_config.get("text_config", {}) + # Merge non-nested keys into text_config + text_config.update({k: v for k, v in base_config.items() if k not in ["text_config", "vision_config"]}) + return text_config + + def merge_vision_config(base_config): + vision_config = base_config.get("vision_config", {}) + # Merge non-nested keys into vision_config + vision_config.update({k: v for k, v in base_config.items() if k not in ["text_config", "vision_config"]}) + return vision_config + if self.from_hf_url: from transformers import AutoConfig @@ -1755,12 +1777,14 @@ def _set_hf_params(self, checkpoint_dir): ) config = self.hf_config.to_dict() + self._set_params_from_dict(config, is_hf=True) + else: config_file = os.path.join(checkpoint_dir, "config.json") assert os.path.exists(config_file), f"config.json file not found at {config_file}" with open(config_file, "r") as f: config = json.load(f) - self._set_params_from_dict(config, is_hf=True) + self._set_params_from_dict(config) # compatibility with _set_params if "llama" in self.model_name.lower(): @@ -1824,8 +1848,15 @@ def is_llama_vision(self): return ("llama" in self.CKPT_DIR.lower()) and ("vision" in self.CKPT_DIR.lower()) def get_state_dict_prefix(self, module_name, layer_num, is_vision=False): - text_prefix = self.state_dict_text_prefix - vision_prefix = self.state_dict_vision_prefix + if self.is_multimodal: + no_prefix_models = { + "Mistral-Small-3.1-24B-Instruct-2503", + } + text_prefix = "" if self.model_name in no_prefix_models else self.state_dict_text_prefix + else: + text_prefix = "" if not is_vision else self.state_dict_text_prefix + + vision_prefix = self.state_dict_vision_prefix if is_vision else "" layer_prefix = f"layers.{layer_num}." if layer_num is not None else "" @@ -1907,6 +1938,11 @@ def load_state_dict(self): assert self.checkpoint_type == CheckpointType.HuggingFace if self.from_hf_url: model_cls = self.get_hf_model_cls() + + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + from transformers import Mistral3ForConditionalGeneration as AutoModelForCausalLM + + model_cls = AutoModelForCausalLM model = model_cls.from_pretrained( self.CKPT_DIR, torch_dtype="auto", @@ -2379,57 +2415,84 @@ def create_tokenizer(self): logger.info(f"Successfully loaded tokenizer from {self.TOKENIZER_PATH}") except Exception as e: logger.warning(f"Failed to load tokenizer from {self.TOKENIZER_PATH}: {e}") - - # Try to use base model tokenizer as fallback - fallback_tokenizer_path = base_model_tokenizer_mapping.get(self.base_model_name) - - # If no direct match, try to infer from model name patterns - if not fallback_tokenizer_path: - model_name_lower = self.model_name.lower() - if "qwen2.5" in model_name_lower and "0.5b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-Coder-0.5B-Instruct" - elif "qwen2.5" in model_name_lower and "1.5b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-1.5B-Instruct" - elif "qwen2.5" in model_name_lower and "3b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-3B-Instruct" - elif "qwen2.5" in model_name_lower and "7b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-7B-Instruct" - elif "qwen2.5" in model_name_lower and "14b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-14B-Instruct" - elif "qwen2.5" in model_name_lower and "32b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-32B-Instruct" - elif "qwen2.5" in model_name_lower and "72b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-72B-Instruct" - elif "llama" in model_name_lower and "3.1" in model_name_lower and "8b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.1-8B-Instruct" - elif "llama" in model_name_lower and "3.1" in model_name_lower and "70b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.1-70B-Instruct" - elif "llama" in model_name_lower and "3.2" in model_name_lower and "1b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.2-1B-Instruct" - elif "llama" in model_name_lower and "3.2" in model_name_lower and "3b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.2-3B-Instruct" - elif "mistral" in model_name_lower and "7b" in model_name_lower: - fallback_tokenizer_path = "mistralai/Mistral-7B-Instruct-v0.3" - elif ( - "phi-3-mini" in model_name_lower - and "128k" in model_name_lower - and "instruct" in model_name_lower - ): - fallback_tokenizer_path = "microsoft/Phi-3-mini-128k-instruct" - - if fallback_tokenizer_path: - logger.info(f"Attempting to use fallback tokenizer: {fallback_tokenizer_path}") - try: - tokenizer = AutoTokenizer.from_pretrained( - fallback_tokenizer_path, local_files_only=os.getenv("CI") == "true" - ) - logger.info(f"Successfully loaded fallback tokenizer from {fallback_tokenizer_path}") - except Exception as fallback_e: - logger.error(f"Failed to load fallback tokenizer from {fallback_tokenizer_path}: {fallback_e}") - raise fallback_e - else: - logger.error(f"No fallback tokenizer found for base model: {self.base_model_name}") - raise e + # Special handling for Mistral-Small-3.1-24B-Instruct-2503 + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + tokenizer = AutoTokenizer.from_pretrained( + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", trust_remote_code=True + ) + logger.info("Manually setting Mistral instruct-style chat template on the tokenizer.") + + mistral_template = """{% for message in messages %} + {% if message['role'] == 'system' %} + <|system|> + {{ message['content'] }} + {% elif message['role'] == 'user' %} + [INST] {{ message['content'] }} [/INST] + {% elif message['role'] == 'assistant' %} + {{ message['content'] }}{{ eos_token }} + {% endif %} + {% endfor %}""" + tokenizer.chat_template = mistral_template + else: + try: + # Try to load tokenizer from the original model path + tokenizer = AutoTokenizer.from_pretrained(self.TOKENIZER_PATH) + logger.info(f"Successfully loaded tokenizer from {self.TOKENIZER_PATH}") + except Exception as e: + logger.warning(f"Failed to load tokenizer from {self.TOKENIZER_PATH}: {e}") + + # Try to use base model tokenizer as fallback + fallback_tokenizer_path = base_model_tokenizer_mapping.get(self.base_model_name) + + # If no direct match, try to infer from model name patterns + if not fallback_tokenizer_path: + model_name_lower = self.model_name.lower() + if "qwen2.5" in model_name_lower and "0.5b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-Coder-0.5B-Instruct" + elif "qwen2.5" in model_name_lower and "1.5b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-1.5B-Instruct" + elif "qwen2.5" in model_name_lower and "3b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-3B-Instruct" + elif "qwen2.5" in model_name_lower and "7b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-7B-Instruct" + elif "qwen2.5" in model_name_lower and "14b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-14B-Instruct" + elif "qwen2.5" in model_name_lower and "32b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-32B-Instruct" + elif "qwen2.5" in model_name_lower and "72b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-72B-Instruct" + elif "llama" in model_name_lower and "3.1" in model_name_lower and "8b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.1-8B-Instruct" + elif "llama" in model_name_lower and "3.1" in model_name_lower and "70b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.1-70B-Instruct" + elif "llama" in model_name_lower and "3.2" in model_name_lower and "1b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.2-1B-Instruct" + elif "llama" in model_name_lower and "3.2" in model_name_lower and "3b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.2-3B-Instruct" + elif "mistral" in model_name_lower and "7b" in model_name_lower: + fallback_tokenizer_path = "mistralai/Mistral-7B-Instruct-v0.3" + elif ( + "phi-3-mini" in model_name_lower + and "128k" in model_name_lower + and "instruct" in model_name_lower + ): + fallback_tokenizer_path = "microsoft/Phi-3-mini-128k-instruct" + + if fallback_tokenizer_path: + logger.info(f"Attempting to use fallback tokenizer: {fallback_tokenizer_path}") + try: + tokenizer = AutoTokenizer.from_pretrained( + fallback_tokenizer_path, local_files_only=os.getenv("CI") == "true" + ) + logger.info(f"Successfully loaded fallback tokenizer from {fallback_tokenizer_path}") + except Exception as fallback_e: + logger.error( + f"Failed to load fallback tokenizer from {fallback_tokenizer_path}: {fallback_e}" + ) + raise fallback_e + else: + logger.error(f"No fallback tokenizer found for base model: {self.base_model_name}") + raise e # Add meta-compatible stop token list to the HF tokenizer if not "stop_tokens" in tokenizer.__dict__: @@ -2540,6 +2603,24 @@ def reference_transformer(self, wrap=True, load_checkpoint=False): else: return model + def reference_vision_multi_modal(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.multi_modal_projector + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rms_norm(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.multi_modal_projector.mm_soft_emb_norm + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + def reference_rms_norm(self): if self.checkpoint_type == CheckpointType.Meta: from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import RMSNorm @@ -2547,7 +2628,8 @@ def reference_rms_norm(self): return RMSNorm(self.dim, self.norm_eps) else: model = self.reference_transformer(wrap=False) - layer = model.model.norm + layers = getattr(model, "layers", getattr(model, "model", {}).layers) + layer = layers[0].input_layernorm layer._load_state_dict = layer.load_state_dict layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) return layer @@ -2558,18 +2640,23 @@ def reference_vision_transformer(self, wrap=True, load_checkpoint=False): model_cls = self.get_hf_model_cls() - if self.dummy_weights and not load_checkpoint: - config = AutoConfig.from_pretrained(self.LOCAL_HF_PARAMS[self.model_name]) - config.num_layers = self.n_layers - config.num_hidden_layers = self.n_layers - model = model_cls.from_config(config) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + from transformers import Mistral3ForConditionalGeneration as AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained(self.CKPT_DIR, torch_dtype=torch.bfloat16) else: - if self.cached_hf_model is None: - model = model_cls.from_pretrained(self.CKPT_DIR, local_files_only=os.getenv("CI") == "true") - self.cached_hf_model = model + if self.dummy_weights and not load_checkpoint: + config = AutoConfig.from_pretrained(self.LOCAL_HF_PARAMS[self.model_name]) + config.num_layers = self.n_layers + config.num_hidden_layers = self.n_layers + model = model_cls.from_config(config) else: - model = self.cached_hf_model - model.model.layers = model.model.layers[: self.n_layers] + if self.cached_hf_model is None: + model = model_cls.from_pretrained(self.CKPT_DIR, local_files_only=os.getenv("CI") == "true") + self.cached_hf_model = model + else: + model = self.cached_hf_model + model.model.layers = model.model.layers[: self.n_layers] if wrap: wrapper = HfModelWrapper(model, self.head_dim) @@ -2577,6 +2664,149 @@ def reference_vision_transformer(self, wrap=True, load_checkpoint=False): else: return model + def reference_gemma_model(self): + model = self.reference_vision_transformer(wrap=False) + layer = model + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_model(self): + model = self.reference_vision_transformer(wrap=False) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + # Mistral-Small-3.1-24B-Instruct-2503 has a different structure + layer = model.vision_tower + else: + layer = model.vision_tower.vision_model + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + + return layer + + def reference_vision_mlp(self, layer_idx=0): + model = self.reference_vision_transformer(wrap=False) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = model.vision_tower.transformer.layers[layer_idx].feed_forward + else: + layer = model.vision_tower.vision_model.encoder.layers[0].mlp + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_pixtral_image_block(self, layer_num=0): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.transformer.layers[layer_num] + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rms(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.transformer.layers[0].ffn_norm + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.patch_conv + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_conv2d_patch(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.patch_conv + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_pixtral_image_block(self, layer_num=0): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.transformer.layers[layer_num] + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rms(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.transformer.layers[0].ffn_norm + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_conv2d_patch(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.patch_conv + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_siglip_patch_embed(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.embeddings.patch_embedding + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_pos_embedding(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.embeddings.position_embedding + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_embedding(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.embeddings + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_layernorm(self, layer_name="layer_norm1"): + if layer_name == "layer_norm1": + layer = model.vision_tower.vision_model.encoder.layers[0].layer_norm1 + elif layer_name == "layer_norm2": + layer = model.vision_tower.vision_model.encoder.layers[0].layer_norm2 + else: + layer = model.vision_tower.vision_model.post_layernorm + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_attention(self, layer_idx=0): + model = self.reference_vision_transformer(wrap=False) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = model.vision_tower.transformer.layers[layer_idx].attention + else: + layer = model.vision_tower.vision_model.encoder.layers[0].self_attn # Common naming + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rot_emb(self): + model = self.reference_vision_transformer(wrap=False) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = model.vision_tower.patch_positional_embedding + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_encoder_block(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.encoder.layers[0] + # layer._load_state_dict = layer.load_state_dict + # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_encoder(self): + model = self.reference_vision_transformer(wrap=False) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = model.vision_tower.transformer + else: + layer = model.vision_tower.vision_model.encoder + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + def reference_mlp(self): if self.checkpoint_type == CheckpointType.Meta: from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import FeedForward