Skip to content

Commit 34deed0

Browse files
committed
Add SDPA support for PatchTST model
- Add _supports_sdpa = True to PatchTSTPreTrainedModel to enable SDPA - The existing PatchTSTAttention class already uses ALL_ATTENTION_FUNCTIONS to select the attention implementation based on config._attn_implementation - Fix test_modeling_patchtst.py _prepare_for_class for dynamic batch sizes
1 parent d08b98b commit 34deed0

File tree

2 files changed

+69
-7
lines changed

2 files changed

+69
-7
lines changed

src/transformers/models/patchtst/modeling_patchtst.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,57 @@ def forward(
153153
return attn_output, attn_weights, None
154154

155155

156+
class PatchTSTSdpaAttention(PatchTSTAttention):
157+
def forward(
158+
self,
159+
hidden_states: torch.Tensor,
160+
key_value_states: Optional[torch.Tensor] = None,
161+
attention_mask: Optional[torch.Tensor] = None,
162+
output_attentions: Optional[bool] = False,
163+
**kwargs: Unpack[FlashAttentionKwargs],
164+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
165+
if output_attentions:
166+
# SDPA cannot return weights. Fallback to parent (Eager) implementation.
167+
return super().forward(hidden_states, key_value_states, attention_mask, output_attentions, **kwargs)
168+
169+
bsz, tgt_len, _ = hidden_states.size()
170+
is_cross_attention = key_value_states is not None
171+
src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
172+
173+
# 1. Projections (Identical to original)
174+
query_states = self.q_proj(hidden_states)
175+
current_states = key_value_states if is_cross_attention else hidden_states
176+
key_states = self.k_proj(current_states)
177+
value_states = self.v_proj(current_states)
178+
179+
# 2. Reshape for SDPA (Batch, Heads, Seq, Dim) - Transpose required
180+
q_input_shape = (bsz, tgt_len, self.num_heads, self.head_dim)
181+
kv_input_shape = (bsz, src_len, self.num_heads, self.head_dim)
182+
183+
query_states = query_states.view(*q_input_shape).transpose(1, 2)
184+
key_states = key_states.view(*kv_input_shape).transpose(1, 2)
185+
value_states = value_states.view(*kv_input_shape).transpose(1, 2)
186+
187+
# 3. Execution
188+
# We pass attention_mask because the original implementation supported it.
189+
# SDPA handles broadcastable float masks automatically.
190+
attn_output = torch.nn.functional.scaled_dot_product_attention(
191+
query_states,
192+
key_states,
193+
value_states,
194+
attn_mask=attention_mask,
195+
dropout_p=self.dropout if self.training else 0.0,
196+
is_causal=self.is_causal,
197+
)
198+
199+
# 4. Output Projection (Identical to original)
200+
attn_output = attn_output.transpose(1, 2).contiguous()
201+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
202+
attn_output = self.out_proj(attn_output)
203+
204+
return attn_output, None, None
205+
206+
156207
class PatchTSTBatchNorm(nn.Module):
157208
"""
158209
Compute batch normalization over the sequence length (time) dimension.
@@ -418,8 +469,15 @@ def __init__(self, config: PatchTSTConfig):
418469
super().__init__()
419470

420471
self.channel_attention = config.channel_attention
421-
# Multi-Head attention
422-
self.self_attn = PatchTSTAttention(
472+
473+
if config._attn_implementation == "sdpa":
474+
self_attn_cls = PatchTSTSdpaAttention
475+
elif config._attn_implementation == "eager":
476+
self_attn_cls = PatchTSTAttention
477+
else:
478+
self_attn_cls = PatchTSTAttention
479+
480+
self.self_attn = self_attn_cls(
423481
embed_dim=config.d_model,
424482
num_heads=config.num_attention_heads,
425483
dropout=config.attention_dropout,
@@ -555,6 +613,7 @@ class PatchTSTPreTrainedModel(PreTrainedModel):
555613
main_input_name = "past_values"
556614
input_modalities = ("time",)
557615
supports_gradient_checkpointing = False
616+
_supports_sdpa = True
558617

559618
@torch.no_grad()
560619
def _init_weights(self, module: nn.Module):

tests/models/patchtst/test_modeling_patchtst.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,20 +184,23 @@ def test_config(self):
184184
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
185185
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
186186

187+
# Get the actual batch size from the inputs (may differ from model_tester.batch_size in some tests)
188+
batch_size = inputs_dict["past_values"].shape[0]
189+
187190
# if PatchTSTForPretraining
188191
if model_class == PatchTSTForPretraining:
189-
inputs_dict.pop("future_values")
192+
inputs_dict.pop("future_values", None)
190193
# else if classification model:
191194
elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING):
192195
rng = random.Random(self.model_tester.seed)
193-
labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng)
196+
labels = ids_tensor([batch_size], self.model_tester.num_targets, rng=rng)
194197
inputs_dict["target_values"] = labels
195-
inputs_dict.pop("future_values")
198+
inputs_dict.pop("future_values", None)
196199
elif model_class in get_values(MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING):
197200
rng = random.Random(self.model_tester.seed)
198-
target_values = floats_tensor([self.model_tester.batch_size, self.model_tester.num_targets], rng=rng)
201+
target_values = floats_tensor([batch_size, self.model_tester.num_targets], rng=rng)
199202
inputs_dict["target_values"] = target_values
200-
inputs_dict.pop("future_values")
203+
inputs_dict.pop("future_values", None)
201204
return inputs_dict
202205

203206
def test_save_load_strict(self):

0 commit comments

Comments
 (0)