From bc319e25e02abd210d9cb15596b08ca4aed4bbef Mon Sep 17 00:00:00 2001 From: Abhinav Date: Sat, 15 Nov 2025 12:51:23 +0530 Subject: [PATCH 1/3] Disable flash attention for traced masks Add check for traced masks to disable flash attention. --- keras/src/backend/jax/nn.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 15cc90f7374..c04866eeeae 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1276,6 +1276,10 @@ def dot_product_attention( # use flash attention _can_use_flash_attention(query, key, value, bias, raise_error=True) + # Check if mask is traced - cannot use flash attention with traced masks + if mask is not None and isinstance(mask, jax.core.Tracer): + flash_attention = False + # TPU-specific flash attention path if is_tpu and flash_attention: # Get sharding parameters from distribution context From a6b65a4d711542d284d5c750024c372d4c6feae6 Mon Sep 17 00:00:00 2001 From: Abhinav Date: Sat, 15 Nov 2025 13:03:07 +0530 Subject: [PATCH 2/3] Update keras/src/backend/jax/nn.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras/src/backend/jax/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index c04866eeeae..ec77ab41f98 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1276,8 +1276,8 @@ def dot_product_attention( # use flash attention _can_use_flash_attention(query, key, value, bias, raise_error=True) - # Check if mask is traced - cannot use flash attention with traced masks - if mask is not None and isinstance(mask, jax.core.Tracer): + # On TPU, flash attention cannot be used with traced masks. + if is_tpu and mask is not None and isinstance(mask, jax.core.Tracer): flash_attention = False # TPU-specific flash attention path From 862c2719f79489bf305c7e184c59b799e29899b8 Mon Sep 17 00:00:00 2001 From: Abhinav Date: Sat, 15 Nov 2025 13:03:36 +0530 Subject: [PATCH 3/3] Update keras/src/backend/jax/nn.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras/src/backend/jax/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index ec77ab41f98..504504642d2 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1276,7 +1276,7 @@ def dot_product_attention( # use flash attention _can_use_flash_attention(query, key, value, bias, raise_error=True) - # On TPU, flash attention cannot be used with traced masks. + # On TPU, traced masks cause ConcretizationTypeError with flash attention. if is_tpu and mask is not None and isinstance(mask, jax.core.Tracer): flash_attention = False