1010from loguru import logger
1111
1212import ttnn
13- from models .tt_transformers .tt .common import PagedAttentionConfig , preprocess_inputs_prefill
13+ from models .tt_transformers .tt .common import (
14+ PagedAttentionConfig ,
15+ create_causal_mask ,
16+ create_sliding_window_causal_mask ,
17+ preprocess_inputs_prefill ,
18+ )
1419from models .tt_transformers .tt .model import Transformer
1520from models .tt_transformers .tt .model_config import DecodersPrecision , ModelArgs , parse_decoder_json
1621from models .tt_transformers .tt .rope import get_rot_mats
@@ -262,6 +267,32 @@ def test_tt_model_acc(
262267 pt_prefill_input [batch_id ],
263268 )
264269
270+ if model_args .attention_mask :
271+ attn_mask = torch .ones (prefill_lens [0 ] + 1 ).unsqueeze (0 )
272+ cache_postion = torch .arange (prefill_lens [0 ])
273+ attention_mask = [
274+ create_sliding_window_causal_mask (
275+ prefill_input ,
276+ attn_mask ,
277+ cache_postion ,
278+ model_args ,
279+ paged_attention_config ,
280+ device = mesh_device ,
281+ mode = "prefill" ,
282+ ),
283+ create_causal_mask (
284+ prefill_input ,
285+ attn_mask ,
286+ cache_postion ,
287+ model_args ,
288+ paged_attention_config ,
289+ device = mesh_device ,
290+ mode = "prefill" ,
291+ ),
292+ ]
293+ else :
294+ attention_mask = None
295+
265296 tt_out = tt_model (
266297 prefill_input ,
267298 current_pos = None ,
@@ -270,6 +301,7 @@ def test_tt_model_acc(
270301 user_id = batch_id ,
271302 mode = "prefill" ,
272303 page_table = page_table_tt ,
304+ attention_masks = attention_mask ,
273305 get_last_token = ((decoding_pos [batch_id ] - 1 ) // 32 ) * 32 ,
274306 )
275307
@@ -322,13 +354,51 @@ def test_tt_model_acc(
322354 pt_decode_input ,
323355 model_args .model_config ["DECODE_RESIDUAL_MEMCFG" ],
324356 )
357+ # Run TT model
358+ if model_args .attention_mask :
359+ torch_current_pos = ttnn .to_torch (current_pos_tensor )
360+ cur_batch_size = torch_current_pos .size (0 )
361+ max_len = torch_current_pos .max ().item () + 1 # longest seq length (+1 since pos starts at 0)
362+
363+ # Initialize with zeros
364+ attn_mask = torch .zeros (cur_batch_size , max_len , dtype = torch .long )
365+ for j , length in enumerate (torch_current_pos .tolist ()):
366+ attn_mask [j , : length + 1 ] = 1
367+
368+ torch_current_pos = torch .tensor ([max_len - 1 ])
369+
370+ attention_mask = [
371+ create_sliding_window_causal_mask (
372+ decode_input ,
373+ attn_mask ,
374+ current_pos ,
375+ model_args ,
376+ paged_attention_config ,
377+ device = mesh_device ,
378+ mode = "decode" ,
379+ ),
380+ create_causal_mask (
381+ decode_input ,
382+ attn_mask ,
383+ current_pos ,
384+ model_args ,
385+ paged_attention_config ,
386+ device = mesh_device ,
387+ mode = "decode" ,
388+ ),
389+ ]
390+ attention_mask = [ttnn .to_device (v , device = mesh_device ) for v in attention_mask ]
391+ else :
392+ attention_mask = None
393+
325394 # Run TT model
326395 tt_out = tt_model (
327396 decode_input ,
328397 current_pos_tensor ,
329398 rot_mats_global = rot_mats ,
330399 rot_mats_local = rot_mats_local ,
331400 mode = "decode" ,
401+ attention_masks = attention_mask ,
332402 page_table = page_table_tt ,
333403 )
334404
0 commit comments