diff --git a/models/common/rmsnorm.py b/models/common/rmsnorm.py index 35d7ec55121e..4c15554d80a4 100644 --- a/models/common/rmsnorm.py +++ b/models/common/rmsnorm.py @@ -52,6 +52,7 @@ def __init__( output_mem_config=None, ccl_topology=ttnn.Topology.Ring, tt_ccl=None, + simplified_rms=False, ): super().__init__() self.device = device @@ -114,13 +115,21 @@ def __init__( fp32_dest_acc_en=True, packer_l1_acc=True, ) + self.simplified_rms = simplified_rms def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) -> ttnn.Tensor: # If input is sharded do sharded RMSNorm and optionally return sharded output program_config = self.sharded_program_config if in_sharded else None memory_config = self.sharded_output_config if out_sharded else None distributed = self.is_distributed and self.is_distributed(mode) - norm = self._distributed_rmsnorm if distributed else ttnn.rms_norm + norm = ( + self._simplified_rmsnorm + if self.simplified_rms + else self._distributed_rmsnorm + if distributed + else ttnn.rms_norm + ) + weight = self.weight_distributed if distributed else self.weight if in_sharded: @@ -142,6 +151,25 @@ def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) -> else: return x + def _simplified_rmsnorm( + self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None + ): + inp = ttnn.sharded_to_interleaved(inp, ttnn.DRAM_MEMORY_CONFIG) + xnorm = ttnn.pow(inp, 2) + xnorm = ttnn.mean(xnorm, dim=-1, keepdim=True) + xnorm = ttnn.rsqrt(xnorm + epsilon) + xnorm = ttnn.multiply(inp, xnorm) + weight = ttnn.reshape(weight, [1, 1, -1]) + output = ttnn.multiply(xnorm, (weight), use_legacy=False) + + if memory_config is not None: + output = ttnn.to_memory_config(output, memory_config) + + ttnn.deallocate(xnorm) + ttnn.deallocate(weight) + + return output + def _distributed_rmsnorm( self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None ): diff --git a/models/experimental/qwen25_vl/tests/test_attention.py b/models/experimental/qwen25_vl/tests/test_attention.py new file mode 100644 index 000000000000..c01f78924f4e --- /dev/null +++ b/models/experimental/qwen25_vl/tests/test_attention.py @@ -0,0 +1,128 @@ +"""Test for Qwen 2.5 VL Vision Attention""" + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs + +from models.experimental.qwen25_vl.tt.attention import TtQwen2_5_VLVisionSdpaAttention +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_attention_inference(batch, num_chunks, mesh_device, reset_seeds): + dtype = ttnn.bfloat16 + pcc_required = 0.99 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "visual.blocks.0.attn." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + dim = model_args.vision_dim + + reference_model = model_args.reference_vision_attention() + reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + hidden_size = model_args.vision_dim + n_heads = model_args.vision_attn_n_heads + head_dim = hidden_size // n_heads + seq_len = model_args.vision_chunk_ntok + + tt_model = TtQwen2_5_VLVisionSdpaAttention( + mesh_device, + state_dict, + state_dict_prefix=first_layer_prefix, + # weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + ) + + seq_len = 4096 + hidden_dim = 1280 + num_heads = 16 + head_dim = hidden_dim // num_heads # 80 + rotary_dim = head_dim // 2 # 40 + + # Step 1: PyTorch input + pt_attention_input = torch.randn(seq_len, hidden_dim) # no batch dim + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32) + + # Step 2: precompute cos/sin + cos, sin = precompute_rope_cos_sin(seq_len, head_dim) + + # Step 3: run PyTorch reference + reference_output = reference_model( + pt_attention_input, cu_seqlens, rotary_pos_emb=None, position_embeddings=(cos, sin) + ) + + # Step 4: TT input + tt_attention_input = model_args.prepare_residual_tensor_prefill( + pt_attention_input.unsqueeze(0), force_replicated=True + ) + + cos_tensor = ttnn.from_torch(cos, device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + sin_tensor = ttnn.from_torch(sin, device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + + # Step 6: run TT + tt_out = tt_model(tt_attention_input, cu_seqlens, position_embeddings=(cos_tensor, sin_tensor)) + + # Doing contract in tt is correct!! + tt_output_torch = ttnn.to_torch(tt_out, device=mesh_device).squeeze(0) + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" + + +def precompute_rope_cos_sin(seq_len: int, dim: int, theta: float = 10000.0): + """ + Precompute RoPE cos/sin tensors. + Args: + seq_len: sequence length (number of tokens) + dim: hidden size (usually head_dim, not full hidden_size) + theta: RoPE theta parameter (default 10000) + Returns: + cos, sin: [seq_len, dim] each + """ + # Build the rope frequencies + half_dim = dim // 2 + freq_seq = torch.arange(half_dim, dtype=torch.float32) + inv_freq = 1.0 / (theta ** (freq_seq / half_dim)) + + # positions: [seq_len] + positions = torch.arange(seq_len, dtype=torch.float32) + + # Outer product: [seq_len, half_dim] + sinusoid_inp = torch.outer(positions, inv_freq) + + # Concatenate for complex dim + sin = torch.sin(torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)) + cos = torch.cos(torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)) + + return cos, sin diff --git a/models/experimental/qwen25_vl/tests/test_e2e.py b/models/experimental/qwen25_vl/tests/test_e2e.py new file mode 100644 index 000000000000..63a9b0ad0b00 --- /dev/null +++ b/models/experimental/qwen25_vl/tests/test_e2e.py @@ -0,0 +1,408 @@ +"""Test for Qwen 2.5 VL End-to-End Vision-Text Pipeline""" + +import torch +import pytest +from loguru import logger +import os +import ttnn +from models.tt_transformers.tt.common import ( + sample_host, + PagedAttentionConfig, + preprocess_inputs_prefill, +) +from models.tt_transformers.tt.model_config import DecodersPrecision + +from models.experimental.qwen25_vl.tt.model import Qwen25VLTransformer as Transformer + +from models.tt_transformers.tt.generator import Generator + +from models.experimental.qwen25_vl.tt.vision_model import TtQwen2_5_VisionTransformerPretrainedModel +from models.utility_functions import skip_for_grayskull, skip_for_blackhole + +from models.tt_transformers.tt.model_config import ModelArgs +from transformers import AutoProcessor + +import re + + +def parse_chat_output(text): + """Parse chat output format from generated text.""" + pattern = r"<\|(?Puser|assistant)\|>\s*(?P.*?)(?=<\|(?:user|assistant|end)\|>|$)" + matches = re.finditer(pattern, text, re.DOTALL) + return [(match.group("role"), match.group("message").strip()) for match in matches] + + +def display_chat(logger, conversation): + """Display chat conversation in formatted output.""" + for role, message in conversation: + if role == "user": + logger.info(f"👤 User: {message}") + elif role == "assistant": + logger.info(f"🤖 Assistant: {message}") + + +def setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations): + """Setup model arguments for vision-enabled model (Single Responsibility).""" + instruct = True if weights == "instruct" else False + + model_args = ModelArgs( + mesh_device=mesh_device, + instruct=instruct, + optimizations=optimizations, + max_seq_len=max_seq_len, + max_batch_size=batch_size, + ) + + return model_args, instruct + + +def setup_vision_prompts_and_tokenizer(model_args, instruct): + """Setup multimodal prompts and tokenizer for vision-enabled model.""" + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://raw.githubusercontent.com/yavuzceliker/sample-images/refs/heads/main/images/image-1.jpg", + }, + {"type": "text", "text": "Describe this image in detail in 1000 words."}, + ], + } + ] + + tokenizer = model_args.tokenizer + return messages, tokenizer + + +def process_real_vision_inputs(messages, model_args): + """Process real image inputs using AutoProcessor (Interface Segregation).""" + processor = AutoProcessor.from_pretrained(os.getenv("HF_MODEL")) + + encoded = processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to("cpu", dtype=torch.bfloat16) + + input_ids = encoded["input_ids"] + pixel_values = encoded["pixel_values"] + attention_mask = encoded["attention_mask"] + image_grid_thw = encoded["image_grid_thw"] + + # logger.info(f"Processed vision inputs - input_ids: {input_ids.shape}, pixel_values: {pixel_values.shape}") + + return { + "input_ids": input_ids, + "pixel_values": pixel_values, + "attention_mask": attention_mask, + "image_grid_thw": image_grid_thw, + "processor": processor, + } + + +def load_separate_models_like_test_end2end(model_args, mesh_device, dtype, paged_attention, page_params): + """Load separate vision and text models following test_end2end.py pattern.""" + state_dict = model_args.load_state_dict() + + vision_prefix = "visual." + + # Setup paged attention config (exactly like test_end2end.py) + paged_attention_config = None + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + + # Load vision model (exactly like test_end2end.py) + vision_model = TtQwen2_5_VisionTransformerPretrainedModel( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=vision_prefix, + dtype=dtype, + model_args=model_args, + weight_cache_path=model_args.weight_cache_path(dtype), + layers=model_args.vision_n_layers, + ) + # Load text model (exactly like test_end2end.py) + + text_model = Transformer( + args=model_args, + mesh_device=mesh_device, + dtype=dtype, + state_dict=state_dict, + weight_cache_path=model_args.weight_cache_path(dtype), + paged_attention_config=paged_attention_config, + ) + + logger.info("Separate vision and text models loaded like test_end2end.py") + return vision_model, text_model + + +def run_generation_exactly_like_test_end2end( + vision_model, text_model, processed_inputs, model_args, page_table=None, paged_attention_config=None, max_gen_len=20 +): + """Run generation following the EXACT pattern from test_end2end.py.""" + input_ids = processed_inputs["input_ids"] + pixel_values = processed_inputs["pixel_values"] + + logger.info("Running generation exactly like test_end2end.py...") + + logger.info("Running Vision Model...") + + generator = Generator([text_model], [model_args], vision_model.mesh_device, tokenizer=model_args.tokenizer) + + tt_kv_cache = [[l.attention.layer_past for l in text_model.layers]] if paged_attention_config else None + + input_tokens_prefill = input_ids + batch_size = input_tokens_prefill.shape[0] + + prompt_text = model_args.tokenizer.decode(input_ids[0].tolist()) + input_prompts = [prompt_text] + + ( + input_tokens_prefill_pt, + encoded_prompts, + decoding_pos, + prefill_lens, + ) = preprocess_inputs_prefill( + input_prompts, + model_args.tokenizer, + [model_args], + instruct=True, + max_generated_tokens=max_gen_len, + max_prefill_len=8192, + ) + + input_tokens_prefill_pt = torch.stack(input_tokens_prefill_pt).view(batch_size, -1) + + logger.info("Running prefill...") + logits = generator.prefill_forward_text( + input_tokens_prefill_pt, + page_table=page_table, + kv_cache=tt_kv_cache, + prompt_lens=decoding_pos, + vision_model=vision_model, + processed_inputs=processed_inputs, + ) + + prefilled_token = torch.argmax(logits, dim=-1) + logger.info(f"Prefilled token: {prefilled_token}") + + all_outputs = [encoded_prompts[0][: prefill_lens[0]]] + all_outputs[0].append(int(prefilled_token[0].item())) + + current_pos = torch.tensor([decoding_pos[0]]) + out_tok = prefilled_token + generation_length = 2000 + + results = [] + + logger.info("Starting decode loop...") + for iteration in range(generation_length): + logger.info(f"[Text] Decoding token {iteration}, current_pos: {current_pos.item()}") + + logits = generator.decode_forward_text( + out_tok, + current_pos, + enable_trace=False, + page_table=page_table, + kv_cache=tt_kv_cache, + ) + + _, out_tok = sample_host( + logits, + temperature=0, + top_p=0.9, + ) + + token_id = out_tok[0].item() + decoded_token = model_args.tokenizer.decode([token_id]) + logger.info(f"Generated token {iteration}: ID={token_id}, text='{decoded_token}'") + + # Create result object + result = type("TokenResult", (), {"token": token_id, "text": decoded_token})() + + results.append(result) + + all_outputs[0].append(token_id) + current_pos += 1 + + if token_id == 151645 or token_id == 151643: + logger.warning("Reached End token") + break + + # Early stopping (exactly like test_end2end.py) + if len(all_outputs[0]) >= 5 and all(t == all_outputs[0][-1] for t in all_outputs[0][-5:]): + logger.warning(f"Detected exact repetition of token {all_outputs[0][-1]} five times in a row. Stopping.") + break + + # Final response (exactly like test_end2end.py) + response = model_args.tokenizer.decode(all_outputs[0], skip_special_tokens=True) + logger.info(f"📝 Final Generated Response:\n{response}") + logger.info(f"📝 Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") + chat = parse_chat_output(response) + display_chat(logger, chat) + + logger.info(f"Generated {len(results)} tokens successfully") + return results + + +def validate_e2e_outputs(results, expected_min_tokens=1): + """Validate end-to-end pipeline outputs.""" + if not results: + logger.error("No results generated from E2E pipeline") + return False + + if len(results) < expected_min_tokens: + logger.warning(f"Generated only {len(results)} tokens, expected at least {expected_min_tokens}") + return False + + # Check if tokens are valid + for result in results: + if not hasattr(result, "token") or not hasattr(result, "text"): + logger.error("Invalid result format") + return False + + logger.info("E2E pipeline validation passed") + return True + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@skip_for_blackhole("Failing on DRAM harvested P100a, see #21419") +@pytest.mark.timeout(1800) +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize( + "weights, layers", + [ + ("instruct", None), + ], + ids=["full"], +) +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False, + ), + ids=( + "paged_attention", + # "default_attention", + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (2048,), # Use smaller seq_len like test_end2end.py to avoid memory issues +) +@pytest.mark.parametrize( + "optimizations", + [ + lambda model_args: DecodersPrecision.accuracy(model_args.n_layers, model_args.model_name), + ], + ids=["accuracy"], +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +# @pytest.mark.parametrize("device_params", [{"l1_small_size": 1584864, "trace_region_size": 0}], indirect=True) +def test_e2e_vision_text_pipeline( + weights, + layers, + max_seq_len, + batch_size, + paged_attention, + page_params, + optimizations, + mesh_device, + reset_seeds, + request, + device_params, +): + """Test end-to-end vision-text pipeline using proper Generator methods.""" + logger.info("Starting E2E vision-text pipeline test") + + # Use bfloat8_b like test_end2end.py for better memory efficiency + dtype = ttnn.bfloat8_b + + # Setup vision-enabled model configuration + model_args, instruct = setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations) + + if layers is not None: + model_args.n_layers = layers + + # Setup vision prompts and tokenizer + messages, tokenizer = setup_vision_prompts_and_tokenizer(model_args, instruct) + + # Process real vision inputs from images + processed_inputs = process_real_vision_inputs(messages, model_args) + + # Load separate models following test_end2end.py pattern + logger.info("Loading separate vision and text models like test_end2end.py...") + vision_model, text_model = load_separate_models_like_test_end2end( + model_args, mesh_device, dtype, paged_attention, page_params + ) + + # Setup page table for paged attention (exactly like test_end2end.py) + page_table_tt = None + paged_attention_config = None + + # Prepare page table for paged attention (exactly like test_end2end.py) + page_table = None + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, -2) if batch_size > 1 else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + # Run generation following EXACT test_end2end.py pattern + logger.info("Running generation following EXACT test_end2end.py pattern...") + results = run_generation_exactly_like_test_end2end( + vision_model, text_model, processed_inputs, model_args, page_table, paged_attention_config, max_gen_len=10 + ) + + # Validate results + validation_passed = validate_e2e_outputs(results, expected_min_tokens=1) + + # Final validation + if validation_passed and len(results) > 0: + logger.info("✅ E2E vision-text pipeline test PASSED!") + logger.info(f"Successfully generated {len(results)} tokens") + + # Log generated tokens for debugging + for i, result in enumerate(results[:5]): + logger.info(f"Token {i}: {result.token} -> '{result.text}'") + else: + logger.error("❌ E2E pipeline test failed") + assert False, f"E2E pipeline failed - generated {len(results)} tokens, validation: {validation_passed}" diff --git a/models/experimental/qwen25_vl/tests/test_image_merger.py b/models/experimental/qwen25_vl/tests/test_image_merger.py new file mode 100644 index 000000000000..442165307a72 --- /dev/null +++ b/models/experimental/qwen25_vl/tests/test_image_merger.py @@ -0,0 +1,103 @@ +""""Test for Qwen 2.5 VL Patch Merger""" + +from loguru import logger + +import torch +import pytest +import os +import ttnn +from models.experimental.qwen25_vl.tt.patch_merger import TTQwen2_5_VLPatchMerger + +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from models.tt_transformers.tt.model_config import ModelArgs + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_patch_merger_inference(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_vision_qwen_merger() # Qwen Patch merger + first_layer_prefix = "visual.merger." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model.load_state_dict(partial_state_dict) + + tt_model = TTQwen2_5_VLPatchMerger( + device=device, + dim=5120, + state_dict=state_dict, + weight_key=first_layer_prefix, + layer_num=None, + state_dict_prefix="", + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps=1e-06, + dims=3584, + context_dim=1280, + spatial_merge_size=2, + mode=mode, + ) + + input = torch.rand(1, 4, 1280) + reference_output = reference_model(input) + tt_input = ttnn.from_torch( + input, + device=device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + tt_input = ttnn.reshape(tt_input, [1, 4, 1280]) + + tt_output = tt_model(tt_input) + + # DistributedNorm outputs are replicated across devices + tt_output_torch = ttnn.to_torch( + tt_output, + ) + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("Merger Passed!") + else: + logger.warning("Merger Failed!") + + assert passing, f"Merger output does not meet PCC requirement {0.99}." diff --git a/models/experimental/qwen25_vl/tests/test_image_patch_embed.py b/models/experimental/qwen25_vl/tests/test_image_patch_embed.py new file mode 100644 index 000000000000..c08423637b2f --- /dev/null +++ b/models/experimental/qwen25_vl/tests/test_image_patch_embed.py @@ -0,0 +1,95 @@ +""""Test for Qwen 2.5 VL Patch Embed""" + +from loguru import logger + +import torch +import pytest +import os +import ttnn +from models.experimental.qwen25_vl.tt.patch_embed import TTQwen2_5_VisionPatchEmbed + +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from models.tt_transformers.tt.model_config import ModelArgs + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_embed_inference(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_vision_qwen_patch_embed() # Qwen Patch embed + first_layer_prefix = "visual.patch_embed." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model.load_state_dict(partial_state_dict) + + tt_model = TTQwen2_5_VisionPatchEmbed( + device=device, + patch_size=14, + temporal_patch_size=2, + in_channels=3, + embed_dim=1280, + state_dict=state_dict, + weight_key=first_layer_prefix, + layer_num=None, + state_dict_prefix="", + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + mode=mode, + ) + + input = torch.rand(1, 1, 1380, 1176) + reference_output = reference_model(input) + + tt_input = ttnn.from_torch( + input, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + + tt_output = tt_model(tt_input) + + tt_output_torch = ttnn.to_torch( + tt_output, + ) + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("Patch embed Passed!") + else: + logger.warning("Patch embed Failed!") + + assert passing, f"Patch embed output does not meet PCC requirement {0.99}." diff --git a/models/experimental/qwen25_vl/tests/test_mlp.py b/models/experimental/qwen25_vl/tests/test_mlp.py new file mode 100644 index 000000000000..ee936435b8ec --- /dev/null +++ b/models/experimental/qwen25_vl/tests/test_mlp.py @@ -0,0 +1,82 @@ +""""Test for Qwen 2.5 VL Vision MLP""" + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs + +from models.experimental.qwen25_vl.tt.mlp import QwenTTVisionMLP +from models.utility_functions import comp_allclose, comp_pcc, nearest_32, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_mlp_inference(batch, num_chunks, mesh_device, reset_seeds): + dtype = ttnn.bfloat16 + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "visual.blocks.0.feed_forward." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + model_args.WEIGHTS_DTYPE = dtype + + dim = model_args.vision_dim + seq_len = nearest_32(model_args.vision_chunk_ntok) * num_chunks + reference_model = model_args.reference_vision_mlp() + reference_model.load_state_dict(partial_state_dict) + + tt_model = QwenTTVisionMLP( + mesh_device=mesh_device, + args=model_args, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + ) + + torch_input = torch.randn(1, batch, seq_len, dim) + + reference_output = reference_model(torch_input).squeeze() + + tt_input = ttnn.from_torch( + torch_input, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + ) + + tt_output = tt_model(tt_input) + + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ + :, :1, :, : + ].squeeze() + + pcc_required = 0.99 + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/qwen25_vl/tests/test_vision_block.py b/models/experimental/qwen25_vl/tests/test_vision_block.py new file mode 100644 index 000000000000..69bd51eb5617 --- /dev/null +++ b/models/experimental/qwen25_vl/tests/test_vision_block.py @@ -0,0 +1,116 @@ +""""Test for Qwen 2.5 VL Vision Transformer Block""" + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs + +from models.experimental.qwen25_vl.tt.vision_block import TtQwen2_5_VLVisionBlock +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_transformer_inference(batch, num_chunks, mesh_device, reset_seeds): + dtype = ttnn.bfloat16 + pcc_required = 0.99 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "visual.blocks.0." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + dim = model_args.vision_dim + + reference_model = model_args.reference_vision_block() + reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + vision_dim = model_args.vision_dim + n_heads = model_args.vision_attn_n_heads + head_dim = vision_dim // n_heads + seq_len = model_args.vision_chunk_ntok - 1 + + tt_model = TtQwen2_5_VLVisionBlock( + mesh_device, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + model_args=model_args, + dtype=dtype, + ) + + pt_attention_input = torch.randn(seq_len, vision_dim) # no batch dim + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32) + + cos, sin = precompute_rope_cos_sin(seq_len, head_dim) + + reference_output = reference_model( + pt_attention_input, cu_seqlens, rotary_pos_emb=None, position_embeddings=(cos, sin) + ) + + tt_attention_input = model_args.prepare_residual_tensor_prefill( + pt_attention_input.unsqueeze(0), force_replicated=True + ) + + cos_tensor = ttnn.from_torch(cos, device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + sin_tensor = ttnn.from_torch(sin, device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + + tt_out = tt_model(tt_attention_input, cu_seqlens, position_embeddings=(cos_tensor, sin_tensor)) + + tt_output_torch = ttnn.to_torch(tt_out, device=mesh_device).squeeze(0).squeeze(0) + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" + + +def precompute_rope_cos_sin(seq_len: int, dim: int, theta: float = 10000.0): + """ + Precompute RoPE cos/sin tensors. + Args: + seq_len: sequence length (number of tokens) + dim: hidden size (usually head_dim, not full hidden_size) + theta: RoPE theta parameter (default 10000) + Returns: + cos, sin: [seq_len, dim] each + """ + # Build the rope frequencies + half_dim = dim // 2 + freq_seq = torch.arange(half_dim, dtype=torch.float32) + inv_freq = 1.0 / (theta ** (freq_seq / half_dim)) + + # positions: [seq_len] + positions = torch.arange(seq_len, dtype=torch.float32) + + # Outer product: [seq_len, half_dim] + sinusoid_inp = torch.outer(positions, inv_freq) + + # Concatenate for complex dim + sin = torch.sin(torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)) + cos = torch.cos(torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)) + + return cos, sin diff --git a/models/experimental/qwen25_vl/tests/test_vision_model.py b/models/experimental/qwen25_vl/tests/test_vision_model.py new file mode 100644 index 000000000000..bf0bef57ca25 --- /dev/null +++ b/models/experimental/qwen25_vl/tests/test_vision_model.py @@ -0,0 +1,84 @@ +"""Test for Qwen 2.5 VL Vision Transformer Pretrained Model Inference""" + +import os + +import pytest +import torch +from loguru import logger + +import ttnn + +from models.tt_transformers.tt.model_config import ModelArgs + +from models.experimental.qwen25_vl.tt.vision_model import TtQwen2_5_VisionTransformerPretrainedModel +from models.utility_functions import comp_pcc, skip_for_grayskull, comp_allclose + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_vision_inference(batch, num_chunks, mesh_device, reset_seeds): + dtype = ttnn.bfloat16 + pcc_required = 0.99 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "visual." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + dim = model_args.vision_dim + + reference_model = model_args.reference_vision_model() + reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + n_layers = model_args.vision_n_layers + + tt_model = TtQwen2_5_VisionTransformerPretrainedModel( + mesh_device, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + model_args=model_args, + dtype=dtype, + layers=n_layers, + ) + + pt_input = torch.randn([32, 1176]) # no batch dim + grid_thw = torch.tensor([[1, 4, 8]]) + + reference_output = reference_model( + pt_input, + grid_thw, + ) + + tt_attention_input = model_args.prepare_residual_tensor_prefill(pt_input.unsqueeze(0), force_replicated=True) + tt_out = tt_model(tt_attention_input, grid_thw) + + tt_output_torch = ttnn.to_torch(tt_out, device=mesh_device) + + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/qwen25_vl/tests/test_vision_rms.py b/models/experimental/qwen25_vl/tests/test_vision_rms.py new file mode 100644 index 000000000000..ca453d09940c --- /dev/null +++ b/models/experimental/qwen25_vl/tests/test_vision_rms.py @@ -0,0 +1,113 @@ +"""Test for Qwen 2.5 VL RMSNorm Layer Inference""" + +from loguru import logger + +import torch +import pytest +import os + +import ttnn +from models.experimental.qwen25_vl.tt.rmsnorm import RMSNorm + +from models.tt_transformers.tt.distributed_norm import DistributedNorm + + +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from models.tt_transformers.tt.model_config import ModelArgs + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_rmsnorm_inference(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + dim = tt_model_args.vision_dim + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_vision_rms_norm() # Qwen2_5 RMSNorm + first_layer_prefix = "visual.blocks.0.norm1." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model.load_state_dict(partial_state_dict) + + tt_inner_norm = RMSNorm( + device=device, + dim=dim, + state_dict=state_dict, + state_dict_prefix="", + weight_key=first_layer_prefix[:-1], # Remove trailing dot + weight_dtype=dtype, + is_distributed=False, + sharded_program_config=tt_model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=tt_model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], + ) + + # Wrap it in DistributedNorm + tt_model = DistributedNorm(tt_inner_norm, tt_model_args, TG=tt_model_args.is_galaxy) + + input = torch.rand(1, 1, 1280) + + reference_output = reference_model(input) + + # DistributedNorm inputs are fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode) + tt_input = ttnn.from_torch( + input, + device=device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, -1), mesh_shape=tt_model_args.cluster_shape), + memory_config=( + tt_model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG + ), + ) + + tt_output = tt_model(tt_input, mode=mode) + + # DistributedNorm outputs are replicated across devices + tt_output_torch = ttnn.to_torch( + tt_output, + mesh_composer=ttnn.ConcatMesh2dToTensor( + device, dims=(0, 2) if tt_model_args.is_galaxy else (2, 0), mesh_shape=tt_model_args.cluster_shape + ), + )[:1, :, :] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("rms_norm Passed!") + else: + logger.warning("rms_norm Failed!") + + assert passing, f"rms_norm output does not meet PCC requirement {0.99}." diff --git a/models/experimental/qwen25_vl/tt/attention.py b/models/experimental/qwen25_vl/tt/attention.py new file mode 100644 index 000000000000..a515bf35e737 --- /dev/null +++ b/models/experimental/qwen25_vl/tt/attention.py @@ -0,0 +1,141 @@ +""" +This is the vision attention implementation for Qwen-VL-7B. + +We couldn't reuse the LLaMA version from tt_transformers because it expects separate q, k, v weights, +but Qwen-VL uses fused qkv weights. So this has been rewritten to support that, +based on the original code at: +models/tt_transformers/tt/multimodal/llama_image_attention.py +""" + + +import torch +import ttnn +from models.common.lightweightmodule import LightweightModule + + +def rotate_half(x): + x1 = ttnn.slice(x, (0, 0, 0), (x.shape[0], x.shape[1], x.shape[2] // 2)) + x2 = ttnn.slice(x, (0, 0, x.shape[-1] // 2), (x.shape[0], x.shape[1], x.shape[2])) + return ttnn.concat([ttnn.mul(x2, -1, use_legacy=False), x1], dim=-1) + + +def apply_rotary_pos_emb_vision_tt(q, k, cos, sin): + cos = ttnn.unsqueeze(cos, -2) + sin = ttnn.unsqueeze(sin, -2) + + q_embed = ttnn.add(ttnn.mul(q, cos), ttnn.mul(rotate_half(q), sin)) + k_embed = ttnn.add(ttnn.mul(k, cos), ttnn.mul(rotate_half(k), sin)) + return q_embed, k_embed + + +class TtQwen2_5_VLVisionSdpaAttention(LightweightModule): + def __init__(self, mesh_device, state_dict, state_dict_prefix, dtype, configuration): + super().__init__() + + self.mesh_device = mesh_device + self.dtype = dtype + self.hidden_size = 1280 + self.num_heads = 16 + self.head_dim = self.hidden_size // self.num_heads + self.scale = self.head_dim**-0.5 + + # Load qkv weight & bias (fused): shape [hidden_size, hidden_size*3] + qkv_weight = state_dict[f"{state_dict_prefix}qkv.weight"] + qkv_bias = state_dict[f"{state_dict_prefix}qkv.bias"] + + # Transpose to [hidden_size, 3*hidden_size] for matmul + self.qkv_weight = ttnn.as_tensor( + torch.transpose(qkv_weight, -2, -1), + device=mesh_device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + self.qkv_bias = ttnn.as_tensor( + qkv_bias, device=mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + + # Output projection: proj + proj_weight = state_dict[f"{state_dict_prefix}proj.weight"] # shape [hidden_size, hidden_size] + proj_bias = state_dict[f"{state_dict_prefix}proj.bias"] # shape [hidden_size] + + self.proj_weight = ttnn.as_tensor( + torch.transpose(proj_weight, -2, -1), + device=mesh_device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + self.proj_bias = ttnn.as_tensor( + proj_bias, device=mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + + self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=True, + packer_l1_acc=True, + dst_full_sync_en=False, + ) + + def forward(self, hidden_states, cu_seqlens, position_embeddings): + """ + hidden_states: ttnn.Tensor of shape [batch, seq_len, hidden_size] + position_embeddings: tuple (cos, sin) each of shape [seq_len, head_dim] + """ + seq_len = hidden_states.shape[-2] + cos, sin = position_embeddings + # Fused qkv projection + qkv = ttnn.linear( + hidden_states, + self.qkv_weight, + bias=self.qkv_bias, + dtype=self.dtype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) # shape [batch, seq_len, hidden_size*3] + + (q, k, v) = ttnn.permute(ttnn.reshape(qkv, [seq_len, 3, self.num_heads, -1]), [1, 0, 2, 3]) + ttnn.deallocate(qkv) + + # Apply rotary position embeddings + q, k = apply_rotary_pos_emb_vision_tt(q, k, cos, sin) + # return q + + seq_len = cu_seqlens[-1].item() + + q = ttnn.unsqueeze(ttnn.permute(ttnn.pad(q, [(0, 0), (0, 0), (0, 16)], 0), [1, 0, 2]), 0) + k = ttnn.unsqueeze(ttnn.permute(ttnn.pad(k, [(0, 0), (0, 0), (0, 16)], 0), [1, 0, 2]), 0) + v = ttnn.unsqueeze(ttnn.permute(ttnn.pad(v, [(0, 0), (0, 0), (0, 16)], 0), [1, 0, 2]), 0) + + attn_output = ttnn.transformer.scaled_dot_product_attention( + q, k, v, is_causal=False, scale=self.scale + ) # shape [1, seq_len, num_heads, head_dim] + + ttnn.deallocate(q) + ttnn.deallocate(k) + ttnn.deallocate(v) + + # attn_output shape: [1, 16, 4096, 96] + # Need to slice back from 96 → 80 + attn_output = ttnn.slice( + attn_output, + (0, 0, 0, 0), + (attn_output.shape[0], attn_output.shape[1], attn_output.shape[2], self.head_dim), # head_dim=80 + ) + + attn_output = ttnn.permute(ttnn.squeeze(attn_output, 0), [1, 0, 2]) + attn_output = ttnn.reshape(attn_output, [seq_len, -1]) + + # Final projection + output = ttnn.linear( + attn_output, + self.proj_weight, + bias=self.proj_bias, + dtype=self.dtype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + + ttnn.deallocate(attn_output) + + return output diff --git a/models/experimental/qwen25_vl/tt/mlp.py b/models/experimental/qwen25_vl/tt/mlp.py new file mode 100644 index 000000000000..e1ebfe02d6ed --- /dev/null +++ b/models/experimental/qwen25_vl/tt/mlp.py @@ -0,0 +1,121 @@ +""" +This is the MLP (feed-forward) implementation for Qwen-VL-7B. + +We couldn't reuse TtLlamaImageFeedForward from tt_transformers because the logic is different. +Qwen does: down_proj(act_fn(gate_proj(x)) * up_proj(x)) +Tt does: c_proj(activation(c_fc(x))) + +So this version was written specifically for Qwen, based on its architecture. +""" + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +class QwenTTVisionMLP(LightweightModule): + def __init__( + self, + mesh_device, + args, + state_dict, + weight_cache_path, + dtype, + state_dict_prefix=None, + ): + super().__init__() + + self.mesh_device = mesh_device + self.args = args + self.state_dict = state_dict + self.dim = args.dim + + def get_weight(name): + return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) + + def get_bias(name): + return state_dict[f"{state_dict_prefix}{name}.bias"] + + def cache_name(name): + if args.dummy_weights: + return None + return weight_cache_path / f"{state_dict_prefix}.{name}" + + def as_tensor(name, dtype, is_bias=False): + tensor_data = get_bias(name) if is_bias else get_weight(name) + return ttnn.as_tensor( + tensor_data, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + # cache_file_name=cache_name(name), + ) + + # Weights and Biases + self.w1 = as_tensor("w1", dtype) + self.b1 = as_tensor("w1", ttnn.bfloat16, is_bias=True) + + self.w3 = as_tensor("w3", dtype) + self.b3 = as_tensor("w3", ttnn.bfloat16, is_bias=True) + + self.w2 = as_tensor("w2", dtype) + self.b2 = as_tensor("w2", ttnn.bfloat16, is_bias=True) + + self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=True, + packer_l1_acc=True, + dst_full_sync_en=False, + ) + + def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: + """ + Qwen HF MLP reference: + output = down_proj(act_fn(gate_proj(x)) * up_proj(x)) + Mapping: + w1 -> gate_proj + w3 -> up_proj + w2 -> down_proj + """ + + # Linear with GELU activation + w1_out = ttnn.linear( + x, + self.w1, + bias=self.b1, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + activation="silu", + compute_kernel_config=self.compute_kernel_config, + ) + + w3_out = ttnn.linear( + x, + self.w3, + bias=self.b3, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + + # Element-wise multiply + w2_in = ttnn.mul(w1_out, w3_out, dtype=ttnn.bfloat16) + + # Final projection + w2_out = ttnn.linear( + w2_in, + self.w2, + bias=self.b2, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + + ttnn.deallocate(w1_out) + ttnn.deallocate(w3_out) + ttnn.deallocate(w2_in) + + return w2_out diff --git a/models/experimental/qwen25_vl/tt/model.py b/models/experimental/qwen25_vl/tt/model.py new file mode 100644 index 000000000000..50cbbe924784 --- /dev/null +++ b/models/experimental/qwen25_vl/tt/model.py @@ -0,0 +1,117 @@ +""" +This is the end-to-end pipeline for the Qwen-VL 2.5 model. + +The `Qwen25VLTransformer` class inherits from the `Transformer` class in tt_transformers. +It overrides `prepare_inputs_prefill` to run inference on the vision model and +pass the resulting visual tokens to the text model along with text tokens. +""" + + +import ttnn +import torch + +from models.tt_transformers.tt.model import Transformer + + +class Qwen25VLTransformer(Transformer): + def __init__( + self, + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=None, + use_paged_kv_cache=False, + ): + super().__init__( + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=paged_attention_config, + use_paged_kv_cache=use_paged_kv_cache, + ) + + def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): + """ + Inputs are torch tensors or python types. This function returns ttnn + tensors on device. + TODO: Debate whether this function is responsible for padding + """ + + tokens = tokens.reshape(1, 1, 1, -1) + S = tokens.shape[-1] + tokens = ttnn.from_torch( + tokens, + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + # self.embed_scale = args.dim**0.5 + tokens_embd = self.embd(tokens) + # tokens_embd = ttnn.multiply(tokens_embd, self.embed_scale) + + pixel_values = kwargs["processed_inputs"]["pixel_values"] + input_ids = kwargs["processed_inputs"]["input_ids"] + image_grid_thw = kwargs["processed_inputs"]["image_grid_thw"] + + vision_model = kwargs["vision_model"] + pixel_values = self.args.prepare_residual_tensor_prefill(pixel_values.unsqueeze(0), force_replicated=True) + + vision_output = vision_model(pixel_values, image_grid_thw) + + tokens_embd = ttnn.to_torch(tokens_embd) + comp_vision_output = ttnn.to_torch(ttnn.from_device(vision_output)) + + input_ids = torch.nn.functional.pad(input_ids, (0, tokens_embd.shape[1] - input_ids.shape[1]), "constant", 0) + image_features = comp_vision_output.squeeze(0) + special_image_mask = (input_ids == 151655).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(tokens_embd) + image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) + tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) + + tokens_embd = ttnn.from_torch( + tokens_embd, + dtype=ttnn.bfloat16, + device=self.mesh_device, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) + # Slice the rot mats to the prefill seqlen + assert ( + self.rope_setup.cos_matrix.shape[2] >= start_pos + S + ), f"Padded prefill end idx {start_pos + S} exceeds max seq len {self.rope_setup.cos_matrix.shape[2]}" + + tt_rot_mats_prefill_global = [ + self.rope_setup.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + + if page_table is not None: + tt_page_table = ttnn.from_torch( + page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_page_table = None + + if chunk_page_table is not None: + tt_chunk_page_table = ttnn.from_torch( + chunk_page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_chunk_page_table = None + + return tokens_embd, tt_rot_mats_prefill_global, tt_page_table, tt_chunk_page_table diff --git a/models/experimental/qwen25_vl/tt/patch_embed.py b/models/experimental/qwen25_vl/tt/patch_embed.py new file mode 100644 index 000000000000..36262c777f7e --- /dev/null +++ b/models/experimental/qwen25_vl/tt/patch_embed.py @@ -0,0 +1,62 @@ +""" +This is the patch embedding implementation for Qwen-VL-7B. + +The existing TtLlamaConv2dPatch from tt_transformers uses Conv2d, but Qwen needs Conv3d instead. +Since the stride size is the same as the kernel size for this operation, we can use a matrix +multiplication (matmul) instead of a convolution. This is necessary because +`ttnn.experimental.conv3d` currently only supports Conv3d with stride (1, 1, 1). +""" + +import ttnn + + +class TTQwen2_5_VisionPatchEmbed: + def __init__( + self, + device, + patch_size, + temporal_patch_size, + in_channels, + embed_dim, + state_dict, + weight_key, + layer_num=None, + state_dict_prefix="", + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + mode="decode", + ): + super().__init__() + self.mode = mode + self.device = device + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + self.weight_memory_config = weight_memory_config + self.weight_dtype = weight_dtype + + weight_name_1 = f"{state_dict_prefix}{weight_key}proj.weight" + torch_weight = state_dict[weight_name_1] + + weight_matrix = torch_weight.view(self.embed_dim, -1) + self.weight = ttnn.from_torch( + weight_matrix.T, + device=self.device, + dtype=self.weight_dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=self.weight_memory_config, + ) + self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=True, + packer_l1_acc=True, + dst_full_sync_en=False, + ) + + def __call__(self, x: ttnn.Tensor) -> ttnn.Tensor: + x_flattened = ttnn.reshape(x, (x.shape[2], -1)) + output = ttnn.matmul(x_flattened, self.weight, compute_kernel_config=self.compute_kernel_config) + + return output diff --git a/models/experimental/qwen25_vl/tt/patch_merger.py b/models/experimental/qwen25_vl/tt/patch_merger.py new file mode 100644 index 000000000000..a1001b0ec96d --- /dev/null +++ b/models/experimental/qwen25_vl/tt/patch_merger.py @@ -0,0 +1,117 @@ +""" +This is the patch merger implementation used in the Qwen-VL-7B model. + +There's no existing implementation for this in tt_transformers, +so it was written specifically based on Qwen-VL's architecture. +""" + +import ttnn +from models.experimental.qwen25_vl.tt.rmsnorm import RMSNorm +from models.tt_transformers.tt.model_config import ModelArgs + + +class TTQwen2_5_VLPatchMerger: + def __init__( + self, + device, + dim, + state_dict, + weight_key, + layer_num=None, + state_dict_prefix="", + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps: float = 1e-06, + dims=3584, + context_dim=1280, + spatial_merge_size=2, + mode="decode", + ): + super().__init__() + self.eps = eps + self.mode = mode + + tt_model_args = ModelArgs( + device, + max_batch_size=1, + max_seq_len=128, + ) + + weight_name_1 = f"{state_dict_prefix}{weight_key}ln_q.weight" + weight_name_2 = f"{state_dict_prefix}{weight_key}feed_forward.0.weight" + weight_name_3 = f"{state_dict_prefix}{weight_key}feed_forward.2.weight" + + bias_name_2 = f"{state_dict_prefix}{weight_key}feed_forward.0.bias" + bias_name_3 = f"{state_dict_prefix}{weight_key}feed_forward.2.bias" + + self.weight_1 = ttnn.as_tensor( + state_dict[weight_name_1], + device=device, + dtype=weight_dtype, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=weight_memory_config, + ) + + self.weight_2 = ttnn.as_tensor( + state_dict[weight_name_2], + device=device, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=weight_memory_config, + ) + + self.weight_3 = ttnn.as_tensor( + state_dict[weight_name_3], + device=device, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=weight_memory_config, + ) + + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = RMSNorm( + device=device, + dim=1280, + state_dict=state_dict, + state_dict_prefix="", + weight_key="visual.merger.ln_q", + weight_dtype=ttnn.bfloat16, + is_distributed=False, + sharded_program_config=tt_model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=False, + ) + + self.weight_3 = ttnn.transpose(self.weight_3, 0, 1) + + self.weight_2 = ttnn.transpose(self.weight_2, 0, 1) + + self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=True, + packer_l1_acc=True, + dst_full_sync_en=False, + ) + + def __call__(self, x): + x = self.ln_q(x, mode=self.mode) + + x = ttnn.reshape(x, (-1, self.hidden_size)) + + x = ttnn.linear( + x, + self.weight_2, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + x = ttnn.gelu(x) + + x = ttnn.linear( + x, + self.weight_3, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + + return x diff --git a/models/experimental/qwen25_vl/tt/rmsnorm.py b/models/experimental/qwen25_vl/tt/rmsnorm.py new file mode 100644 index 000000000000..025ffcfc638b --- /dev/null +++ b/models/experimental/qwen25_vl/tt/rmsnorm.py @@ -0,0 +1,140 @@ +""" +This is a modified RMSNorm implementation for Qwen-VL-7B. + +It's based on the existing RMSNorm in models/common/rmsnorm.py, +with slight changes to support the bf8 data type. +""" + + +import ttnn +from models.common.lightweightmodule import LightweightModule + +TILE = 32 +SHARD_HEIGHT = TILE # Current ttnn.rms_norm implementation requires shard height to be a single tile + + +class RMSNorm(LightweightModule): + def __init__( + self, + device, + dim, + state_dict, + weight_key, + layer_num=None, + state_dict_prefix=None, + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps: float = 1e-06, + sharded_program_config=None, + sharded_output_config=None, + output_mem_config=None, + ccl_topology=ttnn.Topology.Ring, + ): + super().__init__() + self.eps = eps + self.is_distributed = is_distributed + self.ccl_topology = ccl_topology + + if state_dict_prefix: + weight_name = f"{state_dict_prefix}{weight_key}.weight" + else: + if layer_num is None: + weight_name = f"{weight_key}.weight" + else: + weight_name = f"layers.{layer_num}.{weight_key}.weight" + torch_weight = ( + state_dict[weight_name].unsqueeze(0).view(1, 1, dim).reshape([1, 1, dim // SHARD_HEIGHT, SHARD_HEIGHT]) + ) + + cache_name = None if weight_cache_path is None else weight_cache_path / weight_name + + # Compatibility with models that don't use mesh devices (e.g. single-chip Mistral-7b) + is_mesh_device = device.__class__.__name__ == "MeshDevice" + + self.weight = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=weight_memory_config, + # cache_file_name=cache_name, + mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + ) + + if self.is_distributed: + self.weight_distributed = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=weight_memory_config, + # cache_file_name=cache_name, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape)) + if is_mesh_device + else None, + ) + + self.sharded_output_config = sharded_output_config + self.sharded_program_config = sharded_program_config + self.output_mem_config = output_mem_config + + self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=False, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + def forward(self, x: ttnn.Tensor, mode="decode", in_sharded=False, out_sharded=False) -> ttnn.Tensor: + # If input is sharded do sharded RMSNorm and optionally return sharded output + program_config = self.sharded_program_config if in_sharded else None + memory_config = self.sharded_output_config if out_sharded else None + distributed = self.is_distributed and self.is_distributed(mode) + norm = self._distributed_rmsnorm + weight = self.weight_distributed if distributed else self.weight + + if in_sharded: + assert not distributed, "Distributed RMSNorm does not support sharded inputs" + else: + assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor" + + x = norm( + x, + epsilon=self.eps, + weight=weight, + program_config=program_config, + memory_config=memory_config, + compute_kernel_config=self.compute_kernel_config_hifi2, + ) + + if in_sharded and not out_sharded: + return ttnn.sharded_to_interleaved(x) + else: + return x + + def _distributed_rmsnorm( + self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None + ): + inp = ttnn.sharded_to_interleaved(inp, ttnn.DRAM_MEMORY_CONFIG) + + xnorm = ttnn.pow(inp, 2) + + xnorm = ttnn.mean(xnorm, dim=-1, keepdim=True) + + xnorm = ttnn.rsqrt(xnorm + epsilon) + + xnorm = ttnn.multiply(inp, xnorm) + + weight = ttnn.reshape(weight, [1, 1, -1]) + + output = ttnn.multiply(xnorm, (weight), use_legacy=False) + + if memory_config is not None: + output = ttnn.to_memory_config(output, memory_config) + + ttnn.deallocate(xnorm) + ttnn.deallocate(weight) + + return output diff --git a/models/experimental/qwen25_vl/tt/rope.py b/models/experimental/qwen25_vl/tt/rope.py new file mode 100644 index 000000000000..97a89ed8bf53 --- /dev/null +++ b/models/experimental/qwen25_vl/tt/rope.py @@ -0,0 +1,41 @@ +""" +This is the vision rotary embedding implementation for Qwen-VL-7B. + +The existing RotarySetup(models/tt_transformers/tt/rope.py) in tt_transformers can't be used here, +as Qwen-VL uses a different logic for applying rotary embeddings. +This version is implemented specifically to match Qwen's design. +""" + + +import torch +import ttnn + + +class TTQwen2_5_VisionRotaryEmbedding: + def __init__(self, device, dim: int, theta: float = 10000.0, mode="decode"): + self.dim = dim + self.theta = theta + self.device = device + + arange_indices = ttnn.arange(start=0, end=dim, step=2, device=device) + arange_indices = ttnn.to_layout(arange_indices, ttnn.TILE_LAYOUT) + exponent = ttnn.div(arange_indices, dim) + pow_result = ttnn.pow(theta, exponent) + recip = ttnn.reciprocal(pow_result) + self.inv_freq = ttnn.multiply(recip, 1.0) + + def __call__(self, seqlen: int): + tt_seq = ttnn.arange(end=seqlen, device=self.device) + tt_seq = ttnn.to_torch(tt_seq) + tt_inv_freq = ttnn.to_torch(self.inv_freq) + tt_freqs = torch.outer(tt_seq, tt_inv_freq) + tt_freqs = ttnn.from_torch( + tt_freqs, + device=self.device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + ttnn.deallocate(self.inv_freq) + + return tt_freqs diff --git a/models/experimental/qwen25_vl/tt/text_mlp.py b/models/experimental/qwen25_vl/tt/text_mlp.py new file mode 100644 index 000000000000..ef973009ec45 --- /dev/null +++ b/models/experimental/qwen25_vl/tt/text_mlp.py @@ -0,0 +1,114 @@ +""" +This is the text MLP implementation for Qwen-VL-7B. + +The existing MLP in tt_transformers(models/tt_transformers/tt/mlp.py) caused "Statically Allocated Circular L1 Buffer" issue +when used with the Qwen VL model. To avoid this, the MLP was re-written without memory optimizations. +""" + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +class MLP(LightweightModule): + def __init__( + self, + mesh_device, + args, + state_dict, + weight_cache_path, + dtype, + layer_num=0, + model_config=None, + state_dict_prefix=None, + ): + super().__init__() + + self.mesh_device = mesh_device + self.args = args + self.state_dict = state_dict + self.dim = args.dim + + state_dict_prefix = state_dict_prefix or args.get_state_dict_prefix(self.__class__.__name__, layer_num) + + def get_weight(name): + return torch.transpose(state_dict[f"{state_dict_prefix}.{name}.weight"], -2, -1) + + def get_bias(name): + return state_dict[f"{state_dict_prefix}{name}.bias"] + + def cache_name(name): + if args.dummy_weights: + return None + return weight_cache_path / f"{state_dict_prefix}.{name}" + + def as_tensor(name, dtype, is_bias=False): + tensor_data = get_bias(name) if is_bias else get_weight(name) + return ttnn.as_tensor( + tensor_data, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + # cache_file_name=cache_name(name), + ) + + # Weights and Biases + self.w1 = as_tensor("w1", dtype) + self.w3 = as_tensor("w3", dtype) + self.w2 = as_tensor("w2", dtype) + + self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=True, + packer_l1_acc=True, + dst_full_sync_en=False, + ) + + def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: + """ + Qwen HF MLP reference: + output = down_proj(act_fn(gate_proj(x)) * up_proj(x)) + Mapping: + w1 -> gate_proj + w3 -> up_proj + w2 -> down_proj + """ + + # Linear with GELU activation + w1_out = ttnn.linear( + x, + self.w1, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + activation="silu", + compute_kernel_config=self.compute_kernel_config, + ) + + w3_out = ttnn.linear( + x, + self.w3, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + + # Element-wise multiply + w2_in = ttnn.mul(w1_out, w3_out, dtype=ttnn.bfloat16) + + # Final projection + w2_out = ttnn.linear( + w2_in, + self.w2, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + + ttnn.deallocate(w1_out) + ttnn.deallocate(w3_out) + ttnn.deallocate(w2_in) + + return w2_out diff --git a/models/experimental/qwen25_vl/tt/vision_block.py b/models/experimental/qwen25_vl/tt/vision_block.py new file mode 100644 index 000000000000..72461a61c7a6 --- /dev/null +++ b/models/experimental/qwen25_vl/tt/vision_block.py @@ -0,0 +1,79 @@ +""" +This is the vision block used in the Qwen-VL-7B architecture +consisting of RMSnorm and self-attention layer followed by an MLP layer. +""" + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.experimental.qwen25_vl.tt.mlp import QwenTTVisionMLP +from models.experimental.qwen25_vl.tt.rmsnorm import RMSNorm +from models.experimental.qwen25_vl.tt.attention import TtQwen2_5_VLVisionSdpaAttention + + +class TtQwen2_5_VLVisionBlock(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + dtype, + model_args, + weight_cache_path=None, + state_dict_prefix=None, + ): + super().__init__() + + self.norm1 = RMSNorm( + device=mesh_device, + dim=1280, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="norm1", + weight_dtype=dtype, + is_distributed=False, + sharded_program_config=model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], + ) + + self.norm2 = RMSNorm( + device=mesh_device, + dim=1280, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="norm2", + weight_dtype=dtype, + is_distributed=False, + sharded_program_config=model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], + ) + + self.attn = TtQwen2_5_VLVisionSdpaAttention( + mesh_device, + state_dict, + state_dict_prefix=f"{state_dict_prefix}attn.", + # weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + ) + + self.mlp = QwenTTVisionMLP( + mesh_device=mesh_device, + args=model_args, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}feed_forward.", + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + ) + + def forward(self, hidden_states, cu_seqlens, position_embeddings): + hidden_states = ttnn.add( + hidden_states, + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ), + ) + + hidden_states = ttnn.add(hidden_states, self.mlp(self.norm2(hidden_states))) + + return hidden_states diff --git a/models/experimental/qwen25_vl/tt/vision_model.py b/models/experimental/qwen25_vl/tt/vision_model.py new file mode 100644 index 000000000000..974dbbe8daa4 --- /dev/null +++ b/models/experimental/qwen25_vl/tt/vision_model.py @@ -0,0 +1,242 @@ +""" +This is the end-to-end architecture of the Qwen-VL 2.5 vision model. + +It brings together all components—patch embedding, vision blocks, rotary embeddings, +and patch merger for visual input processing. +""" + +import ttnn +from tqdm import tqdm +from models.common.lightweightmodule import LightweightModule +from models.experimental.qwen25_vl.tt.vision_block import TtQwen2_5_VLVisionBlock +from models.experimental.qwen25_vl.tt.patch_embed import TTQwen2_5_VisionPatchEmbed +from models.experimental.qwen25_vl.tt.rope import TTQwen2_5_VisionRotaryEmbedding +from models.experimental.qwen25_vl.tt.patch_merger import TTQwen2_5_VLPatchMerger + +import torch +import torch.nn.functional as F + + +class TtQwen2_5_VisionTransformerPretrainedModel(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + model_args, + layers, + block_key="", + gated=False, + ): + self.spatial_merge_size = model_args.spatial_merge_size + self.patch_size = model_args.vision_patch_size + self.fullatt_block_indexes = model_args.fullatt_block_indexes + self.window_size = model_args.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + self.mesh_device = mesh_device + hidden_size = model_args.vision_dim + n_heads = model_args.vision_attn_n_heads + out_hidden_size = model_args.out_hidden_size + temporal_patch_size = model_args.temporal_patch_size + + self.patch_embed = TTQwen2_5_VisionPatchEmbed( + device=mesh_device, + patch_size=self.patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=3, + embed_dim=hidden_size, + state_dict=state_dict, + weight_key="patch_embed.", + layer_num=None, + state_dict_prefix=state_dict_prefix, + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + ) + + head_dim = hidden_size // n_heads + + self.rotary_pos_emb = TTQwen2_5_VisionRotaryEmbedding( + device=mesh_device, + dim=head_dim // 2, + theta=10000.0, + ) + + self.blocks = [ + TtQwen2_5_VLVisionBlock( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}blocks.{i}.", + weight_cache_path=weight_cache_path, + dtype=dtype, + model_args=model_args, + ) + for i in tqdm(range(layers), desc=f"Loading vision transformer blocks") + ] + + self.merger = TTQwen2_5_VLPatchMerger( + device=mesh_device, + dim=5120, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="merger.", + layer_num=None, + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps=1e-06, + dims=out_hidden_size, + context_dim=hidden_size, + spatial_merge_size=self.spatial_merge_size, + ) + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb_full = ttnn.to_torch(rotary_pos_emb_full, device=self.mesh_device) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states, grid_thw): + hidden_states = self.patch_embed(hidden_states) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + # device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len = hidden_states.shape[-2] + + hidden_states = ttnn.reshape(hidden_states, [seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1]) + # hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + tt_index = ttnn.from_torch( + window_index.view(-1, 1, 1).expand(-1, hidden_states.shape[-2], hidden_states.shape[-1]).permute(1, 2, 0), + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.TILE_LAYOUT, + # memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + hidden_states = ttnn.gather(ttnn.permute(hidden_states, (1, 2, 0)), dim=-1, index=tt_index) + hidden_states = ttnn.permute(hidden_states, (2, 0, 1)) + + hidden_states = ttnn.reshape(hidden_states, [seq_len, -1]) + # hidden_states = hidden_states.reshape(seq_len, -1) + + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + + cos_tensor = ttnn.from_torch(emb.cos(), device=self.mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + sin_tensor = ttnn.from_torch(emb.sin(), device=self.mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + + position_embeddings = (cos_tensor, sin_tensor) + + ttnn.deallocate(tt_index) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) + + ttnn.deallocate(cos_tensor) + ttnn.deallocate(sin_tensor) + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + + tt_reverse_indices = ttnn.from_torch( + reverse_indices.view(-1, 1).expand(-1, hidden_states.shape[-1]).transpose(0, 1), + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.TILE_LAYOUT, + ) + hidden_states = ttnn.gather(ttnn.permute(hidden_states, (1, 0)), dim=-1, index=tt_reverse_indices) + hidden_states = ttnn.permute(hidden_states, (1, 0)) + + return hidden_states diff --git a/models/tt_transformers/tt/attention.py b/models/tt_transformers/tt/attention.py index 47ba6a7d95fd..12ccbd544c39 100644 --- a/models/tt_transformers/tt/attention.py +++ b/models/tt_transformers/tt/attention.py @@ -51,6 +51,7 @@ def __init__( self.batch_size_per_device_group = ( max(self.max_batch_size // self.num_device_groups, 1) if self.TG else self.max_batch_size ) + self.simplified_rms = True if configuration.base_model_name == "Qwen2.5-VL-7B" else False self.n_local_heads = self.n_heads // self.num_devices_per_group self.n_local_kv_heads = self.n_kv_heads // self.num_devices_per_group @@ -262,6 +263,7 @@ def norm_reshard(x, norm, mode): is_distributed=False, sharded_program_config=None, # FIXME: add height-sharded support. self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"], sharded_output_config=None, # FIXME: add height-sharded support. self.model_config["CREATE_QKV_DECODE_SHARD"] + simplified_rms=self.simplified_rms, ) self.q_norm = lambda x, mode: norm_reshard(x, fn_q_norm, mode) else: @@ -281,6 +283,7 @@ def norm_reshard(x, norm, mode): is_distributed=False, sharded_program_config=None, # FIXME: add height-sharded support. self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"], sharded_output_config=None, # FIXME: add height-sharded support. self.model_config["CREATE_QKV_DECODE_SHARD"], + simplified_rms=self.simplified_rms, ) self.k_norm = lambda x, mode: norm_reshard(x, fn_k_norm, mode) else: diff --git a/models/tt_transformers/tt/decoder.py b/models/tt_transformers/tt/decoder.py index 24e95a709b8a..d86520611b5f 100644 --- a/models/tt_transformers/tt/decoder.py +++ b/models/tt_transformers/tt/decoder.py @@ -4,6 +4,7 @@ import ttnn from models.common.lightweightmodule import LightweightModule from models.common.rmsnorm import RMSNorm +from models.experimental.qwen25_vl.tt.text_mlp import MLP as QwenMLP from models.tt_transformers.tt.attention import Attention as DefaultAttention from models.tt_transformers.tt.distributed_norm import DistributedNorm from models.tt_transformers.tt.mlp import MLP @@ -39,6 +40,7 @@ def __init__( self.n_kv_heads = args.n_kv_heads self.current = 0 self.model_config = args.get_model_config() + self.simplified_rms = True if self.args.base_model_name == "Qwen2.5-VL-7B" else False self.layer_num = layer_num @@ -55,15 +57,26 @@ def __init__( paged_attention_config=paged_attention_config, use_paged_kv_cache=use_paged_kv_cache, ) - self.feed_forward = MLP( - mesh_device=mesh_device, - args=args, - state_dict=state_dict, - weight_cache_path=weight_cache_path, - layer_num=layer_num, - dtype=dtype, - model_config=self.model_config, - ) + if self.args.base_model_name == "Qwen2.5-VL-7B": + self.feed_forward = QwenMLP( + mesh_device=mesh_device, + args=args, + state_dict=state_dict, + weight_cache_path=weight_cache_path, + layer_num=layer_num, + dtype=dtype, + model_config=self.model_config, + ) + else: + self.feed_forward = MLP( + mesh_device=mesh_device, + args=args, + state_dict=state_dict, + weight_cache_path=weight_cache_path, + layer_num=layer_num, + dtype=dtype, + model_config=self.model_config, + ) self.attention_norm = DistributedNorm( RMSNorm( device=mesh_device, @@ -79,6 +92,7 @@ def __init__( sharded_program_config=self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"], sharded_output_config=self.model_config["SHARDED_ATTN_INPUT_MEMCFG"], ccl_topology=self.args.ccl_topology(), + simplified_rms=self.simplified_rms, ), args, TG=args.is_galaxy, @@ -98,6 +112,7 @@ def __init__( sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], ccl_topology=self.args.ccl_topology(), + simplified_rms=self.simplified_rms, ), args, TG=args.is_galaxy, diff --git a/models/tt_transformers/tt/generator.py b/models/tt_transformers/tt/generator.py index c9413fe6f44a..cd0620049b91 100644 --- a/models/tt_transformers/tt/generator.py +++ b/models/tt_transformers/tt/generator.py @@ -57,7 +57,7 @@ def __init__(self, model, model_args, mesh_device, tokenizer=None, formatter=Non # Note: This function is called by vLLM def prefill_forward_text( - self, tokens: torch.Tensor, page_table=None, kv_cache=None, prompt_lens=None, empty_slots=None + self, tokens: torch.Tensor, page_table=None, kv_cache=None, prompt_lens=None, empty_slots=None, **kwargs ): if page_table is not None: assert isinstance(page_table, torch.Tensor), "page_table mush be torch.Tensor" @@ -101,6 +101,7 @@ def prefill_forward_text( last_token_idx=last_token_idx, kv_cache=model_kv_cache, model_id=model_id, + **kwargs, ) out_list.append(logits) @@ -116,7 +117,9 @@ def prefill_forward_text( logger.info(f"Finished prefill for all users up to {batch_seq_len} tokens, Starting decode...") return output_logits - def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_token_idx, kv_cache=None, model_id=-1): + def prefill_forward_single_user_text( + self, tokens, page_table, user_id, last_token_idx, kv_cache=None, model_id=-1, **kwargs + ): seq_len = tokens.shape[-1] use_chunked_prefill = seq_len > self.model_args[model_id].max_prefill_chunk_size if use_chunked_prefill: @@ -165,6 +168,7 @@ def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_tok start_pos=chunk_start, page_table=page_table_user_padded, chunk_page_table=chunk_page_table, + **kwargs, ) tt_logits = self.model[model_id].ttnn_prefill_forward( chunk_prefill_input, @@ -175,6 +179,7 @@ def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_tok chunk_start_idx=chunk_start, get_last_token=(last_token_idx_in_chunk // 32) * 32, kv_cache=kv_cache, + **kwargs, ) if chunk_start == last_chunk_start: @@ -185,6 +190,7 @@ def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_tok prefill_input, rot_mats_prefill, page_table_tt, _ = self.model[model_id].prepare_inputs_prefill( tokens, page_table=page_table, + **kwargs, ) tt_logits = self.model[model_id].ttnn_prefill_forward( diff --git a/models/tt_transformers/tt/load_checkpoints.py b/models/tt_transformers/tt/load_checkpoints.py index 6b28e2b4e5ce..38e68df53e2b 100644 --- a/models/tt_transformers/tt/load_checkpoints.py +++ b/models/tt_transformers/tt/load_checkpoints.py @@ -238,6 +238,7 @@ def map_hf_to_meta_keys(loaded_weights): """ replacements = [ ("^emb.weight", "weight"), + ("language.model.", ""), ("model.", ""), ("embed_tokens", "tok_embeddings"), ("lm_head", "output"), @@ -256,6 +257,35 @@ def map_hf_to_meta_keys(loaded_weights): return replace_keys(loaded_weights, replacements) +def map_vision_meta_to_hf_keys(loaded_weights): + """ + Map Hugging Face checkpoint keys to Meta checkpoint keys. + You can use this to support other models by adding more mappings. + See replace_keys for more details on the format of replacements. + """ + inverted_mapping = [ + ("attention_norm", "input_layernorm"), + ("ffn_norm", "post_attention_layernorm"), + ("attention", "self_attn"), + ("feed_forward", "mlp"), + ("w1", "gate_proj"), + ("w2", "down_proj"), + ("w3", "up_proj"), + ("wq", "q_proj"), + ("wk", "k_proj"), + ("wv", "v_proj"), + ("wo", "o_proj"), + ] + + return replace_keys(loaded_weights, inverted_mapping) + + +def convert_vision_meta_to_hf(state_dict, head_dim): + # state_dict = convert_meta_qkv_to_hf_format(state_dict, head_dim) + state_dict = map_vision_meta_to_hf_keys(state_dict) + return state_dict + + def map_meta_to_hf_keys(loaded_weights): # Define mappings at each level of the hierarchy meta_to_hf_mappings = { diff --git a/models/tt_transformers/tt/model.py b/models/tt_transformers/tt/model.py index 591c915085e6..8fce82afe8bd 100644 --- a/models/tt_transformers/tt/model.py +++ b/models/tt_transformers/tt/model.py @@ -40,6 +40,7 @@ def __init__( self.model_config = args.get_model_config() self.grid_size = self.args.max_grid_size state_dict_prefix = args.get_state_dict_prefix("", None) + self.simplified_rms = True if self.args.base_model_name == "Qwen2.5-VL-7B" else False self.embd = Embedding( mesh_device=mesh_device, @@ -90,6 +91,7 @@ def __init__( sharded_program_config=self.model_config["SHARDED_NORM_LM_HEAD_PRGM_CFG"], sharded_output_config=self.model_config["LM_HEAD_INPUT_MEMCFG"], ccl_topology=self.args.ccl_topology(), + simplified_rms=self.simplified_rms, ), args, args.is_galaxy, diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index 59a955a568a5..33641b93a268 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -27,6 +27,7 @@ from models.tt_transformers.tt.load_checkpoints import ( convert_hf_to_meta, convert_meta_to_hf, + convert_vision_meta_to_hf, load_hf_state_dict, load_meta_state_dict, reverse_permute, @@ -1576,7 +1577,54 @@ def _set_params(self, checkpoint_dir): else None ) + def _set_vision_params(self, vision_config): + self.vision_chunk_size = vision_config.get("vision_chunk_size", 896) + self.vision_max_num_chunks = vision_config.get("vision_max_num_chunks", 4) + self.vision_num_cross_attention_layers = vision_config.get("vision_num_cross_attention_layers", 8) + self.vision_dim = vision_config.get("hidden_size", 1152) + + intermediate_size = vision_config.get("intermediate_size", self.vision_dim * 4) + self.vision_mlp_ratio = intermediate_size // self.vision_dim + self.vision_hidden_dim = int(self.vision_dim * self.vision_mlp_ratio) + self.vision_attn_n_heads = vision_config.get("num_attention_heads") or vision_config.get("num_heads") or 16 + self.vision_head_dim = self.vision_dim // self.vision_attn_n_heads + + self.vision_n_layers = vision_config.get("num_hidden_layers") or vision_config.get("depth") or 27 + self.vision_patch_size = vision_config.get("patch_size", 14) + self.vision_in_channels = vision_config.get("num_channels", 3) + + self.vision_dropout = vision_config.get("attention_dropout", 0.0) + self.mm_tokens_per_image = vision_config.get("mm_tokens_per_image", 256) + + # Qwen2.5 VL specific params + if "Qwen2.5-VL-7B" in self.base_model_name: + self.spatial_merge_size = vision_config.get("spatial_merge_size") + self.window_size = vision_config.get("window_size") + self.fullatt_block_indexes = vision_config.get("fullatt_block_indexes") + self.out_hidden_size = vision_config.get("out_hidden_size") + self.temporal_patch_size = vision_config.get("temporal_patch_size") + + # Optional vision activation layer, defaults to GELU + act_layer = vision_config.get("act_layer", "gelu").lower() + self.vision_act_layer = { + "gelu": ttnn.UnaryOpType.GELU, + "relu": ttnn.UnaryOpType.RELU, + "silu": ttnn.UnaryOpType.SILU, + }.get(act_layer, ttnn.UnaryOpType.GELU) + def _set_hf_params(self, checkpoint_dir): + def merge_text_config(base_config): + text_config = base_config.get("text_config", {}) + # Merge non-nested keys into text_config + text_config.update({k: v for k, v in base_config.items() if k not in ["text_config", "vision_config"]}) + return text_config + + def merge_vision_config(base_config): + vision_config = base_config.get("vision_config", {}) + # Merge non-nested keys into vision_config + vision_config.update({k: v for k, v in base_config.items() if k not in ["text_config", "vision_config"]}) + return vision_config + if self.from_hf_url: # Special case Qwen2.5-VL models until they are fully integrated into a HF release if "Qwen/Qwen2.5-VL" in self.model_name: @@ -1593,12 +1641,20 @@ def _set_hf_params(self, checkpoint_dir): self.hf_config = AutoConfig.from_pretrained(self.CKPT_DIR) config = self.hf_config.to_dict() + if "text_config" in config or "vision_config" in config: + merged_text_config = merge_text_config(config) + self._set_params_from_dict(merged_text_config, is_hf=True) + if "vision_config" in config: + self._set_vision_params(config["vision_config"]) + else: + self._set_params_from_dict(config, is_hf=True) + else: config_file = os.path.join(checkpoint_dir, "config.json") assert os.path.exists(config_file), f"config.json file not found at {config_file}" with open(config_file, "r") as f: config = json.load(f) - self._set_params_from_dict(config, is_hf=True) + self._set_params_from_dict(config) def __repr__(self): return f"""ModelArgs( @@ -1623,7 +1679,7 @@ def is_vision(self): return self.vision_chunk_size > 0 def get_state_dict_prefix(self, module_name, layer_num): - text_prefix = self.state_dict_text_prefix + text_prefix = "text_model." if self.is_vision() and not "Qwen2.5-VL" in self.model_name else "" layer_prefix = f"layers.{layer_num}." if layer_num is not None else "" module_map = { "MLP": "feed_forward", @@ -2196,6 +2252,133 @@ def reference_rms_norm(self): layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) return layer + def reference_vision_transformer(self, wrap=True, load_checkpoint=False): + if self.checkpoint_type == CheckpointType.HuggingFace: + from transformers import AutoConfig, AutoModelForCausalLM + + if self.dummy_weights and not load_checkpoint: + config = AutoConfig.from_pretrained(self.LOCAL_HF_PARAMS[self.model_name]) + config.num_layers = self.n_layers + config.num_hidden_layers = self.n_layers + model = AutoModelForCausalLM.from_config(config) + else: + from transformers import Qwen2_5_VLForConditionalGeneration + + if "Qwen2.5-VL" in self.model_name: + model = Qwen2_5_VLForConditionalGeneration.from_pretrained(self.CKPT_DIR) + else: + if self.cached_hf_model is None: + model = AutoModelForCausalLM.from_pretrained(self.CKPT_DIR) + self.cached_hf_model = model + else: + model = self.cached_hf_model + model.model.layers = model.model.layers[: self.n_layers] + + if wrap: + wrapper = HfModelWrapper(model, self.head_dim) + return wrapper + else: + return model + + def reference_vision_model(self): + model = self.reference_vision_transformer(wrap=False) + if "Qwen2.5-VL-7B" in self.model_name: + layer = model.visual + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + + return layer + + def reference_vision_block(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.visual.blocks[0] + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + + return layer + + def reference_vision_mlp(self): + model = self.reference_vision_transformer(wrap=False) + if "Qwen2.5-VL-7B" in self.model_name: + layer = model.visual.blocks[0].mlp + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_siglip_patch_embed(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.embeddings.patch_embedding + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_pos_embedding(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.embeddings.position_embedding + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_embedding(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.embeddings + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_layernorm(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.encoder.layers[0].layer_norm1 + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_attention(self): + model = self.reference_vision_transformer(wrap=False) + if "Qwen2.5-VL-7B" in self.model_name: + layer = model.visual.blocks[0].attn + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_encoder_block(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.encoder.layers[0] + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rms_norm(self): + model = self.reference_vision_transformer(wrap=False) + if "Qwen2.5-VL-7B" in self.model_name: + layer = model.visual.blocks[0].norm1 + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_encoder(self): + model = self.reference_vision_transformer(wrap=False) + if "Qwen2.5-VL-7B" in self.model_name: + layer = model.visual.blocks + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + + return layer + + def reference_vision_qwen_merger(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.visual.merger + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_qwen_patch_embed(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.visual.patch_embed + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + return layer + def reference_mlp(self): if self.checkpoint_type == CheckpointType.Meta: from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import FeedForward