99from models .common .lightweightmodule import LightweightModule
1010from models .common .rmsnorm import RMSNorm
1111from models .tt_transformers .tt .ccl import TT_CCL
12- from models .tt_transformers .tt .common import copy_host_to_device , get_decode_mask
12+ from models .tt_transformers .tt .common import copy_host_to_device
1313from models .tt_transformers .tt .decoder import TransformerBlock
1414from models .tt_transformers .tt .distributed_norm import DistributedNorm
1515from models .tt_transformers .tt .embedding import Embedding , ScaledEmbedding
@@ -30,10 +30,8 @@ def __init__(
3030 use_paged_kv_cache = False ,
3131 attention_class = None ,
3232 rope_setup_class = None ,
33- attn_mask = None ,
3433 ):
3534 super ().__init__ ()
36- self .paged_attention_config = paged_attention_config
3735 self .args = args
3836 self .vocab_size = args .vocab_size
3937 assert self .vocab_size > 0
@@ -130,31 +128,6 @@ def __init__(
130128 max_columns_per_device = self .args .max_columns_per_device_lm_head ,
131129 )
132130
133- if hasattr (self .args , "sliding_window" ) and self .args .sliding_window is not None :
134- # We are using sliding window attention in this model. We can create a custom attention mask to apply the sliding attention
135- # First we create the mask for all decode positions on host [bsz, n_heads_per_device, seq_len, seq_len]
136- self .decode_sliding_mask_mat = get_decode_mask (
137- self .args ,
138- self .mesh_device ,
139- paged_attention_config = paged_attention_config ,
140- )
141- # Then we copy a slice for a single decode position for each user on to device [bsz, n_heads_per_device, 1, seq_len]
142- # We can update this tensor on host each iteration and copy to device to save storing the large square tensor on device
143- self .device_decode_sliding_mask = ttnn .as_tensor (
144- torch .concat (
145- [self .decode_sliding_mask_mat [i , :, 0 :1 , :].unsqueeze (0 ) for i in range (self .args .max_batch_size )],
146- axis = 0 ,
147- ).transpose (1 , 2 ),
148- dtype = ttnn .bfloat4_b ,
149- layout = ttnn .TILE_LAYOUT ,
150- device = self .mesh_device ,
151- memory_config = ttnn .DRAM_MEMORY_CONFIG ,
152- mesh_mapper = ttnn .ReplicateTensorToMesh (self .mesh_device ),
153- )
154- else :
155- self .decode_sliding_mask_mat = None
156- self .device_decode_sliding_mask = None
157-
158131 def prepare_inputs_prefill (self , tokens , start_pos = 0 , page_table = None , chunk_page_table = None ):
159132 """
160133 Inputs are torch tensors or python types. This function returns ttnn
@@ -214,38 +187,13 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag
214187 )
215188 else :
216189 tt_chunk_page_table = None
217- if self .args .attention_mask :
218- attn_mask = torch .ones (S + 1 ).unsqueeze (0 )
219- cache_postion = torch .arange (S )
220- attention_mask = [
221- create_sliding_window_causal_mask (
222- tokens_embd ,
223- attn_mask ,
224- cache_postion ,
225- self .args ,
226- self .paged_attention_config ,
227- device = self .mesh_device ,
228- mode = "prefill" ,
229- ),
230- create_causal_mask (
231- tokens_embd ,
232- attn_mask ,
233- cache_postion ,
234- self .args ,
235- self .paged_attention_config ,
236- device = self .mesh_device ,
237- mode = "prefill" ,
238- ),
239- ]
240- else :
241- attention_mask = None
190+
242191 return (
243192 tokens_embd ,
244193 tt_rot_mats_prefill_global ,
245194 tt_rot_mats_prefill_local ,
246195 tt_page_table ,
247196 tt_chunk_page_table ,
248- attention_mask ,
249197 )
250198
251199 def prepare_inputs_decode (self , * inputs ):
@@ -309,41 +257,7 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None):
309257 mesh_shape = self .args .cluster_shape ,
310258 ),
311259 )
312- if self .args .attention_mask :
313- batch_size = current_pos .size (0 )
314- max_len = current_pos .max ().item () + 1 # longest seq length (+1 since pos starts at 0)
315-
316- # Initialize with zeros
317- attn_mask = torch .zeros (batch_size , max_len , dtype = torch .long )
318- for i , length in enumerate (current_pos .tolist ()):
319- attn_mask [i , : length + 1 ] = 1
320-
321- current_pos = torch .tensor ([max_len - 1 ])
322-
323- attention_mask = [
324- create_sliding_window_causal_mask (
325- tokens ,
326- attn_mask ,
327- current_pos ,
328- self .args ,
329- self .paged_attention_config ,
330- device = self .mesh_device ,
331- mode = "decode" ,
332- ),
333- create_causal_mask (
334- tokens ,
335- attn_mask ,
336- current_pos ,
337- self .args ,
338- self .paged_attention_config ,
339- device = self .mesh_device ,
340- mode = "decode" ,
341- ),
342- ]
343- else :
344- attention_mask = None
345-
346- return tokens , current_pos_tt , rope_idxs_global , rope_idxs_local , page_table , attention_mask
260+ return tokens , current_pos_tt , rope_idxs_global , rope_idxs_local , page_table
347261
348262 def _transform_decode_inputs_device (self , tokens ):
349263 """
@@ -414,17 +328,6 @@ def ttnn_prefill_forward(
414328 This method will take device tensors and any other args to run forward.
415329 It returns ttnn device tensors.
416330 """
417- if hasattr (self .args , "sliding_window" ) and self .args .sliding_window is not None :
418- mask = torch .triu (torch .full ((1 , 1 , x .shape [- 2 ], x .shape [- 2 ]), - float ("inf" )), diagonal = 1 )
419- sliding_mask = mask + torch .tril (
420- torch .full ((1 , 1 , x .shape [- 2 ], x .shape [- 2 ]), - float ("inf" )),
421- diagonal = - self .args .sliding_window ,
422- )
423- sliding_attn_mask = ttnn .from_torch (
424- sliding_mask , device = self .mesh_device , layout = ttnn .TILE_LAYOUT , dtype = ttnn .bfloat16
425- )
426- else :
427- sliding_attn_mask = None
428331 return self .forward (
429332 x ,
430333 current_pos = None ,
@@ -437,7 +340,6 @@ def ttnn_prefill_forward(
437340 chunk_start_idx = chunk_start_idx ,
438341 get_last_token = get_last_token ,
439342 kv_cache = kv_cache ,
440- sliding_attn_mask = sliding_attn_mask ,
441343 )
442344
443345 def _increment_decode_positions_device (self , current_pos , rot_mat_idxs_global , rot_mat_idxs_local ):
@@ -456,24 +358,6 @@ def _increment_decode_positions_device(self, current_pos, rot_mat_idxs_global, r
456358 if rot_mat_idxs_local is not None :
457359 ttnn .plus_one (rot_mat_idxs_local )
458360
459- def update_attention_masks (self , current_pos ):
460- torch_mask = torch .concat (
461- [
462- self .decode_sliding_mask_mat [i , :, current_pos [i ].item () : current_pos [i ].item () + 1 , :].unsqueeze (0 )
463- for i in range (self .decode_sliding_mask_mat .shape [0 ])
464- ],
465- axis = 0 ,
466- ).transpose (1 , 2 )
467- sliding_window_causal_mask = ttnn .as_tensor (
468- torch_mask ,
469- dtype = ttnn .bfloat4_b ,
470- layout = ttnn .TILE_LAYOUT ,
471- device = None ,
472- memory_config = ttnn .DRAM_MEMORY_CONFIG ,
473- mesh_mapper = ttnn .ReplicateTensorToMesh (self .mesh_device ),
474- )
475- ttnn .copy_host_to_device_tensor (sliding_window_causal_mask , self .device_decode_sliding_mask )
476-
477361 def ttnn_decode_forward (
478362 self ,
479363 x ,
@@ -502,7 +386,6 @@ def ttnn_decode_forward(
502386 mode = "decode" ,
503387 page_table = page_table ,
504388 kv_cache = kv_cache ,
505- sliding_attn_mask = self .device_decode_sliding_mask ,
506389 )
507390
508391 # Gather the output across all devices and untilize the tensor (for argmax)
@@ -553,7 +436,6 @@ def forward(
553436 chunk_start_idx = None ,
554437 get_last_token = - 1 ,
555438 kv_cache = None ,
556- sliding_attn_mask = None ,
557439 ):
558440 for i , layer in enumerate (self .layers ):
559441 # No-op if callers already provide the right memory config
@@ -565,14 +447,6 @@ def forward(
565447 elif activation_dtype is not None and x .dtype != activation_dtype :
566448 x = ttnn .typecast (x , activation_dtype )
567449
568- if sliding_attn_mask is not None :
569- attn_mask_i = (
570- sliding_attn_mask
571- if (hasattr (layer .attention , "is_sliding" ) and layer .attention .is_sliding )
572- else None
573- )
574- else :
575- attn_mask_i = None
576450 x = layer (
577451 x ,
578452 current_pos ,
@@ -584,7 +458,6 @@ def forward(
584458 chunk_page_table = chunk_page_table ,
585459 chunk_start_idx = chunk_start_idx ,
586460 kv_cache = kv_cache [i ] if kv_cache is not None else None ,
587- attn_mask = attn_mask_i ,
588461 )
589462
590463 if mode == "prefill" and get_last_token == - 1 :
0 commit comments