Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 71 additions & 24 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import copy
import math
from collections.abc import Callable
from typing import Optional, Union

import torch
Expand All @@ -27,6 +28,7 @@
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...generation import GenerationMixin
from ...masking_utils import create_bidirectional_mask, create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
Expand All @@ -38,9 +40,11 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
DUMMY_INPUTS,
DUMMY_MASK,
TransformersKwargs,
auto_docstring,
logging,
)
Expand Down Expand Up @@ -157,6 +161,35 @@ def forward(self, hidden_states):
return hidden_states


def eager_attention_forward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather have the relative position bias within here, see #38301 or more specifically

def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
):
if scaling is None:
scaling = query.size(-1) ** -0.5
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3))
# Relative positional embeddings
if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query":
query_length, key_length = query.shape[2], key.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
if module.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
attn_weights = attn_weights + relative_position_scores
elif module.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key
# Scaling is shifted in case of embeddings being relative
attn_weights = attn_weights * scaling
if attention_mask is not None and attention_mask.ndim == 4:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
(no longer on main but should give you the idea how this should look like)

module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
if scaling is None:
scaling = query.size(-1) ** -0.5

# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling

if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attn_weights


class T5Attention(nn.Module):
def __init__(
self,
Expand All @@ -165,6 +198,7 @@ def __init__(
layer_idx: Optional[int] = None,
):
super().__init__()
self.config = config
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
Expand All @@ -182,6 +216,8 @@ def __init__(
"when creating this class."
)

self.scaling = self.d_model**-0.5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For completeness, we should have the is_causal flag here, you can look into Bart for this - i.e. encoder = False, decoder = False if self attn or True if cross attn.


self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
Expand Down Expand Up @@ -262,16 +298,16 @@ def compute_bias(self, query_length, key_length, device=None, cache_position=Non

def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_values=None,
query_length=None,
use_cache=False,
output_attentions=False,
cache_position=None,
):
hidden_states: torch.FloatTensor,
key_value_states: Optional[torch.FloatTensor] = None,
past_key_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_bias: Optional[torch.FloatTensor] = None,
query_length: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
Comment on lines +301 to +310
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not rename here, this would break BC. The type annotations are fine by itself.

"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
Expand Down Expand Up @@ -319,39 +355,46 @@ def forward(
past_key_values.is_updated[self.layer_idx] = True

# compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
scores = torch.matmul(query_states, key_states.transpose(3, 2))

if position_bias is None:
key_length = key_states.shape[-2]
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
(1, self.n_heads, seq_length, key_length), device=query_states.device, dtype=query_states.dtype
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device, cache_position=cache_position
real_seq_length, key_length, device=query_states.device, cache_position=cache_position
)
position_bias = position_bias[:, :, -seq_length:, :]

if mask is not None:
causal_mask = mask[:, :, :, : key_states.shape[-2]]
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask

position_bias_masked = position_bias
scores += position_bias_masked

# (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
logger.warning_once(
"T5 uses relative position bias; SDPA/FlashAttention not supported, fall back to eager."
)
Comment on lines +380 to +383
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never happen as we don't support anything other than eager. I would even raise an error here if anything


attn_output = torch.matmul(attn_weights, value_states)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=position_bias,
dropout=0.0 if not self.training else self.dropout,
scaling=self.scaling,
output_attentions=output_attentions,
**kwargs,
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
attn_output = attn_output.view(batch_size, -1, self.inner_dim).contiguous()
attn_output = self.o(attn_output)

outputs = (attn_output, position_bias)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice if we could refactor this along in this PR, we have an Outputrecorder which can handle collecting the weights. We no longer need to explicitly have the kwargs then. You can take a look at other model like Llama or t5gemma2 which do this. In essence, you need decorators (check_model_input, can_return_tuple) and the respective flag _can_record_outputs.

Expand Down Expand Up @@ -547,6 +590,10 @@ class T5PreTrainedModel(PreTrainedModel):
config: T5Config
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_supports_attention_backend = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not supported - kwargs are not used everywhere so far and enc-dec will need another look

_supports_flash_attn = False
_supports_sdpa = False
_supports_flex_attn = False
_can_compile_fullgraph = True

_no_split_modules = ["T5Block"]
Expand Down