Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions lm_engine/hf_models/modeling_utils/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from __future__ import annotations

from typing import Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -27,9 +29,11 @@ def __init__(
self,
normalized_shape: int,
eps: float = 1e-6,
initialization_function: Callable | None = None,
use_padding_free_transformer: bool = False,
sequence_parallel: bool = False,
) -> LayerNorm:
self.initialization_function = initialization_function
super().__init__(normalized_shape, eps=eps)
Comment on lines +36 to 37
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a multiple inheritance issue here. LayerNorm inherits from nn.LayerNorm and DTensorModule, but DTensorModule.__init__ is never called. This will cause an AttributeError when self.is_tp_enabled is accessed later. You should explicitly call the initializers of both parent classes.

Since nn.LayerNorm.__init__ calls reset_parameters(), which you've modified to depend on self.initialization_function, the order of initialization is important. The DTensorModule should be initialized first.

Suggested change
self.initialization_function = initialization_function
super().__init__(normalized_shape, eps=eps)
DTensorModule.__init__(self)
self.initialization_function = initialization_function
nn.LayerNorm.__init__(self, normalized_shape, eps=eps)


if self.is_tp_enabled:
Expand All @@ -55,8 +59,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

return x

@torch.no_grad()
def reset_parameters(self) -> None:
super().reset_parameters()
if self.initialization_function is None:
super().reset_parameters()
else:
self.initialization_function(self.weight, self.bias)

mark_parameter_as_initialized(self.weight)
mark_parameter_as_initialized(self.bias)

Expand All @@ -66,9 +75,11 @@ def __init__(
self,
normalized_shape: int,
eps: float = 1e-6,
initialization_function: Callable | None = None,
use_padding_free_transformer: bool = False,
sequence_parallel: bool = False,
) -> RMSNorm:
self.initialization_function = initialization_function
super().__init__(normalized_shape, eps=eps)
Comment on lines +82 to 83
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to LayerNorm, this __init__ method has a multiple inheritance issue. RMSNorm inherits from nn.RMSNorm and DTensorModule, but DTensorModule.__init__ is never called. This will cause an AttributeError when self.is_tp_enabled is accessed. You should explicitly call the initializers of both parent classes.

Suggested change
self.initialization_function = initialization_function
super().__init__(normalized_shape, eps=eps)
DTensorModule.__init__(self)
self.initialization_function = initialization_function
nn.RMSNorm.__init__(self, normalized_shape, eps=eps)


if self.is_tp_enabled:
Expand Down Expand Up @@ -103,8 +114,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

return x

@torch.no_grad()
def reset_parameters(self) -> None:
super().reset_parameters()
if self.initialization_function is None:
super().reset_parameters()
else:
self.initialization_function(self.weight)

mark_parameter_as_initialized(self.weight)


Expand Down Expand Up @@ -156,6 +172,7 @@ def get_normalization_function(
normalized_shape: int,
eps: float = 1e-5,
p: int | None = None,
initialization_function: Callable | None = None,
use_padding_free_transformer: bool = False,
sequence_parallel: bool = False,
) -> LayerNorm | RMSNorm | PNorm:
Expand All @@ -172,10 +189,13 @@ def get_normalization_function(
if normalization_function in _NORMALIZATION_FUNCTIONS:
if normalization_function == "p_norm":
assert p is not None
assert initialization_function is None
normalization = _NORMALIZATION_FUNCTIONS[normalization_function](**kwargs, p=p)
else:
assert p is None
normalization = _NORMALIZATION_FUNCTIONS[normalization_function](**kwargs)
normalization = _NORMALIZATION_FUNCTIONS[normalization_function](
**kwargs, initialization_function=initialization_function
)
else:
raise ValueError(f"unexpected `normalization_function` {normalization_function}")

Expand Down
Loading