Skip to content
30 changes: 30 additions & 0 deletions lm_engine/hf_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Copyright (c) 2025, Mayank Mishra
# **************************************************

import torch

from .config import CommonConfig
from .loss import get_autoregressive_language_modeling_loss, is_aux_loss_zero
from .mixins import CausalLMOutputWithPast, PipelineParallelInput, PipelineParallelOutput
Expand Down Expand Up @@ -39,3 +41,31 @@


register_model_classes()


def _patch_granitemoehybrid_weight_init() -> None:
try:
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
GraniteMoeHybridMambaLayer,
GraniteMoeHybridPreTrainedModel,
)
except Exception:
return

if getattr(GraniteMoeHybridPreTrainedModel._init_weights, "_lm_engine_patched", False):
return

def _init_weights(self, module):
super(GraniteMoeHybridPreTrainedModel, self)._init_weights(module)
if isinstance(module, GraniteMoeHybridMambaLayer):
module.dt_bias.data.fill_(1.0)
module.A_log.data.copy_(
torch.log(torch.arange(1, module.num_heads + 1, dtype=module.A_log.dtype, device=module.A_log.device))
)
module.D.data.fill_(1.0)

_init_weights._lm_engine_patched = True
GraniteMoeHybridPreTrainedModel._init_weights = _init_weights
Comment on lines +58 to +68
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Monkey-patching _init_weights by calling super()._init_weights inside the replacement function will bypass any initialization logic defined in the original GraniteMoeHybridPreTrainedModel._init_weights method. In the transformers library, model-specific PreTrainedModel subclasses typically implement _init_weights to handle their specific layers. By replacing the method and calling super(), you are skipping the original class's logic and calling the parent's (likely PreTrainedModel) method instead.

To correctly extend the existing initialization, you should save the original method and call it within your patch.

Suggested change
def _init_weights(self, module):
super(GraniteMoeHybridPreTrainedModel, self)._init_weights(module)
if isinstance(module, GraniteMoeHybridMambaLayer):
module.dt_bias.data.fill_(1.0)
module.A_log.data.copy_(
torch.log(torch.arange(1, module.num_heads + 1, dtype=module.A_log.dtype, device=module.A_log.device))
)
module.D.data.fill_(1.0)
_init_weights._lm_engine_patched = True
GraniteMoeHybridPreTrainedModel._init_weights = _init_weights
original_init_weights = GraniteMoeHybridPreTrainedModel._init_weights
def _init_weights(self, module):
original_init_weights(self, module)
if isinstance(module, GraniteMoeHybridMambaLayer):
module.dt_bias.data.fill_(1.0)
module.A_log.data.copy_(
torch.log(torch.arange(1, module.num_heads + 1, dtype=module.A_log.dtype, device=module.A_log.device))
)
module.D.data.fill_(1.0)



_patch_granitemoehybrid_weight_init()
19 changes: 2 additions & 17 deletions lm_engine/hf_models/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,7 @@

from ...utils import BaseArgs, divide_if_divisible
from .mlp import _MLPArgs, _MoEArgs
from .sequence_mixer import (
_CausalConvolution,
_GatedDeltaNetArgs,
_GRUArgs,
_M2RNNArgs,
_Mamba2Args,
_RNNArgs,
_SoftmaxAttentionArgs,
)
from .sequence_mixer import _GatedDeltaNetArgs, _GRUArgs, _M2RNNArgs, _Mamba2Args, _RNNArgs, _SoftmaxAttentionArgs


def _hold_base_args(key: str) -> Callable:
Expand All @@ -37,7 +29,6 @@ def _run(self, *args, **kwargs):


