Skip to content

Commit 28b249b

Browse files
committed
Migrate mistral 24b to tt_transformers
1 parent 261d632 commit 28b249b

20 files changed

+269
-114
lines changed

models/tt_transformers/demo/simple_vision_demo.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
import ttnn
2828
from models.demos.utils.llm_demo_utils import create_benchmark_data, verify_perf
2929
from models.perf.benchmarking_utils import BenchmarkProfiler
30+
from models.tt_transformers.tt.common import hf_multimodal_encode
3031
from models.tt_transformers.tt.generator import Generator
32+
from models.tt_transformers.tt.model_config import CheckpointType
3133

3234

3335
def get_batch_sampler(temperature, top_p, tokenizer):
@@ -62,6 +64,7 @@ def create_multimodal_model(
6264
):
6365
from models.tt_transformers.tt.model_config import ModelArgs
6466
from models.tt_transformers.tt.multimodal.llama_vision_model import CrossAttentionTransformer
67+
from models.tt_transformers.tt.multimodal.mistral_24b.mistral_e2e_model import MistralTransformer
6568

6669
tt_model_args = ModelArgs(mesh_device, max_batch_size=max_batch_size)
6770
assert tt_model_args.is_vision(), "This model is multimodal"
@@ -76,14 +79,25 @@ def create_multimodal_model(
7679

7780
if checkpoint is None:
7881
checkpoint = tt_model_args.load_state_dict()
79-
model = CrossAttentionTransformer(
80-
mesh_device,
81-
state_dict=checkpoint,
82-
weight_cache_path=tt_model_args.weight_cache_path(dtype),
83-
dtype=dtype,
84-
configuration=tt_model_args,
85-
use_paged_kv_cache=use_paged_kv_cache,
86-
)
82+
83+
if tt_model_args.base_model_name == "Mistral-Small-3.1-24B":
84+
model = MistralTransformer(
85+
mesh_device=mesh_device,
86+
state_dict=checkpoint,
87+
weight_cache_path=tt_model_args.weight_cache_path(ttnn.bfloat8_b),
88+
dtype=ttnn.bfloat8_b,
89+
args=tt_model_args,
90+
use_paged_kv_cache=use_paged_kv_cache,
91+
)
92+
else:
93+
model = CrossAttentionTransformer(
94+
mesh_device,
95+
state_dict=checkpoint,
96+
weight_cache_path=tt_model_args.weight_cache_path(dtype),
97+
dtype=dtype,
98+
configuration=tt_model_args,
99+
use_paged_kv_cache=use_paged_kv_cache,
100+
)
87101
return tt_model_args, model, checkpoint
88102

89103

@@ -136,7 +150,7 @@ def prepare_generator_args(
136150
)
137151
@pytest.mark.parametrize(
138152
"test_type,max_seq_len",
139-
(("normal", 512),),
153+
(("normal", 2048),),
140154
ids=["normal"],
141155
)
142156
@pytest.mark.parametrize(
@@ -182,9 +196,6 @@ def test_multimodal_demo_text(
182196
profiler = BenchmarkProfiler()
183197
profiler.start("run")
184198

185-
ckpt_dir = os.environ["LLAMA_DIR"]
186-
tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model")
187-
188199
num_devices = mesh_device.get_num_devices() if isinstance(mesh_device, ttnn.MeshDevice) else 1
189200
max_batch_size *= data_parallel # input batch_size is interpreted as size per DP group
190201

@@ -195,11 +206,27 @@ def test_multimodal_demo_text(
195206
max_batch_size=max_batch_size,
196207
max_seq_len=max_seq_len,
197208
)
209+
210+
HF_MODEL = model_args[0].checkpoint_type == CheckpointType.HuggingFace
211+
212+
if not HF_MODEL:
213+
ckpt_dir = os.environ["LLAMA_DIR"]
214+
tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model")
215+
216+
tokenizer = Tokenizer(model_path=tokenizer_path)
217+
formatter = ChatFormat(tokenizer)
218+
else:
219+
from transformers import AutoProcessor
220+
221+
processor = AutoProcessor.from_pretrained(model_args[0].CKPT_DIR)
222+
tokenizer = model_args[0].tokenizer
223+
198224
generator = Generator(model, model_args, mesh_device)
199-
tokenizer = Tokenizer(model_path=tokenizer_path)
200-
formatter = ChatFormat(tokenizer)
201225

202-
xattn_caches = [model.setup_cache(model_args[i].max_batch_size) for i, model in enumerate(generator.model)]
226+
xattn_caches = [
227+
model.setup_cache(model_args[i].max_batch_size) if not HF_MODEL else None
228+
for i, model in enumerate(generator.model)
229+
]
203230

204231
# Create random images for trace capture with specific dimensions
205232
trace_img_560x560 = create_random_image(560, 560)
@@ -264,6 +291,8 @@ def test_multimodal_demo_text(
264291
_num_prefill_tokens = 0
265292
_num_decode_tokens = 0
266293

294+
prompt_encoder = hf_multimodal_encode if HF_MODEL else formatter.encode_dialog_prompt
295+
267296
for iter_num in range(warmup_iters + 1):
268297
logger.info(f"Iteration {iter_num}")
269298
current_dialogs = trace_dialogs + dialogs
@@ -273,9 +302,15 @@ def test_multimodal_demo_text(
273302
for msg in dialog:
274303
print(f"{msg.role.capitalize()}: {msg.content}\n")
275304
batch_model_input = [
276-
formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) for dialog in batch_dialogs
305+
prompt_encoder(dialog, processor) if HF_MODEL else prompt_encoder(dialog, tool_prompt_format=False)
306+
for dialog in batch_dialogs
277307
]
278308

309+
if HF_MODEL:
310+
image_sizes = [model_input.image_sizes for model_input in batch_model_input]
311+
else:
312+
image_sizes = None
313+
279314
# Do initial prefill
280315
vision_images = [
281316
model_input.vision.images if model_input.vision else None for model_input in batch_model_input
@@ -288,7 +323,7 @@ def test_multimodal_demo_text(
288323
total_lens = prefill_lens + max_gen_len
289324

290325
# Create padded tokens tensor for batch
291-
pad_id = tokenizer.pad_id
326+
pad_id = tokenizer.pad_token_id if HF_MODEL else tokenizer.pad_id
292327
bsz = len(prompt_tokens)
293328
tokens = torch.full((bsz, max(total_lens)), pad_id, dtype=torch.long)
294329

@@ -312,6 +347,7 @@ def test_multimodal_demo_text(
312347
xattn_caches,
313348
total_lens,
314349
prefill_lens,
350+
image_sizes=image_sizes,
315351
)
316352

317353
# Get cached prefill time
@@ -323,12 +359,7 @@ def test_multimodal_demo_text(
323359
decode_batch_xattn_masks,
324360
decode_batch_text_masks,
325361
) = generator.prefill_forward(
326-
vision_images,
327-
vision_mask,
328-
tokens,
329-
xattn_caches,
330-
total_lens,
331-
prefill_lens,
362+
vision_images, vision_mask, tokens, xattn_caches, total_lens, prefill_lens, image_sizes=image_sizes
332363
)
333364

334365
prefill_end = time.perf_counter()
@@ -375,12 +406,16 @@ def test_multimodal_demo_text(
375406
) # gen_idx is (num_tokens - 1) to avoid counting compile iter
376407

377408
# Log full text output for each user in batch
378-
vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256]
409+
if HF_MODEL:
410+
# For HF models, get vision tokens from the processor if they exist
411+
vision_tokens = []
412+
else:
413+
vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256]
379414

380415
for user_id in range(max_batch_size):
381416
# Remove <|image|> tokens since they break the tokenizer
382417
tokens_out = [
383-
t if t not in vision_tokens else tokenizer.pad_id
418+
t if t not in vision_tokens else pad_id
384419
for t in tokens[user_id].tolist()[: position_id[user_id] + 2]
385420
]
386421
text = tokenizer.decode(tokens_out)

models/experimental/mistral_24b/tests/test_conv2d.py renamed to models/tt_transformers/tests/multimodal/mistral_24b/test_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import ttnn
1414
from models.tt_transformers.tt.model_config import ModelArgs
15-
from models.experimental.mistral_24b.tt.vision_conv2d import TtMistralConv2dPatch
15+
from models.tt_transformers.tt.multimodal.mistral_24b.vision_conv2d import TtMistralConv2dPatch
1616
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
1717
from ttnn import ConcatMeshToTensor
1818

models/experimental/mistral_24b/tests/test_patch_rot_emb.py renamed to models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from loguru import logger
4+
import os
55

6-
import torch
76
import pytest
8-
import os
9-
import ttnn
7+
import torch
8+
from loguru import logger
109

11-
from models.experimental.mistral_24b.tt.vision_rope import VisionRotarySetup as RotarySetup
10+
import ttnn
1211

13-
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
12+
from models.tt_transformers.tt.multimodal.mistral_24b.vision_rope import VisionRotarySetup as RotarySetup
1413
from models.tt_transformers.tt.model_config import ModelArgs
14+
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
1515

1616

1717
@torch.no_grad()

models/experimental/mistral_24b/tests/test_pixtral_transformer.py renamed to models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_transformer.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@
88
from loguru import logger
99

1010
import ttnn
11-
from models.tt_transformers.tt.ccl import TT_CCL
12-
from models.tt_transformers.tt.model_config import ModelArgs
13-
1411
from models.experimental.mistral_24b.tt.vision_pixtral_transformer import TtPixtralTransformer
12+
from models.tt_transformers.tt.model_config import ModelArgs
1513
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
1614

1715

@@ -29,11 +27,6 @@
2927
],
3028
indirect=True,
3129
)
32-
@pytest.mark.parametrize(
33-
"device_params",
34-
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}],
35-
indirect=True,
36-
)
3730
def test_image_transformer_inference(batch, num_chunks, mesh_device):
3831
pcc_required = 0.99
3932

