diff --git a/models/experimental/qwen25_vl/tests/test_e2e.py b/models/experimental/qwen25_vl/tests/test_e2e.py deleted file mode 100644 index 63a9b0ad0b00..000000000000 --- a/models/experimental/qwen25_vl/tests/test_e2e.py +++ /dev/null @@ -1,408 +0,0 @@ -"""Test for Qwen 2.5 VL End-to-End Vision-Text Pipeline""" - -import torch -import pytest -from loguru import logger -import os -import ttnn -from models.tt_transformers.tt.common import ( - sample_host, - PagedAttentionConfig, - preprocess_inputs_prefill, -) -from models.tt_transformers.tt.model_config import DecodersPrecision - -from models.experimental.qwen25_vl.tt.model import Qwen25VLTransformer as Transformer - -from models.tt_transformers.tt.generator import Generator - -from models.experimental.qwen25_vl.tt.vision_model import TtQwen2_5_VisionTransformerPretrainedModel -from models.utility_functions import skip_for_grayskull, skip_for_blackhole - -from models.tt_transformers.tt.model_config import ModelArgs -from transformers import AutoProcessor - -import re - - -def parse_chat_output(text): - """Parse chat output format from generated text.""" - pattern = r"<\|(?Puser|assistant)\|>\s*(?P.*?)(?=<\|(?:user|assistant|end)\|>|$)" - matches = re.finditer(pattern, text, re.DOTALL) - return [(match.group("role"), match.group("message").strip()) for match in matches] - - -def display_chat(logger, conversation): - """Display chat conversation in formatted output.""" - for role, message in conversation: - if role == "user": - logger.info(f"👤 User: {message}") - elif role == "assistant": - logger.info(f"🤖 Assistant: {message}") - - -def setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations): - """Setup model arguments for vision-enabled model (Single Responsibility).""" - instruct = True if weights == "instruct" else False - - model_args = ModelArgs( - mesh_device=mesh_device, - instruct=instruct, - optimizations=optimizations, - max_seq_len=max_seq_len, - max_batch_size=batch_size, - ) - - return model_args, instruct - - -def setup_vision_prompts_and_tokenizer(model_args, instruct): - """Setup multimodal prompts and tokenizer for vision-enabled model.""" - messages = [ - { - "role": "user", - "content": [ - { - "type": "image", - "image": "https://raw.githubusercontent.com/yavuzceliker/sample-images/refs/heads/main/images/image-1.jpg", - }, - {"type": "text", "text": "Describe this image in detail in 1000 words."}, - ], - } - ] - - tokenizer = model_args.tokenizer - return messages, tokenizer - - -def process_real_vision_inputs(messages, model_args): - """Process real image inputs using AutoProcessor (Interface Segregation).""" - processor = AutoProcessor.from_pretrained(os.getenv("HF_MODEL")) - - encoded = processor.apply_chat_template( - messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" - ).to("cpu", dtype=torch.bfloat16) - - input_ids = encoded["input_ids"] - pixel_values = encoded["pixel_values"] - attention_mask = encoded["attention_mask"] - image_grid_thw = encoded["image_grid_thw"] - - # logger.info(f"Processed vision inputs - input_ids: {input_ids.shape}, pixel_values: {pixel_values.shape}") - - return { - "input_ids": input_ids, - "pixel_values": pixel_values, - "attention_mask": attention_mask, - "image_grid_thw": image_grid_thw, - "processor": processor, - } - - -def load_separate_models_like_test_end2end(model_args, mesh_device, dtype, paged_attention, page_params): - """Load separate vision and text models following test_end2end.py pattern.""" - state_dict = model_args.load_state_dict() - - vision_prefix = "visual." - - # Setup paged attention config (exactly like test_end2end.py) - paged_attention_config = None - if paged_attention: - paged_attention_config = PagedAttentionConfig( - block_size=page_params["page_block_size"], - max_num_blocks=page_params["page_max_num_blocks"], - ) - - # Load vision model (exactly like test_end2end.py) - vision_model = TtQwen2_5_VisionTransformerPretrainedModel( - mesh_device=mesh_device, - state_dict=state_dict, - state_dict_prefix=vision_prefix, - dtype=dtype, - model_args=model_args, - weight_cache_path=model_args.weight_cache_path(dtype), - layers=model_args.vision_n_layers, - ) - # Load text model (exactly like test_end2end.py) - - text_model = Transformer( - args=model_args, - mesh_device=mesh_device, - dtype=dtype, - state_dict=state_dict, - weight_cache_path=model_args.weight_cache_path(dtype), - paged_attention_config=paged_attention_config, - ) - - logger.info("Separate vision and text models loaded like test_end2end.py") - return vision_model, text_model - - -def run_generation_exactly_like_test_end2end( - vision_model, text_model, processed_inputs, model_args, page_table=None, paged_attention_config=None, max_gen_len=20 -): - """Run generation following the EXACT pattern from test_end2end.py.""" - input_ids = processed_inputs["input_ids"] - pixel_values = processed_inputs["pixel_values"] - - logger.info("Running generation exactly like test_end2end.py...") - - logger.info("Running Vision Model...") - - generator = Generator([text_model], [model_args], vision_model.mesh_device, tokenizer=model_args.tokenizer) - - tt_kv_cache = [[l.attention.layer_past for l in text_model.layers]] if paged_attention_config else None - - input_tokens_prefill = input_ids - batch_size = input_tokens_prefill.shape[0] - - prompt_text = model_args.tokenizer.decode(input_ids[0].tolist()) - input_prompts = [prompt_text] - - ( - input_tokens_prefill_pt, - encoded_prompts, - decoding_pos, - prefill_lens, - ) = preprocess_inputs_prefill( - input_prompts, - model_args.tokenizer, - [model_args], - instruct=True, - max_generated_tokens=max_gen_len, - max_prefill_len=8192, - ) - - input_tokens_prefill_pt = torch.stack(input_tokens_prefill_pt).view(batch_size, -1) - - logger.info("Running prefill...") - logits = generator.prefill_forward_text( - input_tokens_prefill_pt, - page_table=page_table, - kv_cache=tt_kv_cache, - prompt_lens=decoding_pos, - vision_model=vision_model, - processed_inputs=processed_inputs, - ) - - prefilled_token = torch.argmax(logits, dim=-1) - logger.info(f"Prefilled token: {prefilled_token}") - - all_outputs = [encoded_prompts[0][: prefill_lens[0]]] - all_outputs[0].append(int(prefilled_token[0].item())) - - current_pos = torch.tensor([decoding_pos[0]]) - out_tok = prefilled_token - generation_length = 2000 - - results = [] - - logger.info("Starting decode loop...") - for iteration in range(generation_length): - logger.info(f"[Text] Decoding token {iteration}, current_pos: {current_pos.item()}") - - logits = generator.decode_forward_text( - out_tok, - current_pos, - enable_trace=False, - page_table=page_table, - kv_cache=tt_kv_cache, - ) - - _, out_tok = sample_host( - logits, - temperature=0, - top_p=0.9, - ) - - token_id = out_tok[0].item() - decoded_token = model_args.tokenizer.decode([token_id]) - logger.info(f"Generated token {iteration}: ID={token_id}, text='{decoded_token}'") - - # Create result object - result = type("TokenResult", (), {"token": token_id, "text": decoded_token})() - - results.append(result) - - all_outputs[0].append(token_id) - current_pos += 1 - - if token_id == 151645 or token_id == 151643: - logger.warning("Reached End token") - break - - # Early stopping (exactly like test_end2end.py) - if len(all_outputs[0]) >= 5 and all(t == all_outputs[0][-1] for t in all_outputs[0][-5:]): - logger.warning(f"Detected exact repetition of token {all_outputs[0][-1]} five times in a row. Stopping.") - break - - # Final response (exactly like test_end2end.py) - response = model_args.tokenizer.decode(all_outputs[0], skip_special_tokens=True) - logger.info(f"📝 Final Generated Response:\n{response}") - logger.info(f"📝 Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") - chat = parse_chat_output(response) - display_chat(logger, chat) - - logger.info(f"Generated {len(results)} tokens successfully") - return results - - -def validate_e2e_outputs(results, expected_min_tokens=1): - """Validate end-to-end pipeline outputs.""" - if not results: - logger.error("No results generated from E2E pipeline") - return False - - if len(results) < expected_min_tokens: - logger.warning(f"Generated only {len(results)} tokens, expected at least {expected_min_tokens}") - return False - - # Check if tokens are valid - for result in results: - if not hasattr(result, "token") or not hasattr(result, "text"): - logger.error("Invalid result format") - return False - - logger.info("E2E pipeline validation passed") - return True - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@skip_for_blackhole("Failing on DRAM harvested P100a, see #21419") -@pytest.mark.timeout(1800) -@pytest.mark.models_performance_bare_metal -@pytest.mark.parametrize( - "weights, layers", - [ - ("instruct", None), - ], - ids=["full"], -) -@pytest.mark.parametrize( - "paged_attention", - ( - True, - # False, - ), - ids=( - "paged_attention", - # "default_attention", - ), -) -@pytest.mark.parametrize( - "page_params", - [{"page_block_size": 32, "page_max_num_blocks": 1024}], -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -@pytest.mark.parametrize( - "max_seq_len", - (2048,), # Use smaller seq_len like test_end2end.py to avoid memory issues -) -@pytest.mark.parametrize( - "optimizations", - [ - lambda model_args: DecodersPrecision.accuracy(model_args.n_layers, model_args.model_name), - ], - ids=["accuracy"], -) -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -# @pytest.mark.parametrize("device_params", [{"l1_small_size": 1584864, "trace_region_size": 0}], indirect=True) -def test_e2e_vision_text_pipeline( - weights, - layers, - max_seq_len, - batch_size, - paged_attention, - page_params, - optimizations, - mesh_device, - reset_seeds, - request, - device_params, -): - """Test end-to-end vision-text pipeline using proper Generator methods.""" - logger.info("Starting E2E vision-text pipeline test") - - # Use bfloat8_b like test_end2end.py for better memory efficiency - dtype = ttnn.bfloat8_b - - # Setup vision-enabled model configuration - model_args, instruct = setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations) - - if layers is not None: - model_args.n_layers = layers - - # Setup vision prompts and tokenizer - messages, tokenizer = setup_vision_prompts_and_tokenizer(model_args, instruct) - - # Process real vision inputs from images - processed_inputs = process_real_vision_inputs(messages, model_args) - - # Load separate models following test_end2end.py pattern - logger.info("Loading separate vision and text models like test_end2end.py...") - vision_model, text_model = load_separate_models_like_test_end2end( - model_args, mesh_device, dtype, paged_attention, page_params - ) - - # Setup page table for paged attention (exactly like test_end2end.py) - page_table_tt = None - paged_attention_config = None - - # Prepare page table for paged attention (exactly like test_end2end.py) - page_table = None - if paged_attention: - paged_attention_config = PagedAttentionConfig( - block_size=page_params["page_block_size"], - max_num_blocks=page_params["page_max_num_blocks"], - ) - # Implied shuffling of blocks - permutation = torch.randperm(paged_attention_config.max_num_blocks) - # Page table which maps virtual blocks to physical - reverse_permutation = torch.argsort(permutation) - page_table = reverse_permutation.reshape( - model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size - ) - page_table_tt = ttnn.from_torch( - page_table, - device=mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, -2) if batch_size > 1 else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - # Run generation following EXACT test_end2end.py pattern - logger.info("Running generation following EXACT test_end2end.py pattern...") - results = run_generation_exactly_like_test_end2end( - vision_model, text_model, processed_inputs, model_args, page_table, paged_attention_config, max_gen_len=10 - ) - - # Validate results - validation_passed = validate_e2e_outputs(results, expected_min_tokens=1) - - # Final validation - if validation_passed and len(results) > 0: - logger.info("✅ E2E vision-text pipeline test PASSED!") - logger.info(f"Successfully generated {len(results)} tokens") - - # Log generated tokens for debugging - for i, result in enumerate(results[:5]): - logger.info(f"Token {i}: {result.token} -> '{result.text}'") - else: - logger.error("❌ E2E pipeline test failed") - assert False, f"E2E pipeline failed - generated {len(results)} tokens, validation: {validation_passed}" diff --git a/models/experimental/qwen25_vl/tt/rope.py b/models/experimental/qwen25_vl/tt/rope.py deleted file mode 100644 index 97a89ed8bf53..000000000000 --- a/models/experimental/qwen25_vl/tt/rope.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -This is the vision rotary embedding implementation for Qwen-VL-7B. - -The existing RotarySetup(models/tt_transformers/tt/rope.py) in tt_transformers can't be used here, -as Qwen-VL uses a different logic for applying rotary embeddings. -This version is implemented specifically to match Qwen's design. -""" - - -import torch -import ttnn - - -class TTQwen2_5_VisionRotaryEmbedding: - def __init__(self, device, dim: int, theta: float = 10000.0, mode="decode"): - self.dim = dim - self.theta = theta - self.device = device - - arange_indices = ttnn.arange(start=0, end=dim, step=2, device=device) - arange_indices = ttnn.to_layout(arange_indices, ttnn.TILE_LAYOUT) - exponent = ttnn.div(arange_indices, dim) - pow_result = ttnn.pow(theta, exponent) - recip = ttnn.reciprocal(pow_result) - self.inv_freq = ttnn.multiply(recip, 1.0) - - def __call__(self, seqlen: int): - tt_seq = ttnn.arange(end=seqlen, device=self.device) - tt_seq = ttnn.to_torch(tt_seq) - tt_inv_freq = ttnn.to_torch(self.inv_freq) - tt_freqs = torch.outer(tt_seq, tt_inv_freq) - tt_freqs = ttnn.from_torch( - tt_freqs, - device=self.device, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.L1_MEMORY_CONFIG, - ) - ttnn.deallocate(self.inv_freq) - - return tt_freqs diff --git a/models/experimental/qwen25_vl/tt/text_mlp.py b/models/experimental/qwen25_vl/tt/text_mlp.py deleted file mode 100644 index ef973009ec45..000000000000 --- a/models/experimental/qwen25_vl/tt/text_mlp.py +++ /dev/null @@ -1,114 +0,0 @@ -""" -This is the text MLP implementation for Qwen-VL-7B. - -The existing MLP in tt_transformers(models/tt_transformers/tt/mlp.py) caused "Statically Allocated Circular L1 Buffer" issue -when used with the Qwen VL model. To avoid this, the MLP was re-written without memory optimizations. -""" - -import torch - -import ttnn -from models.common.lightweightmodule import LightweightModule - - -class MLP(LightweightModule): - def __init__( - self, - mesh_device, - args, - state_dict, - weight_cache_path, - dtype, - layer_num=0, - model_config=None, - state_dict_prefix=None, - ): - super().__init__() - - self.mesh_device = mesh_device - self.args = args - self.state_dict = state_dict - self.dim = args.dim - - state_dict_prefix = state_dict_prefix or args.get_state_dict_prefix(self.__class__.__name__, layer_num) - - def get_weight(name): - return torch.transpose(state_dict[f"{state_dict_prefix}.{name}.weight"], -2, -1) - - def get_bias(name): - return state_dict[f"{state_dict_prefix}{name}.bias"] - - def cache_name(name): - if args.dummy_weights: - return None - return weight_cache_path / f"{state_dict_prefix}.{name}" - - def as_tensor(name, dtype, is_bias=False): - tensor_data = get_bias(name) if is_bias else get_weight(name) - return ttnn.as_tensor( - tensor_data, - dtype=dtype, - device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - # cache_file_name=cache_name(name), - ) - - # Weights and Biases - self.w1 = as_tensor("w1", dtype) - self.w3 = as_tensor("w3", dtype) - self.w2 = as_tensor("w2", dtype) - - self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - fp32_dest_acc_en=True, - packer_l1_acc=True, - dst_full_sync_en=False, - ) - - def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: - """ - Qwen HF MLP reference: - output = down_proj(act_fn(gate_proj(x)) * up_proj(x)) - Mapping: - w1 -> gate_proj - w3 -> up_proj - w2 -> down_proj - """ - - # Linear with GELU activation - w1_out = ttnn.linear( - x, - self.w1, - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - activation="silu", - compute_kernel_config=self.compute_kernel_config, - ) - - w3_out = ttnn.linear( - x, - self.w3, - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config, - ) - - # Element-wise multiply - w2_in = ttnn.mul(w1_out, w3_out, dtype=ttnn.bfloat16) - - # Final projection - w2_out = ttnn.linear( - w2_in, - self.w2, - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config, - ) - - ttnn.deallocate(w1_out) - ttnn.deallocate(w3_out) - ttnn.deallocate(w2_in) - - return w2_out diff --git a/models/tt_transformers/demo/simple_vision_demo.py b/models/tt_transformers/demo/simple_vision_demo.py index 7d21da9ca274..22dd0b002335 100644 --- a/models/tt_transformers/demo/simple_vision_demo.py +++ b/models/tt_transformers/demo/simple_vision_demo.py @@ -27,7 +27,9 @@ import ttnn from models.demos.utils.llm_demo_utils import create_benchmark_data, verify_perf from models.perf.benchmarking_utils import BenchmarkProfiler +from models.tt_transformers.tt.common import hf_multimodal_encode from models.tt_transformers.tt.generator import Generator +from models.tt_transformers.tt.model_config import CheckpointType def get_batch_sampler(temperature, top_p, tokenizer): @@ -62,6 +64,7 @@ def create_multimodal_model( ): from models.tt_transformers.tt.model_config import ModelArgs from models.tt_transformers.tt.multimodal.llama_vision_model import CrossAttentionTransformer + from models.tt_transformers.tt.multimodal.qwen_vl.qwen_e2e_model import TtQwen_Model tt_model_args = ModelArgs(mesh_device, max_batch_size=max_batch_size) assert tt_model_args.is_vision(), "This model is multimodal" @@ -76,14 +79,25 @@ def create_multimodal_model( if checkpoint is None: checkpoint = tt_model_args.load_state_dict() - model = CrossAttentionTransformer( - mesh_device, - state_dict=checkpoint, - weight_cache_path=tt_model_args.weight_cache_path(dtype), - dtype=dtype, - configuration=tt_model_args, - use_paged_kv_cache=use_paged_kv_cache, - ) + + if tt_model_args.base_model_name == "Qwen2.5-VL-7B": + model = TtQwen_Model( + mesh_device=mesh_device, + state_dict=checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(ttnn.bfloat8_b), + dtype=ttnn.bfloat8_b, + args=tt_model_args, + use_paged_kv_cache=use_paged_kv_cache, + ) + else: + model = CrossAttentionTransformer( + mesh_device, + state_dict=checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=tt_model_args, + use_paged_kv_cache=use_paged_kv_cache, + ) return tt_model_args, model, checkpoint @@ -128,7 +142,7 @@ def prepare_generator_args( ) @pytest.mark.parametrize( "test_type,max_seq_len", - (("normal", 512),), + (("normal", 2048),), ids=["normal"], ) @pytest.mark.parametrize( @@ -172,9 +186,6 @@ def test_multimodal_demo_text( profiler = BenchmarkProfiler() profiler.start("run") - ckpt_dir = os.environ["LLAMA_DIR"] - tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") - num_devices = mesh_device.get_num_devices() if isinstance(mesh_device, ttnn.MeshDevice) else 1 max_batch_size *= data_parallel # input batch_size is interpreted as size per DP group @@ -185,11 +196,26 @@ def test_multimodal_demo_text( max_batch_size=max_batch_size, max_seq_len=max_seq_len, ) + + HF_MODEL = model_args[0].checkpoint_type == CheckpointType.HuggingFace + + if not HF_MODEL: + ckpt_dir = os.environ["LLAMA_DIR"] + tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") + + tokenizer = Tokenizer(model_path=tokenizer_path) + formatter = ChatFormat(tokenizer) + else: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_args[0].CKPT_DIR) + generator = Generator(model, model_args, mesh_device) - tokenizer = Tokenizer(model_path=tokenizer_path) - formatter = ChatFormat(tokenizer) - xattn_caches = [model.setup_cache(model_args[i].max_batch_size) for i, model in enumerate(generator.model)] + xattn_caches = [ + model.setup_cache(model_args[i].max_batch_size) if not HF_MODEL else None + for i, model in enumerate(generator.model) + ] # Create random images for trace capture with specific dimensions trace_img_560x560 = create_random_image(560, 560) @@ -250,10 +276,12 @@ def test_multimodal_demo_text( total_users = len(dialogs) num_batches = total_users // max_batch_size - sampler = get_batch_sampler(temperature, top_p, tokenizer) + sampler = get_batch_sampler(temperature, top_p, model_args[0].tokenizer) _num_prefill_tokens = 0 _num_decode_tokens = 0 + prompt_encoder = hf_multimodal_encode if HF_MODEL else formatter.encode_dialog_prompt + for iter_num in range(warmup_iters + 1): logger.info(f"Iteration {iter_num}") current_dialogs = trace_dialogs + dialogs @@ -263,9 +291,17 @@ def test_multimodal_demo_text( for msg in dialog: print(f"{msg.role.capitalize()}: {msg.content}\n") batch_model_input = [ - formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) for dialog in batch_dialogs + prompt_encoder(dialog, processor) if HF_MODEL else prompt_encoder(dialog, tool_prompt_format=False) + for dialog in batch_dialogs ] + if HF_MODEL: + # Use the processor's tokenizer instead of model_args tokenizer to ensure consistency + tokenizer = processor.tokenizer + image_grid_thw = [model_input.image_grid_thw for model_input in batch_model_input] + else: + image_grid_thw = None + # Do initial prefill vision_images = [ model_input.vision.images if model_input.vision else None for model_input in batch_model_input @@ -278,7 +314,7 @@ def test_multimodal_demo_text( total_lens = prefill_lens + max_gen_len # Create padded tokens tensor for batch - pad_id = tokenizer.pad_id + pad_id = tokenizer.pad_token_id if HF_MODEL else tokenizer.pad_id bsz = len(prompt_tokens) tokens = torch.full((bsz, max(total_lens)), pad_id, dtype=torch.long) @@ -302,6 +338,7 @@ def test_multimodal_demo_text( xattn_caches, total_lens, prefill_lens, + image_grid_thw=image_grid_thw, ) # Get cached prefill time @@ -319,6 +356,7 @@ def test_multimodal_demo_text( xattn_caches, total_lens, prefill_lens, + image_grid_thw=image_grid_thw, ) prefill_end = time.perf_counter() @@ -365,12 +403,16 @@ def test_multimodal_demo_text( ) # gen_idx is (num_tokens - 1) to avoid counting compile iter # Log full text output for each user in batch - vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256] + if HF_MODEL: + # For HF models, get vision tokens from the processor if they exist + vision_tokens = [] + else: + vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256] for user_id in range(max_batch_size): # Remove <|image|> tokens since they break the tokenizer tokens_out = [ - t if t not in vision_tokens else tokenizer.pad_id + t if t not in vision_tokens else pad_id for t in tokens[user_id].tolist()[: position_id[user_id] + 2] ] text = tokenizer.decode(tokens_out) diff --git a/models/experimental/qwen25_vl/tests/test_attention.py b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_attention.py similarity index 85% rename from models/experimental/qwen25_vl/tests/test_attention.py rename to models/tt_transformers/tests/multimodal/qwen_vl/test_image_attention.py index c01f78924f4e..b6fa77675ef1 100644 --- a/models/experimental/qwen25_vl/tests/test_attention.py +++ b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_attention.py @@ -8,8 +8,7 @@ import ttnn from models.tt_transformers.tt.model_config import ModelArgs - -from models.experimental.qwen25_vl.tt.attention import TtQwen2_5_VLVisionSdpaAttention +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_attention import TtQwen2_5_VLVisionSdpaAttention from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull @@ -83,14 +82,28 @@ def test_attention_inference(batch, num_chunks, mesh_device, reset_seeds): pt_attention_input.unsqueeze(0), force_replicated=True ) - cos_tensor = ttnn.from_torch(cos, device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) - sin_tensor = ttnn.from_torch(sin, device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + cos_tensor = ttnn.from_torch( + cos, + device=mesh_device, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + ) + sin_tensor = ttnn.from_torch( + sin, + device=mesh_device, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + ) # Step 6: run TT tt_out = tt_model(tt_attention_input, cu_seqlens, position_embeddings=(cos_tensor, sin_tensor)) # Doing contract in tt is correct!! - tt_output_torch = ttnn.to_torch(tt_out, device=mesh_device).squeeze(0) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[ + : tt_out.shape[0], : + ] passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) diff --git a/models/experimental/qwen25_vl/tests/test_vision_block.py b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_block.py similarity index 85% rename from models/experimental/qwen25_vl/tests/test_vision_block.py rename to models/tt_transformers/tests/multimodal/qwen_vl/test_image_block.py index 69bd51eb5617..f3e8e5b73645 100644 --- a/models/experimental/qwen25_vl/tests/test_vision_block.py +++ b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_block.py @@ -8,8 +8,7 @@ import ttnn from models.tt_transformers.tt.model_config import ModelArgs - -from models.experimental.qwen25_vl.tt.vision_block import TtQwen2_5_VLVisionBlock +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_block import TtQwen2_5_VLVisionBlock from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull @@ -73,12 +72,23 @@ def test_transformer_inference(batch, num_chunks, mesh_device, reset_seeds): pt_attention_input.unsqueeze(0), force_replicated=True ) - cos_tensor = ttnn.from_torch(cos, device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) - sin_tensor = ttnn.from_torch(sin, device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + cos_tensor = ttnn.from_torch( + cos, + device=mesh_device, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + ) + sin_tensor = ttnn.from_torch( + sin, + device=mesh_device, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + ) tt_out = tt_model(tt_attention_input, cu_seqlens, position_embeddings=(cos_tensor, sin_tensor)) - - tt_output_torch = ttnn.to_torch(tt_out, device=mesh_device).squeeze(0).squeeze(0) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[0, 0, :, :] passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) diff --git a/models/experimental/qwen25_vl/tests/test_image_merger.py b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_merger.py similarity index 88% rename from models/experimental/qwen25_vl/tests/test_image_merger.py rename to models/tt_transformers/tests/multimodal/qwen_vl/test_image_merger.py index 442165307a72..dd4bc5bae365 100644 --- a/models/experimental/qwen25_vl/tests/test_image_merger.py +++ b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_merger.py @@ -1,15 +1,15 @@ """"Test for Qwen 2.5 VL Patch Merger""" -from loguru import logger +import os -import torch import pytest -import os -import ttnn -from models.experimental.qwen25_vl.tt.patch_merger import TTQwen2_5_VLPatchMerger +import torch +from loguru import logger -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +import ttnn from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_patch_merger import TTQwen2_5_VLPatchMerger +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull @torch.no_grad() @@ -58,6 +58,7 @@ def test_patch_merger_inference(seq_len, batch_size, reset_seeds, device): dim=5120, state_dict=state_dict, weight_key=first_layer_prefix, + args=tt_model_args, layer_num=None, state_dict_prefix="", weight_cache_path=None, @@ -77,6 +78,7 @@ def test_patch_merger_inference(seq_len, batch_size, reset_seeds, device): input, device=device, dtype=dtype, + mesh_mapper=ttnn.ReplicateTensorToMesh(device), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) @@ -86,9 +88,8 @@ def test_patch_merger_inference(seq_len, batch_size, reset_seeds, device): tt_output = tt_model(tt_input) # DistributedNorm outputs are replicated across devices - tt_output_torch = ttnn.to_torch( - tt_output, - ) + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(device, dim=0)) + tt_output_torch = tt_output_torch[0, :] passing, pcc_message = comp_pcc(reference_output, tt_output_torch) diff --git a/models/experimental/qwen25_vl/tests/test_mlp.py b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_mlp.py similarity index 94% rename from models/experimental/qwen25_vl/tests/test_mlp.py rename to models/tt_transformers/tests/multimodal/qwen_vl/test_image_mlp.py index ee936435b8ec..f7b3dad7ba45 100644 --- a/models/experimental/qwen25_vl/tests/test_mlp.py +++ b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_mlp.py @@ -8,8 +8,7 @@ import ttnn from models.tt_transformers.tt.model_config import ModelArgs - -from models.experimental.qwen25_vl.tt.mlp import QwenTTVisionMLP +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_mlp import QwenTTVisionMLP from models.utility_functions import comp_allclose, comp_pcc, nearest_32, skip_for_grayskull @@ -69,8 +68,8 @@ def test_mlp_inference(batch, num_chunks, mesh_device, reset_seeds): tt_output = tt_model(tt_input) - tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ - :, :1, :, : + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[ + 0, :, :, : ].squeeze() pcc_required = 0.99 diff --git a/models/experimental/qwen25_vl/tests/test_vision_model.py b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_model.py similarity index 87% rename from models/experimental/qwen25_vl/tests/test_vision_model.py rename to models/tt_transformers/tests/multimodal/qwen_vl/test_image_model.py index bf0bef57ca25..1d3d3313197a 100644 --- a/models/experimental/qwen25_vl/tests/test_vision_model.py +++ b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_model.py @@ -7,11 +7,9 @@ from loguru import logger import ttnn - from models.tt_transformers.tt.model_config import ModelArgs - -from models.experimental.qwen25_vl.tt.vision_model import TtQwen2_5_VisionTransformerPretrainedModel -from models.utility_functions import comp_pcc, skip_for_grayskull, comp_allclose +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_vision_model import TtQwen2_5_VisionTransformerPretrainedModel +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull @skip_for_grayskull("Requires wormhole_b0 to run") @@ -41,8 +39,6 @@ def test_vision_inference(batch, num_chunks, mesh_device, reset_seeds): k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - dim = model_args.vision_dim - reference_model = model_args.reference_vision_model() reference_model.load_state_dict(partial_state_dict) reference_model.eval() @@ -69,8 +65,9 @@ def test_vision_inference(batch, num_chunks, mesh_device, reset_seeds): tt_attention_input = model_args.prepare_residual_tensor_prefill(pt_input.unsqueeze(0), force_replicated=True) tt_out = tt_model(tt_attention_input, grid_thw) - - tt_output_torch = ttnn.to_torch(tt_out, device=mesh_device) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[ + : tt_out.shape[0], : + ] non_zero_indices = tt_output_torch.ne(0).nonzero(as_tuple=True) tt_output_torch = tt_output_torch[non_zero_indices] diff --git a/models/experimental/qwen25_vl/tests/test_image_patch_embed.py b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_patch_embed.py similarity index 83% rename from models/experimental/qwen25_vl/tests/test_image_patch_embed.py rename to models/tt_transformers/tests/multimodal/qwen_vl/test_image_patch_embed.py index c08423637b2f..27c645a8b643 100644 --- a/models/experimental/qwen25_vl/tests/test_image_patch_embed.py +++ b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_patch_embed.py @@ -1,15 +1,15 @@ """"Test for Qwen 2.5 VL Patch Embed""" -from loguru import logger +import os -import torch import pytest -import os -import ttnn -from models.experimental.qwen25_vl.tt.patch_embed import TTQwen2_5_VisionPatchEmbed +import torch +from loguru import logger -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +import ttnn from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_patch_embed import TTQwen2_5_VisionPatchEmbed +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull @torch.no_grad() @@ -55,6 +55,7 @@ def test_embed_inference(seq_len, batch_size, reset_seeds, device): tt_model = TTQwen2_5_VisionPatchEmbed( device=device, + args=tt_model_args, patch_size=14, temporal_patch_size=2, in_channels=3, @@ -73,14 +74,19 @@ def test_embed_inference(seq_len, batch_size, reset_seeds, device): reference_output = reference_model(input) tt_input = ttnn.from_torch( - input, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + input, + device=device, + dtype=dtype, + mesh_mapper=ttnn.ReplicateTensorToMesh(device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) tt_output = tt_model(tt_input) - tt_output_torch = ttnn.to_torch( - tt_output, - ) + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(device, dim=-1))[ + : tt_output.shape[0], : + ] passing, pcc_message = comp_pcc(reference_output, tt_output_torch) diff --git a/models/experimental/qwen25_vl/tests/test_vision_rms.py b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_rms.py similarity index 97% rename from models/experimental/qwen25_vl/tests/test_vision_rms.py rename to models/tt_transformers/tests/multimodal/qwen_vl/test_image_rms.py index ca453d09940c..01fd4b6e1eef 100644 --- a/models/experimental/qwen25_vl/tests/test_vision_rms.py +++ b/models/tt_transformers/tests/multimodal/qwen_vl/test_image_rms.py @@ -1,19 +1,16 @@ """Test for Qwen 2.5 VL RMSNorm Layer Inference""" -from loguru import logger +import os -import torch import pytest -import os +import torch +from loguru import logger import ttnn -from models.experimental.qwen25_vl.tt.rmsnorm import RMSNorm - from models.tt_transformers.tt.distributed_norm import DistributedNorm - - -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_rmsnorm import RMSNorm +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull @torch.no_grad() diff --git a/models/tt_transformers/tt/common.py b/models/tt_transformers/tt/common.py index 5eebf47ce735..a8302a71e7d5 100644 --- a/models/tt_transformers/tt/common.py +++ b/models/tt_transformers/tt/common.py @@ -5,9 +5,11 @@ import math import re from enum import Enum +from types import SimpleNamespace from typing import Optional import torch +from llama_models.llama3.api.datatypes import ImageMedia from loguru import logger from pydantic import BaseModel, Field @@ -614,3 +616,46 @@ def create_tt_model( tt_kv_cache = [l.attention.layer_past for l in model.layers] if paged_attention_config else None return tt_model_args, model, tt_kv_cache, state_dict + + +def hf_multimodal_encode(messages, processor): + hf_messages = [] + + for msg in messages: + hf_content = [] + + for item in msg.content: + if isinstance(item, ImageMedia): + hf_content.append( + { + "type": "image", + "image": item.image, + } + ) + elif isinstance(item, str): + hf_content.append( + { + "type": "text", + "text": item, + } + ) + + hf_messages.append( + { + "role": msg.role, + "content": hf_content, + } + ) + + encoded = processor.apply_chat_template( + hf_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to("cpu", dtype=torch.bfloat16) + + return SimpleNamespace( + **encoded, + tokens=encoded["input_ids"].squeeze(0), + vision=SimpleNamespace( + images=encoded["pixel_values"], + mask=None, + ), + ) diff --git a/models/tt_transformers/tt/decoder.py b/models/tt_transformers/tt/decoder.py index d86520611b5f..905df8e8de17 100644 --- a/models/tt_transformers/tt/decoder.py +++ b/models/tt_transformers/tt/decoder.py @@ -4,7 +4,6 @@ import ttnn from models.common.lightweightmodule import LightweightModule from models.common.rmsnorm import RMSNorm -from models.experimental.qwen25_vl.tt.text_mlp import MLP as QwenMLP from models.tt_transformers.tt.attention import Attention as DefaultAttention from models.tt_transformers.tt.distributed_norm import DistributedNorm from models.tt_transformers.tt.mlp import MLP @@ -41,6 +40,7 @@ def __init__( self.current = 0 self.model_config = args.get_model_config() self.simplified_rms = True if self.args.base_model_name == "Qwen2.5-VL-7B" else False + self.simplified_rms = True if self.args.base_model_name == "Qwen2.5-VL-7B" else False self.layer_num = layer_num @@ -57,26 +57,16 @@ def __init__( paged_attention_config=paged_attention_config, use_paged_kv_cache=use_paged_kv_cache, ) - if self.args.base_model_name == "Qwen2.5-VL-7B": - self.feed_forward = QwenMLP( - mesh_device=mesh_device, - args=args, - state_dict=state_dict, - weight_cache_path=weight_cache_path, - layer_num=layer_num, - dtype=dtype, - model_config=self.model_config, - ) - else: - self.feed_forward = MLP( - mesh_device=mesh_device, - args=args, - state_dict=state_dict, - weight_cache_path=weight_cache_path, - layer_num=layer_num, - dtype=dtype, - model_config=self.model_config, - ) + + self.feed_forward = MLP( + mesh_device=mesh_device, + args=args, + state_dict=state_dict, + weight_cache_path=weight_cache_path, + layer_num=layer_num, + dtype=dtype, + model_config=self.model_config, + ) self.attention_norm = DistributedNorm( RMSNorm( device=mesh_device, diff --git a/models/tt_transformers/tt/generator.py b/models/tt_transformers/tt/generator.py index cd0620049b91..c6dd416d376c 100644 --- a/models/tt_transformers/tt/generator.py +++ b/models/tt_transformers/tt/generator.py @@ -22,6 +22,7 @@ get_padded_prefill_len, num_blocks_in_seq, ) +from models.tt_transformers.tt.model_config import CheckpointType @dataclass(frozen=True) @@ -79,6 +80,7 @@ def prefill_forward_text( seq_len = int(prompt_lens[idx]) last_token_idx = seq_len - 1 prefill_seq_len = get_padded_prefill_len(seq_len) + local_kwargs = kwargs.copy() # Avoid modifying original kwargs logger.info(f"Prefilling User {user_id + 1} up to {seq_len} tokens") @@ -94,6 +96,12 @@ def prefill_forward_text( ) model_kv_cache = kv_cache[model_id] if kv_cache is not None else None + # Check if 'pixel_values' exists and index it safely + if "pixel_values" in local_kwargs: + local_kwargs["pixel_values"] = local_kwargs["pixel_values"][idx] + if "image_grid_thw" in local_kwargs: + local_kwargs["image_grid_thw"] = local_kwargs["image_grid_thw"][idx] + logits = self.prefill_forward_single_user_text( prefill_ids, page_table=page_table_user, @@ -101,7 +109,7 @@ def prefill_forward_text( last_token_idx=last_token_idx, kv_cache=model_kv_cache, model_id=model_id, - **kwargs, + **local_kwargs, ) out_list.append(logits) @@ -491,6 +499,61 @@ def _prefill_forward_single_user( # Note: This function is called by vLLM def prefill_forward( + self, + vision_images, + vision_masks, + tokens, + xattn_caches, + total_lens, + prompt_lens, + page_table=None, + kv_cache=None, + cross_page_table=None, + empty_slots=None, + **kwargs, + ): + if self.model_args[0].checkpoint_type == CheckpointType.HuggingFace: + logits = self.prefill_forward_text( + tokens, + page_table=page_table, + kv_cache=kv_cache, + prompt_lens=prompt_lens, + pixel_values=vision_images, + **kwargs, + ) + + return logits, None, None, None, None + + else: + ( + output_logits, + prefill_output_xattn_masks, + prefill_output_full_text_row_masked_out_masks, + decode_output_xattn_masks, + decode_output_full_text_row_masked_out_masks, + ) = self.prefill_forward_llama_vision( + vision_images, + vision_masks, + tokens, + xattn_caches, + total_lens, + prompt_lens, + page_table=page_table, + kv_cache=kv_cache, + cross_page_table=cross_page_table, + empty_slots=empty_slots, + ) + + return ( + output_logits, + prefill_output_xattn_masks, + prefill_output_full_text_row_masked_out_masks, + decode_output_xattn_masks, + decode_output_full_text_row_masked_out_masks, + ) + + # Note: This function is called by vLLM + def prefill_forward_llama_vision( self, vision_images, vision_masks, @@ -587,7 +650,7 @@ def prefill_forward( ) # Note: This function is called by vLLM - def decode_forward( + def decode_forward_llama_vision( self, start_pos, tokens, @@ -651,6 +714,47 @@ def decode_forward( else: return tt_logits + def decode_forward( + self, + start_pos, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + xattn_caches=None, + page_table=None, + kv_cache=None, + cross_page_table=None, + enable_trace=True, + read_from_device=True, + ): + import os + + if os.environ.get("HF_MODEL"): + return self.decode_forward_text( + tokens, + start_pos, + enable_trace=enable_trace, + page_table=page_table, + kv_cache=kv_cache, + ) + else: + return self.decode_forward_llama_vision( + start_pos, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + xattn_caches, + page_table, + kv_cache, + cross_page_table, + enable_trace, + read_from_device, + ) + # Note: This function is called by vLLM def read_decode_output(self, tt_out, unpadded_batch, is_tokens=False): """ diff --git a/models/tt_transformers/tt/load_checkpoints.py b/models/tt_transformers/tt/load_checkpoints.py index 38e68df53e2b..1214d18fec67 100644 --- a/models/tt_transformers/tt/load_checkpoints.py +++ b/models/tt_transformers/tt/load_checkpoints.py @@ -213,6 +213,12 @@ def convert_meta_to_hf(state_dict, head_dim): return state_dict +def convert_vision_meta_to_hf(state_dict, head_dim): + # state_dict = convert_meta_qkv_to_hf_format(state_dict, head_dim) + state_dict = map_vision_meta_to_hf_keys(state_dict) + return state_dict + + def replace_keys(state_dict, replacements): """ Replacements are in the form (pattern, replacement). @@ -238,6 +244,7 @@ def map_hf_to_meta_keys(loaded_weights): """ replacements = [ ("^emb.weight", "weight"), + ("model.language_model.", ""), ("language.model.", ""), ("model.", ""), ("embed_tokens", "tok_embeddings"), diff --git a/models/tt_transformers/tt/model.py b/models/tt_transformers/tt/model.py index 8fce82afe8bd..36baaa66363e 100644 --- a/models/tt_transformers/tt/model.py +++ b/models/tt_transformers/tt/model.py @@ -360,6 +360,7 @@ def forward( chunk_start_idx=None, get_last_token=-1, kv_cache=None, + **kwargs, ): for i, layer in enumerate(self.layers): # No-op if callers already provide the right memory config diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index 33641b93a268..95a51be08fe6 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -1725,23 +1725,18 @@ def load_state_dict(self): else: assert self.checkpoint_type == CheckpointType.HuggingFace if self.from_hf_url: - # Special case Qwen2.5-VL models until they are fully integrated into a HF release - if "Qwen2.5-VL" in self.model_name: - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLForConditionalGeneration as AutoModelForCausalLM, - ) + from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForImageTextToText - print("Loading Qwen2.5-VL model: ", AutoModelForCausalLM) + if "Qwen2.5-VL-7B" in self.model_name: + model = AutoModelForImageTextToText.from_pretrained(self.CKPT_DIR, torch_dtype="auto") else: - from transformers import AutoModelForCausalLM - - model = AutoModelForCausalLM.from_pretrained( - self.CKPT_DIR, - torch_dtype="auto" - # Note that the default setting is torch.dtype.float32, but model weights are - # may come in any dtype. If the model's weights are in torch.dtype.bfloat16, this would result in 2x memory usage from an - # unnecessary cast. - ) + model = AutoModelForCausalLM.from_pretrained( + self.CKPT_DIR, + torch_dtype="auto" + # Note that the default setting is torch.dtype.float32, but model weights are + # may come in any dtype. If the model's weights are in torch.dtype.bfloat16, this would result in 2x memory usage from an + # unnecessary cast. + ) if self.cache_hf_flag: self.cached_hf_model = model state_dict = model.state_dict() @@ -1753,7 +1748,7 @@ def load_state_dict(self): state_dict = standardize_hf_keys_multimodal(state_dict) else: state_dict = standardize_hf_keys(state_dict) - state_dict = convert_hf_to_meta(state_dict, self.head_dim) + state_dict = convert_hf_to_meta(state_dict, self.head_dim) keys_dict = list(state_dict.keys())[:] remv = [f"layers.{i}." for i in list(range(self.n_layers, self.full_model_n_layers))] @@ -2096,6 +2091,7 @@ def create_tokenizer(self): "Qwen2.5-1.5B": "Qwen/Qwen2.5-1.5B-Instruct", "Qwen2.5-3B": "Qwen/Qwen2.5-3B-Instruct", "Qwen2.5-7B": "Qwen/Qwen2.5-7B-Instruct", + "Qwen2.5-VL-7B": "Qwen/Qwen2.5-VL-7B-Instruct", "Qwen2.5-14B": "Qwen/Qwen2.5-14B-Instruct", "Qwen2.5-32B": "Qwen/Qwen2.5-32B-Instruct", "Qwen2.5-72B": "Qwen/Qwen2.5-72B-Instruct", diff --git a/models/experimental/qwen25_vl/tt/model.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_e2e_model.py similarity index 60% rename from models/experimental/qwen25_vl/tt/model.py rename to models/tt_transformers/tt/multimodal/qwen_vl/qwen_e2e_model.py index 50cbbe924784..e8cc53f7c3a6 100644 --- a/models/experimental/qwen25_vl/tt/model.py +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_e2e_model.py @@ -1,19 +1,29 @@ -""" -This is the end-to-end pipeline for the Qwen-VL 2.5 model. +from typing import List -The `Qwen25VLTransformer` class inherits from the `Transformer` class in tt_transformers. -It overrides `prepare_inputs_prefill` to run inference on the vision model and -pass the resulting visual tokens to the text model along with text tokens. -""" - - -import ttnn import torch +import ttnn from models.tt_transformers.tt.model import Transformer +from models.tt_transformers.tt.multimodal.llama_vision_model import _stack_images +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_vision_model import TtQwen2_5_VisionTransformerPretrainedModel + +def _stack_images( + images: List[List[torch.Tensor]], # batch of samples, each with list of image embeddings +) -> List[torch.Tensor]: + """ + Concatenate image embeddings per sample into a single 2D tensor. -class Qwen25VLTransformer(Transformer): + Args: + images: List of samples, each being a list of [num_patches, hidden_dim] tensors + + Returns: + List of [total_patches, hidden_dim] tensors, one per sample + """ + return [torch.cat(image_list, dim=0) for image_list in images] + + +class TtQwen_Model(Transformer): def __init__( self, args, @@ -34,17 +44,26 @@ def __init__( use_paged_kv_cache=use_paged_kv_cache, ) - def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): + self.vision_model = TtQwen2_5_VisionTransformerPretrainedModel( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix="visual.", + dtype=dtype, + model_args=args, + weight_cache_path=args.weight_cache_path(dtype), + layers=args.vision_n_layers, + ) + + def prepare_inputs_prefill(self, pt_tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): """ Inputs are torch tensors or python types. This function returns ttnn tensors on device. TODO: Debate whether this function is responsible for padding """ - tokens = tokens.reshape(1, 1, 1, -1) - S = tokens.shape[-1] + S = pt_tokens.shape[-1] tokens = ttnn.from_torch( - tokens, + pt_tokens.reshape(1, 1, 1, -1), device=self.mesh_device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, @@ -54,31 +73,21 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag tokens_embd = self.embd(tokens) # tokens_embd = ttnn.multiply(tokens_embd, self.embed_scale) - pixel_values = kwargs["processed_inputs"]["pixel_values"] - input_ids = kwargs["processed_inputs"]["input_ids"] - image_grid_thw = kwargs["processed_inputs"]["image_grid_thw"] - - vision_model = kwargs["vision_model"] - pixel_values = self.args.prepare_residual_tensor_prefill(pixel_values.unsqueeze(0), force_replicated=True) + vision_output = self.compute_vision_token(**kwargs) - vision_output = vision_model(pixel_values, image_grid_thw) + tokens_embd = ttnn.to_torch(tokens_embd, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1)) + comp_vision_output = ttnn.to_torch( + vision_output, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0) + )[: vision_output.shape[0], :] - tokens_embd = ttnn.to_torch(tokens_embd) - comp_vision_output = ttnn.to_torch(ttnn.from_device(vision_output)) - - input_ids = torch.nn.functional.pad(input_ids, (0, tokens_embd.shape[1] - input_ids.shape[1]), "constant", 0) image_features = comp_vision_output.squeeze(0) - special_image_mask = (input_ids == 151655).unsqueeze(-1) + special_image_mask = (pt_tokens == 151655).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(tokens_embd) image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) - tokens_embd = ttnn.from_torch( + tokens_embd = self.args.prepare_residual_tensor_prefill( tokens_embd, - dtype=ttnn.bfloat16, - device=self.mesh_device, - layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) @@ -115,3 +124,9 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag tt_chunk_page_table = None return tokens_embd, tt_rot_mats_prefill_global, tt_page_table, tt_chunk_page_table + + def compute_vision_token(self, pixel_values, image_grid_thw): + pixel_values = self.args.prepare_residual_tensor_prefill(pixel_values.unsqueeze(0), force_replicated=True) + + vision_output = self.vision_model(pixel_values, image_grid_thw) + return vision_output diff --git a/models/experimental/qwen25_vl/tt/attention.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_attention.py similarity index 85% rename from models/experimental/qwen25_vl/tt/attention.py rename to models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_attention.py index a515bf35e737..bd8dabf68676 100644 --- a/models/experimental/qwen25_vl/tt/attention.py +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_attention.py @@ -9,6 +9,7 @@ import torch + import ttnn from models.common.lightweightmodule import LightweightModule @@ -38,6 +39,7 @@ def __init__(self, mesh_device, state_dict, state_dict_prefix, dtype, configurat self.num_heads = 16 self.head_dim = self.hidden_size // self.num_heads self.scale = self.head_dim**-0.5 + self.configuration = configuration # Load qkv weight & bias (fused): shape [hidden_size, hidden_size*3] qkv_weight = state_dict[f"{state_dict_prefix}qkv.weight"] @@ -48,11 +50,17 @@ def __init__(self, mesh_device, state_dict, state_dict_prefix, dtype, configurat torch.transpose(qkv_weight, -2, -1), device=mesh_device, dtype=dtype, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) self.qkv_bias = ttnn.as_tensor( - qkv_bias, device=mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + qkv_bias, + device=mesh_device, + dtype=dtype, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) # Output projection: proj @@ -63,11 +71,17 @@ def __init__(self, mesh_device, state_dict, state_dict_prefix, dtype, configurat torch.transpose(proj_weight, -2, -1), device=mesh_device, dtype=dtype, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) self.proj_bias = ttnn.as_tensor( - proj_bias, device=mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + proj_bias, + device=mesh_device, + dtype=dtype, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( @@ -94,6 +108,9 @@ def forward(self, hidden_states, cu_seqlens, position_embeddings): compute_kernel_config=self.compute_kernel_config, ) # shape [batch, seq_len, hidden_size*3] + if self.configuration.num_devices > 1: + qkv = ttnn.all_gather(qkv, dim=-1, num_links=1) + (q, k, v) = ttnn.permute(ttnn.reshape(qkv, [seq_len, 3, self.num_heads, -1]), [1, 0, 2, 3]) ttnn.deallocate(qkv) @@ -135,7 +152,9 @@ def forward(self, hidden_states, cu_seqlens, position_embeddings): memory_config=ttnn.DRAM_MEMORY_CONFIG, compute_kernel_config=self.compute_kernel_config, ) - ttnn.deallocate(attn_output) + if self.configuration.num_devices > 1: + output = ttnn.all_gather(output, dim=1, num_links=1) + return output diff --git a/models/experimental/qwen25_vl/tt/vision_block.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_block.py similarity index 89% rename from models/experimental/qwen25_vl/tt/vision_block.py rename to models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_block.py index 72461a61c7a6..14f81eae0263 100644 --- a/models/experimental/qwen25_vl/tt/vision_block.py +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_block.py @@ -5,9 +5,9 @@ import ttnn from models.common.lightweightmodule import LightweightModule -from models.experimental.qwen25_vl.tt.mlp import QwenTTVisionMLP -from models.experimental.qwen25_vl.tt.rmsnorm import RMSNorm -from models.experimental.qwen25_vl.tt.attention import TtQwen2_5_VLVisionSdpaAttention +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_attention import TtQwen2_5_VLVisionSdpaAttention +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_mlp import QwenTTVisionMLP +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_rmsnorm import RMSNorm class TtQwen2_5_VLVisionBlock(LightweightModule): diff --git a/models/experimental/qwen25_vl/tt/mlp.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_mlp.py similarity index 90% rename from models/experimental/qwen25_vl/tt/mlp.py rename to models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_mlp.py index e1ebfe02d6ed..ec13ff554e85 100644 --- a/models/experimental/qwen25_vl/tt/mlp.py +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_mlp.py @@ -48,7 +48,7 @@ def as_tensor(name, dtype, is_bias=False): tensor_data, dtype=dtype, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, # cache_file_name=cache_name(name), @@ -101,6 +101,10 @@ def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: compute_kernel_config=self.compute_kernel_config, ) + if self.args.num_devices > 1: + w1_out = ttnn.all_gather(w1_out, dim=3, num_links=1) + w3_out = ttnn.all_gather(w3_out, dim=3, num_links=1) + # Element-wise multiply w2_in = ttnn.mul(w1_out, w3_out, dtype=ttnn.bfloat16) @@ -118,4 +122,7 @@ def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: ttnn.deallocate(w3_out) ttnn.deallocate(w2_in) + if self.args.num_devices > 1: + w2_out = ttnn.all_gather(w2_out, dim=len(w2_out.shape) - 1, num_links=1) + return w2_out diff --git a/models/experimental/qwen25_vl/tt/patch_embed.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_patch_embed.py similarity index 90% rename from models/experimental/qwen25_vl/tt/patch_embed.py rename to models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_patch_embed.py index 36262c777f7e..2cdc3e679103 100644 --- a/models/experimental/qwen25_vl/tt/patch_embed.py +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_patch_embed.py @@ -14,6 +14,7 @@ class TTQwen2_5_VisionPatchEmbed: def __init__( self, device, + args, patch_size, temporal_patch_size, in_channels, @@ -36,6 +37,7 @@ def __init__( self.embed_dim = embed_dim self.weight_memory_config = weight_memory_config self.weight_dtype = weight_dtype + self.args = args weight_name_1 = f"{state_dict_prefix}{weight_key}proj.weight" torch_weight = state_dict[weight_name_1] @@ -45,6 +47,7 @@ def __init__( weight_matrix.T, device=self.device, dtype=self.weight_dtype, + mesh_mapper=ttnn.ShardTensorToMesh(self.device, dim=-1), layout=ttnn.TILE_LAYOUT, memory_config=self.weight_memory_config, ) @@ -59,4 +62,7 @@ def __call__(self, x: ttnn.Tensor) -> ttnn.Tensor: x_flattened = ttnn.reshape(x, (x.shape[2], -1)) output = ttnn.matmul(x_flattened, self.weight, compute_kernel_config=self.compute_kernel_config) + if self.args.num_devices > 1: + output = ttnn.all_gather(output, dim=1, num_links=1) + return output diff --git a/models/experimental/qwen25_vl/tt/patch_merger.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_patch_merger.py similarity index 77% rename from models/experimental/qwen25_vl/tt/patch_merger.py rename to models/tt_transformers/tt/multimodal/qwen_vl/qwen_patch_merger.py index a1001b0ec96d..71cd4b4c02c2 100644 --- a/models/experimental/qwen25_vl/tt/patch_merger.py +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_patch_merger.py @@ -6,8 +6,7 @@ """ import ttnn -from models.experimental.qwen25_vl.tt.rmsnorm import RMSNorm -from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_rmsnorm import RMSNorm class TTQwen2_5_VLPatchMerger: @@ -17,6 +16,7 @@ def __init__( dim, state_dict, weight_key, + args, layer_num=None, state_dict_prefix="", weight_cache_path=None, @@ -33,39 +33,35 @@ def __init__( self.eps = eps self.mode = mode - tt_model_args = ModelArgs( - device, - max_batch_size=1, - max_seq_len=128, - ) + self.args = args weight_name_1 = f"{state_dict_prefix}{weight_key}ln_q.weight" weight_name_2 = f"{state_dict_prefix}{weight_key}feed_forward.0.weight" weight_name_3 = f"{state_dict_prefix}{weight_key}feed_forward.2.weight" - bias_name_2 = f"{state_dict_prefix}{weight_key}feed_forward.0.bias" - bias_name_3 = f"{state_dict_prefix}{weight_key}feed_forward.2.bias" - self.weight_1 = ttnn.as_tensor( state_dict[weight_name_1], device=device, dtype=weight_dtype, + mesh_mapper=ttnn.ShardTensorToMesh(device, dim=-1), layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=weight_memory_config, ) self.weight_2 = ttnn.as_tensor( - state_dict[weight_name_2], + state_dict[weight_name_2].transpose(0, 1), device=device, dtype=weight_dtype, + mesh_mapper=ttnn.ShardTensorToMesh(device, dim=-1), layout=ttnn.TILE_LAYOUT, memory_config=weight_memory_config, ) self.weight_3 = ttnn.as_tensor( - state_dict[weight_name_3], + state_dict[weight_name_3].transpose(0, 1), device=device, dtype=weight_dtype, + mesh_mapper=ttnn.ShardTensorToMesh(device, dim=-1), layout=ttnn.TILE_LAYOUT, memory_config=weight_memory_config, ) @@ -79,14 +75,10 @@ def __init__( weight_key="visual.merger.ln_q", weight_dtype=ttnn.bfloat16, is_distributed=False, - sharded_program_config=tt_model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_program_config=self.args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], sharded_output_config=False, ) - self.weight_3 = ttnn.transpose(self.weight_3, 0, 1) - - self.weight_2 = ttnn.transpose(self.weight_2, 0, 1) - self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( math_fidelity=ttnn.MathFidelity.HiFi4, fp32_dest_acc_en=True, @@ -96,8 +88,7 @@ def __init__( def __call__(self, x): x = self.ln_q(x, mode=self.mode) - - x = ttnn.reshape(x, (-1, self.hidden_size)) + x = ttnn.reshape(x, (1, 1, -1, self.hidden_size)) x = ttnn.linear( x, @@ -105,6 +96,10 @@ def __call__(self, x): memory_config=ttnn.DRAM_MEMORY_CONFIG, compute_kernel_config=self.compute_kernel_config, ) + + if self.args.num_devices > 1: + x = ttnn.all_gather(x, dim=3) + x = ttnn.gelu(x) x = ttnn.linear( @@ -113,5 +108,9 @@ def __call__(self, x): memory_config=ttnn.DRAM_MEMORY_CONFIG, compute_kernel_config=self.compute_kernel_config, ) + if self.args.num_devices > 1: + x = ttnn.all_gather(x, dim=3, num_links=1) + + x = ttnn.reshape(x, (-1, x.shape[-1])) return x diff --git a/models/experimental/qwen25_vl/tt/rmsnorm.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_rmsnorm.py similarity index 98% rename from models/experimental/qwen25_vl/tt/rmsnorm.py rename to models/tt_transformers/tt/multimodal/qwen_vl/qwen_rmsnorm.py index 025ffcfc638b..c0a32474fa49 100644 --- a/models/experimental/qwen25_vl/tt/rmsnorm.py +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_rmsnorm.py @@ -117,8 +117,6 @@ def forward(self, x: ttnn.Tensor, mode="decode", in_sharded=False, out_sharded=F def _distributed_rmsnorm( self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None ): - inp = ttnn.sharded_to_interleaved(inp, ttnn.DRAM_MEMORY_CONFIG) - xnorm = ttnn.pow(inp, 2) xnorm = ttnn.mean(xnorm, dim=-1, keepdim=True) diff --git a/models/experimental/qwen25_vl/tt/vision_model.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_vision_model.py similarity index 88% rename from models/experimental/qwen25_vl/tt/vision_model.py rename to models/tt_transformers/tt/multimodal/qwen_vl/qwen_vision_model.py index 974dbbe8daa4..9a1a7f16b9d3 100644 --- a/models/experimental/qwen25_vl/tt/vision_model.py +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_vision_model.py @@ -5,16 +5,16 @@ and patch merger for visual input processing. """ -import ttnn -from tqdm import tqdm -from models.common.lightweightmodule import LightweightModule -from models.experimental.qwen25_vl.tt.vision_block import TtQwen2_5_VLVisionBlock -from models.experimental.qwen25_vl.tt.patch_embed import TTQwen2_5_VisionPatchEmbed -from models.experimental.qwen25_vl.tt.rope import TTQwen2_5_VisionRotaryEmbedding -from models.experimental.qwen25_vl.tt.patch_merger import TTQwen2_5_VLPatchMerger - import torch import torch.nn.functional as F +from tqdm import tqdm + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_block import TtQwen2_5_VLVisionBlock +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_image_patch_embed import TTQwen2_5_VisionPatchEmbed +from models.tt_transformers.tt.multimodal.qwen_vl.qwen_patch_merger import TTQwen2_5_VLPatchMerger +from models.tt_transformers.tt.rope import TTQwen2_5_VisionRotaryEmbedding class TtQwen2_5_VisionTransformerPretrainedModel(LightweightModule): @@ -43,6 +43,7 @@ def __init__( self.patch_embed = TTQwen2_5_VisionPatchEmbed( device=mesh_device, + args=model_args, patch_size=self.patch_size, temporal_patch_size=temporal_patch_size, in_channels=3, @@ -81,6 +82,7 @@ def __init__( dim=5120, state_dict=state_dict, state_dict_prefix=state_dict_prefix, + args=model_args, weight_key="merger.", layer_num=None, weight_cache_path=None, @@ -119,7 +121,9 @@ def rot_pos_emb(self, grid_thw): pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb_full = ttnn.to_torch(rotary_pos_emb_full, device=self.mesh_device) + rotary_pos_emb_full = ttnn.to_torch( + rotary_pos_emb_full, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0) + )[: rotary_pos_emb_full.shape[0], :] rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb @@ -166,9 +170,7 @@ def get_window_index(self, grid_thw): def forward(self, hidden_states, grid_thw): hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) - window_index, cu_window_seqlens = self.get_window_index(grid_thw) cu_window_seqlens = torch.tensor( cu_window_seqlens, @@ -178,9 +180,7 @@ def forward(self, hidden_states, grid_thw): cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) seq_len = hidden_states.shape[-2] - hidden_states = ttnn.reshape(hidden_states, [seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1]) - # hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) tt_index = ttnn.from_torch( window_index.view(-1, 1, 1).expand(-1, hidden_states.shape[-2], hidden_states.shape[-1]).permute(1, 2, 0), device=self.mesh_device, @@ -191,22 +191,31 @@ def forward(self, hidden_states, grid_thw): hidden_states = ttnn.gather(ttnn.permute(hidden_states, (1, 2, 0)), dim=-1, index=tt_index) hidden_states = ttnn.permute(hidden_states, (2, 0, 1)) - - hidden_states = ttnn.reshape(hidden_states, [seq_len, -1]) - # hidden_states = hidden_states.reshape(seq_len, -1) + hidden_states = ttnn.reshape(hidden_states, [1, 1, seq_len, -1]) rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - cos_tensor = ttnn.from_torch(emb.cos(), device=self.mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) - sin_tensor = ttnn.from_torch(emb.sin(), device=self.mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + cos_tensor = ttnn.from_torch( + emb.cos(), + device=self.mesh_device, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + layout=ttnn.TILE_LAYOUT, + ) + sin_tensor = ttnn.from_torch( + emb.sin(), + device=self.mesh_device, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + layout=ttnn.TILE_LAYOUT, + ) position_embeddings = (cos_tensor, sin_tensor) ttnn.deallocate(tt_index) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, # Select dtype based on the following factors: @@ -216,7 +225,6 @@ def forward(self, hidden_states, grid_thw): dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens diff --git a/models/tt_transformers/tt/rope.py b/models/tt_transformers/tt/rope.py index e5e96c148fb2..2b13582897d0 100644 --- a/models/tt_transformers/tt/rope.py +++ b/models/tt_transformers/tt/rope.py @@ -447,3 +447,26 @@ def get_rot_mats( if return_rot_idxs: return [cos, sin], rot_idxs return [cos, sin] + + +class TTQwen2_5_VisionRotaryEmbedding: + def __init__(self, device, dim: int, theta: float = 10000.0, mode="decode"): + self.dim = dim + self.theta = theta + self.device = device + + arange_indices = ttnn.arange(start=0, end=dim, step=2, device=device) + arange_indices = ttnn.to_layout(arange_indices, ttnn.TILE_LAYOUT) + exponent = ttnn.div(arange_indices, dim) + pow_result = ttnn.pow(theta, exponent) + recip = ttnn.reciprocal(pow_result) + self.inv_freq = ttnn.multiply(recip, 1.0) + + def __call__(self, seqlen: int): + tt_seq = ttnn.arange(end=seqlen, device=self.device) + tt_seq = ttnn.reshape(tt_seq, [1, 1, 1, tt_seq.shape[0]]) + tt_freq = ttnn.reshape(self.inv_freq, [1, 1, 1, self.inv_freq.shape[0]]) + tt_freqs = ttnn.outer(tt_seq, tt_freq) + tt_freqs = ttnn.reshape(tt_freqs, [tt_freqs.shape[2], tt_freqs.shape[3]]) + + return tt_freqs