Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 60 additions & 25 deletions models/tt_transformers/demo/simple_vision_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import ttnn
from models.demos.utils.llm_demo_utils import create_benchmark_data, verify_perf
from models.perf.benchmarking_utils import BenchmarkProfiler
from models.tt_transformers.tt.common import hf_multimodal_encode
from models.tt_transformers.tt.generator import Generator
from models.tt_transformers.tt.model_config import CheckpointType


def get_batch_sampler(temperature, top_p, tokenizer):
Expand Down Expand Up @@ -62,6 +64,7 @@ def create_multimodal_model(
):
from models.tt_transformers.tt.model_config import ModelArgs
from models.tt_transformers.tt.multimodal.llama_vision_model import CrossAttentionTransformer
from models.tt_transformers.tt.multimodal.mistral_24b.mistral_e2e_model import MistralTransformer

tt_model_args = ModelArgs(mesh_device, max_batch_size=max_batch_size)
assert tt_model_args.is_vision(), "This model is multimodal"
Expand All @@ -76,14 +79,25 @@ def create_multimodal_model(

if checkpoint is None:
checkpoint = tt_model_args.load_state_dict()
model = CrossAttentionTransformer(
mesh_device,
state_dict=checkpoint,
weight_cache_path=tt_model_args.weight_cache_path(dtype),
dtype=dtype,
configuration=tt_model_args,
use_paged_kv_cache=use_paged_kv_cache,
)

if tt_model_args.base_model_name == "Mistral-Small-3.1-24B":
model = MistralTransformer(
mesh_device=mesh_device,
state_dict=checkpoint,
weight_cache_path=tt_model_args.weight_cache_path(ttnn.bfloat8_b),
dtype=ttnn.bfloat8_b,
args=tt_model_args,
use_paged_kv_cache=use_paged_kv_cache,
)
else:
model = CrossAttentionTransformer(
mesh_device,
state_dict=checkpoint,
weight_cache_path=tt_model_args.weight_cache_path(dtype),
dtype=dtype,
configuration=tt_model_args,
use_paged_kv_cache=use_paged_kv_cache,
)
return tt_model_args, model, checkpoint


Expand Down Expand Up @@ -136,7 +150,7 @@ def prepare_generator_args(
)
@pytest.mark.parametrize(
"test_type,max_seq_len",
(("normal", 512),),
(("normal", 2048),),
ids=["normal"],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -182,9 +196,6 @@ def test_multimodal_demo_text(
profiler = BenchmarkProfiler()
profiler.start("run")

ckpt_dir = os.environ["LLAMA_DIR"]
tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model")

num_devices = mesh_device.get_num_devices() if isinstance(mesh_device, ttnn.MeshDevice) else 1
max_batch_size *= data_parallel # input batch_size is interpreted as size per DP group

Expand All @@ -195,11 +206,27 @@ def test_multimodal_demo_text(
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
)

HF_MODEL = model_args[0].checkpoint_type == CheckpointType.HuggingFace

if not HF_MODEL:
ckpt_dir = os.environ["LLAMA_DIR"]
tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model")

tokenizer = Tokenizer(model_path=tokenizer_path)
formatter = ChatFormat(tokenizer)
else:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(model_args[0].CKPT_DIR)
tokenizer = model_args[0].tokenizer

generator = Generator(model, model_args, mesh_device)
tokenizer = Tokenizer(model_path=tokenizer_path)
formatter = ChatFormat(tokenizer)

xattn_caches = [model.setup_cache(model_args[i].max_batch_size) for i, model in enumerate(generator.model)]
xattn_caches = [
model.setup_cache(model_args[i].max_batch_size) if not HF_MODEL else None
for i, model in enumerate(generator.model)
]

# Create random images for trace capture with specific dimensions
trace_img_560x560 = create_random_image(560, 560)
Expand Down Expand Up @@ -264,6 +291,8 @@ def test_multimodal_demo_text(
_num_prefill_tokens = 0
_num_decode_tokens = 0

prompt_encoder = hf_multimodal_encode if HF_MODEL else formatter.encode_dialog_prompt

for iter_num in range(warmup_iters + 1):
logger.info(f"Iteration {iter_num}")
current_dialogs = trace_dialogs + dialogs
Expand All @@ -273,9 +302,15 @@ def test_multimodal_demo_text(
for msg in dialog:
print(f"{msg.role.capitalize()}: {msg.content}\n")
batch_model_input = [
formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) for dialog in batch_dialogs
prompt_encoder(dialog, processor) if HF_MODEL else prompt_encoder(dialog, tool_prompt_format=False)
for dialog in batch_dialogs
]

if HF_MODEL:
image_sizes = [model_input.image_sizes for model_input in batch_model_input]
else:
image_sizes = None

# Do initial prefill
vision_images = [
model_input.vision.images if model_input.vision else None for model_input in batch_model_input
Expand All @@ -288,7 +323,7 @@ def test_multimodal_demo_text(
total_lens = prefill_lens + max_gen_len

# Create padded tokens tensor for batch
pad_id = tokenizer.pad_id
pad_id = tokenizer.pad_token_id if HF_MODEL else tokenizer.pad_id
bsz = len(prompt_tokens)
tokens = torch.full((bsz, max(total_lens)), pad_id, dtype=torch.long)

Expand All @@ -312,6 +347,7 @@ def test_multimodal_demo_text(
xattn_caches,
total_lens,
prefill_lens,
image_sizes=image_sizes,
)

# Get cached prefill time
Expand All @@ -323,12 +359,7 @@ def test_multimodal_demo_text(
decode_batch_xattn_masks,
decode_batch_text_masks,
) = generator.prefill_forward(
vision_images,
vision_mask,
tokens,
xattn_caches,
total_lens,
prefill_lens,
vision_images, vision_mask, tokens, xattn_caches, total_lens, prefill_lens, image_sizes=image_sizes
)

prefill_end = time.perf_counter()
Expand Down Expand Up @@ -375,12 +406,16 @@ def test_multimodal_demo_text(
) # gen_idx is (num_tokens - 1) to avoid counting compile iter

# Log full text output for each user in batch
vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256]
if HF_MODEL:
# For HF models, get vision tokens from the processor if they exist
vision_tokens = []
else:
vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256]

for user_id in range(max_batch_size):
# Remove <|image|> tokens since they break the tokenizer
tokens_out = [
t if t not in vision_tokens else tokenizer.pad_id
t if t not in vision_tokens else pad_id
for t in tokens[user_id].tolist()[: position_id[user_id] + 2]
]
text = tokenizer.decode(tokens_out)
Expand Down
102 changes: 102 additions & 0 deletions models/tt_transformers/tests/multimodal/mistral_24b/test_conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0

"""
This file is a unit test for validating the Mistral-24B conv2d.
"""
import os

import pytest
import torch
from loguru import logger

import ttnn
from models.tt_transformers.tt.model_config import ModelArgs
from models.tt_transformers.tt.multimodal.mistral_24b.vision_conv2d import TtMistralConv2dPatch
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
from ttnn import ConcatMeshToTensor


@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize(
"mesh_device",
[
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids())
)
],
indirect=True,
)
def test_conv2d_inference(
mesh_device,
reset_seeds,
):
pcc_required = 0.9999
dtype = ttnn.bfloat16

model_args = ModelArgs(mesh_device)
state_dict = model_args.load_state_dict()

# Ref model needs partial state dict, but our models use full state dict keys as cached weight names
first_layer_prefix = "vision_tower.patch_conv."
partial_state_dict = {
k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix))
}
num_devices = model_args.num_devices

