diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 15cc90f7374..504504642d2 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1276,6 +1276,10 @@ def dot_product_attention( # use flash attention _can_use_flash_attention(query, key, value, bias, raise_error=True) + # On TPU, traced masks cause ConcretizationTypeError with flash attention. + if is_tpu and mask is not None and isinstance(mask, jax.core.Tracer): + flash_attention = False + # TPU-specific flash attention path if is_tpu and flash_attention: # Get sharding parameters from distribution context