diff --git a/lm_engine/hf_models/__init__.py b/lm_engine/hf_models/__init__.py index 1f876ae80..7e90cbffa 100644 --- a/lm_engine/hf_models/__init__.py +++ b/lm_engine/hf_models/__init__.py @@ -2,6 +2,8 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** +import torch + from .config import CommonConfig from .loss import get_autoregressive_language_modeling_loss, is_aux_loss_zero from .mixins import CausalLMOutputWithPast, PipelineParallelInput, PipelineParallelOutput @@ -39,3 +41,31 @@ register_model_classes() + + +def _patch_granitemoehybrid_weight_init() -> None: + try: + from transformers.models.granitemoehybrid.modeling_granitemoehybrid import ( + GraniteMoeHybridMambaLayer, + GraniteMoeHybridPreTrainedModel, + ) + except Exception: + return + + if getattr(GraniteMoeHybridPreTrainedModel._init_weights, "_lm_engine_patched", False): + return + + def _init_weights(self, module): + super(GraniteMoeHybridPreTrainedModel, self)._init_weights(module) + if isinstance(module, GraniteMoeHybridMambaLayer): + module.dt_bias.data.fill_(1.0) + module.A_log.data.copy_( + torch.log(torch.arange(1, module.num_heads + 1, dtype=module.A_log.dtype, device=module.A_log.device)) + ) + module.D.data.fill_(1.0) + + _init_weights._lm_engine_patched = True + GraniteMoeHybridPreTrainedModel._init_weights = _init_weights + + +_patch_granitemoehybrid_weight_init() diff --git a/lm_engine/hf_models/config/__init__.py b/lm_engine/hf_models/config/__init__.py index daa6499db..f32472530 100644 --- a/lm_engine/hf_models/config/__init__.py +++ b/lm_engine/hf_models/config/__init__.py @@ -11,15 +11,7 @@ from ...utils import BaseArgs, divide_if_divisible from .mlp import _MLPArgs, _MoEArgs -from .sequence_mixer import ( - _CausalConvolution, - _GatedDeltaNetArgs, - _GRUArgs, - _M2RNNArgs, - _Mamba2Args, - _RNNArgs, - _SoftmaxAttentionArgs, -) +from .sequence_mixer import _GatedDeltaNetArgs, _GRUArgs, _M2RNNArgs, _Mamba2Args, _RNNArgs, _SoftmaxAttentionArgs def _hold_base_args(key: str) -> Callable: @@ -37,7 +29,6 @@ def _run(self, *args, **kwargs): _SEQUENCE_MIXER_CONFIG_CLASSES = { - "causal_convolution": _CausalConvolution, "gru": _GRUArgs, "m2rnn": _M2RNNArgs, "mamba2": _Mamba2Args, @@ -183,13 +174,7 @@ def _set_sequence_mixer_blocks(self) -> None: self.sequence_mixer_blocks = [{} for _ in range(self.num_layers)] sequence_mixer_blocks: list[ - _CausalConvolution - | _GRUArgs - | _Mamba2Args - | _RNNArgs - | _M2RNNArgs - | _SoftmaxAttentionArgs - | _GatedDeltaNetArgs + _GRUArgs | _Mamba2Args | _RNNArgs | _M2RNNArgs | _SoftmaxAttentionArgs | _GatedDeltaNetArgs ] = [] for i in range(self.num_layers): sequence_mixer_block = deepcopy(self.sequence_mixer_blocks[i]) diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index 6028c95db..5516139f5 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -134,19 +134,6 @@ def model_post_init(self, __context: Any) -> None: assert self.sequence_mixer_type == "m2rnn" -class _CausalConvolution(BaseArgs): - sequence_mixer_type: str = "causal_convolution" - activation_function: str = "silu" - in_channels: int - out_channels: int - kernel_size: int - num_groups: int - add_bias: bool = False - - def model_post_init(self, __context: Any) -> None: - assert self.sequence_mixer_type == "causal_convolution" - - class _GatedDeltaNetArgs(_SoftPlusDecayArgs): sequence_mixer_type: str = "gated_deltanet" k_head_dim: int diff --git a/lm_engine/hf_models/modeling_utils/__init__.py b/lm_engine/hf_models/modeling_utils/__init__.py index 3404b87d3..f121a32bc 100644 --- a/lm_engine/hf_models/modeling_utils/__init__.py +++ b/lm_engine/hf_models/modeling_utils/__init__.py @@ -3,7 +3,6 @@ # ************************************************** from .activations import get_activation_function, is_glu -from .convolution import ParameterizedConv1d from .dropout import Dropout from .dtensor_module import DTensorModule from .embedding import ParameterizedEmbedding, get_tensor_parallel_vocab_info diff --git a/lm_engine/hf_models/modeling_utils/convolution.py b/lm_engine/hf_models/modeling_utils/convolution.py deleted file mode 100644 index 1dcf09758..000000000 --- a/lm_engine/hf_models/modeling_utils/convolution.py +++ /dev/null @@ -1,57 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - - -from __future__ import annotations - -import torch -import torch.nn as nn - -from ..parameter import mark_parameter_as_initialized, mark_parameter_as_no_weight_decay - - -class ParameterizedConv1d(nn.Conv1d): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: str | int = 0, - dilation: int = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = "zeros", # TODO: refine this type - device=None, - dtype=None, - std: float | None = None, - ) -> ParameterizedConv1d: - self.std = std - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - mark_parameter_as_no_weight_decay(self.bias) - - @torch.no_grad() - def reset_parameters(self) -> None: - if self.std is None: - super().reset_parameters() - else: - nn.init.normal_(self.weight, mean=0, std=self.std) - if hasattr(self, "bias") and self.bias is not None: - self.bias.zero_() - - mark_parameter_as_initialized(self.weight) - mark_parameter_as_initialized(self.bias) diff --git a/lm_engine/hf_models/modeling_utils/depthwise_causal_convolution.py b/lm_engine/hf_models/modeling_utils/depthwise_causal_convolution.py new file mode 100644 index 000000000..6d5947058 --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/depthwise_causal_convolution.py @@ -0,0 +1,155 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...enums import Kernel +from ...kernels import is_kernel_allowed +from ...utils import divide_if_divisible, is_causal_conv1d_available +from ..parameter import ( + mark_parameter_as_initialized, + mark_parameter_as_mup_learning_rate, + mark_parameter_as_no_weight_decay, +) +from .activations import get_activation_function + + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + + +def _apply_mask_to_padding_states(hidden_states: torch.Tensor, attention_mask: torch.Tensor | None) -> torch.Tensor: + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +class DepthwiseCausalConvolution(nn.Conv1d): + def __init__( + self, + hidden_size: int, + kernel_size: int, + activation_function: str, + add_bias: bool, + std: float | None, + use_padding_free_transformer: bool, + ) -> DepthwiseCausalConvolution: + if use_padding_free_transformer: + raise NotImplementedError() + + self.std = std + + super().__init__( + in_channels=hidden_size, + out_channels=hidden_size, + kernel_size=kernel_size, + padding=kernel_size - 1, + groups=hidden_size, + bias=add_bias, + ) + + self.activation_string = activation_function + self.activation_function = get_activation_function(self.activation_string) + self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] + self.kernel_size = kernel_size + + mark_parameter_as_no_weight_decay(self.bias) + + self.reset_parameters() + + def forward( + self, + hidden_states: torch.Tensor, + input_state: torch.Tensor | None, + attention_mask: torch.Tensor | None, + output_state: bool, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + S = hidden_states.size(1) + hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) + + if is_kernel_allowed(Kernel.causal_conv1d): + if input_state is None: + hidden_states = hidden_states.transpose(-1, -2) + + if output_state: + # F.pad trims the hidden_states if sequence_length > kernel_size + input_state = F.pad(hidden_states, (self.kernel_size - S, 0)) + + hidden_states = causal_conv1d_fn( + x=hidden_states, + weight=self.weight.squeeze(1), + bias=self.bias, + activation=self.activation_string if self.use_activation_inside_kernel else None, + ) + + hidden_states = hidden_states.transpose(-1, -2) + else: + assert S == 1 + + input_state_buffer = input_state.clone() + + hidden_states = causal_conv1d_update( + x=hidden_states, + conv_state=input_state_buffer, + weight=self.weight.squeeze(1), + bias=self.bias, + activation=self.activation_string if self.use_activation_inside_kernel else None, + ) + + input_state = input_state_buffer if output_state else None + + if not self.use_activation_inside_kernel: + hidden_states = self.activation_function(hidden_states) + else: + if input_state is None: + hidden_states = hidden_states.transpose(-1, -2) + + if output_state: + # F.pad trims the hidden_states if sequence_length > kernel_size + input_state = F.pad(hidden_states, (self.kernel_size - S, 0)) + + hidden_states = super().forward(hidden_states) + + # removes padding on the right side of the sequence + hidden_states = hidden_states[..., : 1 - self.kernel_size] + hidden_states = hidden_states.transpose(-1, -2) + else: + assert S == 1 + + input_state = input_state.roll(shifts=-1, dims=-1) + input_state[..., -1] = hidden_states[:, 0] + + hidden_states = (input_state * self.weight.squeeze(1)).sum(dim=-1) + hidden_states = hidden_states[:, None, :] + if self.bias is not None: + hidden_states = hidden_states + self.bias + + if not output_state: + input_state = None + + hidden_states = self.activation_function(hidden_states) + hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) + + return hidden_states, input_state + + @torch.no_grad() + def reset_parameters(self) -> None: + if self.std is None: + super().reset_parameters() + else: + nn.init.normal_(self.weight, mean=0, std=self.std) + if hasattr(self, "bias") and self.bias is not None: + self.bias.zero_() + + mark_parameter_as_initialized(self.weight) + mark_parameter_as_initialized(self.bias) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index 3e2f97d37..d624f02b1 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -9,7 +9,6 @@ interleave_query_key_value_tensor_for_attention, split_query_key_value_tensor_for_attention, ) -from .causal_convolution import CausalConvolution from .gated_deltanet import GatedDeltaNet from .gru import GRU from .m2rnn import M2RNN @@ -18,7 +17,7 @@ from .utils import flash_attention -SEQUENCE_MIXER_TYPE = Attention | CausalConvolution | GRU | Mamba2 | RNN | GatedDeltaNet +SEQUENCE_MIXER_TYPE = Attention | GRU | Mamba2 | RNN | GatedDeltaNet def get_sequence_mixer( @@ -29,25 +28,7 @@ def get_sequence_mixer( is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() - if sequence_mixer_type == "causal_convolution": - assert not is_tp_enabled - return CausalConvolution( - hidden_size=config.hidden_size, - in_channels=block.in_channels, - out_channels=block.out_channels, - kernel_size=block.kernel_size, - num_groups=block.num_groups, - activation_function=block.activation_function, - add_bias=block.add_bias, - initializer_range=config.initializer_range, - m_width=config.m_width, - init_method=config.init_method, - num_layers=config.num_layers, - layer_idx=layer_idx, - use_depth_scaled_init=config.use_depth_scaled_init, - use_padding_free_transformer=use_padding_free_transformer, - ) - elif sequence_mixer_type == "gru": + if sequence_mixer_type == "gru": assert not is_tp_enabled return GRU( input_size=config.hidden_size, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py deleted file mode 100644 index c730539bc..000000000 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/causal_convolution.py +++ /dev/null @@ -1,261 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ....enums import Kernel -from ....kernels import is_kernel_allowed -from ....utils import divide_if_divisible, is_causal_conv1d_available -from ...cache import ConstantCache, GenerationCache, GenerationState -from ...parameter import mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay -from ..activations import get_activation_function, is_glu -from ..convolution import ParameterizedConv1d -from ..init_utils import _get_std_for_linear -from ..linear import ParameterizedLinear - - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update - - -def _apply_mask_to_padding_states(hidden_states: torch.Tensor, attention_mask: torch.Tensor | None) -> torch.Tensor: - """ - Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 - """ - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - dtype = hidden_states.dtype - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - - return hidden_states - - -def causal_convolution( - hidden_states: torch.Tensor, - input_state: torch.Tensor | None, - attention_mask: torch.Tensor | None, - conv1d_weight: torch.Tensor, - conv1d_bias: torch.Tensor | None, - conv1d_num_groups: int, - return_cache_state: bool, - activation_string: str, - conv1d_padding: int, - conv1d_stride: int = 1, -) -> tuple[torch.Tensor, torch.Tensor]: - casual_conv1d_compatible = conv1d_num_groups == conv1d_weight.size(0) and conv1d_weight.size(1) == 1 - sequence_length = hidden_states.size(1) - kernel_size = conv1d_weight.size(-1) - - assert conv1d_stride == 1 - assert conv1d_padding == kernel_size - 1 - - hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) - - if is_kernel_allowed(Kernel.causal_conv1d) and casual_conv1d_compatible: - use_activation_inside_kernel = activation_string in [None, "silu", "swish"] - - if input_state is None: - hidden_states = hidden_states.transpose(-1, -2) - - if return_cache_state: - # F.pad trims the hidden_states if sequence_length > kernel_size - input_state = F.pad(hidden_states, (kernel_size - sequence_length, 0)) - - hidden_states = causal_conv1d_fn( - x=hidden_states, - weight=conv1d_weight.squeeze(1), - bias=conv1d_bias, - activation=activation_string if use_activation_inside_kernel else None, - ) - - hidden_states = hidden_states.transpose(-1, -2) - else: - assert sequence_length == 1 - - # we clone to prevent modification in-place - # torch compile can remove the clone if its not needed - # this is to prevent silent incorrectness down the line in the model - input_state_buffer = input_state.clone() - hidden_states = causal_conv1d_update( - x=hidden_states, - conv_state=input_state_buffer, - weight=conv1d_weight.squeeze(1), - bias=conv1d_bias, - activation=activation_string if use_activation_inside_kernel else None, - ) - input_state = input_state_buffer if return_cache_state else None - - if not use_activation_inside_kernel: - hidden_states = get_activation_function(activation_string)(hidden_states) - else: - if input_state is None: - hidden_states = hidden_states.transpose(-1, -2) - - if return_cache_state: - # F.pad trims the hidden_states if sequence_length > kernel_size - input_state = F.pad(hidden_states, (kernel_size - sequence_length, 0)) - - hidden_states = F.conv1d( - input=hidden_states, - weight=conv1d_weight, - bias=conv1d_bias, - stride=conv1d_stride, - padding=conv1d_padding, - groups=conv1d_num_groups, - ) - - # removes padding on the right side of the sequence - hidden_states = hidden_states[..., : 1 - kernel_size] - hidden_states = hidden_states.transpose(-1, -2) - else: - assert sequence_length == 1 - - input_state = input_state.roll(shifts=-1, dims=-1) - input_state[..., -1] = hidden_states[:, 0] - - hidden_states = (input_state * conv1d_weight.squeeze(1)).sum(dim=-1) - hidden_states = hidden_states[:, None, :] - if conv1d_bias is not None: - hidden_states = hidden_states + conv1d_bias - - if not return_cache_state: - input_state = None - - hidden_states = get_activation_function(activation_string)(hidden_states) - hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) - - return hidden_states, input_state - - -class CausalConvolution(nn.Module): - def __init__( - self, - hidden_size: int, - in_channels: int, - out_channels: int, - kernel_size: int, - num_groups: int, - activation_function: str, - add_bias: bool, - initializer_range: float | None, - m_width: float, - init_method: str, - num_layers: int, - layer_idx: int, - use_depth_scaled_init: bool, - use_padding_free_transformer: bool, - ) -> CausalConvolution: - super().__init__() - - if use_padding_free_transformer: - raise NotImplementedError() - - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.num_groups = num_groups - self.layer_idx = layer_idx - self.activation_string = activation_function - - self.input_projection = ParameterizedLinear( - hidden_size, - in_channels, - bias=add_bias, - std=_get_std_for_linear( - initializer_range=initializer_range, - init_method=init_method, - m_width=m_width, - fan_in=hidden_size, - num_layers=num_layers, - use_depth_scaled_init=False, - ), - ) - - divide_if_divisible(in_channels, num_groups, "") - divide_if_divisible(out_channels, num_groups, "") - - if is_glu(self.activation_string): - intermediate_size = divide_if_divisible(out_channels, 2, "") - else: - intermediate_size = out_channels - - self.conv1d = ParameterizedConv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - bias=add_bias, - padding=kernel_size - 1, - groups=num_groups, - std=_get_std_for_linear( - initializer_range=initializer_range, - init_method=init_method, - m_width=m_width, - fan_in=kernel_size, - num_layers=num_layers, - use_depth_scaled_init=False, - ), - ) - - self.activation_function = get_activation_function(self.activation_string) - - self.output_projection = ParameterizedLinear( - intermediate_size, - hidden_size, - bias=add_bias, - std=_get_std_for_linear( - initializer_range=initializer_range, - init_method=init_method, - m_width=m_width, - fan_in=intermediate_size, - num_layers=num_layers, - use_depth_scaled_init=use_depth_scaled_init, - ), - ) - - self.casual_conv1d_compatible = self.num_groups == self.in_channels == self.out_channels - self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] - - mark_parameter_as_mup_learning_rate(self.input_projection.weight) - mark_parameter_as_mup_learning_rate(self.conv1d.weight) - mark_parameter_as_mup_learning_rate(self.output_projection.weight) - - mark_parameter_as_no_weight_decay(self.input_projection.bias) - mark_parameter_as_no_weight_decay(self.conv1d.bias) - mark_parameter_as_no_weight_decay(self.output_projection.bias) - - def forward( - self, x: torch.Tensor, cache_params: GenerationCache | None = None, attention_mask: torch.Tensor | None = None - ) -> torch.Tensor: - input_state = ( - None if cache_params is None else cache_params.get_cache(layer_idx=self.layer_idx, empty_value=None) - ) - - S = x.size(1) - x = self.input_projection(x) - - x, input_state = causal_convolution( - hidden_states=x, - input_state=input_state, - attention_mask=attention_mask, - conv1d_weight=self.conv1d.weight, - conv1d_bias=self.conv1d.bias, - conv1d_num_groups=self.num_groups, - return_cache_state=cache_params is not None, - activation_string=self.activation_string, - conv1d_padding=self.kernel_size - 1, - conv1d_stride=1, - ) - - if cache_params is not None: - cache_params.update( - states=(GenerationState(state=input_state, method=ConstantCache, num_tokens_added=S),), - layer_idx=self.layer_idx, - ) - - x = self.output_projection(x) - - return x diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py index 17ba86282..3e32bc2c5 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py @@ -13,12 +13,11 @@ from ....utils import divide_if_divisible, is_fla_available from ...cache import ConstantCache, GenerationCache, GenerationState from ..activations import silu -from ..convolution import ParameterizedConv1d from ..decay_gate import SoftplusDecayGate +from ..depthwise_causal_convolution import DepthwiseCausalConvolution from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function -from .causal_convolution import causal_convolution from .utils import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence @@ -105,13 +104,11 @@ def __init__( ) self.conv_size = conv_size - self.qkv_conv1d = ParameterizedConv1d( - in_channels=2 * self.key_dim + self.value_dim, - out_channels=2 * self.key_dim + self.value_dim, + self.qkv_conv1d = DepthwiseCausalConvolution( + hidden_size=2 * self.key_dim + self.value_dim, kernel_size=conv_size, - padding=conv_size - 1, - groups=2 * self.key_dim + self.value_dim, - bias=False, + activation_function="silu", + add_bias=False, std=_get_std_for_linear( initializer_range=initializer_range, init_method=init_method, @@ -120,8 +117,8 @@ def __init__( num_layers=num_layers, use_depth_scaled_init=False, ), + use_padding_free_transformer=use_padding_free_transformer, ) - self.activation_string = "silu" self.o_norm = get_normalization_function("rmsnorm", self.v_head_dim, eps=norm_eps) self.o_proj = ParameterizedLinear( @@ -163,17 +160,11 @@ def forward( else: a, b = self.ab_proj(hidden_states).chunk(2, dim=-1) - qkv, c = causal_convolution( + qkv, c = self.qkv_conv1d( hidden_states=qkv, input_state=c, attention_mask=attention_mask, - conv1d_weight=self.qkv_conv1d.weight, - conv1d_bias=self.qkv_conv1d.bias, - conv1d_num_groups=qkv.size(-1), - return_cache_state=cache_params is not None, - activation_string=self.activation_string, - conv1d_padding=self.conv_size - 1, - conv1d_stride=1, + output_state=cache_params is not None, ) q, k, v = qkv.split((self.key_dim, self.key_dim, self.value_dim), dim=-1) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py index 5ebc66993..407221fb4 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py @@ -17,11 +17,10 @@ mark_parameter_as_no_weight_decay, ) from ..activations import clip_gradients, get_activation_function, is_glu, sigmoid, silu, tanh -from ..convolution import ParameterizedConv1d +from ..depthwise_causal_convolution import DepthwiseCausalConvolution from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function -from .causal_convolution import causal_convolution from .utils import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence @@ -120,13 +119,11 @@ def __init__( if kernel_size is not None: assert not is_glu(self.activation_string) - self.conv1d = ParameterizedConv1d( - in_channels=self.state_size, - out_channels=self.state_size, + self.conv1d = DepthwiseCausalConvolution( + hidden_size=self.state_size, kernel_size=kernel_size, - bias=add_bias, - padding=kernel_size - 1, - groups=self.state_size, + activation_function=self.activation_string, + add_bias=add_bias, std=_get_std_for_linear( initializer_range=initializer_range, init_method=init_method, @@ -135,6 +132,7 @@ def __init__( num_layers=num_layers, use_depth_scaled_init=False, ), + use_padding_free_transformer=use_padding_free_transformer, ) mark_parameter_as_mup_learning_rate(self.conv1d.weight) @@ -206,17 +204,11 @@ def forward( if self.kernel_size is None: x = self.activation_function(x) else: - x, c = causal_convolution( + x, c = self.conv1d( hidden_states=x, input_state=c, attention_mask=attention_mask, - conv1d_weight=self.conv1d.weight, - conv1d_bias=self.conv1d.bias, - conv1d_num_groups=self.state_size, - return_cache_state=cache_params is not None, - activation_string=self.activation_string, - conv1d_padding=self.kernel_size - 1, - conv1d_stride=1, + output_state=cache_params is not None, ) x, xf, xr = [i.view(*i.size()[:-1], -1, self.state_head_dim) for i in (x, xf, xr)] diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py index 090dddf1b..43e1ff04c 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/m2rnn.py @@ -20,12 +20,11 @@ mark_parameter_as_no_weight_decay, ) from ..activations import clip_gradients, is_glu, silu, tanh -from ..convolution import ParameterizedConv1d from ..decay_gate import SoftplusDecayGate +from ..depthwise_causal_convolution import DepthwiseCausalConvolution from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function -from .causal_convolution import causal_convolution from .utils import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence @@ -132,13 +131,11 @@ def __init__( else: assert self.activation_string is None or not is_glu(self.activation_string) - self.conv1d = ParameterizedConv1d( - in_channels=self.conv_dim, - out_channels=self.conv_dim, + self.conv1d = DepthwiseCausalConvolution( + hidden_size=self.conv_dim, kernel_size=kernel_size, - bias=add_bias, - padding=kernel_size - 1, - groups=self.conv_dim, + activation_function=self.activation_string, + add_bias=add_bias, std=_get_std_for_linear( initializer_range=initializer_range, init_method=init_method, @@ -147,6 +144,7 @@ def __init__( num_layers=num_layers, use_depth_scaled_init=False, ), + use_padding_free_transformer=use_padding_free_transformer, ) mark_parameter_as_mup_learning_rate(self.conv1d.weight) @@ -211,17 +209,11 @@ def forward( f = self.decay_gate(f, final_exponential=True, output_dtype=f.dtype) if self.kernel_size is not None: - x, c = causal_convolution( + x, c = self.conv1d( hidden_states=x, input_state=c, attention_mask=attention_mask, - conv1d_weight=self.conv1d.weight, - conv1d_bias=self.conv1d.bias, - conv1d_num_groups=self.conv_dim, - return_cache_state=cache_params is not None, - activation_string=self.activation_string, - conv1d_padding=self.kernel_size - 1, - conv1d_stride=1, + output_state=cache_params is not None, ) q, k, v = x.split((self.q_shape, self.k_shape, self.v_shape), dim=-1) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index 7de4ee434..09907c4aa 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -10,29 +10,25 @@ from ....enums import Kernel from ....kernels import is_kernel_allowed -from ....utils import divide_if_divisible, is_causal_conv1d_available, is_mamba_2_ssm_available +from ....utils import divide_if_divisible, is_mamba_2_ssm_available from ...cache import ConstantCache, GenerationCache, GenerationState from ...parameter import ( mark_parameter_as_initialized, mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay, ) -from ..activations import get_activation_function, silu -from ..convolution import ParameterizedConv1d +from ..activations import silu from ..decay_gate import SoftplusDecayGate +from ..depthwise_causal_convolution import DepthwiseCausalConvolution, _apply_mask_to_padding_states from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function -from .causal_convolution import _apply_mask_to_padding_states if is_mamba_2_ssm_available(): from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update - def _pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int) -> torch.Tensor: """ @@ -132,10 +128,6 @@ def __init__( self.layer_idx = layer_idx self.use_conv_bias = use_conv_bias - self.activation_string = ssm_activation_function - self.activation = get_activation_function(self.activation_string) - self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] - self.n_groups = num_groups self.head_dim = divide_if_divisible(ssm_intermediate_size, ssm_num_heads, "") self.chunk_size = chunk_size @@ -144,13 +136,11 @@ def __init__( # 1D convolutional layer self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size - self.conv1d = ParameterizedConv1d( - in_channels=self.conv_dim, - out_channels=self.conv_dim, - bias=use_conv_bias, + self.conv1d = DepthwiseCausalConvolution( + hidden_size=self.conv_dim, kernel_size=self.conv_kernel_size, - groups=self.conv_dim, - padding=self.conv_kernel_size - 1, + activation_function=ssm_activation_function, + add_bias=add_bias, std=_get_std_for_linear( initializer_range=initializer_range, init_method=init_method, @@ -159,6 +149,7 @@ def __init__( num_layers=num_layers, use_depth_scaled_init=False, ), + use_padding_free_transformer=False, ) # projection of the input hidden states @@ -258,16 +249,12 @@ def _torch_forward( # 2. Convolution sequence transformation if use_precomputed_states: - conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, -1] = hidden_states_B_C[:, 0, :].to(conv_state.device) - - # We need to guarantee that anything regarding the cache is on the same device - conv_state = conv_state.to(device=self.conv1d.weight.device) - - hidden_states_B_C = torch.sum(conv_state * self.conv1d.weight.squeeze(1), dim=-1) - if self.use_conv_bias: - hidden_states_B_C = hidden_states_B_C + self.conv1d.bias - hidden_states_B_C = self.activation(hidden_states_B_C) + hidden_states_B_C, conv_state = self.conv1d( + hidden_states=hidden_states_B_C, + input_state=conv_state, + attention_mask=attention_mask, + output_state=True, + ) else: # Init cache if cache_params is not None: @@ -275,19 +262,17 @@ def _torch_forward( ssm_state = torch.zeros( batch_size, self.num_heads, - divide_if_divisible(self.intermediate_size, self.num_heads, ""), + divide_if_divisible(self.intermediate_size, self.num_heads), self.ssm_state_size, device=projected_states.device, dtype=dtype, ) - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) - conv_state = F.pad( - hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) - ) - - hidden_states_B_C = self.activation( - self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + hidden_states_B_C, conv_state = self.conv1d( + hidden_states=hidden_states_B_C, + input_state=None, + attention_mask=attention_mask, + output_state=cache_params is not None, ) hidden_states_B_C = _apply_mask_to_padding_states(hidden_states_B_C, attention_mask) @@ -509,12 +494,11 @@ def _cuda_forward( ) # 2. Convolution sequence transformation - hidden_states_B_C = causal_conv1d_update( - hidden_states_B_C, - conv_state, - self.conv1d.weight.squeeze(1), - self.conv1d.bias, - self.activation_string, + hidden_states_B_C, conv_state = self.conv1d( + hidden_states=hidden_states_B_C[:, None, :], + input_state=conv_state, + attention_mask=None, + output_state=cache_params is not None, ) hidden_states, B, C = torch.split( @@ -591,36 +575,13 @@ def _cuda_forward( [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) - # 2. Convolution sequence transformation - # Init cache - if cache_params is not None: - # storing the states - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - conv_state = F.pad( - hidden_states_B_C_transposed, - (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), - ) - - cache_params.update( - states=( - GenerationState(state=conv_state, method=ConstantCache), - GenerationState(state=ssm_state, method=ConstantCache), - ), - layer_idx=self.layer_idx, - ) - - hidden_states_B_C = causal_conv1d_fn( - x=hidden_states_B_C_transposed, - weight=self.conv1d.weight.squeeze(1), - bias=self.conv1d.bias, - activation=self.activation_string if self.use_activation_inside_kernel else None, - ).transpose(1, 2) - - if not self.use_activation_inside_kernel: - hidden_states_B_C = self.activation(hidden_states_B_C) + hidden_states_B_C, conv_state = self.conv1d( + hidden_states=hidden_states_B_C, + input_state=None, + attention_mask=attention_mask, + output_state=cache_params is not None, + ) hidden_states_B_C = _apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py index 8c1c64c38..b7f2fd8c5 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py @@ -17,11 +17,10 @@ mark_parameter_as_no_weight_decay, ) from ..activations import clip_gradients, get_activation_function, is_glu, silu, tanh -from ..convolution import ParameterizedConv1d +from ..depthwise_causal_convolution import DepthwiseCausalConvolution from ..init_utils import _get_std_for_linear from ..linear import ParameterizedLinear from ..normalization import get_normalization_function -from .causal_convolution import causal_convolution from .utils import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence @@ -97,13 +96,11 @@ def __init__( if kernel_size is not None: assert not is_glu(self.activation_string) - self.conv1d = ParameterizedConv1d( - in_channels=self.state_size, - out_channels=self.state_size, + self.conv1d = DepthwiseCausalConvolution( + hidden_size=self.state_size, kernel_size=kernel_size, - bias=add_bias, - padding=kernel_size - 1, - groups=self.state_size, + activation_function=self.activation_string, + add_bias=add_bias, std=_get_std_for_linear( initializer_range=initializer_range, init_method=init_method, @@ -112,6 +109,7 @@ def __init__( num_layers=num_layers, use_depth_scaled_init=False, ), + use_padding_free_transformer=use_padding_free_transformer, ) mark_parameter_as_mup_learning_rate(self.conv1d.weight) @@ -176,17 +174,11 @@ def forward( if self.kernel_size is None: x = self.activation_function(x) else: - x, c = causal_convolution( + x, c = self.conv1d( hidden_states=x, input_state=c, attention_mask=attention_mask, - conv1d_weight=self.conv1d.weight, - conv1d_bias=self.conv1d.bias, - conv1d_num_groups=self.state_size, - return_cache_state=cache_params is not None, - activation_string=self.activation_string, - conv1d_padding=self.kernel_size - 1, - conv1d_stride=1, + output_state=cache_params is not None, ) x = x.view(*x.size()[:-1], -1, self.state_head_dim)