Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
530 changes: 530 additions & 0 deletions models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0

"""
This file is a unit test for validating the Mistral-24B Vision Model pipeline.
"""

import os
import pytest
import torch
from loguru import logger

import ttnn
from models.tt_transformers.tt.ccl import TT_CCL
from models.tt_transformers.tt.model_config import ModelArgs
from models.experimental.mistral_24b.tt.pipeline.vision_model import TtMistralVisionTransformer
from models.common.utility_functions import comp_allclose, comp_pcc


def get_image_features(vision_tower, projector, input_tensor, image_sizes):
"""
Get image features from the vision tower and projector.
"""
vision_token = vision_tower(input_tensor, image_sizes).last_hidden_state
image_features = projector(vision_token.squeeze(0), image_sizes)
return image_features


@pytest.mark.parametrize(
"mesh_device",
[
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids())
)
],
indirect=True,
)
@pytest.mark.parametrize(
"device_params",
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}],
indirect=True,
)
def test_mistral_vision_model(mesh_device, reset_seeds):
pcc_required = 0.97
dtype = ttnn.bfloat8_b

model_args = ModelArgs(mesh_device)
state_dict = model_args.load_state_dict()

first_layer_prefix = "vision_tower."
partial_state_dict = {
k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix)
}

##### Reference model output (Torch) #####
reference_model = model_args.reference_vision_model()
reference_model.load_state_dict(partial_state_dict)

mmp_first_layer_prefix = "multi_modal_projector."

mmp_partial_state_dict = {
k[len(mmp_first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(mmp_first_layer_prefix))
}

reference_mmp = model_args.reference_vision_multi_modal()
reference_mmp.load_state_dict(mmp_partial_state_dict)

B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size
input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16)

reference_output = get_image_features(reference_model, reference_mmp, input_tensor, image_sizes=[(H, W)])

# ##### TT Model: TtMistralVisionTransformer #####
tt_ccl = TT_CCL(mesh_device=mesh_device)
vision_model = TtMistralVisionTransformer(
mesh_device=mesh_device,
tt_ccl=tt_ccl,
state_dict=state_dict,
state_dict_prefix=first_layer_prefix,
dtype=dtype,
model_args=model_args,
)

tt_output = vision_model(input_tensor, image_sizes=[(H, W)])
tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[
:, : tt_output.shape[-1]
]

non_zero_indices = tt_output.ne(0).nonzero(as_tuple=True)
tt_output = tt_output[non_zero_indices]
reference_output = reference_output[non_zero_indices]

passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required)

logger.info(comp_allclose(reference_output, tt_output))
logger.info(f"PCC: {pcc_message}")
assert passing, f"PCC below {pcc_required}. {pcc_message}"
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0

"""
This file is a unit test for validating the Mistral-24B Vision Tower model.
"""

import os
import pytest
import torch
from loguru import logger

import ttnn
from models.tt_transformers.tt.ccl import TT_CCL
from models.tt_transformers.tt.model_config import ModelArgs
from models.experimental.mistral_24b.tt.pipeline.mistral_vision_tower import MistralVisionTower
from models.common.utility_functions import comp_allclose, comp_pcc


@pytest.mark.parametrize(
"mesh_device",
[
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids())
)
],
indirect=True,
)
@pytest.mark.parametrize(
"device_params",
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}],
indirect=True,
)
def test_mistral_vision_tower(mesh_device, reset_seeds):
pcc_required = 0.99
dtype = ttnn.bfloat16

model_args = ModelArgs(mesh_device)
state_dict = model_args.load_state_dict()

first_layer_prefix = "vision_tower."
partial_state_dict = {
k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix)
}

B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size
input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16)
print("state_dict ", state_dict.keys())
##### Reference model output (Torch) #####
reference_model = model_args.reference_vision_model()
reference_model.load_state_dict(partial_state_dict)
reference_output = reference_model(input_tensor, image_sizes=[(H, W)])

reference_output = reference_output.last_hidden_state
tt_ccl = TT_CCL(mesh_device)
##### TT Model: MistralVisionTower #####
vision_model = MistralVisionTower(
mesh_device=mesh_device,
tt_ccl=tt_ccl,
state_dict=state_dict,
state_dict_prefix=first_layer_prefix,
dtype=dtype,
configuration=model_args,
)

tt_output = vision_model(input_tensor, image_sizes=[(H, W)])
tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[
:, :, :, : tt_output.shape[-1]
]
tt_output = tt_output.squeeze(0)
passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required)

logger.info(comp_allclose(reference_output, tt_output))
logger.info(f"PCC: {pcc_message}")
assert passing, f"PCC below {pcc_required}. {pcc_message}"
101 changes: 101 additions & 0 deletions models/experimental/mistral_24b/tests/test_conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0

"""
This file is a unit test for validating the Mistral-24B conv2d.
"""
import os

import pytest
import torch
from loguru import logger

import ttnn
from models.tt_transformers.tt.model_config import ModelArgs
from models.experimental.mistral_24b.tt.vision_conv2d import TtMistralConv2dPatch
from models.common.utility_functions import comp_allclose, comp_pcc
from ttnn import ConcatMeshToTensor


