Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ade16a2
WIP LoadCheckpoints Mistral 27B
Jul 3, 2025
c9488f2
Setup for Mistral 27b
Jul 3, 2025
1d61002
Setup for Mistral 24b
Jul 3, 2025
96d7e11
MLP Support added
Jul 3, 2025
397a105
RMSNorm and Patch Conv completed
Jul 4, 2025
cbe3a9d
Refactor Conv2D Patch
Jul 4, 2025
7929da4
WIP PixtralRotaryEmbedding
Jul 4, 2025
82039a0
Add vision attn test
nikileshx Jul 4, 2025
c58dd29
[WIP] Add RoPE tests
nikileshx Jul 4, 2025
465eb91
Complete vision RoPE and tests
nikileshx Jul 7, 2025
6b26dbd
Construct vision block pipeline
nikileshx Jul 7, 2025
5eeefca
Integrate vision block and debug rot_emb
nikileshx Jul 8, 2025
a61df3e
Integrate attn to pipeline
nikileshx Jul 8, 2025
df2403b
[WIP] Fix the VisionAttn and Enable vision tower: 0.75
nikileshx Jul 11, 2025
c2c0347
Refactor the vision_transformer tests
nikileshx Jul 16, 2025
68aaa3d
Add MMP and integrate vision tower
nikileshx Jul 19, 2025
ff87343
Debug the PixtralVisionModel
nikileshx Jul 22, 2025
a9d78d9
Add E2E test and handle modules
nikileshx Jul 22, 2025
08c6a6f
Add vision_model and MMP together
nikileshx Jul 22, 2025
b27f255
Enable E2E pipeline
nikileshx Jul 25, 2025
488dea0
Fix E2E pipeline prefill generation
nikileshx Jul 30, 2025
3201373
Refactor E2E
nikileshx Aug 11, 2025
9c16b20
Rebase Mistral-24b branch to align latest load_checkpoints
nikileshx Aug 11, 2025
4c7207b
mcw/dev_mistral-3.1-24b-instruct_branch
nikileshx Aug 13, 2025
4c0ccb4
Migrate mistral-24B to tt-transformers
MohammedTaherMcW Aug 14, 2025
8836378
mistral migration pr latest change
GanesanMulticoreware Aug 15, 2025
830beea
refactor the test_script
GanesanMulticoreware Aug 16, 2025
44e1f06
Fix: updated vision demo and conv2d tests for mistral_24B migration
GanesanMulticoreware Aug 18, 2025
be48e24
code cleanup
GanesanMulticoreware Aug 18, 2025
7fc99e8
add end to end test script
GanesanMulticoreware Aug 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 51 additions & 19 deletions models/tt_transformers/demo/simple_vision_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import ttnn
from models.demos.utils.llm_demo_utils import create_benchmark_data, verify_perf
from models.perf.benchmarking_utils import BenchmarkProfiler
from models.tt_transformers.tt.common import hf_multimodal_encode
from models.tt_transformers.tt.generator import Generator
from models.tt_transformers.tt.model_config import CheckpointType


def get_batch_sampler(temperature, top_p, tokenizer):
Expand Down Expand Up @@ -62,6 +64,7 @@ def create_multimodal_model(
):
from models.tt_transformers.tt.model_config import ModelArgs
from models.tt_transformers.tt.multimodal.llama_vision_model import CrossAttentionTransformer
from models.tt_transformers.tt.multimodal.mistral_24b.model import MistralTransformer

