diff --git a/models/tt_transformers/demo/simple_vision_demo.py b/models/tt_transformers/demo/simple_vision_demo.py index cbd09301ab89..145a32d8c9b0 100644 --- a/models/tt_transformers/demo/simple_vision_demo.py +++ b/models/tt_transformers/demo/simple_vision_demo.py @@ -27,7 +27,9 @@ import ttnn from models.demos.utils.llm_demo_utils import create_benchmark_data, verify_perf from models.perf.benchmarking_utils import BenchmarkProfiler +from models.tt_transformers.tt.common import hf_multimodal_encode from models.tt_transformers.tt.generator import Generator +from models.tt_transformers.tt.model_config import CheckpointType def get_batch_sampler(temperature, top_p, tokenizer): @@ -62,6 +64,7 @@ def create_multimodal_model( ): from models.tt_transformers.tt.model_config import ModelArgs from models.tt_transformers.tt.multimodal.llama_vision_model import CrossAttentionTransformer + from models.tt_transformers.tt.multimodal.mistral_24b.model import MistralTransformer tt_model_args = ModelArgs(mesh_device, max_batch_size=max_batch_size) assert tt_model_args.is_vision(), "This model is multimodal" @@ -76,14 +79,25 @@ def create_multimodal_model( if checkpoint is None: checkpoint = tt_model_args.load_state_dict() - model = CrossAttentionTransformer( - mesh_device, - state_dict=checkpoint, - weight_cache_path=tt_model_args.weight_cache_path(dtype), - dtype=dtype, - configuration=tt_model_args, - use_paged_kv_cache=use_paged_kv_cache, - ) + + if tt_model_args.base_model_name == "Mistral-Small-3.1-24B": + model = MistralTransformer( + mesh_device=mesh_device, + state_dict=checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(ttnn.bfloat8_b), + dtype=ttnn.bfloat8_b, + args=tt_model_args, + use_paged_kv_cache=use_paged_kv_cache, + ) + else: + model = CrossAttentionTransformer( + mesh_device, + state_dict=checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=tt_model_args, + use_paged_kv_cache=use_paged_kv_cache, + ) return tt_model_args, model, checkpoint @@ -136,7 +150,7 @@ def prepare_generator_args( ) @pytest.mark.parametrize( "test_type,max_seq_len", - (("normal", 512),), + (("normal", 2048),), ids=["normal"], ) @pytest.mark.parametrize( @@ -182,9 +196,6 @@ def test_multimodal_demo_text( profiler = BenchmarkProfiler() profiler.start("run") - ckpt_dir = os.environ["LLAMA_DIR"] - tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") - num_devices = mesh_device.get_num_devices() if isinstance(mesh_device, ttnn.MeshDevice) else 1 max_batch_size *= data_parallel # input batch_size is interpreted as size per DP group @@ -195,11 +206,26 @@ def test_multimodal_demo_text( max_batch_size=max_batch_size, max_seq_len=max_seq_len, ) + + HF_MODEL = model_args[0].checkpoint_type == CheckpointType.HuggingFace + + if not HF_MODEL: + ckpt_dir = os.environ["LLAMA_DIR"] + tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") + + tokenizer = Tokenizer(model_path=tokenizer_path) + formatter = ChatFormat(tokenizer) + else: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_args[0].CKPT_DIR) + generator = Generator(model, model_args, mesh_device) - tokenizer = Tokenizer(model_path=tokenizer_path) - formatter = ChatFormat(tokenizer) - xattn_caches = [model.setup_cache(model_args[i].max_batch_size) for i, model in enumerate(generator.model)] + xattn_caches = [ + model.setup_cache(model_args[i].max_batch_size) if not HF_MODEL else None + for i, model in enumerate(generator.model) + ] # Create random images for trace capture with specific dimensions trace_img_560x560 = create_random_image(560, 560) @@ -260,10 +286,12 @@ def test_multimodal_demo_text( total_users = len(dialogs) num_batches = total_users // max_batch_size - sampler = get_batch_sampler(temperature, top_p, tokenizer) + sampler = get_batch_sampler(temperature, top_p, model_args[0].tokenizer) _num_prefill_tokens = 0 _num_decode_tokens = 0 + prompt_encoder = hf_multimodal_encode if HF_MODEL else formatter.encode_dialog_prompt + for iter_num in range(warmup_iters + 1): logger.info(f"Iteration {iter_num}") current_dialogs = trace_dialogs + dialogs @@ -273,7 +301,8 @@ def test_multimodal_demo_text( for msg in dialog: print(f"{msg.role.capitalize()}: {msg.content}\n") batch_model_input = [ - formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) for dialog in batch_dialogs + prompt_encoder(dialog, processor) if HF_MODEL else prompt_encoder(dialog, tool_prompt_format=False) + for dialog in batch_dialogs ] # Do initial prefill @@ -288,7 +317,7 @@ def test_multimodal_demo_text( total_lens = prefill_lens + max_gen_len # Create padded tokens tensor for batch - pad_id = tokenizer.pad_id + pad_id = tokenizer.pad_token_id if HF_MODEL else tokenizer.pad_id bsz = len(prompt_tokens) tokens = torch.full((bsz, max(total_lens)), pad_id, dtype=torch.long) @@ -312,6 +341,7 @@ def test_multimodal_demo_text( xattn_caches, total_lens, prefill_lens, + ) # Get cached prefill time @@ -329,6 +359,7 @@ def test_multimodal_demo_text( xattn_caches, total_lens, prefill_lens, + ) prefill_end = time.perf_counter() @@ -375,12 +406,13 @@ def test_multimodal_demo_text( ) # gen_idx is (num_tokens - 1) to avoid counting compile iter # Log full text output for each user in batch + vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256] for user_id in range(max_batch_size): # Remove <|image|> tokens since they break the tokenizer tokens_out = [ - t if t not in vision_tokens else tokenizer.pad_id + t if t not in vision_tokens else pad_id for t in tokens[user_id].tolist()[: position_id[user_id] + 2] ] text = tokenizer.decode(tokens_out) diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/pipeline/test_end2end.py b/models/tt_transformers/tests/multimodal/mistral_24b/pipeline/test_end2end.py new file mode 100644 index 000000000000..fd2feba20036 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/pipeline/test_end2end.py @@ -0,0 +1,527 @@ +"""Test for Mistral-24B End-to-End Vision-Text Pipeline""" + +import torch +import pytest +from loguru import logger +from PIL import Image +import os +import ttnn + +from models.tt_transformers.tt.common import ( + sample_host, + PagedAttentionConfig, + preprocess_inputs_prefill, +) + +from models.tt_transformers.tt.model_config import DecodersPrecision +from models.tt_transformers.tt.multimodal.mistral_24b.model import MistralTransformer as Transformer + +from models.tt_transformers.tt.generator import Generator + +from models.tt_transformers.tt.multimodal.mistral_24b.vision_model import TtMistralVisionTransformer +from models.utility_functions import skip_for_grayskull, skip_for_blackhole + +from models.tt_transformers.tt.model_config import ModelArgs +from transformers import AutoProcessor, AutoModelForVision2Seq + +import re + + +def run_reference_demo_pipeline(messages, model_id="mistralai/Mistral-Small-3.1-24B-Instruct-2503"): + """ + Run Hugging Face reference demo model (Vision-Text pipeline) using given messages. + """ + logger.info("Running reference HF vision-text model...") + + processor = AutoProcessor.from_pretrained(model_id) + model = AutoModelForVision2Seq.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.bfloat16, + ) + + model.eval() + + # Apply chat template + prompt_text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, padding=True, padding_side="left" + ) + + # Extract images (already loaded) + image_inputs = [] + for msg in messages: + for item in msg["content"]: + if item["type"] == "image": + image_inputs.append(item["image"]) + + # Tokenize and move to model device + inputs = processor( + text=[prompt_text], + images=image_inputs, + return_tensors="pt", + ).to(model.device, dtype=torch.bfloat16) + + with torch.no_grad(): + generated_ids = model.generate( + **inputs, + max_new_tokens=100, + temperature=0.0, + top_p=0.9, + do_sample=False, + pad_token_id=model.config.pad_token_id, + ) + + # Decode + output = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + logger.info(f"HF reference model output: {output}") + + chat = parse_chat_output(output) + display_chat(logger, chat) + + return output + + +def parse_chat_output(text): + """Parse chat output format from generated text.""" + pattern = r"<\|(?Puser|assistant)\|>\s*(?P.*?)(?=<\|(?:user|assistant|end)\|>|$)" + matches = re.finditer(pattern, text, re.DOTALL) + return [(match.group("role"), match.group("message").strip()) for match in matches] + + +def display_chat(logger, conversation): + """Display chat conversation in formatted output.""" + for role, message in conversation: + if role == "user": + logger.info(f"👤 User: {message}") + elif role == "assistant": + logger.info(f"🤖 Assistant: {message}") + + +def setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations): + """Setup model arguments for vision-enabled model (Single Responsibility).""" + instruct = True if weights == "instruct" else False + + model_args = ModelArgs( + mesh_device=mesh_device, + instruct=instruct, + optimizations=optimizations, + max_seq_len=max_seq_len, + max_batch_size=batch_size, + ) + + return model_args, instruct + + +def setup_vision_prompts_and_tokenizer(model_args, instruct): + """Setup multimodal prompts and tokenizer for vision-enabled model.""" + image_path = "real_inputs/pixtral_transformer_inputs/people.jpg" + image = Image.open(image_path).convert("RGB") + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + # "image": "https://raw.githubusercontent.com/yavuzceliker/sample-images/refs/heads/main/images/image-1.jpg", + {"type": "text", "text": "Tell me what you see in the picture?"}, + ], + } + ] + + tokenizer = model_args.tokenizer + return messages, tokenizer + + +def process_vision_info(messages): + """Extract images (already opened) from messages.""" + image_inputs = [] + video_inputs = None # Not used + + for msg in messages: + content = msg.get("content", []) + for item in content: + if item.get("type") == "image": + image_inputs.append(item["image"]) + + return image_inputs, video_inputs + + +def process_real_vision_inputs(messages, model_args): + """Process real image inputs using AutoProcessor (Interface Segregation).""" + processor = AutoProcessor.from_pretrained(os.getenv("HF_MODEL")) + + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, padding=True, padding_side="left" + ) + + image_inputs, video_inputs = process_vision_info(messages) + # image_inputs, video_inputs = None, None + + encoded = processor( + text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt", return_dict=True + ).to("cpu", dtype=torch.bfloat16) + input_ids = encoded["input_ids"] + pixel_values = encoded["pixel_values"] if "pixel_values" in encoded else None + attention_mask = encoded["attention_mask"] if "attention_mask" in encoded else None + image_sizes = encoded["image_sizes"] if "image_sizes" in encoded else None + + return { + "input_ids": input_ids, + "pixel_values": pixel_values, + "attention_mask": attention_mask, + "image_sizes": image_sizes, + "processor": processor, + } + + +def load_separate_models_like_test_end2end(model_args, mesh_device, dtype, paged_attention, page_params): + """Load separate vision and text models following test_end2end.py pattern.""" + state_dict = model_args.load_state_dict() + + vision_prefix = "vision_tower." + # Setup paged attention config (exactly like test_end2end.py) + paged_attention_config = None + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + + # Load vision model (exactly like test_end2end.py) + vision_model = TtMistralVisionTransformer( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=vision_prefix, + dtype=dtype, + model_args=model_args, + ) + + # Load text model (exactly like test_end2end.py) + text_model = Transformer( + args=model_args, + mesh_device=mesh_device, + dtype=dtype, + state_dict=state_dict, + weight_cache_path=model_args.weight_cache_path(dtype), + paged_attention_config=paged_attention_config, + ) + logger.info("Separate vision and text models loaded like test_end2end.py") + return vision_model, text_model + + +def run_generation_exactly_like_test_end2end( + vision_model, + text_model, + processed_inputs, + model_args, + page_table=None, + paged_attention_config=None, + max_gen_len=20, + repetition_ngram_size=3, +): + """Run generation following the EXACT pattern from test_end2end.py.""" + input_ids = processed_inputs["input_ids"] + + logger.info("Running generation exactly like test_end2end.py...") + + logger.info("Running Vision Model...") + generator = Generator([text_model], [model_args], vision_model.mesh_device, tokenizer=model_args.tokenizer) + tt_kv_cache = [[l.attention.layer_past for l in text_model.layers]] if paged_attention_config else None + + input_tokens_prefill = input_ids + batch_size = input_tokens_prefill.shape[0] + + prompt_text = model_args.tokenizer.decode(input_ids[0].tolist()) + input_prompts = [prompt_text] + + ( + input_tokens_prefill_pt, + encoded_prompts, + decoding_pos, + prefill_lens, + ) = preprocess_inputs_prefill( + input_prompts, + model_args.tokenizer, + [model_args], + instruct=True, + max_generated_tokens=max_gen_len, + max_prefill_len=8192, + ) + + input_tokens_prefill_pt = torch.stack(input_tokens_prefill_pt).view(batch_size, -1) + + logger.info("Running prefill...") + logits = generator.prefill_forward_text( + input_tokens_prefill_pt, + page_table=page_table, + kv_cache=tt_kv_cache, + prompt_lens=decoding_pos, + vision_model=vision_model, + processed_inputs=processed_inputs, + ) + + prefilled_token = torch.argmax(logits, dim=-1) + prefilled_token_decoded_res = model_args.tokenizer.decode(prefilled_token[0].item()) + logger.info(f"prefilled_token_decoded_res: {prefilled_token_decoded_res}") + + logger.info(f"Prefilled token: {prefilled_token}") + + import torch.nn.functional as F + + logger.info(f"Encoded prompt: {encoded_prompts[0]}") + logger.info(f"Decoded prompt: {model_args.tokenizer.decode(encoded_prompts[0])}") + + # logits: [1, 1, vocab_size] + last_logits = logits[0, -1] # shape: [vocab_size] + probs = F.softmax(last_logits, dim=-1) + + top_k = 5 + topk_probs, topk_indices = torch.topk(probs, k=top_k) + + topk_tokens = [model_args.tokenizer.decode([idx.item()]) for idx in topk_indices] + + logger.info("🔍 Top-5 predicted tokens (with probabilities):") + for i in range(top_k): + logger.info(f"{i+1}. Token: '{topk_tokens[i]}' (ID={topk_indices[i].item()}), P={topk_probs[i].item():.4f}") + + all_outputs = [encoded_prompts[0][: prefill_lens[0]]] + all_outputs[0].append(int(prefilled_token[0].item())) + + current_pos = torch.tensor([decoding_pos[0]]) + out_tok = prefilled_token + generation_length = max_gen_len + + results = [] + + logger.info("Starting decode loop...") + for iteration in range(generation_length): + logger.info(f"[Text] Decoding token {iteration}, current_pos: {current_pos.item()}") + + logits = generator.decode_forward_text( + out_tok, + current_pos, + enable_trace=False, + page_table=page_table, + kv_cache=tt_kv_cache, + ) + + _, out_tok = sample_host( + logits, + temperature=0, + top_p=0.9, + ) + + token_id = out_tok[0].item() + decoded_token = model_args.tokenizer.decode([token_id]) + logger.info(f"Generated token {iteration}: ID={token_id}, text='{decoded_token}'") + + # Stop if EOS detected + if token_id == model_args.tokenizer.eos_token_id: + logger.info("EOS token detected, stopping generation.") + break + + # Stop if repetition detected (n-gram) + if len(all_outputs[0]) >= repetition_ngram_size * 2: + last_ngram = tuple(all_outputs[0][-repetition_ngram_size:]) + for i in range(len(all_outputs[0]) - repetition_ngram_size): + if tuple(all_outputs[0][i : i + repetition_ngram_size]) == last_ngram: + logger.info(f"Detected {repetition_ngram_size}-gram repetition, stopping.") + break + + # Create result object + result = type("TokenResult", (), {"token": token_id, "text": decoded_token})() + + results.append(result) + + all_outputs[0].append(token_id) + current_pos += 1 + + # Early stopping (exactly like test_end2end.py) + if len(all_outputs[0]) >= 5 and all(t == all_outputs[0][-1] for t in all_outputs[0][-5:]): + logger.warning(f"Detected exact repetition of token {all_outputs[0][-1]} five times in a row. Stopping.") + break + + # Final response (exactly like test_end2end.py) + response = model_args.tokenizer.decode(all_outputs[0], skip_special_tokens=True) + logger.info(f"📝 Each iteration Generated Response:\n{response}") + logger.info(f"📝 Each iteration Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") + chat = parse_chat_output(response) + display_chat(logger, chat) + + logger.info(f" Each iteration Generated {len(results)} tokens successfully") + + # Final response (exactly like test_end2end.py) + response = model_args.tokenizer.decode(all_outputs[0], skip_special_tokens=True) + logger.info(f"📝 Final Generated Response:\n{response}") + logger.info(f"📝 Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") + chat = parse_chat_output(response) + display_chat(logger, chat) + + logger.info(f"Generated {len(results)} tokens successfully") + return results + + +def validate_e2e_outputs(results, expected_min_tokens=1): + """Validate end-to-end pipeline outputs.""" + if not results: + logger.error("No results generated from E2E pipeline") + return False + + if len(results) < expected_min_tokens: + logger.warning(f"Generated only {len(results)} tokens, expected at least {expected_min_tokens}") + return False + + # Check if tokens are valid + for result in results: + if not hasattr(result, "token") or not hasattr(result, "text"): + logger.error("Invalid result format") + return False + + logger.info("E2E pipeline validation passed") + return True + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@skip_for_blackhole("Failing on DRAM harvested P100a, see #21419") +@pytest.mark.timeout(1800) +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize( + "weights, layers", + [ + ("instruct", None), + ], + ids=["full"], +) +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False, + ), + ids=( + "paged_attention", + # "default_attention", + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (1024,), # Use smaller seq_len like test_end2end.py to avoid memory issues +) +@pytest.mark.parametrize( + "optimizations", + [ + lambda model_args: DecodersPrecision.accuracy(model_args.n_layers, model_args.model_name), + ], + ids=["accuracy"], +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "N150x4": (1, 4), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +# @pytest.mark.parametrize("device_params", [{"l1_small_size": 1584864, "trace_region_size": 0}], indirect=True) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 10 * 1024}], indirect=True) +def test_e2e_vision_text_pipeline( + weights, + layers, + max_seq_len, + batch_size, + paged_attention, + page_params, + optimizations, + mesh_device, + reset_seeds, + request, + device_params, +): + """Test end-to-end vision-text pipeline using proper Generator methods.""" + logger.info("Starting E2E vision-text pipeline test") + + # Use bfloat8_b like test_end2end.py for better memory efficiency + dtype = ttnn.bfloat8_b + + # Setup vision-enabled model configuration + model_args, instruct = setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations) + + if layers is not None: + model_args.n_layers = layers + + # Setup vision prompts and tokenizer + messages, tokenizer = setup_vision_prompts_and_tokenizer(model_args, instruct) + + # logger.info("Running reference HF vision-text model using messages..... ") + # hf_output = run_reference_demo_pipeline(messages) + + # Process real vision inputs from images + processed_inputs = process_real_vision_inputs(messages, model_args) + + # Load separate models following test_end2end.py pattern + logger.info("Loading separate vision and text models like test_end2end.py...") + vision_model, text_model = load_separate_models_like_test_end2end( + model_args, mesh_device, dtype, paged_attention, page_params + ) + + # Setup page table for paged attention (exactly like test_end2end.py) + page_table_tt = None + paged_attention_config = None + + # Prepare page table for paged attention (exactly like test_end2end.py) + page_table = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, -2) if batch_size > 1 else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + # Run generation following EXACT test_end2end.py pattern + logger.info("Running generation following EXACT test_end2end.py pattern...") + results = run_generation_exactly_like_test_end2end( + vision_model, text_model, processed_inputs, model_args, page_table, paged_attention_config, max_gen_len=600 + ) + + # Validate results + validation_passed = validate_e2e_outputs(results, expected_min_tokens=1) + + # Final validation + if validation_passed and len(results) > 0: + logger.info("✅ E2E vision-text pipeline test PASSED!") + logger.info(f"Successfully generated {len(results)} tokens") + + # Log generated tokens for debugging + for i, result in enumerate(results[:5]): + logger.info(f"Token {i}: {result.token} -> '{result.text}'") + else: + logger.error("❌ E2E pipeline test failed") + assert False, f"E2E pipeline failed - generated {len(results)} tokens, validation: {validation_passed}" \ No newline at end of file diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_conv2d.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_conv2d.py new file mode 100644 index 000000000000..cdae8d2ee702 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_conv2d.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.multimodal.mistral_24b.vision_conv2d import TtMistralConv2dPatch +from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from ttnn import ConcatMeshToTensor + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_conv2d_inference( + mesh_device, + use_program_cache, + reset_seeds, +): + pcc_required = 0.9999 + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "vision_tower.patch_conv." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + num_devices = model_args.num_devices + + ##### Create input tensor for the all gather ##### + B, NCH, H, W = (1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size) + in_channels, out_channels, kernel_size, stride, bias = ( + 3, + model_args.vision_dim, + model_args.vision_patch_size, + model_args.vision_patch_size, + False, + ) + + assert NCH == in_channels, "Number of channels in input tensor should match in_channels for the Conv2d patch." + assert type(kernel_size) == int, "Only symmetric kernel_size is currently supported." + assert kernel_size == stride, "Only same kernel_size and stride are currently supported." + + assert H % kernel_size == 0, "Height should be divisible by kernel_size." + assert W % kernel_size == 0, "Width should be divisible by kernel_size." + + ##### Prepare inputs ##### + input_tensor = torch.randn((B, NCH, H, W)) + logger.info(f"Input tensor shape: {input_tensor.shape}") + + + reference_model = model_args.reference_conv2d_patch() + reference_model.load_state_dict(partial_state_dict) + reference_output = reference_model(input_tensor) + + tt_model = TtMistralConv2dPatch( + mesh_device, + state_dict, + first_layer_prefix, + dtype, + in_channels, + out_channels, + kernel_size, + stride, + bias, + ) + tt_output = tt_model(input_tensor) + + ##### Check the outputs ##### + out = ttnn.from_device(tt_output) + tt_output_torch = ttnn.to_torch(out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=2)) + + # Only select output from one device + tt_output_torch = tt_output_torch[0, ..., :out_channels] + + # 1. Restore batch dim + tt_output_torch = tt_output_torch.unsqueeze(0) + + # 2. Permute to match Conv2D output: (N, C_out, H_out, W_out) + tt_output_torch = tt_output_torch.permute(0, 2, 1).reshape(1, 1024, 64, 64) + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_mmp.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_mmp.py new file mode 100644 index 000000000000..c8b2b9d56221 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_mmp.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.multimodal.mistral_24b.vision_mmp import TTMistral3MultiModalProjector +from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +def test_multi_modal_inference(seq_len, batch_size, reset_seeds, device): + print("device:", device) + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_vision_multi_modal() + # print(reference_model) + first_layer_prefix = "multi_modal_projector." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model.load_state_dict(partial_state_dict) + # create input tensor for multi_modal_projector layer + batch_size = 1 + seq_length = 1152 + patches_per_image = 64 + num_patches = patches_per_image * patches_per_image + input = torch.randn((1656, 1024)) # image_features: torch.Size([1656, 1024]) + + image_size = torch.tensor([[504, 644]], dtype=torch.int32) + + reference_output = reference_model(input, image_size) + print("reference_output:", reference_output.shape) + + # DistributedNorm inputs are fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode) + tt_input = ttnn.from_torch( + input, + device=device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, -1), mesh_shape=tt_model_args.cluster_shape), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + tt_image_size = ttnn.from_torch( + image_size, + device=device, + dtype=ttnn.int32, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(device), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + print("state_dict ", state_dict.keys()) + tt_model = TTMistral3MultiModalProjector( + mesh_device=device, + args=tt_model_args, + state_dict=state_dict, + state_dict_prefix="multi_modal_projector.", + dtype=dtype, + eps=1e-06, # layer_norm_eps + ) + + # print("tt_input:", tt_input.memory_config()) + + tt_output = tt_model(tt_input, tt_image_size) + + output_torch = ttnn.to_torch(tt_output) + + print("output_torch:", output_torch.shape) + # # transpose output from NHWC to NCHW + # output_torch = output_torch.permute(0, 2, 1) + passing, pcc_message = comp_pcc(reference_output, output_torch) + pcc_required = 0.999 + logger.info(comp_allclose(reference_output, output_torch)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py new file mode 100644 index 000000000000..cead380210b1 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py @@ -0,0 +1,95 @@ +import os + +import pytest +import torch +from loguru import logger + +import ttnn + +# models/tt_transformers/tt/common.py +from models.tt_transformers.tt.multimodal.mistral_24b.vision_rope import VisionRotarySetup as RotarySetup +from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_rot_emb(seq_len, batch_size, use_program_cache, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + partial_state_dict = {} + + reference_model = tt_model_args.reference_vision_rot_emb() + reference_model.load_state_dict(partial_state_dict) + + image_size = tt_model_args.vision_image_size + patch_size = tt_model_args.vision_patch_size + dim = tt_model_args.vision_head_dim + num_patches_per_dim = image_size // patch_size + num_patches = num_patches_per_dim * num_patches_per_dim + + + + position_ids = torch.arange(4096, dtype=torch.long) + + x = torch.randn(batch_size, 4096, 1024) + + cos, sin = reference_model(x, position_ids) + tt_model = RotarySetup( + device, + batch_size, + dim, + image_size, + patch_size, + num_patches, + tt_model_args.vision_rope_theta, + scale_factor=None, + orig_context_len=num_patches, + datatype=dtype, + ) + + cos2, sin2 = tt_model.get_rot_mats(position_ids) + cos2 = ttnn.from_device(cos2) + cos2 = ttnn.to_torch(cos2) + cos2 = cos2.squeeze(0) + + sin2 = ttnn.from_device(sin2) + sin2 = ttnn.to_torch(sin2) + sin2 = sin2.squeeze(0) + + passing, pcc_message = comp_pcc(cos, cos2) + + logger.info(comp_allclose(cos, cos2)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"COS PCC value is lower than {0.99} for some of the outputs. Check Warnings!" + + passing, pcc_message = comp_pcc(sin, sin2) + + logger.info(comp_allclose(sin, sin2)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"SIN PCC value is lower than {0.99} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_image_block.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_image_block.py new file mode 100644 index 000000000000..89843cccd027 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_image_block.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.multimodal.mistral_24b.vision_pixtral_image_block import TtPixtralImageTransformerBlock +from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 1),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_pixtral_image_block(batch, num_chunks, mesh_device, reset_seeds): + dtype = ttnn.bfloat16 + pcc_required = 0.99 + + model_args = ModelArgs(mesh_device) + model_args.n_layers = 1 + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower.transformer.layers.0." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + print("partial_state_dict keys:", partial_state_dict.keys()) + + dim = model_args.vision_dim + heads = model_args.vision_attn_n_heads + seq_len = model_args.vision_chunk_ntok - 1 + head_dim = dim // heads + + reference_model = model_args.reference_pixtral_image_block() + reference_model.load_state_dict(partial_state_dict) + + tt_model = TtPixtralImageTransformerBlock( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + ) + + pt_attention_input = torch.randn(batch, seq_len, dim).to(torch.bfloat16) + attention_mask = torch.zeros(batch, 1, seq_len, seq_len).to(torch.bfloat16) + + B, T, D = pt_attention_input.shape + cos = torch.ones((1, T, head_dim)).to(torch.bfloat16) + sin = torch.zeros((1, T, head_dim)).to(torch.bfloat16) + + positional_embedding = (cos, sin) + attention_input = model_args.prepare_residual_tensor_prefill( + pt_attention_input, + force_replicated=True, + ) + tt_mask = ttnn.from_torch( + attention_mask, + device=mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + tt_out = tt_model(attention_input, mask=tt_mask) + reference_output = reference_model( + pt_attention_input, attention_mask=attention_mask, position_embeddings=positional_embedding + )[0] + + print("tt_out shape:", tt_out.shape) + print("reference_output shape:", reference_output.shape) + + tt_output_torch = ttnn.to_torch(tt_out).squeeze(0) + print("tt_output_torch shape:", tt_output_torch.shape) + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_transformer.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_transformer.py new file mode 100644 index 000000000000..9e1cc7a4ac45 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_transformer.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.multimodal.mistral_24b.vision_pixtral_transformer import TtPixtralTransformer +from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 1),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_image_transformer_inference(batch, num_chunks, mesh_device): + pcc_required = 0.99 + + model_args = ModelArgs(mesh_device) + dtype = ttnn.bfloat16 + + state_dict = model_args.load_state_dict() + n_layers = model_args.vision_n_layers + first_layer_prefix = "vision_tower.transformer." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + dim = model_args.vision_dim + heads = model_args.vision_attn_n_heads + seq_len = model_args.vision_chunk_ntok - 1 + head_dim = dim // heads + + reference_model = model_args.reference_vision_encoder() + reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + all_tests_pass = True + + tt_model = TtPixtralTransformer( + mesh_device, + state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=None, + dtype=dtype, + configuration=model_args, + layers=n_layers, + ) + + # Create PT input + pt_attention_input = torch.rand(batch, seq_len, dim).to(torch.bfloat16) + attention_mask = torch.zeros(batch, 1, seq_len, seq_len).to(torch.bfloat16) + + B, T, D = pt_attention_input.shape + cos = torch.ones((1, T, head_dim)).to(torch.bfloat16) + sin = torch.zeros((1, T, head_dim)).to(torch.bfloat16) + + position_embeddings = (cos, sin) + + + cos, sin = position_embeddings + + cos_t = ttnn.from_torch( + cos, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + sin_t = ttnn.from_torch( + sin, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + attention_input = model_args.prepare_residual_tensor_prefill( + pt_attention_input, + force_replicated=True, + ) + tt_mask = ttnn.from_torch( + attention_mask, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + with torch.no_grad(): + tt_out = tt_model(attention_input, mask=tt_mask, position_embeddings=(cos_t, sin_t)) + reference_output = reference_model( + pt_attention_input, attention_mask=attention_mask, position_embeddings=(cos, sin) + )[0] + tt_output_torch = ttnn.to_torch(tt_out) + tt_output_torch = tt_output_torch.squeeze(0) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + if not passing: + logger.warning(f"PCC value -- {pcc_message} -- is lower than {pcc_required} for the output.") + else: + logger.info(f"PCC: {pcc_message}") + logger.info(comp_allclose(reference_output, tt_output_torch)) + all_tests_pass = all_tests_pass and passing + + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_attention.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_attention.py new file mode 100644 index 000000000000..61cf9686dcef --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_attention.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.multimodal.mistral_24b.vision_attention import TtMistralImageAttention as TtLlamaImageAttention +from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_vision_attention(mesh_device, seq_len, batch_size): + + dtype = ttnn.bfloat8_b + + model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=128) + model_args.n_layers = 1 + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower.transformer.layers.0.attention." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model = model_args.reference_vision_attention() + reference_model.load_state_dict(partial_state_dict) + + hidden_size = model_args.vision_dim + n_heads = model_args.vision_attn_n_heads + head_dim = hidden_size // n_heads + + tt_model = TtLlamaImageAttention( + mesh_device, + state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + ) + + dim = model_args.vision_dim + pt_attention_input = torch.randn(batch_size, seq_len, dim).to(torch.bfloat16) + attention_mask = torch.zeros(batch_size, 1, seq_len, seq_len).to(torch.bfloat16) + + B, T, D = pt_attention_input.shape + cos = torch.ones((1, T, head_dim)).to(torch.bfloat16) + sin = torch.zeros((1, T, head_dim)).to(torch.bfloat16) + + + + attention_input = model_args.prepare_residual_tensor_prefill( + pt_attention_input, + force_replicated=True, + ) + + + cos_t = ttnn.from_torch( + cos, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + sin_t = ttnn.from_torch( + sin, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + tt_mask = ttnn.from_torch( + attention_mask, + device=mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + tt_out = tt_model(attention_input, position_embeddings=(cos_t, sin_t), mask=tt_mask) + tt_output_torch = ttnn.to_torch(tt_out, device=mesh_device)[0, :, :, :] + + reference_output = reference_model(pt_attention_input, attention_mask, position_embeddings=(cos, sin))[0] + pcc_required = 0.99 + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_mlp.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_mlp.py new file mode 100644 index 000000000000..ab833311e982 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_mlp.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn + +# from models.tt_transformers.tt.mlp import MLP +from models.tt_transformers.tt.multimodal.mistral_24b.vision_mlp import MistralTTVisionMLP as MLP +from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (64 * 1024,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_mlp_inference(seq_len, batch_size, mesh_device, reset_seeds): + dtype = ttnn.bfloat8_b + mode = "decode" if seq_len <= 32 else "prefill" + + model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=128) + model_args.n_layers = 1 + state_dict = model_args.load_state_dict() + + print("state_dict keys:", state_dict.keys()) + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "vision_tower.transformer.layers.0.feed_forward." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model = model_args.reference_vision_mlp() + print(partial_state_dict.keys()) + reference_model.load_state_dict(partial_state_dict) + + tt_model = MLP( + mesh_device=mesh_device, + args=model_args, + state_dict=state_dict, + weight_cache_path=model_args.weight_cache_path(dtype), + state_dict_prefix="vision_tower.transformer.layers.0.feed_forward.", + dtype=dtype, + + ) + torch_input = torch.randn(1, 1, seq_len, 1024).to(torch.bfloat16) + print("torch_input shape:", torch_input.shape) + reference_output = reference_model(torch_input) + tt_input = ttnn.from_torch( + torch_input, + device=mesh_device, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, dims=(None, 3) if model_args.is_galaxy else (None, None), mesh_shape=model_args.cluster_shape + ), # When both dims are None, the mapper used is `ReplicateTensorToMesh` + dtype=ttnn.bfloat16, + memory_config=( + ( + tt_model.model_config["MLP_ACT_MEMCFG"] + if model_args.is_galaxy + else model_args.model_config["SHARDED_MLP_INPUT_MEMCFG"] + ) + if mode == "decode" + else ttnn.DRAM_MEMORY_CONFIG + ), + layout=ttnn.TILE_LAYOUT, + ) + + logger.info("Run MLP") + tt_output = tt_model(tt_input) + + tt_output_torch = ttnn.to_torch( + tt_output, + mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape), + ) + + tt_output_torch = tt_output_torch[:, :1, :, :] + + pcc_required = 0.99 + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + if passing: + logger.info("MLP Passed!") + else: + logger.warning("MLP Failed!") + + assert passing, f"MLP output does not meet PCC requirement {pcc_required}: {pcc_message}." diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py new file mode 100644 index 000000000000..1a7e934ed912 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.multimodal.mistral_24b.vision_model import TtMistralVisionTransformer +from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +def get_image_features(vision_tower, projector, input_tensor, image_sizes): + """ + Get image features from the vision tower and projector. + """ + vision_token = vision_tower(input_tensor, image_sizes).last_hidden_state + image_features = projector(vision_token.squeeze(0), image_sizes) + return image_features + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_mistral_vision_model(mesh_device, reset_seeds): + pcc_required = 0.97 + dtype = ttnn.bfloat8_b + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix) + } + + print("partial_state_dict keys:", partial_state_dict.keys()) + + ##### Reference model output (Torch) ##### + reference_model = model_args.reference_vision_model() + reference_model.load_state_dict(partial_state_dict) + + mmp_first_layer_prefix = "multi_modal_projector." + + mmp_partial_state_dict = { + k[len(mmp_first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(mmp_first_layer_prefix)) + } + + print("mmp_partial_state_dict keys:", mmp_partial_state_dict.keys()) + + reference_mmp = model_args.reference_vision_multi_modal() + reference_mmp.load_state_dict(mmp_partial_state_dict) + + B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size + input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16) + + # reference_output = reference_model(input_tensor, image_sizes=[(H, W)]) + reference_output = get_image_features(reference_model, reference_mmp, input_tensor, image_sizes=[(H, W)]) + + # ##### TT Model: TtMistralVisionTransformer ##### + vision_model = TtMistralVisionTransformer( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + dtype=dtype, + model_args=model_args, + ) + + tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) # [0] + tt_output = ttnn.from_device(tt_output) + tt_output = ttnn.to_torch(tt_output) + + non_zero_indices = tt_output.ne(0).nonzero(as_tuple=True) + tt_output = tt_output[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC below {pcc_required}. {pcc_message}" diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py new file mode 100644 index 000000000000..84f983435b38 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py @@ -0,0 +1,108 @@ +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm +from models.tt_transformers.tt.distributed_norm import DistributedNorm +from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_rmsnorm_inference(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_vision_rms() + + first_layer_prefix = "vision_tower.transformer.layers.0.ffn_norm." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + + reference_model.load_state_dict(partial_state_dict) + + tt_inner_norm = RMSNorm( + device=device, + dim=1024, + state_dict=state_dict, + state_dict_prefix="vision_tower.transformer.layers.0.", + weight_key="ffn_norm", + weight_dtype=dtype, + is_distributed=tt_model_args.is_distributed_norm, + sharded_program_config=tt_model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=tt_model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], + ) + + # Wrap it in DistributedNorm + tt_model = DistributedNorm(tt_inner_norm, tt_model_args, TG=tt_model_args.is_galaxy) + + input = torch.rand(1, 1, 1024) + + reference_output = reference_model(input) + + # DistributedNorm inputs are fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode) + tt_input = ttnn.from_torch( + input, + device=device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, -1), mesh_shape=tt_model_args.cluster_shape), + memory_config=( + tt_model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG + ), + ) + + tt_output = tt_model(tt_input, mode=mode) + + # DistributedNorm outputs are replicated across devices + tt_output_torch = ttnn.to_torch( + tt_output, + mesh_composer=ttnn.ConcatMesh2dToTensor( + device, dims=(0, 2) if tt_model_args.is_galaxy else (2, 0), mesh_shape=tt_model_args.cluster_shape + ), + )[:1, :, :] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("rms_norm Passed!") + else: + logger.warning("rms_norm Failed!") + + assert passing, f"rms_norm output does not meet PCC requirement {0.99}." diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py new file mode 100644 index 000000000000..4a2fcea4d984 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.multimodal.mistral_24b.mistral_vision_tower import MistralVisionTower +from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_mistral_vision_tower(mesh_device, reset_seeds): + pcc_required = 0.98 + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix) + } + + B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size + input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16) + + ##### Reference model output (Torch) ##### + reference_model = model_args.reference_vision_model() + reference_model.load_state_dict(partial_state_dict) + reference_output = reference_model(input_tensor, image_sizes=[(H, W)]) + + reference_output = reference_output.last_hidden_state + ##### TT Model: MistralVisionTower ##### + vision_model = MistralVisionTower( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + dtype=dtype, + configuration=model_args, + ) + + tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) + tt_output = ttnn.from_device(tt_output) + tt_output = ttnn.to_torch(tt_output).squeeze(0) + passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC below {pcc_required}. {pcc_message}" diff --git a/models/tt_transformers/tt/common.py b/models/tt_transformers/tt/common.py index 9829a65d1b3f..b3f91b2b8582 100644 --- a/models/tt_transformers/tt/common.py +++ b/models/tt_transformers/tt/common.py @@ -9,10 +9,12 @@ import torch from loguru import logger -from pydantic import AliasChoices, BaseModel, Field - +from pydantic import BaseModel, Field,AliasChoices +import os import ttnn +from ttnn import ConcatMeshToTensor +model_name = os.getenv("HF_MODEL") class HostEmbedding(torch.nn.Module): def __init__(self, model_args): @@ -22,16 +24,12 @@ def __init__(self, model_args): def forward(self, x): return self.emb(x) - class HostScaledEmbedding(HostEmbedding): def __init__(self, model_args): super().__init__(model_args) self.embed_scale = model_args.embed_scale - def forward(self, x): return self.emb(x) * self.embed_scale - - # Default configuration for Paged Attention class PagedAttentionConfig: def __init__(self, block_size=32, max_num_blocks=1024): @@ -51,13 +49,16 @@ class RopeScalingType(str, Enum): class RopeScaling(BaseModel): """RoPE scaling configuration.""" - - rope_type: RopeScalingType = Field( + if model_name=="mistral/Mistral-Small-3.1-24B-Instruct-2503": + rope_type: RopeScalingType = Field(exclude=True, description="RoPE scaling type") + factor: Optional[float] + original_max_position_embeddings: int + else: + rope_type: RopeScalingType = Field( validation_alias=AliasChoices("rope_type", "type"), exclude=True, description="RoPE scaling type" ) - factor: float - original_max_position_embeddings: Optional[int] = None - + factor: float + original_max_position_embeddings: Optional[int] = None class RopeScalingLinear(RopeScaling): """RoPE scaling configuration for linear.""" @@ -87,6 +88,8 @@ def rope_scaling_model_factory(rope_scaling_params: dict) -> RopeScaling: return RopeScalingLinear(**rope_scaling_params) elif rope_scaling_type == RopeScalingType.LLAMA3: return RopeScalingLlama3(**rope_scaling_params) + elif rope_scaling_type == RopeScalingType.LINEAR: + return RopeScalingLinear(**rope_scaling_params) elif rope_scaling_type == RopeScalingType.YARN: return RopeScalingYarn(**rope_scaling_params) elif rope_scaling_type in ["default", "mrope"]: @@ -96,7 +99,50 @@ def rope_scaling_model_factory(rope_scaling_params: dict) -> RopeScaling: return None else: raise ValueError(f"Unexpected RoPE scaling type: {rope_scaling_type}") - +# below function is mistral 24B model specific function +def generate_block_attention_mask_tt(patch_embeds_list, tensor, tt_device): + tensor = ttnn.to_torch(tensor, mesh_composer=ConcatMeshToTensor(tt_device, dim=0)) + device = tensor.device + dtype = tensor.dtype + seq_len = tensor.shape[-2] + d_min = torch.finfo(dtype).min + causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) + + block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) + block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1) + for start, end in zip(block_start_idx, block_end_idx): + causal_mask[start:end, start:end] = 0 + + causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) + + causal_mask_tt = ttnn.from_torch( + causal_mask, + device=tt_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + return causal_mask_tt + +# below function is mistral 24B model specific function +def position_ids_in_meshgrid_tt(tt_patch_embeds_list, max_width, device): + position_ids_tt = [] + for tt_patch in tt_patch_embeds_list: + shape = tt_patch.shape + height, width = shape[-2], shape[-1] + mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_width + v_grid + + tt_ids = ttnn.from_torch( + ids, + device=device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + position_ids_tt.append(tt_ids[:, 0]) + return ttnn.concat(position_ids_tt, dim=0) def encode_prompt_instruct(tokenizer, prompt_text, system_prompt_text=None): """<|begin_of_text|><|start_header_id|>system<|end_header_id|> @@ -254,6 +300,43 @@ def apply_scaling(freqs: torch.Tensor, scale_factor: float, orig_context_len: in new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) +# below function is mistral 24B model specific function +def apply_scaling_vision(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): + return freqs / scale_factor + +# below function is mistral 24B model specific function +def precompute_vision_freqs( + dim: int, max_patches_per_side: int, theta: float, scale_factor=None, orig_context_len=None +): + # Compute base frequencies + base_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + if scale_factor is not None: + base_freqs = apply_scaling_vision(base_freqs, scale_factor, orig_context_len) + + # Get height and width indices + h_idx = torch.arange(max_patches_per_side) + w_idx = torch.arange(max_patches_per_side) + + # Compute 2D frequency matrices + freqs_h = torch.outer(h_idx, base_freqs[::2]) + freqs_w = torch.outer(w_idx, base_freqs[1::2]) + + # Broadcast + merge + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape( + -1, dim // 2 + ) # Shape: [H*W, dim//2] + + full_freqs = torch.cat([inv_freq, inv_freq], dim=-1) + cos = full_freqs.cos() + sin = full_freqs.sin() + return cos, sin # Shape: [H*W, dim] + def precompute_freqs(dim: int, end: int, theta, scale_factor, orig_context_len): """ diff --git a/models/tt_transformers/tt/distributed_norm.py b/models/tt_transformers/tt/distributed_norm.py index 8adaed8d4b9c..7fcec16e9240 100644 --- a/models/tt_transformers/tt/distributed_norm.py +++ b/models/tt_transformers/tt/distributed_norm.py @@ -2,13 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 +import os + import ttnn from models.common.lightweightmodule import LightweightModule from models.tt_transformers.tt.ccl import tt_distributed_rmsnorm, tt_sharded_distributed_rmsnorm +import os - +model_name = os.getenv("HF_MODEL") class DistributedNorm(LightweightModule): - def __init__(self, norm, args, tt_ccl, TG=False): + def __init__(self, norm, args,tt_ccl=None, TG=False): self.norm = norm self.args = args self.tt_ccl = tt_ccl @@ -69,11 +72,20 @@ def forward(self, x, mode): compute_kernel_config=self.ln_cfg, ) - input_mem_cfg = self.norm.sharded_output_config if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG + model_name = os.getenv("HF_MODEL") + if model_name == "mistralai/Mistral-Small-3.1-24B-Instruct-2503": + input_mem_cfg = ( + self.norm.sharded_output_config + if (mode == "decode" and self.norm.sharded_output_config is not None) + else ttnn.DRAM_MEMORY_CONFIG + ) # Distributed norm already performs a gather if self.args.is_multichip and not self.args.is_distributed_norm(mode): - x = ttnn.experimental.all_gather_async( + if model_name=="mistralai/Mistral-Small-3.1-24B-Instruct-2503": + x = ttnn.all_gather(x, dim=3, num_links=1, topology=self.args.ccl_topology(), memory_config=input_mem_cfg) # mistral 24B specific operation + else: + x = ttnn.experimental.all_gather_async( x, persistent_output_buffer=None, dim=3, @@ -93,7 +105,10 @@ def forward(self, x, mode): # Distributed norm requires a gather if self.args.is_distributed_norm(mode): - x = ttnn.experimental.all_gather_async( + if model_name == "mistralai/Mistral-Small-3.1-24B-Instruct-2503": + x = ttnn.all_gather(x, dim=3, num_links=1, topology=self.args.ccl_topology()) + else: + x = ttnn.experimental.all_gather_async( x, persistent_output_buffer=None, dim=3, diff --git a/models/tt_transformers/tt/generator.py b/models/tt_transformers/tt/generator.py index 603d104fb3b3..0de0b3d342fb 100644 --- a/models/tt_transformers/tt/generator.py +++ b/models/tt_transformers/tt/generator.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass +from pyexpat import model +from turtle import mode import torch from llama_models.llama3.api.datatypes import InterleavedTextMedia, StopReason @@ -23,6 +25,7 @@ num_blocks_in_seq, ) +model_name = os.getenv("HF_MODEL") @dataclass(frozen=True) class SamplingParams: @@ -58,7 +61,7 @@ def __init__(self, model, model_args, mesh_device, tokenizer=None, formatter=Non # Note: This function is called by vLLM def prefill_forward_text( - self, tokens: torch.Tensor, page_table=None, kv_cache=None, prompt_lens=None, empty_slots=None + self, tokens: torch.Tensor, page_table=None, kv_cache=None, prompt_lens=None, empty_slots=None, **kwargs ): if page_table is not None: assert isinstance(page_table, torch.Tensor), "page_table mush be torch.Tensor" @@ -97,11 +100,12 @@ def prefill_forward_text( logits = self.prefill_forward_single_user_text( prefill_ids, - page_table=page_table_user, + page_table=page_table_user if page_table is not None else None, user_id=group_user_id, last_token_idx=last_token_idx, kv_cache=model_kv_cache, model_id=model_id, + **kwargs, ) out_list.append(logits) @@ -117,7 +121,9 @@ def prefill_forward_text( logger.info(f"Finished prefill for all users up to {batch_seq_len} tokens, Starting decode...") return output_logits - def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_token_idx, kv_cache=None, model_id=-1): + def prefill_forward_single_user_text( + self, tokens, page_table, user_id, last_token_idx, kv_cache=None, model_id=-1, **kwargs + ): seq_len = tokens.shape[-1] use_chunked_prefill = seq_len > self.model_args[model_id].max_prefill_chunk_size if use_chunked_prefill: @@ -155,57 +161,99 @@ def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_tok ), f"Chunk end should be less than seq_len, got chunk_end={chunk_end} and seq_len={seq_len}" chunk_tokens = tokens[:, chunk_start:chunk_end] chunk_page_table = page_table_user[:, chunk_start // block_size : chunk_end // block_size] + if model_name == "mistralai/Mistral-Small-3.1-24B-Instruct-2503": + ( + chunk_prefill_input, + chunk_rot_mats_prefill, + page_table_tt, + chunk_page_table_tt, + ) = self.model[model_id].prepare_inputs_prefill( + chunk_tokens, + start_pos=chunk_start, + page_table=page_table_user_padded, + chunk_page_table=chunk_page_table, + **kwargs, + ) + tt_logits = self.model[model_id].ttnn_prefill_forward( + chunk_prefill_input, + rot_mats=chunk_rot_mats_prefill, + user_id=CHUNK_USER_ID, + page_table=page_table_tt, + chunk_page_table=chunk_page_table_tt, + chunk_start_idx=chunk_start, + get_last_token=(last_token_idx_in_chunk // 32) * 32, + kv_cache=kv_cache, + **kwargs, + ) + else: + ( + chunk_prefill_input, + chunk_rot_mats_global_prefill, + chunk_rot_mats_local_prefill, + page_table_tt, + chunk_page_table_tt, + ) = self.model[model_id].prepare_inputs_prefill( + chunk_tokens, + start_pos=chunk_start, + page_table=page_table_user_padded, + chunk_page_table=chunk_page_table, + ) + tt_logits = self.model[model_id].ttnn_prefill_forward( + chunk_prefill_input, + rot_mats_global=chunk_rot_mats_global_prefill, + rot_mats_local=chunk_rot_mats_local_prefill, + user_id=CHUNK_USER_ID, + page_table=page_table_tt, + chunk_page_table=chunk_page_table_tt, + chunk_start_idx=chunk_start, + get_last_token=(last_token_idx_in_chunk // 32) * 32, + kv_cache=kv_cache, + ) + + if chunk_start == last_chunk_start: + return tt_logits + else: + del tt_logits + else: + if model_name == "mistralai/Mistral-Small-3.1-24B-Instruct-2503": + prefill_input, rot_mats_prefill, page_table_tt, _ = self.model[model_id].prepare_inputs_prefill( + tokens, + page_table=page_table, + **kwargs, + ) + + tt_logits = self.model[model_id].ttnn_prefill_forward( + prefill_input, + rot_mats=rot_mats_prefill, + user_id=user_id, + page_table=page_table_tt, + get_last_token=(last_token_idx // 32) * 32, + kv_cache=kv_cache, + ) + return tt_logits + else: ( - chunk_prefill_input, - chunk_rot_mats_global_prefill, - chunk_rot_mats_local_prefill, + prefill_input, + rot_mats_global_prefill, + rot_mats_local_prefill, page_table_tt, - chunk_page_table_tt, + _, ) = self.model[model_id].prepare_inputs_prefill( - chunk_tokens, - start_pos=chunk_start, - page_table=page_table_user_padded, - chunk_page_table=chunk_page_table, + tokens, + page_table=page_table, ) + tt_logits = self.model[model_id].ttnn_prefill_forward( - chunk_prefill_input, - rot_mats_global=chunk_rot_mats_global_prefill, - rot_mats_local=chunk_rot_mats_local_prefill, - user_id=CHUNK_USER_ID, + prefill_input, + rot_mats_global=rot_mats_global_prefill, + rot_mats_local=rot_mats_local_prefill, + user_id=user_id, page_table=page_table_tt, - chunk_page_table=chunk_page_table_tt, - chunk_start_idx=chunk_start, - get_last_token=(last_token_idx_in_chunk // 32) * 32, + get_last_token=(last_token_idx // 32) * 32, kv_cache=kv_cache, ) - - if chunk_start == last_chunk_start: - return tt_logits - else: - del tt_logits - else: - ( - prefill_input, - rot_mats_global_prefill, - rot_mats_local_prefill, - page_table_tt, - _, - ) = self.model[model_id].prepare_inputs_prefill( - tokens, - page_table=page_table, - ) - - tt_logits = self.model[model_id].ttnn_prefill_forward( - prefill_input, - rot_mats_global=rot_mats_global_prefill, - rot_mats_local=rot_mats_local_prefill, - user_id=user_id, - page_table=page_table_tt, - get_last_token=(last_token_idx // 32) * 32, - kv_cache=kv_cache, - ) - return tt_logits + return tt_logits # Note: This function is called by vLLM def decode_forward_text( @@ -265,34 +313,59 @@ def _decode_forward_no_trace_text( tt_rot_mat_idxs_global = [] tt_rot_mat_idxs_local = [] tt_page_table = [] + tt_rot_mat_idxs_global = [] + tt_rot_mats=[] + tt_rot_mat_idxs_local = [] for i in range(self.data_parallel): - user_page_table = page_table[i] if page_table is not None else None - model_i = self.model[i] - ( - tt_tokens_i, - tt_current_pos_i, - tt_rot_mat_idxs_global_i, - tt_rot_mat_idxs_local_i, - tt_page_table_i, - ) = model_i.prepare_inputs_decode(tokens[i], current_pos[i], user_page_table) - tt_tokens.append(tt_tokens_i) - tt_current_pos.append(tt_current_pos_i) - tt_rot_mat_idxs_global.append(tt_rot_mat_idxs_global_i) - tt_rot_mat_idxs_local.append(tt_rot_mat_idxs_local_i) - tt_page_table.append(tt_page_table_i) + if model_name == "mistralai/Mistral-Small-3.1-24B-Instruct-2503": + user_page_table = page_table[i] if page_table is not None else None + tt_tokens_i, tt_current_pos_i, tt_rot_mats_i, tt_page_table_i = self.model[i].prepare_inputs_decode( + tokens[i], current_pos[i], user_page_table + ) + tt_tokens.append(tt_tokens_i) + tt_current_pos.append(tt_current_pos_i) + tt_rot_mats.append(tt_rot_mats_i) + tt_page_table.append(tt_page_table_i) + else: + user_page_table = page_table[i] if page_table is not None else None + model_i = self.model[i] + ( + tt_tokens_i, + tt_current_pos_i, + tt_rot_mat_idxs_global_i, + tt_rot_mat_idxs_local_i, + tt_page_table_i, + ) = model_i.prepare_inputs_decode(tokens[i], current_pos[i], user_page_table) + tt_tokens.append(tt_tokens_i) + tt_current_pos.append(tt_current_pos_i) + tt_rot_mat_idxs_global.append(tt_rot_mat_idxs_global_i) + tt_rot_mat_idxs_local.append(tt_rot_mat_idxs_local_i) + tt_page_table.append(tt_page_table_i) + for i in range(self.data_parallel): user_kv_cache = kv_cache[i] if kv_cache is not None else None - tt_logits_i = self.model[i].ttnn_decode_forward( - tt_tokens[i], - tt_current_pos[i], - rot_mat_idxs_global=tt_rot_mat_idxs_global[i], - rot_mat_idxs_local=tt_rot_mat_idxs_local[i], - page_table=tt_page_table[i], - kv_cache=user_kv_cache, - argmax_on_device=argmax_on_device, - ) + if model_name == "mistralai/Mistral-Small-3.1-24B-Instruct-2503": + tt_logits_i = self.model[i].ttnn_decode_forward( + tt_tokens[i], + tt_current_pos[i], + rot_mats=tt_rot_mats[i], + page_table=tt_page_table[i], + kv_cache=user_kv_cache, + argmax_on_device=argmax_on_device, + ) + else: + tt_logits_i = self.model[i].ttnn_decode_forward( + tt_tokens[i], + tt_current_pos[i], + rot_mat_idxs_global=tt_rot_mat_idxs_global[i], + rot_mat_idxs_local=tt_rot_mat_idxs_local[i], + page_table=tt_page_table[i], + kv_cache=user_kv_cache, + argmax_on_device=argmax_on_device, + ) + tt_logits.append(tt_logits_i) return tt_logits @@ -332,14 +405,54 @@ def _capture_trace_text( trace_id = ttnn.begin_trace_capture(self.model_args[i].mesh_device, cq_id=0) trace_ids[i] = trace_id user_kv_cache = kv_cache[i] if kv_cache is not None else None - tt_out_trace.append( - self.model[i].ttnn_decode_forward( - *device_inputs[i], kv_cache=user_kv_cache, argmax_on_device=argmax_on_device + if model_name=="mistralai/Mistral-Small-3.1-24B-Instruct-2503": + transformed_inputs = self.model[i].transform_decode_inputs_device(*(device_inputs[i])) + tt_out_trace.append( + self.model[i].ttnn_decode_forward( + *transformed_inputs, kv_cache=user_kv_cache, argmax_on_device=argmax_on_device + ) + ) + else: + tt_out_trace.append( + self.model[i].ttnn_decode_forward( + *device_inputs[i], kv_cache=user_kv_cache, argmax_on_device=argmax_on_device + ) ) - ) ttnn.end_trace_capture(self.model_args[i].mesh_device, trace_id, cq_id=0) logger.info("Done Capturing Decode Trace") return trace_ids, tt_out_trace, *device_inputs +# Note: This function is specific to the Mistral model + def _decode_forward_trace_text( + self, + trace_ids, + device_inputs, + tt_out_trace, + tokens, + current_pos, + page_table=None, + ): + """ + Executes the trace for the decode_forward method but does not read back outputs. + """ + host_inputs = [] + for i in range(self.data_parallel): + user_page_table = page_table[i] if page_table is not None else None + host_inputs_i = self.model[i].prepare_decode_inputs_host(tokens[i], current_pos[i], user_page_table) + host_inputs.append(host_inputs_i) + + to_device = [] + for i in range(self.data_parallel): + to_device.append( + copy_host_to_device( + host_tensors=host_inputs[i], + device_tensors=device_inputs[i], + ) + ) + device_inputs = to_device + for i, trace_id in trace_ids.items(): + ttnn.execute_trace(self.model_args[i].mesh_device, trace_id, cq_id=0, blocking=False) + + return tt_out_trace def _easy_trace_text( self, @@ -359,28 +472,38 @@ def _easy_trace_text( self.trace_ids_text = trace_ids self.trace_inputs_text = device_inputs self.trace_output_text = tt_out_trace + if model_name=="mistralai/Mistral-Small-3.1-24B-Instruct-2503": + trace_logits_rm = self._decode_forward_trace_text( + self.trace_ids_text, + self.trace_inputs_text, + self.trace_output_text, + tokens, + current_pos, + page_table=page_table, + ) + return trace_logits_rm + else: + reset_inputs = not argmax_on_device + if self.prev_page_table is None or any( + not torch.equal(prev, curr) for prev, curr in zip(self.prev_page_table, page_table) + ): + reset_inputs = True + self.prev_page_table = page_table + if reset_inputs: + for i in range(self.data_parallel): + user_page_table = page_table[i] if page_table is not None else None + host_inputs_i = self.model[i].prepare_decode_inputs_host(tokens[i], current_pos[i], user_page_table) + + copy_host_to_device( + host_tensors=host_inputs_i, + device_tensors=self.trace_inputs_text[i], + ) + + for i, trace_id in self.trace_ids_text.items(): + ttnn.execute_trace(self.model_args[i].mesh_device, trace_id, cq_id=0, blocking=False) + + return self.trace_output_text - reset_inputs = not argmax_on_device - if self.prev_page_table is None or any( - not torch.equal(prev, curr) for prev, curr in zip(self.prev_page_table, page_table) - ): - reset_inputs = True - self.prev_page_table = page_table - - if reset_inputs: - for i in range(self.data_parallel): - user_page_table = page_table[i] if page_table is not None else None - host_inputs_i = self.model[i].prepare_decode_inputs_host(tokens[i], current_pos[i], user_page_table) - - copy_host_to_device( - host_tensors=host_inputs_i, - device_tensors=self.trace_inputs_text[i], - ) - - for i, trace_id in self.trace_ids_text.items(): - ttnn.execute_trace(self.model_args[i].mesh_device, trace_id, cq_id=0, blocking=False) - - return self.trace_output_text def _prefill_forward_single_user( self, diff --git a/models/tt_transformers/tt/generator_vllm.py b/models/tt_transformers/tt/generator_vllm.py index 5125f551053d..701b3dac50c4 100644 --- a/models/tt_transformers/tt/generator_vllm.py +++ b/models/tt_transformers/tt/generator_vllm.py @@ -373,3 +373,111 @@ def decode_forward(self, *args, **kwargs): def allocate_kv_cache(self, *args, **kwargs): return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path) + +def input_processor_for_mistral(ctx: InputContext, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]): + input_processor = ctx.get_hf_processor() + if "prompt" in inputs: + prompt_text = inputs["prompt"] + else: + assert "prompt_token_ids" in inputs, "prompt_token_ids must be available in server mode" + prompt_text = input_processor.decode(inputs["prompt_token_ids"], skip_special_tokens=False) + + if "multi_modal_data" in inputs and "image" in inputs["multi_modal_data"]: + images = inputs["multi_modal_data"]["image"] + else: + images = None + + processed_inputs = input_processor( + text=prompt_text, + images=images, + return_tensors="pt", + ) + + assert processed_inputs.input_ids.shape[0] == 1, "Only one image is processed at a time by vLLM" + return { + "type": inputs["type"], + "prompt_token_ids": processed_inputs.input_ids[0].tolist(), + "prompt": prompt_text, + "multi_modal_data": {"image": processed_inputs}, # [INFO] add processed_inputs + } + + +from types import SimpleNamespace + + +class CustomNamespace(SimpleNamespace): + def __contains__(self, key): + return key in self.__dict__ + + +@INPUT_REGISTRY.register_input_processor(input_processor_for_mistral) +class Mistral3ForConditionalGeneration(Generator, SupportsMultiModal): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.MISTRAL_IMAGE_TOKEN_ID = 10 + self.max_gen_len = self.model_args[0].max_seq_len - 1 # TODO: double check what this should be + + @classmethod + def initialize_vllm_model( + cls, hf_config, mesh_device, max_batch_size, max_seq_len=131072, n_layers=None, tt_data_parallel=1 + ): + submesh_devices = create_submeshes(mesh_device, tt_data_parallel) + + model_args = [] + model = [] + state_dict = None + + for submesh in submesh_devices: + model_args_i, model_i, state_dict = create_multimodal_model( + mesh_device=submesh, + max_batch_size=max_batch_size // tt_data_parallel, + max_seq_len=max_seq_len, + use_paged_kv_cache=True, + checkpoint=state_dict, + ) + model_args.append(model_args_i) + model.append(model_i) + + return cls(model, model_args, mesh_device) + + @property + def cache_path(self): + return self.model_args[0].model_cache_path + + def prefill_forward(self, *args, **kwargs): + self.tokenizer = self.model_args[0].tokenizer + pad_token_id = self.tokenizer.pad_token_id + + tokens = kwargs["tokens"] + prompt_lens = kwargs["prompt_lens"] + inputs = CustomNamespace() + inputs.input_ids = tokens + data = kwargs.get("images", None) # This contains the entire Data list, not just the pixel values + for i in range(tokens.shape[0]): # for each user, fix their padding + tokens[i][prompt_lens[i] :] = pad_token_id + pixel_values = None + + if hasattr(data[0], "pixel_values"): + # If inputs is a list of objects with .pixel_values, concatenate them + pixel_values = torch.concat([im.pixel_values for im in data if hasattr(im, "pixel_values")], dim=0) + + page_table = kwargs.get("page_table", None) + kv_cache = kwargs.get("kv_cache", None) + vision_images = pixel_values + + vision_images = [vision_images] if vision_images is not None else None + + return super().prefill_forward_text( + tokens=inputs.input_ids, + page_table=page_table, + kv_cache=kv_cache, + prompt_lens=prompt_lens, + pixel_values=vision_images, + ) + + def allocate_kv_cache(self, *args, **kwargs): + return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path) + + def decode_forward(self, *args, **kwargs): + return super().decode_forward_text(*args, **kwargs) \ No newline at end of file diff --git a/models/tt_transformers/tt/load_checkpoints.py b/models/tt_transformers/tt/load_checkpoints.py index 6b28e2b4e5ce..831cd005c9c8 100644 --- a/models/tt_transformers/tt/load_checkpoints.py +++ b/models/tt_transformers/tt/load_checkpoints.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -11,7 +11,7 @@ from loguru import logger from safetensors.torch import load_file as safetensors_load_file from tqdm import tqdm - +model_name = os.getenv("HF_MODEL") # TODO Update function for large models: For 1 layer tests we only want to load 1 checkpoint file, instead of all. def load_hf_state_dict(ckpt_dir): @@ -101,7 +101,7 @@ def load_meta_state_dict(ckpt_dir, n_layers=None, start_layer_idx=0): def load_chunked_checkpoints(checkpoints, n_layers, start_layer_idx): checkpoint = {} - (f"Loading {len(checkpoints)} chunked checkpoint files") + (f"Loading {len(checkpoints)} checkpoint files") for ckpt in tqdm(checkpoints): if n_layers: # Layer range is in the file name, like layers_start-end.pth @@ -134,7 +134,10 @@ def load_sharded_checkpoints(checkpoints, n_layers): logger.info(f"Loading {len(checkpoints)} sharded checkpoint files") for ckpt in tqdm(checkpoints): loaded_ckpt = torch.load(ckpt, map_location="cpu") - for key, value in loaded_ckpt.items(): + for ( + key, + value, + ) in loaded_ckpt.items(): if "layers." in key: layer_num = int(key.split("layers.")[1].split(".")[0]) if n_layers and layer_num >= n_layers: @@ -147,10 +150,10 @@ def load_sharded_checkpoints(checkpoints, n_layers): # concat checkpoint values for key, value in checkpoint.items(): - if len(value) == 1 or is_param_replicated_across_shards(key): + if len(value) == 1 or "norm" in key: checkpoint[key] = value[0] else: - if key.endswith("tok_embeddings.weight") or key.endswith("output.weight"): + if key == "tok_embeddings.weight" or key == "output.weight": assert value[0].shape[1] == 8192 # FIXME: do we need this hardcoded shape? # Concatenate along dimension 0 for llama3 token embeddings weight and lm head checkpoint[key] = torch.cat(value, dim=0) @@ -256,12 +259,41 @@ def map_hf_to_meta_keys(loaded_weights): return replace_keys(loaded_weights, replacements) +def map_vision_meta_to_hf_keys(loaded_weights): + """ + Map Hugging Face checkpoint keys to Meta checkpoint keys. + You can use this to support other models by adding more mappings. + See replace_keys for more details on the format of replacements. + """ + inverted_mapping = [ + ("attention_norm", "input_layernorm"), + ("ffn_norm", "post_attention_layernorm"), + ("attention", "self_attn"), + ("feed_forward", "mlp"), + ("w1", "gate_proj"), + ("w2", "down_proj"), + ("w3", "up_proj"), + ("wq", "q_proj"), + ("wk", "k_proj"), + ("wv", "v_proj"), + ("wo", "o_proj"), + ] + + return replace_keys(loaded_weights, inverted_mapping) + + +def convert_vision_meta_to_hf(state_dict, head_dim): + # state_dict = convert_meta_qkv_to_hf_format(state_dict, head_dim) + state_dict = map_vision_meta_to_hf_keys(state_dict) + return state_dict + + def map_meta_to_hf_keys(loaded_weights): # Define mappings at each level of the hierarchy meta_to_hf_mappings = { # Top level "tok_embeddings.weight": "model.embed_tokens.weight", - "norm.weight": "model.norm.weight", + # "norm.weight": "model.norm.weight", "output.weight": "lm_head.weight", # Layer level "attention_norm.weight": "input_layernorm.weight", @@ -304,7 +336,9 @@ def map_meta_to_hf_keys(loaded_weights): # Host embeddings "emb.weight": "weight", } - + # Add norm.weight mapping for non-Mistral models + if model_name != "mistralai/Mistral-Small-3.1-24B-Instruct-2503": + meta_to_hf_mappings["norm.weight"] = "model.norm.weight" hf_state_dict = {} for key, tensor in loaded_weights.items(): # Handle full model paths with layer numbers @@ -325,7 +359,7 @@ def map_meta_to_hf_keys(loaded_weights): # For submodule state dicts, try matching the end of the key matched = False for meta_pattern, hf_pattern in meta_to_hf_mappings.items(): - if key.endswith("." + meta_pattern): + if key.endswith(meta_pattern) and key[-len(meta_pattern) :] != meta_pattern: # Replace only the matching part at the end prefix = key[: -len(meta_pattern)] new_key = prefix + hf_pattern @@ -336,6 +370,7 @@ def map_meta_to_hf_keys(loaded_weights): # If no mapping found, keep the original key if not matched: hf_state_dict[key] = tensor + return hf_state_dict diff --git a/models/tt_transformers/tt/model.py b/models/tt_transformers/tt/model.py index 67b15ceaef63..862af77d5f01 100644 --- a/models/tt_transformers/tt/model.py +++ b/models/tt_transformers/tt/model.py @@ -462,5 +462,4 @@ def forward( if mode == "prefill": x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) - # x = ttnn.to_memory_config(x, memory_config=ttnn.DRAM_MEMORY_CONFIG) return x diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index 5ed83397c3be..3751b7d534e5 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC # SPDX-License-Identifier: Apache-2.0 @@ -27,6 +27,7 @@ from models.tt_transformers.tt.load_checkpoints import ( convert_hf_to_meta, convert_meta_to_hf, + convert_vision_meta_to_hf, load_hf_state_dict, load_meta_state_dict, reverse_permute, @@ -35,11 +36,16 @@ ) from models.utility_functions import is_blackhole, is_wormhole_b0, nearest_32 +model_name = os.getenv("HF_MODEL") +print("*" * 200) +print(f"Model name: {model_name}") +print(f"model{model_name}") +print("*" * 200) # file names for performance and accuracy mode override files PERFORMANCE_DECODER_CONFIG_FILENAME = "performance_decoder_config.json" ACCURACY_DECODER_CONFIG_FILENAME = "accuracy_decoder_config.json" - +model_name = os.getenv("HF_MODEL") class TensorGroup(Enum): FF1_FF3 = "ff1_3" FF2 = "ff2" @@ -140,7 +146,10 @@ def performance(cls, model_name): """Configuration optimized for performance All models use bfp4 in FF1 and FF3 MLPs in this configuration """ - base_model_name = get_base_model_name(model_name) + if model_name=="mistralai/Mistral-Small-3.1-24B-Instruct-2503": + base_model_name = model_name.split("B-")[0] + "B" if "B-" in model_name else model_name + else: + base_model_name = get_base_model_name(model_name) if base_model_name == "Qwen2.5-7B": logger.info( f"Model {model_name} is degraded under standard high-performance settings, using BF16 attention and BFP8 MLP" @@ -564,8 +573,11 @@ def __init__( "Qwen2.5-VL-32B": {"N150": None, "N300": None, "T3K": 64, "TG": None, "P150x4": None}, "Qwen2.5-VL-72B": {"N150": None, "N300": None, "T3K": 32, "TG": None, "P150x4": None}, "Phi-3.5-mini-instruct": {"N150": 128, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, + "gemma-3-1b-it": {"N150": 8, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, + "gemma-3-4b-it": {"N150": 8, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, "QwQ-32B": {"N150": None, "N300": None, "T3K": 64, "TG": 128, "P150x4": 128}, "Qwen3-32B": {"N150": None, "N300": None, "T3K": 64, "TG": 128, "P150x4": 128}, + "Mistral-Small-3.1-24B-Instruct-2503": {"N150": None, "N300": None, "T3K": 64, "TG": 128, "P150x4": 128} } try: max_prefill_chunk_size_div1024 = MAX_PREFILL_CHUNK_SIZES_DIV1024[self.base_model_name][self.device_name] @@ -719,13 +731,19 @@ def __init__( # All Gather Matmul for Dense Out (DO) # TODO: Is there a better way to decide if fused all gather matmul should be used? And is there a better way to use the flag, instead of passing it into model_config? # NOTE: Fused all gather matmul only suppports a core grid of size num_devices x 1 - # TODO: #26657 (self.num_devices == 8 and os.getenv("ACTUAL_DEVICE", "") != "TG") should be refactored, and investigate if ACTUAL_DEVICE environment variable is still used - self.model_config["USE_FUSED_ALL_GATHER_MATMUL"] = ( - self.num_devices == 8 - and os.getenv("ACTUAL_DEVICE", "") != "TG" - and (self.dim // self.tile_size // self.num_devices) % self.num_devices == 0 - and self.num_devices > 1 - ) + if model_name=="mistralai/Mistral-Small-3.1-24B-Instruct-2503": + self.model_config["USE_FUSED_ALL_GATHER_MATMUL"] = ( + self.ccl_topology() == ttnn.Topology.Ring + and (self.dim // self.tile_size // self.num_devices) % self.num_devices == 0 + and self.num_devices > 1 + ) + else: + self.model_config["USE_FUSED_ALL_GATHER_MATMUL"] = ( + self.num_devices == 8 + and os.getenv("ACTUAL_DEVICE", "") != "TG" + and (self.dim // self.tile_size // self.num_devices) % self.num_devices == 0 + and self.num_devices > 1 + ) if self.model_config["USE_FUSED_ALL_GATHER_MATMUL"]: do_core_grid_size = (8, 1) @@ -798,12 +816,15 @@ def __init__( if self.is_galaxy else ( 1024 - if self.num_devices == 8 - and os.getenv("ACTUAL_DEVICE", "") != "TG" + if ( + (model_name == "mistralai/Mistral-Small-3.1-24B-Instruct-2503" and self.ccl_topology() == ttnn.Topology.Ring) + or (self.num_devices == 8 and os.getenv("ACTUAL_DEVICE", "") != "TG") + ) and 1024 % (self.dim / self.num_devices) == 0 else self.dim ) ) + num_rows = lambda seq_len: min(seq_len, 1024) dram_sharded_wo = not (self.model_config["USE_FUSED_ALL_GATHER_MATMUL"] or self.is_galaxy) self.model_config["WO_PREFILL_PROGCFG"] = lambda seq_len: self.matmul_config( @@ -1495,7 +1516,10 @@ def _set_params_from_dict(self, config, is_hf=False): self.mlp_activation_type = self._get_hidden_activation_type(text_config) # Vision params (Meta-specific) - self.vision_chunk_size = config.get("vision_chunk_size", -1) + if model_name=="mistralai/Mistral-Small-3.1-24B-Instruct-2503": + self.vision_chunk_size = config.get("vision_chunk_size", 896) + else: + self.vision_chunk_size = config.get("vision_chunk_size", -1) self.vision_max_num_chunks = config.get("vision_max_num_chunks", 4) self.vision_num_cross_attention_layers = config.get("vision_num_cross_attention_layers", -1) @@ -1586,8 +1610,61 @@ def _set_params(self, checkpoint_dir): if self.rope_scaling_factor is not None else None ) +# Note: This function is specific to the Mistral model. + def _set_vision_params(self, vision_config): + self.vision_chunk_size = vision_config.get("vision_chunk_size", 896) + self.vision_max_num_chunks = vision_config.get("vision_max_num_chunks", 4) + self.vision_num_cross_attention_layers = vision_config.get("vision_num_cross_attention_layers", 8) + self.vision_dim = vision_config.get("hidden_size", 1152) + intermediate_size = vision_config.get("intermediate_size", self.vision_dim * 4) + self.vision_image_size = vision_config.get("image_size", 1540) + self.vision_rope_theta = vision_config.get("rope_theta", 10000.0) + self.image_token_index = vision_config.get("image_token_index", 10) + + self.vision_mlp_ratio = intermediate_size // self.vision_dim + self.vision_hidden_dim = int(self.vision_dim * self.vision_mlp_ratio) + self.vision_attn_n_heads = vision_config.get("num_attention_heads") or vision_config.get("num_heads") or 16 + self.vision_head_dim = self.vision_dim // self.vision_attn_n_heads + + self.vision_n_layers = vision_config.get("num_hidden_layers") or vision_config.get("depth") or 27 + self.vision_patch_size = vision_config.get("patch_size", 14) + self.vision_in_channels = vision_config.get("num_channels", 3) + + self.vision_dropout = vision_config.get("attention_dropout", 0.0) + self.mm_tokens_per_image = vision_config.get("mm_tokens_per_image", 256) + + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + self.vision_head_dim = vision_config.get("head_dim", 64) + + # Optional vision activation layer, defaults to GELU + act_layer = vision_config.get("act_layer", "gelu").lower() + self.vision_act_layer = { + "gelu": ttnn.UnaryOpType.GELU, + "relu": ttnn.UnaryOpType.RELU, + "silu": ttnn.UnaryOpType.SILU, + }.get(act_layer, ttnn.UnaryOpType.GELU) + + # Optional tuning knobs + # self.vision_max_num_tiles = vision_config.get("max_num_tiles", 4) + # self.vision_n_global_layers = vision_config.get("n_global_layers", 8) + + # # Optional Meta-specific knobs + # self.vision_max_num_chunks = vision_config.get("max_num_chunks", 4) + # self.vision_num_cross_attention_layers = vision_config.get("num_cross_attention_layers", -1) def _set_hf_params(self, checkpoint_dir): + def merge_text_config(base_config): + text_config = base_config.get("text_config", {}) + # Merge non-nested keys into text_config + text_config.update({k: v for k, v in base_config.items() if k not in ["text_config", "vision_config"]}) + return text_config + + def merge_vision_config(base_config): + vision_config = base_config.get("vision_config", {}) + # Merge non-nested keys into vision_config + vision_config.update({k: v for k, v in base_config.items() if k not in ["text_config", "vision_config"]}) + return vision_config + if self.from_hf_url: # Special case Qwen2.5-VL models until they are fully integrated into a HF release if "Qwen/Qwen2.5-VL" in self.model_name: @@ -1604,12 +1681,26 @@ def _set_hf_params(self, checkpoint_dir): self.hf_config = AutoConfig.from_pretrained(self.CKPT_DIR) config = self.hf_config.to_dict() + # Note: This function is specific to the Mistral model. + if "text_config" in config or "vision_config" in config: + merged_text_config = merge_text_config(config) + self._set_params_from_dict(merged_text_config, is_hf=True) + + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + self._set_vision_params(config["vision_config"]) + else: + if "vision_config" in config: + merged_vision_config = merge_vision_config(config) + self._set_vision_params(merged_vision_config) + else: + self._set_params_from_dict(config, is_hf=True) + else: config_file = os.path.join(checkpoint_dir, "config.json") assert os.path.exists(config_file), f"config.json file not found at {config_file}" with open(config_file, "r") as f: config = json.load(f) - self._set_params_from_dict(config, is_hf=True) + self._set_params_from_dict(config) def __repr__(self): return f"""ModelArgs( @@ -1634,7 +1725,9 @@ def is_vision(self): return self.vision_chunk_size > 0 def get_state_dict_prefix(self, module_name, layer_num): - text_prefix = self.state_dict_text_prefix + text_prefix = ( + self.state_dict_text_prefix if self.is_vision() and not "Mistral-Small-3.1-24B" in self.model_name else "" + ) layer_prefix = f"layers.{layer_num}." if layer_num is not None else "" module_map = { "MLP": "feed_forward", @@ -1642,6 +1735,16 @@ def get_state_dict_prefix(self, module_name, layer_num): "TransformerBlock": "", "": "", # If no module is given, just get layer prefix } + #Note: This function is specific to the Mistral model. + vision_module_map = { + "MLP": "mlp.", + "Attention": "self_attn.", + "TransformerBlock": "", + "": "", + } + module_map = ( + vision_module_map if self.is_vision() and not "Mistral-Small-3.1-24B" in self.model_name else module_map + ) return text_prefix + layer_prefix + module_map[module_name] def weight_cache_path(self, dtype): @@ -1687,6 +1790,8 @@ def load_state_dict(self): ) print("Loading Qwen2.5-VL model: ", AutoModelForCausalLM) + elif "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + from transformers import Mistral3ForConditionalGeneration as AutoModelForCausalLM else: from transformers import AutoModelForCausalLM @@ -2067,55 +2172,76 @@ def create_tokenizer(self): logger.info(f"Model name: {self.model_name}") logger.info(f"Base model name: {self.base_model_name}") - try: - # Try to load tokenizer from the original model path - tokenizer = AutoTokenizer.from_pretrained(self.TOKENIZER_PATH) - logger.info(f"Successfully loaded tokenizer from {self.TOKENIZER_PATH}") - except Exception as e: - logger.warning(f"Failed to load tokenizer from {self.TOKENIZER_PATH}: {e}") - - # Try to use base model tokenizer as fallback - fallback_tokenizer_path = base_model_tokenizer_mapping.get(self.base_model_name) - - # If no direct match, try to infer from model name patterns - if not fallback_tokenizer_path: - model_name_lower = self.model_name.lower() - if "qwen2.5" in model_name_lower and "0.5b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-Coder-0.5B-Instruct" - elif "qwen2.5" in model_name_lower and "1.5b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-1.5B-Instruct" - elif "qwen2.5" in model_name_lower and "3b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-3B-Instruct" - elif "qwen2.5" in model_name_lower and "7b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-7B-Instruct" - elif "qwen2.5" in model_name_lower and "14b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-14B-Instruct" - elif "qwen2.5" in model_name_lower and "32b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-32B-Instruct" - elif "qwen2.5" in model_name_lower and "72b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-72B-Instruct" - elif "llama" in model_name_lower and "3.1" in model_name_lower and "8b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.1-8B-Instruct" - elif "llama" in model_name_lower and "3.1" in model_name_lower and "70b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.1-70B-Instruct" - elif "llama" in model_name_lower and "3.2" in model_name_lower and "1b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.2-1B-Instruct" - elif "llama" in model_name_lower and "3.2" in model_name_lower and "3b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.2-3B-Instruct" - elif "mistral" in model_name_lower and "7b" in model_name_lower: - fallback_tokenizer_path = "mistralai/Mistral-7B-Instruct-v0.3" - - if fallback_tokenizer_path: - logger.info(f"Attempting to use fallback tokenizer: {fallback_tokenizer_path}") - try: - tokenizer = AutoTokenizer.from_pretrained(fallback_tokenizer_path) - logger.info(f"Successfully loaded fallback tokenizer from {fallback_tokenizer_path}") - except Exception as fallback_e: - logger.error(f"Failed to load fallback tokenizer from {fallback_tokenizer_path}: {fallback_e}") - raise fallback_e - else: - logger.error(f"No fallback tokenizer found for base model: {self.base_model_name}") - raise e + # Special handling for Mistral-Small-3.1-24B-Instruct-2503 + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + tokenizer = AutoTokenizer.from_pretrained( + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", trust_remote_code=True + ) + logger.info("Manually setting Mistral instruct-style chat template on the tokenizer.") + + mistral_template = """{% for message in messages %} + {% if message['role'] == 'system' %} + <|system|> + {{ message['content'] }} + {% elif message['role'] == 'user' %} + [INST] {{ message['content'] }} [/INST] + {% elif message['role'] == 'assistant' %} + {{ message['content'] }}{{ eos_token }} + {% endif %} + {% endfor %}""" + tokenizer.chat_template = mistral_template + else: + try: + # Try to load tokenizer from the original model path + tokenizer = AutoTokenizer.from_pretrained(self.TOKENIZER_PATH) + logger.info(f"Successfully loaded tokenizer from {self.TOKENIZER_PATH}") + except Exception as e: + logger.warning(f"Failed to load tokenizer from {self.TOKENIZER_PATH}: {e}") + + # Try to use base model tokenizer as fallback + fallback_tokenizer_path = base_model_tokenizer_mapping.get(self.base_model_name) + + # If no direct match, try to infer from model name patterns + if not fallback_tokenizer_path: + model_name_lower = self.model_name.lower() + if "qwen2.5" in model_name_lower and "0.5b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-Coder-0.5B-Instruct" + elif "qwen2.5" in model_name_lower and "1.5b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-1.5B-Instruct" + elif "qwen2.5" in model_name_lower and "3b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-3B-Instruct" + elif "qwen2.5" in model_name_lower and "7b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-7B-Instruct" + elif "qwen2.5" in model_name_lower and "14b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-14B-Instruct" + elif "qwen2.5" in model_name_lower and "32b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-32B-Instruct" + elif "qwen2.5" in model_name_lower and "72b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-72B-Instruct" + elif "llama" in model_name_lower and "3.1" in model_name_lower and "8b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.1-8B-Instruct" + elif "llama" in model_name_lower and "3.1" in model_name_lower and "70b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.1-70B-Instruct" + elif "llama" in model_name_lower and "3.2" in model_name_lower and "1b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.2-1B-Instruct" + elif "llama" in model_name_lower and "3.2" in model_name_lower and "3b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.2-3B-Instruct" + elif "mistral" in model_name_lower and "7b" in model_name_lower: + fallback_tokenizer_path = "mistralai/Mistral-7B-Instruct-v0.3" + + if fallback_tokenizer_path: + logger.info(f"Attempting to use fallback tokenizer: {fallback_tokenizer_path}") + try: + tokenizer = AutoTokenizer.from_pretrained(fallback_tokenizer_path) + logger.info(f"Successfully loaded fallback tokenizer from {fallback_tokenizer_path}") + except Exception as fallback_e: + logger.error( + f"Failed to load fallback tokenizer from {fallback_tokenizer_path}: {fallback_e}" + ) + raise fallback_e + else: + logger.error(f"No fallback tokenizer found for base model: {self.base_model_name}") + raise e # Add meta-compatible stop token list to the HF tokenizer if not "stop_tokens" in tokenizer.__dict__: @@ -2165,6 +2291,8 @@ def reference_transformer(self, wrap=True, load_checkpoint=False): from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLForConditionalGeneration as AutoModelForCausalLM, ) + elif "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + from transformers import Mistral3ForConditionalGeneration as AutoModelForCausalLM else: from transformers import AutoConfig, AutoModelForCausalLM @@ -2194,6 +2322,48 @@ def reference_transformer(self, wrap=True, load_checkpoint=False): return wrapper else: return model + # Note: This function is specific to the Mistral model. + def reference_vision_multi_modal(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.multi_modal_projector + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rms_norm(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.multi_modal_projector.mm_soft_emb_norm + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rms_norm_qwen(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.visual.blocks[0].norm1 + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rms_norm_qwen_merger(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.visual.merger + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_qwen_patch_embed(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.visual.patch_embed + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_qwen_rotary_embed(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.visual.rotary_pos_emb + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + return layer def reference_rms_norm(self): if self.checkpoint_type == CheckpointType.Meta: @@ -2202,10 +2372,155 @@ def reference_rms_norm(self): return RMSNorm(self.dim, self.norm_eps) else: model = self.reference_transformer(wrap=False) - layer = model.model.norm + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layers = getattr(model, "layers", getattr(model, "model", {}).layers) + layer = layers[0].input_layernorm + else: + layer = model.model.norm layer._load_state_dict = layer.load_state_dict layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) return layer + # Note: This function is specific to the Mistral model. + def reference_vision_transformer(self, wrap=True, load_checkpoint=False): + if self.checkpoint_type == CheckpointType.HuggingFace: + from transformers import AutoConfig, AutoModelForCausalLM + + if self.dummy_weights and not load_checkpoint: + config = AutoConfig.from_pretrained(self.LOCAL_HF_PARAMS[self.model_name]) + config.num_layers = self.n_layers + config.num_hidden_layers = self.n_layers + model = AutoModelForCausalLM.from_config(config) + else: + if "gemma-3" in self.model_name: + from transformers import Gemma3ForConditionalGeneration + + model = Gemma3ForConditionalGeneration.from_pretrained(self.CKPT_DIR) + model = model + elif "Qwen2.5-VL-7B" in self.model_name: + from transformers import Qwen2_5_VLForConditionalGeneration + + model = Qwen2_5_VLForConditionalGeneration.from_pretrained(self.CKPT_DIR) + model = model + elif "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + from transformers import Mistral3ForConditionalGeneration + + model = Mistral3ForConditionalGeneration.from_pretrained(self.CKPT_DIR) + model = model + + else: + if self.cached_hf_model is None: + model = AutoModelForCausalLM.from_pretrained(self.CKPT_DIR) + self.cached_hf_model = model + else: + model = self.cached_hf_model + model.model.layers = model.model.layers[: self.n_layers] + if wrap: + wrapper = HfModelWrapper(model, self.head_dim) + return wrapper + else: + return model + + def reference_vision_model(self): + model = self.reference_vision_transformer(wrap=False) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + # Mistral-Small-3.1-24B-Instruct-2503 has a different structure + layer = model.vision_tower + else: + layer = model.vision_tower.vision_model + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_pixtral_image_block(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.transformer.layers[0] + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_mlp(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.transformer.layers[0].feed_forward + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rms(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.ln_pre + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_conv2d_patch(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.patch_conv + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_siglip_patch_embed(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.embeddings.patch_embedding + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_pos_embedding(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.embeddings.position_embedding + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_embedding(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.embeddings + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_layernorm(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.encoder.layers[0].layer_norm1 + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_attention(self): + model = self.reference_vision_transformer(wrap=False) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = model.vision_tower.transformer.layers[0].attention + else: + layer = model.vision_tower.vision_model.encoder.layers[0].self_attn # Common naming + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rot_emb(self): + model = self.reference_vision_transformer(wrap=False) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = model.vision_tower.patch_positional_embedding + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_encoder_block(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.encoder.layers[0] + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_encoder(self): + model = self.reference_vision_transformer(wrap=False) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = model.vision_tower.transformer + else: + layer = model.vision_tower.vision_model.encoder + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer def reference_mlp(self): if self.checkpoint_type == CheckpointType.Meta: @@ -2221,15 +2536,18 @@ def reference_mlp(self): def reference_embedding(self, reference_model=None): if self.checkpoint_type == CheckpointType.Meta: - from models.tt_transformers.tt.common import HostEmbedding, HostScaledEmbedding + from models.tt_transformers.tt.common import HostEmbedding,HostScaledEmbedding - return HostEmbedding(self) if self.embed_scale is None else HostScaledEmbedding(self) + return HostEmbedding(self)if self.embed_scale is None else HostScaledEmbedding(self) else: if reference_model is None: model = self.reference_transformer(wrap=False) layer = model.model.embed_tokens else: - layer = reference_model.model.model.embed_tokens + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = reference_model.model.embed_tokens + else: + layer = reference_model.model.model.embed_tokens layer._load_state_dict = layer.load_state_dict layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/mistral_vision_tower.py b/models/tt_transformers/tt/multimodal/mistral_24b/mistral_vision_tower.py new file mode 100644 index 000000000000..f22bc75fda24 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/mistral_vision_tower.py @@ -0,0 +1,162 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.common import generate_block_attention_mask_tt, position_ids_in_meshgrid_tt +from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm +from models.tt_transformers.tt.multimodal.mistral_24b.vision_conv2d import TtMistralConv2dPatch +from models.tt_transformers.tt.multimodal.mistral_24b.vision_pixtral_transformer import TtPixtralTransformer +from models.tt_transformers.tt.multimodal.mistral_24b.vision_rope import VisionRotarySetup as RotarySetup +from ttnn import ConcatMeshToTensor + + +class MistralVisionTower(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + dtype, + configuration, + return_intermediate=None, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.dtype = dtype + self.config = configuration + + self.image_size = configuration.vision_chunk_size + self.patch_size = configuration.vision_patch_size + self.width = configuration.vision_dim + self.layers = configuration.vision_n_layers + self.heads = configuration.vision_attn_n_heads + self.vision_head_dim = configuration.vision_head_dim + self.mlp_ratio = configuration.vision_mlp_ratio + self.act_layer = configuration.vision_act_layer + self.in_channels = configuration.vision_in_channels + self.n_global_layers = configuration.vision_n_global_layers + self.max_seq_len = configuration.max_seq_len + self.return_intermediate = return_intermediate + self.n_layers = configuration.vision_n_layers + + in_channels, out_channels, kernel_size, stride, bias = ( + 3, + configuration.vision_dim, + configuration.vision_patch_size, + configuration.vision_patch_size, + False, + ) + + self.patch_conv = TtMistralConv2dPatch( + mesh_device=self.mesh_device, + state_dict=self.state_dict, + state_dict_prefix=f"{state_dict_prefix}patch_conv.", + dtype=self.dtype, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + bias=bias, + ) + + self.ln_pre = RMSNorm( + device=mesh_device, + dim=self.width, + state_dict=self.state_dict, + state_dict_prefix=state_dict_prefix, + weight_dtype=dtype, + weight_key="ln_pre", + is_distributed=False, + ) + + image_size = configuration.vision_image_size + patch_size = configuration.vision_patch_size + dim = configuration.vision_head_dim + num_patches_per_dim = image_size // patch_size + num_patches = num_patches_per_dim * num_patches_per_dim + self.num_patches = num_patches + + self.patch_positional_embedding = RotarySetup( + self.mesh_device, + 1, + dim, + image_size, + patch_size, + num_patches, + configuration.vision_rope_theta, + scale_factor=None, + orig_context_len=num_patches, + datatype=dtype, + ) + + self.transformer = TtPixtralTransformer( + mesh_device=self.mesh_device, + state_dict=self.state_dict, + state_dict_prefix=f"{state_dict_prefix}transformer.", + weight_cache_path=configuration.weight_cache_path(dtype), + dtype=self.dtype, + configuration=configuration, + layers=self.n_layers, + ) + + def forward(self, input_tensor, image_sizes=None): + """ + input_tensor shape: (B, C, H, W) + """ + patch_embeds = self.patch_conv(input_tensor) + patch_embeds = ttnn.transpose(patch_embeds, 1, 2) + height, width = image_sizes[0] + patch_embeds = ttnn.reshape( + patch_embeds, + [patch_embeds.shape[0], self.width, height // self.patch_size, width // self.patch_size], + ) + + patch_embeds_list = [ + ttnn.slice( + patch_embeds, + [0, 0, 0, 0], + [1, self.width, size[0] // self.patch_size, size[1] // self.patch_size], + ) + for size in image_sizes + ] + + reshaped_patches = [] + for p in patch_embeds_list: + p = ttnn.reshape(p, (1, self.width, -1)) + p = ttnn.transpose(p, 1, 2) + reshaped_patches.append(p) + + patch_embeds = ttnn.concat(reshaped_patches, dim=0) + + # ln_pre RMS Norm + mode = "prefill" # if self.max_seq_len <= 32 else "prefill" + patch_embeds = self.ln_pre(patch_embeds, mode=mode) + + # # positional embeddings + position_ids = position_ids_in_meshgrid_tt( + patch_embeds_list, + max_width=self.config.vision_image_size // self.config.vision_patch_size, + device=self.mesh_device, + ) + + torch_position_ids = ttnn.to_torch(position_ids, mesh_composer=ConcatMeshToTensor(self.mesh_device, dim=0))[ + : position_ids.shape[-1] + ] + + position_embeddings = self.patch_positional_embedding.get_rot_mats(torch_position_ids) + + attention_mask = generate_block_attention_mask_tt( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds, tt_device=self.mesh_device + ) + + patch_embeds = ttnn.unsqueeze(patch_embeds, 0) + out = self.transformer(patch_embeds, mask=attention_mask, position_embeddings=position_embeddings) + # deallocate position_embeddings + ttnn.deallocate(position_embeddings[0]) + ttnn.deallocate(position_embeddings[1]) + + return out diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/model.py b/models/tt_transformers/tt/multimodal/mistral_24b/model.py new file mode 100644 index 000000000000..36061a79588a --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/model.py @@ -0,0 +1,123 @@ +""" +This is the end-to-end pipeline for the Mistral-Small-3.1-24B-Instruct-2503 model. + +The `MistralTransformer` class inherits from the `Transformer` class in tt_transformers. +It overrides `prepare_inputs_prefill` to run inference on the vision model and +pass the resulting visual tokens to the text model along with text tokens. +""" + + +import torch + +import ttnn +from models.tt_transformers.tt.model import Transformer +from ttnn import ConcatMeshToTensor + + +class MistralTransformer(Transformer): + def __init__( + self, + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=None, + use_paged_kv_cache=False, + ): + super().__init__( + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=paged_attention_config, + use_paged_kv_cache=use_paged_kv_cache, + ) + + def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): + """ + Inputs are torch tensors or python types. This function returns ttnn + tensors on device. + TODO: Debate whether this function is responsible for padding + """ + + tokens = tokens.reshape(1, 1, 1, -1) + S = tokens.shape[-1] + tokens = ttnn.from_torch( + tokens, + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + tokens_embd = self.embd(tokens) + + pixel_values = kwargs["processed_inputs"]["pixel_values"] + input_ids = kwargs["processed_inputs"]["input_ids"] + image_sizes = kwargs["processed_inputs"]["image_sizes"] + + if pixel_values is not None: + vision_model = kwargs["vision_model"] + vision_output = vision_model(pixel_values, image_sizes) + vision_output_torch = ttnn.to_torch( + vision_output, mesh_composer=ConcatMeshToTensor(self.mesh_device, dim=0) + )[: vision_output.shape[0]] + tokens_embd = ttnn.to_torch(tokens_embd, mesh_composer=ConcatMeshToTensor(self.mesh_device, dim=-1)) + sliced_token_embds = tokens_embd[: tokens_embd.shape[0]] + + image_features = vision_output_torch + + input_ids = torch.nn.functional.pad( + input_ids, (0, tokens_embd.shape[1] - input_ids.shape[1]), "constant", 0 + ) + special_image_mask = (input_ids == 10).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(tokens_embd) + image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) + tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) + + + tokens_embd = ttnn.from_torch( + tokens_embd, + dtype=ttnn.bfloat16, + device=self.mesh_device, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh( + self.mesh_device, dims=(None, 2), mesh_shape=list(self.mesh_device.shape) + ), + ) + + tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) + # Slice the rot mats to the prefill seqlen + assert ( + self.rope_setup.cos_matrix.shape[2] >= start_pos + S + ), f"Padded prefill end idx {start_pos + S} exceeds max seq len {self.rope_setup.cos_matrix.shape[2]}" + + tt_rot_mats_prefill_global = [ + self.rope_setup.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + + if page_table is not None: + tt_page_table = ttnn.from_torch( + page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_page_table = None + + if chunk_page_table is not None: + tt_chunk_page_table = ttnn.from_torch( + chunk_page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_chunk_page_table = None + + return tokens_embd, tt_rot_mats_prefill_global, tt_page_table, tt_chunk_page_table diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/rmsnorm.py b/models/tt_transformers/tt/multimodal/mistral_24b/rmsnorm.py new file mode 100644 index 000000000000..7018e519fd5a --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/rmsnorm.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 +import ttnn +from models.common.lightweightmodule import LightweightModule + +TILE = 32 +SHARD_HEIGHT = TILE # Current ttnn.rms_norm implementation requires shard height to be a single tile + + +class RMSNorm(LightweightModule): + """ + RMSNorm supporting replication over a MeshDevice and sharding within devices. + + This class implements a Root Mean Square Normalization (RMSNorm) that can be + distributed across multiple devices and cores. If the `device` parameter is a + MeshDevice, the weights and computations are replicated across all devices in + the mesh. Expects an interleaved input tensor, can optionally output a sharded tensor. + + Args: + device: The device or MeshDevice on which to perform the computations. + state_dict: The state dictionary containing the model parameters. + dim: Input dimension (e.g. model hidden dimension size). + layer_num: The layer number to determine the weight key in the state dictionary. + weight_key: The key for retrieving the weight from the state dictionary. + weight_cache_path: Optional path for caching the tilized weights. + weight_memory_config: Configuration for the weight memory, default is DRAM_MEMORY_CONFIG. + weight_dtype: The data type for the tensors, bfp8_b hits >0.999 PCC in the models we tested. + model_config: Optional configuration dictionary for the model. + eps (float): Small value to avoid division by zero in normalization, default is 1e-05. + + If model_config is provided, it must specify SHARDED_NORM_INPUT_MEMCFG, SHARDED_NORM_PRGM_CFG + and SHARDED_NORM_OUTPUT_MEMCFG. If not provided, default configurations will be generated. + """ + + def __init__( + self, + device, + dim, + state_dict, + weight_key, + layer_num=None, + state_dict_prefix=None, + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps: float = 1e-05, + sharded_program_config=None, + sharded_output_config=None, + output_mem_config=None, + ccl_topology=ttnn.Topology.Ring, + ): + super().__init__() + self.eps = eps + self.is_distributed = is_distributed + self.ccl_topology = ccl_topology + + if state_dict_prefix: + weight_name = f"{state_dict_prefix}{weight_key}.weight" + else: + if layer_num is None: + weight_name = f"{weight_key}.weight" + else: + weight_name = f"layers.{layer_num}.{weight_key}.weight" + + torch_weight = ( + state_dict[weight_name].unsqueeze(0).view(1, 1, dim).reshape([1, 1, dim // SHARD_HEIGHT, SHARD_HEIGHT]) + ) + + cache_name = None if weight_cache_path is None else weight_cache_path / weight_name + + # Compatibility with models that don't use mesh devices (e.g. single-chip Mistral-7b) + is_mesh_device = device.__class__.__name__ == "MeshDevice" + + self.weight = ttnn.as_tensor( + state_dict[weight_name].unsqueeze(0).view(1, 1, dim), + device=device, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT if weight_dtype == ttnn.bfloat8_b else ttnn.ROW_MAJOR_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=cache_name, + mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + ) + + if self.is_distributed: + self.weight_distributed = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=cache_name, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape)) + if is_mesh_device + else None, + ) + + self.sharded_output_config = sharded_output_config + self.sharded_program_config = sharded_program_config + self.output_mem_config = output_mem_config + + self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=False, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) -> ttnn.Tensor: + # If input is sharded do sharded RMSNorm and optionally return sharded output + program_config = self.sharded_program_config if in_sharded else None + memory_config = self.sharded_output_config if out_sharded else None + distributed = self.is_distributed and self.is_distributed(mode) + norm = self._distributed_rmsnorm if distributed else ttnn.rms_norm + weight = self.weight_distributed if distributed else self.weight + + if in_sharded: + assert not distributed, "Distributed RMSNorm does not support sharded inputs" + else: + assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor" + + x = norm( + x, + epsilon=self.eps, + weight=weight, + program_config=program_config, + memory_config=memory_config, + compute_kernel_config=self.compute_kernel_config_hifi2, + ) + + if in_sharded and not out_sharded: + return ttnn.sharded_to_interleaved(x) + else: + return x + + def _distributed_rmsnorm( + self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None + ): + assert program_config is None, "Distributed RMSNorm does not support sharded inputs" + assert memory_config is None, "Distributed RMSNorm does not support sharded outputs" + + # Run distributed rmsnorm part 1 + tt_stats = ttnn.rms_norm_pre_all_gather(inp, compute_kernel_config=compute_kernel_config, dtype=ttnn.bfloat16) + # AllGather stats + tt_stats = ttnn.all_gather( + tt_stats, + dim=3, + num_links=1, + topology=self.ccl_topology, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + # Run distributed rmsnorm part 2 + tt_out = ttnn.rms_norm_post_all_gather( + inp, + tt_stats, + epsilon=epsilon, + weight=weight, + compute_kernel_config=compute_kernel_config, + ) + tt_stats.deallocate(True) + + return tt_out diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_attention.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_attention.py new file mode 100644 index 000000000000..f3c8daa31945 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_attention.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.utility_functions import nearest_32 + + +def rotate_half(x): + last_dim = x.shape[-1] + half = last_dim // 2 + + x1 = ttnn.slice(x, (0, 0, 0, 0), (x.shape[0], x.shape[1], x.shape[2], half)) + x2 = ttnn.slice(x, (0, 0, 0, half), (x.shape[0], x.shape[1], x.shape[2], last_dim)) + + neg_x2 = ttnn.mul(x2, -1, use_legacy=False) + return ttnn.concat([neg_x2, x1], dim=-1) + + +def apply_rotary_pos_emb_vision_tt(q, k, cos, sin): + cos = ttnn.unsqueeze(cos, 0) + sin = ttnn.unsqueeze(sin, 0) + + q_embed = ttnn.add(ttnn.mul(q, cos, use_legacy=True), ttnn.mul(rotate_half(q), sin, use_legacy=True)) + k_embed = ttnn.add(ttnn.mul(k, cos), ttnn.mul(rotate_half(k), sin)) + return q_embed, k_embed + + +class TtMistralImageAttention(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.num_devices = configuration.num_devices + + self.hidden_size = configuration.vision_dim + self.n_heads = configuration.vision_attn_n_heads + self.head_dim = self.hidden_size // self.n_heads + self.n_kv_heads = self.n_heads + + self.n_local_heads = self.n_heads // configuration.num_devices + self.n_local_kv_heads = self.n_kv_heads // configuration.num_devices + + self.dtype = dtype + + self.grid_size = configuration.max_grid_size + + self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 + self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 + self.compute_kernel_config_sdpa = configuration.compute_kernel_config_sdpa + self.configuration = configuration + + self.model_config = configuration.get_model_config() + + if configuration.dummy_weights or (weight_cache_path is None): + cache_name = lambda _: None + else: + cache_name = lambda name: weight_cache_path / (f"{state_dict_prefix}{name}") + + wq_str = f"{state_dict_prefix}wq.weight" + wk_str = f"{state_dict_prefix}wk.weight" + wv_str = f"{state_dict_prefix}wv.weight" + wo_str = f"{state_dict_prefix}wo.weight" + + # when splitting the devices, we need to make sure that the number of heads is divisible by the number of devices + assert self.n_heads % configuration.num_devices == 0 + assert self.n_kv_heads % configuration.num_devices == 0 + + # Pad head_dim to multiple of 32 + def pad_head_dim(weight, heads_out=True): + # Pad head dim to multiple of 32 + # heads_out means that the output dim of this weight contains heads. + dim = weight.shape[1] + assert weight.shape[0] == dim + padded_head_dim = nearest_32(self.head_dim) + padding_size = padded_head_dim - self.head_dim + if padding_size > 0: + if heads_out: + weight = weight.transpose(-1, -2) + weight = weight.reshape(dim, self.n_heads, self.head_dim) + padding = torch.zeros(dim, self.n_heads, padding_size, dtype=weight.dtype) + weight = torch.cat([weight, padding], dim=-1) + weight = weight.reshape(dim, self.n_heads * padded_head_dim) + if heads_out: + weight = weight.transpose(-1, -2) + return weight + + wq_padded = pad_head_dim(self.state_dict[wq_str]) + wk_padded = pad_head_dim(self.state_dict[wk_str]) + wv_padded = pad_head_dim(self.state_dict[wv_str]) + wo_padded = pad_head_dim(self.state_dict[wo_str], heads_out=False) + wq_chunked, wk_chunked, wv_chunked = ( + torch.chunk(w, configuration.num_devices) for w in [wq_padded, wk_padded, wv_padded] + ) + + self.wqkv = ttnn.as_tensor( + torch.concat( + [ + torch.concat( + [ + torch.transpose( + wq_chunked[i], + -2, + -1, + ), + torch.transpose( + wk_chunked[i], + -2, + -1, + ), + torch.transpose( + wv_chunked[i], + -2, + -1, + ), + ], + dim=-1, + ) + for i in range(configuration.num_devices) + ], + dim=-1, + ), + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + dtype=self.dtype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + cache_file_name=cache_name("wqkv_sharded"), + ) + + self.wo = ttnn.as_tensor( + torch.transpose( + wo_padded, + -2, + -1, + ), + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-2), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=self.dtype, + layout=ttnn.TILE_LAYOUT, + cache_file_name=cache_name("wo_sharded"), + ) + + self.scale = self.head_dim**-0.5 + + def forward(self, x_11SH, position_embeddings=None, mask=None): + seq_len = x_11SH.shape[-2] + + MAX_MM_SEQ_LEN = self.configuration.VISION_MAX_MM_SEQ + + if seq_len > MAX_MM_SEQ_LEN: + x_11SH = ttnn.reshape(x_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) + + xqkv_fused = ttnn.linear( + x_11SH, + self.wqkv, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config_hifi4, + program_config=self.model_config["IMAGE_ATTN_QKV_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + if seq_len > MAX_MM_SEQ_LEN: + xqkv_fused = ttnn.reshape(xqkv_fused, [1, 1, seq_len, -1]) + + # split qkv into heads + ( + q_heads_1QSD, + k_heads_1KSD, + v_heads_1VSD, + ) = ttnn.experimental.nlp_create_qkv_heads( + xqkv_fused, + num_heads=self.n_local_heads, + num_kv_heads=self.n_local_kv_heads, + transpose_k_heads=False, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + if position_embeddings is not None: + cos, sin = position_embeddings + q_heads_1QSD, k_heads_1KSD = apply_rotary_pos_emb_vision_tt(q_heads_1QSD, k_heads_1KSD, cos, sin) + ttnn.deallocate(xqkv_fused) + # TODO: get this from model_config + sdpa_cfg = ttnn.SDPAProgramConfig( + compute_with_storage_grid_size=(8, 8), q_chunk_size=128, k_chunk_size=128, exp_approx_mode=False + ) + attn_output_1QSD = ttnn.transformer.scaled_dot_product_attention( + q_heads_1QSD, + k_heads_1KSD, + v_heads_1VSD, + is_causal=False, + scale=self.scale, + program_config=sdpa_cfg, + compute_kernel_config=self.compute_kernel_config_sdpa, + ) + # deallocate keys and values + ttnn.deallocate(q_heads_1QSD) + ttnn.deallocate(k_heads_1KSD) + ttnn.deallocate(v_heads_1VSD) + + ### + # Output matmul + ### + attn_output_11SH = ttnn.experimental.nlp_concat_heads( + attn_output_1QSD, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + ttnn.deallocate(attn_output_1QSD) + + # reshaping long sequence to matmul fit on device + if seq_len > MAX_MM_SEQ_LEN: + attn_output_11SH = ttnn.reshape(attn_output_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) + + output_11SH = ttnn.linear( + attn_output_11SH, + self.wo, + compute_kernel_config=self.compute_kernel_config_hifi4, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + program_config=self.model_config["IMAGE_ATTN_OUT_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + if seq_len > MAX_MM_SEQ_LEN: + output_11SH = ttnn.reshape(output_11SH, [1, 1, seq_len, -1]) + ttnn.deallocate(attn_output_11SH) + + # All reduce + if self.num_devices > 1: # replace with reduce_scatter and all_gather + dense_out_gathered = ttnn.all_gather(output_11SH, dim=1, num_links=1, topology=ttnn.Topology.Linear) + output_11SH.deallocate(True) + dense_out_reduced = ttnn.experimental.fast_reduce_nc( + dense_out_gathered, dims=[1], output=None, compute_kernel_config=None + ) + # slicing the required sequence length + dense_out_reduced = dense_out_reduced[:, :, : dense_out_gathered.shape[-2], :] + return dense_out_reduced + else: + return output_11SH diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_conv2d.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_conv2d.py new file mode 100644 index 000000000000..0b16dca7fbcf --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_conv2d.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +class TtMistralConv2dPatch(LightweightModule): + """Conv2D Patching layer. + Column parallel over unfolded input. + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + Input: (bsz, in_channels, width, height) + Output: (bsz, num_tokens, out_channels) + """ + + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + dtype, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + bias, + ): + super().__init__() + + self.mesh_device = mesh_device + self.num_devices = self.mesh_device.get_num_devices() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + + self.bias = ( + ttnn.as_tensor( + torch.reshape(state_dict[f"{state_dict_prefix}_linear.bias"], (1, -1)), + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + if bias + else None + ) + + self._unfold = torch.nn.Unfold(kernel_size=self.kernel_size, stride=self.stride) + + weight = state_dict[f"{state_dict_prefix}weight"] + if weight.ndim == 4: + weight = weight.reshape(out_channels, -1).T + # pad_len = nearest_32(weight.shape[-1]) - weight.shape[-1] + # padding = torch.zeros(self.out_channels, pad_len, dtype=weight.dtype) + # padded_weight = torch.cat([weight, padding], dim=-1) + # padded_weight = padded_weight.permute(1, 0).reshape(1, 1, -1, self.out_channels) + + self._linear_weight = ttnn.as_tensor( + weight, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + self.compute_kernel_config = ttnn.init_device_compute_kernel_config( + mesh_device.arch(), + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + def forward(self, x: torch.Tensor): + x = self._unfold(x) + x = x.permute(0, 2, 1) + + # Need to pad the last dimension of x to be a multiple of a tile + # pad_len = nearest_32(x.shape[-1]) - x.shape[-1] + # padding = torch.zeros((x.shape[0], x.shape[1], pad_len), dtype=x.dtype, device=x.device) + # x = torch.cat([x, padding], dim=-1) + + x = ttnn.as_tensor( + x, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + out = ttnn.linear( + x, + self._linear_weight, + bias=self.bias, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + core_grid=ttnn.CoreGrid(y=8, x=8), + ) + + return out diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_mlp.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_mlp.py new file mode 100644 index 000000000000..8c8612d937d4 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_mlp.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +class MistralTTVisionMLP(LightweightModule): + def __init__( + self, + mesh_device, + args, + state_dict, + weight_cache_path, + dtype, + state_dict_prefix=None, + ): + super().__init__() + + self.mesh_device = mesh_device + self.args = args + self.state_dict = state_dict + self.dim = args.dim + + def get_weight(name): + return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) + + def get_bias(name): + return state_dict[f"{state_dict_prefix}{name}.bias"] + + def cache_name(name): + if args.dummy_weights: + return None + return weight_cache_path / f"{state_dict_prefix}.{name}" + + def as_tensor(name, dtype, is_bias=False): + tensor_data = get_bias(name) if is_bias else get_weight(name) + return ttnn.as_tensor( + tensor_data, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + # cache_file_name=cache_name(name), + ) + + # Weights and Biases + self.w1 = as_tensor("w1", dtype) + self.b1 = as_tensor("w1", ttnn.bfloat16, is_bias=False) + + self.w3 = as_tensor("w3", dtype) + self.b3 = as_tensor("w3", ttnn.bfloat16, is_bias=False) + + self.w2 = as_tensor("w2", dtype) + self.b2 = as_tensor("w2", ttnn.bfloat16, is_bias=False) + + self.compute_kernel_config_hifi4 = self.args.compute_kernel_config_hifi4 + + def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: + """ + Qwen HF MLP reference: + output = down_proj(act_fn(gate_proj(x)) * up_proj(x)) + Mapping: + w1 -> gate_proj + w3 -> up_proj + w2 -> down_proj + """ + + # Linear with SILU activation + w1_out = ttnn.linear( + x, + self.w1, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + activation="silu", + compute_kernel_config=self.compute_kernel_config_hifi4, + ) + + w3_out = ttnn.linear( + x, + self.w3, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config_hifi4, + ) + + # Element-wise multiply + w2_in = ttnn.mul(w1_out, w3_out, dtype=ttnn.bfloat16) + + # Final projection + w2_out = ttnn.linear( + w2_in, + self.w2, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config_hifi4, + ) + + ttnn.deallocate(w1_out) + ttnn.deallocate(w3_out) + ttnn.deallocate(w2_in) + return w2_out diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_mmp.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_mmp.py new file mode 100644 index 000000000000..d6cf6e3be6b9 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_mmp.py @@ -0,0 +1,169 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm +from ttnn import ConcatMeshToTensor + + +class TTMistral3PatchMerger(LightweightModule): + def __init__( + self, + mesh_device, + args, + state_dict, + state_dict_prefix, + weight_cache_path=None, + dtype=ttnn.bfloat16, + ): + super().__init__() + self.device = mesh_device + hidden_size = args.vision_dim + self.spatial_merge_size = 2 # TODO Handle in Model_config spatial_merge_size + self.patch_size = args.vision_patch_size + self.args = args + + def get_weight(name): + return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) + + def get_bias(name): + return state_dict[f"{state_dict_prefix}{name}.bias"] + + def cache_name(name): + if args.dummy_weights: + return None + return weight_cache_path / f"{state_dict_prefix}.{name}" + + def as_tensor(name, dtype, is_bias=False): + tensor_data = get_bias(name) if is_bias else get_weight(name) + return ttnn.as_tensor( + tensor_data, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + + ) + + self.merging_weights = as_tensor("merging_layer", dtype) + self.merging_bias = as_tensor("merging_layer", ttnn.bfloat16, is_bias=False) + + def forward(self, image_features: ttnn.Tensor, image_sizes) -> ttnn.Tensor: + image_sizes = [ + (image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes + ] + + tokens_per_image = [h * w for h, w in image_sizes] + d = image_features.shape[-1] + + permuted_tensor = [] + for image_index, image_tokens in enumerate(ttnn.split(image_features, tokens_per_image, dim=0)): + # Reshape image_tokens into a 2D grid + h, w = image_sizes[image_index] + + image_tokens = ttnn.to_layout(image_tokens, ttnn.ROW_MAJOR_LAYOUT) + + image_grid = ttnn.view(image_tokens, (h, w, d)) + # Permute the grid to have channels last + image_grid = ttnn.permute(image_grid, (2, 0, 1)) # Channels first + image_grid = ttnn.unsqueeze(image_grid, dim=0) # Add batch dimension + # Reshape the grid to merge patches + if self.args.num_devices > 1: + image_grid_torch = ttnn.to_torch(image_grid, mesh_composer=ConcatMeshToTensor(self.device, dim=0)) + image_grid_torch = image_grid_torch[0].unsqueeze(0) # shape: [1, 1024, 30, 44] + image_grid_torch = image_grid_torch.to(dtype=torch.bfloat16) + else: + image_grid_torch = ttnn.to_torch(image_grid).to(dtype=torch.bfloat16) + + grid = torch.nn.functional.unfold( + image_grid_torch, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size + ) + + grid = ttnn.from_torch(grid, device=self.device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + + grid = ttnn.view(grid, (d * self.spatial_merge_size**2, -1)) + grid = ttnn.transpose(grid, 0, 1) # Transpose to have features first + + permuted_tensor.append(grid) + + image_features = ttnn.concat(permuted_tensor, dim=0) + # Apply merging layer + image_features = ttnn.linear( + image_features, self.merging_weights, dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + + return image_features + + +class TTMistral3MultiModalProjector(LightweightModule): + def __init__(self, mesh_device, args, state_dict, state_dict_prefix, dtype, eps, weight_cache_path=None): + super().__init__() + + self.norm = RMSNorm( + device=mesh_device, + dim=args.vision_dim, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="norm", + weight_dtype=dtype, + eps=eps, + ) + + self.patch_merger = TTMistral3PatchMerger( + mesh_device=mesh_device, + args=args, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}patch_merger.", + ) + + def get_weight(name): + return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) + + def get_bias(name): + return state_dict[f"{state_dict_prefix}{name}.bias"] + + def cache_name(name): + if args.dummy_weights: + return None + return weight_cache_path / f"{state_dict_prefix}.{name}" + + def as_tensor(name, dtype, is_bias=False): + tensor_data = get_bias(name) if is_bias else get_weight(name) + return ttnn.as_tensor( + tensor_data, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + + ) + + self.linear_1_weight = as_tensor("linear_1", dtype) + self.linear_1_bias = as_tensor("linear_1", ttnn.bfloat16, is_bias=False) + + self.linear_2_weight = as_tensor("linear_2", dtype) + self.linear_2_bias = as_tensor("linear_2", ttnn.bfloat16, is_bias=False) + + def forward(self, image_features: ttnn.Tensor, image_sizes): + image_features = self.norm(image_features, mode="decode") + image_features = self.patch_merger(image_features, image_sizes) + + hidden_states = ttnn.linear( + image_features, + self.linear_1_weight, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + activation="gelu", # Using GELU activation as per Mistral 3 model + ) + + hidden_states = ttnn.linear( + hidden_states, self.linear_2_weight, dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + + return hidden_states diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_model.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_model.py new file mode 100644 index 000000000000..60a920f8fcca --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_model.py @@ -0,0 +1,44 @@ +""" +This is the end-to-end architecture of the Mistral-24B vision model. + +It brings together all components related to visual and MultiModalProjector together. +""" + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.mistral_24b.mistral_vision_tower import MistralVisionTower +from models.tt_transformers.tt.multimodal.mistral_24b.vision_mmp import TTMistral3MultiModalProjector + + +class TtMistralVisionTransformer(LightweightModule): + def __init__(self, mesh_device, state_dict, state_dict_prefix, dtype, model_args): + super().__init__() + self.state_dict = state_dict + self.mesh_device = mesh_device + + self.vision_tower = MistralVisionTower( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + dtype=dtype, + configuration=model_args, + ) + + self.mmp = TTMistral3MultiModalProjector( + mesh_device=mesh_device, + args=model_args, + state_dict=state_dict, + state_dict_prefix="multi_modal_projector.", + dtype=dtype, + eps=1e-05, # layer_norm_eps + ) + + def forward(self, input_tensor, image_sizes=None): + """ + input_tensor shape: (B, C, H, W) + """ + + x = self.vision_tower(input_tensor, image_sizes=image_sizes) + x = ttnn.squeeze(ttnn.squeeze(x, 0), 0) + x = self.mmp(x, image_sizes) + return x diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_image_block.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_image_block.py new file mode 100644 index 000000000000..983a0d0891fa --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_image_block.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm +from models.tt_transformers.tt.multimodal.mistral_24b.vision_attention import ( + TtMistralImageAttention as TtLlamaImageAttention, +) +from models.tt_transformers.tt.multimodal.mistral_24b.vision_mlp import MistralTTVisionMLP as MLP + + +class TtPixtralImageTransformerBlock(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + ): + super().__init__() + self.state_dict = state_dict + self.mesh_device = mesh_device + self.num_devices = configuration.num_devices + self.hidden_size = configuration.vision_dim + + self.attention_norm = RMSNorm( + device=mesh_device, + dim=configuration.vision_dim, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="attention_norm", + weight_dtype=dtype, + is_distributed=False, + sharded_program_config=configuration.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=configuration.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], + ) + + self.attention = TtLlamaImageAttention( + mesh_device, + state_dict, + state_dict_prefix=f"{state_dict_prefix}attention.", + weight_cache_path=weight_cache_path, + dtype=dtype, + configuration=configuration, + ) + + self.ffn_norm = RMSNorm( + device=mesh_device, + dim=configuration.vision_dim, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="ffn_norm", + weight_dtype=dtype, + is_distributed=False, + sharded_program_config=configuration.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=configuration.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], + ) + + self.mlp = MLP( + mesh_device=mesh_device, + args=configuration, + state_dict=state_dict, + weight_cache_path=weight_cache_path, + state_dict_prefix=f"{state_dict_prefix}feed_forward.", + dtype=dtype, + ) + + def forward(self, x_input, mask=None, position_embeddings=None): + mode = "prefill" + attn_out = self.attention( + self.attention_norm(x_input, mode=mode), position_embeddings=position_embeddings, mask=mask + ) + res = ttnn.add(x_input, attn_out, use_legacy=True) + mlp_out = self.mlp(self.ffn_norm(res, mode=mode)) + out = ttnn.add(res, mlp_out) + return out diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_transformer.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_transformer.py new file mode 100644 index 000000000000..a8179f9a4dfe --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_transformer.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from tqdm import tqdm + +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.mistral_24b.vision_pixtral_image_block import TtPixtralImageTransformerBlock + + +class TtPixtralTransformer(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + layers, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + + block_key = "layers" + self.resblocks = [ + TtPixtralImageTransformerBlock( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}{block_key}.{i}.", + weight_cache_path=weight_cache_path, + dtype=dtype, + configuration=configuration, + ) + for i in tqdm(range(layers), desc=f"Loading vision transformer layers") + ] + + def forward(self, x, return_intermediate=None, mask=None, position_embeddings=None): + """ + Different from reference impl in that if return_intermediates, it returns + a list of intermediate tensors rather than a stack of intermediates. + Outer code will have to be aware and handle this correctly. + """ + out = [] + for idx, r in enumerate(self.resblocks): + if return_intermediate is not None and idx in return_intermediate: + out.append(x) + x = r(x, mask=mask, position_embeddings=position_embeddings) + if return_intermediate is not None: + return x, out + return x diff --git a/models/tt_transformers/tt/multimodal/mistral_24b/vision_rope.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_rope.py new file mode 100644 index 000000000000..3fdd45caea8f --- /dev/null +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_rope.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.common import precompute_vision_freqs +from ttnn import ReplicateTensorToMesh + + +def compute_gather_cos_sin(dhead, max_patches_per_side, theta, scale_factor, orig_context_len, position_ids): + cos, sin = precompute_vision_freqs(dhead, max_patches_per_side, theta, scale_factor, orig_context_len) + return cos, sin + + +class VisionRotarySetup(LightweightModule): + def __init__( + self, + device, + batch_size: int, + head_dim: int, + image_size: int, + patch_size: int, + max_seq_len: int, + rope_theta: float, + scale_factor: float, # use None to disable rope scaling + orig_context_len: int, # only used if scaling enabled + datatype=ttnn.bfloat16, + ): + super().__init__() + + self.batch_size = batch_size + self.head_dim = head_dim + self.device = device + self.is_mesh_device = isinstance(device, ttnn._ttnn.multi_device.MeshDevice) + self.num_devices = device.get_num_devices() if self.is_mesh_device else 1 + if self.num_devices == 32: + self.batch_size_per_device_group = max(self.batch_size // list(device.shape)[1], 1) + else: + self.batch_size_per_device_group = self.batch_size + self.core_grid = device.compute_with_storage_grid_size() + + max_patches_per_side = image_size // patch_size + + # Generate the cos/sin matrices needed for ttnn.embedding op + cos_matrix, sin_matrix = compute_gather_cos_sin( + dhead=head_dim, + max_patches_per_side=max_patches_per_side, + theta=rope_theta, + scale_factor=scale_factor, + orig_context_len=orig_context_len, + position_ids=torch.arange(max_seq_len), + ) + self.cos_matrix = ttnn.from_torch( + cos_matrix, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=datatype, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + self.sin_matrix = ttnn.from_torch( + sin_matrix, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=datatype, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + + def get_rot_mats(self, position_idxs, return_rot_idxs=False): + device = self.device + + + # If position_idxs is a torch tensor, get the TTNN version of it + if isinstance(position_idxs, torch.Tensor): + rot_idxs = position_idxs.unsqueeze(0) + else: + rot_idxs = position_idxs + assert len(rot_idxs.shape) == 2 and rot_idxs.shape[0] == 1, "rot_idxs must be a [1, batch] tensor" + + rot_idxs = ttnn.from_torch( + rot_idxs, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=ttnn.uint32, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + # Send the idxs to device + if rot_idxs.device != device: + rot_idxs = ttnn.to_device(rot_idxs, device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + + embedding_layout = ttnn.TILE_LAYOUT + cos = ttnn.embedding(rot_idxs, self.cos_matrix, layout=embedding_layout) # [1, batch, head_dim] + sin = ttnn.embedding(rot_idxs, self.sin_matrix, layout=embedding_layout) # [1, batch, head_dim] + + if return_rot_idxs: + return [cos, sin], rot_idxs + ttnn.deallocate(rot_idxs) + return [cos, sin] diff --git a/real_inputs/pixtral_transformer_inputs/people.jpg b/real_inputs/pixtral_transformer_inputs/people.jpg new file mode 100644 index 000000000000..16dad8dcbf18 Binary files /dev/null and b/real_inputs/pixtral_transformer_inputs/people.jpg differ