2727import ttnn
2828from models .demos .utils .llm_demo_utils import create_benchmark_data , verify_perf
2929from models .perf .benchmarking_utils import BenchmarkProfiler
30+ from models .tt_transformers .tt .common import hf_multimodal_encode
3031from models .tt_transformers .tt .generator import Generator
32+ from models .tt_transformers .tt .model_config import CheckpointType
3133
3234
3335def get_batch_sampler (temperature , top_p , tokenizer ):
@@ -62,6 +64,7 @@ def create_multimodal_model(
6264):
6365 from models .tt_transformers .tt .model_config import ModelArgs
6466 from models .tt_transformers .tt .multimodal .llama_vision_model import CrossAttentionTransformer
67+ from models .tt_transformers .tt .multimodal .mistral_24b .mistral_e2e_model import MistralTransformer
6568
6669 tt_model_args = ModelArgs (mesh_device , max_batch_size = max_batch_size )
6770 assert tt_model_args .is_vision (), "This model is multimodal"
@@ -76,14 +79,25 @@ def create_multimodal_model(
7679
7780 if checkpoint is None :
7881 checkpoint = tt_model_args .load_state_dict ()
79- model = CrossAttentionTransformer (
80- mesh_device ,
81- state_dict = checkpoint ,
82- weight_cache_path = tt_model_args .weight_cache_path (dtype ),
83- dtype = dtype ,
84- configuration = tt_model_args ,
85- use_paged_kv_cache = use_paged_kv_cache ,
86- )
82+
83+ if tt_model_args .base_model_name == "Mistral-Small-3.1-24B" :
84+ model = MistralTransformer (
85+ mesh_device = mesh_device ,
86+ state_dict = checkpoint ,
87+ weight_cache_path = tt_model_args .weight_cache_path (ttnn .bfloat8_b ),
88+ dtype = ttnn .bfloat8_b ,
89+ args = tt_model_args ,
90+ use_paged_kv_cache = use_paged_kv_cache ,
91+ )
92+ else :
93+ model = CrossAttentionTransformer (
94+ mesh_device ,
95+ state_dict = checkpoint ,
96+ weight_cache_path = tt_model_args .weight_cache_path (dtype ),
97+ dtype = dtype ,
98+ configuration = tt_model_args ,
99+ use_paged_kv_cache = use_paged_kv_cache ,
100+ )
87101 return tt_model_args , model , checkpoint
88102
89103
@@ -136,7 +150,7 @@ def prepare_generator_args(
136150)
137151@pytest .mark .parametrize (
138152 "test_type,max_seq_len" ,
139- (("normal" , 512 ),),
153+ (("normal" , 2048 ),),
140154 ids = ["normal" ],
141155)
142156@pytest .mark .parametrize (
@@ -182,9 +196,6 @@ def test_multimodal_demo_text(
182196 profiler = BenchmarkProfiler ()
183197 profiler .start ("run" )
184198
185- ckpt_dir = os .environ ["LLAMA_DIR" ]
186- tokenizer_path = str (Path (ckpt_dir ) / "tokenizer.model" )
187-
188199 num_devices = mesh_device .get_num_devices () if isinstance (mesh_device , ttnn .MeshDevice ) else 1
189200 max_batch_size *= data_parallel # input batch_size is interpreted as size per DP group
190201
@@ -195,11 +206,27 @@ def test_multimodal_demo_text(
195206 max_batch_size = max_batch_size ,
196207 max_seq_len = max_seq_len ,
197208 )
209+
210+ HF_MODEL = model_args [0 ].checkpoint_type == CheckpointType .HuggingFace
211+
212+ if not HF_MODEL :
213+ ckpt_dir = os .environ ["LLAMA_DIR" ]
214+ tokenizer_path = str (Path (ckpt_dir ) / "tokenizer.model" )
215+
216+ tokenizer = Tokenizer (model_path = tokenizer_path )
217+ formatter = ChatFormat (tokenizer )
218+ else :
219+ from transformers import AutoProcessor
220+
221+ processor = AutoProcessor .from_pretrained (model_args [0 ].CKPT_DIR )
222+ tokenizer = model_args [0 ].tokenizer
223+
198224 generator = Generator (model , model_args , mesh_device )
199- tokenizer = Tokenizer (model_path = tokenizer_path )
200- formatter = ChatFormat (tokenizer )
201225
202- xattn_caches = [model .setup_cache (model_args [i ].max_batch_size ) for i , model in enumerate (generator .model )]
226+ xattn_caches = [
227+ model .setup_cache (model_args [i ].max_batch_size ) if not HF_MODEL else None
228+ for i , model in enumerate (generator .model )
229+ ]
203230
204231 # Create random images for trace capture with specific dimensions
205232 trace_img_560x560 = create_random_image (560 , 560 )
@@ -264,6 +291,8 @@ def test_multimodal_demo_text(
264291 _num_prefill_tokens = 0
265292 _num_decode_tokens = 0
266293
294+ prompt_encoder = hf_multimodal_encode if HF_MODEL else formatter .encode_dialog_prompt
295+
267296 for iter_num in range (warmup_iters + 1 ):
268297 logger .info (f"Iteration { iter_num } " )
269298 current_dialogs = trace_dialogs + dialogs
@@ -273,9 +302,15 @@ def test_multimodal_demo_text(
273302 for msg in dialog :
274303 print (f"{ msg .role .capitalize ()} : { msg .content } \n " )
275304 batch_model_input = [
276- formatter .encode_dialog_prompt (dialog , tool_prompt_format = False ) for dialog in batch_dialogs
305+ prompt_encoder (dialog , processor ) if HF_MODEL else prompt_encoder (dialog , tool_prompt_format = False )
306+ for dialog in batch_dialogs
277307 ]
278308
309+ if HF_MODEL :
310+ image_sizes = [model_input .image_sizes for model_input in batch_model_input ]
311+ else :
312+ image_sizes = None
313+
279314 # Do initial prefill
280315 vision_images = [
281316 model_input .vision .images if model_input .vision else None for model_input in batch_model_input
@@ -288,7 +323,7 @@ def test_multimodal_demo_text(
288323 total_lens = prefill_lens + max_gen_len
289324
290325 # Create padded tokens tensor for batch
291- pad_id = tokenizer .pad_id
326+ pad_id = tokenizer .pad_token_id if HF_MODEL else tokenizer . pad_id
292327 bsz = len (prompt_tokens )
293328 tokens = torch .full ((bsz , max (total_lens )), pad_id , dtype = torch .long )
294329
@@ -312,6 +347,7 @@ def test_multimodal_demo_text(
312347 xattn_caches ,
313348 total_lens ,
314349 prefill_lens ,
350+ image_sizes = image_sizes ,
315351 )
316352
317353 # Get cached prefill time
@@ -323,12 +359,7 @@ def test_multimodal_demo_text(
323359 decode_batch_xattn_masks ,
324360 decode_batch_text_masks ,
325361 ) = generator .prefill_forward (
326- vision_images ,
327- vision_mask ,
328- tokens ,
329- xattn_caches ,
330- total_lens ,
331- prefill_lens ,
362+ vision_images , vision_mask , tokens , xattn_caches , total_lens , prefill_lens , image_sizes = image_sizes
332363 )
333364
334365 prefill_end = time .perf_counter ()
@@ -375,12 +406,16 @@ def test_multimodal_demo_text(
375406 ) # gen_idx is (num_tokens - 1) to avoid counting compile iter
376407
377408 # Log full text output for each user in batch
378- vision_tokens = [tokenizer .special_tokens ["<|image|>" ], 128256 ]
409+ if HF_MODEL :
410+ # For HF models, get vision tokens from the processor if they exist
411+ vision_tokens = []
412+ else :
413+ vision_tokens = [tokenizer .special_tokens ["<|image|>" ], 128256 ]
379414
380415 for user_id in range (max_batch_size ):
381416 # Remove <|image|> tokens since they break the tokenizer
382417 tokens_out = [
383- t if t not in vision_tokens else tokenizer . pad_id
418+ t if t not in vision_tokens else pad_id
384419 for t in tokens [user_id ].tolist ()[: position_id [user_id ] + 2 ]
385420 ]
386421 text = tokenizer .decode (tokens_out )
0 commit comments