@@ -58,10 +51,8 @@ def test_image_transformer_inference(batch, num_chunks, mesh_device):
5851

5952
all_tests_pass = True
6053

61-
tt_ccl = TT_CCL(mesh_device)
6254
tt_model = TtPixtralTransformer(
6355
mesh_device,
64-
tt_ccl,
6556
state_dict,
6657
state_dict_prefix=first_layer_prefix,
6758
weight_cache_path=None,

models/experimental/mistral_24b/tests/test_vision_attention.py renamed to models/tt_transformers/tests/multimodal/mistral_24b/test_vision_attention.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,9 @@
88
from loguru import logger
99

1010
import ttnn
11-
from models.tt_transformers.tt.ccl import TT_CCL
11+
from models.experimental.mistral_24b.tt.vision_attention import TtMistralImageAttention as TtLlamaImageAttention
1212
from models.tt_transformers.tt.model_config import ModelArgs
1313
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
14-
15-
from models.experimental.mistral_24b.tt.vision_attention import TtMistralImageAttention as TtLlamaImageAttention
16-
1714
from ttnn import ConcatMeshToTensor
1815

1916

@@ -36,11 +33,6 @@
3633
"batch_size",
3734
(1,),
3835
)
39-
@pytest.mark.parametrize(
40-
"device_params",
41-
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}],
42-
indirect=True,
43-
)
4436
def test_vision_attention(mesh_device, seq_len, batch_size):
4537
logger.info(f"seq_len: {seq_len}, batch_size: {batch_size}")
4638
dtype = ttnn.bfloat16
@@ -61,10 +53,8 @@ def test_vision_attention(mesh_device, seq_len, batch_size):
6153
n_heads = model_args.vision_attn_n_heads
6254
head_dim = hidden_size // n_heads
6355

