From aa70aa00a10438f927d81d947935b3b6fe2a22e1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 28 Apr 2026 15:33:37 -0700 Subject: [PATCH 01/19] fix Signed-off-by: Mayank Mishra --- accelerated-model-architectures | 2 +- lm_engine/hf_models/config/__init__.py | 19 +- lm_engine/hf_models/config/sequence_mixer.py | 13 - .../sequence_mixer_blocks/__init__.py | 23 +- .../causal_convolution.py | 259 +++++++----------- 5 files changed, 99 insertions(+), 217 deletions(-) diff --git a/accelerated-model-architectures b/accelerated-model-architectures index e99632015..f9cf82992 160000 --- a/accelerated-model-architectures +++ b/accelerated-model-architectures @@ -1 +1 @@ -Subproject commit e996320150a371c2afcb4d1ad405ce0b0eb811fa +Subproject commit f9cf829925e0ebcc1b54cf7ee6200ebc3eeb06c3 diff --git a/lm_engine/hf_models/config/__init__.py b/lm_engine/hf_models/config/__init__.py index daa6499db..f32472530 100644 --- a/lm_engine/hf_models/config/__init__.py +++ b/lm_engine/hf_models/config/__init__.py @@ -11,15 +11,7 @@ from ...utils import BaseArgs, divide_if_divisible from .mlp import _MLPArgs, _MoEArgs -from .sequence_mixer import ( - _CausalConvolution, - _GatedDeltaNetArgs, - _GRUArgs, - _M2RNNArgs, - _Mamba2Args, - _RNNArgs, - _SoftmaxAttentionArgs, -) +from .sequence_mixer import _GatedDeltaNetArgs, _GRUArgs, _M2RNNArgs, _Mamba2Args, _RNNArgs, _SoftmaxAttentionArgs def _hold_base_args(key: str) -> Callable: @@ -37,7 +29,6 @@ def _run(self, *args, **kwargs): _SEQUENCE_MIXER_CONFIG_CLASSES = { - "causal_convolution": _CausalConvolution, "gru": _GRUArgs, "m2rnn": _M2RNNArgs, "mamba2": _Mamba2Args, @@ -183,13 +174,7 @@ def _set_sequence_mixer_blocks(self) -> None: self.sequence_mixer_blocks = [{} for _ in range(self.num_layers)] sequence_mixer_blocks: list[ - _CausalConvolution - | _GRUArgs - | _Mamba2Args - | _RNNArgs - | _M2RNNArgs - | _SoftmaxAttentionArgs - | _GatedDeltaNetArgs + _GRUArgs | _Mamba2Args | _RNNArgs | _M2RNNArgs | _SoftmaxAttentionArgs | _GatedDeltaNetArgs ] = [] for i in range(self.num_layers): sequence_mixer_block = deepcopy(self.sequence_mixer_blocks[i]) diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index 6028c95db..5516139f5 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -134,19 +134,6 @@ def model_post_init(self, __context: Any) -> None: assert self.sequence_mixer_type == "m2rnn" -class _CausalConvolution(BaseArgs): - sequence_mixer_type: str = "causal_convolution" - activation_function: str = "silu" - in_channels: int - out_channels: int - kernel_size: int - num_groups: int - add_bias: bool = False - - def model_post_init(self, __context: Any) -> None: - assert self.sequence_mixer_type == "causal_convolution" - - class _GatedDeltaNetArgs(_SoftPlusDecayArgs): sequence_mixer_type: str = "gated_deltanet" k_head_dim: int diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index 3e2f97d37..d624f02b1 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -9,7 +9,6 @@ interleave_query_key_value_tensor_for_attention, split_query_key_value_tensor_for_attention, ) -from .causal_convolution import CausalConvolution from .gated_deltanet import GatedDeltaNet from .gru import GRU from .m2rnn import M2RNN @@ -18,7 +17,7 @@ from .utils import flash_attention -SEQUENCE_MIXER_TYPE = Attention | CausalConvolution | GRU | Mamba2 | RNN | GatedDeltaNet +SEQUENCE_MIXER_TYPE = Attention | GRU | Mamba2 | RNN | GatedDeltaNet def get_sequence_mixer( @@ -29,25 +28,7 @@ def get_sequence_mixer( is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() - if sequence_mixer_type == "causal_convolution": - assert not is_tp_enabled - return CausalConvolution( - hidden_size=config.hidden_size, - in_channels=block.in_channels, - out_channels=block.out_channels, - kernel_size=block.kernel_size, - num_groups=block.num_groups, - activation_function=block.activation_function, - add_bias=block.add_bias, - initializer_range=config.initializer_range, - m_width=config.m_width, - init_method=config.init_method, - num_layers=config.num_layers, - layer_idx=layer_idx, - use_depth_scaled_init=config.use_depth_scaled_init, - use_padding_free_transformer=use_padding_free_transformer, - ) - elif sequence_mixer_type == "gru": + if sequence_mixer_type == "gru": assert not is_tp_enabled return GRU( input_size=config.hidden_size, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py index c730539bc..e8a697f09 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py @@ -34,107 +34,9 @@ def _apply_mask_to_padding_states(hidden_states: torch.Tensor, attention_mask: t return hidden_states -def causal_convolution( - hidden_states: torch.Tensor, - input_state: torch.Tensor | None, - attention_mask: torch.Tensor | None, - conv1d_weight: torch.Tensor, - conv1d_bias: torch.Tensor | None, - conv1d_num_groups: int, - return_cache_state: bool, - activation_string: str, - conv1d_padding: int, - conv1d_stride: int = 1, -) -> tuple[torch.Tensor, torch.Tensor]: - casual_conv1d_compatible = conv1d_num_groups == conv1d_weight.size(0) and conv1d_weight.size(1) == 1 - sequence_length = hidden_states.size(1) - kernel_size = conv1d_weight.size(-1) - - assert conv1d_stride == 1 - assert conv1d_padding == kernel_size - 1 - - hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) - - if is_kernel_allowed(Kernel.causal_conv1d) and casual_conv1d_compatible: - use_activation_inside_kernel = activation_string in [None, "silu", "swish"] - - if input_state is None: - hidden_states = hidden_states.transpose(-1, -2) - - if return_cache_state: - # F.pad trims the hidden_states if sequence_length > kernel_size - input_state = F.pad(hidden_states, (kernel_size - sequence_length, 0)) - - hidden_states = causal_conv1d_fn( - x=hidden_states, - weight=conv1d_weight.squeeze(1), - bias=conv1d_bias, - activation=activation_string if use_activation_inside_kernel else None, - ) - - hidden_states = hidden_states.transpose(-1, -2) - else: - assert sequence_length == 1 - - # we clone to prevent modification in-place - # torch compile can remove the clone if its not needed - # this is to prevent silent incorrectness down the line in the model - input_state_buffer = input_state.clone() - hidden_states = causal_conv1d_update( - x=hidden_states, - conv_state=input_state_buffer, - weight=conv1d_weight.squeeze(1), - bias=conv1d_bias, - activation=activation_string if use_activation_inside_kernel else None, - ) - input_state = input_state_buffer if return_cache_state else None - - if not use_activation_inside_kernel: - hidden_states = get_activation_function(activation_string)(hidden_states) - else: - if input_state is None: - hidden_states = hidden_states.transpose(-1, -2) - - if return_cache_state: - # F.pad trims the hidden_states if sequence_length > kernel_size - input_state = F.pad(hidden_states, (kernel_size - sequence_length, 0)) - - hidden_states = F.conv1d( - input=hidden_states, - weight=conv1d_weight, - bias=conv1d_bias, - stride=conv1d_stride, - padding=conv1d_padding, - groups=conv1d_num_groups, - ) - - # removes padding on the right side of the sequence - hidden_states = hidden_states[..., : 1 - kernel_size] - hidden_states = hidden_states.transpose(-1, -2) - else: - assert sequence_length == 1 - - input_state = input_state.roll(shifts=-1, dims=-1) - input_state[..., -1] = hidden_states[:, 0] - - hidden_states = (input_state * conv1d_weight.squeeze(1)).sum(dim=-1) - hidden_states = hidden_states[:, None, :] - if conv1d_bias is not None: - hidden_states = hidden_states + conv1d_bias - - if not return_cache_state: - input_state = None - - hidden_states = get_activation_function(activation_string)(hidden_states) - hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) - - return hidden_states, input_state - - class CausalConvolution(nn.Module): def __init__( self, - hidden_size: int, in_channels: int, out_channels: int, kernel_size: int, @@ -146,7 +48,6 @@ def __init__( init_method: str, num_layers: int, layer_idx: int, - use_depth_scaled_init: bool, use_padding_free_transformer: bool, ) -> CausalConvolution: super().__init__() @@ -161,27 +62,8 @@ def __init__( self.layer_idx = layer_idx self.activation_string = activation_function - self.input_projection = ParameterizedLinear( - hidden_size, - in_channels, - bias=add_bias, - std=_get_std_for_linear( - initializer_range=initializer_range, - init_method=init_method, - m_width=m_width, - fan_in=hidden_size, - num_layers=num_layers, - use_depth_scaled_init=False, - ), - ) - - divide_if_divisible(in_channels, num_groups, "") - divide_if_divisible(out_channels, num_groups, "") - - if is_glu(self.activation_string): - intermediate_size = divide_if_divisible(out_channels, 2, "") - else: - intermediate_size = out_channels + divide_if_divisible(in_channels, num_groups) + divide_if_divisible(out_channels, num_groups) self.conv1d = ParameterizedConv1d( in_channels=in_channels, @@ -202,60 +84,107 @@ def __init__( self.activation_function = get_activation_function(self.activation_string) - self.output_projection = ParameterizedLinear( - intermediate_size, - hidden_size, - bias=add_bias, - std=_get_std_for_linear( - initializer_range=initializer_range, - init_method=init_method, - m_width=m_width, - fan_in=intermediate_size, - num_layers=num_layers, - use_depth_scaled_init=use_depth_scaled_init, - ), - ) - self.casual_conv1d_compatible = self.num_groups == self.in_channels == self.out_channels self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] mark_parameter_as_mup_learning_rate(self.input_projection.weight) mark_parameter_as_mup_learning_rate(self.conv1d.weight) - mark_parameter_as_mup_learning_rate(self.output_projection.weight) mark_parameter_as_no_weight_decay(self.input_projection.bias) mark_parameter_as_no_weight_decay(self.conv1d.bias) - mark_parameter_as_no_weight_decay(self.output_projection.bias) def forward( - self, x: torch.Tensor, cache_params: GenerationCache | None = None, attention_mask: torch.Tensor | None = None - ) -> torch.Tensor: - input_state = ( - None if cache_params is None else cache_params.get_cache(layer_idx=self.layer_idx, empty_value=None) - ) - - S = x.size(1) - x = self.input_projection(x) - - x, input_state = causal_convolution( - hidden_states=x, - input_state=input_state, - attention_mask=attention_mask, - conv1d_weight=self.conv1d.weight, - conv1d_bias=self.conv1d.bias, - conv1d_num_groups=self.num_groups, - return_cache_state=cache_params is not None, - activation_string=self.activation_string, - conv1d_padding=self.kernel_size - 1, - conv1d_stride=1, - ) - - if cache_params is not None: - cache_params.update( - states=(GenerationState(state=input_state, method=ConstantCache, num_tokens_added=S),), - layer_idx=self.layer_idx, - ) + self, + hidden_states: torch.Tensor, + input_state: torch.Tensor | None, + attention_mask: torch.Tensor | None, + return_cache_state: bool, + activation_string: str, + conv1d_padding: int, + conv1d_stride: int = 1, + ) -> tuple[torch.Tensor, torch.Tensor]: + W = self.conv1d.weight + b = self.conv1d.bias + + casual_conv1d_compatible = self.num_groups == W.size(0) and W.size(1) == 1 + sequence_length = hidden_states.size(1) + + assert conv1d_stride == 1 + assert conv1d_padding == self.kernel_size - 1 - x = self.output_projection(x) + hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) - return x + if is_kernel_allowed(Kernel.causal_conv1d) and casual_conv1d_compatible: + use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] + + if input_state is None: + hidden_states = hidden_states.transpose(-1, -2) + + if return_cache_state: + # F.pad trims the hidden_states if sequence_length > kernel_size + input_state = F.pad(hidden_states, (self.kernel_size - sequence_length, 0)) + + hidden_states = causal_conv1d_fn( + x=hidden_states, + weight=W.squeeze(1), + bias=b, + activation=self.activation_string if use_activation_inside_kernel else None, + ) + + hidden_states = hidden_states.transpose(-1, -2) + else: + assert sequence_length == 1 + + # we clone to prevent modification in-place + # torch compile can remove the clone if its not needed + # this is to prevent silent incorrectness down the line in the model + input_state_buffer = input_state.clone() + hidden_states = causal_conv1d_update( + x=hidden_states, + conv_state=input_state_buffer, + weight=W.squeeze(1), + bias=b, + activation=self.activation_string if use_activation_inside_kernel else None, + ) + input_state = input_state_buffer if return_cache_state else None + + if not use_activation_inside_kernel: + hidden_states = self.activation_function(hidden_states) + else: + if input_state is None: + hidden_states = hidden_states.transpose(-1, -2) + + if return_cache_state: + # F.pad trims the hidden_states if sequence_length > kernel_size + input_state = F.pad(hidden_states, (self.kernel_size - sequence_length, 0)) + + hidden_states = F.conv1d( + input=hidden_states, + weight=W, + bias=b, + stride=conv1d_stride, + padding=conv1d_padding, + groups=self.num_groups, + ) + + # removes padding on the right side of the sequence + hidden_states = hidden_states[..., : 1 - self.kernel_size] + hidden_states = hidden_states.transpose(-1, -2) + else: + assert sequence_length == 1 + + input_state = input_state.roll(shifts=-1, dims=-1) + input_state[..., -1] = hidden_states[:, 0] + + hidden_states = (input_state * W.squeeze(1)).sum(dim=-1) + hidden_states = hidden_states[:, None, :] + if b is not None: + hidden_states = hidden_states + b + + if not return_cache_state: + input_state = None + + hidden_states = self.activation_function(hidden_states) + hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) + + return hidden_states, input_state From ea6910d1613dff6ee7c0d01d26a1c4b84236ddff Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 28 Apr 2026 15:33:50 -0700 Subject: [PATCH 02/19] fix Signed-off-by: Mayank Mishra --- accelerated-model-architectures | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accelerated-model-architectures b/accelerated-model-architectures index f9cf82992..7f09c3228 160000 --- a/accelerated-model-architectures +++ b/accelerated-model-architectures @@ -1 +1 @@ -Subproject commit f9cf829925e0ebcc1b54cf7ee6200ebc3eeb06c3 +Subproject commit 7f09c32284b5e2365975716e72e6d0a2ca60222e From 3967ae3a8faad826e4b07483afb0a150be01c368 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 28 Apr 2026 15:37:57 -0700 Subject: [PATCH 03/19] fix Signed-off-by: Mayank Mishra --- .../causal_convolution.py | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py index e8a697f09..b2ec2796f 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py @@ -11,12 +11,10 @@ from ....enums import Kernel from ....kernels import is_kernel_allowed from ....utils import divide_if_divisible, is_causal_conv1d_available -from ...cache import ConstantCache, GenerationCache, GenerationState from ...parameter import mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay -from ..activations import get_activation_function, is_glu +from ..activations import get_activation_function from ..convolution import ParameterizedConv1d from ..init_utils import _get_std_for_linear -from ..linear import ParameterizedLinear if is_causal_conv1d_available(): @@ -99,41 +97,37 @@ def forward( input_state: torch.Tensor | None, attention_mask: torch.Tensor | None, return_cache_state: bool, - activation_string: str, conv1d_padding: int, conv1d_stride: int = 1, ) -> tuple[torch.Tensor, torch.Tensor]: W = self.conv1d.weight b = self.conv1d.bias - casual_conv1d_compatible = self.num_groups == W.size(0) and W.size(1) == 1 - sequence_length = hidden_states.size(1) + S = hidden_states.size(1) assert conv1d_stride == 1 assert conv1d_padding == self.kernel_size - 1 hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) - if is_kernel_allowed(Kernel.causal_conv1d) and casual_conv1d_compatible: - use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] - + if is_kernel_allowed(Kernel.causal_conv1d) and self.casual_conv1d_compatible: if input_state is None: hidden_states = hidden_states.transpose(-1, -2) if return_cache_state: # F.pad trims the hidden_states if sequence_length > kernel_size - input_state = F.pad(hidden_states, (self.kernel_size - sequence_length, 0)) + input_state = F.pad(hidden_states, (self.kernel_size - S, 0)) hidden_states = causal_conv1d_fn( x=hidden_states, weight=W.squeeze(1), bias=b, - activation=self.activation_string if use_activation_inside_kernel else None, + activation=self.activation_string if self.use_activation_inside_kernel else None, ) hidden_states = hidden_states.transpose(-1, -2) else: - assert sequence_length == 1 + assert S == 1 # we clone to prevent modification in-place # torch compile can remove the clone if its not needed @@ -144,11 +138,11 @@ def forward( conv_state=input_state_buffer, weight=W.squeeze(1), bias=b, - activation=self.activation_string if use_activation_inside_kernel else None, + activation=self.activation_string if self.use_activation_inside_kernel else None, ) input_state = input_state_buffer if return_cache_state else None - if not use_activation_inside_kernel: + if not self.use_activation_inside_kernel: hidden_states = self.activation_function(hidden_states) else: if input_state is None: @@ -156,7 +150,7 @@ def forward( if return_cache_state: # F.pad trims the hidden_states if sequence_length > kernel_size - input_state = F.pad(hidden_states, (self.kernel_size - sequence_length, 0)) + input_state = F.pad(hidden_states, (self.kernel_size - S, 0)) hidden_states = F.conv1d( input=hidden_states, @@ -171,7 +165,7 @@ def forward( hidden_states = hidden_states[..., : 1 - self.kernel_size] hidden_states = hidden_states.transpose(-1, -2) else: - assert sequence_length == 1 + assert S == 1 input_state = input_state.roll(shifts=-1, dims=-1) input_state[..., -1] = hidden_states[:, 0] From 76a645d25e81f04a4185b4496ff5847a409c309d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 28 Apr 2026 16:04:52 -0700 Subject: [PATCH 04/19] fix Signed-off-by: Mayank Mishra --- eval_table.tex | 27 +++++++++++++++++++ .../causal_convolution.py | 3 --- 2 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 eval_table.tex diff --git a/eval_table.tex b/eval_table.tex new file mode 100644 index 000000000..17371cbce --- /dev/null +++ b/eval_table.tex @@ -0,0 +1,27 @@ +\begin{table}[htbp] +\centering +\resizebox{\textwidth}{!}{% +\begin{tabular}{lccccccccccccc} +\toprule +Model & Wiki PPL & LMB PPL & LAMBADA & HellaSwag & PIQA & ARC-E & ARC-C & WinoGrande & BoolQ & OBQA & COPA & SciQ & Avg Acc \\ +\midrule +GRU & 27.61 & 85.99 & 23.19 & 46.67 & 70.13 & 59.26 & 30.72 & 51.22 & 52.91 & 34.40 & 65.00 & 75.50 & 50.90 \\ +M$^2$RNN & 22.98 & 33.25 & 34.91 & 47.81 & 70.51 & \underline{59.76} & \underline{30.80} & 51.70 & 53.91 & \textbf{35.20} & \textbf{73.00} & 78.90 & 53.65 \\ +Gated DeltaNet & 22.69 & 36.65 & 32.78 & 47.03 & 70.40 & 58.38 & 29.78 & \textbf{53.43} & 56.45 & 32.20 & 70.00 & \underline{80.30} & 53.08 \\ +M$^2$RNN-TT & 22.81 & 32.82 & 34.56 & 47.83 & \underline{70.95} & 58.75 & 30.12 & 50.59 & 54.04 & 33.40 & 69.00 & 77.60 & 52.68 \\ +Gated M$^2$RNN-TT & \underline{22.34} & \underline{29.72} & \underline{35.36} & \underline{48.08} & 70.57 & 58.80 & 30.55 & 51.07 & \underline{56.91} & 34.60 & 71.00 & 78.40 & 53.53 \\ +\midrule +Hybrid M$^2$RNN & \underline{21.32} & 38.10 & 34.76 & 48.38 & 70.84 & 59.01 & 30.89 & 52.72 & 53.12 & \underline{35.00} & 72.00 & \textbf{82.80} & 53.95 \\ +Hybrid M$^2$RNN-TT & 21.34 & 36.84 & 34.08 & \underline{48.42} & \textbf{71.06} & 60.65 & 30.72 & 52.72 & 53.67 & 34.40 & 68.00 & 80.00 & 53.37 \\ +Hybrid Gated DeltaNet & 21.68 & 29.55 & \underline{36.06} & 48.00 & 70.67 & 59.39 & 29.95 & \underline{53.12} & 55.44 & 34.20 & \textbf{73.00} & 82.10 & 54.19 \\ +Hybrid GDN + M$^2$RNN (1L) & 21.41 & \underline{28.59} & 35.61 & 48.37 & 70.84 & \textbf{60.86} & 30.38 & 51.85 & 53.98 & 34.00 & 68.00 & 80.80 & 53.47 \\ +Hybrid GDN + M$^2$RNN-TT (1L) & 21.41 & 30.28 & 35.55 & 48.03 & 70.89 & 58.42 & \underline{31.14} & 52.49 & \underline{58.50} & 32.80 & 69.00 & 80.60 & 53.74 \\ +\midrule +Hybrid GDN + M$^2$RNN (3L) & \textbf{21.25} & 26.79 & \textbf{37.49} & \textbf{48.50} & 70.24 & \underline{60.73} & \textbf{31.91} & \underline{52.64} & 55.54 & \underline{34.00} & \underline{72.00} & 81.50 & 54.46 \\ +Hybrid GDN + M$^2$RNN-TT (3L) & 21.40 & \textbf{26.77} & 37.40 & 48.48 & \underline{70.95} & 58.63 & 30.89 & 51.93 & \textbf{59.88} & 33.20 & \underline{72.00} & \underline{81.70} & 54.50 \\ +\bottomrule +\end{tabular} +} +\caption{Zero-shot evaluation results for 400M cosine models. Accuracy metrics are reported as \%; PPL metrics are perplexity (lower is better). \textbf{Bold} = best overall; \underline{underline} = best in group.} +\label{tab:eval_results} +\end{table} diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py index b2ec2796f..41fa11c3c 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py @@ -85,10 +85,7 @@ def __init__( self.casual_conv1d_compatible = self.num_groups == self.in_channels == self.out_channels self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] - mark_parameter_as_mup_learning_rate(self.input_projection.weight) mark_parameter_as_mup_learning_rate(self.conv1d.weight) - - mark_parameter_as_no_weight_decay(self.input_projection.bias) mark_parameter_as_no_weight_decay(self.conv1d.bias) def forward( From d09913c14c6b7f1de44dd8efe586d925a946218d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 28 Apr 2026 16:29:00 -0700 Subject: [PATCH 05/19] fix Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/convolution.py | 156 ++++++++++++++- .../causal_convolution.py | 181 ------------------ 2 files changed, 154 insertions(+), 183 deletions(-) delete mode 100644 lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py diff --git a/lm_engine/hf_models/modeling_utils/convolution.py b/lm_engine/hf_models/modeling_utils/convolution.py index 1dcf09758..308f392ce 100644 --- a/lm_engine/hf_models/modeling_utils/convolution.py +++ b/lm_engine/hf_models/modeling_utils/convolution.py @@ -2,13 +2,26 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** - from __future__ import annotations import torch import torch.nn as nn +import torch.nn.functional as F + +from ...enums import Kernel +from ...kernels import is_kernel_allowed +from ...utils import divide_if_divisible, is_causal_conv1d_available +from ..parameter import ( + mark_parameter_as_initialized, + mark_parameter_as_mup_learning_rate, + mark_parameter_as_no_weight_decay, +) +from .activations import get_activation_function +from .convolution import ParameterizedConv1d + -from ..parameter import mark_parameter_as_initialized, mark_parameter_as_no_weight_decay +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update class ParameterizedConv1d(nn.Conv1d): @@ -44,6 +57,145 @@ def __init__( mark_parameter_as_no_weight_decay(self.bias) + +def _apply_mask_to_padding_states(hidden_states: torch.Tensor, attention_mask: torch.Tensor | None) -> torch.Tensor: + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +class CausalConvolution(nn.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + groups: int, + activation_function: str, + add_bias: bool, + std: float | None, + use_padding_free_transformer: bool, + ) -> CausalConvolution: + if use_padding_free_transformer: + raise NotImplementedError() + + # _get_std_for_linear( + # initializer_range=initializer_range, + # init_method=init_method, + # m_width=m_width, + # fan_in=kernel_size, + # num_layers=num_layers, + # use_depth_scaled_init=False, + # ) + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=kernel_size - 1, + groups=groups, + bias=add_bias, + ) + + self.activation_string = activation_function + self.activation_function = get_activation_function(self.activation_string) + self.casual_conv1d_compatible = self.groups == self.in_channels == self.out_channels + self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] + self.std = std + + divide_if_divisible(in_channels, groups) + divide_if_divisible(out_channels, groups) + + mark_parameter_as_mup_learning_rate(self.weight) + mark_parameter_as_no_weight_decay(self.bias) + + def forward( + self, + hidden_states: torch.Tensor, + input_state: torch.Tensor | None, + attention_mask: torch.Tensor | None, + return_cache_state: bool, + conv1d_padding: int, + conv1d_stride: int = 1, + ) -> tuple[torch.Tensor, torch.Tensor]: + S = hidden_states.size(1) + + assert conv1d_stride == 1 + assert conv1d_padding == self.kernel_size - 1 + + hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) + + if is_kernel_allowed(Kernel.causal_conv1d) and self.casual_conv1d_compatible: + if input_state is None: + hidden_states = hidden_states.transpose(-1, -2) + + if return_cache_state: + # F.pad trims the hidden_states if sequence_length > kernel_size + input_state = F.pad(hidden_states, (self.kernel_size - S, 0)) + + hidden_states = causal_conv1d_fn( + x=hidden_states, + weight=self.weight.squeeze(1), + bias=self.bias, + activation=self.activation_string if self.use_activation_inside_kernel else None, + ) + + hidden_states = hidden_states.transpose(-1, -2) + else: + assert S == 1 + + # we clone to prevent modification in-place + # torch compile can remove the clone if its not needed + # this is to prevent silent incorrectness down the line in the model + input_state_buffer = input_state.clone() + hidden_states = causal_conv1d_update( + x=hidden_states, + conv_state=input_state_buffer, + weight=self.weight.squeeze(1), + bias=self.bias, + activation=self.activation_string if self.use_activation_inside_kernel else None, + ) + input_state = input_state_buffer if return_cache_state else None + + if not self.use_activation_inside_kernel: + hidden_states = self.activation_function(hidden_states) + else: + if input_state is None: + hidden_states = hidden_states.transpose(-1, -2) + + if return_cache_state: + # F.pad trims the hidden_states if sequence_length > kernel_size + input_state = F.pad(hidden_states, (self.kernel_size - S, 0)) + + hidden_states = super().forward(hidden_states) + + # removes padding on the right side of the sequence + hidden_states = hidden_states[..., : 1 - self.kernel_size] + hidden_states = hidden_states.transpose(-1, -2) + else: + assert S == 1 + + input_state = input_state.roll(shifts=-1, dims=-1) + input_state[..., -1] = hidden_states[:, 0] + + hidden_states = (input_state * self.weight.squeeze(1)).sum(dim=-1) + hidden_states = hidden_states[:, None, :] + if self.bias is not None: + hidden_states = hidden_states + self.bias + + if not return_cache_state: + input_state = None + + hidden_states = self.activation_function(hidden_states) + hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) + + return hidden_states, input_state + @torch.no_grad() def reset_parameters(self) -> None: if self.std is None: diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py deleted file mode 100644 index 41fa11c3c..000000000 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py +++ /dev/null @@ -1,181 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ....enums import Kernel -from ....kernels import is_kernel_allowed -from ....utils import divide_if_divisible, is_causal_conv1d_available -from ...parameter import mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay -from ..activations import get_activation_function -from ..convolution import ParameterizedConv1d -from ..init_utils import _get_std_for_linear - - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update - - -def _apply_mask_to_padding_states(hidden_states: torch.Tensor, attention_mask: torch.Tensor | None) -> torch.Tensor: - """ - Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 - """ - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - dtype = hidden_states.dtype - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - - return hidden_states - - -class CausalConvolution(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - num_groups: int, - activation_function: str, - add_bias: bool, - initializer_range: float | None, - m_width: float, - init_method: str, - num_layers: int, - layer_idx: int, - use_padding_free_transformer: bool, - ) -> CausalConvolution: - super().__init__() - - if use_padding_free_transformer: - raise NotImplementedError() - - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.num_groups = num_groups - self.layer_idx = layer_idx - self.activation_string = activation_function - - divide_if_divisible(in_channels, num_groups) - divide_if_divisible(out_channels, num_groups) - - self.conv1d = ParameterizedConv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - bias=add_bias, - padding=kernel_size - 1, - groups=num_groups, - std=_get_std_for_linear( - initializer_range=initializer_range, - init_method=init_method, - m_width=m_width, - fan_in=kernel_size, - num_layers=num_layers, - use_depth_scaled_init=False, - ), - ) - - self.activation_function = get_activation_function(self.activation_string) - - self.casual_conv1d_compatible = self.num_groups == self.in_channels == self.out_channels - self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] - - mark_parameter_as_mup_learning_rate(self.conv1d.weight) - mark_parameter_as_no_weight_decay(self.conv1d.bias) - - def forward( - self, - hidden_states: torch.Tensor, - input_state: torch.Tensor | None, - attention_mask: torch.Tensor | None, - return_cache_state: bool, - conv1d_padding: int, - conv1d_stride: int = 1, - ) -> tuple[torch.Tensor, torch.Tensor]: - W = self.conv1d.weight - b = self.conv1d.bias - - S = hidden_states.size(1) - - assert conv1d_stride == 1 - assert conv1d_padding == self.kernel_size - 1 - - hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) - - if is_kernel_allowed(Kernel.causal_conv1d) and self.casual_conv1d_compatible: - if input_state is None: - hidden_states = hidden_states.transpose(-1, -2) - - if return_cache_state: - # F.pad trims the hidden_states if sequence_length > kernel_size - input_state = F.pad(hidden_states, (self.kernel_size - S, 0)) - - hidden_states = causal_conv1d_fn( - x=hidden_states, - weight=W.squeeze(1), - bias=b, - activation=self.activation_string if self.use_activation_inside_kernel else None, - ) - - hidden_states = hidden_states.transpose(-1, -2) - else: - assert S == 1 - - # we clone to prevent modification in-place - # torch compile can remove the clone if its not needed - # this is to prevent silent incorrectness down the line in the model - input_state_buffer = input_state.clone() - hidden_states = causal_conv1d_update( - x=hidden_states, - conv_state=input_state_buffer, - weight=W.squeeze(1), - bias=b, - activation=self.activation_string if self.use_activation_inside_kernel else None, - ) - input_state = input_state_buffer if return_cache_state else None - - if not self.use_activation_inside_kernel: - hidden_states = self.activation_function(hidden_states) - else: - if input_state is None: - hidden_states = hidden_states.transpose(-1, -2) - - if return_cache_state: - # F.pad trims the hidden_states if sequence_length > kernel_size - input_state = F.pad(hidden_states, (self.kernel_size - S, 0)) - - hidden_states = F.conv1d( - input=hidden_states, - weight=W, - bias=b, - stride=conv1d_stride, - padding=conv1d_padding, - groups=self.num_groups, - ) - - # removes padding on the right side of the sequence - hidden_states = hidden_states[..., : 1 - self.kernel_size] - hidden_states = hidden_states.transpose(-1, -2) - else: - assert S == 1 - - input_state = input_state.roll(shifts=-1, dims=-1) - input_state[..., -1] = hidden_states[:, 0] - - hidden_states = (input_state * W.squeeze(1)).sum(dim=-1) - hidden_states = hidden_states[:, None, :] - if b is not None: - hidden_states = hidden_states + b - - if not return_cache_state: - input_state = None - - hidden_states = self.activation_function(hidden_states) - hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) - - return hidden_states, input_state From b9a7837b87e475d89fb9e39fce37d9786ee68d39 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 28 Apr 2026 16:35:14 -0700 Subject: [PATCH 06/19] fix Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/convolution.py | 68 ++----------------- 1 file changed, 6 insertions(+), 62 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/convolution.py b/lm_engine/hf_models/modeling_utils/convolution.py index 308f392ce..b4bf44464 100644 --- a/lm_engine/hf_models/modeling_utils/convolution.py +++ b/lm_engine/hf_models/modeling_utils/convolution.py @@ -17,47 +17,12 @@ mark_parameter_as_no_weight_decay, ) from .activations import get_activation_function -from .convolution import ParameterizedConv1d if is_causal_conv1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -class ParameterizedConv1d(nn.Conv1d): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: str | int = 0, - dilation: int = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = "zeros", # TODO: refine this type - device=None, - dtype=None, - std: float | None = None, - ) -> ParameterizedConv1d: - self.std = std - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - mark_parameter_as_no_weight_decay(self.bias) - - def _apply_mask_to_padding_states(hidden_states: torch.Tensor, attention_mask: torch.Tensor | None) -> torch.Tensor: """ Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 @@ -69,48 +34,33 @@ def _apply_mask_to_padding_states(hidden_states: torch.Tensor, attention_mask: t return hidden_states -class CausalConvolution(nn.Conv1d): +class DepthwiseCausalConvolution(nn.Conv1d): def __init__( self, - in_channels: int, - out_channels: int, + hidden_size: int, kernel_size: int, - groups: int, activation_function: str, add_bias: bool, std: float | None, use_padding_free_transformer: bool, - ) -> CausalConvolution: + ) -> DepthwiseCausalConvolution: if use_padding_free_transformer: raise NotImplementedError() - # _get_std_for_linear( - # initializer_range=initializer_range, - # init_method=init_method, - # m_width=m_width, - # fan_in=kernel_size, - # num_layers=num_layers, - # use_depth_scaled_init=False, - # ) - super().__init__( - in_channels=in_channels, - out_channels=out_channels, + in_channels=hidden_size, + out_channels=hidden_size, kernel_size=kernel_size, padding=kernel_size - 1, - groups=groups, + groups=hidden_size, bias=add_bias, ) self.activation_string = activation_function self.activation_function = get_activation_function(self.activation_string) - self.casual_conv1d_compatible = self.groups == self.in_channels == self.out_channels self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] self.std = std - divide_if_divisible(in_channels, groups) - divide_if_divisible(out_channels, groups) - mark_parameter_as_mup_learning_rate(self.weight) mark_parameter_as_no_weight_decay(self.bias) @@ -120,14 +70,8 @@ def forward( input_state: torch.Tensor | None, attention_mask: torch.Tensor | None, return_cache_state: bool, - conv1d_padding: int, - conv1d_stride: int = 1, ) -> tuple[torch.Tensor, torch.Tensor]: S = hidden_states.size(1) - - assert conv1d_stride == 1 - assert conv1d_padding == self.kernel_size - 1 - hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) if is_kernel_allowed(Kernel.causal_conv1d) and self.casual_conv1d_compatible: From 588ea11f04f64de8f5ddb392c6a530aa9efa1dd4 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 28 Apr 2026 16:46:28 -0700 Subject: [PATCH 07/19] fix Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/gated_deltanet.py | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py index 17ba86282..aa592d7a2 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py @@ -13,12 +13,11 @@ from ....utils import divide_if_divisible, is_fla_available from ...cache import ConstantCache, GenerationCache, GenerationState from ..activations import silu -from ..convolution import ParameterizedConv1d +from ..convolution import DepthwiseCausalConvolution from ..decay_gate import SoftplusDecayGate from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function -from .causal_convolution import causal_convolution from .utils import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence @@ -105,13 +104,11 @@ def __init__( ) self.conv_size = conv_size - self.qkv_conv1d = ParameterizedConv1d( - in_channels=2 * self.key_dim + self.value_dim, - out_channels=2 * self.key_dim + self.value_dim, + self.qkv_conv1d = DepthwiseCausalConvolution( + hidden_size=2 * self.key_dim + self.value_dim, kernel_size=conv_size, - padding=conv_size - 1, - groups=2 * self.key_dim + self.value_dim, - bias=False, + activation_function="silu", + add_bias=False, std=_get_std_for_linear( initializer_range=initializer_range, init_method=init_method, @@ -120,8 +117,8 @@ def __init__( num_layers=num_layers, use_depth_scaled_init=False, ), + use_padding_free_transformer=use_padding_free_transformer, ) - self.activation_string = "silu" self.o_norm = get_normalization_function("rmsnorm", self.v_head_dim, eps=norm_eps) self.o_proj = ParameterizedLinear( @@ -163,17 +160,11 @@ def forward( else: a, b = self.ab_proj(hidden_states).chunk(2, dim=-1) - qkv, c = causal_convolution( + qkv, c = self.qkv_conv1d( hidden_states=qkv, input_state=c, attention_mask=attention_mask, - conv1d_weight=self.qkv_conv1d.weight, - conv1d_bias=self.qkv_conv1d.bias, - conv1d_num_groups=qkv.size(-1), return_cache_state=cache_params is not None, - activation_string=self.activation_string, - conv1d_padding=self.conv_size - 1, - conv1d_stride=1, ) q, k, v = qkv.split((self.key_dim, self.key_dim, self.value_dim), dim=-1) From 0113fbfd8b9202be1d8c948262768e48e87d1209 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 28 Apr 2026 16:48:46 -0700 Subject: [PATCH 08/19] fix Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/gru.py | 22 ++++++------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py index 5ebc66993..846b3b2e3 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py @@ -17,11 +17,10 @@ mark_parameter_as_no_weight_decay, ) from ..activations import clip_gradients, get_activation_function, is_glu, sigmoid, silu, tanh -from ..convolution import ParameterizedConv1d +from ..convolution import DepthwiseCausalConvolution, ParameterizedConv1d from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function -from .causal_convolution import causal_convolution from .utils import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence @@ -120,13 +119,11 @@ def __init__( if kernel_size is not None: assert not is_glu(self.activation_string) - self.conv1d = ParameterizedConv1d( - in_channels=self.state_size, - out_channels=self.state_size, + self.conv1d = DepthwiseCausalConvolution( + hidden_size=self.state_size, kernel_size=kernel_size, - bias=add_bias, - padding=kernel_size - 1, - groups=self.state_size, + activation_function=self.activation_string, + add_bias=add_bias, std=_get_std_for_linear( initializer_range=initializer_range, init_method=init_method, @@ -135,6 +132,7 @@ def __init__( num_layers=num_layers, use_depth_scaled_init=False, ), + use_padding_free_transformer=use_padding_free_transformer, ) mark_parameter_as_mup_learning_rate(self.conv1d.weight) @@ -206,17 +204,11 @@ def forward( if self.kernel_size is None: x = self.activation_function(x) else: - x, c = causal_convolution( + x, c = self.conv1d( hidden_states=x, input_state=c, attention_mask=attention_mask, - conv1d_weight=self.conv1d.weight, - conv1d_bias=self.conv1d.bias, - conv1d_num_groups=self.state_size, return_cache_state=cache_params is not None, - activation_string=self.activation_string, - conv1d_padding=self.kernel_size - 1, - conv1d_stride=1, ) x, xf, xr = [i.view(*i.size()[:-1], -1, self.state_head_dim) for i in (x, xf, xr)] From 87bcf9bb08cfb40affd857a71d5e360b891d52af Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 28 Apr 2026 16:52:15 -0700 Subject: [PATCH 09/19] fix Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/m2rnn.py | 22 ++++++------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py index 090dddf1b..a481f5697 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py @@ -20,12 +20,11 @@ mark_parameter_as_no_weight_decay, ) from ..activations import clip_gradients, is_glu, silu, tanh -from ..convolution import ParameterizedConv1d +from ..convolution import DepthwiseCausalConvolution from ..decay_gate import SoftplusDecayGate from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function -from .causal_convolution import causal_convolution from .utils import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence @@ -132,13 +131,11 @@ def __init__( else: assert self.activation_string is None or not is_glu(self.activation_string) - self.conv1d = ParameterizedConv1d( - in_channels=self.conv_dim, - out_channels=self.conv_dim, + self.conv1d = DepthwiseCausalConvolution( + hidden_size=self.conv_dim, kernel_size=kernel_size, - bias=add_bias, - padding=kernel_size - 1, - groups=self.conv_dim, + activation_function=self.activation_string, + add_bias=add_bias, std=_get_std_for_linear( initializer_range=initializer_range, init_method=init_method, @@ -147,6 +144,7 @@ def __init__( num_layers=num_layers, use_depth_scaled_init=False, ), + use_padding_free_transformer=use_padding_free_transformer, ) mark_parameter_as_mup_learning_rate(self.conv1d.weight) @@ -211,17 +209,11 @@ def forward( f = self.decay_gate(f, final_exponential=True, output_dtype=f.dtype) if self.kernel_size is not None: - x, c = causal_convolution( + x, c = self.conv1d( hidden_states=x, input_state=c, attention_mask=attention_mask, - conv1d_weight=self.conv1d.weight, - conv1d_bias=self.conv1d.bias, - conv1d_num_groups=self.conv_dim, return_cache_state=cache_params is not None, - activation_string=self.activation_string, - conv1d_padding=self.kernel_size - 1, - conv1d_stride=1, ) q, k, v = x.split((self.q_shape, self.k_shape, self.v_shape), dim=-1) From 531778a56a3ff4c784e4149523a27693beb1754e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 28 Apr 2026 17:04:11 -0700 Subject: [PATCH 10/19] fix Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/rnn.py | 22 ++++++------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 8c1c64c38..7eb87459f 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -17,11 +17,10 @@ mark_parameter_as_no_weight_decay, ) from ..activations import clip_gradients, get_activation_function, is_glu, silu, tanh -from ..convolution import ParameterizedConv1d +from ..convolution import DepthwiseCausalConvolution from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function -from .causal_convolution import causal_convolution from .utils import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence @@ -97,13 +96,11 @@ def __init__( if kernel_size is not None: assert not is_glu(self.activation_string) - self.conv1d = ParameterizedConv1d( - in_channels=self.state_size, - out_channels=self.state_size, + self.conv1d = DepthwiseCausalConvolution( + hidden_size=self.state_size, kernel_size=kernel_size, - bias=add_bias, - padding=kernel_size - 1, - groups=self.state_size, + activation_function=self.activation_string, + add_bias=add_bias, std=_get_std_for_linear( initializer_range=initializer_range, init_method=init_method, @@ -112,6 +109,7 @@ def __init__( num_layers=num_layers, use_depth_scaled_init=False, ), + use_padding_free_transformer=use_padding_free_transformer, ) mark_parameter_as_mup_learning_rate(self.conv1d.weight) @@ -176,17 +174,11 @@ def forward( if self.kernel_size is None: x = self.activation_function(x) else: - x, c = causal_convolution( + x, c = self.conv1d( hidden_states=x, input_state=c, attention_mask=attention_mask, - conv1d_weight=self.conv1d.weight, - conv1d_bias=self.conv1d.bias, - conv1d_num_groups=self.state_size, return_cache_state=cache_params is not None, - activation_string=self.activation_string, - conv1d_padding=self.kernel_size - 1, - conv1d_stride=1, ) x = x.view(*x.size()[:-1], -1, self.state_head_dim) From 02b8eb21fad53e0f0b982f960130c069fa7ad58f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 28 Apr 2026 19:16:08 -0700 Subject: [PATCH 11/19] fix Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/convolution.py | 1 - .../modeling_utils/sequence_mixer_blocks/mamba2.py | 13 ++++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/convolution.py b/lm_engine/hf_models/modeling_utils/convolution.py index b4bf44464..7f6dc5744 100644 --- a/lm_engine/hf_models/modeling_utils/convolution.py +++ b/lm_engine/hf_models/modeling_utils/convolution.py @@ -61,7 +61,6 @@ def __init__( self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] self.std = std - mark_parameter_as_mup_learning_rate(self.weight) mark_parameter_as_no_weight_decay(self.bias) def forward( diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index 7de4ee434..fc31c2233 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -18,7 +18,7 @@ mark_parameter_as_no_weight_decay, ) from ..activations import get_activation_function, silu -from ..convolution import ParameterizedConv1d +from ..convolution import DepthwiseCausalConvolution from ..decay_gate import SoftplusDecayGate from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear @@ -144,13 +144,11 @@ def __init__( # 1D convolutional layer self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size - self.conv1d = ParameterizedConv1d( - in_channels=self.conv_dim, - out_channels=self.conv_dim, - bias=use_conv_bias, + self.conv1d = DepthwiseCausalConvolution( + hidden_size=self.conv_dim, kernel_size=self.conv_kernel_size, - groups=self.conv_dim, - padding=self.conv_kernel_size - 1, + activation_function=self.activation_string, + add_bias=add_bias, std=_get_std_for_linear( initializer_range=initializer_range, init_method=init_method, @@ -159,6 +157,7 @@ def __init__( num_layers=num_layers, use_depth_scaled_init=False, ), + use_padding_free_transformer=False, ) # projection of the input hidden states From 362ee507d0f03f83769119c63ed3db19c907870f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 28 Apr 2026 19:20:00 -0700 Subject: [PATCH 12/19] fix Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/__init__.py | 1 - lm_engine/hf_models/modeling_utils/convolution.py | 5 ++++- .../hf_models/modeling_utils/sequence_mixer_blocks/gru.py | 2 +- .../hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py | 3 +-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/__init__.py b/lm_engine/hf_models/modeling_utils/__init__.py index 3404b87d3..f121a32bc 100644 --- a/lm_engine/hf_models/modeling_utils/__init__.py +++ b/lm_engine/hf_models/modeling_utils/__init__.py @@ -3,7 +3,6 @@ # ************************************************** from .activations import get_activation_function, is_glu -from .convolution import ParameterizedConv1d from .dropout import Dropout from .dtensor_module import DTensorModule from .embedding import ParameterizedEmbedding, get_tensor_parallel_vocab_info diff --git a/lm_engine/hf_models/modeling_utils/convolution.py b/lm_engine/hf_models/modeling_utils/convolution.py index 7f6dc5744..ea7b19a9c 100644 --- a/lm_engine/hf_models/modeling_utils/convolution.py +++ b/lm_engine/hf_models/modeling_utils/convolution.py @@ -47,6 +47,8 @@ def __init__( if use_padding_free_transformer: raise NotImplementedError() + self.std = std + super().__init__( in_channels=hidden_size, out_channels=hidden_size, @@ -59,10 +61,11 @@ def __init__( self.activation_string = activation_function self.activation_function = get_activation_function(self.activation_string) self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] - self.std = std mark_parameter_as_no_weight_decay(self.bias) + self.reset_parameters() + def forward( self, hidden_states: torch.Tensor, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py index 846b3b2e3..e16b062ff 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py @@ -17,7 +17,7 @@ mark_parameter_as_no_weight_decay, ) from ..activations import clip_gradients, get_activation_function, is_glu, sigmoid, silu, tanh -from ..convolution import DepthwiseCausalConvolution, ParameterizedConv1d +from ..convolution import DepthwiseCausalConvolution from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index fc31c2233..8f2c9f56f 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -18,12 +18,11 @@ mark_parameter_as_no_weight_decay, ) from ..activations import get_activation_function, silu -from ..convolution import DepthwiseCausalConvolution +from ..convolution import DepthwiseCausalConvolution, _apply_mask_to_padding_states from ..decay_gate import SoftplusDecayGate from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function -from .causal_convolution import _apply_mask_to_padding_states if is_mamba_2_ssm_available(): From 69d314f3aeeb763f88c0686884c7d1043637386a Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 28 Apr 2026 22:27:21 -0700 Subject: [PATCH 13/19] fix Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index 8f2c9f56f..367a7a0e7 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -273,7 +273,7 @@ def _torch_forward( ssm_state = torch.zeros( batch_size, self.num_heads, - divide_if_divisible(self.intermediate_size, self.num_heads, ""), + divide_if_divisible(self.intermediate_size, self.num_heads), self.ssm_state_size, device=projected_states.device, dtype=dtype, From 6fbd1c335b1ecadd9ebe6cdbe08621ee8b3e88f3 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 29 Apr 2026 14:19:52 -0700 Subject: [PATCH 14/19] add Signed-off-by: Mayank Mishra --- lm_engine/hf_models/__init__.py | 30 +++++++++++++++++++ .../hf_models/modeling_utils/convolution.py | 1 + 2 files changed, 31 insertions(+) diff --git a/lm_engine/hf_models/__init__.py b/lm_engine/hf_models/__init__.py index 1f876ae80..7e90cbffa 100644 --- a/lm_engine/hf_models/__init__.py +++ b/lm_engine/hf_models/__init__.py @@ -2,6 +2,8 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** +import torch + from .config import CommonConfig from .loss import get_autoregressive_language_modeling_loss, is_aux_loss_zero from .mixins import CausalLMOutputWithPast, PipelineParallelInput, PipelineParallelOutput @@ -39,3 +41,31 @@ register_model_classes() + + +def _patch_granitemoehybrid_weight_init() -> None: + try: + from transformers.models.granitemoehybrid.modeling_granitemoehybrid import ( + GraniteMoeHybridMambaLayer, + GraniteMoeHybridPreTrainedModel, + ) + except Exception: + return + + if getattr(GraniteMoeHybridPreTrainedModel._init_weights, "_lm_engine_patched", False): + return + + def _init_weights(self, module): + super(GraniteMoeHybridPreTrainedModel, self)._init_weights(module) + if isinstance(module, GraniteMoeHybridMambaLayer): + module.dt_bias.data.fill_(1.0) + module.A_log.data.copy_( + torch.log(torch.arange(1, module.num_heads + 1, dtype=module.A_log.dtype, device=module.A_log.device)) + ) + module.D.data.fill_(1.0) + + _init_weights._lm_engine_patched = True + GraniteMoeHybridPreTrainedModel._init_weights = _init_weights + + +_patch_granitemoehybrid_weight_init() diff --git a/lm_engine/hf_models/modeling_utils/convolution.py b/lm_engine/hf_models/modeling_utils/convolution.py index ea7b19a9c..8c10ed8f1 100644 --- a/lm_engine/hf_models/modeling_utils/convolution.py +++ b/lm_engine/hf_models/modeling_utils/convolution.py @@ -61,6 +61,7 @@ def __init__( self.activation_string = activation_function self.activation_function = get_activation_function(self.activation_string) self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] + self.casual_conv1d_compatible = True mark_parameter_as_no_weight_decay(self.bias) From eebc57198c40721f8e825a92173124a6147e95cf Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 29 Apr 2026 14:21:13 -0700 Subject: [PATCH 15/19] add Signed-off-by: Mayank Mishra --- .../{convolution.py => depthwise_causal_convolution.py} | 3 +-- .../modeling_utils/sequence_mixer_blocks/gated_deltanet.py | 2 +- .../hf_models/modeling_utils/sequence_mixer_blocks/gru.py | 2 +- .../hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py | 2 +- .../hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py | 2 +- .../hf_models/modeling_utils/sequence_mixer_blocks/rnn.py | 2 +- 6 files changed, 6 insertions(+), 7 deletions(-) rename lm_engine/hf_models/modeling_utils/{convolution.py => depthwise_causal_convolution.py} (97%) diff --git a/lm_engine/hf_models/modeling_utils/convolution.py b/lm_engine/hf_models/modeling_utils/depthwise_causal_convolution.py similarity index 97% rename from lm_engine/hf_models/modeling_utils/convolution.py rename to lm_engine/hf_models/modeling_utils/depthwise_causal_convolution.py index 8c10ed8f1..b1a01def7 100644 --- a/lm_engine/hf_models/modeling_utils/convolution.py +++ b/lm_engine/hf_models/modeling_utils/depthwise_causal_convolution.py @@ -61,7 +61,6 @@ def __init__( self.activation_string = activation_function self.activation_function = get_activation_function(self.activation_string) self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] - self.casual_conv1d_compatible = True mark_parameter_as_no_weight_decay(self.bias) @@ -77,7 +76,7 @@ def forward( S = hidden_states.size(1) hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) - if is_kernel_allowed(Kernel.causal_conv1d) and self.casual_conv1d_compatible: + if is_kernel_allowed(Kernel.causal_conv1d): if input_state is None: hidden_states = hidden_states.transpose(-1, -2) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py index aa592d7a2..3b7f0a742 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py @@ -13,8 +13,8 @@ from ....utils import divide_if_divisible, is_fla_available from ...cache import ConstantCache, GenerationCache, GenerationState from ..activations import silu -from ..convolution import DepthwiseCausalConvolution from ..decay_gate import SoftplusDecayGate +from ..depthwise_causal_convolution import DepthwiseCausalConvolution from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py index e16b062ff..d066a44c8 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py @@ -17,7 +17,7 @@ mark_parameter_as_no_weight_decay, ) from ..activations import clip_gradients, get_activation_function, is_glu, sigmoid, silu, tanh -from ..convolution import DepthwiseCausalConvolution +from ..depthwise_causal_convolution import DepthwiseCausalConvolution from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py index a481f5697..fbcff95f5 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py @@ -20,8 +20,8 @@ mark_parameter_as_no_weight_decay, ) from ..activations import clip_gradients, is_glu, silu, tanh -from ..convolution import DepthwiseCausalConvolution from ..decay_gate import SoftplusDecayGate +from ..depthwise_causal_convolution import DepthwiseCausalConvolution from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index 367a7a0e7..ff3c125a5 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -18,8 +18,8 @@ mark_parameter_as_no_weight_decay, ) from ..activations import get_activation_function, silu -from ..convolution import DepthwiseCausalConvolution, _apply_mask_to_padding_states from ..decay_gate import SoftplusDecayGate +from ..depthwise_causal_convolution import DepthwiseCausalConvolution, _apply_mask_to_padding_states from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 7eb87459f..8d1046c2f 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -17,7 +17,7 @@ mark_parameter_as_no_weight_decay, ) from ..activations import clip_gradients, get_activation_function, is_glu, silu, tanh -from ..convolution import DepthwiseCausalConvolution +from ..depthwise_causal_convolution import DepthwiseCausalConvolution from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function From fdeaf36159dc2f570c5d619001cdde71263c2beb Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 29 Apr 2026 14:23:15 -0700 Subject: [PATCH 16/19] add Signed-off-by: Mayank Mishra --- eval_table.tex | 27 --------------------------- 1 file changed, 27 deletions(-) delete mode 100644 eval_table.tex diff --git a/eval_table.tex b/eval_table.tex deleted file mode 100644 index 17371cbce..000000000 --- a/eval_table.tex +++ /dev/null @@ -1,27 +0,0 @@ -\begin{table}[htbp] -\centering -\resizebox{\textwidth}{!}{% -\begin{tabular}{lccccccccccccc} -\toprule -Model & Wiki PPL & LMB PPL & LAMBADA & HellaSwag & PIQA & ARC-E & ARC-C & WinoGrande & BoolQ & OBQA & COPA & SciQ & Avg Acc \\ -\midrule -GRU & 27.61 & 85.99 & 23.19 & 46.67 & 70.13 & 59.26 & 30.72 & 51.22 & 52.91 & 34.40 & 65.00 & 75.50 & 50.90 \\ -M$^2$RNN & 22.98 & 33.25 & 34.91 & 47.81 & 70.51 & \underline{59.76} & \underline{30.80} & 51.70 & 53.91 & \textbf{35.20} & \textbf{73.00} & 78.90 & 53.65 \\ -Gated DeltaNet & 22.69 & 36.65 & 32.78 & 47.03 & 70.40 & 58.38 & 29.78 & \textbf{53.43} & 56.45 & 32.20 & 70.00 & \underline{80.30} & 53.08 \\ -M$^2$RNN-TT & 22.81 & 32.82 & 34.56 & 47.83 & \underline{70.95} & 58.75 & 30.12 & 50.59 & 54.04 & 33.40 & 69.00 & 77.60 & 52.68 \\ -Gated M$^2$RNN-TT & \underline{22.34} & \underline{29.72} & \underline{35.36} & \underline{48.08} & 70.57 & 58.80 & 30.55 & 51.07 & \underline{56.91} & 34.60 & 71.00 & 78.40 & 53.53 \\ -\midrule -Hybrid M$^2$RNN & \underline{21.32} & 38.10 & 34.76 & 48.38 & 70.84 & 59.01 & 30.89 & 52.72 & 53.12 & \underline{35.00} & 72.00 & \textbf{82.80} & 53.95 \\ -Hybrid M$^2$RNN-TT & 21.34 & 36.84 & 34.08 & \underline{48.42} & \textbf{71.06} & 60.65 & 30.72 & 52.72 & 53.67 & 34.40 & 68.00 & 80.00 & 53.37 \\ -Hybrid Gated DeltaNet & 21.68 & 29.55 & \underline{36.06} & 48.00 & 70.67 & 59.39 & 29.95 & \underline{53.12} & 55.44 & 34.20 & \textbf{73.00} & 82.10 & 54.19 \\ -Hybrid GDN + M$^2$RNN (1L) & 21.41 & \underline{28.59} & 35.61 & 48.37 & 70.84 & \textbf{60.86} & 30.38 & 51.85 & 53.98 & 34.00 & 68.00 & 80.80 & 53.47 \\ -Hybrid GDN + M$^2$RNN-TT (1L) & 21.41 & 30.28 & 35.55 & 48.03 & 70.89 & 58.42 & \underline{31.14} & 52.49 & \underline{58.50} & 32.80 & 69.00 & 80.60 & 53.74 \\ -\midrule -Hybrid GDN + M$^2$RNN (3L) & \textbf{21.25} & 26.79 & \textbf{37.49} & \textbf{48.50} & 70.24 & \underline{60.73} & \textbf{31.91} & \underline{52.64} & 55.54 & \underline{34.00} & \underline{72.00} & 81.50 & 54.46 \\ -Hybrid GDN + M$^2$RNN-TT (3L) & 21.40 & \textbf{26.77} & 37.40 & 48.48 & \underline{70.95} & 58.63 & 30.89 & 51.93 & \textbf{59.88} & 33.20 & \underline{72.00} & \underline{81.70} & 54.50 \\ -\bottomrule -\end{tabular} -} -\caption{Zero-shot evaluation results for 400M cosine models. Accuracy metrics are reported as \%; PPL metrics are perplexity (lower is better). \textbf{Bold} = best overall; \underline{underline} = best in group.} -\label{tab:eval_results} -\end{table} From 4fd91235aac50fd0d1456c60a1bb469b702be328 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 29 Apr 2026 16:37:31 -0700 Subject: [PATCH 17/19] add Signed-off-by: Mayank Mishra --- .../depthwise_causal_convolution.py | 17 ++-- .../sequence_mixer_blocks/mamba2.py | 87 ++++++------------- 2 files changed, 33 insertions(+), 71 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/depthwise_causal_convolution.py b/lm_engine/hf_models/modeling_utils/depthwise_causal_convolution.py index b1a01def7..beab701ae 100644 --- a/lm_engine/hf_models/modeling_utils/depthwise_causal_convolution.py +++ b/lm_engine/hf_models/modeling_utils/depthwise_causal_convolution.py @@ -71,8 +71,8 @@ def forward( hidden_states: torch.Tensor, input_state: torch.Tensor | None, attention_mask: torch.Tensor | None, - return_cache_state: bool, - ) -> tuple[torch.Tensor, torch.Tensor]: + output_state: bool, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: S = hidden_states.size(1) hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) @@ -80,7 +80,7 @@ def forward( if input_state is None: hidden_states = hidden_states.transpose(-1, -2) - if return_cache_state: + if output_state: # F.pad trims the hidden_states if sequence_length > kernel_size input_state = F.pad(hidden_states, (self.kernel_size - S, 0)) @@ -95,10 +95,8 @@ def forward( else: assert S == 1 - # we clone to prevent modification in-place - # torch compile can remove the clone if its not needed - # this is to prevent silent incorrectness down the line in the model input_state_buffer = input_state.clone() + hidden_states = causal_conv1d_update( x=hidden_states, conv_state=input_state_buffer, @@ -106,7 +104,8 @@ def forward( bias=self.bias, activation=self.activation_string if self.use_activation_inside_kernel else None, ) - input_state = input_state_buffer if return_cache_state else None + + input_state = input_state_buffer if output_state else None if not self.use_activation_inside_kernel: hidden_states = self.activation_function(hidden_states) @@ -114,7 +113,7 @@ def forward( if input_state is None: hidden_states = hidden_states.transpose(-1, -2) - if return_cache_state: + if output_state: # F.pad trims the hidden_states if sequence_length > kernel_size input_state = F.pad(hidden_states, (self.kernel_size - S, 0)) @@ -134,7 +133,7 @@ def forward( if self.bias is not None: hidden_states = hidden_states + self.bias - if not return_cache_state: + if not output_state: input_state = None hidden_states = self.activation_function(hidden_states) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index ff3c125a5..ef6e0def6 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -10,14 +10,14 @@ from ....enums import Kernel from ....kernels import is_kernel_allowed -from ....utils import divide_if_divisible, is_causal_conv1d_available, is_mamba_2_ssm_available +from ....utils import divide_if_divisible, is_mamba_2_ssm_available from ...cache import ConstantCache, GenerationCache, GenerationState from ...parameter import ( mark_parameter_as_initialized, mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay, ) -from ..activations import get_activation_function, silu +from ..activations import silu from ..decay_gate import SoftplusDecayGate from ..depthwise_causal_convolution import DepthwiseCausalConvolution, _apply_mask_to_padding_states from ..init_utils import _get_std_for_linear @@ -29,9 +29,6 @@ from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update - def _pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int) -> torch.Tensor: """ @@ -131,10 +128,6 @@ def __init__( self.layer_idx = layer_idx self.use_conv_bias = use_conv_bias - self.activation_string = ssm_activation_function - self.activation = get_activation_function(self.activation_string) - self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] - self.n_groups = num_groups self.head_dim = divide_if_divisible(ssm_intermediate_size, ssm_num_heads, "") self.chunk_size = chunk_size @@ -146,7 +139,7 @@ def __init__( self.conv1d = DepthwiseCausalConvolution( hidden_size=self.conv_dim, kernel_size=self.conv_kernel_size, - activation_function=self.activation_string, + activation_function=ssm_activation_function, add_bias=add_bias, std=_get_std_for_linear( initializer_range=initializer_range, @@ -256,16 +249,12 @@ def _torch_forward( # 2. Convolution sequence transformation if use_precomputed_states: - conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, -1] = hidden_states_B_C[:, 0, :].to(conv_state.device) - - # We need to guarantee that anything regarding the cache is on the same device - conv_state = conv_state.to(device=self.conv1d.weight.device) - - hidden_states_B_C = torch.sum(conv_state * self.conv1d.weight.squeeze(1), dim=-1) - if self.use_conv_bias: - hidden_states_B_C = hidden_states_B_C + self.conv1d.bias - hidden_states_B_C = self.activation(hidden_states_B_C) + hidden_states_B_C, conv_state = self.conv1d( + hidden_states=hidden_states_B_C, + input_state=conv_state, + attention_mask=attention_mask, + return_cache_state=True, + ) else: # Init cache if cache_params is not None: @@ -279,13 +268,11 @@ def _torch_forward( dtype=dtype, ) - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) - conv_state = F.pad( - hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) - ) - - hidden_states_B_C = self.activation( - self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + hidden_states_B_C, conv_state = self.conv1d( + hidden_states=hidden_states_B_C, + input_state=None, + attention_mask=attention_mask, + return_cache_state=cache_params is not None, ) hidden_states_B_C = _apply_mask_to_padding_states(hidden_states_B_C, attention_mask) @@ -507,12 +494,11 @@ def _cuda_forward( ) # 2. Convolution sequence transformation - hidden_states_B_C = causal_conv1d_update( - hidden_states_B_C, - conv_state, - self.conv1d.weight.squeeze(1), - self.conv1d.bias, - self.activation_string, + hidden_states_B_C, conv_state = self.conv1d( + hidden_states=hidden_states_B_C[:, None, :], + input_state=conv_state, + attention_mask=None, + return_cache_state=cache_params is not None, ) hidden_states, B, C = torch.split( @@ -589,36 +575,13 @@ def _cuda_forward( [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) - # 2. Convolution sequence transformation - # Init cache - if cache_params is not None: - # storing the states - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - conv_state = F.pad( - hidden_states_B_C_transposed, - (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), - ) - - cache_params.update( - states=( - GenerationState(state=conv_state, method=ConstantCache), - GenerationState(state=ssm_state, method=ConstantCache), - ), - layer_idx=self.layer_idx, - ) - - hidden_states_B_C = causal_conv1d_fn( - x=hidden_states_B_C_transposed, - weight=self.conv1d.weight.squeeze(1), - bias=self.conv1d.bias, - activation=self.activation_string if self.use_activation_inside_kernel else None, - ).transpose(1, 2) - - if not self.use_activation_inside_kernel: - hidden_states_B_C = self.activation(hidden_states_B_C) + hidden_states_B_C, conv_state = self.conv1d( + hidden_states=hidden_states_B_C, + input_state=None, + attention_mask=attention_mask, + return_cache_state=cache_params is not None, + ) hidden_states_B_C = _apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( From 9a0ed238656f4843b2defc03066cbf79ab86892c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 30 Apr 2026 00:23:36 -0700 Subject: [PATCH 18/19] init Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/gated_deltanet.py | 2 +- .../hf_models/modeling_utils/sequence_mixer_blocks/gru.py | 2 +- .../modeling_utils/sequence_mixer_blocks/m2rnn.py | 2 +- .../modeling_utils/sequence_mixer_blocks/mamba2.py | 8 ++++---- .../hf_models/modeling_utils/sequence_mixer_blocks/rnn.py | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py index 3b7f0a742..3e32bc2c5 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py @@ -164,7 +164,7 @@ def forward( hidden_states=qkv, input_state=c, attention_mask=attention_mask, - return_cache_state=cache_params is not None, + output_state=cache_params is not None, ) q, k, v = qkv.split((self.key_dim, self.key_dim, self.value_dim), dim=-1) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py index d066a44c8..407221fb4 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py @@ -208,7 +208,7 @@ def forward( hidden_states=x, input_state=c, attention_mask=attention_mask, - return_cache_state=cache_params is not None, + output_state=cache_params is not None, ) x, xf, xr = [i.view(*i.size()[:-1], -1, self.state_head_dim) for i in (x, xf, xr)] diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py index fbcff95f5..43e1ff04c 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py @@ -213,7 +213,7 @@ def forward( hidden_states=x, input_state=c, attention_mask=attention_mask, - return_cache_state=cache_params is not None, + output_state=cache_params is not None, ) q, k, v = x.split((self.q_shape, self.k_shape, self.v_shape), dim=-1) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index ef6e0def6..09907c4aa 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -253,7 +253,7 @@ def _torch_forward( hidden_states=hidden_states_B_C, input_state=conv_state, attention_mask=attention_mask, - return_cache_state=True, + output_state=True, ) else: # Init cache @@ -272,7 +272,7 @@ def _torch_forward( hidden_states=hidden_states_B_C, input_state=None, attention_mask=attention_mask, - return_cache_state=cache_params is not None, + output_state=cache_params is not None, ) hidden_states_B_C = _apply_mask_to_padding_states(hidden_states_B_C, attention_mask) @@ -498,7 +498,7 @@ def _cuda_forward( hidden_states=hidden_states_B_C[:, None, :], input_state=conv_state, attention_mask=None, - return_cache_state=cache_params is not None, + output_state=cache_params is not None, ) hidden_states, B, C = torch.split( @@ -580,7 +580,7 @@ def _cuda_forward( hidden_states=hidden_states_B_C, input_state=None, attention_mask=attention_mask, - return_cache_state=cache_params is not None, + output_state=cache_params is not None, ) hidden_states_B_C = _apply_mask_to_padding_states(hidden_states_B_C, attention_mask) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 8d1046c2f..b7f2fd8c5 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -178,7 +178,7 @@ def forward( hidden_states=x, input_state=c, attention_mask=attention_mask, - return_cache_state=cache_params is not None, + output_state=cache_params is not None, ) x = x.view(*x.size()[:-1], -1, self.state_head_dim) From fef2fa2b73442309feedebbf66104eff29f9a39a Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 30 Apr 2026 00:28:02 -0700 Subject: [PATCH 19/19] init Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/depthwise_causal_convolution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lm_engine/hf_models/modeling_utils/depthwise_causal_convolution.py b/lm_engine/hf_models/modeling_utils/depthwise_causal_convolution.py index beab701ae..6d5947058 100644 --- a/lm_engine/hf_models/modeling_utils/depthwise_causal_convolution.py +++ b/lm_engine/hf_models/modeling_utils/depthwise_causal_convolution.py @@ -61,6 +61,7 @@ def __init__( self.activation_string = activation_function self.activation_function = get_activation_function(self.activation_string) self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] + self.kernel_size = kernel_size mark_parameter_as_no_weight_decay(self.bias)