Skip to content

[REFACTOR] drop convolution and only retain depthwise causal convolution#420

Open
mayank31398 wants to merge 20 commits intomainfrom
conv
Open

[REFACTOR] drop convolution and only retain depthwise causal convolution#420
mayank31398 wants to merge 20 commits intomainfrom
conv

Conversation

@mayank31398
Copy link
Copy Markdown
Collaborator

No description provided.

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the convolution logic across several sequence mixer blocks by introducing a centralized DepthwiseCausalConvolution module and removing deprecated convolution components. It also adds zero-shot evaluation results and a weight initialization patch for Granite MoE Hybrid models. Feedback focuses on correcting the monkey-patching logic for weight initialization, fixing a parameter mismatch in Mamba2's bias configuration, addressing potential errors when bias parameters are null, ensuring consistent attention masking, and improving robustness for edge cases like unit kernel sizes.

Comment on lines +58 to +68
def _init_weights(self, module):
super(GraniteMoeHybridPreTrainedModel, self)._init_weights(module)
if isinstance(module, GraniteMoeHybridMambaLayer):
module.dt_bias.data.fill_(1.0)
module.A_log.data.copy_(
torch.log(torch.arange(1, module.num_heads + 1, dtype=module.A_log.dtype, device=module.A_log.device))
)
module.D.data.fill_(1.0)

_init_weights._lm_engine_patched = True
GraniteMoeHybridPreTrainedModel._init_weights = _init_weights
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Monkey-patching _init_weights by calling super()._init_weights inside the replacement function will bypass any initialization logic defined in the original GraniteMoeHybridPreTrainedModel._init_weights method. In the transformers library, model-specific PreTrainedModel subclasses typically implement _init_weights to handle their specific layers. By replacing the method and calling super(), you are skipping the original class's logic and calling the parent's (likely PreTrainedModel) method instead.

To correctly extend the existing initialization, you should save the original method and call it within your patch.

Suggested change
def _init_weights(self, module):
super(GraniteMoeHybridPreTrainedModel, self)._init_weights(module)
if isinstance(module, GraniteMoeHybridMambaLayer):
module.dt_bias.data.fill_(1.0)
module.A_log.data.copy_(
torch.log(torch.arange(1, module.num_heads + 1, dtype=module.A_log.dtype, device=module.A_log.device))
)
module.D.data.fill_(1.0)
_init_weights._lm_engine_patched = True
GraniteMoeHybridPreTrainedModel._init_weights = _init_weights
original_init_weights = GraniteMoeHybridPreTrainedModel._init_weights
def _init_weights(self, module):
original_init_weights(self, module)
if isinstance(module, GraniteMoeHybridMambaLayer):
module.dt_bias.data.fill_(1.0)
module.A_log.data.copy_(
torch.log(torch.arange(1, module.num_heads + 1, dtype=module.A_log.dtype, device=module.A_log.device))
)
module.D.data.fill_(1.0)

groups=self.conv_dim,
padding=self.conv_kernel_size - 1,
activation_function=self.activation_string,
add_bias=add_bias,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In the refactor, bias=use_conv_bias was replaced with add_bias=add_bias. In Mamba2, add_bias typically refers to the linear projections, while use_conv_bias specifically controls the convolution layer's bias. This change might incorrectly enable or disable the convolution bias based on the wrong configuration parameter.

Suggested change
add_bias=add_bias,
add_bias=use_conv_bias,

Comment on lines +13 to +16
from ...utils import divide_if_divisible, is_causal_conv1d_available
from ..parameter import (
mark_parameter_as_initialized,
mark_parameter_as_mup_learning_rate,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The following imports are unused in this file: divide_if_divisible and mark_parameter_as_mup_learning_rate.

from ...utils import is_causal_conv1d_available
from ..parameter import (
    mark_parameter_as_initialized,
    mark_parameter_as_no_weight_decay,
)

"""
Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
"""
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check attention_mask.shape[0] > 1 prevents the mask from being applied when the batch size is 1. Padding can still exist in a single-sequence batch if the sequence length is less than the maximum length. The mask should be applied whenever it is provided to ensure padding tokens are correctly zeroed out.

Suggested change
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
if attention_mask is not None and attention_mask.shape[1] > 1:

self.activation_function = get_activation_function(self.activation_string)
self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"]

mark_parameter_as_no_weight_decay(self.bias)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If add_bias is False, self.bias will be None. Calling mark_parameter_as_no_weight_decay on None might lead to an error depending on its implementation. It is safer to check if the bias exists first.

Suggested change
mark_parameter_as_no_weight_decay(self.bias)
if self.bias is not None:
mark_parameter_as_no_weight_decay(self.bias)

Comment on lines +79 to +112
if is_kernel_allowed(Kernel.causal_conv1d):
if input_state is None:
hidden_states = hidden_states.transpose(-1, -2)

if return_cache_state:
# F.pad trims the hidden_states if sequence_length > kernel_size
input_state = F.pad(hidden_states, (self.kernel_size - S, 0))

hidden_states = causal_conv1d_fn(
x=hidden_states,
weight=self.weight.squeeze(1),
bias=self.bias,
activation=self.activation_string if self.use_activation_inside_kernel else None,
)

hidden_states = hidden_states.transpose(-1, -2)
else:
assert S == 1

# we clone to prevent modification in-place
# torch compile can remove the clone if its not needed
# this is to prevent silent incorrectness down the line in the model
input_state_buffer = input_state.clone()
hidden_states = causal_conv1d_update(
x=hidden_states,
conv_state=input_state_buffer,
weight=self.weight.squeeze(1),
bias=self.bias,
activation=self.activation_string if self.use_activation_inside_kernel else None,
)
input_state = input_state_buffer if return_cache_state else None

if not self.use_activation_inside_kernel:
hidden_states = self.activation_function(hidden_states)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The output of the convolution is not masked in the is_kernel_allowed path. If a bias is present, the output at padding positions will be non-zero (the bias value after activation). This creates an inconsistency with the else path (line 141) where the output is explicitly masked. You should apply _apply_mask_to_padding_states to the output in both paths.

hidden_states = super().forward(hidden_states)

# removes padding on the right side of the sequence
hidden_states = hidden_states[..., : 1 - self.kernel_size]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The slicing logic [: 1 - self.kernel_size] will result in an empty tensor if kernel_size is 1. While these models typically use larger kernels, the implementation should be robust to kernel_size=1.

Suggested change
hidden_states = hidden_states[..., : 1 - self.kernel_size]
# removes padding on the right side of the sequence
if self.kernel_size > 1:
hidden_states = hidden_states[..., : 1 - self.kernel_size]

self.bias.zero_()

mark_parameter_as_initialized(self.weight)
mark_parameter_as_initialized(self.bias)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the initialization, mark_parameter_as_initialized should only be called on self.bias if it is not None.

Suggested change
mark_parameter_as_initialized(self.bias)
mark_parameter_as_initialized(self.weight)
if self.bias is not None:
mark_parameter_as_initialized(self.bias)

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
add_bias=add_bias,
std=_get_std_for_linear(
initializer_range=initializer_range,
init_method=init_method,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is intended or not, but _get_std_for_linear should also take the fan_in as an arg for the mup init.

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants