Skip to content

Commit a6255ce

Browse files
Fix repeatation issue for Gemma migration
1 parent 813a8a7 commit a6255ce

File tree

5 files changed

+145
-9
lines changed

5 files changed

+145
-9
lines changed

models/tt_transformers/tt/attention.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,8 @@ def forward_decode(
388388
rot_mats=None,
389389
page_table=None,
390390
kv_cache=None,
391+
causal_mask=None,
392+
is_causal=True,
391393
) -> ttnn.Tensor:
392394
"""
393395
x: (seq_len, 1, batch, dim)
@@ -516,6 +518,8 @@ def forward_decode(
516518
program_config=self.model_config["SDPA_DECODE_PROGCFG"],
517519
compute_kernel_config=self.sdpa_decode_compute_kernel_cfg,
518520
memory_config=ttnn.DRAM_MEMORY_CONFIG,
521+
attn_mask=causal_mask,
522+
is_causal=is_causal,
519523
)
520524
else:
521525
attn_output_1G4D = ttnn.transformer.scaled_dot_product_attention_decode(
@@ -527,6 +531,8 @@ def forward_decode(
527531
program_config=self.model_config["SDPA_DECODE_PROGCFG"],
528532
compute_kernel_config=self.sdpa_decode_compute_kernel_cfg,
529533
memory_config=ttnn.DRAM_MEMORY_CONFIG, # FIXME: why not L1 height sharded e.g. SCORES_BATCHED_MM_OUTPUT_MEMCFG?
534+
attn_mask=causal_mask,
535+
is_causal=is_causal,
530536
)
531537

532538
ttnn.deallocate(q_heads_1BQD)
@@ -671,6 +677,8 @@ def forward_prefill(
671677
chunk_page_table=None,
672678
chunk_start_idx=None,
673679
kv_cache=None,
680+
causal_mask=None,
681+
is_causal=True,
674682
):
675683
seq_len = x_11SH.shape[-2]
676684
assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128"
@@ -833,10 +841,11 @@ def forward_prefill(
833841
q_heads_1QSD_8b,
834842
k_heads_1KSD_8b,
835843
v_heads_1VSD_8b,
836-
is_causal=True,
837844
scale=self.scale,
838845
compute_kernel_config=self.sdpa_prefill_compute_kernel_cfg,
839846
program_config=self.model_config["SDPA_PROGCFG"](seq_len),
847+
attn_mask=causal_mask,
848+
is_causal=is_causal,
840849
)
841850

842851
# deallocate keys and values
@@ -915,6 +924,8 @@ def forward(
915924
chunk_page_table=None,
916925
chunk_start_idx=None,
917926
kv_cache=None,
927+
causal_mask=None,
928+
is_causal=True,
918929
):
919930
if mode == "prefill":
920931
return self.forward_prefill(
@@ -925,9 +936,19 @@ def forward(
925936
chunk_page_table=chunk_page_table,
926937
chunk_start_idx=chunk_start_idx,
927938
kv_cache=kv_cache,
939+
causal_mask=causal_mask,
940+
is_causal=is_causal,
928941
)
929942
else:
930-
return self.forward_decode(x, current_pos, rot_mats, page_table=page_table, kv_cache=kv_cache)
943+
return self.forward_decode(
944+
x,
945+
current_pos,
946+
rot_mats,
947+
page_table=page_table,
948+
kv_cache=kv_cache,
949+
causal_mask=causal_mask,
950+
is_causal=is_causal,
951+
)
931952

932953
def prefill_prepare_tensor_for_kv_cache(self, key_or_value_layer, user_id):
933954
tensor_copy = ttnn.clone(key_or_value_layer)

models/tt_transformers/tt/decoder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def forward(
176176
chunk_page_table=None,
177177
chunk_start_idx=None,
178178
kv_cache=None,
179+
causal_mask=None,
180+
is_causal=True,
179181
) -> ttnn.Tensor:
180182
TG = self.args.is_galaxy
181183
residual = x
@@ -204,6 +206,8 @@ def forward(
204206
chunk_page_table=chunk_page_table,
205207
chunk_start_idx=chunk_start_idx,
206208
kv_cache=kv_cache,
209+
causal_mask=causal_mask,
210+
is_causal=is_causal,
207211
)
208212
if self.pre_ff_norm == None:
209213
# Here x and attn_out are both fractured across devices

models/tt_transformers/tt/model.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from models.common.lightweightmodule import LightweightModule
1010
from models.common.rmsnorm import RMSNorm
1111
from models.tt_transformers.tt.ccl import TT_CCL
12-
from models.tt_transformers.tt.common import copy_host_to_device
12+
from models.tt_transformers.tt.common import copy_host_to_device, create_causal_mask, create_sliding_window_causal_mask
1313
from models.tt_transformers.tt.decoder import TransformerBlock
1414
from models.tt_transformers.tt.distributed_norm import DistributedNorm
1515
from models.tt_transformers.tt.embedding import Embedding, ScaledEmbedding
@@ -32,6 +32,7 @@ def __init__(
3232
rope_setup_class=None,
3333
):
3434
super().__init__()
35+
self.paged_attention_config = paged_attention_config
3536
self.args = args
3637
self.vocab_size = args.vocab_size
3738
assert self.vocab_size > 0
@@ -187,14 +188,38 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag
187188
)
188189
else:
189190
tt_chunk_page_table = None
190-
191+
if self.args.attention_mask:
192+
attn_mask = torch.ones(S + 1).unsqueeze(0)
193+
cache_postion = torch.arange(S)
194+
attention_mask = [
195+
create_sliding_window_causal_mask(
196+
tokens_embd,
197+
attn_mask,
198+
cache_postion,
199+
self.args,
200+
self.paged_attention_config,
201+
device=self.mesh_device,
202+
mode="prefill",
203+
),
204+
create_causal_mask(
205+
tokens_embd,
206+
attn_mask,
207+
cache_postion,
208+
self.args,
209+
self.paged_attention_config,
210+
device=self.mesh_device,
211+
mode="prefill",
212+
),
213+
]
214+
else:
215+
attention_mask = None
191216
return (
192217
tokens_embd,
193218
tt_rot_mats_prefill_global,
194219
tt_rot_mats_prefill_local,
195220
tt_page_table,
196221
tt_chunk_page_table,
197-
None,
222+
attention_mask,
198223
)
199224

200225
def prepare_inputs_decode(self, *inputs):
@@ -258,7 +283,41 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None):
258283
mesh_shape=self.args.cluster_shape,
259284
),
260285
)
261-
return tokens, current_pos_tt, rope_idxs_global, rope_idxs_local, page_table, None
286+
if self.args.attention_mask:
287+
batch_size = current_pos.size(0)
288+
max_len = current_pos.max().item() + 1 # longest seq length (+1 since pos starts at 0)
289+
290+
# Initialize with zeros
291+
attn_mask = torch.zeros(batch_size, max_len, dtype=torch.long)
292+
for i, length in enumerate(current_pos.tolist()):
293+
attn_mask[i, : length + 1] = 1
294+
295+
current_pos = torch.tensor([max_len - 1])
296+
297+
attention_mask = [
298+
create_sliding_window_causal_mask(
299+
tokens,
300+
attn_mask,
301+
current_pos,
302+
self.args,
303+
self.paged_attention_config,
304+
device=self.mesh_device,
305+
mode="decode",
306+
),
307+
create_causal_mask(
308+
tokens,
309+
attn_mask,
310+
current_pos,
311+
self.args,
312+
self.paged_attention_config,
313+
device=self.mesh_device,
314+
mode="decode",
315+
),
316+
]
317+
else:
318+
attention_mask = None
319+
320+
return tokens, current_pos_tt, rope_idxs_global, rope_idxs_local, page_table, attention_mask
262321

263322
def _transform_decode_inputs_device(self, tokens):
264323
"""
@@ -324,6 +383,7 @@ def ttnn_prefill_forward(
324383
chunk_start_idx=None,
325384
get_last_token=-1,
326385
kv_cache=None,
386+
attention_masks=None,
327387
**kwargs,
328388
):
329389
"""
@@ -342,6 +402,7 @@ def ttnn_prefill_forward(
342402
chunk_start_idx=chunk_start_idx,
343403
get_last_token=get_last_token,
344404
kv_cache=kv_cache,
405+
attention_masks=attention_masks,
345406
)
346407

347408
def _increment_decode_positions_device(self, current_pos, rot_mat_idxs_global, rot_mat_idxs_local):
@@ -369,6 +430,7 @@ def ttnn_decode_forward(
369430
page_table=None,
370431
kv_cache=None,
371432
argmax_on_device=False,
433+
attention_masks=None,
372434
**kwargs,
373435
):
374436
"""
@@ -388,6 +450,7 @@ def ttnn_decode_forward(
388450
rot_mats_local=rot_mats_local,
389451
mode="decode",
390452
page_table=page_table,
453+
attention_masks=attention_masks,
391454
kv_cache=kv_cache,
392455
)
393456

