From 2763a965c0b39978ec6ebf8e49e5a1160e99b4b2 Mon Sep 17 00:00:00 2001 From: MohammedTaherMcW Date: Thu, 7 Aug 2025 15:57:00 +0000 Subject: [PATCH 1/2] Migrate Qwen 2.5 VL 7B to tt_transformers --- .../demo/simple_vision_demo.py | 82 ++++-- .../qwen_vl/test_image_attention.py | 141 ++++++++++ .../multimodal/qwen_vl/test_image_block.py | 126 +++++++++ .../multimodal/qwen_vl/test_image_merger.py | 104 ++++++++ .../multimodal/qwen_vl/test_image_mlp.py | 81 ++++++ .../multimodal/qwen_vl/test_image_model.py | 81 ++++++ .../qwen_vl/test_image_patch_embed.py | 101 +++++++ .../multimodal/qwen_vl/test_image_rms.py | 110 ++++++++ models/tt_transformers/tt/common.py | 45 ++++ models/tt_transformers/tt/decoder.py | 32 +-- models/tt_transformers/tt/generator.py | 108 +++++++- models/tt_transformers/tt/load_checkpoints.py | 7 + models/tt_transformers/tt/model.py | 1 + models/tt_transformers/tt/model_config.py | 28 +- .../tt/multimodal/qwen_vl/qwen_e2e_model.py | 132 +++++++++ .../qwen_vl/qwen_image_attention.py | 160 +++++++++++ .../tt/multimodal/qwen_vl/qwen_image_block.py | 79 ++++++ .../tt/multimodal/qwen_vl/qwen_image_mlp.py | 128 +++++++++ .../qwen_vl/qwen_image_patch_embed.py | 68 +++++ .../multimodal/qwen_vl/qwen_patch_merger.py | 116 ++++++++ .../tt/multimodal/qwen_vl/qwen_rmsnorm.py | 138 ++++++++++ .../multimodal/qwen_vl/qwen_vision_model.py | 250 ++++++++++++++++++ models/tt_transformers/tt/rope.py | 23 ++ 23 files changed, 2082 insertions(+), 59 deletions(-) create mode 100644 models/tt_transformers/tests/multimodal/qwen_vl/test_image_attention.py create mode 100644 models/tt_transformers/tests/multimodal/qwen_vl/test_image_block.py create mode 100644 models/tt_transformers/tests/multimodal/qwen_vl/test_image_merger.py create mode 100644 models/tt_transformers/tests/multimodal/qwen_vl/test_image_mlp.py create mode 100644 models/tt_transformers/tests/multimodal/qwen_vl/test_image_model.py create mode 100644 models/tt_transformers/tests/multimodal/qwen_vl/test_image_patch_embed.py create mode 100644 models/tt_transformers/tests/multimodal/qwen_vl/test_image_rms.py create mode 100644 models/tt_transformers/tt/multimodal/qwen_vl/qwen_e2e_model.py create mode 100644 models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_attention.py create mode 100644 models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_block.py create mode 100644 models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_mlp.py create mode 100644 models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_patch_embed.py create mode 100644 models/tt_transformers/tt/multimodal/qwen_vl/qwen_patch_merger.py create mode 100644 models/tt_transformers/tt/multimodal/qwen_vl/qwen_rmsnorm.py create mode 100644 models/tt_transformers/tt/multimodal/qwen_vl/qwen_vision_model.py diff --git a/models/tt_transformers/demo/simple_vision_demo.py b/models/tt_transformers/demo/simple_vision_demo.py index 7d21da9ca274..22dd0b002335 100644 --- a/models/tt_transformers/demo/simple_vision_demo.py +++ b/models/tt_transformers/demo/simple_vision_demo.py @@ -27,7 +27,9 @@ import ttnn from models.demos.utils.llm_demo_utils import create_benchmark_data, verify_perf from models.perf.benchmarking_utils import BenchmarkProfiler +from models.tt_transformers.tt.common import hf_multimodal_encode from models.tt_transformers.tt.generator import Generator +from models.tt_transformers.tt.model_config import CheckpointType def get_batch_sampler(temperature, top_p, tokenizer): @@ -62,6 +64,7 @@ def create_multimodal_model( ): from models.tt_transformers.tt.model_config import ModelArgs from models.tt_transformers.tt.multimodal.llama_vision_model import CrossAttentionTransformer + from models.tt_transformers.tt.multimodal.qwen_vl.qwen_e2e_model import TtQwen_Model tt_model_args = ModelArgs(mesh_device, max_batch_size=max_batch_size) assert tt_model_args.is_vision(), "This model is multimodal" @@ -76,14 +79,25 @@ def create_multimodal_model( if checkpoint is None: checkpoint = tt_model_args.load_state_dict() - model = CrossAttentionTransformer( - mesh_device, - state_dict=checkpoint, - weight_cache_path=tt_model_args.weight_cache_path(dtype), - dtype=dtype, - configuration=tt_model_args, - use_paged_kv_cache=use_paged_kv_cache, - ) + + if tt_model_args.base_model_name == "Qwen2.5-VL-7B": + model = TtQwen_Model( + mesh_device=mesh_device, + state_dict=checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(ttnn.bfloat8_b), + dtype=ttnn.bfloat8_b, + args=tt_model_args, + use_paged_kv_cache=use_paged_kv_cache, + ) + else: + model = CrossAttentionTransformer( + mesh_device, + state_dict=checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=tt_model_args, + use_paged_kv_cache=use_paged_kv_cache, + ) return tt_model_args, model, checkpoint @@ -128,7 +142,7 @@ def prepare_generator_args( ) @pytest.mark.parametrize( "test_type,max_seq_len", - (("normal", 512),), + (("normal", 2048),), ids=["normal"], ) @pytest.mark.parametrize( @@ -172,9 +186,6 @@ def test_multimodal_demo_text( profiler = BenchmarkProfiler() profiler.start("run") - ckpt_dir = os.environ["LLAMA_DIR"] - tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") - num_devices = mesh_device.get_num_devices() if isinstance(mesh_device, ttnn.MeshDevice) else 1 max_batch_size *= data_parallel # input batch_size is interpreted as size per DP group @@ -185,11 +196,26 @@ def test_multimodal_demo_text( max_batch_size=max_batch_size, max_seq_len=max_seq_len, ) + + HF_MODEL = model_args[0].checkpoint_type == CheckpointType.HuggingFace + + if not HF_MODEL: + ckpt_dir = os.environ["LLAMA_DIR"] + tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") + + tokenizer = Tokenizer(model_path=tokenizer_path) + formatter = ChatFormat(tokenizer) + else: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_args[0].CKPT_DIR) + generator = Generator(model, model_args, mesh_device) - tokenizer = Tokenizer(model_path=tokenizer_path) - formatter = ChatFormat(tokenizer) - xattn_caches = [model.setup_cache(model_args[i].max_batch_size) for i, model in enumerate(generator.model)] + xattn_caches = [ + model.setup_cache(model_args[i].max_batch_size) if not HF_MODEL else None + for i, model in enumerate(generator.model) + ] # Create random images for trace capture with specific dimensions trace_img_560x560 = create_random_image(560, 560) @@ -250,10 +276,12 @@ def test_multimodal_demo_text( total_users = len(dialogs) num_batches = total_users // max_batch_size - sampler = get_batch_sampler(temperature, top_p, tokenizer) + sampler = get_batch_sampler(temperature, top_p, model_args[0].tokenizer) _num_prefill_tokens = 0 _num_decode_tokens = 0 + prompt_encoder = hf_multimodal_encode if HF_MODEL else formatter.encode_dialog_prompt + for iter_num in range(warmup_iters + 1): logger.info(f"Iteration {iter_num}") current_dialogs = trace_dialogs + dialogs @@ -263,9 +291,17 @@ def test_multimodal_demo_text( for msg in dialog: print(f"{msg.role.capitalize()}: {msg.content}\n") batch_model_input = [ - formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) for dialog in batch_dialogs + prompt_encoder(dialog, processor) if HF_MODEL else prompt_encoder(dialog, tool_prompt_format=False) + for dialog in batch_dialogs ] + if HF_MODEL: + # Use the processor's tokenizer instead of model_args tokenizer to ensure consistency + tokenizer = processor.tokenizer + image_grid_thw = [model_input.image_grid_thw for model_input in batch_model_input] + else: + image_grid_thw = None + # Do initial prefill vision_images = [ model_input.vision.images if model_input.vision else None for model_input in batch_model_input @@ -278,7 +314,7 @@ def test_multimodal_demo_text( total_lens = prefill_lens + max_gen_len # Create padded tokens tensor for batch - pad_id = tokenizer.pad_id + pad_id = tokenizer.pad_token_id if HF_MODEL else tokenizer.pad_id bsz = len(prompt_tokens) tokens = torch.full((bsz, max(total_lens)), pad_id, dtype=torch.long) @@ -302,6 +338,7 @@ def test_multimodal_demo_text( xattn_caches, total_lens, prefill_lens, + image_grid_thw=image_grid_thw, ) # Get cached prefill time @@ -319,6 +356,7 @@ def test_multimodal_demo_text( xattn_caches, total_lens, prefill_lens, + image_grid_thw=image_grid_thw, ) prefill_end = time.perf_counter() @@ -365,12 +403,16 @@ def test_multimodal_demo_text( ) # gen_idx is (num_tokens - 1) to avoid counting compile iter # Log full text output for each user in batch - vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256] + if HF_MODEL: + # For HF models, get vision tokens from the processor if they exist + vision_tokens = [] + else: + vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256] for user_id in range(max_batch_size): # Remove <|image|> tokens since they break the tokenizer tokens_out = [ - t if t not in vision_tokens else tokenizer.pad_id + t if t not in vision_tokens else pad_id for t in tokens[user_id].tolist()[: position_id[user_id] + 2] ] text = tokenizer.decode(tokens_out) diff --git a/models/tt_transformers/tests/multimodal/qwen_vl/test_image_attention.py b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_attention.py new file mode 100644 index 000000000000..b6fa77675ef1 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_attention.py @@ -0,0 +1,141 @@ +"""Test for Qwen 2.5 VL Vision Attention""" + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_attention import TtQwen2_5_VLVisionSdpaAttention +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_attention_inference(batch, num_chunks, mesh_device, reset_seeds): + dtype = ttnn.bfloat16 + pcc_required = 0.99 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "visual.blocks.0.attn." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + dim = model_args.vision_dim + + reference_model = model_args.reference_vision_attention() + reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + hidden_size = model_args.vision_dim + n_heads = model_args.vision_attn_n_heads + head_dim = hidden_size // n_heads + seq_len = model_args.vision_chunk_ntok + + tt_model = TtQwen2_5_VLVisionSdpaAttention( + mesh_device, + state_dict, + state_dict_prefix=first_layer_prefix, + # weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + ) + + seq_len = 4096 + hidden_dim = 1280 + num_heads = 16 + head_dim = hidden_dim // num_heads # 80 + rotary_dim = head_dim // 2 # 40 + + # Step 1: PyTorch input + pt_attention_input = torch.randn(seq_len, hidden_dim) # no batch dim + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32) + + # Step 2: precompute cos/sin + cos, sin = precompute_rope_cos_sin(seq_len, head_dim) + + # Step 3: run PyTorch reference + reference_output = reference_model( + pt_attention_input, cu_seqlens, rotary_pos_emb=None, position_embeddings=(cos, sin) + ) + + # Step 4: TT input + tt_attention_input = model_args.prepare_residual_tensor_prefill( + pt_attention_input.unsqueeze(0), force_replicated=True + ) + + cos_tensor = ttnn.from_torch( + cos, + device=mesh_device, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + ) + sin_tensor = ttnn.from_torch( + sin, + device=mesh_device, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + ) + + # Step 6: run TT + tt_out = tt_model(tt_attention_input, cu_seqlens, position_embeddings=(cos_tensor, sin_tensor)) + + # Doing contract in tt is correct!! + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[ + : tt_out.shape[0], : + ] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" + + +def precompute_rope_cos_sin(seq_len: int, dim: int, theta: float = 10000.0): + """ + Precompute RoPE cos/sin tensors. + Args: + seq_len: sequence length (number of tokens) + dim: hidden size (usually head_dim, not full hidden_size) + theta: RoPE theta parameter (default 10000) + Returns: + cos, sin: [seq_len, dim] each + """ + # Build the rope frequencies + half_dim = dim // 2 + freq_seq = torch.arange(half_dim, dtype=torch.float32) + inv_freq = 1.0 / (theta ** (freq_seq / half_dim)) + + # positions: [seq_len] + positions = torch.arange(seq_len, dtype=torch.float32) + + # Outer product: [seq_len, half_dim] + sinusoid_inp = torch.outer(positions, inv_freq) + + # Concatenate for complex dim + sin = torch.sin(torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)) + cos = torch.cos(torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)) + + return cos, sin diff --git a/models/tt_transformers/tests/multimodal/qwen_vl/test_image_block.py b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_block.py new file mode 100644 index 000000000000..f3e8e5b73645 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_block.py @@ -0,0 +1,126 @@ +""""Test for Qwen 2.5 VL Vision Transformer Block""" + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_block import TtQwen2_5_VLVisionBlock +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_transformer_inference(batch, num_chunks, mesh_device, reset_seeds): + dtype = ttnn.bfloat16 + pcc_required = 0.99 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "visual.blocks.0." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + dim = model_args.vision_dim + + reference_model = model_args.reference_vision_block() + reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + vision_dim = model_args.vision_dim + n_heads = model_args.vision_attn_n_heads + head_dim = vision_dim // n_heads + seq_len = model_args.vision_chunk_ntok - 1 + + tt_model = TtQwen2_5_VLVisionBlock( + mesh_device, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + model_args=model_args, + dtype=dtype, + ) + + pt_attention_input = torch.randn(seq_len, vision_dim) # no batch dim + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32) + + cos, sin = precompute_rope_cos_sin(seq_len, head_dim) + + reference_output = reference_model( + pt_attention_input, cu_seqlens, rotary_pos_emb=None, position_embeddings=(cos, sin) + ) + + tt_attention_input = model_args.prepare_residual_tensor_prefill( + pt_attention_input.unsqueeze(0), force_replicated=True + ) + + cos_tensor = ttnn.from_torch( + cos, + device=mesh_device, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + ) + sin_tensor = ttnn.from_torch( + sin, + device=mesh_device, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + ) + + tt_out = tt_model(tt_attention_input, cu_seqlens, position_embeddings=(cos_tensor, sin_tensor)) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[0, 0, :, :] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" + + +def precompute_rope_cos_sin(seq_len: int, dim: int, theta: float = 10000.0): + """ + Precompute RoPE cos/sin tensors. + Args: + seq_len: sequence length (number of tokens) + dim: hidden size (usually head_dim, not full hidden_size) + theta: RoPE theta parameter (default 10000) + Returns: + cos, sin: [seq_len, dim] each + """ + # Build the rope frequencies + half_dim = dim // 2 + freq_seq = torch.arange(half_dim, dtype=torch.float32) + inv_freq = 1.0 / (theta ** (freq_seq / half_dim)) + + # positions: [seq_len] + positions = torch.arange(seq_len, dtype=torch.float32) + + # Outer product: [seq_len, half_dim] + sinusoid_inp = torch.outer(positions, inv_freq) + + # Concatenate for complex dim + sin = torch.sin(torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)) + cos = torch.cos(torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)) + + return cos, sin diff --git a/models/tt_transformers/tests/multimodal/qwen_vl/test_image_merger.py b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_merger.py new file mode 100644 index 000000000000..dd4bc5bae365 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_merger.py @@ -0,0 +1,104 @@ +""""Test for Qwen 2.5 VL Patch Merger""" + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_patch_merger import TTQwen2_5_VLPatchMerger +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_patch_merger_inference(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_vision_qwen_merger() # Qwen Patch merger + first_layer_prefix = "visual.merger." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model.load_state_dict(partial_state_dict) + + tt_model = TTQwen2_5_VLPatchMerger( + device=device, + dim=5120, + state_dict=state_dict, + weight_key=first_layer_prefix, + args=tt_model_args, + layer_num=None, + state_dict_prefix="", + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps=1e-06, + dims=3584, + context_dim=1280, + spatial_merge_size=2, + mode=mode, + ) + + input = torch.rand(1, 4, 1280) + reference_output = reference_model(input) + tt_input = ttnn.from_torch( + input, + device=device, + dtype=dtype, + mesh_mapper=ttnn.ReplicateTensorToMesh(device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + tt_input = ttnn.reshape(tt_input, [1, 4, 1280]) + + tt_output = tt_model(tt_input) + + # DistributedNorm outputs are replicated across devices + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(device, dim=0)) + tt_output_torch = tt_output_torch[0, :] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("Merger Passed!") + else: + logger.warning("Merger Failed!") + + assert passing, f"Merger output does not meet PCC requirement {0.99}." diff --git a/models/tt_transformers/tests/multimodal/qwen_vl/test_image_mlp.py b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_mlp.py new file mode 100644 index 000000000000..f7b3dad7ba45 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_mlp.py @@ -0,0 +1,81 @@ +""""Test for Qwen 2.5 VL Vision MLP""" + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_mlp import QwenTTVisionMLP +from models.utility_functions import comp_allclose, comp_pcc, nearest_32, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_mlp_inference(batch, num_chunks, mesh_device, reset_seeds): + dtype = ttnn.bfloat16 + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "visual.blocks.0.feed_forward." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + model_args.WEIGHTS_DTYPE = dtype + + dim = model_args.vision_dim + seq_len = nearest_32(model_args.vision_chunk_ntok) * num_chunks + reference_model = model_args.reference_vision_mlp() + reference_model.load_state_dict(partial_state_dict) + + tt_model = QwenTTVisionMLP( + mesh_device=mesh_device, + args=model_args, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + ) + + torch_input = torch.randn(1, batch, seq_len, dim) + + reference_output = reference_model(torch_input).squeeze() + + tt_input = ttnn.from_torch( + torch_input, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + ) + + tt_output = tt_model(tt_input) + + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[ + 0, :, :, : + ].squeeze() + + pcc_required = 0.99 + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/qwen_vl/test_image_model.py b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_model.py new file mode 100644 index 000000000000..1d3d3313197a --- /dev/null +++ b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_model.py @@ -0,0 +1,81 @@ +"""Test for Qwen 2.5 VL Vision Transformer Pretrained Model Inference""" + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_vision_model import TtQwen2_5_VisionTransformerPretrainedModel +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 4),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_vision_inference(batch, num_chunks, mesh_device, reset_seeds): + dtype = ttnn.bfloat16 + pcc_required = 0.99 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "visual." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model = model_args.reference_vision_model() + reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + n_layers = model_args.vision_n_layers + + tt_model = TtQwen2_5_VisionTransformerPretrainedModel( + mesh_device, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + model_args=model_args, + dtype=dtype, + layers=n_layers, + ) + + pt_input = torch.randn([32, 1176]) # no batch dim + grid_thw = torch.tensor([[1, 4, 8]]) + + reference_output = reference_model( + pt_input, + grid_thw, + ) + + tt_attention_input = model_args.prepare_residual_tensor_prefill(pt_input.unsqueeze(0), force_replicated=True) + tt_out = tt_model(tt_attention_input, grid_thw) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[ + : tt_out.shape[0], : + ] + + non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) + tt_output_torch = tt_output_torch[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/tt_transformers/tests/multimodal/qwen_vl/test_image_patch_embed.py b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_patch_embed.py new file mode 100644 index 000000000000..27c645a8b643 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_patch_embed.py @@ -0,0 +1,101 @@ +""""Test for Qwen 2.5 VL Patch Embed""" + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_patch_embed import TTQwen2_5_VisionPatchEmbed +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_embed_inference(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_vision_qwen_patch_embed() # Qwen Patch embed + first_layer_prefix = "visual.patch_embed." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model.load_state_dict(partial_state_dict) + + tt_model = TTQwen2_5_VisionPatchEmbed( + device=device, + args=tt_model_args, + patch_size=14, + temporal_patch_size=2, + in_channels=3, + embed_dim=1280, + state_dict=state_dict, + weight_key=first_layer_prefix, + layer_num=None, + state_dict_prefix="", + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + mode=mode, + ) + + input = torch.rand(1, 1, 1380, 1176) + reference_output = reference_model(input) + + tt_input = ttnn.from_torch( + input, + device=device, + dtype=dtype, + mesh_mapper=ttnn.ReplicateTensorToMesh(device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + tt_output = tt_model(tt_input) + + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(device, dim=-1))[ + : tt_output.shape[0], : + ] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("Patch embed Passed!") + else: + logger.warning("Patch embed Failed!") + + assert passing, f"Patch embed output does not meet PCC requirement {0.99}." diff --git a/models/tt_transformers/tests/multimodal/qwen_vl/test_image_rms.py b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_rms.py new file mode 100644 index 000000000000..01fd4b6e1eef --- /dev/null +++ b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_rms.py @@ -0,0 +1,110 @@ +"""Test for Qwen 2.5 VL RMSNorm Layer Inference""" + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.distributed_norm import DistributedNorm +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_rmsnorm import RMSNorm +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_rmsnorm_inference(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + dim = tt_model_args.vision_dim + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_vision_rms_norm() # Qwen2_5 RMSNorm + first_layer_prefix = "visual.blocks.0.norm1." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model.load_state_dict(partial_state_dict) + + tt_inner_norm = RMSNorm( + device=device, + dim=dim, + state_dict=state_dict, + state_dict_prefix="", + weight_key=first_layer_prefix[:-1], # Remove trailing dot + weight_dtype=dtype, + is_distributed=False, + sharded_program_config=tt_model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=tt_model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], + ) + + # Wrap it in DistributedNorm + tt_model = DistributedNorm(tt_inner_norm, tt_model_args, TG=tt_model_args.is_galaxy) + + input = torch.rand(1, 1, 1280) + + reference_output = reference_model(input) + + # DistributedNorm inputs are fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode) + tt_input = ttnn.from_torch( + input, + device=device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, -1), mesh_shape=tt_model_args.cluster_shape), + memory_config=( + tt_model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG + ), + ) + + tt_output = tt_model(tt_input, mode=mode) + + # DistributedNorm outputs are replicated across devices + tt_output_torch = ttnn.to_torch( + tt_output, + mesh_composer=ttnn.ConcatMesh2dToTensor( + device, dims=(0, 2) if tt_model_args.is_galaxy else (2, 0), mesh_shape=tt_model_args.cluster_shape + ), + )[:1, :, :] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("rms_norm Passed!") + else: + logger.warning("rms_norm Failed!") + + assert passing, f"rms_norm output does not meet PCC requirement {0.99}." diff --git a/models/tt_transformers/tt/common.py b/models/tt_transformers/tt/common.py index 5eebf47ce735..a8302a71e7d5 100644 --- a/models/tt_transformers/tt/common.py +++ b/models/tt_transformers/tt/common.py @@ -5,9 +5,11 @@ import math import re from enum import Enum +from types import SimpleNamespace from typing import Optional import torch +from llama_models.llama3.api.datatypes import ImageMedia from loguru import logger from pydantic import BaseModel, Field @@ -614,3 +616,46 @@ def create_tt_model( tt_kv_cache = [l.attention.layer_past for l in model.layers] if paged_attention_config else None return tt_model_args, model, tt_kv_cache, state_dict + + +def hf_multimodal_encode(messages, processor): + hf_messages = [] + + for msg in messages: + hf_content = [] + + for item in msg.content: + if isinstance(item, ImageMedia): + hf_content.append( + { + "type": "image", + "image": item.image, + } + ) + elif isinstance(item, str): + hf_content.append( + { + "type": "text", + "text": item, + } + ) + + hf_messages.append( + { + "role": msg.role, + "content": hf_content, + } + ) + + encoded = processor.apply_chat_template( + hf_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to("cpu", dtype=torch.bfloat16) + + return SimpleNamespace( + **encoded, + tokens=encoded["input_ids"].squeeze(0), + vision=SimpleNamespace( + images=encoded["pixel_values"], + mask=None, + ), + ) diff --git a/models/tt_transformers/tt/decoder.py b/models/tt_transformers/tt/decoder.py index d86520611b5f..905df8e8de17 100644 --- a/models/tt_transformers/tt/decoder.py +++ b/models/tt_transformers/tt/decoder.py @@ -4,7 +4,6 @@ import ttnn from models.common.lightweightmodule import LightweightModule from models.common.rmsnorm import RMSNorm -from models.experimental.qwen25_vl.tt.text_mlp import MLP as QwenMLP from models.tt_transformers.tt.attention import Attention as DefaultAttention from models.tt_transformers.tt.distributed_norm import DistributedNorm from models.tt_transformers.tt.mlp import MLP @@ -41,6 +40,7 @@ def __init__( self.current = 0 self.model_config = args.get_model_config() self.simplified_rms = True if self.args.base_model_name == "Qwen2.5-VL-7B" else False + self.simplified_rms = True if self.args.base_model_name == "Qwen2.5-VL-7B" else False self.layer_num = layer_num @@ -57,26 +57,16 @@ def __init__( paged_attention_config=paged_attention_config, use_paged_kv_cache=use_paged_kv_cache, ) - if self.args.base_model_name == "Qwen2.5-VL-7B": - self.feed_forward = QwenMLP( - mesh_device=mesh_device, - args=args, - state_dict=state_dict, - weight_cache_path=weight_cache_path, - layer_num=layer_num, - dtype=dtype, - model_config=self.model_config, - ) - else: - self.feed_forward = MLP( - mesh_device=mesh_device, - args=args, - state_dict=state_dict, - weight_cache_path=weight_cache_path, - layer_num=layer_num, - dtype=dtype, - model_config=self.model_config, - ) + + self.feed_forward = MLP( + mesh_device=mesh_device, + args=args, + state_dict=state_dict, + weight_cache_path=weight_cache_path, + layer_num=layer_num, + dtype=dtype, + model_config=self.model_config, + ) self.attention_norm = DistributedNorm( RMSNorm( device=mesh_device, diff --git a/models/tt_transformers/tt/generator.py b/models/tt_transformers/tt/generator.py index cd0620049b91..c6dd416d376c 100644 --- a/models/tt_transformers/tt/generator.py +++ b/models/tt_transformers/tt/generator.py @@ -22,6 +22,7 @@ get_padded_prefill_len, num_blocks_in_seq, ) +from models.tt_transformers.tt.model_config import CheckpointType @dataclass(frozen=True) @@ -79,6 +80,7 @@ def prefill_forward_text( seq_len = int(prompt_lens[idx]) last_token_idx = seq_len - 1 prefill_seq_len = get_padded_prefill_len(seq_len) + local_kwargs = kwargs.copy() # Avoid modifying original kwargs logger.info(f"Prefilling User {user_id + 1} up to {seq_len} tokens") @@ -94,6 +96,12 @@ def prefill_forward_text( ) model_kv_cache = kv_cache[model_id] if kv_cache is not None else None + # Check if 'pixel_values' exists and index it safely + if "pixel_values" in local_kwargs: + local_kwargs["pixel_values"] = local_kwargs["pixel_values"][idx] + if "image_grid_thw" in local_kwargs: + local_kwargs["image_grid_thw"] = local_kwargs["image_grid_thw"][idx] + logits = self.prefill_forward_single_user_text( prefill_ids, page_table=page_table_user, @@ -101,7 +109,7 @@ def prefill_forward_text( last_token_idx=last_token_idx, kv_cache=model_kv_cache, model_id=model_id, - **kwargs, + **local_kwargs, ) out_list.append(logits) @@ -491,6 +499,61 @@ def _prefill_forward_single_user( # Note: This function is called by vLLM def prefill_forward( + self, + vision_images, + vision_masks, + tokens, + xattn_caches, + total_lens, + prompt_lens, + page_table=None, + kv_cache=None, + cross_page_table=None, + empty_slots=None, + **kwargs, + ): + if self.model_args[0].checkpoint_type == CheckpointType.HuggingFace: + logits = self.prefill_forward_text( + tokens, + page_table=page_table, + kv_cache=kv_cache, + prompt_lens=prompt_lens, + pixel_values=vision_images, + **kwargs, + ) + + return logits, None, None, None, None + + else: + ( + output_logits, + prefill_output_xattn_masks, + prefill_output_full_text_row_masked_out_masks, + decode_output_xattn_masks, + decode_output_full_text_row_masked_out_masks, + ) = self.prefill_forward_llama_vision( + vision_images, + vision_masks, + tokens, + xattn_caches, + total_lens, + prompt_lens, + page_table=page_table, + kv_cache=kv_cache, + cross_page_table=cross_page_table, + empty_slots=empty_slots, + ) + + return ( + output_logits, + prefill_output_xattn_masks, + prefill_output_full_text_row_masked_out_masks, + decode_output_xattn_masks, + decode_output_full_text_row_masked_out_masks, + ) + + # Note: This function is called by vLLM + def prefill_forward_llama_vision( self, vision_images, vision_masks, @@ -587,7 +650,7 @@ def prefill_forward( ) # Note: This function is called by vLLM - def decode_forward( + def decode_forward_llama_vision( self, start_pos, tokens, @@ -651,6 +714,47 @@ def decode_forward( else: return tt_logits + def decode_forward( + self, + start_pos, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + xattn_caches=None, + page_table=None, + kv_cache=None, + cross_page_table=None, + enable_trace=True, + read_from_device=True, + ): + import os + + if os.environ.get("HF_MODEL"): + return self.decode_forward_text( + tokens, + start_pos, + enable_trace=enable_trace, + page_table=page_table, + kv_cache=kv_cache, + ) + else: + return self.decode_forward_llama_vision( + start_pos, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + xattn_caches, + page_table, + kv_cache, + cross_page_table, + enable_trace, + read_from_device, + ) + # Note: This function is called by vLLM def read_decode_output(self, tt_out, unpadded_batch, is_tokens=False): """ diff --git a/models/tt_transformers/tt/load_checkpoints.py b/models/tt_transformers/tt/load_checkpoints.py index 38e68df53e2b..1214d18fec67 100644 --- a/models/tt_transformers/tt/load_checkpoints.py +++ b/models/tt_transformers/tt/load_checkpoints.py @@ -213,6 +213,12 @@ def convert_meta_to_hf(state_dict, head_dim): return state_dict +def convert_vision_meta_to_hf(state_dict, head_dim): + # state_dict = convert_meta_qkv_to_hf_format(state_dict, head_dim) + state_dict = map_vision_meta_to_hf_keys(state_dict) + return state_dict + + def replace_keys(state_dict, replacements): """ Replacements are in the form (pattern, replacement). @@ -238,6 +244,7 @@ def map_hf_to_meta_keys(loaded_weights): """ replacements = [ ("^emb.weight", "weight"), + ("model.language_model.", ""), ("language.model.", ""), ("model.", ""), ("embed_tokens", "tok_embeddings"), diff --git a/models/tt_transformers/tt/model.py b/models/tt_transformers/tt/model.py index 8fce82afe8bd..36baaa66363e 100644 --- a/models/tt_transformers/tt/model.py +++ b/models/tt_transformers/tt/model.py @@ -360,6 +360,7 @@ def forward( chunk_start_idx=None, get_last_token=-1, kv_cache=None, + **kwargs, ): for i, layer in enumerate(self.layers): # No-op if callers already provide the right memory config diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index 33641b93a268..95a51be08fe6 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -1725,23 +1725,18 @@ def load_state_dict(self): else: assert self.checkpoint_type == CheckpointType.HuggingFace if self.from_hf_url: - # Special case Qwen2.5-VL models until they are fully integrated into a HF release - if "Qwen2.5-VL" in self.model_name: - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLForConditionalGeneration as AutoModelForCausalLM, - ) + from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForImageTextToText - print("Loading Qwen2.5-VL model: ", AutoModelForCausalLM) + if "Qwen2.5-VL-7B" in self.model_name: + model = AutoModelForImageTextToText.from_pretrained(self.CKPT_DIR, torch_dtype="auto") else: - from transformers import AutoModelForCausalLM - - model = AutoModelForCausalLM.from_pretrained( - self.CKPT_DIR, - torch_dtype="auto" - # Note that the default setting is torch.dtype.float32, but model weights are - # may come in any dtype. If the model's weights are in torch.dtype.bfloat16, this would result in 2x memory usage from an - # unnecessary cast. - ) + model = AutoModelForCausalLM.from_pretrained( + self.CKPT_DIR, + torch_dtype="auto" + # Note that the default setting is torch.dtype.float32, but model weights are + # may come in any dtype. If the model's weights are in torch.dtype.bfloat16, this would result in 2x memory usage from an + # unnecessary cast. + ) if self.cache_hf_flag: self.cached_hf_model = model state_dict = model.state_dict() @@ -1753,7 +1748,7 @@ def load_state_dict(self): state_dict = standardize_hf_keys_multimodal(state_dict) else: state_dict = standardize_hf_keys(state_dict) - state_dict = convert_hf_to_meta(state_dict, self.head_dim) + state_dict = convert_hf_to_meta(state_dict, self.head_dim) keys_dict = list(state_dict.keys())[:] remv = [f"layers.{i}." for i in list(range(self.n_layers, self.full_model_n_layers))] @@ -2096,6 +2091,7 @@ def create_tokenizer(self): "Qwen2.5-1.5B": "Qwen/Qwen2.5-1.5B-Instruct", "Qwen2.5-3B": "Qwen/Qwen2.5-3B-Instruct", "Qwen2.5-7B": "Qwen/Qwen2.5-7B-Instruct", + "Qwen2.5-VL-7B": "Qwen/Qwen2.5-VL-7B-Instruct", "Qwen2.5-14B": "Qwen/Qwen2.5-14B-Instruct", "Qwen2.5-32B": "Qwen/Qwen2.5-32B-Instruct", "Qwen2.5-72B": "Qwen/Qwen2.5-72B-Instruct", diff --git a/models/tt_transformers/tt/multimodal/qwen_vl/qwen_e2e_model.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_e2e_model.py new file mode 100644 index 000000000000..e8cc53f7c3a6 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_e2e_model.py @@ -0,0 +1,132 @@ +from typing import List + +import torch + +import ttnn +from models.tt_transformers.tt.model import Transformer +from models.tt_transformers.tt.multimodal.llama_vision_model import _stack_images +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_vision_model import TtQwen2_5_VisionTransformerPretrainedModel + + +def _stack_images( + images: List[List[torch.Tensor]], # batch of samples, each with list of image embeddings +) -> List[torch.Tensor]: + """ + Concatenate image embeddings per sample into a single 2D tensor. + + Args: + images: List of samples, each being a list of [num_patches, hidden_dim] tensors + + Returns: + List of [total_patches, hidden_dim] tensors, one per sample + """ + return [torch.cat(image_list, dim=0) for image_list in images] + + +class TtQwen_Model(Transformer): + def __init__( + self, + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=None, + use_paged_kv_cache=False, + ): + super().__init__( + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=paged_attention_config, + use_paged_kv_cache=use_paged_kv_cache, + ) + + self.vision_model = TtQwen2_5_VisionTransformerPretrainedModel( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix="visual.", + dtype=dtype, + model_args=args, + weight_cache_path=args.weight_cache_path(dtype), + layers=args.vision_n_layers, + ) + + def prepare_inputs_prefill(self, pt_tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): + """ + Inputs are torch tensors or python types. This function returns ttnn + tensors on device. + TODO: Debate whether this function is responsible for padding + """ + + S = pt_tokens.shape[-1] + tokens = ttnn.from_torch( + pt_tokens.reshape(1, 1, 1, -1), + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + # self.embed_scale = args.dim**0.5 + tokens_embd = self.embd(tokens) + # tokens_embd = ttnn.multiply(tokens_embd, self.embed_scale) + + vision_output = self.compute_vision_token(**kwargs) + + tokens_embd = ttnn.to_torch(tokens_embd, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1)) + comp_vision_output = ttnn.to_torch( + vision_output, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0) + )[: vision_output.shape[0], :] + + image_features = comp_vision_output.squeeze(0) + special_image_mask = (pt_tokens == 151655).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(tokens_embd) + image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) + tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) + + tokens_embd = self.args.prepare_residual_tensor_prefill( + tokens_embd, + ) + + tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) + # Slice the rot mats to the prefill seqlen + assert ( + self.rope_setup.cos_matrix.shape[2] >= start_pos + S + ), f"Padded prefill end idx {start_pos + S} exceeds max seq len {self.rope_setup.cos_matrix.shape[2]}" + + tt_rot_mats_prefill_global = [ + self.rope_setup.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + + if page_table is not None: + tt_page_table = ttnn.from_torch( + page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_page_table = None + + if chunk_page_table is not None: + tt_chunk_page_table = ttnn.from_torch( + chunk_page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_chunk_page_table = None + + return tokens_embd, tt_rot_mats_prefill_global, tt_page_table, tt_chunk_page_table + + def compute_vision_token(self, pixel_values, image_grid_thw): + pixel_values = self.args.prepare_residual_tensor_prefill(pixel_values.unsqueeze(0), force_replicated=True) + + vision_output = self.vision_model(pixel_values, image_grid_thw) + return vision_output diff --git a/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_attention.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_attention.py new file mode 100644 index 000000000000..bd8dabf68676 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_attention.py @@ -0,0 +1,160 @@ +""" +This is the vision attention implementation for Qwen-VL-7B. + +We couldn't reuse the LLaMA version from tt_transformers because it expects separate q, k, v weights, +but Qwen-VL uses fused qkv weights. So this has been rewritten to support that, +based on the original code at: +models/tt_transformers/tt/multimodal/llama_image_attention.py +""" + + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +def rotate_half(x): + x1 = ttnn.slice(x, (0, 0, 0), (x.shape[0], x.shape[1], x.shape[2] // 2)) + x2 = ttnn.slice(x, (0, 0, x.shape[-1] // 2), (x.shape[0], x.shape[1], x.shape[2])) + return ttnn.concat([ttnn.mul(x2, -1, use_legacy=False), x1], dim=-1) + + +def apply_rotary_pos_emb_vision_tt(q, k, cos, sin): + cos = ttnn.unsqueeze(cos, -2) + sin = ttnn.unsqueeze(sin, -2) + + q_embed = ttnn.add(ttnn.mul(q, cos), ttnn.mul(rotate_half(q), sin)) + k_embed = ttnn.add(ttnn.mul(k, cos), ttnn.mul(rotate_half(k), sin)) + return q_embed, k_embed + + +class TtQwen2_5_VLVisionSdpaAttention(LightweightModule): + def __init__(self, mesh_device, state_dict, state_dict_prefix, dtype, configuration): + super().__init__() + + self.mesh_device = mesh_device + self.dtype = dtype + self.hidden_size = 1280 + self.num_heads = 16 + self.head_dim = self.hidden_size // self.num_heads + self.scale = self.head_dim**-0.5 + self.configuration = configuration + + # Load qkv weight & bias (fused): shape [hidden_size, hidden_size*3] + qkv_weight = state_dict[f"{state_dict_prefix}qkv.weight"] + qkv_bias = state_dict[f"{state_dict_prefix}qkv.bias"] + + # Transpose to [hidden_size, 3*hidden_size] for matmul + self.qkv_weight = ttnn.as_tensor( + torch.transpose(qkv_weight, -2, -1), + device=mesh_device, + dtype=dtype, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + self.qkv_bias = ttnn.as_tensor( + qkv_bias, + device=mesh_device, + dtype=dtype, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + # Output projection: proj + proj_weight = state_dict[f"{state_dict_prefix}proj.weight"] # shape [hidden_size, hidden_size] + proj_bias = state_dict[f"{state_dict_prefix}proj.bias"] # shape [hidden_size] + + self.proj_weight = ttnn.as_tensor( + torch.transpose(proj_weight, -2, -1), + device=mesh_device, + dtype=dtype, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + self.proj_bias = ttnn.as_tensor( + proj_bias, + device=mesh_device, + dtype=dtype, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=True, + packer_l1_acc=True, + dst_full_sync_en=False, + ) + + def forward(self, hidden_states, cu_seqlens, position_embeddings): + """ + hidden_states: ttnn.Tensor of shape [batch, seq_len, hidden_size] + position_embeddings: tuple (cos, sin) each of shape [seq_len, head_dim] + """ + seq_len = hidden_states.shape[-2] + cos, sin = position_embeddings + # Fused qkv projection + qkv = ttnn.linear( + hidden_states, + self.qkv_weight, + bias=self.qkv_bias, + dtype=self.dtype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) # shape [batch, seq_len, hidden_size*3] + + if self.configuration.num_devices > 1: + qkv = ttnn.all_gather(qkv, dim=-1, num_links=1) + + (q, k, v) = ttnn.permute(ttnn.reshape(qkv, [seq_len, 3, self.num_heads, -1]), [1, 0, 2, 3]) + ttnn.deallocate(qkv) + + # Apply rotary position embeddings + q, k = apply_rotary_pos_emb_vision_tt(q, k, cos, sin) + # return q + + seq_len = cu_seqlens[-1].item() + + q = ttnn.unsqueeze(ttnn.permute(ttnn.pad(q, [(0, 0), (0, 0), (0, 16)], 0), [1, 0, 2]), 0) + k = ttnn.unsqueeze(ttnn.permute(ttnn.pad(k, [(0, 0), (0, 0), (0, 16)], 0), [1, 0, 2]), 0) + v = ttnn.unsqueeze(ttnn.permute(ttnn.pad(v, [(0, 0), (0, 0), (0, 16)], 0), [1, 0, 2]), 0) + + attn_output = ttnn.transformer.scaled_dot_product_attention( + q, k, v, is_causal=False, scale=self.scale + ) # shape [1, seq_len, num_heads, head_dim] + + ttnn.deallocate(q) + ttnn.deallocate(k) + ttnn.deallocate(v) + + # attn_output shape: [1, 16, 4096, 96] + # Need to slice back from 96 → 80 + attn_output = ttnn.slice( + attn_output, + (0, 0, 0, 0), + (attn_output.shape[0], attn_output.shape[1], attn_output.shape[2], self.head_dim), # head_dim=80 + ) + + attn_output = ttnn.permute(ttnn.squeeze(attn_output, 0), [1, 0, 2]) + attn_output = ttnn.reshape(attn_output, [seq_len, -1]) + + # Final projection + output = ttnn.linear( + attn_output, + self.proj_weight, + bias=self.proj_bias, + dtype=self.dtype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + ttnn.deallocate(attn_output) + + if self.configuration.num_devices > 1: + output = ttnn.all_gather(output, dim=1, num_links=1) + + return output diff --git a/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_block.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_block.py new file mode 100644 index 000000000000..14f81eae0263 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_block.py @@ -0,0 +1,79 @@ +""" +This is the vision block used in the Qwen-VL-7B architecture +consisting of RMSnorm and self-attention layer followed by an MLP layer. +""" + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_attention import TtQwen2_5_VLVisionSdpaAttention +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_mlp import QwenTTVisionMLP +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_rmsnorm import RMSNorm + + +class TtQwen2_5_VLVisionBlock(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + dtype, + model_args, + weight_cache_path=None, + state_dict_prefix=None, + ): + super().__init__() + + self.norm1 = RMSNorm( + device=mesh_device, + dim=1280, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="norm1", + weight_dtype=dtype, + is_distributed=False, + sharded_program_config=model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], + ) + + self.norm2 = RMSNorm( + device=mesh_device, + dim=1280, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="norm2", + weight_dtype=dtype, + is_distributed=False, + sharded_program_config=model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], + ) + + self.attn = TtQwen2_5_VLVisionSdpaAttention( + mesh_device, + state_dict, + state_dict_prefix=f"{state_dict_prefix}attn.", + # weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + ) + + self.mlp = QwenTTVisionMLP( + mesh_device=mesh_device, + args=model_args, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}feed_forward.", + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + ) + + def forward(self, hidden_states, cu_seqlens, position_embeddings): + hidden_states = ttnn.add( + hidden_states, + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ), + ) + + hidden_states = ttnn.add(hidden_states, self.mlp(self.norm2(hidden_states))) + + return hidden_states diff --git a/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_mlp.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_mlp.py new file mode 100644 index 000000000000..ec13ff554e85 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_mlp.py @@ -0,0 +1,128 @@ +""" +This is the MLP (feed-forward) implementation for Qwen-VL-7B. + +We couldn't reuse TtLlamaImageFeedForward from tt_transformers because the logic is different. +Qwen does: down_proj(act_fn(gate_proj(x)) * up_proj(x)) +Tt does: c_proj(activation(c_fc(x))) + +So this version was written specifically for Qwen, based on its architecture. +""" + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +class QwenTTVisionMLP(LightweightModule): + def __init__( + self, + mesh_device, + args, + state_dict, + weight_cache_path, + dtype, + state_dict_prefix=None, + ): + super().__init__() + + self.mesh_device = mesh_device + self.args = args + self.state_dict = state_dict + self.dim = args.dim + + def get_weight(name): + return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) + + def get_bias(name): + return state_dict[f"{state_dict_prefix}{name}.bias"] + + def cache_name(name): + if args.dummy_weights: + return None + return weight_cache_path / f"{state_dict_prefix}.{name}" + + def as_tensor(name, dtype, is_bias=False): + tensor_data = get_bias(name) if is_bias else get_weight(name) + return ttnn.as_tensor( + tensor_data, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + # cache_file_name=cache_name(name), + ) + + # Weights and Biases + self.w1 = as_tensor("w1", dtype) + self.b1 = as_tensor("w1", ttnn.bfloat16, is_bias=True) + + self.w3 = as_tensor("w3", dtype) + self.b3 = as_tensor("w3", ttnn.bfloat16, is_bias=True) + + self.w2 = as_tensor("w2", dtype) + self.b2 = as_tensor("w2", ttnn.bfloat16, is_bias=True) + + self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=True, + packer_l1_acc=True, + dst_full_sync_en=False, + ) + + def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: + """ + Qwen HF MLP reference: + output = down_proj(act_fn(gate_proj(x)) * up_proj(x)) + Mapping: + w1 -> gate_proj + w3 -> up_proj + w2 -> down_proj + """ + + # Linear with GELU activation + w1_out = ttnn.linear( + x, + self.w1, + bias=self.b1, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + activation="silu", + compute_kernel_config=self.compute_kernel_config, + ) + + w3_out = ttnn.linear( + x, + self.w3, + bias=self.b3, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + + if self.args.num_devices > 1: + w1_out = ttnn.all_gather(w1_out, dim=3, num_links=1) + w3_out = ttnn.all_gather(w3_out, dim=3, num_links=1) + + # Element-wise multiply + w2_in = ttnn.mul(w1_out, w3_out, dtype=ttnn.bfloat16) + + # Final projection + w2_out = ttnn.linear( + w2_in, + self.w2, + bias=self.b2, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + + ttnn.deallocate(w1_out) + ttnn.deallocate(w3_out) + ttnn.deallocate(w2_in) + + if self.args.num_devices > 1: + w2_out = ttnn.all_gather(w2_out, dim=len(w2_out.shape) - 1, num_links=1) + + return w2_out diff --git a/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_patch_embed.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_patch_embed.py new file mode 100644 index 000000000000..2cdc3e679103 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_patch_embed.py @@ -0,0 +1,68 @@ +""" +This is the patch embedding implementation for Qwen-VL-7B. + +The existing TtLlamaConv2dPatch from tt_transformers uses Conv2d, but Qwen needs Conv3d instead. +Since the stride size is the same as the kernel size for this operation, we can use a matrix +multiplication (matmul) instead of a convolution. This is necessary because +`ttnn.experimental.conv3d` currently only supports Conv3d with stride (1, 1, 1). +""" + +import ttnn + + +class TTQwen2_5_VisionPatchEmbed: + def __init__( + self, + device, + args, + patch_size, + temporal_patch_size, + in_channels, + embed_dim, + state_dict, + weight_key, + layer_num=None, + state_dict_prefix="", + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + mode="decode", + ): + super().__init__() + self.mode = mode + self.device = device + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + self.weight_memory_config = weight_memory_config + self.weight_dtype = weight_dtype + self.args = args + + weight_name_1 = f"{state_dict_prefix}{weight_key}proj.weight" + torch_weight = state_dict[weight_name_1] + + weight_matrix = torch_weight.view(self.embed_dim, -1) + self.weight = ttnn.from_torch( + weight_matrix.T, + device=self.device, + dtype=self.weight_dtype, + mesh_mapper=ttnn.ShardTensorToMesh(self.device, dim=-1), + layout=ttnn.TILE_LAYOUT, + memory_config=self.weight_memory_config, + ) + self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=True, + packer_l1_acc=True, + dst_full_sync_en=False, + ) + + def __call__(self, x: ttnn.Tensor) -> ttnn.Tensor: + x_flattened = ttnn.reshape(x, (x.shape[2], -1)) + output = ttnn.matmul(x_flattened, self.weight, compute_kernel_config=self.compute_kernel_config) + + if self.args.num_devices > 1: + output = ttnn.all_gather(output, dim=1, num_links=1) + + return output diff --git a/models/tt_transformers/tt/multimodal/qwen_vl/qwen_patch_merger.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_patch_merger.py new file mode 100644 index 000000000000..71cd4b4c02c2 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_patch_merger.py @@ -0,0 +1,116 @@ +""" +This is the patch merger implementation used in the Qwen-VL-7B model. + +There's no existing implementation for this in tt_transformers, +so it was written specifically based on Qwen-VL's architecture. +""" + +import ttnn +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_rmsnorm import RMSNorm + + +class TTQwen2_5_VLPatchMerger: + def __init__( + self, + device, + dim, + state_dict, + weight_key, + args, + layer_num=None, + state_dict_prefix="", + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps: float = 1e-06, + dims=3584, + context_dim=1280, + spatial_merge_size=2, + mode="decode", + ): + super().__init__() + self.eps = eps + self.mode = mode + + self.args = args + + weight_name_1 = f"{state_dict_prefix}{weight_key}ln_q.weight" + weight_name_2 = f"{state_dict_prefix}{weight_key}feed_forward.0.weight" + weight_name_3 = f"{state_dict_prefix}{weight_key}feed_forward.2.weight" + + self.weight_1 = ttnn.as_tensor( + state_dict[weight_name_1], + device=device, + dtype=weight_dtype, + mesh_mapper=ttnn.ShardTensorToMesh(device, dim=-1), + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=weight_memory_config, + ) + + self.weight_2 = ttnn.as_tensor( + state_dict[weight_name_2].transpose(0, 1), + device=device, + dtype=weight_dtype, + mesh_mapper=ttnn.ShardTensorToMesh(device, dim=-1), + layout=ttnn.TILE_LAYOUT, + memory_config=weight_memory_config, + ) + + self.weight_3 = ttnn.as_tensor( + state_dict[weight_name_3].transpose(0, 1), + device=device, + dtype=weight_dtype, + mesh_mapper=ttnn.ShardTensorToMesh(device, dim=-1), + layout=ttnn.TILE_LAYOUT, + memory_config=weight_memory_config, + ) + + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = RMSNorm( + device=device, + dim=1280, + state_dict=state_dict, + state_dict_prefix="", + weight_key="visual.merger.ln_q", + weight_dtype=ttnn.bfloat16, + is_distributed=False, + sharded_program_config=self.args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=False, + ) + + self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=True, + packer_l1_acc=True, + dst_full_sync_en=False, + ) + + def __call__(self, x): + x = self.ln_q(x, mode=self.mode) + x = ttnn.reshape(x, (1, 1, -1, self.hidden_size)) + + x = ttnn.linear( + x, + self.weight_2, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + + if self.args.num_devices > 1: + x = ttnn.all_gather(x, dim=3) + + x = ttnn.gelu(x) + + x = ttnn.linear( + x, + self.weight_3, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + if self.args.num_devices > 1: + x = ttnn.all_gather(x, dim=3, num_links=1) + + x = ttnn.reshape(x, (-1, x.shape[-1])) + + return x diff --git a/models/tt_transformers/tt/multimodal/qwen_vl/qwen_rmsnorm.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_rmsnorm.py new file mode 100644 index 000000000000..c0a32474fa49 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_rmsnorm.py @@ -0,0 +1,138 @@ +""" +This is a modified RMSNorm implementation for Qwen-VL-7B. + +It's based on the existing RMSNorm in models/common/rmsnorm.py, +with slight changes to support the bf8 data type. +""" + + +import ttnn +from models.common.lightweightmodule import LightweightModule + +TILE = 32 +SHARD_HEIGHT = TILE # Current ttnn.rms_norm implementation requires shard height to be a single tile + + +class RMSNorm(LightweightModule): + def __init__( + self, + device, + dim, + state_dict, + weight_key, + layer_num=None, + state_dict_prefix=None, + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps: float = 1e-06, + sharded_program_config=None, + sharded_output_config=None, + output_mem_config=None, + ccl_topology=ttnn.Topology.Ring, + ): + super().__init__() + self.eps = eps + self.is_distributed = is_distributed + self.ccl_topology = ccl_topology + + if state_dict_prefix: + weight_name = f"{state_dict_prefix}{weight_key}.weight" + else: + if layer_num is None: + weight_name = f"{weight_key}.weight" + else: + weight_name = f"layers.{layer_num}.{weight_key}.weight" + torch_weight = ( + state_dict[weight_name].unsqueeze(0).view(1, 1, dim).reshape([1, 1, dim // SHARD_HEIGHT, SHARD_HEIGHT]) + ) + + cache_name = None if weight_cache_path is None else weight_cache_path / weight_name + + # Compatibility with models that don't use mesh devices (e.g. single-chip Mistral-7b) + is_mesh_device = device.__class__.__name__ == "MeshDevice" + + self.weight = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=weight_memory_config, + # cache_file_name=cache_name, + mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + ) + + if self.is_distributed: + self.weight_distributed = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=weight_memory_config, + # cache_file_name=cache_name, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape)) + if is_mesh_device + else None, + ) + + self.sharded_output_config = sharded_output_config + self.sharded_program_config = sharded_program_config + self.output_mem_config = output_mem_config + + self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=False, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + def forward(self, x: ttnn.Tensor, mode="decode", in_sharded=False, out_sharded=False) -> ttnn.Tensor: + # If input is sharded do sharded RMSNorm and optionally return sharded output + program_config = self.sharded_program_config if in_sharded else None + memory_config = self.sharded_output_config if out_sharded else None + distributed = self.is_distributed and self.is_distributed(mode) + norm = self._distributed_rmsnorm + weight = self.weight_distributed if distributed else self.weight + + if in_sharded: + assert not distributed, "Distributed RMSNorm does not support sharded inputs" + else: + assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor" + + x = norm( + x, + epsilon=self.eps, + weight=weight, + program_config=program_config, + memory_config=memory_config, + compute_kernel_config=self.compute_kernel_config_hifi2, + ) + + if in_sharded and not out_sharded: + return ttnn.sharded_to_interleaved(x) + else: + return x + + def _distributed_rmsnorm( + self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None + ): + xnorm = ttnn.pow(inp, 2) + + xnorm = ttnn.mean(xnorm, dim=-1, keepdim=True) + + xnorm = ttnn.rsqrt(xnorm + epsilon) + + xnorm = ttnn.multiply(inp, xnorm) + + weight = ttnn.reshape(weight, [1, 1, -1]) + + output = ttnn.multiply(xnorm, (weight), use_legacy=False) + + if memory_config is not None: + output = ttnn.to_memory_config(output, memory_config) + + ttnn.deallocate(xnorm) + ttnn.deallocate(weight) + + return output diff --git a/models/tt_transformers/tt/multimodal/qwen_vl/qwen_vision_model.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_vision_model.py new file mode 100644 index 000000000000..9a1a7f16b9d3 --- /dev/null +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_vision_model.py @@ -0,0 +1,250 @@ +""" +This is the end-to-end architecture of the Qwen-VL 2.5 vision model. + +It brings together all components—patch embedding, vision blocks, rotary embeddings, +and patch merger for visual input processing. +""" + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_block import TtQwen2_5_VLVisionBlock +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_patch_embed import TTQwen2_5_VisionPatchEmbed +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_patch_merger import TTQwen2_5_VLPatchMerger +from models.tt_transformers.tt.rope import TTQwen2_5_VisionRotaryEmbedding + + +class TtQwen2_5_VisionTransformerPretrainedModel(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + model_args, + layers, + block_key="", + gated=False, + ): + self.spatial_merge_size = model_args.spatial_merge_size + self.patch_size = model_args.vision_patch_size + self.fullatt_block_indexes = model_args.fullatt_block_indexes + self.window_size = model_args.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + self.mesh_device = mesh_device + hidden_size = model_args.vision_dim + n_heads = model_args.vision_attn_n_heads + out_hidden_size = model_args.out_hidden_size + temporal_patch_size = model_args.temporal_patch_size + + self.patch_embed = TTQwen2_5_VisionPatchEmbed( + device=mesh_device, + args=model_args, + patch_size=self.patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=3, + embed_dim=hidden_size, + state_dict=state_dict, + weight_key="patch_embed.", + layer_num=None, + state_dict_prefix=state_dict_prefix, + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + ) + + head_dim = hidden_size // n_heads + + self.rotary_pos_emb = TTQwen2_5_VisionRotaryEmbedding( + device=mesh_device, + dim=head_dim // 2, + theta=10000.0, + ) + + self.blocks = [ + TtQwen2_5_VLVisionBlock( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}blocks.{i}.", + weight_cache_path=weight_cache_path, + dtype=dtype, + model_args=model_args, + ) + for i in tqdm(range(layers), desc=f"Loading vision transformer blocks") + ] + + self.merger = TTQwen2_5_VLPatchMerger( + device=mesh_device, + dim=5120, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + args=model_args, + weight_key="merger.", + layer_num=None, + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps=1e-06, + dims=out_hidden_size, + context_dim=hidden_size, + spatial_merge_size=self.spatial_merge_size, + ) + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb_full = ttnn.to_torch( + rotary_pos_emb_full, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0) + )[: rotary_pos_emb_full.shape[0], :] + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states, grid_thw): + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + # device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len = hidden_states.shape[-2] + hidden_states = ttnn.reshape(hidden_states, [seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1]) + tt_index = ttnn.from_torch( + window_index.view(-1, 1, 1).expand(-1, hidden_states.shape[-2], hidden_states.shape[-1]).permute(1, 2, 0), + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.TILE_LAYOUT, + # memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + hidden_states = ttnn.gather(ttnn.permute(hidden_states, (1, 2, 0)), dim=-1, index=tt_index) + hidden_states = ttnn.permute(hidden_states, (2, 0, 1)) + hidden_states = ttnn.reshape(hidden_states, [1, 1, seq_len, -1]) + + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + + cos_tensor = ttnn.from_torch( + emb.cos(), + device=self.mesh_device, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + layout=ttnn.TILE_LAYOUT, + ) + sin_tensor = ttnn.from_torch( + emb.sin(), + device=self.mesh_device, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + layout=ttnn.TILE_LAYOUT, + ) + + position_embeddings = (cos_tensor, sin_tensor) + + ttnn.deallocate(tt_index) + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) + + ttnn.deallocate(cos_tensor) + ttnn.deallocate(sin_tensor) + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + + tt_reverse_indices = ttnn.from_torch( + reverse_indices.view(-1, 1).expand(-1, hidden_states.shape[-1]).transpose(0, 1), + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.TILE_LAYOUT, + ) + hidden_states = ttnn.gather(ttnn.permute(hidden_states, (1, 0)), dim=-1, index=tt_reverse_indices) + hidden_states = ttnn.permute(hidden_states, (1, 0)) + + return hidden_states diff --git a/models/tt_transformers/tt/rope.py b/models/tt_transformers/tt/rope.py index e5e96c148fb2..2b13582897d0 100644 --- a/models/tt_transformers/tt/rope.py +++ b/models/tt_transformers/tt/rope.py @@ -447,3 +447,26 @@ def get_rot_mats( if return_rot_idxs: return [cos, sin], rot_idxs return [cos, sin] + + +class TTQwen2_5_VisionRotaryEmbedding: + def __init__(self, device, dim: int, theta: float = 10000.0, mode="decode"): + self.dim = dim + self.theta = theta + self.device = device + + arange_indices = ttnn.arange(start=0, end=dim, step=2, device=device) + arange_indices = ttnn.to_layout(arange_indices, ttnn.TILE_LAYOUT) + exponent = ttnn.div(arange_indices, dim) + pow_result = ttnn.pow(theta, exponent) + recip = ttnn.reciprocal(pow_result) + self.inv_freq = ttnn.multiply(recip, 1.0) + + def __call__(self, seqlen: int): + tt_seq = ttnn.arange(end=seqlen, device=self.device) + tt_seq = ttnn.reshape(tt_seq, [1, 1, 1, tt_seq.shape[0]]) + tt_freq = ttnn.reshape(self.inv_freq, [1, 1, 1, self.inv_freq.shape[0]]) + tt_freqs = ttnn.outer(tt_seq, tt_freq) + tt_freqs = ttnn.reshape(tt_freqs, [tt_freqs.shape[2], tt_freqs.shape[3]]) + + return tt_freqs From 333da1cc01e2d37933f309685c42ebe92dcb0b5a Mon Sep 17 00:00:00 2001 From: MohammedTaherMcW Date: Thu, 7 Aug 2025 16:05:58 +0000 Subject: [PATCH 2/2] Remove experimental Qwen --- .../qwen25_vl/tests/test_attention.py | 128 ------ .../experimental/qwen25_vl/tests/test_e2e.py | 408 ------------------ .../qwen25_vl/tests/test_image_merger.py | 103 ----- .../qwen25_vl/tests/test_image_patch_embed.py | 95 ---- .../experimental/qwen25_vl/tests/test_mlp.py | 82 ---- .../qwen25_vl/tests/test_vision_block.py | 116 ----- .../qwen25_vl/tests/test_vision_model.py | 84 ---- .../qwen25_vl/tests/test_vision_rms.py | 113 ----- models/experimental/qwen25_vl/tt/attention.py | 141 ------ models/experimental/qwen25_vl/tt/mlp.py | 121 ------ models/experimental/qwen25_vl/tt/model.py | 117 ----- .../experimental/qwen25_vl/tt/patch_embed.py | 62 --- .../experimental/qwen25_vl/tt/patch_merger.py | 117 ----- models/experimental/qwen25_vl/tt/rmsnorm.py | 140 ------ models/experimental/qwen25_vl/tt/rope.py | 41 -- models/experimental/qwen25_vl/tt/text_mlp.py | 114 ----- .../experimental/qwen25_vl/tt/vision_block.py | 79 ---- .../experimental/qwen25_vl/tt/vision_model.py | 242 ----------- 18 files changed, 2303 deletions(-) delete mode 100644 models/experimental/qwen25_vl/tests/test_attention.py delete mode 100644 models/experimental/qwen25_vl/tests/test_e2e.py delete mode 100644 models/experimental/qwen25_vl/tests/test_image_merger.py delete mode 100644 models/experimental/qwen25_vl/tests/test_image_patch_embed.py delete mode 100644 models/experimental/qwen25_vl/tests/test_mlp.py delete mode 100644 models/experimental/qwen25_vl/tests/test_vision_block.py delete mode 100644 models/experimental/qwen25_vl/tests/test_vision_model.py delete mode 100644 models/experimental/qwen25_vl/tests/test_vision_rms.py delete mode 100644 models/experimental/qwen25_vl/tt/attention.py delete mode 100644 models/experimental/qwen25_vl/tt/mlp.py delete mode 100644 models/experimental/qwen25_vl/tt/model.py delete mode 100644 models/experimental/qwen25_vl/tt/patch_embed.py delete mode 100644 models/experimental/qwen25_vl/tt/patch_merger.py delete mode 100644 models/experimental/qwen25_vl/tt/rmsnorm.py delete mode 100644 models/experimental/qwen25_vl/tt/rope.py delete mode 100644 models/experimental/qwen25_vl/tt/text_mlp.py delete mode 100644 models/experimental/qwen25_vl/tt/vision_block.py delete mode 100644 models/experimental/qwen25_vl/tt/vision_model.py diff --git a/models/experimental/qwen25_vl/tests/test_attention.py b/models/experimental/qwen25_vl/tests/test_attention.py deleted file mode 100644 index c01f78924f4e..000000000000 --- a/models/experimental/qwen25_vl/tests/test_attention.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Test for Qwen 2.5 VL Vision Attention""" - -import os - -import pytest -import torch -from loguru import logger - -import ttnn -from models.tt_transformers.tt.model_config import ModelArgs - -from models.experimental.qwen25_vl.tt.attention import TtQwen2_5_VLVisionSdpaAttention -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull - - -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "batch, num_chunks", - ((1, 4),), -) -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -def test_attention_inference(batch, num_chunks, mesh_device, reset_seeds): - dtype = ttnn.bfloat16 - pcc_required = 0.99 - - model_args = ModelArgs(mesh_device) - state_dict = model_args.load_state_dict() - - # Ref model needs partial state dict, but our models use full state dict keys as cached weight names - first_layer_prefix = "visual.blocks.0.attn." - partial_state_dict = { - k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) - } - - dim = model_args.vision_dim - - reference_model = model_args.reference_vision_attention() - reference_model.load_state_dict(partial_state_dict) - reference_model.eval() - - hidden_size = model_args.vision_dim - n_heads = model_args.vision_attn_n_heads - head_dim = hidden_size // n_heads - seq_len = model_args.vision_chunk_ntok - - tt_model = TtQwen2_5_VLVisionSdpaAttention( - mesh_device, - state_dict, - state_dict_prefix=first_layer_prefix, - # weight_cache_path=model_args.weight_cache_path(dtype), - dtype=dtype, - configuration=model_args, - ) - - seq_len = 4096 - hidden_dim = 1280 - num_heads = 16 - head_dim = hidden_dim // num_heads # 80 - rotary_dim = head_dim // 2 # 40 - - # Step 1: PyTorch input - pt_attention_input = torch.randn(seq_len, hidden_dim) # no batch dim - cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32) - - # Step 2: precompute cos/sin - cos, sin = precompute_rope_cos_sin(seq_len, head_dim) - - # Step 3: run PyTorch reference - reference_output = reference_model( - pt_attention_input, cu_seqlens, rotary_pos_emb=None, position_embeddings=(cos, sin) - ) - - # Step 4: TT input - tt_attention_input = model_args.prepare_residual_tensor_prefill( - pt_attention_input.unsqueeze(0), force_replicated=True - ) - - cos_tensor = ttnn.from_torch(cos, device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) - sin_tensor = ttnn.from_torch(sin, device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) - - # Step 6: run TT - tt_out = tt_model(tt_attention_input, cu_seqlens, position_embeddings=(cos_tensor, sin_tensor)) - - # Doing contract in tt is correct!! - tt_output_torch = ttnn.to_torch(tt_out, device=mesh_device).squeeze(0) - - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) - - logger.info(comp_allclose(reference_output, tt_output_torch)) - logger.info(f"PCC: {pcc_message}") - - assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" - - -def precompute_rope_cos_sin(seq_len: int, dim: int, theta: float = 10000.0): - """ - Precompute RoPE cos/sin tensors. - Args: - seq_len: sequence length (number of tokens) - dim: hidden size (usually head_dim, not full hidden_size) - theta: RoPE theta parameter (default 10000) - Returns: - cos, sin: [seq_len, dim] each - """ - # Build the rope frequencies - half_dim = dim // 2 - freq_seq = torch.arange(half_dim, dtype=torch.float32) - inv_freq = 1.0 / (theta ** (freq_seq / half_dim)) - - # positions: [seq_len] - positions = torch.arange(seq_len, dtype=torch.float32) - - # Outer product: [seq_len, half_dim] - sinusoid_inp = torch.outer(positions, inv_freq) - - # Concatenate for complex dim - sin = torch.sin(torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)) - cos = torch.cos(torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)) - - return cos, sin diff --git a/models/experimental/qwen25_vl/tests/test_e2e.py b/models/experimental/qwen25_vl/tests/test_e2e.py deleted file mode 100644 index 63a9b0ad0b00..000000000000 --- a/models/experimental/qwen25_vl/tests/test_e2e.py +++ /dev/null @@ -1,408 +0,0 @@ -"""Test for Qwen 2.5 VL End-to-End Vision-Text Pipeline""" - -import torch -import pytest -from loguru import logger -import os -import ttnn -from models.tt_transformers.tt.common import ( - sample_host, - PagedAttentionConfig, - preprocess_inputs_prefill, -) -from models.tt_transformers.tt.model_config import DecodersPrecision - -from models.experimental.qwen25_vl.tt.model import Qwen25VLTransformer as Transformer - -from models.tt_transformers.tt.generator import Generator - -from models.experimental.qwen25_vl.tt.vision_model import TtQwen2_5_VisionTransformerPretrainedModel -from models.utility_functions import skip_for_grayskull, skip_for_blackhole - -from models.tt_transformers.tt.model_config import ModelArgs -from transformers import AutoProcessor - -import re - - -def parse_chat_output(text): - """Parse chat output format from generated text.""" - pattern = r"<\|(?Puser|assistant)\|>\s*(?P.*?)(?=<\|(?:user|assistant|end)\|>|$)" - matches = re.finditer(pattern, text, re.DOTALL) - return [(match.group("role"), match.group("message").strip()) for match in matches] - - -def display_chat(logger, conversation): - """Display chat conversation in formatted output.""" - for role, message in conversation: - if role == "user": - logger.info(f"👤 User: {message}") - elif role == "assistant": - logger.info(f"🤖 Assistant: {message}") - - -def setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations): - """Setup model arguments for vision-enabled model (Single Responsibility).""" - instruct = True if weights == "instruct" else False - - model_args = ModelArgs( - mesh_device=mesh_device, - instruct=instruct, - optimizations=optimizations, - max_seq_len=max_seq_len, - max_batch_size=batch_size, - ) - - return model_args, instruct - - -def setup_vision_prompts_and_tokenizer(model_args, instruct): - """Setup multimodal prompts and tokenizer for vision-enabled model.""" - messages = [ - { - "role": "user", - "content": [ - { - "type": "image", - "image": "https://raw.githubusercontent.com/yavuzceliker/sample-images/refs/heads/main/images/image-1.jpg", - }, - {"type": "text", "text": "Describe this image in detail in 1000 words."}, - ], - } - ] - - tokenizer = model_args.tokenizer - return messages, tokenizer - - -def process_real_vision_inputs(messages, model_args): - """Process real image inputs using AutoProcessor (Interface Segregation).""" - processor = AutoProcessor.from_pretrained(os.getenv("HF_MODEL")) - - encoded = processor.apply_chat_template( - messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" - ).to("cpu", dtype=torch.bfloat16) - - input_ids = encoded["input_ids"] - pixel_values = encoded["pixel_values"] - attention_mask = encoded["attention_mask"] - image_grid_thw = encoded["image_grid_thw"] - - # logger.info(f"Processed vision inputs - input_ids: {input_ids.shape}, pixel_values: {pixel_values.shape}") - - return { - "input_ids": input_ids, - "pixel_values": pixel_values, - "attention_mask": attention_mask, - "image_grid_thw": image_grid_thw, - "processor": processor, - } - - -def load_separate_models_like_test_end2end(model_args, mesh_device, dtype, paged_attention, page_params): - """Load separate vision and text models following test_end2end.py pattern.""" - state_dict = model_args.load_state_dict() - - vision_prefix = "visual." - - # Setup paged attention config (exactly like test_end2end.py) - paged_attention_config = None - if paged_attention: - paged_attention_config = PagedAttentionConfig( - block_size=page_params["page_block_size"], - max_num_blocks=page_params["page_max_num_blocks"], - ) - - # Load vision model (exactly like test_end2end.py) - vision_model = TtQwen2_5_VisionTransformerPretrainedModel( - mesh_device=mesh_device, - state_dict=state_dict, - state_dict_prefix=vision_prefix, - dtype=dtype, - model_args=model_args, - weight_cache_path=model_args.weight_cache_path(dtype), - layers=model_args.vision_n_layers, - ) - # Load text model (exactly like test_end2end.py) - - text_model = Transformer( - args=model_args, - mesh_device=mesh_device, - dtype=dtype, - state_dict=state_dict, - weight_cache_path=model_args.weight_cache_path(dtype), - paged_attention_config=paged_attention_config, - ) - - logger.info("Separate vision and text models loaded like test_end2end.py") - return vision_model, text_model - - -def run_generation_exactly_like_test_end2end( - vision_model, text_model, processed_inputs, model_args, page_table=None, paged_attention_config=None, max_gen_len=20 -): - """Run generation following the EXACT pattern from test_end2end.py.""" - input_ids = processed_inputs["input_ids"] - pixel_values = processed_inputs["pixel_values"] - - logger.info("Running generation exactly like test_end2end.py...") - - logger.info("Running Vision Model...") - - generator = Generator([text_model], [model_args], vision_model.mesh_device, tokenizer=model_args.tokenizer) - - tt_kv_cache = [[l.attention.layer_past for l in text_model.layers]] if paged_attention_config else None - - input_tokens_prefill = input_ids - batch_size = input_tokens_prefill.shape[0] - - prompt_text = model_args.tokenizer.decode(input_ids[0].tolist()) - input_prompts = [prompt_text] - - ( - input_tokens_prefill_pt, - encoded_prompts, - decoding_pos, - prefill_lens, - ) = preprocess_inputs_prefill( - input_prompts, - model_args.tokenizer, - [model_args], - instruct=True, - max_generated_tokens=max_gen_len, - max_prefill_len=8192, - ) - - input_tokens_prefill_pt = torch.stack(input_tokens_prefill_pt).view(batch_size, -1) - - logger.info("Running prefill...") - logits = generator.prefill_forward_text( - input_tokens_prefill_pt, - page_table=page_table, - kv_cache=tt_kv_cache, - prompt_lens=decoding_pos, - vision_model=vision_model, - processed_inputs=processed_inputs, - ) - - prefilled_token = torch.argmax(logits, dim=-1) - logger.info(f"Prefilled token: {prefilled_token}") - - all_outputs = [encoded_prompts[0][: prefill_lens[0]]] - all_outputs[0].append(int(prefilled_token[0].item())) - - current_pos = torch.tensor([decoding_pos[0]]) - out_tok = prefilled_token - generation_length = 2000 - - results = [] - - logger.info("Starting decode loop...") - for iteration in range(generation_length): - logger.info(f"[Text] Decoding token {iteration}, current_pos: {current_pos.item()}") - - logits = generator.decode_forward_text( - out_tok, - current_pos, - enable_trace=False, - page_table=page_table, - kv_cache=tt_kv_cache, - ) - - _, out_tok = sample_host( - logits, - temperature=0, - top_p=0.9, - ) - - token_id = out_tok[0].item() - decoded_token = model_args.tokenizer.decode([token_id]) - logger.info(f"Generated token {iteration}: ID={token_id}, text='{decoded_token}'") - - # Create result object - result = type("TokenResult", (), {"token": token_id, "text": decoded_token})() - - results.append(result) - - all_outputs[0].append(token_id) - current_pos += 1 - - if token_id == 151645 or token_id == 151643: - logger.warning("Reached End token") - break - - # Early stopping (exactly like test_end2end.py) - if len(all_outputs[0]) >= 5 and all(t == all_outputs[0][-1] for t in all_outputs[0][-5:]): - logger.warning(f"Detected exact repetition of token {all_outputs[0][-1]} five times in a row. Stopping.") - break - - # Final response (exactly like test_end2end.py) - response = model_args.tokenizer.decode(all_outputs[0], skip_special_tokens=True) - logger.info(f"📝 Final Generated Response:\n{response}") - logger.info(f"📝 Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") - chat = parse_chat_output(response) - display_chat(logger, chat) - - logger.info(f"Generated {len(results)} tokens successfully") - return results - - -def validate_e2e_outputs(results, expected_min_tokens=1): - """Validate end-to-end pipeline outputs.""" - if not results: - logger.error("No results generated from E2E pipeline") - return False - - if len(results) < expected_min_tokens: - logger.warning(f"Generated only {len(results)} tokens, expected at least {expected_min_tokens}") - return False - - # Check if tokens are valid - for result in results: - if not hasattr(result, "token") or not hasattr(result, "text"): - logger.error("Invalid result format") - return False - - logger.info("E2E pipeline validation passed") - return True - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@skip_for_blackhole("Failing on DRAM harvested P100a, see #21419") -@pytest.mark.timeout(1800) -@pytest.mark.models_performance_bare_metal -@pytest.mark.parametrize( - "weights, layers", - [ - ("instruct", None), - ], - ids=["full"], -) -@pytest.mark.parametrize( - "paged_attention", - ( - True, - # False, - ), - ids=( - "paged_attention", - # "default_attention", - ), -) -@pytest.mark.parametrize( - "page_params", - [{"page_block_size": 32, "page_max_num_blocks": 1024}], -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -@pytest.mark.parametrize( - "max_seq_len", - (2048,), # Use smaller seq_len like test_end2end.py to avoid memory issues -) -@pytest.mark.parametrize( - "optimizations", - [ - lambda model_args: DecodersPrecision.accuracy(model_args.n_layers, model_args.model_name), - ], - ids=["accuracy"], -) -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -# @pytest.mark.parametrize("device_params", [{"l1_small_size": 1584864, "trace_region_size": 0}], indirect=True) -def test_e2e_vision_text_pipeline( - weights, - layers, - max_seq_len, - batch_size, - paged_attention, - page_params, - optimizations, - mesh_device, - reset_seeds, - request, - device_params, -): - """Test end-to-end vision-text pipeline using proper Generator methods.""" - logger.info("Starting E2E vision-text pipeline test") - - # Use bfloat8_b like test_end2end.py for better memory efficiency - dtype = ttnn.bfloat8_b - - # Setup vision-enabled model configuration - model_args, instruct = setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations) - - if layers is not None: - model_args.n_layers = layers - - # Setup vision prompts and tokenizer - messages, tokenizer = setup_vision_prompts_and_tokenizer(model_args, instruct) - - # Process real vision inputs from images - processed_inputs = process_real_vision_inputs(messages, model_args) - - # Load separate models following test_end2end.py pattern - logger.info("Loading separate vision and text models like test_end2end.py...") - vision_model, text_model = load_separate_models_like_test_end2end( - model_args, mesh_device, dtype, paged_attention, page_params - ) - - # Setup page table for paged attention (exactly like test_end2end.py) - page_table_tt = None - paged_attention_config = None - - # Prepare page table for paged attention (exactly like test_end2end.py) - page_table = None - if paged_attention: - paged_attention_config = PagedAttentionConfig( - block_size=page_params["page_block_size"], - max_num_blocks=page_params["page_max_num_blocks"], - ) - # Implied shuffling of blocks - permutation = torch.randperm(paged_attention_config.max_num_blocks) - # Page table which maps virtual blocks to physical - reverse_permutation = torch.argsort(permutation) - page_table = reverse_permutation.reshape( - model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size - ) - page_table_tt = ttnn.from_torch( - page_table, - device=mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, -2) if batch_size > 1 else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - # Run generation following EXACT test_end2end.py pattern - logger.info("Running generation following EXACT test_end2end.py pattern...") - results = run_generation_exactly_like_test_end2end( - vision_model, text_model, processed_inputs, model_args, page_table, paged_attention_config, max_gen_len=10 - ) - - # Validate results - validation_passed = validate_e2e_outputs(results, expected_min_tokens=1) - - # Final validation - if validation_passed and len(results) > 0: - logger.info("✅ E2E vision-text pipeline test PASSED!") - logger.info(f"Successfully generated {len(results)} tokens") - - # Log generated tokens for debugging - for i, result in enumerate(results[:5]): - logger.info(f"Token {i}: {result.token} -> '{result.text}'") - else: - logger.error("❌ E2E pipeline test failed") - assert False, f"E2E pipeline failed - generated {len(results)} tokens, validation: {validation_passed}" diff --git a/models/experimental/qwen25_vl/tests/test_image_merger.py b/models/experimental/qwen25_vl/tests/test_image_merger.py deleted file mode 100644 index 442165307a72..000000000000 --- a/models/experimental/qwen25_vl/tests/test_image_merger.py +++ /dev/null @@ -1,103 +0,0 @@ -""""Test for Qwen 2.5 VL Patch Merger""" - -from loguru import logger - -import torch -import pytest -import os -import ttnn -from models.experimental.qwen25_vl.tt.patch_merger import TTQwen2_5_VLPatchMerger - -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull -from models.tt_transformers.tt.model_config import ModelArgs - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("device"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -@pytest.mark.parametrize( - "seq_len", - (128,), -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -def test_patch_merger_inference(seq_len, batch_size, reset_seeds, device): - dtype = ttnn.bfloat16 - mode = "decode" if seq_len <= 32 else "prefill" - - tt_model_args = ModelArgs( - device, - max_batch_size=batch_size, - max_seq_len=128, - ) - - tt_model_args.n_layers = 1 - state_dict = tt_model_args.load_state_dict() - - reference_model = tt_model_args.reference_vision_qwen_merger() # Qwen Patch merger - first_layer_prefix = "visual.merger." - - partial_state_dict = { - k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) - } - - reference_model.load_state_dict(partial_state_dict) - - tt_model = TTQwen2_5_VLPatchMerger( - device=device, - dim=5120, - state_dict=state_dict, - weight_key=first_layer_prefix, - layer_num=None, - state_dict_prefix="", - weight_cache_path=None, - weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, - weight_dtype=ttnn.bfloat16, - is_distributed=None, - eps=1e-06, - dims=3584, - context_dim=1280, - spatial_merge_size=2, - mode=mode, - ) - - input = torch.rand(1, 4, 1280) - reference_output = reference_model(input) - tt_input = ttnn.from_torch( - input, - device=device, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) - - tt_input = ttnn.reshape(tt_input, [1, 4, 1280]) - - tt_output = tt_model(tt_input) - - # DistributedNorm outputs are replicated across devices - tt_output_torch = ttnn.to_torch( - tt_output, - ) - - passing, pcc_message = comp_pcc(reference_output, tt_output_torch) - - logger.info(comp_allclose(reference_output, tt_output_torch)) - logger.info(f"PCC: {pcc_message}") - - if passing: - logger.info("Merger Passed!") - else: - logger.warning("Merger Failed!") - - assert passing, f"Merger output does not meet PCC requirement {0.99}." diff --git a/models/experimental/qwen25_vl/tests/test_image_patch_embed.py b/models/experimental/qwen25_vl/tests/test_image_patch_embed.py deleted file mode 100644 index c08423637b2f..000000000000 --- a/models/experimental/qwen25_vl/tests/test_image_patch_embed.py +++ /dev/null @@ -1,95 +0,0 @@ -""""Test for Qwen 2.5 VL Patch Embed""" - -from loguru import logger - -import torch -import pytest -import os -import ttnn -from models.experimental.qwen25_vl.tt.patch_embed import TTQwen2_5_VisionPatchEmbed - -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull -from models.tt_transformers.tt.model_config import ModelArgs - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("device"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -@pytest.mark.parametrize( - "seq_len", - (128,), -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -def test_embed_inference(seq_len, batch_size, reset_seeds, device): - dtype = ttnn.bfloat16 - mode = "decode" if seq_len <= 32 else "prefill" - - tt_model_args = ModelArgs( - device, - max_batch_size=batch_size, - max_seq_len=128, - ) - - tt_model_args.n_layers = 1 - state_dict = tt_model_args.load_state_dict() - - reference_model = tt_model_args.reference_vision_qwen_patch_embed() # Qwen Patch embed - first_layer_prefix = "visual.patch_embed." - - partial_state_dict = { - k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) - } - - reference_model.load_state_dict(partial_state_dict) - - tt_model = TTQwen2_5_VisionPatchEmbed( - device=device, - patch_size=14, - temporal_patch_size=2, - in_channels=3, - embed_dim=1280, - state_dict=state_dict, - weight_key=first_layer_prefix, - layer_num=None, - state_dict_prefix="", - weight_cache_path=None, - weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, - weight_dtype=ttnn.bfloat16, - mode=mode, - ) - - input = torch.rand(1, 1, 1380, 1176) - reference_output = reference_model(input) - - tt_input = ttnn.from_torch( - input, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG - ) - - tt_output = tt_model(tt_input) - - tt_output_torch = ttnn.to_torch( - tt_output, - ) - - passing, pcc_message = comp_pcc(reference_output, tt_output_torch) - - logger.info(comp_allclose(reference_output, tt_output_torch)) - logger.info(f"PCC: {pcc_message}") - - if passing: - logger.info("Patch embed Passed!") - else: - logger.warning("Patch embed Failed!") - - assert passing, f"Patch embed output does not meet PCC requirement {0.99}." diff --git a/models/experimental/qwen25_vl/tests/test_mlp.py b/models/experimental/qwen25_vl/tests/test_mlp.py deleted file mode 100644 index ee936435b8ec..000000000000 --- a/models/experimental/qwen25_vl/tests/test_mlp.py +++ /dev/null @@ -1,82 +0,0 @@ -""""Test for Qwen 2.5 VL Vision MLP""" - -import os - -import pytest -import torch -from loguru import logger - -import ttnn -from models.tt_transformers.tt.model_config import ModelArgs - -from models.experimental.qwen25_vl.tt.mlp import QwenTTVisionMLP -from models.utility_functions import comp_allclose, comp_pcc, nearest_32, skip_for_grayskull - - -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "batch, num_chunks", - ((1, 4),), -) -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -def test_mlp_inference(batch, num_chunks, mesh_device, reset_seeds): - dtype = ttnn.bfloat16 - model_args = ModelArgs(mesh_device) - state_dict = model_args.load_state_dict() - - # Ref model needs partial state dict, but our models use full state dict keys as cached weight names - first_layer_prefix = "visual.blocks.0.feed_forward." - partial_state_dict = { - k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) - } - - model_args.WEIGHTS_DTYPE = dtype - - dim = model_args.vision_dim - seq_len = nearest_32(model_args.vision_chunk_ntok) * num_chunks - reference_model = model_args.reference_vision_mlp() - reference_model.load_state_dict(partial_state_dict) - - tt_model = QwenTTVisionMLP( - mesh_device=mesh_device, - args=model_args, - state_dict=state_dict, - state_dict_prefix=first_layer_prefix, - weight_cache_path=model_args.weight_cache_path(dtype), - dtype=dtype, - ) - - torch_input = torch.randn(1, batch, seq_len, dim) - - reference_output = reference_model(torch_input).squeeze() - - tt_input = ttnn.from_torch( - torch_input, - device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - layout=ttnn.TILE_LAYOUT, - ) - - tt_output = tt_model(tt_input) - - tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ - :, :1, :, : - ].squeeze() - - pcc_required = 0.99 - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) - - logger.info(comp_allclose(reference_output, tt_output_torch)) - logger.info(f"PCC: {pcc_message}") - - assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/qwen25_vl/tests/test_vision_block.py b/models/experimental/qwen25_vl/tests/test_vision_block.py deleted file mode 100644 index 69bd51eb5617..000000000000 --- a/models/experimental/qwen25_vl/tests/test_vision_block.py +++ /dev/null @@ -1,116 +0,0 @@ -""""Test for Qwen 2.5 VL Vision Transformer Block""" - -import os - -import pytest -import torch -from loguru import logger - -import ttnn -from models.tt_transformers.tt.model_config import ModelArgs - -from models.experimental.qwen25_vl.tt.vision_block import TtQwen2_5_VLVisionBlock -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull - - -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "batch, num_chunks", - ((1, 4),), -) -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -def test_transformer_inference(batch, num_chunks, mesh_device, reset_seeds): - dtype = ttnn.bfloat16 - pcc_required = 0.99 - - model_args = ModelArgs(mesh_device) - state_dict = model_args.load_state_dict() - - # Ref model needs partial state dict, but our models use full state dict keys as cached weight names - first_layer_prefix = "visual.blocks.0." - partial_state_dict = { - k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) - } - - dim = model_args.vision_dim - - reference_model = model_args.reference_vision_block() - reference_model.load_state_dict(partial_state_dict) - reference_model.eval() - - vision_dim = model_args.vision_dim - n_heads = model_args.vision_attn_n_heads - head_dim = vision_dim // n_heads - seq_len = model_args.vision_chunk_ntok - 1 - - tt_model = TtQwen2_5_VLVisionBlock( - mesh_device, - state_dict=state_dict, - state_dict_prefix=first_layer_prefix, - weight_cache_path=model_args.weight_cache_path(dtype), - model_args=model_args, - dtype=dtype, - ) - - pt_attention_input = torch.randn(seq_len, vision_dim) # no batch dim - cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32) - - cos, sin = precompute_rope_cos_sin(seq_len, head_dim) - - reference_output = reference_model( - pt_attention_input, cu_seqlens, rotary_pos_emb=None, position_embeddings=(cos, sin) - ) - - tt_attention_input = model_args.prepare_residual_tensor_prefill( - pt_attention_input.unsqueeze(0), force_replicated=True - ) - - cos_tensor = ttnn.from_torch(cos, device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) - sin_tensor = ttnn.from_torch(sin, device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) - - tt_out = tt_model(tt_attention_input, cu_seqlens, position_embeddings=(cos_tensor, sin_tensor)) - - tt_output_torch = ttnn.to_torch(tt_out, device=mesh_device).squeeze(0).squeeze(0) - - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) - - logger.info(comp_allclose(reference_output, tt_output_torch)) - logger.info(f"PCC: {pcc_message}") - - assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" - - -def precompute_rope_cos_sin(seq_len: int, dim: int, theta: float = 10000.0): - """ - Precompute RoPE cos/sin tensors. - Args: - seq_len: sequence length (number of tokens) - dim: hidden size (usually head_dim, not full hidden_size) - theta: RoPE theta parameter (default 10000) - Returns: - cos, sin: [seq_len, dim] each - """ - # Build the rope frequencies - half_dim = dim // 2 - freq_seq = torch.arange(half_dim, dtype=torch.float32) - inv_freq = 1.0 / (theta ** (freq_seq / half_dim)) - - # positions: [seq_len] - positions = torch.arange(seq_len, dtype=torch.float32) - - # Outer product: [seq_len, half_dim] - sinusoid_inp = torch.outer(positions, inv_freq) - - # Concatenate for complex dim - sin = torch.sin(torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)) - cos = torch.cos(torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)) - - return cos, sin diff --git a/models/experimental/qwen25_vl/tests/test_vision_model.py b/models/experimental/qwen25_vl/tests/test_vision_model.py deleted file mode 100644 index bf0bef57ca25..000000000000 --- a/models/experimental/qwen25_vl/tests/test_vision_model.py +++ /dev/null @@ -1,84 +0,0 @@ -"""Test for Qwen 2.5 VL Vision Transformer Pretrained Model Inference""" - -import os - -import pytest -import torch -from loguru import logger - -import ttnn - -from models.tt_transformers.tt.model_config import ModelArgs - -from models.experimental.qwen25_vl.tt.vision_model import TtQwen2_5_VisionTransformerPretrainedModel -from models.utility_functions import comp_pcc, skip_for_grayskull, comp_allclose - - -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "batch, num_chunks", - ((1, 4),), -) -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -def test_vision_inference(batch, num_chunks, mesh_device, reset_seeds): - dtype = ttnn.bfloat16 - pcc_required = 0.99 - - model_args = ModelArgs(mesh_device) - state_dict = model_args.load_state_dict() - - # Ref model needs partial state dict, but our models use full state dict keys as cached weight names - first_layer_prefix = "visual." - partial_state_dict = { - k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) - } - - dim = model_args.vision_dim - - reference_model = model_args.reference_vision_model() - reference_model.load_state_dict(partial_state_dict) - reference_model.eval() - - n_layers = model_args.vision_n_layers - - tt_model = TtQwen2_5_VisionTransformerPretrainedModel( - mesh_device, - state_dict=state_dict, - state_dict_prefix=first_layer_prefix, - weight_cache_path=model_args.weight_cache_path(dtype), - model_args=model_args, - dtype=dtype, - layers=n_layers, - ) - - pt_input = torch.randn([32, 1176]) # no batch dim - grid_thw = torch.tensor([[1, 4, 8]]) - - reference_output = reference_model( - pt_input, - grid_thw, - ) - - tt_attention_input = model_args.prepare_residual_tensor_prefill(pt_input.unsqueeze(0), force_replicated=True) - tt_out = tt_model(tt_attention_input, grid_thw) - - tt_output_torch = ttnn.to_torch(tt_out, device=mesh_device) - - non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) - tt_output_torch = tt_output_torch[non_zero_indices] - reference_output = reference_output[non_zero_indices] - - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) - - logger.info(comp_allclose(reference_output, tt_output_torch)) - logger.info(f"PCC: {pcc_message}") - - assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/qwen25_vl/tests/test_vision_rms.py b/models/experimental/qwen25_vl/tests/test_vision_rms.py deleted file mode 100644 index ca453d09940c..000000000000 --- a/models/experimental/qwen25_vl/tests/test_vision_rms.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Test for Qwen 2.5 VL RMSNorm Layer Inference""" - -from loguru import logger - -import torch -import pytest -import os - -import ttnn -from models.experimental.qwen25_vl.tt.rmsnorm import RMSNorm - -from models.tt_transformers.tt.distributed_norm import DistributedNorm - - -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull -from models.tt_transformers.tt.model_config import ModelArgs - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("device"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -@pytest.mark.parametrize( - "seq_len", - (128,), -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -def test_rmsnorm_inference(seq_len, batch_size, reset_seeds, device): - dtype = ttnn.bfloat16 - mode = "decode" if seq_len <= 32 else "prefill" - - tt_model_args = ModelArgs( - device, - max_batch_size=batch_size, - max_seq_len=128, - ) - - dim = tt_model_args.vision_dim - - tt_model_args.n_layers = 1 - state_dict = tt_model_args.load_state_dict() - - reference_model = tt_model_args.reference_vision_rms_norm() # Qwen2_5 RMSNorm - first_layer_prefix = "visual.blocks.0.norm1." - - partial_state_dict = { - k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) - } - - reference_model.load_state_dict(partial_state_dict) - - tt_inner_norm = RMSNorm( - device=device, - dim=dim, - state_dict=state_dict, - state_dict_prefix="", - weight_key=first_layer_prefix[:-1], # Remove trailing dot - weight_dtype=dtype, - is_distributed=False, - sharded_program_config=tt_model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], - sharded_output_config=tt_model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], - ) - - # Wrap it in DistributedNorm - tt_model = DistributedNorm(tt_inner_norm, tt_model_args, TG=tt_model_args.is_galaxy) - - input = torch.rand(1, 1, 1280) - - reference_output = reference_model(input) - - # DistributedNorm inputs are fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode) - tt_input = ttnn.from_torch( - input, - device=device, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, -1), mesh_shape=tt_model_args.cluster_shape), - memory_config=( - tt_model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG - ), - ) - - tt_output = tt_model(tt_input, mode=mode) - - # DistributedNorm outputs are replicated across devices - tt_output_torch = ttnn.to_torch( - tt_output, - mesh_composer=ttnn.ConcatMesh2dToTensor( - device, dims=(0, 2) if tt_model_args.is_galaxy else (2, 0), mesh_shape=tt_model_args.cluster_shape - ), - )[:1, :, :] - - passing, pcc_message = comp_pcc(reference_output, tt_output_torch) - - logger.info(comp_allclose(reference_output, tt_output_torch)) - logger.info(f"PCC: {pcc_message}") - - if passing: - logger.info("rms_norm Passed!") - else: - logger.warning("rms_norm Failed!") - - assert passing, f"rms_norm output does not meet PCC requirement {0.99}." diff --git a/models/experimental/qwen25_vl/tt/attention.py b/models/experimental/qwen25_vl/tt/attention.py deleted file mode 100644 index a515bf35e737..000000000000 --- a/models/experimental/qwen25_vl/tt/attention.py +++ /dev/null @@ -1,141 +0,0 @@ -""" -This is the vision attention implementation for Qwen-VL-7B. - -We couldn't reuse the LLaMA version from tt_transformers because it expects separate q, k, v weights, -but Qwen-VL uses fused qkv weights. So this has been rewritten to support that, -based on the original code at: -models/tt_transformers/tt/multimodal/llama_image_attention.py -""" - - -import torch -import ttnn -from models.common.lightweightmodule import LightweightModule - - -def rotate_half(x): - x1 = ttnn.slice(x, (0, 0, 0), (x.shape[0], x.shape[1], x.shape[2] // 2)) - x2 = ttnn.slice(x, (0, 0, x.shape[-1] // 2), (x.shape[0], x.shape[1], x.shape[2])) - return ttnn.concat([ttnn.mul(x2, -1, use_legacy=False), x1], dim=-1) - - -def apply_rotary_pos_emb_vision_tt(q, k, cos, sin): - cos = ttnn.unsqueeze(cos, -2) - sin = ttnn.unsqueeze(sin, -2) - - q_embed = ttnn.add(ttnn.mul(q, cos), ttnn.mul(rotate_half(q), sin)) - k_embed = ttnn.add(ttnn.mul(k, cos), ttnn.mul(rotate_half(k), sin)) - return q_embed, k_embed - - -class TtQwen2_5_VLVisionSdpaAttention(LightweightModule): - def __init__(self, mesh_device, state_dict, state_dict_prefix, dtype, configuration): - super().__init__() - - self.mesh_device = mesh_device - self.dtype = dtype - self.hidden_size = 1280 - self.num_heads = 16 - self.head_dim = self.hidden_size // self.num_heads - self.scale = self.head_dim**-0.5 - - # Load qkv weight & bias (fused): shape [hidden_size, hidden_size*3] - qkv_weight = state_dict[f"{state_dict_prefix}qkv.weight"] - qkv_bias = state_dict[f"{state_dict_prefix}qkv.bias"] - - # Transpose to [hidden_size, 3*hidden_size] for matmul - self.qkv_weight = ttnn.as_tensor( - torch.transpose(qkv_weight, -2, -1), - device=mesh_device, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) - self.qkv_bias = ttnn.as_tensor( - qkv_bias, device=mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG - ) - - # Output projection: proj - proj_weight = state_dict[f"{state_dict_prefix}proj.weight"] # shape [hidden_size, hidden_size] - proj_bias = state_dict[f"{state_dict_prefix}proj.bias"] # shape [hidden_size] - - self.proj_weight = ttnn.as_tensor( - torch.transpose(proj_weight, -2, -1), - device=mesh_device, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) - self.proj_bias = ttnn.as_tensor( - proj_bias, device=mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG - ) - - self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - fp32_dest_acc_en=True, - packer_l1_acc=True, - dst_full_sync_en=False, - ) - - def forward(self, hidden_states, cu_seqlens, position_embeddings): - """ - hidden_states: ttnn.Tensor of shape [batch, seq_len, hidden_size] - position_embeddings: tuple (cos, sin) each of shape [seq_len, head_dim] - """ - seq_len = hidden_states.shape[-2] - cos, sin = position_embeddings - # Fused qkv projection - qkv = ttnn.linear( - hidden_states, - self.qkv_weight, - bias=self.qkv_bias, - dtype=self.dtype, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config, - ) # shape [batch, seq_len, hidden_size*3] - - (q, k, v) = ttnn.permute(ttnn.reshape(qkv, [seq_len, 3, self.num_heads, -1]), [1, 0, 2, 3]) - ttnn.deallocate(qkv) - - # Apply rotary position embeddings - q, k = apply_rotary_pos_emb_vision_tt(q, k, cos, sin) - # return q - - seq_len = cu_seqlens[-1].item() - - q = ttnn.unsqueeze(ttnn.permute(ttnn.pad(q, [(0, 0), (0, 0), (0, 16)], 0), [1, 0, 2]), 0) - k = ttnn.unsqueeze(ttnn.permute(ttnn.pad(k, [(0, 0), (0, 0), (0, 16)], 0), [1, 0, 2]), 0) - v = ttnn.unsqueeze(ttnn.permute(ttnn.pad(v, [(0, 0), (0, 0), (0, 16)], 0), [1, 0, 2]), 0) - - attn_output = ttnn.transformer.scaled_dot_product_attention( - q, k, v, is_causal=False, scale=self.scale - ) # shape [1, seq_len, num_heads, head_dim] - - ttnn.deallocate(q) - ttnn.deallocate(k) - ttnn.deallocate(v) - - # attn_output shape: [1, 16, 4096, 96] - # Need to slice back from 96 → 80 - attn_output = ttnn.slice( - attn_output, - (0, 0, 0, 0), - (attn_output.shape[0], attn_output.shape[1], attn_output.shape[2], self.head_dim), # head_dim=80 - ) - - attn_output = ttnn.permute(ttnn.squeeze(attn_output, 0), [1, 0, 2]) - attn_output = ttnn.reshape(attn_output, [seq_len, -1]) - - # Final projection - output = ttnn.linear( - attn_output, - self.proj_weight, - bias=self.proj_bias, - dtype=self.dtype, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config, - ) - - ttnn.deallocate(attn_output) - - return output diff --git a/models/experimental/qwen25_vl/tt/mlp.py b/models/experimental/qwen25_vl/tt/mlp.py deleted file mode 100644 index e1ebfe02d6ed..000000000000 --- a/models/experimental/qwen25_vl/tt/mlp.py +++ /dev/null @@ -1,121 +0,0 @@ -""" -This is the MLP (feed-forward) implementation for Qwen-VL-7B. - -We couldn't reuse TtLlamaImageFeedForward from tt_transformers because the logic is different. -Qwen does: down_proj(act_fn(gate_proj(x)) * up_proj(x)) -Tt does: c_proj(activation(c_fc(x))) - -So this version was written specifically for Qwen, based on its architecture. -""" - -import torch - -import ttnn -from models.common.lightweightmodule import LightweightModule - - -class QwenTTVisionMLP(LightweightModule): - def __init__( - self, - mesh_device, - args, - state_dict, - weight_cache_path, - dtype, - state_dict_prefix=None, - ): - super().__init__() - - self.mesh_device = mesh_device - self.args = args - self.state_dict = state_dict - self.dim = args.dim - - def get_weight(name): - return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) - - def get_bias(name): - return state_dict[f"{state_dict_prefix}{name}.bias"] - - def cache_name(name): - if args.dummy_weights: - return None - return weight_cache_path / f"{state_dict_prefix}.{name}" - - def as_tensor(name, dtype, is_bias=False): - tensor_data = get_bias(name) if is_bias else get_weight(name) - return ttnn.as_tensor( - tensor_data, - dtype=dtype, - device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - # cache_file_name=cache_name(name), - ) - - # Weights and Biases - self.w1 = as_tensor("w1", dtype) - self.b1 = as_tensor("w1", ttnn.bfloat16, is_bias=True) - - self.w3 = as_tensor("w3", dtype) - self.b3 = as_tensor("w3", ttnn.bfloat16, is_bias=True) - - self.w2 = as_tensor("w2", dtype) - self.b2 = as_tensor("w2", ttnn.bfloat16, is_bias=True) - - self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - fp32_dest_acc_en=True, - packer_l1_acc=True, - dst_full_sync_en=False, - ) - - def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: - """ - Qwen HF MLP reference: - output = down_proj(act_fn(gate_proj(x)) * up_proj(x)) - Mapping: - w1 -> gate_proj - w3 -> up_proj - w2 -> down_proj - """ - - # Linear with GELU activation - w1_out = ttnn.linear( - x, - self.w1, - bias=self.b1, - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - activation="silu", - compute_kernel_config=self.compute_kernel_config, - ) - - w3_out = ttnn.linear( - x, - self.w3, - bias=self.b3, - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config, - ) - - # Element-wise multiply - w2_in = ttnn.mul(w1_out, w3_out, dtype=ttnn.bfloat16) - - # Final projection - w2_out = ttnn.linear( - w2_in, - self.w2, - bias=self.b2, - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config, - ) - - ttnn.deallocate(w1_out) - ttnn.deallocate(w3_out) - ttnn.deallocate(w2_in) - - return w2_out diff --git a/models/experimental/qwen25_vl/tt/model.py b/models/experimental/qwen25_vl/tt/model.py deleted file mode 100644 index 50cbbe924784..000000000000 --- a/models/experimental/qwen25_vl/tt/model.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -This is the end-to-end pipeline for the Qwen-VL 2.5 model. - -The `Qwen25VLTransformer` class inherits from the `Transformer` class in tt_transformers. -It overrides `prepare_inputs_prefill` to run inference on the vision model and -pass the resulting visual tokens to the text model along with text tokens. -""" - - -import ttnn -import torch - -from models.tt_transformers.tt.model import Transformer - - -class Qwen25VLTransformer(Transformer): - def __init__( - self, - args, - dtype, - mesh_device, - state_dict, - weight_cache_path, - paged_attention_config=None, - use_paged_kv_cache=False, - ): - super().__init__( - args, - dtype, - mesh_device, - state_dict, - weight_cache_path, - paged_attention_config=paged_attention_config, - use_paged_kv_cache=use_paged_kv_cache, - ) - - def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): - """ - Inputs are torch tensors or python types. This function returns ttnn - tensors on device. - TODO: Debate whether this function is responsible for padding - """ - - tokens = tokens.reshape(1, 1, 1, -1) - S = tokens.shape[-1] - tokens = ttnn.from_torch( - tokens, - device=self.mesh_device, - dtype=ttnn.uint32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - # self.embed_scale = args.dim**0.5 - tokens_embd = self.embd(tokens) - # tokens_embd = ttnn.multiply(tokens_embd, self.embed_scale) - - pixel_values = kwargs["processed_inputs"]["pixel_values"] - input_ids = kwargs["processed_inputs"]["input_ids"] - image_grid_thw = kwargs["processed_inputs"]["image_grid_thw"] - - vision_model = kwargs["vision_model"] - pixel_values = self.args.prepare_residual_tensor_prefill(pixel_values.unsqueeze(0), force_replicated=True) - - vision_output = vision_model(pixel_values, image_grid_thw) - - tokens_embd = ttnn.to_torch(tokens_embd) - comp_vision_output = ttnn.to_torch(ttnn.from_device(vision_output)) - - input_ids = torch.nn.functional.pad(input_ids, (0, tokens_embd.shape[1] - input_ids.shape[1]), "constant", 0) - image_features = comp_vision_output.squeeze(0) - special_image_mask = (input_ids == 151655).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(tokens_embd) - image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) - tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) - - tokens_embd = ttnn.from_torch( - tokens_embd, - dtype=ttnn.bfloat16, - device=self.mesh_device, - layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - - tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) - # Slice the rot mats to the prefill seqlen - assert ( - self.rope_setup.cos_matrix.shape[2] >= start_pos + S - ), f"Padded prefill end idx {start_pos + S} exceeds max seq len {self.rope_setup.cos_matrix.shape[2]}" - - tt_rot_mats_prefill_global = [ - self.rope_setup.cos_matrix[:, :, start_pos : start_pos + S, :], - self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], - ] - - if page_table is not None: - tt_page_table = ttnn.from_torch( - page_table, - device=self.mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - else: - tt_page_table = None - - if chunk_page_table is not None: - tt_chunk_page_table = ttnn.from_torch( - chunk_page_table, - device=self.mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - else: - tt_chunk_page_table = None - - return tokens_embd, tt_rot_mats_prefill_global, tt_page_table, tt_chunk_page_table diff --git a/models/experimental/qwen25_vl/tt/patch_embed.py b/models/experimental/qwen25_vl/tt/patch_embed.py deleted file mode 100644 index 36262c777f7e..000000000000 --- a/models/experimental/qwen25_vl/tt/patch_embed.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -This is the patch embedding implementation for Qwen-VL-7B. - -The existing TtLlamaConv2dPatch from tt_transformers uses Conv2d, but Qwen needs Conv3d instead. -Since the stride size is the same as the kernel size for this operation, we can use a matrix -multiplication (matmul) instead of a convolution. This is necessary because -`ttnn.experimental.conv3d` currently only supports Conv3d with stride (1, 1, 1). -""" - -import ttnn - - -class TTQwen2_5_VisionPatchEmbed: - def __init__( - self, - device, - patch_size, - temporal_patch_size, - in_channels, - embed_dim, - state_dict, - weight_key, - layer_num=None, - state_dict_prefix="", - weight_cache_path=None, - weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, - weight_dtype=ttnn.bfloat16, - mode="decode", - ): - super().__init__() - self.mode = mode - self.device = device - self.patch_size = patch_size - self.temporal_patch_size = temporal_patch_size - self.in_channels = in_channels - self.embed_dim = embed_dim - self.weight_memory_config = weight_memory_config - self.weight_dtype = weight_dtype - - weight_name_1 = f"{state_dict_prefix}{weight_key}proj.weight" - torch_weight = state_dict[weight_name_1] - - weight_matrix = torch_weight.view(self.embed_dim, -1) - self.weight = ttnn.from_torch( - weight_matrix.T, - device=self.device, - dtype=self.weight_dtype, - layout=ttnn.TILE_LAYOUT, - memory_config=self.weight_memory_config, - ) - self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - fp32_dest_acc_en=True, - packer_l1_acc=True, - dst_full_sync_en=False, - ) - - def __call__(self, x: ttnn.Tensor) -> ttnn.Tensor: - x_flattened = ttnn.reshape(x, (x.shape[2], -1)) - output = ttnn.matmul(x_flattened, self.weight, compute_kernel_config=self.compute_kernel_config) - - return output diff --git a/models/experimental/qwen25_vl/tt/patch_merger.py b/models/experimental/qwen25_vl/tt/patch_merger.py deleted file mode 100644 index a1001b0ec96d..000000000000 --- a/models/experimental/qwen25_vl/tt/patch_merger.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -This is the patch merger implementation used in the Qwen-VL-7B model. - -There's no existing implementation for this in tt_transformers, -so it was written specifically based on Qwen-VL's architecture. -""" - -import ttnn -from models.experimental.qwen25_vl.tt.rmsnorm import RMSNorm -from models.tt_transformers.tt.model_config import ModelArgs - - -class TTQwen2_5_VLPatchMerger: - def __init__( - self, - device, - dim, - state_dict, - weight_key, - layer_num=None, - state_dict_prefix="", - weight_cache_path=None, - weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, - weight_dtype=ttnn.bfloat16, - is_distributed=None, - eps: float = 1e-06, - dims=3584, - context_dim=1280, - spatial_merge_size=2, - mode="decode", - ): - super().__init__() - self.eps = eps - self.mode = mode - - tt_model_args = ModelArgs( - device, - max_batch_size=1, - max_seq_len=128, - ) - - weight_name_1 = f"{state_dict_prefix}{weight_key}ln_q.weight" - weight_name_2 = f"{state_dict_prefix}{weight_key}feed_forward.0.weight" - weight_name_3 = f"{state_dict_prefix}{weight_key}feed_forward.2.weight" - - bias_name_2 = f"{state_dict_prefix}{weight_key}feed_forward.0.bias" - bias_name_3 = f"{state_dict_prefix}{weight_key}feed_forward.2.bias" - - self.weight_1 = ttnn.as_tensor( - state_dict[weight_name_1], - device=device, - dtype=weight_dtype, - layout=ttnn.ROW_MAJOR_LAYOUT, - memory_config=weight_memory_config, - ) - - self.weight_2 = ttnn.as_tensor( - state_dict[weight_name_2], - device=device, - dtype=weight_dtype, - layout=ttnn.TILE_LAYOUT, - memory_config=weight_memory_config, - ) - - self.weight_3 = ttnn.as_tensor( - state_dict[weight_name_3], - device=device, - dtype=weight_dtype, - layout=ttnn.TILE_LAYOUT, - memory_config=weight_memory_config, - ) - - self.hidden_size = context_dim * (spatial_merge_size**2) - self.ln_q = RMSNorm( - device=device, - dim=1280, - state_dict=state_dict, - state_dict_prefix="", - weight_key="visual.merger.ln_q", - weight_dtype=ttnn.bfloat16, - is_distributed=False, - sharded_program_config=tt_model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], - sharded_output_config=False, - ) - - self.weight_3 = ttnn.transpose(self.weight_3, 0, 1) - - self.weight_2 = ttnn.transpose(self.weight_2, 0, 1) - - self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - fp32_dest_acc_en=True, - packer_l1_acc=True, - dst_full_sync_en=False, - ) - - def __call__(self, x): - x = self.ln_q(x, mode=self.mode) - - x = ttnn.reshape(x, (-1, self.hidden_size)) - - x = ttnn.linear( - x, - self.weight_2, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config, - ) - x = ttnn.gelu(x) - - x = ttnn.linear( - x, - self.weight_3, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config, - ) - - return x diff --git a/models/experimental/qwen25_vl/tt/rmsnorm.py b/models/experimental/qwen25_vl/tt/rmsnorm.py deleted file mode 100644 index 025ffcfc638b..000000000000 --- a/models/experimental/qwen25_vl/tt/rmsnorm.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -This is a modified RMSNorm implementation for Qwen-VL-7B. - -It's based on the existing RMSNorm in models/common/rmsnorm.py, -with slight changes to support the bf8 data type. -""" - - -import ttnn -from models.common.lightweightmodule import LightweightModule - -TILE = 32 -SHARD_HEIGHT = TILE # Current ttnn.rms_norm implementation requires shard height to be a single tile - - -class RMSNorm(LightweightModule): - def __init__( - self, - device, - dim, - state_dict, - weight_key, - layer_num=None, - state_dict_prefix=None, - weight_cache_path=None, - weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, - weight_dtype=ttnn.bfloat16, - is_distributed=None, - eps: float = 1e-06, - sharded_program_config=None, - sharded_output_config=None, - output_mem_config=None, - ccl_topology=ttnn.Topology.Ring, - ): - super().__init__() - self.eps = eps - self.is_distributed = is_distributed - self.ccl_topology = ccl_topology - - if state_dict_prefix: - weight_name = f"{state_dict_prefix}{weight_key}.weight" - else: - if layer_num is None: - weight_name = f"{weight_key}.weight" - else: - weight_name = f"layers.{layer_num}.{weight_key}.weight" - torch_weight = ( - state_dict[weight_name].unsqueeze(0).view(1, 1, dim).reshape([1, 1, dim // SHARD_HEIGHT, SHARD_HEIGHT]) - ) - - cache_name = None if weight_cache_path is None else weight_cache_path / weight_name - - # Compatibility with models that don't use mesh devices (e.g. single-chip Mistral-7b) - is_mesh_device = device.__class__.__name__ == "MeshDevice" - - self.weight = ttnn.as_tensor( - torch_weight, - device=device, - dtype=weight_dtype, - layout=ttnn.TILE_LAYOUT, - memory_config=weight_memory_config, - # cache_file_name=cache_name, - mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, - ) - - if self.is_distributed: - self.weight_distributed = ttnn.as_tensor( - torch_weight, - device=device, - dtype=weight_dtype, - layout=ttnn.ROW_MAJOR_LAYOUT, - memory_config=weight_memory_config, - # cache_file_name=cache_name, - mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape)) - if is_mesh_device - else None, - ) - - self.sharded_output_config = sharded_output_config - self.sharded_program_config = sharded_program_config - self.output_mem_config = output_mem_config - - self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi2, - math_approx_mode=False, - fp32_dest_acc_en=True, - packer_l1_acc=True, - ) - - def forward(self, x: ttnn.Tensor, mode="decode", in_sharded=False, out_sharded=False) -> ttnn.Tensor: - # If input is sharded do sharded RMSNorm and optionally return sharded output - program_config = self.sharded_program_config if in_sharded else None - memory_config = self.sharded_output_config if out_sharded else None - distributed = self.is_distributed and self.is_distributed(mode) - norm = self._distributed_rmsnorm - weight = self.weight_distributed if distributed else self.weight - - if in_sharded: - assert not distributed, "Distributed RMSNorm does not support sharded inputs" - else: - assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor" - - x = norm( - x, - epsilon=self.eps, - weight=weight, - program_config=program_config, - memory_config=memory_config, - compute_kernel_config=self.compute_kernel_config_hifi2, - ) - - if in_sharded and not out_sharded: - return ttnn.sharded_to_interleaved(x) - else: - return x - - def _distributed_rmsnorm( - self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None - ): - inp = ttnn.sharded_to_interleaved(inp, ttnn.DRAM_MEMORY_CONFIG) - - xnorm = ttnn.pow(inp, 2) - - xnorm = ttnn.mean(xnorm, dim=-1, keepdim=True) - - xnorm = ttnn.rsqrt(xnorm + epsilon) - - xnorm = ttnn.multiply(inp, xnorm) - - weight = ttnn.reshape(weight, [1, 1, -1]) - - output = ttnn.multiply(xnorm, (weight), use_legacy=False) - - if memory_config is not None: - output = ttnn.to_memory_config(output, memory_config) - - ttnn.deallocate(xnorm) - ttnn.deallocate(weight) - - return output diff --git a/models/experimental/qwen25_vl/tt/rope.py b/models/experimental/qwen25_vl/tt/rope.py deleted file mode 100644 index 97a89ed8bf53..000000000000 --- a/models/experimental/qwen25_vl/tt/rope.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -This is the vision rotary embedding implementation for Qwen-VL-7B. - -The existing RotarySetup(models/tt_transformers/tt/rope.py) in tt_transformers can't be used here, -as Qwen-VL uses a different logic for applying rotary embeddings. -This version is implemented specifically to match Qwen's design. -""" - - -import torch -import ttnn - - -class TTQwen2_5_VisionRotaryEmbedding: - def __init__(self, device, dim: int, theta: float = 10000.0, mode="decode"): - self.dim = dim - self.theta = theta - self.device = device - - arange_indices = ttnn.arange(start=0, end=dim, step=2, device=device) - arange_indices = ttnn.to_layout(arange_indices, ttnn.TILE_LAYOUT) - exponent = ttnn.div(arange_indices, dim) - pow_result = ttnn.pow(theta, exponent) - recip = ttnn.reciprocal(pow_result) - self.inv_freq = ttnn.multiply(recip, 1.0) - - def __call__(self, seqlen: int): - tt_seq = ttnn.arange(end=seqlen, device=self.device) - tt_seq = ttnn.to_torch(tt_seq) - tt_inv_freq = ttnn.to_torch(self.inv_freq) - tt_freqs = torch.outer(tt_seq, tt_inv_freq) - tt_freqs = ttnn.from_torch( - tt_freqs, - device=self.device, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.L1_MEMORY_CONFIG, - ) - ttnn.deallocate(self.inv_freq) - - return tt_freqs diff --git a/models/experimental/qwen25_vl/tt/text_mlp.py b/models/experimental/qwen25_vl/tt/text_mlp.py deleted file mode 100644 index ef973009ec45..000000000000 --- a/models/experimental/qwen25_vl/tt/text_mlp.py +++ /dev/null @@ -1,114 +0,0 @@ -""" -This is the text MLP implementation for Qwen-VL-7B. - -The existing MLP in tt_transformers(models/tt_transformers/tt/mlp.py) caused "Statically Allocated Circular L1 Buffer" issue -when used with the Qwen VL model. To avoid this, the MLP was re-written without memory optimizations. -""" - -import torch - -import ttnn -from models.common.lightweightmodule import LightweightModule - - -class MLP(LightweightModule): - def __init__( - self, - mesh_device, - args, - state_dict, - weight_cache_path, - dtype, - layer_num=0, - model_config=None, - state_dict_prefix=None, - ): - super().__init__() - - self.mesh_device = mesh_device - self.args = args - self.state_dict = state_dict - self.dim = args.dim - - state_dict_prefix = state_dict_prefix or args.get_state_dict_prefix(self.__class__.__name__, layer_num) - - def get_weight(name): - return torch.transpose(state_dict[f"{state_dict_prefix}.{name}.weight"], -2, -1) - - def get_bias(name): - return state_dict[f"{state_dict_prefix}{name}.bias"] - - def cache_name(name): - if args.dummy_weights: - return None - return weight_cache_path / f"{state_dict_prefix}.{name}" - - def as_tensor(name, dtype, is_bias=False): - tensor_data = get_bias(name) if is_bias else get_weight(name) - return ttnn.as_tensor( - tensor_data, - dtype=dtype, - device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - # cache_file_name=cache_name(name), - ) - - # Weights and Biases - self.w1 = as_tensor("w1", dtype) - self.w3 = as_tensor("w3", dtype) - self.w2 = as_tensor("w2", dtype) - - self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - fp32_dest_acc_en=True, - packer_l1_acc=True, - dst_full_sync_en=False, - ) - - def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: - """ - Qwen HF MLP reference: - output = down_proj(act_fn(gate_proj(x)) * up_proj(x)) - Mapping: - w1 -> gate_proj - w3 -> up_proj - w2 -> down_proj - """ - - # Linear with GELU activation - w1_out = ttnn.linear( - x, - self.w1, - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - activation="silu", - compute_kernel_config=self.compute_kernel_config, - ) - - w3_out = ttnn.linear( - x, - self.w3, - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config, - ) - - # Element-wise multiply - w2_in = ttnn.mul(w1_out, w3_out, dtype=ttnn.bfloat16) - - # Final projection - w2_out = ttnn.linear( - w2_in, - self.w2, - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config, - ) - - ttnn.deallocate(w1_out) - ttnn.deallocate(w3_out) - ttnn.deallocate(w2_in) - - return w2_out diff --git a/models/experimental/qwen25_vl/tt/vision_block.py b/models/experimental/qwen25_vl/tt/vision_block.py deleted file mode 100644 index 72461a61c7a6..000000000000 --- a/models/experimental/qwen25_vl/tt/vision_block.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -This is the vision block used in the Qwen-VL-7B architecture -consisting of RMSnorm and self-attention layer followed by an MLP layer. -""" - -import ttnn -from models.common.lightweightmodule import LightweightModule -from models.experimental.qwen25_vl.tt.mlp import QwenTTVisionMLP -from models.experimental.qwen25_vl.tt.rmsnorm import RMSNorm -from models.experimental.qwen25_vl.tt.attention import TtQwen2_5_VLVisionSdpaAttention - - -class TtQwen2_5_VLVisionBlock(LightweightModule): - def __init__( - self, - mesh_device, - state_dict, - dtype, - model_args, - weight_cache_path=None, - state_dict_prefix=None, - ): - super().__init__() - - self.norm1 = RMSNorm( - device=mesh_device, - dim=1280, - state_dict=state_dict, - state_dict_prefix=state_dict_prefix, - weight_key="norm1", - weight_dtype=dtype, - is_distributed=False, - sharded_program_config=model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], - sharded_output_config=model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], - ) - - self.norm2 = RMSNorm( - device=mesh_device, - dim=1280, - state_dict=state_dict, - state_dict_prefix=state_dict_prefix, - weight_key="norm2", - weight_dtype=dtype, - is_distributed=False, - sharded_program_config=model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], - sharded_output_config=model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], - ) - - self.attn = TtQwen2_5_VLVisionSdpaAttention( - mesh_device, - state_dict, - state_dict_prefix=f"{state_dict_prefix}attn.", - # weight_cache_path=model_args.weight_cache_path(dtype), - dtype=dtype, - configuration=model_args, - ) - - self.mlp = QwenTTVisionMLP( - mesh_device=mesh_device, - args=model_args, - state_dict=state_dict, - state_dict_prefix=f"{state_dict_prefix}feed_forward.", - weight_cache_path=model_args.weight_cache_path(dtype), - dtype=dtype, - ) - - def forward(self, hidden_states, cu_seqlens, position_embeddings): - hidden_states = ttnn.add( - hidden_states, - self.attn( - self.norm1(hidden_states), - cu_seqlens=cu_seqlens, - position_embeddings=position_embeddings, - ), - ) - - hidden_states = ttnn.add(hidden_states, self.mlp(self.norm2(hidden_states))) - - return hidden_states diff --git a/models/experimental/qwen25_vl/tt/vision_model.py b/models/experimental/qwen25_vl/tt/vision_model.py deleted file mode 100644 index 974dbbe8daa4..000000000000 --- a/models/experimental/qwen25_vl/tt/vision_model.py +++ /dev/null @@ -1,242 +0,0 @@ -""" -This is the end-to-end architecture of the Qwen-VL 2.5 vision model. - -It brings together all components—patch embedding, vision blocks, rotary embeddings, -and patch merger for visual input processing. -""" - -import ttnn -from tqdm import tqdm -from models.common.lightweightmodule import LightweightModule -from models.experimental.qwen25_vl.tt.vision_block import TtQwen2_5_VLVisionBlock -from models.experimental.qwen25_vl.tt.patch_embed import TTQwen2_5_VisionPatchEmbed -from models.experimental.qwen25_vl.tt.rope import TTQwen2_5_VisionRotaryEmbedding -from models.experimental.qwen25_vl.tt.patch_merger import TTQwen2_5_VLPatchMerger - -import torch -import torch.nn.functional as F - - -class TtQwen2_5_VisionTransformerPretrainedModel(LightweightModule): - def __init__( - self, - mesh_device, - state_dict, - state_dict_prefix, - weight_cache_path, - dtype, - model_args, - layers, - block_key="", - gated=False, - ): - self.spatial_merge_size = model_args.spatial_merge_size - self.patch_size = model_args.vision_patch_size - self.fullatt_block_indexes = model_args.fullatt_block_indexes - self.window_size = model_args.window_size - self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size - self.mesh_device = mesh_device - hidden_size = model_args.vision_dim - n_heads = model_args.vision_attn_n_heads - out_hidden_size = model_args.out_hidden_size - temporal_patch_size = model_args.temporal_patch_size - - self.patch_embed = TTQwen2_5_VisionPatchEmbed( - device=mesh_device, - patch_size=self.patch_size, - temporal_patch_size=temporal_patch_size, - in_channels=3, - embed_dim=hidden_size, - state_dict=state_dict, - weight_key="patch_embed.", - layer_num=None, - state_dict_prefix=state_dict_prefix, - weight_cache_path=None, - weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, - weight_dtype=ttnn.bfloat16, - ) - - head_dim = hidden_size // n_heads - - self.rotary_pos_emb = TTQwen2_5_VisionRotaryEmbedding( - device=mesh_device, - dim=head_dim // 2, - theta=10000.0, - ) - - self.blocks = [ - TtQwen2_5_VLVisionBlock( - mesh_device=mesh_device, - state_dict=state_dict, - state_dict_prefix=f"{state_dict_prefix}blocks.{i}.", - weight_cache_path=weight_cache_path, - dtype=dtype, - model_args=model_args, - ) - for i in tqdm(range(layers), desc=f"Loading vision transformer blocks") - ] - - self.merger = TTQwen2_5_VLPatchMerger( - device=mesh_device, - dim=5120, - state_dict=state_dict, - state_dict_prefix=state_dict_prefix, - weight_key="merger.", - layer_num=None, - weight_cache_path=None, - weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, - weight_dtype=ttnn.bfloat16, - is_distributed=None, - eps=1e-06, - dims=out_hidden_size, - context_dim=hidden_size, - spatial_merge_size=self.spatial_merge_size, - ) - - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb_full = ttnn.to_torch(rotary_pos_emb_full, device=self.mesh_device) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - - def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 - vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size - - for grid_t, grid_h, grid_w in grid_thw: - llm_grid_h, llm_grid_w = ( - grid_h // self.spatial_merge_size, - grid_w // self.spatial_merge_size, - ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) - index_padded = index_padded.reshape( - grid_t, - num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size, - ) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, - num_windows_h * num_windows_w, - vit_merger_window_size, - vit_merger_window_size, - ) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() - window_index = torch.cat(window_index, dim=0) - - return window_index, cu_window_seqlens - - def forward(self, hidden_states, grid_thw): - hidden_states = self.patch_embed(hidden_states) - - rotary_pos_emb = self.rot_pos_emb(grid_thw) - - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - # device=hidden_states.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) - - seq_len = hidden_states.shape[-2] - - hidden_states = ttnn.reshape(hidden_states, [seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1]) - # hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - tt_index = ttnn.from_torch( - window_index.view(-1, 1, 1).expand(-1, hidden_states.shape[-2], hidden_states.shape[-1]).permute(1, 2, 0), - device=self.mesh_device, - dtype=ttnn.uint32, - layout=ttnn.TILE_LAYOUT, - # memory_config=ttnn.L1_MEMORY_CONFIG, - ) - - hidden_states = ttnn.gather(ttnn.permute(hidden_states, (1, 2, 0)), dim=-1, index=tt_index) - hidden_states = ttnn.permute(hidden_states, (2, 0, 1)) - - hidden_states = ttnn.reshape(hidden_states, [seq_len, -1]) - # hidden_states = hidden_states.reshape(seq_len, -1) - - rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - rotary_pos_emb = rotary_pos_emb[window_index, :, :] - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - - cos_tensor = ttnn.from_torch(emb.cos(), device=self.mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) - sin_tensor = ttnn.from_torch(emb.sin(), device=self.mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) - - position_embeddings = (cos_tensor, sin_tensor) - - ttnn.deallocate(tt_index) - - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - - for layer_num, blk in enumerate(self.blocks): - if layer_num in self.fullatt_block_indexes: - cu_seqlens_now = cu_seqlens - else: - cu_seqlens_now = cu_window_seqlens - - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) - - ttnn.deallocate(cos_tensor) - ttnn.deallocate(sin_tensor) - hidden_states = self.merger(hidden_states) - reverse_indices = torch.argsort(window_index) - - tt_reverse_indices = ttnn.from_torch( - reverse_indices.view(-1, 1).expand(-1, hidden_states.shape[-1]).transpose(0, 1), - device=self.mesh_device, - dtype=ttnn.uint32, - layout=ttnn.TILE_LAYOUT, - ) - hidden_states = ttnn.gather(ttnn.permute(hidden_states, (1, 0)), dim=-1, index=tt_reverse_indices) - hidden_states = ttnn.permute(hidden_states, (1, 0)) - - return hidden_states