Skip to content

Gradient Checkpointing causes model to compute junk results (NNX) #4626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
erfanzar opened this issue Mar 15, 2025 · 6 comments
Open

Gradient Checkpointing causes model to compute junk results (NNX) #4626

erfanzar opened this issue Mar 15, 2025 · 6 comments

Comments

@erfanzar
Copy link

System information

  • Flax, JAX, JAXlib versions: (0.10.4)
  • Python version: (e.g., 3.10.2)

Problem you have encountered:

When using gradient checkpointing in an flax.nnx model (nnx.remat), the model generates incorrect or junk results. This happens on both GPU and TPU. If two models are loaded:

  1. Model 1: Uses gradient checkpointing (EasyDeLGradientCheckPointers.NOTHING_SAVEABLE).
  2. Model 2: Does not use gradient checkpointing.

Both models will generate junk results. However, if only Model 2 (without checkpointing) is created and used, it works correctly.

What you expected to happen:

Each model should independently function correctly, and the activation checkpointing (remat) should not corrupt inference outputs when applied to a separate model instance.

Logs, error messages, etc:

(Provide any logs, traceback, or error messages if available.)

Steps to reproduce:

A minimal reproducible example is given below. Changing gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE to EasyDeLGradientCheckPointers.NONE resolves the issue.

Code snippet to reproduce the issue:

import easydel as ed
import jax
import transformers
from jax import numpy as jnp

def auto_remat(
    *modules: tp.Type[M],
    policy: tp.Union[
        EasyDeLGradientCheckPointers, str
    ] = EasyDeLGradientCheckPointers.NONE,
    prevent_cse: bool = True,
) -> tp.Tuple[tp.Type[M], ...]:
    if policy == EasyDeLGradientCheckPointers.NONE:
        return modules
    if isinstance(policy, str):
        policy = get_gradient_checkpoint_policy(policy)
    outs = ()
    for module in modules:
        assert issubclass(module, nn.Module)
        static_argnums = extract_static_parameters(module=module)
        if static_argnums is None:
            static_argnums = ()

        module.__call__ = nn.remat(
            f=module.__call__,
            prevent_cse=prevent_cse,
            static_argnums=static_argnums,
            policy=policy,
        )
        outs += (module,)
    return outs

def main():
    sharding_axis_dims = (1, 1, 1, -1)
    prefill_length = 512
    max_new_tokens = 128
    max_length = max_new_tokens + prefill_length
    pretrained_model_name_or_path = "Qwen/Qwen2.5-7B-Instruct"

    dtype = param_dtype = jnp.bfloat16
    partition_axis = ed.PartitionAxis()
    tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
    tokenizer.padding_side = "left"
    tokenizer.pad_token_id = tokenizer.eos_token_id

    model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path,
        auto_shard_model=True,
        sharding_axis_dims=sharding_axis_dims,
        config_kwargs=ed.EasyDeLBaseConfigDict(
            freq_max_position_embeddings=max_length,
            mask_max_position_embeddings=max_length,
            kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.NONE,
            gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
            attn_dtype=jnp.bfloat16,
            attn_mechanism=ed.AttentionMechanisms.AUTO,
        ),
        quantization_method=ed.EasyDeLQuantizationMethods.NONE,
        param_dtype=param_dtype,
        dtype=dtype,
        partition_axis=partition_axis,
        precision=jax.lax.Precision.DEFAULT,
    )

    inference = ed.vInference(
        model=model,
        processor_class=tokenizer,
        generation_config=ed.vInferenceConfig(
            max_new_tokens=max_new_tokens,
            temperature=0.8,
            do_sample=True,
            top_p=0.95,
            top_k=10,
            eos_token_id=model.generation_config.eos_token_id,
            streaming_chunks=32,
            num_return_sequences=1,
        ),
    )

    inference.precompile(
        ed.vInferencePreCompileConfig(
            batch_size=1,
            prefill_length=prefill_length,
        )
    )

    messages = [
        {"role": "system", "content": "You are a helpful AI assistant."},
        {"role": "user", "content": "write 10 lines story about why you love EasyDeL"},
    ]

    ids = tokenizer.apply_chat_template(
        messages,
        return_tensors="jax",
        return_dict=True,
        add_generation_prompt=True,
    )
    print("Start Generation Process.")
    for response in inference.generate(**ids):
        ...
    print(
        tokenizer.batch_decode(
            response.sequences[..., response.padded_length :],
            skip_special_tokens=True,
        )
    )
    print(response.tokens_pre_second)

