Skip to content

Commit caad9b3

Browse files
committed
Add unit tests to adopt tt_ccl
1 parent 28b249b commit caad9b3

File tree

7 files changed

+193
-6
lines changed

7 files changed

+193
-6
lines changed

models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from loguru import logger
99

1010
import ttnn
11+
from models.tt_transformers.tt.model_config import ModelArgs
1112

13+
# models/tt_transformers/tt/common.py
1214
from models.tt_transformers.tt.multimodal.mistral_24b.vision_rope import VisionRotarySetup as RotarySetup
13-
from models.tt_transformers.tt.model_config import ModelArgs
1415
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
1516

1617

models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_transformer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from loguru import logger
99

1010
import ttnn
11-
from models.experimental.mistral_24b.tt.vision_pixtral_transformer import TtPixtralTransformer
11+
from models.tt_transformers.tt.ccl import TT_CCL
1212
from models.tt_transformers.tt.model_config import ModelArgs
13+
from models.tt_transformers.tt.multimodal.mistral_24b.vision_pixtral_transformer import TtPixtralTransformer
1314
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
1415

1516

@@ -27,6 +28,11 @@
2728
],
2829
indirect=True,
2930
)
31+
@pytest.mark.parametrize(
32+
"device_params",
33+
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}],
34+
indirect=True,
35+
)
3036
def test_image_transformer_inference(batch, num_chunks, mesh_device):
3137
pcc_required = 0.99
3238

@@ -51,8 +57,10 @@ def test_image_transformer_inference(batch, num_chunks, mesh_device):
5157

5258
all_tests_pass = True
5359

60+
tt_ccl = TT_CCL(mesh_device)
5461
tt_model = TtPixtralTransformer(
5562
mesh_device,
63+
tt_ccl,
5664
state_dict,
5765
state_dict_prefix=first_layer_prefix,
5866
weight_cache_path=None,

models/tt_transformers/tests/multimodal/mistral_24b/test_vision_attention.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88
from loguru import logger
99

1010
import ttnn
11-
from models.experimental.mistral_24b.tt.vision_attention import TtMistralImageAttention as TtLlamaImageAttention
11+
from models.tt_transformers.tt.ccl import TT_CCL
1212
from models.tt_transformers.tt.model_config import ModelArgs
13+
from models.tt_transformers.tt.multimodal.mistral_24b.vision_attention import (
14+
TtMistralImageAttention as TtLlamaImageAttention,
15+
)
1316
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
1417
from ttnn import ConcatMeshToTensor
1518

@@ -33,6 +36,11 @@
3336
"batch_size",
3437
(1,),
3538
)
39+
@pytest.mark.parametrize(
40+
"device_params",
41+
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}],
42+
indirect=True,
43+
)
3644
def test_vision_attention(mesh_device, seq_len, batch_size):
3745
logger.info(f"seq_len: {seq_len}, batch_size: {batch_size}")
3846
dtype = ttnn.bfloat16
@@ -53,8 +61,10 @@ def test_vision_attention(mesh_device, seq_len, batch_size):
5361
n_heads = model_args.vision_attn_n_heads
5462
head_dim = hidden_size // n_heads
5563

