-
Notifications
You must be signed in to change notification settings - Fork 29
[REFACTOR] drop convolution and only retain depthwise causal convolution #420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
aa70aa0
ea6910d
3967ae3
76a645d
d09913c
b9a7837
588ea11
0113fbf
87bcf9b
531778a
02b8eb2
362ee50
69d314f
ff96940
6fbd1c3
eebc571
fdeaf36
4fd9123
9a0ed23
fef2fa2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,155 @@ | ||||||||||
| # ************************************************** | ||||||||||
| # Copyright (c) 2025, Mayank Mishra | ||||||||||
| # ************************************************** | ||||||||||
|
|
||||||||||
| from __future__ import annotations | ||||||||||
|
|
||||||||||
| import torch | ||||||||||
| import torch.nn as nn | ||||||||||
| import torch.nn.functional as F | ||||||||||
|
|
||||||||||
| from ...enums import Kernel | ||||||||||
| from ...kernels import is_kernel_allowed | ||||||||||
| from ...utils import divide_if_divisible, is_causal_conv1d_available | ||||||||||
| from ..parameter import ( | ||||||||||
| mark_parameter_as_initialized, | ||||||||||
| mark_parameter_as_mup_learning_rate, | ||||||||||
|
Comment on lines
+13
to
+16
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||
| mark_parameter_as_no_weight_decay, | ||||||||||
| ) | ||||||||||
| from .activations import get_activation_function | ||||||||||
|
|
||||||||||
|
|
||||||||||
| if is_causal_conv1d_available(): | ||||||||||
| from causal_conv1d import causal_conv1d_fn, causal_conv1d_update | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _apply_mask_to_padding_states(hidden_states: torch.Tensor, attention_mask: torch.Tensor | None) -> torch.Tensor: | ||||||||||
| """ | ||||||||||
| Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 | ||||||||||
| """ | ||||||||||
| if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The check
Suggested change
|
||||||||||
| dtype = hidden_states.dtype | ||||||||||
| hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) | ||||||||||
|
|
||||||||||
| return hidden_states | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class DepthwiseCausalConvolution(nn.Conv1d): | ||||||||||
| def __init__( | ||||||||||
| self, | ||||||||||
| hidden_size: int, | ||||||||||
| kernel_size: int, | ||||||||||
| activation_function: str, | ||||||||||
| add_bias: bool, | ||||||||||
| std: float | None, | ||||||||||
| use_padding_free_transformer: bool, | ||||||||||
| ) -> DepthwiseCausalConvolution: | ||||||||||
| if use_padding_free_transformer: | ||||||||||
| raise NotImplementedError() | ||||||||||
|
|
||||||||||
| self.std = std | ||||||||||
|
|
||||||||||
| super().__init__( | ||||||||||
| in_channels=hidden_size, | ||||||||||
| out_channels=hidden_size, | ||||||||||
| kernel_size=kernel_size, | ||||||||||
| padding=kernel_size - 1, | ||||||||||
| groups=hidden_size, | ||||||||||
| bias=add_bias, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| self.activation_string = activation_function | ||||||||||
| self.activation_function = get_activation_function(self.activation_string) | ||||||||||
| self.use_activation_inside_kernel = self.activation_string in [None, "silu", "swish"] | ||||||||||
| self.kernel_size = kernel_size | ||||||||||
|
|
||||||||||
| mark_parameter_as_no_weight_decay(self.bias) | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If
Suggested change
|
||||||||||
|
|
||||||||||
| self.reset_parameters() | ||||||||||
|
|
||||||||||
| def forward( | ||||||||||
| self, | ||||||||||
| hidden_states: torch.Tensor, | ||||||||||
| input_state: torch.Tensor | None, | ||||||||||
| attention_mask: torch.Tensor | None, | ||||||||||
| output_state: bool, | ||||||||||
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||||||||||
| S = hidden_states.size(1) | ||||||||||
| hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) | ||||||||||
|
|
||||||||||
| if is_kernel_allowed(Kernel.causal_conv1d): | ||||||||||
| if input_state is None: | ||||||||||
| hidden_states = hidden_states.transpose(-1, -2) | ||||||||||
|
|
||||||||||
| if output_state: | ||||||||||
| # F.pad trims the hidden_states if sequence_length > kernel_size | ||||||||||
| input_state = F.pad(hidden_states, (self.kernel_size - S, 0)) | ||||||||||
|
|
||||||||||
| hidden_states = causal_conv1d_fn( | ||||||||||
| x=hidden_states, | ||||||||||
| weight=self.weight.squeeze(1), | ||||||||||
| bias=self.bias, | ||||||||||
| activation=self.activation_string if self.use_activation_inside_kernel else None, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| hidden_states = hidden_states.transpose(-1, -2) | ||||||||||
| else: | ||||||||||
| assert S == 1 | ||||||||||
|
|
||||||||||
| input_state_buffer = input_state.clone() | ||||||||||
|
|
||||||||||
| hidden_states = causal_conv1d_update( | ||||||||||
| x=hidden_states, | ||||||||||
| conv_state=input_state_buffer, | ||||||||||
| weight=self.weight.squeeze(1), | ||||||||||
| bias=self.bias, | ||||||||||
| activation=self.activation_string if self.use_activation_inside_kernel else None, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| input_state = input_state_buffer if output_state else None | ||||||||||
|
|
||||||||||
| if not self.use_activation_inside_kernel: | ||||||||||
| hidden_states = self.activation_function(hidden_states) | ||||||||||
|
Comment on lines
+80
to
+112
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The output of the convolution is not masked in the |
||||||||||
| else: | ||||||||||
| if input_state is None: | ||||||||||
| hidden_states = hidden_states.transpose(-1, -2) | ||||||||||
|
|
||||||||||
| if output_state: | ||||||||||
| # F.pad trims the hidden_states if sequence_length > kernel_size | ||||||||||
| input_state = F.pad(hidden_states, (self.kernel_size - S, 0)) | ||||||||||
|
|
||||||||||
| hidden_states = super().forward(hidden_states) | ||||||||||
|
|
||||||||||
| # removes padding on the right side of the sequence | ||||||||||
| hidden_states = hidden_states[..., : 1 - self.kernel_size] | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The slicing logic
Suggested change
|
||||||||||
| hidden_states = hidden_states.transpose(-1, -2) | ||||||||||
| else: | ||||||||||
| assert S == 1 | ||||||||||
|
|
||||||||||
| input_state = input_state.roll(shifts=-1, dims=-1) | ||||||||||
| input_state[..., -1] = hidden_states[:, 0] | ||||||||||
|
|
||||||||||
| hidden_states = (input_state * self.weight.squeeze(1)).sum(dim=-1) | ||||||||||
| hidden_states = hidden_states[:, None, :] | ||||||||||
| if self.bias is not None: | ||||||||||
| hidden_states = hidden_states + self.bias | ||||||||||
|
|
||||||||||
| if not output_state: | ||||||||||
| input_state = None | ||||||||||
|
|
||||||||||
| hidden_states = self.activation_function(hidden_states) | ||||||||||
| hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) | ||||||||||
|
|
||||||||||
| return hidden_states, input_state | ||||||||||
|
|
||||||||||
| @torch.no_grad() | ||||||||||
| def reset_parameters(self) -> None: | ||||||||||
| if self.std is None: | ||||||||||
| super().reset_parameters() | ||||||||||
| else: | ||||||||||
| nn.init.normal_(self.weight, mean=0, std=self.std) | ||||||||||
| if hasattr(self, "bias") and self.bias is not None: | ||||||||||
| self.bias.zero_() | ||||||||||
|
|
||||||||||
| mark_parameter_as_initialized(self.weight) | ||||||||||
| mark_parameter_as_initialized(self.bias) | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Monkey-patching
_init_weightsby callingsuper()._init_weightsinside the replacement function will bypass any initialization logic defined in the originalGraniteMoeHybridPreTrainedModel._init_weightsmethod. In thetransformerslibrary, model-specificPreTrainedModelsubclasses typically implement_init_weightsto handle their specific layers. By replacing the method and callingsuper(), you are skipping the original class's logic and calling the parent's (likelyPreTrainedModel) method instead.To correctly extend the existing initialization, you should save the original method and call it within your patch.