Skip to content

Commit d40f3f7

Browse files
Add support for sliding window mask in TT-Transformers
1 parent e888971 commit d40f3f7

File tree

4 files changed

+8
-146
lines changed

4 files changed

+8
-146
lines changed

models/tt_transformers/tt/attention.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def __init__(
3030
super().__init__()
3131

3232
self.mesh_device = mesh_device
33+
self.layer_idx = layer_num
34+
self.configuration = configuration
3335
self.tt_ccl = tt_ccl
3436
self.num_devices = configuration.num_devices
3537
self.TG = self.num_devices == 32
@@ -388,7 +390,6 @@ def forward_decode(
388390
rot_mats=None,
389391
page_table=None,
390392
kv_cache=None,
391-
attn_mask=None,
392393
) -> ttnn.Tensor:
393394
"""
394395
x: (seq_len, 1, batch, dim)
@@ -513,12 +514,11 @@ def forward_decode(
513514
values,
514515
cur_pos_tensor=current_pos,
515516
page_table_tensor=page_table,
516-
attn_mask=attn_mask,
517-
is_causal=True if attn_mask is None else False,
518517
scale=self.scale,
519518
program_config=self.model_config["SDPA_DECODE_PROGCFG"],
520519
compute_kernel_config=self.sdpa_decode_compute_kernel_cfg,
521520
memory_config=ttnn.DRAM_MEMORY_CONFIG,
521+
sliding_window=self.configuration.sliding_window if self.is_sliding else 0,
522522
)
523523
else:
524524
attn_output_1G4D = ttnn.transformer.scaled_dot_product_attention_decode(
@@ -527,11 +527,10 @@ def forward_decode(
527527
values,
528528
cur_pos_tensor=current_pos,
529529
scale=self.scale,
530-
is_causal=True if attn_mask is None else False,
531-
attn_mask=attn_mask,
532530
program_config=self.model_config["SDPA_DECODE_PROGCFG"],
533531
compute_kernel_config=self.sdpa_decode_compute_kernel_cfg,
534532
memory_config=ttnn.DRAM_MEMORY_CONFIG, # FIXME: why not L1 height sharded e.g. SCORES_BATCHED_MM_OUTPUT_MEMCFG?
533+
sliding_window=self.configuration.sliding_window if self.is_sliding else 0,
535534
)
536535

537536
ttnn.deallocate(q_heads_1BQD)
@@ -676,7 +675,6 @@ def forward_prefill(
676675
chunk_page_table=None,
677676
chunk_start_idx=None,
678677
kv_cache=None,
679-
attn_mask=None,
680678
):
681679
seq_len = x_11SH.shape[-2]
682680
assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128"
@@ -831,8 +829,6 @@ def forward_prefill(
831829
values_BKSD,
832830
page_table,
833831
chunk_start_idx,
834-
attn_mask=attn_mask,
835-
is_causal=True if attn_mask is None else False,
836832
compute_kernel_config=self.sdpa_prefill_compute_kernel_cfg,
837833
program_config=self.model_config["SDPA_PROGCFG"](seq_len),
838834
)
@@ -844,8 +840,6 @@ def forward_prefill(
844840
scale=self.scale,
845841
compute_kernel_config=self.sdpa_prefill_compute_kernel_cfg,
846842
program_config=self.model_config["SDPA_PROGCFG"](seq_len),
847-
attn_mask=attn_mask,
848-
is_causal=True if attn_mask is None else False,
849843
)
850844

851845
# deallocate keys and values
@@ -924,7 +918,6 @@ def forward(
924918
chunk_page_table=None,
925919
chunk_start_idx=None,
926920
kv_cache=None,
927-
attn_mask=None,
928921
):
929922
if mode == "prefill":
930923
return self.forward_prefill(
@@ -935,7 +928,6 @@ def forward(
935928
chunk_page_table=chunk_page_table,
936929
chunk_start_idx=chunk_start_idx,
937930
kv_cache=kv_cache,
938-
attn_mask=attn_mask,
939931
)
940932
else:
941933
return self.forward_decode(
@@ -944,7 +936,6 @@ def forward(
944936
rot_mats,
945937
page_table=page_table,
946938
kv_cache=kv_cache,
947-
attn_mask=attn_mask,
948939
)
949940

950941
def prefill_prepare_tensor_for_kv_cache(self, key_or_value_layer, user_id):

models/tt_transformers/tt/decoder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def forward(
202202
chunk_page_table=None,
203203
chunk_start_idx=None,
204204
kv_cache=None,
205-
attn_mask=None,
206205
) -> ttnn.Tensor:
207206
TG = self.args.is_galaxy
208207
residual = x
@@ -231,7 +230,6 @@ def forward(
231230
chunk_page_table=chunk_page_table,
232231
chunk_start_idx=chunk_start_idx,
233232
kv_cache=kv_cache,
234-
attn_mask=attn_mask,
235233
)
236234
if self.pre_ff_norm == None:
237235
# Here x and attn_out are both fractured across devices

models/tt_transformers/tt/model.py

Lines changed: 3 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from models.common.lightweightmodule import LightweightModule
1010
from models.common.rmsnorm import RMSNorm
1111
from models.tt_transformers.tt.ccl import TT_CCL
12-
from models.tt_transformers.tt.common import copy_host_to_device, get_decode_mask
12+
from models.tt_transformers.tt.common import copy_host_to_device
1313
from models.tt_transformers.tt.decoder import TransformerBlock
1414
from models.tt_transformers.tt.distributed_norm import DistributedNorm
1515
from models.tt_transformers.tt.embedding import Embedding, ScaledEmbedding
@@ -30,10 +30,8 @@ def __init__(
3030
use_paged_kv_cache=False,
3131
attention_class=None,
3232
rope_setup_class=None,
33-
attn_mask=None,
3433
):
3534
super().__init__()
36-
self.paged_attention_config = paged_attention_config
3735
self.args = args
3836
self.vocab_size = args.vocab_size
3937
assert self.vocab_size > 0
@@ -130,31 +128,6 @@ def __init__(
130128
max_columns_per_device=self.args.max_columns_per_device_lm_head,
131129
)
132130

133-
if hasattr(self.args, "sliding_window") and self.args.sliding_window is not None:
134-
# We are using sliding window attention in this model. We can create a custom attention mask to apply the sliding attention
135-
# First we create the mask for all decode positions on host [bsz, n_heads_per_device, seq_len, seq_len]
136-
self.decode_sliding_mask_mat = get_decode_mask(
137-
self.args,
138-
self.mesh_device,
139-
paged_attention_config=paged_attention_config,
140-
)
141-
# Then we copy a slice for a single decode position for each user on to device [bsz, n_heads_per_device, 1, seq_len]
142-
# We can update this tensor on host each iteration and copy to device to save storing the large square tensor on device
143-
self.device_decode_sliding_mask = ttnn.as_tensor(
144-
torch.concat(
145-
[self.decode_sliding_mask_mat[i, :, 0:1, :].unsqueeze(0) for i in range(self.args.max_batch_size)],
146-
axis=0,
147-
).transpose(1, 2),
148-
dtype=ttnn.bfloat4_b,
149-
layout=ttnn.TILE_LAYOUT,
150-
device=self.mesh_device,
151-
memory_config=ttnn.DRAM_MEMORY_CONFIG,
152-
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
153-
)
154-
else:
155-
self.decode_sliding_mask_mat = None
156-
self.device_decode_sliding_mask = None
157-
158131
def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_page_table=None):
159132
"""
160133
Inputs are torch tensors or python types. This function returns ttnn
@@ -214,38 +187,13 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag
214187
)
215188
else:
216189
tt_chunk_page_table = None
217-
if self.args.attention_mask:
218-
attn_mask = torch.ones(S + 1).unsqueeze(0)
219-
cache_postion = torch.arange(S)
220-
attention_mask = [
221-
create_sliding_window_causal_mask(
222-
tokens_embd,
223-
attn_mask,
224-
cache_postion,
225-
self.args,
226-
self.paged_attention_config,
227-
device=self.mesh_device,
228-
mode="prefill",
229-
),
230-
create_causal_mask(
231-
tokens_embd,
232-
attn_mask,
233-
cache_postion,
234-
self.args,
235-
self.paged_attention_config,
236-
device=self.mesh_device,
237-
mode="prefill",
238-
),
239-
]
240-
else:
241-
attention_mask = None
190+
242191
return (
243192
tokens_embd,
244193
tt_rot_mats_prefill_global,
245194
tt_rot_mats_prefill_local,
246195
tt_page_table,
247196
tt_chunk_page_table,
248-
attention_mask,
249197
)
250198

251199
def prepare_inputs_decode(self, *inputs):
@@ -309,41 +257,7 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None):
309257
mesh_shape=self.args.cluster_shape,
310258
),
311259
)
312-
if self.args.attention_mask:
313-
batch_size = current_pos.size(0)
314-
max_len = current_pos.max().item() + 1 # longest seq length (+1 since pos starts at 0)
315-
316-
# Initialize with zeros
317-
attn_mask = torch.zeros(batch_size, max_len, dtype=torch.long)
318-
for i, length in enumerate(current_pos.tolist()):
319-
attn_mask[i, : length + 1] = 1
320-
321-
current_pos = torch.tensor([max_len - 1])
322-
323-
attention_mask = [
324-
create_sliding_window_causal_mask(
325-
tokens,
326-
attn_mask,
327-
current_pos,
328-
self.args,
329-
self.paged_attention_config,
330-
device=self.mesh_device,
331-
mode="decode",
332-
),
333-
create_causal_mask(
334-
tokens,
335-
attn_mask,
336-
current_pos,
337-
self.args,
338-
self.paged_attention_config,
339-
device=self.mesh_device,
340-
mode="decode",
341-
),
342-
]
343-
else:
344-
attention_mask = None
345-
346-
return tokens, current_pos_tt, rope_idxs_global, rope_idxs_local, page_table, attention_mask
260+
return tokens, current_pos_tt, rope_idxs_global, rope_idxs_local, page_table
347261

348262
def _transform_decode_inputs_device(self, tokens):
349263
"""
@@ -414,17 +328,6 @@ def ttnn_prefill_forward(
414328
This method will take device tensors and any other args to run forward.
415329
It returns ttnn device tensors.
416330
"""
417-
if hasattr(self.args, "sliding_window") and self.args.sliding_window is not None:
418-
mask = torch.triu(torch.full((1, 1, x.shape[-2], x.shape[-2]), -float("inf")), diagonal=1)
419-
sliding_mask = mask + torch.tril(
420-
torch.full((1, 1, x.shape[-2], x.shape[-2]), -float("inf")),
421-
diagonal=-self.args.sliding_window,
422-
)
423-
sliding_attn_mask = ttnn.from_torch(
424-
sliding_mask, device=self.mesh_device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16
425-
)
426-
else:
427-
sliding_attn_mask = None
428331
return self.forward(
429332
x,
430333
current_pos=None,
@@ -437,7 +340,6 @@ def ttnn_prefill_forward(
437340
chunk_start_idx=chunk_start_idx,
438341
get_last_token=get_last_token,
439342
kv_cache=kv_cache,
440-
sliding_attn_mask=sliding_attn_mask,
441343
)
442344

443345
def _increment_decode_positions_device(self, current_pos, rot_mat_idxs_global, rot_mat_idxs_local):
@@ -456,24 +358,6 @@ def _increment_decode_positions_device(self, current_pos, rot_mat_idxs_global, r
456358
if rot_mat_idxs_local is not None:
457359
ttnn.plus_one(rot_mat_idxs_local)
458360

459-
def update_attention_masks(self, current_pos):
460-
torch_mask = torch.concat(
461-
[
462-
self.decode_sliding_mask_mat[i, :, current_pos[i].item() : current_pos[i].item() + 1, :].unsqueeze(0)
463-
for i in range(self.decode_sliding_mask_mat.shape[0])
464-
],
465-
axis=0,
466-
).transpose(1, 2)
467-
sliding_window_causal_mask = ttnn.as_tensor(
468-
torch_mask,
469-
dtype=ttnn.bfloat4_b,
470-
layout=ttnn.TILE_LAYOUT,
471-
device=None,
472-
memory_config=ttnn.DRAM_MEMORY_CONFIG,
473-
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
474-
)
475-
ttnn.copy_host_to_device_tensor(sliding_window_causal_mask, self.device_decode_sliding_mask)
476-
477361
def ttnn_decode_forward(
478362
self,
479363
x,
@@ -502,7 +386,6 @@ def ttnn_decode_forward(
502386
mode="decode",
503387
page_table=page_table,
504388
kv_cache=kv_cache,
505-
sliding_attn_mask=self.device_decode_sliding_mask,
506389
)
507390

508391
# Gather the output across all devices and untilize the tensor (for argmax)
@@ -553,7 +436,6 @@ def forward(
553436
chunk_start_idx=None,
554437
get_last_token=-1,
555438
kv_cache=None,
556-
sliding_attn_mask=None,
557439
):
558440
for i, layer in enumerate(self.layers):
559441
# No-op if callers already provide the right memory config
@@ -565,14 +447,6 @@ def forward(
565447
elif activation_dtype is not None and x.dtype != activation_dtype:
566448
x = ttnn.typecast(x, activation_dtype)
567449

568-
if sliding_attn_mask is not None:
569-
attn_mask_i = (
570-
sliding_attn_mask
571-
if (hasattr(layer.attention, "is_sliding") and layer.attention.is_sliding)
572-
else None
573-
)
574-
else:
575-
attn_mask_i = None
576450
x = layer(
577451
x,
578452
current_pos,
@@ -584,7 +458,6 @@ def forward(
584458
chunk_page_table=chunk_page_table,
585459
chunk_start_idx=chunk_start_idx,
586460
kv_cache=kv_cache[i] if kv_cache is not None else None,
587-
attn_mask=attn_mask_i,
588461
)
589462

590463
if mode == "prefill" and get_last_token == -1:

models/tt_transformers/tt/model_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2410,7 +2410,7 @@ def create_tokenizer(self):
24102410

24112411
# Add meta-compatible stop token list to the HF tokenizer
24122412
if not "stop_tokens" in tokenizer.__dict__:
2413-
tokenizer.stop_tokens = [tokenizer.eos_token_id]
2413+
tokenizer.stop_tokens = self.eos_token_id if self.eos_token_id is not None else [tokenizer.eos_token_id]
24142414
# Phi-3-mini uses "<|end|>" as EOS token
24152415
if "phi-3-mini" in self.base_model_name.lower():
24162416
tokenizer.stop_tokens.append(tokenizer.encode("<|end|>")[0])

0 commit comments

Comments
 (0)