if __name__ == "__main__":
    main()

Workarounds Tried:

  • Setting gradient_checkpointing=EasyDeLGradientCheckPointers.NONE fixes the issue.
  • Ensuring that only one model (either with or without remat) is instantiated at a time prevents the corruption.
  • The issue only occurs when both models exist in memory simultaneously.

Possible Cause:

  • nn.remat might be affecting global state shared across models.
  • Memory corruption or state retention in flax.nnx affecting subsequent inference.

Additional Notes:

  • Would need further debugging into nn.remat handling in flax.nnx.
  • Possible scope leakage between checkpointed and non-checkpointed models.

Expected Fix: Ensure that gradient checkpointing via nnx.remat does not interfere with models that do not use checkpointing in the same session.

@erfanzar
Copy link
Author

Any update or help on this?

@peregilk
Copy link

EasyDeL seems to be the only jax implementation out there that supports GRPO training.

This is a blocking bug for my research. Is anyone else able to reproduce this? At least confirm that it is actually a Flax related bug?

@cgarciae
Copy link
Collaborator

Thanks for reporting this @erfanzar. Would it be possible for you to create a test case where nnx.remat fails?
nnx.remat is not doing whole lot except forwarding the underlying state to JAX.

@demon2036
Copy link

Hi @erfanzar ,

I've encountered similar issues with generating junk results when working with Qwen models, although perhaps for slightly different underlying reasons.

In my case, using my own library implementation, I found that the junk results with Qwen seemed related to numerical stability, particularly within the attention mechanism when using bfloat16. It appeared that some operations were becoming unstable at lower precision.

My workaround was to force key parts of the attention calculation to run in float32 precision, even if the rest of the model used bfloat16. This stabilized the computation and resolved the junk output issue for me.

Here's the relevant snippet from my attention implementation:

import jax.numpy as jnp
import jax.nn as nn
import math

# Assuming query_states, key_states, value_states are initially bf16/fp16
# And attn_mask is prepared appropriately

# Force QK dot product and scaling to float32
attn_weights = (query_states.astype(jnp.float32) @ key_states.swapaxes(-2, -1).astype(jnp.float32)) / math.sqrt(self.head_dim)

# Apply mask in float32
if attn_mask is not None:
    causal_mask = attn_mask # Assuming mask is already correctly broadcastable
    # Ensure mask is also float32
    attn_weights = attn_weights.astype(jnp.float32) + causal_mask.astype(jnp.float32)

# Softmax in float32
attn_weights = nn.softmax(attn_weights.astype(jnp.float32), axis=-1)

# Weight values in float32
attn_output = attn_weights @ value_states.astype(jnp.float32)

# Cast final output back to the original lower precision if needed
attn_output = attn_output.astype(jnp.bfloat16) # Or original dtype

While your issue seems directly linked to nnx.remat potentially causing state interference, it's possible that the recomputation process within remat might be exacerbating underlying numerical precision sensitivities in the Qwen attention layers.

Perhaps you could try modifying the attention mechanism within your EasyDeL setup (if possible, or by modifying the source temporarily) to use float32 for the intermediate calculations as shown above, even when gradient_checkpointing is enabled. It's a bit of a long shot since the root causes might be different, but forcing higher precision in sensitive areas like attention sometimes helps stabilize things unexpectedly.

Hope this potentially offers another avenue to investigate or might provide some relief!

@erfanzar
Copy link
Author

erfanzar commented Apr 5, 2025

Thanks @demon2036 for sharing this but we are using FA3 kernels for gpus and Splash for TPUs so attention dtype isn't really the issue
But thanks ill double check

@demon2036
Copy link

Thanks @demon2036 for sharing this but we are using FA3 kernels for gpus and Splash for TPUs so attention dtype isn't really the issue But thanks ill double check

@erfanzar Have you had a chance to try decoding using greedy sampling?

If you use greedy sampling, is the first token generated correct?

If the first token is correct, but the subsequent tokens are incorrect, that might strongly suggest an underlying numerical precision issue. In that scenario, perhaps it could be worth investigating if forcing operations like RMSNorm (and maybe others) to compute in fp32 makes a difference?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants