Skip to content

Commit 3d8da88

Browse files
committed
Add Support for mistralai/Mistral-Small-3.1-24B-Instruct-2503 model
1 parent 9357c74 commit 3d8da88

26 files changed

+3113
-56
lines changed

models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py

Lines changed: 523 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import os
5+
import pytest
6+
import torch
7+
from loguru import logger
8+
9+
import ttnn
10+
from models.tt_transformers.tt.model_config import ModelArgs
11+
from models.experimental.mistral_24b.tt.pipeline.vision_model import TtMistralVisionTransformer
12+
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
13+
14+
15+
def get_image_features(vision_tower, projector, input_tensor, image_sizes):
16+
"""
17+
Get image features from the vision tower and projector.
18+
"""
19+
vision_token = vision_tower(input_tensor, image_sizes).last_hidden_state
20+
image_features = projector(vision_token.squeeze(0), image_sizes)
21+
return image_features
22+
23+
24+
@skip_for_grayskull("Requires wormhole_b0 to run")
25+
@pytest.mark.parametrize(
26+
"mesh_device",
27+
[
28+
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
29+
os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids())
30+
)
31+
],
32+
indirect=True,
33+
)
34+
def test_mistral_vision_model(mesh_device, reset_seeds):
35+
pcc_required = 0.97
36+
dtype = ttnn.bfloat8_b
37+
38+
model_args = ModelArgs(mesh_device)
39+
state_dict = model_args.load_state_dict()
40+
41+
first_layer_prefix = "vision_tower."
42+
partial_state_dict = {
43+
k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix)
44+
}
45+
46+
##### Reference model output (Torch) #####
47+
reference_model = model_args.reference_vision_model()
48+
reference_model.load_state_dict(partial_state_dict)
49+
50+
mmp_first_layer_prefix = "multi_modal_projector."
51+
52+
mmp_partial_state_dict = {
53+
k[len(mmp_first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(mmp_first_layer_prefix))
54+
}
55+
56+
reference_mmp = model_args.reference_vision_multi_modal()
57+
reference_mmp.load_state_dict(mmp_partial_state_dict)
58+
59+
B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size
60+
input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16)
61+
62+
reference_output = get_image_features(reference_model, reference_mmp, input_tensor, image_sizes=[(H, W)])
63+
64+
# ##### TT Model: TtMistralVisionTransformer #####
65+
vision_model = TtMistralVisionTransformer(
66+
mesh_device=mesh_device,
67+
state_dict=state_dict,
68+
state_dict_prefix=first_layer_prefix,
69+
dtype=dtype,
70+
model_args=model_args,
71+
)
72+
73+
tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) # [0]
74+
tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[
75+
:, : tt_output.shape[-1]
76+
]
77+
78+
non_zero_indices = tt_output.ne(0).nonzero(as_tuple=True)
79+
tt_output = tt_output[non_zero_indices]
80+
reference_output = reference_output[non_zero_indices]
81+
82+
passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required)
83+
84+
logger.info(comp_allclose(reference_output, tt_output))
85+
logger.info(f"PCC: {pcc_message}")
86+
assert passing, f"PCC below {pcc_required}. {pcc_message}"
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import os
5+
import pytest
6+
import torch
7+
from loguru import logger
8+
9+
import ttnn
10+
from models.tt_transformers.tt.model_config import ModelArgs
11+
from models.experimental.mistral_24b.tt.pipeline.mistral_vision_tower import MistralVisionTower
12+
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
13+
14+
15+
@skip_for_grayskull("Requires wormhole_b0 to run")
16+
@pytest.mark.parametrize(
17+
"mesh_device",
18+
[
19+
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
20+
os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids())
21+
)
22+
],
23+
indirect=True,
24+
)
25+
def test_mistral_vision_tower(mesh_device, reset_seeds):
26+
pcc_required = 0.99
27+
dtype = ttnn.bfloat16
28+
29+
model_args = ModelArgs(mesh_device)
30+
state_dict = model_args.load_state_dict()
31+
32+
first_layer_prefix = "vision_tower."
33+
partial_state_dict = {
34+
k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix)
35+
}
36+
37+
B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size
38+
input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16)
39+
40+
##### Reference model output (Torch) #####
41+
reference_model = model_args.reference_vision_model()
42+
reference_model.load_state_dict(partial_state_dict)
43+
reference_output = reference_model(input_tensor, image_sizes=[(H, W)])
44+
45+
reference_output = reference_output.last_hidden_state
46+
##### TT Model: MistralVisionTower #####
47+
vision_model = MistralVisionTower(
48+
mesh_device=mesh_device,
49+
state_dict=state_dict,
50+
state_dict_prefix=first_layer_prefix,
51+
dtype=dtype,
52+
configuration=model_args,
53+
)
54+
55+
tt_output = vision_model(input_tensor, image_sizes=[(H, W)])
56+
tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[
57+
:, :, :, : tt_output.shape[-1]
58+
]
59+
tt_output = tt_output.squeeze(0)
60+
passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required)
61+
62+
logger.info(comp_allclose(reference_output, tt_output))
63+
logger.info(f"PCC: {pcc_message}")
64+
assert passing, f"PCC below {pcc_required}. {pcc_message}"
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2+
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import os
6+
7+
import pytest
8+
import torch
9+
from loguru import logger
10+
11+
import ttnn
12+
from models.tt_transformers.tt.model_config import ModelArgs
13+
from models.experimental.mistral_24b.tt.vision_conv2d import TtMistralConv2dPatch
14+
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
15+
from ttnn import ConcatMeshToTensor
16+
17+
18+
@skip_for_grayskull("Requires wormhole_b0 to run")
19+
@pytest.mark.parametrize(
20+
"mesh_device",
21+
[
22+
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
23+
os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids())
24+
)
25+
],
26+
indirect=True,
27+
)
28+
def test_conv2d_inference(
29+
mesh_device,
30+
reset_seeds,
31+
):
32+
pcc_required = 0.9999
33+
dtype = ttnn.bfloat16
34+
35+
model_args = ModelArgs(mesh_device)
36+
state_dict = model_args.load_state_dict()
37+
38+
# Ref model needs partial state dict, but our models use full state dict keys as cached weight names
39+
first_layer_prefix = "vision_tower.patch_conv."
40+
partial_state_dict = {
41+
k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix))
42+
}
43+
num_devices = model_args.num_devices
44+
45+
##### Create input tensor for the all gather #####
46+
B, NCH, H, W = (1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size)
47+
in_channels, out_channels, kernel_size, stride, bias = (
48+
3,
49+
model_args.vision_dim,
50+
model_args.vision_patch_size,
51+
model_args.vision_patch_size,
52+
False,
53+
)
54+
55+
assert NCH == in_channels, "Number of channels in input tensor should match in_channels for the Conv2d patch."
56+
assert type(kernel_size) == int, "Only symmetric kernel_size is currently supported."
57+
assert kernel_size == stride, "Only same kernel_size and stride are currently supported."
58+
59+
assert H % kernel_size == 0, "Height should be divisible by kernel_size."
60+
assert W % kernel_size == 0, "Width should be divisible by kernel_size."
61+
62+
##### Prepare inputs #####
63+
input_tensor = torch.randn((B, NCH, H, W)).to(dtype=torch.bfloat16)
64+
logger.info(f"Input tensor shape: {input_tensor.shape}")
65+
66+
reference_model = model_args.reference_conv2d_patch()
67+
reference_model.load_state_dict(partial_state_dict)
68+
reference_output = reference_model(input_tensor)
69+
70+
tt_model = TtMistralConv2dPatch(
71+
mesh_device,
72+
state_dict,
73+
first_layer_prefix,
74+
dtype,
75+
in_channels,
76+
out_channels,
77+
kernel_size,
78+
stride,
79+
bias,
80+
)
81+
tt_output = tt_model(input_tensor)
82+
83+
##### Check the outputs #####
84+
out = ttnn.from_device(tt_output)
85+
tt_output_torch = ttnn.to_torch(out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=2))
86+
87+
# Only select output from one device
88+
tt_output_torch = tt_output_torch[0, ..., :out_channels]
89+
90+
# 1. Restore batch dim
91+
tt_output_torch = tt_output_torch.unsqueeze(0)
92+
# 1 1024 4096
93+
# 2. Permute to match Conv2D output: (N, C_out, H_out, W_out)
94+
tt_output_torch = tt_output_torch.permute(0, 2, 1).reshape(1, 1024, 64, 64)
95+
96+
passing, pcc_message = comp_pcc(reference_output, tt_output_torch)
97+
98+
logger.info(comp_allclose(reference_output, tt_output_torch))
99+
logger.info(f"PCC: {pcc_message}")
100+
assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!"
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from loguru import logger
2+
3+
import torch
4+
import pytest
5+
import os
6+
import ttnn
7+
8+
# models/tt_transformers/tt/common.py
9+
from models.experimental.mistral_24b.tt.vision_rope import VisionRotarySetup as RotarySetup
10+
11+
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
12+
from models.tt_transformers.tt.model_config import ModelArgs
13+
14+
15+
@torch.no_grad()
16+
@skip_for_grayskull("Requires wormhole_b0 to run")
17+
@pytest.mark.parametrize(
18+
"device",
19+
[
20+
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
21+
os.environ.get("device"), len(ttnn.get_device_ids())
22+
)
23+
],
24+
indirect=True,
25+
)
26+
@pytest.mark.parametrize(
27+
"seq_len",
28+
(128,),
29+
)
30+
@pytest.mark.parametrize(
31+
"batch_size",
32+
(1,),
33+
)
34+
def test_rot_emb(seq_len, batch_size, reset_seeds, device):
35+
dtype = ttnn.bfloat16
36+
mode = "decode" if seq_len <= 32 else "prefill"
37+
38+
tt_model_args = ModelArgs(
39+
device,
40+
max_batch_size=batch_size,
41+
max_seq_len=128,
42+
)
43+
44+
tt_model_args.n_layers = 1
45+
partial_state_dict = {}
46+
47+
reference_model = tt_model_args.reference_vision_rot_emb()
48+
reference_model.load_state_dict(partial_state_dict)
49+
50+
image_size = tt_model_args.vision_image_size
51+
patch_size = tt_model_args.vision_patch_size
52+
dim = tt_model_args.vision_head_dim
53+
num_patches_per_dim = image_size // patch_size
54+
num_patches = num_patches_per_dim * num_patches_per_dim
55+
position_ids = torch.arange(4096, dtype=torch.long)
56+
57+
x = torch.randn(batch_size, 4096, 1024)
58+
59+
cos, sin = reference_model(x, position_ids)
60+
tt_model = RotarySetup(
61+
device,
62+
batch_size,
63+
dim,
64+
image_size,
65+
patch_size,
66+
num_patches,
67+
tt_model_args.vision_rope_theta,
68+
scale_factor=None,
69+
orig_context_len=num_patches,
70+
datatype=dtype,
71+
)
72+
73+
cos2, sin2 = tt_model.get_rot_mats(position_ids)
74+
cos2 = ttnn.from_device(cos2)
75+
cos2 = ttnn.to_torch(cos2)
76+
cos2 = cos2.squeeze(0)
77+
78+
sin2 = ttnn.from_device(sin2)
79+
sin2 = ttnn.to_torch(sin2)
80+
sin2 = sin2.squeeze(0)
81+
82+
passing, pcc_message = comp_pcc(cos, cos2)
83+
84+
logger.info(comp_allclose(cos, cos2))
85+
logger.info(f"PCC: {pcc_message}")
86+
assert passing, f"COS PCC value is lower than {0.99} for some of the outputs. Check Warnings!"
87+
88+
passing, pcc_message = comp_pcc(sin, sin2)
89+
90+
logger.info(comp_allclose(sin, sin2))
91+
logger.info(f"PCC: {pcc_message}")
92+
assert passing, f"SIN PCC value is lower than {0.99} for some of the outputs. Check Warnings!"

0 commit comments

Comments
 (0)