diff --git a/src/transformers/models/aimv2/modeling_aimv2.py b/src/transformers/models/aimv2/modeling_aimv2.py index 4c7e81ed59cc..0a66e3abc578 100644 --- a/src/transformers/models/aimv2/modeling_aimv2.py +++ b/src/transformers/models/aimv2/modeling_aimv2.py @@ -90,11 +90,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index 77a3d65478d6..3f7416f90450 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -65,11 +65,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index 779f4a63e378..f135c4fe697c 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -72,11 +72,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 96a6a82da91d..a409c9271a37 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -54,11 +54,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 2428222e0dbe..c0dd1d61d162 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1009,11 +1009,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index 73597dd98d82..c336724ba77f 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -51,11 +51,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index d4b19101c861..2d91f1eb8c2c 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -75,11 +75,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 1bf2179deec6..4ef5c5813c78 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -55,11 +55,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index 835275f4e20c..ae780ae257ff 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -225,11 +225,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 87da76281717..65699c7cd131 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -108,11 +108,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/cwm/modeling_cwm.py b/src/transformers/models/cwm/modeling_cwm.py index df9760ed1ba7..787d437981ac 100644 --- a/src/transformers/models/cwm/modeling_cwm.py +++ b/src/transformers/models/cwm/modeling_cwm.py @@ -253,11 +253,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index 89230d7a80b2..58262b420969 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -158,11 +158,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index cfd8d91dfb9a..048d50d4c724 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -44,11 +44,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 3a0ddf6e3f90..0f5cd60e4bcc 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -118,11 +118,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 99524915b9f6..5a2b6d7649a2 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -518,11 +518,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index b9ebf9856264..e74ee4f5e1be 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -61,11 +61,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 65671913b27f..21bf73813cb3 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -198,11 +198,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py index b53ddf923e70..6f5d6f83aab4 100644 --- a/src/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -277,11 +277,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index ccd05fe26347..58b75bb85535 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -52,11 +52,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index f4d0ce11255f..e3c52b33da71 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -953,11 +953,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/exaone4/modeling_exaone4.py b/src/transformers/models/exaone4/modeling_exaone4.py index cb70c9cff142..27e5633a0615 100644 --- a/src/transformers/models/exaone4/modeling_exaone4.py +++ b/src/transformers/models/exaone4/modeling_exaone4.py @@ -58,11 +58,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index a6fd7a5aba99..6b43e0868f97 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -1077,11 +1077,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 8a508e2de54c..e0ee373bba40 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -295,11 +295,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index c982c36f9aab..2252e07452f5 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -344,11 +344,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index 84e6dd3bd77d..0f43aeba0f30 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -321,11 +321,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index ff5e0a00cc0d..26b4a2906204 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -55,11 +55,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 373d49bc942c..bc4c55cb5bf9 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -56,11 +56,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" @@ -475,11 +471,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index fc7d6fd40a80..912ae9abbc7f 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -321,11 +321,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 42de2e0724f3..2c718be7ed4b 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -196,11 +196,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index f722ad416a2f..008605f7e951 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -53,11 +53,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index c9e7245956f3..f9e6d383a480 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1104,11 +1104,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 606a59390e6e..b92032910c0b 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -104,11 +104,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py index 4d184a0b1982..e96064ffe19b 100644 --- a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py @@ -53,11 +53,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index 281a50a9e2cc..79e8348a988d 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -53,11 +53,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index dcfa6d9cd23b..a7ed73ca3122 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -529,11 +529,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 38d6f29c3f04..657cc3645f64 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -386,11 +386,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 51691e0ba4ab..6064ae12bf67 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -54,11 +54,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index fbda366fc319..5619abc3d92c 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -70,11 +70,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index b102a111e10f..1694332d5539 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -56,11 +56,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index f1d639d16bbd..a9d2b94a5c65 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -56,11 +56,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py index 73b9c4a8fde0..cb060644b7d1 100644 --- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -58,11 +58,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e3adac5d117d..0f7b3029f8c1 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -60,11 +60,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/longcat_flash/modeling_longcat_flash.py b/src/transformers/models/longcat_flash/modeling_longcat_flash.py index 4135bce33d83..8f019768cba0 100644 --- a/src/transformers/models/longcat_flash/modeling_longcat_flash.py +++ b/src/transformers/models/longcat_flash/modeling_longcat_flash.py @@ -55,11 +55,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 004ed68cef23..3992c502c59f 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -60,11 +60,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 60c7e2d49eed..b257983df723 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -193,11 +193,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 0f6e2a1d3efc..08b20f7f61e1 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -48,11 +48,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 1faff1f4dcea..3b200de8c555 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -143,11 +143,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index a2d303782bdd..a584dee7a2cc 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -376,11 +376,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index f078518e0c1f..66945db1c451 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -50,11 +50,7 @@ def __init__(self, hidden_size, eps=1e-5): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/ovis2/modeling_ovis2.py b/src/transformers/models/ovis2/modeling_ovis2.py index d990e08190b6..7be3076f176d 100644 --- a/src/transformers/models/ovis2/modeling_ovis2.py +++ b/src/transformers/models/ovis2/modeling_ovis2.py @@ -102,11 +102,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 29b3d2847ed1..7c36d666841a 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -289,11 +289,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index eab15068d252..5db69165fdfb 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1137,11 +1137,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 12e41214094d..c63e0c7b3287 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -556,11 +556,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index 3e6d468a8184..c910d1fbaf66 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -293,11 +293,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 8bda140d3cdb..f58c74a6f77e 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -63,11 +63,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 477694d5fb2b..37c0fa8d45fe 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -296,11 +296,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 1be0487cea98..3203e406db01 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1405,11 +1405,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" @@ -1625,11 +1621,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index eab677bce4fe..3b5fd3090d7b 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -55,11 +55,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/seed_oss/modeling_seed_oss.py b/src/transformers/models/seed_oss/modeling_seed_oss.py index 682193ca8d51..da9da1fc6a69 100644 --- a/src/transformers/models/seed_oss/modeling_seed_oss.py +++ b/src/transformers/models/seed_oss/modeling_seed_oss.py @@ -55,11 +55,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index e23d4993e84c..4934e152155b 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -272,11 +272,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 0ec1b90f89da..e2df304ab780 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -127,11 +127,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index fb51f5add858..a7a11206755d 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -72,11 +72,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 5b5f532cebf7..fe67adf5a248 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -89,11 +89,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return nn.functional.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"