From d678fb21edf8edb824404a6e405d9d174bf63f2d Mon Sep 17 00:00:00 2001 From: DuyguA Date: Thu, 27 Nov 2025 13:13:03 +0100 Subject: [PATCH 1/5] changes for new attention interface --- src/transformers/models/t5/modeling_t5.py | 91 ++++++++++++++++------- 1 file changed, 65 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 051fd8a5e7d0..0925e59bdbe2 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -16,6 +16,7 @@ import copy import math +from collections.abc import Callable from typing import Optional, Union import torch @@ -26,6 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...masking_utils import create_bidirectional_mask, create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -37,10 +39,12 @@ Seq2SeqSequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, + TransformersKwargs, auto_docstring, logging, ) @@ -156,6 +160,33 @@ def forward(self, hidden_states): hidden_states = hidden_states + self.dropout(forwarded_states) return hidden_states +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights class T5Attention(nn.Module): def __init__( @@ -165,6 +196,7 @@ def __init__( layer_idx: Optional[int] = None, ): super().__init__() + self.config = config self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias self.relative_attention_num_buckets = config.relative_attention_num_buckets @@ -182,6 +214,8 @@ def __init__( "when creating this class." ) + self.scaling = self.d_model**-0.5 + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -262,16 +296,16 @@ def compute_bias(self, query_length, key_length, device=None, cache_position=Non def forward( self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_values=None, - query_length=None, - use_cache=False, - output_attentions=False, - cache_position=None, - ): + hidden_states: torch.FloatTensor, + key_value_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_bias: Optional[torch.FloatTensor] = None, + query_length: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ @@ -319,7 +353,6 @@ def forward( past_key_values.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: key_length = key_states.shape[-2] @@ -327,31 +360,37 @@ def forward( real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=query_states.device, dtype=query_states.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: position_bias = self.compute_bias( - real_seq_length, key_length, device=scores.device, cache_position=cache_position + real_seq_length, key_length, device=query_states.device, cache_position=cache_position ) position_bias = position_bias[:, :, -seq_length:, :] - if mask is not None: - causal_mask = mask[:, :, :, : key_states.shape[-2]] + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] position_bias = position_bias + causal_mask - position_bias_masked = position_bias - scores += position_bias_masked - - # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.matmul(attn_weights, value_states) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=position_bias, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + **kwargs, + ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = attn_output.view(batch_size, -1, self.inner_dim).contiguous() attn_output = self.o(attn_output) outputs = (attn_output, position_bias) From fbe163f1c13408591f050acbdff6649f2b68fa16 Mon Sep 17 00:00:00 2001 From: DuyguA Date: Tue, 2 Dec 2025 11:52:53 +0100 Subject: [PATCH 2/5] no support for flash attn --- src/transformers/models/t5/modeling_t5.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0925e59bdbe2..c5d9f3b59408 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -27,8 +27,8 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...masking_utils import create_bidirectional_mask, create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, @@ -160,6 +160,7 @@ def forward(self, hidden_states): hidden_states = hidden_states + self.dropout(forwarded_states) return hidden_states + def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -188,6 +189,7 @@ def eager_attention_forward( return attn_output, attn_weights + class T5Attention(nn.Module): def __init__( self, @@ -586,6 +588,10 @@ class T5PreTrainedModel(PreTrainedModel): config: T5Config base_model_prefix = "transformer" supports_gradient_checkpointing = True + _supports_attention_backend = True + _supports_flash_attn = False + _supports_sdpa = False + _supports_flex_attn = False _can_compile_fullgraph = True _no_split_modules = ["T5Block"] From d93f0c4344b03a5afa96ad7b0cabd991f97e84e3 Mon Sep 17 00:00:00 2001 From: DuyguA Date: Tue, 2 Dec 2025 13:49:40 +0100 Subject: [PATCH 3/5] restrict only eager attention --- src/transformers/models/t5/modeling_t5.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index c5d9f3b59408..28205c82907e 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -39,7 +39,7 @@ Seq2SeqSequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import ( DUMMY_INPUTS, @@ -378,7 +378,9 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + logger.warning_once( + "T5 "T5 uses relative position bias; SDPA/FlashAttention not supported. Falling back to eager." + ) attn_output, attn_weights = attention_interface( self, From f9f47a444de9ee6a789d19927740b781588b09f0 Mon Sep 17 00:00:00 2001 From: DuyguA Date: Tue, 2 Dec 2025 13:55:42 +0100 Subject: [PATCH 4/5] fixed typo --- src/transformers/models/t5/modeling_t5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 28205c82907e..5b197093ad66 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -379,7 +379,7 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": logger.warning_once( - "T5 "T5 uses relative position bias; SDPA/FlashAttention not supported. Falling back to eager." + "T5 uses relative position bias; SDPA/FlashAttention not supported. Falling back to eager." ) attn_output, attn_weights = attention_interface( From 145b10f2b0913da8f9dda306e52055a5fde2560e Mon Sep 17 00:00:00 2001 From: DuyguA Date: Tue, 2 Dec 2025 13:56:49 +0100 Subject: [PATCH 5/5] minor cosmetics --- src/transformers/models/t5/modeling_t5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 5b197093ad66..d71607cd89fc 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -379,7 +379,7 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": logger.warning_once( - "T5 uses relative position bias; SDPA/FlashAttention not supported. Falling back to eager." + "T5 uses relative position bias; SDPA/FlashAttention not supported, fall back to eager." ) attn_output, attn_weights = attention_interface(