@pytest.mark.parametrize(
"mesh_device",
[
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids())
)
],
indirect=True,
)
def test_conv2d_inference(
mesh_device,
reset_seeds,
):
pcc_required = 0.9999
dtype = ttnn.bfloat16

model_args = ModelArgs(mesh_device)
state_dict = model_args.load_state_dict()

# Ref model needs partial state dict, but our models use full state dict keys as cached weight names
first_layer_prefix = "vision_tower.patch_conv."
partial_state_dict = {
k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix))
}
num_devices = model_args.num_devices

##### Create input tensor for the all gather #####
B, NCH, H, W = (1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size)
in_channels, out_channels, kernel_size, stride, bias = (
3,
model_args.vision_dim,
model_args.vision_patch_size,
model_args.vision_patch_size,
False,
)

assert NCH == in_channels, "Number of channels in input tensor should match in_channels for the Conv2d patch."
assert type(kernel_size) == int, "Only symmetric kernel_size is currently supported."
assert kernel_size == stride, "Only same kernel_size and stride are currently supported."

assert H % kernel_size == 0, "Height should be divisible by kernel_size."
assert W % kernel_size == 0, "Width should be divisible by kernel_size."

##### Prepare inputs #####
input_tensor = torch.randn((B, NCH, H, W)).to(dtype=torch.bfloat16)
logger.info(f"Input tensor shape: {input_tensor.shape}")

reference_model = model_args.reference_conv2d_patch()
reference_model.load_state_dict(partial_state_dict)
reference_output = reference_model(input_tensor)

tt_model = TtMistralConv2dPatch(
mesh_device,
state_dict,
first_layer_prefix,
dtype,
in_channels,
out_channels,
kernel_size,
stride,
bias,
)
tt_output = tt_model(input_tensor)

##### Check the outputs #####
out = ttnn.from_device(tt_output)
tt_output_torch = ttnn.to_torch(out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=2))

# Only select output from one device
tt_output_torch = tt_output_torch[0, ..., :out_channels]

# 1. Restore batch dim
tt_output_torch = tt_output_torch.unsqueeze(0)
# 1 1024 4096
# 2. Permute to match Conv2D output: (N, C_out, H_out, W_out)
tt_output_torch = tt_output_torch.permute(0, 2, 1).reshape(1, 1024, 110, 110)

passing, pcc_message = comp_pcc(reference_output, tt_output_torch)

logger.info(comp_allclose(reference_output, tt_output_torch))
logger.info(f"PCC: {pcc_message}")
assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!"
93 changes: 93 additions & 0 deletions models/experimental/mistral_24b/tests/test_patch_rot_emb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0

from loguru import logger

import torch
import pytest
import os
import ttnn

from models.experimental.mistral_24b.tt.vision_rope import VisionRotarySetup as RotarySetup

from models.common.utility_functions import comp_allclose, comp_pcc
from models.tt_transformers.tt.model_config import ModelArgs


@torch.no_grad()
@pytest.mark.parametrize(
"device",
[
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
os.environ.get("device"), len(ttnn.get_device_ids())
)
],
indirect=True,
)
@pytest.mark.parametrize(
"seq_len",
(128,),
)
@pytest.mark.parametrize(
"batch_size",
(1,),
)
def test_rot_emb(seq_len, batch_size, reset_seeds, device):
dtype = ttnn.bfloat16
mode = "decode" if seq_len <= 32 else "prefill"

tt_model_args = ModelArgs(
device,
max_batch_size=batch_size,
max_seq_len=128,
)

tt_model_args.n_layers = 1
partial_state_dict = {}

reference_model = tt_model_args.reference_vision_rot_emb()
reference_model.load_state_dict(partial_state_dict)

image_size = tt_model_args.vision_image_size
patch_size = tt_model_args.vision_patch_size
dim = tt_model_args.vision_head_dim
num_patches_per_dim = image_size // patch_size
num_patches = num_patches_per_dim * num_patches_per_dim
position_ids = torch.arange(4096, dtype=torch.long)

x = torch.randn(batch_size, 4096, 1024)

cos, sin = reference_model(x, position_ids)
tt_model = RotarySetup(
device,
batch_size,
dim,
image_size,
patch_size,
num_patches,
tt_model_args.vision_rope_theta,
scale_factor=None,
orig_context_len=num_patches,
datatype=dtype,
)

cos2, sin2 = tt_model.get_rot_mats(position_ids)
cos2 = ttnn.from_device(cos2)
cos2 = ttnn.to_torch(cos2)
cos2 = cos2.squeeze(0)

sin2 = ttnn.from_device(sin2)
sin2 = ttnn.to_torch(sin2)
sin2 = sin2.squeeze(0)

passing, pcc_message = comp_pcc(cos, cos2)

logger.info(comp_allclose(cos, cos2))
logger.info(f"PCC: {pcc_message}")
assert passing, f"COS PCC value is lower than {0.99} for some of the outputs. Check Warnings!"

passing, pcc_message = comp_pcc(sin, sin2)

logger.info(comp_allclose(sin, sin2))
logger.info(f"PCC: {pcc_message}")
assert passing, f"SIN PCC value is lower than {0.99} for some of the outputs. Check Warnings!"
Loading