@@ -439,6 +502,7 @@ def forward(
439502
chunk_start_idx=None,
440503
get_last_token=-1,
441504
kv_cache=None,
505+
attention_masks=None,
442506
):
443507
for i, layer in enumerate(self.layers):
444508
# No-op if callers already provide the right memory config
@@ -449,7 +513,16 @@ def forward(
449513
x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"], activation_dtype)
450514
elif activation_dtype is not None and x.dtype != activation_dtype:
451515
x = ttnn.typecast(x, activation_dtype)
452-
516+
causal_mask = (
517+
(
518+
attention_masks[0]
519+
if (hasattr(layer.attention, "is_sliding") and layer.attention.is_sliding)
520+
else attention_masks[1]
521+
)
522+
if attention_masks is not None
523+
else None
524+
)
525+
is_causal = False if causal_mask is not None else True
453526
x = layer(
454527
x,
455528
current_pos,
@@ -461,6 +534,8 @@ def forward(
461534
chunk_page_table=chunk_page_table,
462535
chunk_start_idx=chunk_start_idx,
463536
kv_cache=kv_cache[i] if kv_cache is not None else None,
537+
causal_mask=causal_mask,
538+
is_causal=is_causal,
464539
)
465540

466541
if mode == "prefill" and get_last_token == -1:

models/tt_transformers/tt/model_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,11 +1431,13 @@ def _get_hidden_activation_type(self, config):
14311431

14321432
def _set_model_specific_params(self):
14331433
# Gemma3 specific params
1434+
self.attention_mask = False
14341435
is_gemma3 = "gemma-3" in self.base_model_name.lower()
14351436
if is_gemma3:
14361437
self.rms_norm_add_unit_offset = True
14371438
self.embed_scale = self.dim**0.5
14381439
self.sliding_window = 512
1440+
self.attention_mask = True
14391441

14401442
def _set_params_from_dict(self, config, is_hf=False):
14411443
eos_token_id = config.get("eos_token_id", None)

models/tt_transformers/tt/multimodal/gemma3/gemma_e2e_model.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import torch
2+
13
import ttnn
4+
from models.tt_transformers.tt.common import create_causal_mask, create_sliding_window_causal_mask
25
from models.tt_transformers.tt.model import Transformer
36
from models.tt_transformers.tt.multimodal.gemma3.gemma_vision_model import TtGemmaTransformerVision
47

@@ -109,8 +112,39 @@ def prepare_inputs_prefill(self, pt_tokens, start_pos=0, page_table=None, chunk_
109112
)
110113
else:
111114
tt_chunk_page_table = None
112-
113-
return tokens_embd, tt_rot_mats_prefill_global, tt_rot_mats_prefill_local, tt_page_table, tt_chunk_page_table
115+
if self.args.attention_mask:
116+
attn_mask = torch.ones(S + 1).unsqueeze(0)
117+
cache_postion = torch.arange(S)
118+
attention_mask = [
119+
create_sliding_window_causal_mask(
120+
tokens_embd,
121+
attn_mask,
122+
cache_postion,
123+
self.args,
124+
self.paged_attention_config,
125+
device=self.mesh_device,
126+
mode="prefill",
127+
),
128+
create_causal_mask(
129+
tokens_embd,
130+
attn_mask,
131+
cache_postion,
132+
self.args,
133+
self.paged_attention_config,
134+
device=self.mesh_device,
135+
mode="prefill",
136+
),
137+
]
138+
else:
139+
attention_mask = None
140+
return (
141+
tokens_embd,
142+
tt_rot_mats_prefill_global,
143+
tt_rot_mats_prefill_local,
144+
tt_page_table,
145+
tt_chunk_page_table,
146+
attention_mask,
147+
)
114148

115149
def compute_vision_token(self, pixel_values=None):
116150
if pixel_values is None:

0 commit comments

Comments
 (0)