Skip to content

Commit c216b15

Browse files
nikileshxMohammedTaherMcW
authored andcommitted
Add Support for mistralai/Mistral-Small-3.1-24B-Instruct-2503 model
1 parent 4b989ef commit c216b15

23 files changed

+3307
-72
lines changed

models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py

Lines changed: 530 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
This file is a unit test for validating the Mistral-24B Vision Model pipeline.
6+
"""
7+
8+
import os
9+
import pytest
10+
import torch
11+
from loguru import logger
12+
13+
import ttnn
14+
from models.tt_transformers.tt.ccl import TT_CCL
15+
from models.tt_transformers.tt.model_config import ModelArgs
16+
from models.experimental.mistral_24b.tt.pipeline.vision_model import TtMistralVisionTransformer
17+
from models.common.utility_functions import comp_allclose, comp_pcc
18+
19+
20+
def get_image_features(vision_tower, projector, input_tensor, image_sizes):
21+
"""
22+
Get image features from the vision tower and projector.
23+
"""
24+
vision_token = vision_tower(input_tensor, image_sizes).last_hidden_state
25+
image_features = projector(vision_token.squeeze(0), image_sizes)
26+
return image_features
27+
28+
29+
@pytest.mark.parametrize(
30+
"mesh_device",
31+
[
32+
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
33+
os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids())
34+
)
35+
],
36+
indirect=True,
37+
)
38+
@pytest.mark.parametrize(
39+
"device_params",
40+
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}],
41+
indirect=True,
42+
)
43+
def test_mistral_vision_model(mesh_device, reset_seeds):
44+
pcc_required = 0.97
45+
dtype = ttnn.bfloat8_b
46+
47+
model_args = ModelArgs(mesh_device)
48+
state_dict = model_args.load_state_dict()
49+
50+
first_layer_prefix = "vision_tower."
51+
partial_state_dict = {
52+
k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix)
53+
}
54+
55+
##### Reference model output (Torch) #####
56+
reference_model = model_args.reference_vision_model()
57+
reference_model.load_state_dict(partial_state_dict)
58+
59+
mmp_first_layer_prefix = "multi_modal_projector."
60+
61+
mmp_partial_state_dict = {
62+
k[len(mmp_first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(mmp_first_layer_prefix))
63+
}
64+
65+
reference_mmp = model_args.reference_vision_multi_modal()
66+
reference_mmp.load_state_dict(mmp_partial_state_dict)
67+
68+
B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size
69+
input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16)
70+
71+
reference_output = get_image_features(reference_model, reference_mmp, input_tensor, image_sizes=[(H, W)])
72+
73+
# ##### TT Model: TtMistralVisionTransformer #####
74+
tt_ccl = TT_CCL(mesh_device=mesh_device)
75+
vision_model = TtMistralVisionTransformer(
76+
mesh_device=mesh_device,
77+
tt_ccl=tt_ccl,
78+
state_dict=state_dict,
79+
state_dict_prefix=first_layer_prefix,
80+
dtype=dtype,
81+
model_args=model_args,
82+
)
83+
84+
tt_output = vision_model(input_tensor, image_sizes=[(H, W)])
85+
tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[
86+
:, : tt_output.shape[-1]
87+
]
88+
89+
non_zero_indices = tt_output.ne(0).nonzero(as_tuple=True)
90+
tt_output = tt_output[non_zero_indices]
91+
reference_output = reference_output[non_zero_indices]
92+
93+
passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required)
94+
95+
logger.info(comp_allclose(reference_output, tt_output))
96+
logger.info(f"PCC: {pcc_message}")
97+
assert passing, f"PCC below {pcc_required}. {pcc_message}"
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
This file is a unit test for validating the Mistral-24B Vision Tower model.
6+
"""
7+
8+
import os
9+
import pytest
10+
import torch
11+
from loguru import logger
12+
13+
import ttnn
14+
from models.tt_transformers.tt.ccl import TT_CCL
15+
from models.tt_transformers.tt.model_config import ModelArgs
16+
from models.experimental.mistral_24b.tt.pipeline.mistral_vision_tower import MistralVisionTower
17+
from models.common.utility_functions import comp_allclose, comp_pcc
18+
19+
20+
@pytest.mark.parametrize(
21+
"mesh_device",
22+
[
23+
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
24+
os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids())
25+
)
26+
],
27+
indirect=True,
28+
)
29+
@pytest.mark.parametrize(
30+
"device_params",
31+
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}],
32+
indirect=True,
33+
)
34+
def test_mistral_vision_tower(mesh_device, reset_seeds):
35+
pcc_required = 0.99
36+
dtype = ttnn.bfloat16
37+
38+
model_args = ModelArgs(mesh_device)
39+
state_dict = model_args.load_state_dict()
40+
41+
first_layer_prefix = "vision_tower."
42+
partial_state_dict = {
43+
k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix)
44+
}
45+
46+
B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size
47+
input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16)
48+
print("state_dict ", state_dict.keys())
49+
##### Reference model output (Torch) #####
50+
reference_model = model_args.reference_vision_model()
51+
reference_model.load_state_dict(partial_state_dict)
52+
reference_output = reference_model(input_tensor, image_sizes=[(H, W)])
53+
54+
reference_output = reference_output.last_hidden_state
55+
tt_ccl = TT_CCL(mesh_device)
56+
##### TT Model: MistralVisionTower #####
57+
vision_model = MistralVisionTower(
58+
mesh_device=mesh_device,
59+
tt_ccl=tt_ccl,
60+
state_dict=state_dict,
61+
state_dict_prefix=first_layer_prefix,
62+
dtype=dtype,
63+
configuration=model_args,
64+
)
65+
66+
tt_output = vision_model(input_tensor, image_sizes=[(H, W)])
67+
tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[
68+
:, :, :, : tt_output.shape[-1]
69+
]
70+
tt_output = tt_output.squeeze(0)
71+
passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required)
72+
73+
logger.info(comp_allclose(reference_output, tt_output))
74+
logger.info(f"PCC: {pcc_message}")
75+
assert passing, f"PCC below {pcc_required}. {pcc_message}"
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
This file is a unit test for validating the Mistral-24B conv2d.
6+
"""
7+
import os
8+
9+
import pytest
10+
import torch
11+
from loguru import logger
12+
13+
import ttnn
14+
from models.tt_transformers.tt.model_config import ModelArgs
15+
from models.experimental.mistral_24b.tt.vision_conv2d import TtMistralConv2dPatch
16+
from models.common.utility_functions import comp_allclose, comp_pcc
17+
from ttnn import ConcatMeshToTensor
18+
19+
20+
@pytest.mark.parametrize(
21+
"mesh_device",
22+
[
23+
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
24+
os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids())
25+
)
26+
],
27+
indirect=True,
28+
)
29+
def test_conv2d_inference(
30+
mesh_device,
31+
reset_seeds,
32+
):
33+
pcc_required = 0.9999
34+
dtype = ttnn.bfloat16
35+
36+
model_args = ModelArgs(mesh_device)
37+
state_dict = model_args.load_state_dict()
38+
39+
# Ref model needs partial state dict, but our models use full state dict keys as cached weight names
40+
first_layer_prefix = "vision_tower.patch_conv."
41+
partial_state_dict = {
42+
k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix))
43+
}
44+
num_devices = model_args.num_devices
45+
46+
##### Create input tensor for the all gather #####
47+
B, NCH, H, W = (1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size)
48+
in_channels, out_channels, kernel_size, stride, bias = (
49+
3,
50+
model_args.vision_dim,
51+
model_args.vision_patch_size,
52+
model_args.vision_patch_size,
53+
False,
54+
)
55+
56+
assert NCH == in_channels, "Number of channels in input tensor should match in_channels for the Conv2d patch."
57+
assert type(kernel_size) == int, "Only symmetric kernel_size is currently supported."
58+
assert kernel_size == stride, "Only same kernel_size and stride are currently supported."
59+
60+
assert H % kernel_size == 0, "Height should be divisible by kernel_size."
61+
assert W % kernel_size == 0, "Width should be divisible by kernel_size."
62+
63+
##### Prepare inputs #####
64+
input_tensor = torch.randn((B, NCH, H, W)).to(dtype=torch.bfloat16)
65+
logger.info(f"Input tensor shape: {input_tensor.shape}")
66+
67+
reference_model = model_args.reference_conv2d_patch()
68+
reference_model.load_state_dict(partial_state_dict)
69+
reference_output = reference_model(input_tensor)
70+
71+
tt_model = TtMistralConv2dPatch(
72+
mesh_device,
73+
state_dict,
74+
first_layer_prefix,
75+
dtype,
76+
in_channels,
77+
out_channels,
78+
kernel_size,
79+
stride,
80+
bias,
81+
)
82+
tt_output = tt_model(input_tensor)
83+
84+
##### Check the outputs #####
85+
out = ttnn.from_device(tt_output)
86+
tt_output_torch = ttnn.to_torch(out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=2))
87+
88+
# Only select output from one device
89+
tt_output_torch = tt_output_torch[0, ..., :out_channels]
90+
91+
# 1. Restore batch dim
92+
tt_output_torch = tt_output_torch.unsqueeze(0)
93+
# 1 1024 4096
94+
# 2. Permute to match Conv2D output: (N, C_out, H_out, W_out)
95+
tt_output_torch = tt_output_torch.permute(0, 2, 1).reshape(1, 1024, 110, 110)
96+
97+
passing, pcc_message = comp_pcc(reference_output, tt_output_torch)
98+
99+
logger.info(comp_allclose(reference_output, tt_output_torch))
100+
logger.info(f"PCC: {pcc_message}")
101+
assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!"
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from loguru import logger
5+
6+
import torch
7+
import pytest
8+
import os
9+
import ttnn
10+
11+
from models.experimental.mistral_24b.tt.vision_rope import VisionRotarySetup as RotarySetup
12+
13+
from models.common.utility_functions import comp_allclose, comp_pcc
14+
from models.tt_transformers.tt.model_config import ModelArgs
15+
16+
17+
@torch.no_grad()
18+
@pytest.mark.parametrize(
19+
"device",
20+
[
21+
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
22+
os.environ.get("device"), len(ttnn.get_device_ids())
23+
)
24+
],
25+
indirect=True,
26+
)
27+
@pytest.mark.parametrize(
28+
"seq_len",
29+
(128,),
30+
)
31+
@pytest.mark.parametrize(
32+
"batch_size",
33+
(1,),
34+
)
35+
def test_rot_emb(seq_len, batch_size, reset_seeds, device):
36+
dtype = ttnn.bfloat16
37+
mode = "decode" if seq_len <= 32 else "prefill"
38+
39+
tt_model_args = ModelArgs(
40+
device,
41+
max_batch_size=batch_size,
42+
max_seq_len=128,
43+
)
44+
45+
tt_model_args.n_layers = 1
46+
partial_state_dict = {}
47+
48+
reference_model = tt_model_args.reference_vision_rot_emb()
49+
reference_model.load_state_dict(partial_state_dict)
50+
51+
image_size = tt_model_args.vision_image_size
52+
patch_size = tt_model_args.vision_patch_size
53+
dim = tt_model_args.vision_head_dim
54+
num_patches_per_dim = image_size // patch_size
55+
num_patches = num_patches_per_dim * num_patches_per_dim
56+
position_ids = torch.arange(4096, dtype=torch.long)
57+
58+
x = torch.randn(batch_size, 4096, 1024)
59+
60+
cos, sin = reference_model(x, position_ids)
61+
tt_model = RotarySetup(
62+
device,
63+
batch_size,
64+
dim,
65+
image_size,
66+
patch_size,
67+
num_patches,
68+
tt_model_args.vision_rope_theta,
69+
scale_factor=None,
70+
orig_context_len=num_patches,
71+
datatype=dtype,
72+
)
73+
74+
cos2, sin2 = tt_model.get_rot_mats(position_ids)
75+
cos2 = ttnn.from_device(cos2)
76+
cos2 = ttnn.to_torch(cos2)
77+
cos2 = cos2.squeeze(0)
78+
79+
sin2 = ttnn.from_device(sin2)
80+
sin2 = ttnn.to_torch(sin2)
81+
sin2 = sin2.squeeze(0)
82+
83+
passing, pcc_message = comp_pcc(cos, cos2)
84+
85+
logger.info(comp_allclose(cos, cos2))
86+
logger.info(f"PCC: {pcc_message}")
87+
assert passing, f"COS PCC value is lower than {0.99} for some of the outputs. Check Warnings!"
88+
89+
passing, pcc_message = comp_pcc(sin, sin2)
90+
91+
logger.info(comp_allclose(sin, sin2))
92+
logger.info(f"PCC: {pcc_message}")
93+
assert passing, f"SIN PCC value is lower than {0.99} for some of the outputs. Check Warnings!"

0 commit comments

Comments
 (0)