-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Add SDPA and FlashAttention support to T5 #42453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d678fb2
fbe163f
ae02420
84cfb85
d93f0c4
f9f47a4
145b10f
d0e14b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
|
|
||
| import copy | ||
| import math | ||
| from collections.abc import Callable | ||
| from typing import Optional, Union | ||
|
|
||
| import torch | ||
|
|
@@ -27,6 +28,7 @@ | |
| from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache | ||
| from ...generation import GenerationMixin | ||
| from ...masking_utils import create_bidirectional_mask, create_causal_mask | ||
| from ...modeling_flash_attention_utils import FlashAttentionKwargs | ||
| from ...modeling_layers import GradientCheckpointingLayer | ||
| from ...modeling_outputs import ( | ||
| BaseModelOutput, | ||
|
|
@@ -38,9 +40,11 @@ | |
| TokenClassifierOutput, | ||
| ) | ||
| from ...modeling_utils import PreTrainedModel | ||
| from ...processing_utils import Unpack | ||
| from ...utils import ( | ||
| DUMMY_INPUTS, | ||
| DUMMY_MASK, | ||
| TransformersKwargs, | ||
| auto_docstring, | ||
| logging, | ||
| ) | ||
|
|
@@ -157,6 +161,35 @@ def forward(self, hidden_states): | |
| return hidden_states | ||
|
|
||
|
|
||
| def eager_attention_forward( | ||
| module: nn.Module, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attention_mask: Optional[torch.Tensor], | ||
| scaling: Optional[float] = None, | ||
| dropout: float = 0.0, | ||
| **kwargs: Unpack[TransformersKwargs], | ||
| ): | ||
| if scaling is None: | ||
| scaling = query.size(-1) ** -0.5 | ||
|
|
||
| # Take the dot product between "query" and "key" to get the raw attention scores. | ||
| attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling | ||
|
|
||
| if attention_mask is not None: | ||
| attention_mask = attention_mask[:, :, :, : key.shape[-2]] | ||
| attn_weights = attn_weights + attention_mask | ||
|
|
||
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) | ||
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) | ||
|
|
||
| attn_output = torch.matmul(attn_weights, value) | ||
| attn_output = attn_output.transpose(1, 2).contiguous() | ||
|
|
||
| return attn_output, attn_weights | ||
|
|
||
|
|
||
| class T5Attention(nn.Module): | ||
| def __init__( | ||
| self, | ||
|
|
@@ -165,6 +198,7 @@ def __init__( | |
| layer_idx: Optional[int] = None, | ||
| ): | ||
| super().__init__() | ||
| self.config = config | ||
| self.is_decoder = config.is_decoder | ||
| self.has_relative_attention_bias = has_relative_attention_bias | ||
| self.relative_attention_num_buckets = config.relative_attention_num_buckets | ||
|
|
@@ -182,6 +216,8 @@ def __init__( | |
| "when creating this class." | ||
| ) | ||
|
|
||
| self.scaling = self.d_model**-0.5 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For completeness, we should have the |
||
|
|
||
| self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) | ||
| self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) | ||
| self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) | ||
|
|
@@ -262,16 +298,16 @@ def compute_bias(self, query_length, key_length, device=None, cache_position=Non | |
|
|
||
| def forward( | ||
| self, | ||
| hidden_states, | ||
| mask=None, | ||
| key_value_states=None, | ||
| position_bias=None, | ||
| past_key_values=None, | ||
| query_length=None, | ||
| use_cache=False, | ||
| output_attentions=False, | ||
| cache_position=None, | ||
| ): | ||
| hidden_states: torch.FloatTensor, | ||
| key_value_states: Optional[torch.FloatTensor] = None, | ||
| past_key_values: Optional[torch.FloatTensor] = None, | ||
| attention_mask: Optional[torch.FloatTensor] = None, | ||
| position_bias: Optional[torch.FloatTensor] = None, | ||
| query_length: Optional[torch.LongTensor] = None, | ||
| output_attentions: Optional[bool] = False, | ||
| cache_position: Optional[torch.LongTensor] = None, | ||
| **kwargs: Unpack[FlashAttentionKwargs], | ||
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: | ||
|
Comment on lines
+301
to
+310
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not rename here, this would break BC. The type annotations are fine by itself. |
||
| """ | ||
| Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). | ||
| """ | ||
|
|
@@ -319,39 +355,46 @@ def forward( | |
| past_key_values.is_updated[self.layer_idx] = True | ||
|
|
||
| # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 | ||
| scores = torch.matmul(query_states, key_states.transpose(3, 2)) | ||
|
|
||
| if position_bias is None: | ||
| key_length = key_states.shape[-2] | ||
| # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) | ||
| real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 | ||
| if not self.has_relative_attention_bias: | ||
| position_bias = torch.zeros( | ||
| (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype | ||
| (1, self.n_heads, seq_length, key_length), device=query_states.device, dtype=query_states.dtype | ||
| ) | ||
| if self.gradient_checkpointing and self.training: | ||
| position_bias.requires_grad = True | ||
| else: | ||
| position_bias = self.compute_bias( | ||
| real_seq_length, key_length, device=scores.device, cache_position=cache_position | ||
| real_seq_length, key_length, device=query_states.device, cache_position=cache_position | ||
| ) | ||
| position_bias = position_bias[:, :, -seq_length:, :] | ||
|
|
||
| if mask is not None: | ||
| causal_mask = mask[:, :, :, : key_states.shape[-2]] | ||
| if attention_mask is not None: | ||
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] | ||
| position_bias = position_bias + causal_mask | ||
|
|
||
| position_bias_masked = position_bias | ||
| scores += position_bias_masked | ||
|
|
||
| # (batch_size, n_heads, seq_length, key_length) | ||
| attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) | ||
| attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) | ||
| attention_interface: Callable = eager_attention_forward | ||
| if self.config._attn_implementation != "eager": | ||
| logger.warning_once( | ||
| "T5 uses relative position bias; SDPA/FlashAttention not supported, fall back to eager." | ||
| ) | ||
|
Comment on lines
+380
to
+383
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should never happen as we don't support anything other than eager. I would even raise an error here if anything |
||
|
|
||
| attn_output = torch.matmul(attn_weights, value_states) | ||
| attn_output, attn_weights = attention_interface( | ||
| self, | ||
| query_states, | ||
| key_states, | ||
| value_states, | ||
| attention_mask=position_bias, | ||
| dropout=0.0 if not self.training else self.dropout, | ||
| scaling=self.scaling, | ||
| output_attentions=output_attentions, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| attn_output = attn_output.transpose(1, 2).contiguous() | ||
| attn_output = attn_output.view(batch_size, -1, self.inner_dim) | ||
| attn_output = attn_output.view(batch_size, -1, self.inner_dim).contiguous() | ||
| attn_output = self.o(attn_output) | ||
|
|
||
| outputs = (attn_output, position_bias) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be nice if we could refactor this along in this PR, we have an |
||
|
|
@@ -547,6 +590,10 @@ class T5PreTrainedModel(PreTrainedModel): | |
| config: T5Config | ||
| base_model_prefix = "transformer" | ||
| supports_gradient_checkpointing = True | ||
| _supports_attention_backend = True | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not supported - kwargs are not used everywhere so far and enc-dec will need another look |
||
| _supports_flash_attn = False | ||
| _supports_sdpa = False | ||
| _supports_flex_attn = False | ||
| _can_compile_fullgraph = True | ||
|
|
||
| _no_split_modules = ["T5Block"] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather have the relative position bias within here, see #38301 or more specifically
transformers/src/transformers/models/bert/modeling_bert.py
Lines 121 to 176 in 1c3188f