diff --git a/lm_engine/hf_models/modeling_utils/normalization.py b/lm_engine/hf_models/modeling_utils/normalization.py index 0a73a99a4..4bd883d9b 100644 --- a/lm_engine/hf_models/modeling_utils/normalization.py +++ b/lm_engine/hf_models/modeling_utils/normalization.py @@ -4,6 +4,8 @@ from __future__ import annotations +from typing import Callable + import torch import torch.nn as nn import torch.nn.functional as F @@ -27,9 +29,11 @@ def __init__( self, normalized_shape: int, eps: float = 1e-6, + initialization_function: Callable | None = None, use_padding_free_transformer: bool = False, sequence_parallel: bool = False, ) -> LayerNorm: + self.initialization_function = initialization_function super().__init__(normalized_shape, eps=eps) if self.is_tp_enabled: @@ -55,8 +59,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + @torch.no_grad() def reset_parameters(self) -> None: - super().reset_parameters() + if self.initialization_function is None: + super().reset_parameters() + else: + self.initialization_function(self.weight, self.bias) + mark_parameter_as_initialized(self.weight) mark_parameter_as_initialized(self.bias) @@ -66,9 +75,11 @@ def __init__( self, normalized_shape: int, eps: float = 1e-6, + initialization_function: Callable | None = None, use_padding_free_transformer: bool = False, sequence_parallel: bool = False, ) -> RMSNorm: + self.initialization_function = initialization_function super().__init__(normalized_shape, eps=eps) if self.is_tp_enabled: @@ -103,8 +114,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + @torch.no_grad() def reset_parameters(self) -> None: - super().reset_parameters() + if self.initialization_function is None: + super().reset_parameters() + else: + self.initialization_function(self.weight) + mark_parameter_as_initialized(self.weight) @@ -156,6 +172,7 @@ def get_normalization_function( normalized_shape: int, eps: float = 1e-5, p: int | None = None, + initialization_function: Callable | None = None, use_padding_free_transformer: bool = False, sequence_parallel: bool = False, ) -> LayerNorm | RMSNorm | PNorm: @@ -172,10 +189,13 @@ def get_normalization_function( if normalization_function in _NORMALIZATION_FUNCTIONS: if normalization_function == "p_norm": assert p is not None + assert initialization_function is None normalization = _NORMALIZATION_FUNCTIONS[normalization_function](**kwargs, p=p) else: assert p is None - normalization = _NORMALIZATION_FUNCTIONS[normalization_function](**kwargs) + normalization = _NORMALIZATION_FUNCTIONS[normalization_function]( + **kwargs, initialization_function=initialization_function + ) else: raise ValueError(f"unexpected `normalization_function` {normalization_function}")