64+
tt_ccl = TT_CCL(mesh_device)
5665
tt_model = TtLlamaImageAttention(
5766
mesh_device,
67+
tt_ccl,
5868
state_dict,
5969
state_dict_prefix=first_layer_prefix,
6070
weight_cache_path=model_args.weight_cache_path(dtype),

models/tt_transformers/tests/multimodal/mistral_24b/test_vision_mlp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from loguru import logger
1010

1111
import ttnn
12-
13-
from models.experimental.mistral_24b.tt.vision_mlp import MistralTTVisionMLP as MLP
1412
from models.tt_transformers.tt.model_config import ModelArgs
13+
14+
from models.tt_transformers.tt.multimodal.mistral_24b.vision_mlp import MistralTTVisionMLP as MLP
1515
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
1616

1717

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import os
5+
6+
import pytest
7+
import torch
8+
from loguru import logger
9+
10+
import ttnn
11+
from models.tt_transformers.tt.ccl import TT_CCL
12+
from models.tt_transformers.tt.model_config import ModelArgs
13+
from models.tt_transformers.tt.multimodal.mistral_24b.vision_model import TtMistralVisionTransformer
14+
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
15+
16+
17+
def get_image_features(vision_tower, projector, input_tensor, image_sizes):
18+
"""
19+
Get image features from the vision tower and projector.
20+
"""
21+
vision_token = vision_tower(input_tensor, image_sizes).last_hidden_state
22+
image_features = projector(vision_token.squeeze(0), image_sizes)
23+
return image_features
24+
25+
26+
@skip_for_grayskull("Requires wormhole_b0 to run")
27+
@pytest.mark.parametrize(
28+
"mesh_device",
29+
[
30+
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
31+
os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids())
32+
)
33+
],
34+
indirect=True,
35+
)
36+
@pytest.mark.parametrize(
37+
"device_params",
38+
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}],
39+
indirect=True,
40+
)
41+
def test_mistral_vision_model(mesh_device, reset_seeds):
42+
pcc_required = 0.97
43+
dtype = ttnn.bfloat8_b
44+
45+
model_args = ModelArgs(mesh_device)
46+
state_dict = model_args.load_state_dict()
47+
48+
first_layer_prefix = "vision_tower."
49+
partial_state_dict = {
50+
k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix)
51+
}
52+
53+
##### Reference model output (Torch) #####
54+
reference_model = model_args.reference_vision_model()
55+
reference_model.load_state_dict(partial_state_dict)
56+
57+
mmp_first_layer_prefix = "multi_modal_projector."
58+
59+
mmp_partial_state_dict = {
60+
k[len(mmp_first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(mmp_first_layer_prefix))
61+
}
62+
63+
reference_mmp = model_args.reference_vision_multi_modal()
64+
reference_mmp.load_state_dict(mmp_partial_state_dict)
65+
66+
B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size
67+
input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16)
68+
69+
reference_output = get_image_features(reference_model, reference_mmp, input_tensor, image_sizes=[(H, W)])
70+
71+
# ##### TT Model: TtMistralVisionTransformer #####
72+
tt_ccl = TT_CCL(mesh_device=mesh_device)
73+
vision_model = TtMistralVisionTransformer(
74+
mesh_device=mesh_device,
75+
tt_ccl=tt_ccl,
76+
state_dict=state_dict,
77+
state_dict_prefix=first_layer_prefix,
78+
dtype=dtype,
79+
model_args=model_args,
80+
)
81+
82+
tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) # [0]
83+
tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[
84+
:, : tt_output.shape[-1]
85+
]
86+
87+
non_zero_indices = tt_output.ne(0).nonzero(as_tuple=True)
88+
tt_output = tt_output[non_zero_indices]
89+
reference_output = reference_output[non_zero_indices]
90+
91+
passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required)
92+
93+
logger.info(comp_allclose(reference_output, tt_output))
94+
logger.info(f"PCC: {pcc_message}")
95+
assert passing, f"PCC below {pcc_required}. {pcc_message}"

models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from loguru import logger
66

77
import ttnn
8-
from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm
98
from models.tt_transformers.tt.model_config import ModelArgs
9+
from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm
1010
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
1111

1212

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import os
5+
6+
import pytest
7+
import torch
8+
from loguru import logger
9+
10+
import ttnn
11+
from models.tt_transformers.tt.ccl import TT_CCL
12+
from models.tt_transformers.tt.model_config import ModelArgs
13+
from models.tt_transformers.tt.multimodal.mistral_24b.mistral_vision_tower import MistralVisionTower
14+
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
15+
16+
17+
@skip_for_grayskull("Requires wormhole_b0 to run")
18+
@pytest.mark.parametrize(
19+
"mesh_device",
20+
[
21+
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
22+
os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids())
23+
)
24+
],
25+
indirect=True,
26+
)
27+
@pytest.mark.parametrize(
28+
"device_params",
29+
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}],
30+
indirect=True,
31+
)
32+
def test_mistral_vision_tower(mesh_device, reset_seeds):
33+
pcc_required = 0.99
34+
dtype = ttnn.bfloat16
35+
36+
model_args = ModelArgs(mesh_device)
37+
state_dict = model_args.load_state_dict()
38+
39+
first_layer_prefix = "vision_tower."
40+
partial_state_dict = {
41+
k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix)
42+
}
43+
44+
B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size
45+
input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16)
46+
47+
##### Reference model output (Torch) #####
48+
reference_model = model_args.reference_vision_model()
49+
reference_model.load_state_dict(partial_state_dict)
50+
reference_output = reference_model(input_tensor, image_sizes=[(H, W)])
51+
52+
reference_output = reference_output.last_hidden_state
53+
tt_ccl = TT_CCL(mesh_device)
54+
##### TT Model: MistralVisionTower #####
55+
vision_model = MistralVisionTower(
56+
mesh_device=mesh_device,
57+
tt_ccl=tt_ccl,
58+
state_dict=state_dict,
59+
state_dict_prefix=first_layer_prefix,
60+
dtype=dtype,
61+
configuration=model_args,
62+
)
63+
64+
tt_output = vision_model(input_tensor, image_sizes=[(H, W)])
65+
tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[
66+
:, :, :, : tt_output.shape[-1]
67+
]
68+
tt_output = tt_output.squeeze(0)
69+
passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required)
70+
71+
logger.info(comp_allclose(reference_output, tt_output))
72+
logger.info(f"PCC: {pcc_message}")
73+
assert passing, f"PCC below {pcc_required}. {pcc_message}"

0 commit comments

Comments
 (0)