##### Create input tensor for the all gather #####
B, NCH, H, W = (1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size)
in_channels, out_channels, kernel_size, stride, bias = (
3,
model_args.vision_dim,
model_args.vision_patch_size,
model_args.vision_patch_size,
False,
)

assert NCH == in_channels, "Number of channels in input tensor should match in_channels for the Conv2d patch."
assert type(kernel_size) == int, "Only symmetric kernel_size is currently supported."
assert kernel_size == stride, "Only same kernel_size and stride are currently supported."

assert H % kernel_size == 0, "Height should be divisible by kernel_size."
assert W % kernel_size == 0, "Width should be divisible by kernel_size."

##### Prepare inputs #####
input_tensor = torch.randn((B, NCH, H, W)).to(dtype=torch.bfloat16)
logger.info(f"Input tensor shape: {input_tensor.shape}")

reference_model = model_args.reference_conv2d_patch()
reference_model.load_state_dict(partial_state_dict)
reference_output = reference_model(input_tensor)

tt_model = TtMistralConv2dPatch(
mesh_device,
state_dict,
first_layer_prefix,
dtype,
in_channels,
out_channels,
kernel_size,
stride,
bias,
)
tt_output = tt_model(input_tensor)

##### Check the outputs #####
out = ttnn.from_device(tt_output)
tt_output_torch = ttnn.to_torch(out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=2))

