Skip to content

Commit 25e8431

Browse files
committed
Rebase Experimental E2E and refactor to include tt_ccl
1 parent 5a92a0d commit 25e8431

File tree

14 files changed

+95
-12
lines changed

14 files changed

+95
-12
lines changed

models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import ttnn
88

9+
from models.tt_transformers.tt.ccl import TT_CCL
910
from models.tt_transformers.tt.common import (
1011
sample_host,
1112
PagedAttentionConfig,
@@ -117,8 +118,14 @@ def setup_vision_prompts_and_tokenizer(model_args, instruct):
117118
{
118119
"role": "user",
119120
"content": [
120-
{"type": "image", "image": "https://www.theeducationmagazine.com/wp-content/uploads/2020/03/18.jpg"},
121-
{"type": "text", "text": "Tell me who you see in the image and describe the image ?"},
121+
{
122+
"type": "image",
123+
"image": "https://img.freepik.com/premium-photo/girl-hugging-dog-with-girl-hugging-her_737761-2565.jpg",
124+
},
125+
{
126+
"type": "text",
127+
"text": "Is there a cat in this image? If not, what animal do you see in the image? Describe the image in detail in 600 words.",
128+
},
122129
],
123130
}
124131
]
@@ -182,9 +189,11 @@ def load_separate_models_like_test_end2end(model_args, mesh_device, dtype, paged
182189
max_num_blocks=page_params["page_max_num_blocks"],
183190
)
184191

192+
tt_ccl = TT_CCL(mesh_device)
185193
# Load vision model (exactly like test_end2end.py)
186194
vision_model = TtMistralVisionTransformer(
187195
mesh_device=mesh_device,
196+
tt_ccl=tt_ccl,
188197
state_dict=state_dict,
189198
state_dict_prefix=vision_prefix,
190199
dtype=dtype,
@@ -418,6 +427,11 @@ def validate_e2e_outputs(results, expected_min_tokens=1):
418427
],
419428
ids=["accuracy"],
420429
)
430+
@pytest.mark.parametrize(
431+
"device_params",
432+
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}],
433+
indirect=True,
434+
)
421435
@pytest.mark.parametrize(
422436
"mesh_device",
423437
[
@@ -427,8 +441,6 @@ def validate_e2e_outputs(results, expected_min_tokens=1):
427441
],
428442
indirect=True,
429443
)
430-
# @pytest.mark.parametrize("device_params", [{"l1_small_size": 1584864, "trace_region_size": 0}], indirect=True)
431-
@pytest.mark.parametrize("device_params", [{"l1_small_size": 10 * 1024}], indirect=True)
432444
def test_e2e_vision_text_pipeline(
433445
weights,
434446
layers,

models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from loguru import logger
88

99
import ttnn
10+
from models.tt_transformers.tt.ccl import TT_CCL
1011
from models.tt_transformers.tt.model_config import ModelArgs
1112
from models.experimental.mistral_24b.tt.pipeline.vision_model import TtMistralVisionTransformer
1213
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
@@ -31,6 +32,11 @@ def get_image_features(vision_tower, projector, input_tensor, image_sizes):
3132
],
3233
indirect=True,
3334
)
35+
@pytest.mark.parametrize(
36+
"device_params",
37+
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}],
38+
indirect=True,
39+
)
3440
def test_mistral_vision_model(mesh_device, reset_seeds):
3541
pcc_required = 0.97
3642
dtype = ttnn.bfloat8_b
@@ -62,8 +68,10 @@ def test_mistral_vision_model(mesh_device, reset_seeds):
6268
reference_output = get_image_features(reference_model, reference_mmp, input_tensor, image_sizes=[(H, W)])
6369

6470
# ##### TT Model: TtMistralVisionTransformer #####
71+
tt_ccl = TT_CCL(mesh_device=mesh_device)
6572
vision_model = TtMistralVisionTransformer(
6673
mesh_device=mesh_device,
74+
tt_ccl=tt_ccl,
6775
state_dict=state_dict,
6876
state_dict_prefix=first_layer_prefix,
6977
dtype=dtype,

models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from loguru import logger
88

99
import ttnn
10+
from models.tt_transformers.tt.ccl import TT_CCL
1011
from models.tt_transformers.tt.model_config import ModelArgs
1112
from models.experimental.mistral_24b.tt.pipeline.mistral_vision_tower import MistralVisionTower
1213
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
@@ -22,6 +23,11 @@
2223
],
2324
indirect=True,
2425
)
26+
@pytest.mark.parametrize(
27+
"device_params",
28+
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}],
29+
indirect=True,
30+
)
2531
def test_mistral_vision_tower(mesh_device, reset_seeds):
2632
pcc_required = 0.99
2733
dtype = ttnn.bfloat16
@@ -43,9 +49,11 @@ def test_mistral_vision_tower(mesh_device, reset_seeds):
4349
reference_output = reference_model(input_tensor, image_sizes=[(H, W)])
4450

4551
reference_output = reference_output.last_hidden_state
52+
tt_ccl = TT_CCL(mesh_device)
4653
##### TT Model: MistralVisionTower #####
4754
vision_model = MistralVisionTower(
4855
mesh_device=mesh_device,
56+
tt_ccl=tt_ccl,
4957
state_dict=state_dict,
5058
state_dict_prefix=first_layer_prefix,
5159
dtype=dtype,

models/experimental/mistral_24b/tests/test_pixtral_transformer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from loguru import logger
99

1010
import ttnn
11+
from models.tt_transformers.tt.ccl import TT_CCL
1112
from models.tt_transformers.tt.model_config import ModelArgs
1213

1314
from models.experimental.mistral_24b.tt.vision_pixtral_transformer import TtPixtralTransformer
@@ -28,6 +29,11 @@
2829
],
2930
indirect=True,
3031
)
32+
@pytest.mark.parametrize(
33+
"device_params",
34+
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}],
35+
indirect=True,
36+
)
3137
def test_image_transformer_inference(batch, num_chunks, mesh_device):
3238
pcc_required = 0.99
3339

@@ -52,8 +58,10 @@ def test_image_transformer_inference(batch, num_chunks, mesh_device):
5258

5359
all_tests_pass = True
5460

61+
tt_ccl = TT_CCL(mesh_device)
5562
tt_model = TtPixtralTransformer(
5663
mesh_device,
64+
tt_ccl,
5765
state_dict,
5866
state_dict_prefix=first_layer_prefix,
5967
weight_cache_path=None,

models/experimental/mistral_24b/tests/test_vision_attention.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from loguru import logger
1010

1111
import ttnn
12+
from models.tt_transformers.tt.ccl import TT_CCL
1213
from models.tt_transformers.tt.model_config import ModelArgs
1314
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
1415

@@ -36,6 +37,11 @@
3637
"batch_size",
3738
(1,),
3839
)
40+
@pytest.mark.parametrize(
41+
"device_params",
42+
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}],
43+
indirect=True,
44+
)
3945
def test_vision_attention(mesh_device, seq_len, batch_size):
4046
logger.info(f"seq_len: {seq_len}, batch_size: {batch_size}")
4147
dtype = ttnn.bfloat16
@@ -56,8 +62,10 @@ def test_vision_attention(mesh_device, seq_len, batch_size):
5662
n_heads = model_args.vision_attn_n_heads
5763
head_dim = hidden_size // n_heads
5864

65+
tt_ccl = TT_CCL(mesh_device)
5966
tt_model = TtLlamaImageAttention(
6067
mesh_device,
68+
tt_ccl,
6169
state_dict,
6270
state_dict_prefix=first_layer_prefix,
6371
weight_cache_path=model_args.weight_cache_path(dtype),

models/experimental/mistral_24b/tt/model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,14 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag
100100
self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :],
101101
]
102102

103+
if hasattr(self, "rope_local_setup"):
104+
tt_rot_mats_prefill_local = [
105+
self.rope_local_setup.cos_matrix[:, :, start_pos : start_pos + S, :],
106+
self.rope_local_setup.sin_matrix[:, :, start_pos : start_pos + S, :],
107+
]
108+
else:
109+
tt_rot_mats_prefill_local = None
110+
103111
if page_table is not None:
104112
tt_page_table = ttnn.from_torch(
105113
page_table,
@@ -122,4 +130,4 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag
122130
else:
123131
tt_chunk_page_table = None
124132

125-
return tokens_embd, tt_rot_mats_prefill_global, tt_page_table, tt_chunk_page_table
133+
return tokens_embd, tt_rot_mats_prefill_global, tt_rot_mats_prefill_local, tt_page_table, tt_chunk_page_table

models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class MistralVisionTower(LightweightModule):
1818
def __init__(
1919
self,
2020
mesh_device,
21+
tt_ccl,
2122
state_dict,
2223
state_dict_prefix,
2324
dtype,
@@ -28,6 +29,7 @@ def __init__(
2829

2930
self.state_dict = state_dict
3031
self.mesh_device = mesh_device
32+
self.tt_ccl = tt_ccl
3133
self.dtype = dtype
3234
self.config = configuration
3335

@@ -98,6 +100,7 @@ def __init__(
98100

99101
self.transformer = TtPixtralTransformer(
100102
mesh_device=self.mesh_device,
103+
tt_ccl=tt_ccl,
101104
state_dict=self.state_dict,
102105
state_dict_prefix=f"{state_dict_prefix}transformer.",
103106
weight_cache_path=configuration.weight_cache_path(dtype),

models/experimental/mistral_24b/tt/pipeline/vision_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212

1313

1414
class TtMistralVisionTransformer(LightweightModule):
15-
def __init__(self, mesh_device, state_dict, state_dict_prefix, dtype, model_args):
15+
def __init__(self, mesh_device, tt_ccl, state_dict, state_dict_prefix, dtype, model_args):
1616
super().__init__()
1717
self.state_dict = state_dict
1818
self.mesh_device = mesh_device
19+
self.tt_ccl = tt_ccl
1920

2021
self.vision_tower = MistralVisionTower(
2122
mesh_device=mesh_device,
23+
tt_ccl=self.tt_ccl,
2224
state_dict=state_dict,
2325
state_dict_prefix=state_dict_prefix,
2426
dtype=dtype,

models/experimental/mistral_24b/tt/vision_attention.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import ttnn
88
from models.common.lightweightmodule import LightweightModule
9-
from models.utility_functions import nearest_32
9+
from models.utility_functions import is_blackhole, nearest_32
1010

1111

1212
def rotate_half(x):
@@ -33,6 +33,7 @@ class TtMistralImageAttention(LightweightModule):
3333
def __init__(
3434
self,
3535
mesh_device,
36+
tt_ccl,
3637
state_dict,
3738
state_dict_prefix,
3839
weight_cache_path,
@@ -43,6 +44,7 @@ def __init__(
4344

4445
self.state_dict = state_dict
4546
self.mesh_device = mesh_device
47+
self.tt_ccl = tt_ccl
4648
self.num_devices = configuration.num_devices
4749

4850
self.hidden_size = configuration.vision_dim
@@ -237,7 +239,23 @@ def forward(self, x_11SH, position_embeddings=None):
237239

238240
# All reduce
239241
if self.num_devices > 1: # replace with reduce_scatter and all_gather
240-
dense_out_gathered = ttnn.all_gather(output_11SH, dim=1, num_links=1, topology=ttnn.Topology.Linear)
242+
# TODO: 26411
243+
# Remove this blackhole condition once fabric CCLs are working on blackhole
244+
if is_blackhole():
245+
dense_out_gathered = ttnn.all_gather(output_11SH, dim=1, num_links=1, topology=ttnn.Topology.Linear)
246+
else:
247+
dense_out_gathered = ttnn.experimental.all_gather_async(
248+
output_11SH,
249+
persistent_output_buffer=None,
250+
dim=1,
251+
multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(),
252+
num_links=1,
253+
topology=ttnn.Topology.Linear,
254+
barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(),
255+
chunks_per_sync=10,
256+
num_workers_per_link=2,
257+
num_buffers_per_channel=2,
258+
)
241259
output_11SH.deallocate(True)
242260
dense_out_reduced = ttnn.experimental.fast_reduce_nc(
243261
dense_out_gathered, dims=[1], output=None, compute_kernel_config=None

models/experimental/mistral_24b/tt/vision_pixtral_image_block.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class TtPixtralImageTransformerBlock(LightweightModule):
1414
def __init__(
1515
self,
1616
mesh_device,
17+
tt_ccl,
1718
state_dict,
1819
state_dict_prefix,
1920
weight_cache_path,
@@ -23,6 +24,7 @@ def __init__(
2324
super().__init__()
2425
self.state_dict = state_dict
2526
self.mesh_device = mesh_device
27+
self.tt_ccl = tt_ccl
2628
self.configuration = configuration
2729
self.num_devices = configuration.num_devices
2830
self.hidden_size = configuration.vision_dim
@@ -40,6 +42,7 @@ def __init__(
4042

4143
self.attention = TtLlamaImageAttention(
4244
mesh_device,
45+
tt_ccl,
4346
state_dict,
4447
state_dict_prefix=f"{state_dict_prefix}attention.",
4548
weight_cache_path=weight_cache_path,

0 commit comments

Comments
 (0)