_SEQUENCE_MIXER_CONFIG_CLASSES = {
"causal_convolution": _CausalConvolution,
"gru": _GRUArgs,
"m2rnn": _M2RNNArgs,
"mamba2": _Mamba2Args,
Expand Down Expand Up @@ -183,13 +174,7 @@ def _set_sequence_mixer_blocks(self) -> None:
self.sequence_mixer_blocks = [{} for _ in range(self.num_layers)]

sequence_mixer_blocks: list[
_CausalConvolution
| _GRUArgs
| _Mamba2Args
| _RNNArgs
| _M2RNNArgs
| _SoftmaxAttentionArgs
| _GatedDeltaNetArgs
_GRUArgs | _Mamba2Args | _RNNArgs | _M2RNNArgs | _SoftmaxAttentionArgs | _GatedDeltaNetArgs
] = []
for i in range(self.num_layers):
sequence_mixer_block = deepcopy(self.sequence_mixer_blocks[i])
Expand Down
13 changes: 0 additions & 13 deletions lm_engine/hf_models/config/sequence_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,6 @@ def model_post_init(self, __context: Any) -> None:
assert self.sequence_mixer_type == "m2rnn"


class _CausalConvolution(BaseArgs):
sequence_mixer_type: str = "causal_convolution"
activation_function: str = "silu"
in_channels: int
out_channels: int
kernel_size: int
num_groups: int
add_bias: bool = False

def model_post_init(self, __context: Any) -> None:
assert self.sequence_mixer_type == "causal_convolution"


class _GatedDeltaNetArgs(_SoftPlusDecayArgs):
sequence_mixer_type: str = "gated_deltanet"
k_head_dim: int
Expand Down
1 change: 0 additions & 1 deletion lm_engine/hf_models/modeling_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# **************************************************

from .activations import get_activation_function, is_glu
from .convolution import ParameterizedConv1d
from .dropout import Dropout
from .dtensor_module import DTensorModule
from .embedding import ParameterizedEmbedding, get_tensor_parallel_vocab_info
Expand Down
57 changes: 0 additions & 57 deletions lm_engine/hf_models/modeling_utils/convolution.py

This file was deleted.

155 changes: 155 additions & 0 deletions lm_engine/hf_models/modeling_utils/depthwise_causal_convolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...enums import Kernel
from ...kernels import is_kernel_allowed
from ...utils import divide_if_divisible, is_causal_conv1d_available
from ..parameter import (
mark_parameter_as_initialized,
mark_parameter_as_mup_learning_rate,
Comment on lines +13 to +16
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The following imports are unused in this file: divide_if_divisible and mark_parameter_as_mup_learning_rate.

from ...utils import is_causal_conv1d_available
from ..parameter import (
    mark_parameter_as_initialized,
    mark_parameter_as_no_weight_decay,
)

mark_parameter_as_no_weight_decay,
)
from .activations import get_activation_function


if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update


def _apply_mask_to_padding_states(hidden_states: torch.Tensor, attention_mask: torch.Tensor | None) -> torch.Tensor:
"""
Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
"""
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check attention_mask.shape[0] > 1 prevents the mask from being applied when the batch size is 1. Padding can still exist in a single-sequence batch if the sequence length is less than the maximum length. The mask should be applied whenever it is provided to ensure padding tokens are correctly zeroed out.

Suggested change
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
if attention_mask is not None and attention_mask.shape[1] > 1:

dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)

return hidden_states


class DepthwiseCausalConvolution(nn.Conv1d):
def __init__(
self,
hidden_size: int,
kernel_size: int,
activation_function: str,
add_bias: bool,
std: float | None,
use_padding_free_transformer: bool,
) -> DepthwiseCausalConvolution:
if use_padding_free_transformer:
raise NotImplementedError()

self.std = std

super().__init__(
in_channels=hidden_size,
out_channels=hidden_size,
kernel_size=kernel_size,
padding=kernel_size - 1,
groups=hidden_size,
bias=add_bias,
)

self.activation_string = activation_function
self.activation_function = get_activation_function(self.activation_string)
self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"]
self.kernel_size = kernel_size

mark_parameter_as_no_weight_decay(self.bias)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If add_bias is False, self.bias will be None. Calling mark_parameter_as_no_weight_decay on None might lead to an error depending on its implementation. It is safer to check if the bias exists first.

Suggested change
mark_parameter_as_no_weight_decay(self.bias)
if self.bias is not None:
mark_parameter_as_no_weight_decay(self.bias)


self.reset_parameters()

def forward(
self,
hidden_states: torch.Tensor,
input_state: torch.Tensor | None,
attention_mask: torch.Tensor | None,
output_state: bool,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
S = hidden_states.size(1)
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)

if is_kernel_allowed(Kernel.causal_conv1d):
if input_state is None:
hidden_states = hidden_states.transpose(-1, -2)

if output_state:
# F.pad trims the hidden_states if sequence_length > kernel_size
input_state = F.pad(hidden_states, (self.kernel_size - S, 0))

hidden_states = causal_conv1d_fn(
x=hidden_states,
weight=self.weight.squeeze(1),
bias=self.bias,
activation=self.activation_string if self.use_activation_inside_kernel else None,
)