tt_model_args = ModelArgs(mesh_device, max_batch_size=max_batch_size)
assert tt_model_args.is_vision(), "This model is multimodal"
Expand All @@ -76,14 +79,25 @@ def create_multimodal_model(

if checkpoint is None:
checkpoint = tt_model_args.load_state_dict()
model = CrossAttentionTransformer(
mesh_device,
state_dict=checkpoint,
weight_cache_path=tt_model_args.weight_cache_path(dtype),
dtype=dtype,
configuration=tt_model_args,
use_paged_kv_cache=use_paged_kv_cache,
)

if tt_model_args.base_model_name == "Mistral-Small-3.1-24B":
model = MistralTransformer(
mesh_device=mesh_device,
state_dict=checkpoint,
weight_cache_path=tt_model_args.weight_cache_path(ttnn.bfloat8_b),
dtype=ttnn.bfloat8_b,
args=tt_model_args,
use_paged_kv_cache=use_paged_kv_cache,
)
else:
model = CrossAttentionTransformer(
mesh_device,
state_dict=checkpoint,
weight_cache_path=tt_model_args.weight_cache_path(dtype),
dtype=dtype,
configuration=tt_model_args,
use_paged_kv_cache=use_paged_kv_cache,
)
return tt_model_args, model, checkpoint


Expand Down Expand Up @@ -136,7 +150,7 @@ def prepare_generator_args(
)
@pytest.mark.parametrize(
"test_type,max_seq_len",
(("normal", 512),),
(("normal", 2048),),
ids=["normal"],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -182,9 +196,6 @@ def test_multimodal_demo_text(
profiler = BenchmarkProfiler()
profiler.start("run")

ckpt_dir = os.environ["LLAMA_DIR"]
tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model")

num_devices = mesh_device.get_num_devices() if isinstance(mesh_device, ttnn.MeshDevice) else 1
max_batch_size *= data_parallel # input batch_size is interpreted as size per DP group

Expand All @@ -195,11 +206,26 @@ def test_multimodal_demo_text(
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
)

HF_MODEL = model_args[0].checkpoint_type == CheckpointType.HuggingFace

if not HF_MODEL:
ckpt_dir = os.environ["LLAMA_DIR"]
tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model")

tokenizer = Tokenizer(model_path=tokenizer_path)
formatter = ChatFormat(tokenizer)
else:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(model_args[0].CKPT_DIR)

generator = Generator(model, model_args, mesh_device)
tokenizer = Tokenizer(model_path=tokenizer_path)
formatter = ChatFormat(tokenizer)

xattn_caches = [model.setup_cache(model_args[i].max_batch_size) for i, model in enumerate(generator.model)]
xattn_caches = [
model.setup_cache(model_args[i].max_batch_size) if not HF_MODEL else None
for i, model in enumerate(generator.model)
]

# Create random images for trace capture with specific dimensions
trace_img_560x560 = create_random_image(560, 560)
Expand Down Expand Up @@ -260,10 +286,12 @@ def test_multimodal_demo_text(
total_users = len(dialogs)
num_batches = total_users // max_batch_size

sampler = get_batch_sampler(temperature, top_p, tokenizer)
sampler = get_batch_sampler(temperature, top_p, model_args[0].tokenizer)
_num_prefill_tokens = 0
_num_decode_tokens = 0

prompt_encoder = hf_multimodal_encode if HF_MODEL else formatter.encode_dialog_prompt

for iter_num in range(warmup_iters + 1):
logger.info(f"Iteration {iter_num}")
current_dialogs = trace_dialogs + dialogs
Expand All @@ -273,7 +301,8 @@ def test_multimodal_demo_text(
for msg in dialog:
print(f"{msg.role.capitalize()}: {msg.content}\n")
batch_model_input = [
formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) for dialog in batch_dialogs
prompt_encoder(dialog, processor) if HF_MODEL else prompt_encoder(dialog, tool_prompt_format=False)
for dialog in batch_dialogs
]

# Do initial prefill
Expand All @@ -288,7 +317,7 @@ def test_multimodal_demo_text(
total_lens = prefill_lens + max_gen_len

# Create padded tokens tensor for batch
pad_id = tokenizer.pad_id
pad_id = tokenizer.pad_token_id if HF_MODEL else tokenizer.pad_id
bsz = len(prompt_tokens)
tokens = torch.full((bsz, max(total_lens)), pad_id, dtype=torch.long)

Expand All @@ -312,6 +341,7 @@ def test_multimodal_demo_text(
xattn_caches,
total_lens,
prefill_lens,

)

# Get cached prefill time
Expand All @@ -329,6 +359,7 @@ def test_multimodal_demo_text(
xattn_caches,
total_lens,
prefill_lens,

)

prefill_end = time.perf_counter()
Expand Down Expand Up @@ -375,12 +406,13 @@ def test_multimodal_demo_text(
) # gen_idx is (num_tokens - 1) to avoid counting compile iter

# Log full text output for each user in batch

vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256]

for user_id in range(max_batch_size):
# Remove <|image|> tokens since they break the tokenizer
tokens_out = [
t if t not in vision_tokens else tokenizer.pad_id
t if t not in vision_tokens else pad_id
for t in tokens[user_id].tolist()[: position_id[user_id] + 2]
]
text = tokenizer.decode(tokens_out)
Expand Down
Loading