64-
tt_ccl = TT_CCL(mesh_device)
6556
tt_model = TtLlamaImageAttention(
6657
mesh_device,
67-
tt_ccl,
6858
state_dict,
6959
state_dict_prefix=first_layer_prefix,
7060
weight_cache_path=model_args.weight_cache_path(dtype),

models/experimental/mistral_24b/tests/test_vision_rms.py renamed to models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
1-
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
2-
3-
# SPDX-License-Identifier: Apache-2.0
4-
5-
from loguru import logger
1+
import os
62

7-
import torch
83
import pytest
9-
import os
4+
import torch
5+
from loguru import logger
106

117
import ttnn
12-
from models.experimental.mistral_24b.tt.rmsnorm import RMSNorm
13-
14-
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
15-
8+
from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm
169
from models.tt_transformers.tt.model_config import ModelArgs
10+
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
1711

1812

1913
@torch.no_grad()

models/tt_transformers/tt/common.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import math
66
import re
77
from enum import Enum
8+
from types import SimpleNamespace
89
from typing import Optional
910

1011
import torch
12+
from llama_models.llama3.api.datatypes import ImageMedia
1113
from loguru import logger
1214
from pydantic import AliasChoices, BaseModel, Field
1315

@@ -688,3 +690,46 @@ def create_tt_model(
688690
tt_kv_cache = [l.attention.layer_past for l in model.layers] if paged_attention_config else None
689691

690692
return tt_model_args, model, tt_kv_cache, state_dict
693+
694+
695+
def hf_multimodal_encode(messages, processor):
696+
hf_messages = []
697+
698+
for msg in messages:
699+
hf_content = []
700+
701+
for item in msg.content:
702+
if isinstance(item, ImageMedia):
703+
hf_content.append(
704+
{
705+
"type": "image",
706+
"image": item.image,
707+
}
708+
)
709+
elif isinstance(item, str):
710+
hf_content.append(
711+
{
712+
"type": "text",
713+
"text": item,
714+
}
715+
)
716+
717+
hf_messages.append(
718+
{
719+
"role": msg.role,
720+
"content": hf_content,
721+
}
722+
)
723+
724+
encoded = processor.apply_chat_template(
725+
hf_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
726+
).to("cpu", dtype=torch.bfloat16)
727+
728+
return SimpleNamespace(
729+
**encoded,
730+
tokens=encoded["input_ids"].squeeze(0),
731+
vision=SimpleNamespace(
732+
images=encoded["pixel_values"],
733+
mask=None,
734+
),
735+
)

0 commit comments

Comments
 (0)