Skip to content

Conversation

@Furkan-rgb
Copy link

What does this PR do?

Adds SDPA (Scaled Dot Product Attention) support for the PatchTST model.

Changes:

  • Added PatchTSTSdpaAttention class using torch.nn.functional.scaled_dot_product_attention
  • Integrated attention class selection in PatchTSTEncoderLayer based on config._attn_implementation
  • Added _supports_sdpa = True to PatchTSTPreTrainedModel
  • Fixed test_modeling_patchtst.py _prepare_for_class method for proper dynamic batch size handling

Testing:

All 67 PatchTST tests pass (103 skipped as expected):

pytest tests/models/patchtst/test_modeling_patchtst.py

Notes:

  • The SDPA implementation falls back to eager attention when output_attentions=True since SDPA doesn't return attention weights
  • Uses the standard attention implementation pattern from other models in transformers

Copy link
Contributor

@kashif kashif left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good thanks

@Furkan-rgb
Copy link
Author

Seems like the documentation build is failing due to improperly closed tag during the build. Unrelated to my changes.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, we no longer support implementing seperate Attention classes for each attention flavor. There are some older models which haven't been refactored yet for this case.

Please take a look at models like Albert or Bert (a bit messy because it has additional enc-dec logic) which already implement this. The essence is to have

  • One Attention class
  • Specific attributes like is_causal, scaling, num_attention_heads to directly reuse across flavors
  • Generic code around everything around the core attn mechanism, e.g. the projections and views/reshaping
  • The interface (ALL_ATTENTION_FUNCTIONS) that calls the underlying flavor for us

@Furkan-rgb Furkan-rgb force-pushed the add-sdpa-support-patchtst branch from 9233cd1 to 34deed0 Compare November 28, 2025 19:45
- Add _supports_sdpa = True and _supports_flash_attn = True to PatchTSTPreTrainedModel
- The existing PatchTSTAttention class already uses ALL_ATTENTION_FUNCTIONS
  to select the attention implementation based on config._attn_implementation
- Fix test_modeling_patchtst.py _prepare_for_class for dynamic batch sizes
@Furkan-rgb Furkan-rgb force-pushed the add-sdpa-support-patchtst branch from 34deed0 to 23fe9ff Compare November 28, 2025 19:48
@Furkan-rgb Furkan-rgb requested a review from vasqu November 28, 2025 19:49
@Furkan-rgb
Copy link
Author

Furkan-rgb commented Nov 28, 2025

Sorry, we no longer support implementing seperate Attention classes for each attention flavor. There are some older models which haven't been refactored yet for this case.

Please take a look at models like Albert or Bert (a bit messy because it has additional enc-dec logic) which already implement this. The essence is to have

  • One Attention class
  • Specific attributes like is_causal, scaling, num_attention_heads to directly reuse across flavors
  • Generic code around everything around the core attn mechanism, e.g. the projections and views/reshaping
  • The interface (ALL_ATTENTION_FUNCTIONS) that calls the underlying flavor for us

Thanks for the feedback. Took a look and updated accordingly. Much cleaner now. The existing PatchTSTAttention already uses ALL_ATTENTION_FUNCTIONS.

@kashif
Copy link
Contributor

kashif commented Nov 28, 2025

@Furkan-rgb I fixed the SLOW tests and explicitly testing with sdpa

@vasqu
Copy link
Contributor

vasqu commented Dec 1, 2025

run-slow: patchtst

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

This comment contains run-slow, running the specified jobs:

models: ["models/patchtst"]
quantizations: []

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@vasqu
Copy link
Contributor

vasqu commented Dec 1, 2025

run-slow: patchtst

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

💔 This comment contains run-slow, but unknown error occurred and the workflow run aborted!

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: patchtst

@vasqu
Copy link
Contributor

vasqu commented Dec 1, 2025

run-slow: patchtst

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

This comment contains run-slow, running the specified jobs:

models: ["models/patchtst"]
quantizations: []

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small comment about the deepspeed addition

I was free to update some more to add flex attn and enable some other tests. SDPA is enabled by default, so no need to set the attn implementation here. We should probably update the title tho ~ something along Update supported atttns for PatchTST

Comment on lines +578 to +586
position_enc = module._init_pe(self.config, num_patches)
if is_deepspeed_zero3_enabled():
import deepspeed

with deepspeed.zero.GatheredParameters(module.position_enc, modifier_rank=None):
if module.position_enc.numel() > 0:
init.copy_(module.position_enc, position_enc)
else:
init.copy_(module.position_enc, position_enc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly against this but why was this added to this PR? Can we move this to a separate PR? cc @kashif

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the slow test were failing without this change...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, thx for adding then

I will probably merge this PR tomorrow then, we cutting last PRs for v5 rn

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants