Skip to content

Commit f85bffd

Browse files
jennychristophernikileshx
authored andcommitted
Handle state_dict prefix
Remove submodule comprehensive debugging tests Eliminate mask and generate_mask function usage
1 parent 0c1d2d8 commit f85bffd

19 files changed

+25
-7700
lines changed

comprehensive_submodule_tests/comprehensive_attention_norm_test.py

Lines changed: 0 additions & 957 deletions
This file was deleted.

comprehensive_submodule_tests/comprehensive_conv2d_test.py

Lines changed: 0 additions & 841 deletions
This file was deleted.

comprehensive_submodule_tests/comprehensive_ffn_norm_test.py

Lines changed: 0 additions & 949 deletions
This file was deleted.

comprehensive_submodule_tests/comprehensive_ln_pre_rmsnorm_test.py

Lines changed: 0 additions & 811 deletions
This file was deleted.

comprehensive_submodule_tests/comprehensive_pixtral_attention_layer_block_test.py

Lines changed: 0 additions & 1114 deletions
This file was deleted.

comprehensive_submodule_tests/comprehensive_pixtral_attention_test.py

Lines changed: 0 additions & 1082 deletions
This file was deleted.

comprehensive_submodule_tests/comprehensive_pixtral_mlp_test.py

Lines changed: 0 additions & 952 deletions
This file was deleted.

comprehensive_submodule_tests/comprehensive_pixtral_transformer_test.py

Lines changed: 0 additions & 867 deletions
This file was deleted.

models/experimental/mistral_24b/tests/test_pixtral_transformer.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,17 +92,9 @@ def test_image_transformer_inference(batch, num_chunks, mesh_device):
9292
pt_attention_input,
9393
force_replicated=True,
9494
)
95-
tt_mask = ttnn.from_torch(
96-
attention_mask,
97-
device=mesh_device,
98-
dtype=ttnn.bfloat16,
99-
layout=ttnn.TILE_LAYOUT,
100-
memory_config=ttnn.DRAM_MEMORY_CONFIG,
101-
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
102-
)
10395

10496
with torch.no_grad():
105-
tt_out = tt_model(attention_input, mask=tt_mask, position_embeddings=(cos_t, sin_t))
97+
tt_out = tt_model(attention_input, position_embeddings=(cos_t, sin_t))
10698
reference_output = reference_model(
10799
pt_attention_input, attention_mask=attention_mask, position_embeddings=(cos, sin)
108100
)[0]

models/experimental/mistral_24b/tests/test_vision_attention.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,7 @@ def test_vision_attention(mesh_device, seq_len, batch_size):
105105
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
106106
)
107107

108-
tt_mask = ttnn.from_torch(
109-
attention_mask,
110-
device=mesh_device,
111-
dtype=ttnn.bfloat8_b,
112-
layout=ttnn.TILE_LAYOUT,
113-
memory_config=ttnn.DRAM_MEMORY_CONFIG,
114-
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
115-
)
116-
117-
tt_out = tt_model(attention_input, position_embeddings=(cos_t, sin_t), mask=tt_mask)
108+
tt_out = tt_model(attention_input, position_embeddings=(cos_t, sin_t))
118109
tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=-1))[
119110
:, :, :, : tt_out.shape[-1]
120111
]

0 commit comments

Comments
 (0)