# Only select output from one device
tt_output_torch = tt_output_torch[0, ..., :out_channels]

# 1. Restore batch dim
tt_output_torch = tt_output_torch.unsqueeze(0)
# 1 1024 4096
# 2. Permute to match Conv2D output: (N, C_out, H_out, W_out)
tt_output_torch = tt_output_torch.permute(0, 2, 1).reshape(1, 1024, 64, 64)

passing, pcc_message = comp_pcc(reference_output, tt_output_torch)

logger.info(comp_allclose(reference_output, tt_output_torch))
logger.info(f"PCC: {pcc_message}")
assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!"
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0

import os

import pytest
import torch
from loguru import logger

import ttnn
from models.tt_transformers.tt.model_config import ModelArgs
from models.tt_transformers.tt.multimodal.mistral_24b.vision_rope import VisionRotarySetup as RotarySetup
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull


@torch.no_grad()
@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize(
"device",
[
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
os.environ.get("device"), len(ttnn.get_device_ids())
)
],
indirect=True,
)
@pytest.mark.parametrize(
"seq_len",
(128,),
)
@pytest.mark.parametrize(
"batch_size",
(1,),
)
def test_rot_emb(seq_len, batch_size, reset_seeds, device):
dtype = ttnn.bfloat16
mode = "decode" if seq_len <= 32 else "prefill"

tt_model_args = ModelArgs(
device,
max_batch_size=batch_size,
max_seq_len=128,
)

tt_model_args.n_layers = 1
partial_state_dict = {}

reference_model = tt_model_args.reference_vision_rot_emb()
reference_model.load_state_dict(partial_state_dict)

image_size = tt_model_args.vision_image_size
patch_size = tt_model_args.vision_patch_size
dim = tt_model_args.vision_head_dim
num_patches_per_dim = image_size // patch_size
num_patches = num_patches_per_dim * num_patches_per_dim
position_ids = torch.arange(4096, dtype=torch.long)

x = torch.randn(batch_size, 4096, 1024)

cos, sin = reference_model(x, position_ids)
tt_model = RotarySetup(
device,
batch_size,
dim,
image_size,
patch_size,
num_patches,
tt_model_args.vision_rope_theta,
scale_factor=None,
orig_context_len=num_patches,
datatype=dtype,
)

cos2, sin2 = tt_model.get_rot_mats(position_ids)
cos2 = ttnn.from_device(cos2)
cos2 = ttnn.to_torch(cos2)
cos2 = cos2.squeeze(0)

sin2 = ttnn.from_device(sin2)
sin2 = ttnn.to_torch(sin2)
sin2 = sin2.squeeze(0)

passing, pcc_message = comp_pcc(cos, cos2)

logger.info(comp_allclose(cos, cos2))
logger.info(f"PCC: {pcc_message}")
assert passing, f"COS PCC value is lower than {0.99} for some of the outputs. Check Warnings!"

passing, pcc_message = comp_pcc(sin, sin2)

logger.info(comp_allclose(sin, sin2))
logger.info(f"PCC: {pcc_message}")
assert passing, f"SIN PCC value is lower than {0.99} for some of the outputs. Check Warnings!"
Loading