Skip to content

Commit 0cf2700

Browse files
Add Attention mask support in test_accuracy
1 parent aa5c10c commit 0cf2700

File tree

2 files changed

+73
-3
lines changed

2 files changed

+73
-3
lines changed

models/tt_transformers/PERF.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ This configuration uses bfp4 MLP and bfp8 attention weights for all models excep
4545
| Mistral-7B | N150 | 95 | 99 | 29.75 | 100.24 |
4646
| Mistral-7B | N300 | 95 | 99 | 47.01 | 65.95 |
4747
| Mistral-7B | T3K | 95 | 99 | 67.82 | 53.93 |
48-
| gemma-3-1b | N150 |32 |48 | 53.3 |59.9 |
48+
| gemma-3-1b | N150 |83 |96 | 53.3 |59.9 |
4949
| gemma-3-4b | N150 | 78 | 95 | 34 | 68 |
5050
| gemma-3-4b | N300 | 78 | 95 | 35 | 125 |
5151
| gemma-3-27b | T3K | 90 | 99 | 16 | 331 |
@@ -85,7 +85,7 @@ Llama 3 models test as insensitive to attention precision and so we use bfp8 att
8585
| Mistral-7B | N150 | 95 | 99 | 29.75 | 100.24 |
8686
| Mistral-7B | N300 | 95 | 99 | 47.01 | 65.95 |
8787
| Mistral-7B | T3K | 95 | 99 | 67.82 | 53.93 |
88-
| gemma-3-1b | N150 |32 |48 | 51.0 |62.02 |
88+
| gemma-3-1b | N150 | 93 | 99 | 51.0 |62.02 |
8989
| gemma-3-4b | N150 | 88 | 98 | 30 | 79 |
9090
| gemma-3-4b | N300 | 86 | 98 | 32 | 135 |
9191
| gemma-3-27b | T3K | 91 | 100 | 15 | 361 |

models/tt_transformers/tests/test_accuracy.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
from loguru import logger
1111

1212
import ttnn
13-
from models.tt_transformers.tt.common import PagedAttentionConfig, preprocess_inputs_prefill
13+
from models.tt_transformers.tt.common import (
14+
PagedAttentionConfig,
15+
create_causal_mask,
16+
create_sliding_window_causal_mask,
17+
preprocess_inputs_prefill,
18+
)
1419
from models.tt_transformers.tt.model import Transformer
1520
from models.tt_transformers.tt.model_config import DecodersPrecision, ModelArgs, parse_decoder_json
1621
from models.tt_transformers.tt.rope import get_rot_mats
@@ -262,6 +267,32 @@ def test_tt_model_acc(
262267
pt_prefill_input[batch_id],
263268
)
264269

270+
if model_args.attention_mask:
271+
attn_mask = torch.ones(prefill_lens[0] + 1).unsqueeze(0)
272+
cache_postion = torch.arange(prefill_lens[0])
273+
attention_mask = [
274+
create_sliding_window_causal_mask(
275+
prefill_input,
276+
attn_mask,
277+
cache_postion,
278+
model_args,
279+
paged_attention_config,
280+
device=mesh_device,
281+
mode="prefill",
282+
),
283+
create_causal_mask(
284+
prefill_input,
285+
attn_mask,
286+
cache_postion,
287+
model_args,
288+
paged_attention_config,
289+
device=mesh_device,
290+
mode="prefill",
291+
),
292+
]
293+
else:
294+
attention_mask = None
295+
265296
tt_out = tt_model(
266297
prefill_input,
267298
current_pos=None,
@@ -270,6 +301,7 @@ def test_tt_model_acc(
270301
user_id=batch_id,
271302
mode="prefill",
272303
page_table=page_table_tt,
304+
attention_masks=attention_mask,
273305
get_last_token=((decoding_pos[batch_id] - 1) // 32) * 32,
274306
)
275307

@@ -322,13 +354,51 @@ def test_tt_model_acc(
322354
pt_decode_input,
323355
model_args.model_config["DECODE_RESIDUAL_MEMCFG"],
324356
)
357+
# Run TT model
358+
if model_args.attention_mask:
359+
torch_current_pos = ttnn.to_torch(current_pos_tensor)
360+
cur_batch_size = torch_current_pos.size(0)
361+
max_len = torch_current_pos.max().item() + 1 # longest seq length (+1 since pos starts at 0)
362+
363+
# Initialize with zeros
364+
attn_mask = torch.zeros(cur_batch_size, max_len, dtype=torch.long)
365+
for j, length in enumerate(torch_current_pos.tolist()):
366+
attn_mask[j, : length + 1] = 1
367+
368+
torch_current_pos = torch.tensor([max_len - 1])
369+
370+
attention_mask = [
371+
create_sliding_window_causal_mask(
372+
decode_input,
373+
attn_mask,
374+
current_pos,
375+
model_args,
376+
paged_attention_config,
377+
device=mesh_device,
378+
mode="decode",
379+
),
380+
create_causal_mask(
381+
decode_input,
382+
attn_mask,
383+
current_pos,
384+
model_args,
385+
paged_attention_config,
386+
device=mesh_device,
387+
mode="decode",
388+
),
389+
]
390+
attention_mask = [ttnn.to_device(v, device=mesh_device) for v in attention_mask]
391+
else:
392+
attention_mask = None
393+
325394
# Run TT model
326395
tt_out = tt_model(
327396
decode_input,
328397
current_pos_tensor,
329398
rot_mats_global=rot_mats,
330399
rot_mats_local=rot_mats_local,
331400
mode="decode",
401+
attention_masks=attention_mask,
332402
page_table=page_table_tt,
333403
)
334404

0 commit comments

Comments
 (0)