From 918a369a83a52611782e815f89ee2b8605d839ba Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Apr 2026 15:26:46 -0700 Subject: [PATCH 1/3] add init Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/normalization.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/normalization.py b/lm_engine/hf_models/modeling_utils/normalization.py index 0a73a99a4..2f2fd56cd 100644 --- a/lm_engine/hf_models/modeling_utils/normalization.py +++ b/lm_engine/hf_models/modeling_utils/normalization.py @@ -4,6 +4,8 @@ from __future__ import annotations +from typing import Callable + import torch import torch.nn as nn import torch.nn.functional as F @@ -27,9 +29,11 @@ def __init__( self, normalized_shape: int, eps: float = 1e-6, + initialization_function: Callable | None = None, use_padding_free_transformer: bool = False, sequence_parallel: bool = False, ) -> LayerNorm: + self.initialization_function = initialization_function super().__init__(normalized_shape, eps=eps) if self.is_tp_enabled: @@ -55,8 +59,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + @torch.no_grad() def reset_parameters(self) -> None: - super().reset_parameters() + if self.initialization_function is None: + super().reset_parameters() + else: + self.initialization_function(self.weight, self.bias) + mark_parameter_as_initialized(self.weight) mark_parameter_as_initialized(self.bias) @@ -66,9 +75,11 @@ def __init__( self, normalized_shape: int, eps: float = 1e-6, + initialization_function: Callable | None = None, use_padding_free_transformer: bool = False, sequence_parallel: bool = False, ) -> RMSNorm: + self.initialization_function = initialization_function super().__init__(normalized_shape, eps=eps) if self.is_tp_enabled: @@ -103,8 +114,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + @torch.no_grad() def reset_parameters(self) -> None: - super().reset_parameters() + if self.initialization_function is None: + super().reset_parameters() + else: + self.initialization_function(self.weight) + mark_parameter_as_initialized(self.weight) From cd9f7506637c9f2a8f0fc943d1e936b5d0e49bcd Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Apr 2026 15:27:24 -0700 Subject: [PATCH 2/3] add init Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/normalization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/normalization.py b/lm_engine/hf_models/modeling_utils/normalization.py index 2f2fd56cd..06d008b11 100644 --- a/lm_engine/hf_models/modeling_utils/normalization.py +++ b/lm_engine/hf_models/modeling_utils/normalization.py @@ -172,6 +172,7 @@ def get_normalization_function( normalized_shape: int, eps: float = 1e-5, p: int | None = None, + initialization_function: Callable | None = None, use_padding_free_transformer: bool = False, sequence_parallel: bool = False, ) -> LayerNorm | RMSNorm | PNorm: @@ -191,7 +192,9 @@ def get_normalization_function( normalization = _NORMALIZATION_FUNCTIONS[normalization_function](**kwargs, p=p) else: assert p is None - normalization = _NORMALIZATION_FUNCTIONS[normalization_function](**kwargs) + normalization = _NORMALIZATION_FUNCTIONS[normalization_function]( + **kwargs, initialization_function=initialization_function + ) else: raise ValueError(f"unexpected `normalization_function` {normalization_function}") From 7618eccc1a1516c0222a334c6e750c759238c8a5 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 9 Apr 2026 15:31:31 -0700 Subject: [PATCH 3/3] add init Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/normalization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lm_engine/hf_models/modeling_utils/normalization.py b/lm_engine/hf_models/modeling_utils/normalization.py index 06d008b11..4bd883d9b 100644 --- a/lm_engine/hf_models/modeling_utils/normalization.py +++ b/lm_engine/hf_models/modeling_utils/normalization.py @@ -189,6 +189,7 @@ def get_normalization_function( if normalization_function in _NORMALIZATION_FUNCTIONS: if normalization_function == "p_norm": assert p is not None + assert initialization_function is None normalization = _NORMALIZATION_FUNCTIONS[normalization_function](**kwargs, p=p) else: assert p is None