Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,10 @@ def dot_product_attention(
# use flash attention
_can_use_flash_attention(query, key, value, bias, raise_error=True)

# Check if mask is traced - cannot use flash attention with traced masks
if mask is not None and isinstance(mask, jax.core.Tracer):
flash_attention = False

# TPU-specific flash attention path
if is_tpu and flash_attention:
# Get sharding parameters from distribution context
Expand Down