hidden_states = hidden_states.transpose(-1, -2)
else:
assert S == 1

input_state_buffer = input_state.clone()

hidden_states = causal_conv1d_update(
x=hidden_states,
conv_state=input_state_buffer,
weight=self.weight.squeeze(1),
bias=self.bias,
activation=self.activation_string if self.use_activation_inside_kernel else None,
)

input_state = input_state_buffer if output_state else None

if not self.use_activation_inside_kernel:
hidden_states = self.activation_function(hidden_states)
Comment on lines +80 to +112
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The output of the convolution is not masked in the is_kernel_allowed path. If a bias is present, the output at padding positions will be non-zero (the bias value after activation). This creates an inconsistency with the else path (line 141) where the output is explicitly masked. You should apply _apply_mask_to_padding_states to the output in both paths.

else:
if input_state is None:
hidden_states = hidden_states.transpose(-1, -2)

if output_state:
# F.pad trims the hidden_states if sequence_length > kernel_size
input_state = F.pad(hidden_states, (self.kernel_size - S, 0))

hidden_states = super().forward(hidden_states)

# removes padding on the right side of the sequence
hidden_states = hidden_states[..., : 1 - self.kernel_size]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The slicing logic [: 1 - self.kernel_size] will result in an empty tensor if kernel_size is 1. While these models typically use larger kernels, the implementation should be robust to kernel_size=1.

Suggested change
hidden_states = hidden_states[..., : 1 - self.kernel_size]
# removes padding on the right side of the sequence
if self.kernel_size > 1:
hidden_states = hidden_states[..., : 1 - self.kernel_size]

hidden_states = hidden_states.transpose(-1, -2)
else:
assert S == 1

input_state = input_state.roll(shifts=-1, dims=-1)
input_state[..., -1] = hidden_states[:, 0]

hidden_states = (input_state * self.weight.squeeze(1)).sum(dim=-1)
hidden_states = hidden_states[:, None, :]
if self.bias is not None:
hidden_states = hidden_states + self.bias

if not output_state:
input_state = None

hidden_states = self.activation_function(hidden_states)
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)

return hidden_states, input_state

@torch.no_grad()
def reset_parameters(self) -> None:
if self.std is None:
super().reset_parameters()
else:
nn.init.normal_(self.weight, mean=0, std=self.std)
if hasattr(self, "bias") and self.bias is not None:
self.bias.zero_()

mark_parameter_as_initialized(self.weight)
mark_parameter_as_initialized(self.bias)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the initialization, mark_parameter_as_initialized should only be called on self.bias if it is not None.

Suggested change
mark_parameter_as_initialized(self.bias)
mark_parameter_as_initialized(self.weight)
if self.bias is not None:
mark_parameter_as_initialized(self.bias)

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
interleave_query_key_value_tensor_for_attention,
split_query_key_value_tensor_for_attention,
)
from .causal_convolution import CausalConvolution
from .gated_deltanet import GatedDeltaNet
from .gru import GRU
from .m2rnn import M2RNN
Expand All @@ -18,7 +17,7 @@
from .utils import flash_attention


SEQUENCE_MIXER_TYPE = Attention | CausalConvolution | GRU | Mamba2 | RNN | GatedDeltaNet
SEQUENCE_MIXER_TYPE = Attention | GRU | Mamba2 | RNN | GatedDeltaNet


def get_sequence_mixer(
Expand All @@ -29,25 +28,7 @@ def get_sequence_mixer(

is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled()

if sequence_mixer_type == "causal_convolution":
assert not is_tp_enabled
return CausalConvolution(
hidden_size=config.hidden_size,
in_channels=block.in_channels,
out_channels=block.out_channels,
kernel_size=block.kernel_size,
num_groups=block.num_groups,
activation_function=block.activation_function,
add_bias=block.add_bias,
initializer_range=config.initializer_range,
m_width=config.m_width,
init_method=config.init_method,
num_layers=config.num_layers,
layer_idx=layer_idx,
use_depth_scaled_init=config.use_depth_scaled_init,
use_padding_free_transformer=use_padding_free_transformer,
)
elif sequence_mixer_type == "gru":
if sequence_mixer_type == "gru":
assert not is_tp_enabled
return GRU(
input_size=config.hidden_size,
Expand Down
Loading
Loading