Skip to content

Conversation

@mstojkovicTT
Copy link

@mstojkovicTT mstojkovicTT commented Nov 27, 2025

What does this PR do?

Fixes #42398

This PR replaces custom RMSNorm/T5-style norm implementations (e.g. in Llama) that manually compute variance and scaling with the built-in torch.nn.functional.rms_norm. For example, code like:

input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

is simplified to:

return F.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon)

This keeps the behavior and epsilon handling the same while reducing the number of ops, this should improve performance for users without requiring any additional changes on their side.

To verify the performance and the numerical stability, i have wrote the following test

import timeit
import torch
import torch.nn as nn

# Original implementation
class LlamaRMSNormHF(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


# New implementation using torch.nn.functional.rms_norm
class LlamaRMSNormTorch(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        return nn.functional.rms_norm(
            hidden_states,
            hidden_states.shape[-1:],
            self.weight,
            self.variance_epsilon,
        )

def bench(module, x, iters=1000):
    # Warmup
    for _ in range(10):
        module(x)

    if x.is_cuda:
        torch.cuda.synchronize()
    start = timeit.default_timer()
    for _ in range(iters):
        module(x)
    if x.is_cuda:
        torch.cuda.synchronize()
    end = timeit.default_timer()

    return (end - start) / iters


def test_llama_rms_norm_equivalence():
    torch.manual_seed(0)

    hidden_size = 64
    batch_size = 2
    seq_len = 3

    dtype = torch.bfloat16
    device = "cpu" 

    x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device)

    hf_module = LlamaRMSNormHF(hidden_size, eps=1e-6).to(device)
    new_module = LlamaRMSNormTorch(hidden_size, eps=1e-6).to(device)

    # make sure they have the same weights
    with torch.no_grad():
        new_module.weight.copy_(hf_module.weight)

    y_hf = hf_module(x)
    y_new = new_module(x)
    y_new = y_new.to(y_hf.dtype) # torch.allclose needs same dtype

    # Check numerically close
    print(torch.allclose(y_hf, y_new, atol=1e-5, rtol=1e-5))

    # speed benchmark
    t_hf = bench(hf_module, x)
    t_new = bench(new_module, x)

    print(f"HF   RMSNorm: {t_hf * 1e6:.2f} µs / call")
    print(f"F.rms_norm  : {t_new * 1e6:.2f} µs / call")

test_llama_rms_norm_equivalence()

The results show the following:

  • for cpu device:
True
HF   RMSNorm: 86.75 µs / call
F.rms_norm  : 47.25 µs / call
  • for cuda device:
True
HF   RMSNorm: 112.27 µs / call
F.rms_norm  : 83.05 µs / call

note: I have encountered that when I try dtypes that are lower then float32, old implementation will keep it at float32, but my new one will have the dtype of the input tensor. Thats why i have to cast to y_hf.dtype (trying float64 for example will make both implementation output float64). This can be changed, depending on what we want to accomplish.

Who can review?

@Rocketknight1

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: aimv2, apertus, arcee, aria, bamba, bitnet, blt, chameleon, clvp, csm, cwm, deepseek_v2, deepseek_v3, dia, diffllama, doge

@Rocketknight1
Copy link
Member

Hey @mstojkovicTT, thanks for the PR! We definitely want the functions to be a drop-in replacement, so they should return exactly the same dtype as the old functions did.

Also, in your tests you're initializing self.weight = nn.Parameter(torch.ones(hidden_size)) which means that the final scaling step is just multiplying by 1 and having no effect. Can you try with randomly initialized weights with a mean of 1 but a little bit of variance instead, so we can see if everything is equivalent with the original?

@mstojkovicTT
Copy link
Author

mstojkovicTT commented Dec 1, 2025

We definitely want the functions to be a drop-in replacement, so they should return exactly the same dtype as the old functions did.

This may be problematic, because of the following scenario:

  • Lets say that we have input that is bfloat16
  • When i run the huggingface implementation i get a return dtype is float32
  • I dont think this is intended, considering that we do
return self.weight * hidden_states.to(input_dtype)
  • Here, one of the following should be our goal: ether we just want to upcast the input to float32 because of precision in accumulating, in which case we dont care about the input_dtype and have it for no reason, or we want to accumulate and preserve the dtype (which is not the case).
  • One case i see hidden_states.to(input_dtype) being usefull is when we have "larger" dtypes then float32 like float64, but in that case, we did downcast to float32 for no reason

Here is the pytorch default implementation of rms_norm when it is not using custom CUDA or MPS kernels. Here we can see that they do pretty much the same thing, except that they just cast the output of the op to the input type, not the hidden_states that is later multiplied with weight (and that causes this wierd dtype change)

Can you try with randomly initialized weights with a mean of 1 but a little bit of variance instead, so we can see if everything is equivalent with the original?

I did the same testing just with the additional

with torch.no_grad():
        random_weight = torch.randn(hidden_size, device=device) * 0.05 + 1.0
        hf_module.weight.copy_(random_weight)
        new_module.weight.copy_(random_weight)

and everything still works.

And also, thank you for taking a time to review this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LlamaRMSNorm and equivalent module implementations using torch ops

2 participants