Skip to content

Commit 4e6b224

Browse files
committed
Address comments and refactor comments
1 parent 733e9b3 commit 4e6b224

File tree

13 files changed

+53
-99
lines changed

13 files changed

+53
-99
lines changed

models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_mistral_vision_model(mesh_device, reset_seeds):
8282
model_args=model_args,
8383
)
8484

85-
tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) # [0]
85+
tt_output = vision_model(input_tensor, image_sizes=[(H, W)])
8686
tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[
8787
:, : tt_output.shape[-1]
8888
]

models/experimental/mistral_24b/tests/test_vision_mlp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import ttnn
1212

13-
# from models.tt_transformers.tt.mlp import MLP
1413
from models.experimental.mistral_24b.tt.vision_mlp import MistralTTVisionMLP as MLP
1514
from models.tt_transformers.tt.model_config import ModelArgs
1615
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull

models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
"""
66
This file implements the Vision Tower submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model.
7+
This pipeline constructs the vision tower from vision model architecture.
78
"""
89

910
import ttnn

models/experimental/mistral_24b/tt/rmsnorm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
22

33
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
This is the modified version of the rmsnorm for the Mistral-Small-3.1-24B-Instruct-2503 model.
7+
We introduced the `simplified_rms_norm` function to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model.
8+
"""
9+
410
import ttnn
511
from models.common.lightweightmodule import LightweightModule
612

models/experimental/mistral_24b/tt/vision_attention.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
# SPDX-License-Identifier: Apache-2.0
44
"""
5-
This file implements the vision attention submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model.
6-
5+
This is the modified version of the vision_attention for the Mistral-Small-3.1-24B-Instruct-2503 model.
6+
We introduced the `apply_rotary_pos_emb_vision_tt` function to llama_image_attention to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model.
77
"""
8-
import torch
98

9+
import torch
1010
import ttnn
11+
1112
from models.common.lightweightmodule import LightweightModule
1213
from models.utility_functions import is_blackhole, nearest_32
1314

@@ -162,7 +163,7 @@ def pad_head_dim(weight, heads_out=True):
162163
def forward(self, x_11SH, position_embeddings=None):
163164
seq_len = x_11SH.shape[-2]
164165

165-
MAX_MM_SEQ_LEN = self.configuration.VISION_MAX_MM_SEQ
166+
MAX_MM_SEQ_LEN = seq_len
166167

167168
if seq_len > MAX_MM_SEQ_LEN:
168169
x_11SH = ttnn.reshape(x_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1])

models/experimental/mistral_24b/tt/vision_conv2d.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22

33
# SPDX-License-Identifier: Apache-2.0
44

5-
import torch
5+
"""
6+
This is the modified version of the vision_patch_conv2d for the Mistral-Small-3.1-24B-Instruct-2503 model.
7+
We have modified the llama_patch_conv2d to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model.
8+
"""
69

10+
import torch
711
import ttnn
12+
813
from models.common.lightweightmodule import LightweightModule
914

1015

@@ -57,7 +62,7 @@ def __init__(
5762

5863
self._unfold = torch.nn.Unfold(kernel_size=self.kernel_size, stride=self.stride)
5964

60-
weight = state_dict[f"{state_dict_prefix}weight"]
65+
weight = state_dict[f"{state_dict_prefix}_linear.weight"]
6166
if weight.ndim == 4:
6267
weight = weight.reshape(out_channels, -1).T
6368

models/experimental/mistral_24b/tt/vision_mlp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
22

33
# SPDX-License-Identifier: Apache-2.0
4+
45
"""
6+
This is the modified version of the FeedForward for the Mistral-Small-3.1-24B-Instruct-2503 model.
57
This file implements the Vision FeedForward submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model.
6-
78
"""
8-
import torch
99

10+
import torch
1011
import ttnn
12+
1113
from models.common.lightweightmodule import LightweightModule
1214

1315

@@ -48,7 +50,6 @@ def as_tensor(name, dtype, is_bias=False):
4850
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
4951
layout=ttnn.TILE_LAYOUT,
5052
memory_config=ttnn.DRAM_MEMORY_CONFIG,
51-
# cache_file_name=cache_name(name),
5253
)
5354

5455
# Weights and Biases

models/experimental/mistral_24b/tt/vision_mmp.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
"""
5+
This file implements the Vision MultiModalProjector submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model.
6+
"""
47

58
import torch
69
from models.common.lightweightmodule import LightweightModule
710
from models.experimental.mistral_24b.tt.rmsnorm import RMSNorm
811
import ttnn
912
from ttnn import ConcatMeshToTensor
1013

11-
"""
12-
This file implements the Vision pixtral image submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model.
13-
"""
14-
1514

1615
class TTMistral3PatchMerger(LightweightModule):
1716
def __init__(
@@ -26,7 +25,7 @@ def __init__(
2625
super().__init__()
2726
self.device = mesh_device
2827
hidden_size = args.vision_dim
29-
self.spatial_merge_size = 2 # TODO Handle in Model_config spatial_merge_size
28+
self.spatial_merge_size = 2
3029
self.patch_size = args.vision_patch_size
3130
self.args = args
3231

models/experimental/mistral_24b/tt/vision_pixtral_image_block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from models.experimental.mistral_24b.tt.vision_mlp import MistralTTVisionMLP as MLP
1111

1212
"""
13-
This file implements the Pixtral_image_block submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model.
13+
This file implements the pixtral image block specific for the Mistral-Small-3.1-24B-Instruct-2503 model.
1414
"""
1515

1616

models/experimental/mistral_24b/tt/vision_pixtral_transformer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
# SPDX-License-Identifier: Apache-2.0
44

5+
"""
6+
This file implements the Vision Transformer submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model.
7+
This pipeline iterates over the pixtral image blocks to generate the image embeddings.
8+
"""
9+
510
from tqdm import tqdm
611

712
from models.common.lightweightmodule import LightweightModule

0 commit comments

Comments
 (0)