[REFACTOR] drop convolution and only retain depthwise causal convolution#420
[REFACTOR] drop convolution and only retain depthwise causal convolution#420mayank31398 wants to merge 20 commits intomainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the convolution logic across several sequence mixer blocks by introducing a centralized DepthwiseCausalConvolution module and removing deprecated convolution components. It also adds zero-shot evaluation results and a weight initialization patch for Granite MoE Hybrid models. Feedback focuses on correcting the monkey-patching logic for weight initialization, fixing a parameter mismatch in Mamba2's bias configuration, addressing potential errors when bias parameters are null, ensuring consistent attention masking, and improving robustness for edge cases like unit kernel sizes.
| def _init_weights(self, module): | ||
| super(GraniteMoeHybridPreTrainedModel, self)._init_weights(module) | ||
| if isinstance(module, GraniteMoeHybridMambaLayer): | ||
| module.dt_bias.data.fill_(1.0) | ||
| module.A_log.data.copy_( | ||
| torch.log(torch.arange(1, module.num_heads + 1, dtype=module.A_log.dtype, device=module.A_log.device)) | ||
| ) | ||
| module.D.data.fill_(1.0) | ||
|
|
||
| _init_weights._lm_engine_patched = True | ||
| GraniteMoeHybridPreTrainedModel._init_weights = _init_weights |
There was a problem hiding this comment.
Monkey-patching _init_weights by calling super()._init_weights inside the replacement function will bypass any initialization logic defined in the original GraniteMoeHybridPreTrainedModel._init_weights method. In the transformers library, model-specific PreTrainedModel subclasses typically implement _init_weights to handle their specific layers. By replacing the method and calling super(), you are skipping the original class's logic and calling the parent's (likely PreTrainedModel) method instead.
To correctly extend the existing initialization, you should save the original method and call it within your patch.
| def _init_weights(self, module): | |
| super(GraniteMoeHybridPreTrainedModel, self)._init_weights(module) | |
| if isinstance(module, GraniteMoeHybridMambaLayer): | |
| module.dt_bias.data.fill_(1.0) | |
| module.A_log.data.copy_( | |
| torch.log(torch.arange(1, module.num_heads + 1, dtype=module.A_log.dtype, device=module.A_log.device)) | |
| ) | |
| module.D.data.fill_(1.0) | |
| _init_weights._lm_engine_patched = True | |
| GraniteMoeHybridPreTrainedModel._init_weights = _init_weights | |
| original_init_weights = GraniteMoeHybridPreTrainedModel._init_weights | |
| def _init_weights(self, module): | |
| original_init_weights(self, module) | |
| if isinstance(module, GraniteMoeHybridMambaLayer): | |
| module.dt_bias.data.fill_(1.0) | |
| module.A_log.data.copy_( | |
| torch.log(torch.arange(1, module.num_heads + 1, dtype=module.A_log.dtype, device=module.A_log.device)) | |
| ) | |
| module.D.data.fill_(1.0) |
| groups=self.conv_dim, | ||
| padding=self.conv_kernel_size - 1, | ||
| activation_function=self.activation_string, | ||
| add_bias=add_bias, |
There was a problem hiding this comment.
In the refactor, bias=use_conv_bias was replaced with add_bias=add_bias. In Mamba2, add_bias typically refers to the linear projections, while use_conv_bias specifically controls the convolution layer's bias. This change might incorrectly enable or disable the convolution bias based on the wrong configuration parameter.
| add_bias=add_bias, | |
| add_bias=use_conv_bias, |
| from ...utils import divide_if_divisible, is_causal_conv1d_available | ||
| from ..parameter import ( | ||
| mark_parameter_as_initialized, | ||
| mark_parameter_as_mup_learning_rate, |
| """ | ||
| Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 | ||
| """ | ||
| if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: |
There was a problem hiding this comment.
The check attention_mask.shape[0] > 1 prevents the mask from being applied when the batch size is 1. Padding can still exist in a single-sequence batch if the sequence length is less than the maximum length. The mask should be applied whenever it is provided to ensure padding tokens are correctly zeroed out.
| if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: | |
| if attention_mask is not None and attention_mask.shape[1] > 1: |
| self.activation_function = get_activation_function(self.activation_string) | ||
| self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] | ||
|
|
||
| mark_parameter_as_no_weight_decay(self.bias) |
There was a problem hiding this comment.
If add_bias is False, self.bias will be None. Calling mark_parameter_as_no_weight_decay on None might lead to an error depending on its implementation. It is safer to check if the bias exists first.
| mark_parameter_as_no_weight_decay(self.bias) | |
| if self.bias is not None: | |
| mark_parameter_as_no_weight_decay(self.bias) |
| if is_kernel_allowed(Kernel.causal_conv1d): | ||
| if input_state is None: | ||
| hidden_states = hidden_states.transpose(-1, -2) | ||
|
|
||
| if return_cache_state: | ||
| # F.pad trims the hidden_states if sequence_length > kernel_size | ||
| input_state = F.pad(hidden_states, (self.kernel_size - S, 0)) | ||
|
|
||
| hidden_states = causal_conv1d_fn( | ||
| x=hidden_states, | ||
| weight=self.weight.squeeze(1), | ||
| bias=self.bias, | ||
| activation=self.activation_string if self.use_activation_inside_kernel else None, | ||
| ) | ||
|
|
||
| hidden_states = hidden_states.transpose(-1, -2) | ||
| else: | ||
| assert S == 1 | ||
|
|
||
| # we clone to prevent modification in-place | ||
| # torch compile can remove the clone if its not needed | ||
| # this is to prevent silent incorrectness down the line in the model | ||
| input_state_buffer = input_state.clone() | ||
| hidden_states = causal_conv1d_update( | ||
| x=hidden_states, | ||
| conv_state=input_state_buffer, | ||
| weight=self.weight.squeeze(1), | ||
| bias=self.bias, | ||
| activation=self.activation_string if self.use_activation_inside_kernel else None, | ||
| ) | ||
| input_state = input_state_buffer if return_cache_state else None | ||
|
|
||
| if not self.use_activation_inside_kernel: | ||
| hidden_states = self.activation_function(hidden_states) |
There was a problem hiding this comment.
The output of the convolution is not masked in the is_kernel_allowed path. If a bias is present, the output at padding positions will be non-zero (the bias value after activation). This creates an inconsistency with the else path (line 141) where the output is explicitly masked. You should apply _apply_mask_to_padding_states to the output in both paths.
| hidden_states = super().forward(hidden_states) | ||
|
|
||
| # removes padding on the right side of the sequence | ||
| hidden_states = hidden_states[..., : 1 - self.kernel_size] |
There was a problem hiding this comment.
The slicing logic [: 1 - self.kernel_size] will result in an empty tensor if kernel_size is 1. While these models typically use larger kernels, the implementation should be robust to kernel_size=1.
| hidden_states = hidden_states[..., : 1 - self.kernel_size] | |
| # removes padding on the right side of the sequence | |
| if self.kernel_size > 1: | |
| hidden_states = hidden_states[..., : 1 - self.kernel_size] |
| self.bias.zero_() | ||
|
|
||
| mark_parameter_as_initialized(self.weight) | ||
| mark_parameter_as_initialized(self.bias) |
There was a problem hiding this comment.
| add_bias=add_bias, | ||
| std=_get_std_for_linear( | ||
| initializer_range=initializer_range, | ||
| init_method=init_method, |
There was a problem hiding this comment.
Not sure if this is intended or not, but _get_std_for_linear should also take the fan_in as an arg for the mup init.
No description provided.