-
Notifications
You must be signed in to change notification settings - Fork 696
Gradient Checkpointing causes model to compute junk results (NNX) #4626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Any update or help on this? |
EasyDeL seems to be the only jax implementation out there that supports GRPO training. This is a blocking bug for my research. Is anyone else able to reproduce this? At least confirm that it is actually a Flax related bug? |
Thanks for reporting this @erfanzar. Would it be possible for you to create a test case where |
Hi @erfanzar , I've encountered similar issues with generating junk results when working with Qwen models, although perhaps for slightly different underlying reasons. In my case, using my own library implementation, I found that the junk results with Qwen seemed related to numerical stability, particularly within the attention mechanism when using My workaround was to force key parts of the attention calculation to run in Here's the relevant snippet from my attention implementation: import jax.numpy as jnp
import jax.nn as nn
import math
# Assuming query_states, key_states, value_states are initially bf16/fp16
# And attn_mask is prepared appropriately
# Force QK dot product and scaling to float32
attn_weights = (query_states.astype(jnp.float32) @ key_states.swapaxes(-2, -1).astype(jnp.float32)) / math.sqrt(self.head_dim)
# Apply mask in float32
if attn_mask is not None:
causal_mask = attn_mask # Assuming mask is already correctly broadcastable
# Ensure mask is also float32
attn_weights = attn_weights.astype(jnp.float32) + causal_mask.astype(jnp.float32)
# Softmax in float32
attn_weights = nn.softmax(attn_weights.astype(jnp.float32), axis=-1)
# Weight values in float32
attn_output = attn_weights @ value_states.astype(jnp.float32)
# Cast final output back to the original lower precision if needed
attn_output = attn_output.astype(jnp.bfloat16) # Or original dtype While your issue seems directly linked to nnx.remat potentially causing state interference, it's possible that the recomputation process within remat might be exacerbating underlying numerical precision sensitivities in the Qwen attention layers. Perhaps you could try modifying the attention mechanism within your EasyDeL setup (if possible, or by modifying the source temporarily) to use float32 for the intermediate calculations as shown above, even when gradient_checkpointing is enabled. It's a bit of a long shot since the root causes might be different, but forcing higher precision in sensitive areas like attention sometimes helps stabilize things unexpectedly. Hope this potentially offers another avenue to investigate or might provide some relief! |
Thanks @demon2036 for sharing this but we are using FA3 kernels for gpus and Splash for TPUs so attention dtype isn't really the issue |
@erfanzar Have you had a chance to try decoding using greedy sampling? If you use greedy sampling, is the first token generated correct? If the first token is correct, but the subsequent tokens are incorrect, that might strongly suggest an underlying numerical precision issue. In that scenario, perhaps it could be worth investigating if forcing operations like RMSNorm (and maybe others) to compute in fp32 makes a difference? |
System information
Problem you have encountered:
When using gradient checkpointing in an
flax.nnx
model (nnx.remat
), the model generates incorrect or junk results. This happens on both GPU and TPU. If two models are loaded:EasyDeLGradientCheckPointers.NOTHING_SAVEABLE
).Both models will generate junk results. However, if only Model 2 (without checkpointing) is created and used, it works correctly.
What you expected to happen:
Each model should independently function correctly, and the activation checkpointing (
remat
) should not corrupt inference outputs when applied to a separate model instance.Logs, error messages, etc:
(Provide any logs, traceback, or error messages if available.)
Steps to reproduce:
A minimal reproducible example is given below. Changing
gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE
toEasyDeLGradientCheckPointers.NONE
resolves the issue.Code snippet to reproduce the issue:
Workarounds Tried:
gradient_checkpointing=EasyDeLGradientCheckPointers.NONE
fixes the issue.remat
) is instantiated at a time prevents the corruption.Possible Cause:
nn.remat
might be affecting global state shared across models.flax.nnx
affecting subsequent inference.Additional Notes:
nn.remat
handling inflax.nnx
.Expected Fix: Ensure that gradient checkpointing via
nnx.remat
does not interfere with models that do not use checkpointing in the same session.The text was updated successfully, but these errors were encountered: