diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 051fd8a5e7d0..d71607cd89fc 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -16,6 +16,7 @@ import copy import math +from collections.abc import Callable from typing import Optional, Union import torch @@ -27,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...masking_utils import create_bidirectional_mask, create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, @@ -38,9 +40,11 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, + TransformersKwargs, auto_docstring, logging, ) @@ -157,6 +161,35 @@ def forward(self, hidden_states): return hidden_states +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class T5Attention(nn.Module): def __init__( self, @@ -165,6 +198,7 @@ def __init__( layer_idx: Optional[int] = None, ): super().__init__() + self.config = config self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias self.relative_attention_num_buckets = config.relative_attention_num_buckets @@ -182,6 +216,8 @@ def __init__( "when creating this class." ) + self.scaling = self.d_model**-0.5 + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -262,16 +298,16 @@ def compute_bias(self, query_length, key_length, device=None, cache_position=Non def forward( self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_values=None, - query_length=None, - use_cache=False, - output_attentions=False, - cache_position=None, - ): + hidden_states: torch.FloatTensor, + key_value_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_bias: Optional[torch.FloatTensor] = None, + query_length: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ @@ -319,7 +355,6 @@ def forward( past_key_values.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: key_length = key_states.shape[-2] @@ -327,31 +362,39 @@ def forward( real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=query_states.device, dtype=query_states.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: position_bias = self.compute_bias( - real_seq_length, key_length, device=scores.device, cache_position=cache_position + real_seq_length, key_length, device=query_states.device, cache_position=cache_position ) position_bias = position_bias[:, :, -seq_length:, :] - if mask is not None: - causal_mask = mask[:, :, :, : key_states.shape[-2]] + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] position_bias = position_bias + causal_mask - position_bias_masked = position_bias - scores += position_bias_masked - - # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + logger.warning_once( + "T5 uses relative position bias; SDPA/FlashAttention not supported, fall back to eager." + ) - attn_output = torch.matmul(attn_weights, value_states) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=position_bias, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + **kwargs, + ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = attn_output.view(batch_size, -1, self.inner_dim).contiguous() attn_output = self.o(attn_output) outputs = (attn_output, position_bias) @@ -547,6 +590,10 @@ class T5PreTrainedModel(PreTrainedModel): config: T5Config base_model_prefix = "transformer" supports_gradient_checkpointing = True + _supports_attention_backend = True + _supports_flash_attn = False + _supports_sdpa = False + _supports_flex_attn = False _can_compile_fullgraph = True _no_split_modules = ["T5Block"]