@@ -153,6 +153,57 @@ def forward(
153153 return attn_output , attn_weights , None
154154
155155
156+ class PatchTSTSdpaAttention (PatchTSTAttention ):
157+ def forward (
158+ self ,
159+ hidden_states : torch .Tensor ,
160+ key_value_states : Optional [torch .Tensor ] = None ,
161+ attention_mask : Optional [torch .Tensor ] = None ,
162+ output_attentions : Optional [bool ] = False ,
163+ ** kwargs : Unpack [FlashAttentionKwargs ],
164+ ) -> tuple [torch .Tensor , Optional [torch .Tensor ], Optional [tuple [torch .Tensor ]]]:
165+ if output_attentions :
166+ # SDPA cannot return weights. Fallback to parent (Eager) implementation.
167+ return super ().forward (hidden_states , key_value_states , attention_mask , output_attentions , ** kwargs )
168+
169+ bsz , tgt_len , _ = hidden_states .size ()
170+ is_cross_attention = key_value_states is not None
171+ src_len = key_value_states .shape [1 ] if is_cross_attention else tgt_len
172+
173+ # 1. Projections (Identical to original)
174+ query_states = self .q_proj (hidden_states )
175+ current_states = key_value_states if is_cross_attention else hidden_states
176+ key_states = self .k_proj (current_states )
177+ value_states = self .v_proj (current_states )
178+
179+ # 2. Reshape for SDPA (Batch, Heads, Seq, Dim) - Transpose required
180+ q_input_shape = (bsz , tgt_len , self .num_heads , self .head_dim )
181+ kv_input_shape = (bsz , src_len , self .num_heads , self .head_dim )
182+
183+ query_states = query_states .view (* q_input_shape ).transpose (1 , 2 )
184+ key_states = key_states .view (* kv_input_shape ).transpose (1 , 2 )
185+ value_states = value_states .view (* kv_input_shape ).transpose (1 , 2 )
186+
187+ # 3. Execution
188+ # We pass attention_mask because the original implementation supported it.
189+ # SDPA handles broadcastable float masks automatically.
190+ attn_output = torch .nn .functional .scaled_dot_product_attention (
191+ query_states ,
192+ key_states ,
193+ value_states ,
194+ attn_mask = attention_mask ,
195+ dropout_p = self .dropout if self .training else 0.0 ,
196+ is_causal = self .is_causal ,
197+ )
198+
199+ # 4. Output Projection (Identical to original)
200+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
201+ attn_output = attn_output .reshape (bsz , tgt_len , self .embed_dim )
202+ attn_output = self .out_proj (attn_output )
203+
204+ return attn_output , None , None
205+
206+
156207class PatchTSTBatchNorm (nn .Module ):
157208 """
158209 Compute batch normalization over the sequence length (time) dimension.
@@ -418,8 +469,15 @@ def __init__(self, config: PatchTSTConfig):
418469 super ().__init__ ()
419470
420471 self .channel_attention = config .channel_attention
421- # Multi-Head attention
422- self .self_attn = PatchTSTAttention (
472+
473+ if config ._attn_implementation == "sdpa" :
474+ self_attn_cls = PatchTSTSdpaAttention
475+ elif config ._attn_implementation == "eager" :
476+ self_attn_cls = PatchTSTAttention
477+ else :
478+ self_attn_cls = PatchTSTAttention
479+
480+ self .self_attn = self_attn_cls (
423481 embed_dim = config .d_model ,
424482 num_heads = config .num_attention_heads ,
425483 dropout = config .attention_dropout ,
@@ -555,6 +613,7 @@ class PatchTSTPreTrainedModel(PreTrainedModel):
555613 main_input_name = "past_values"
556614 input_modalities = ("time" ,)
557615 supports_gradient_checkpointing = False
616+ _supports_sdpa = True
558617
559618 @torch .no_grad ()
560619 def _init_weights (self , module : nn .Module ):
0 commit comments