-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Add SDPA support for PatchTST model #42465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
kashif
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good thanks
|
Seems like the documentation build is failing due to improperly closed tag during the build. Unrelated to my changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, we no longer support implementing seperate Attention classes for each attention flavor. There are some older models which haven't been refactored yet for this case.
Please take a look at models like Albert or Bert (a bit messy because it has additional enc-dec logic) which already implement this. The essence is to have
- One Attention class
- Specific attributes like
is_causal,scaling,num_attention_headsto directly reuse across flavors - Generic code around everything around the core attn mechanism, e.g. the projections and views/reshaping
- The interface (
ALL_ATTENTION_FUNCTIONS) that calls the underlying flavor for us
9233cd1 to
34deed0
Compare
- Add _supports_sdpa = True and _supports_flash_attn = True to PatchTSTPreTrainedModel - The existing PatchTSTAttention class already uses ALL_ATTENTION_FUNCTIONS to select the attention implementation based on config._attn_implementation - Fix test_modeling_patchtst.py _prepare_for_class for dynamic batch sizes
34deed0 to
23fe9ff
Compare
Thanks for the feedback. Took a look and updated accordingly. Much cleaner now. The existing PatchTSTAttention already uses ALL_ATTENTION_FUNCTIONS. |
|
@Furkan-rgb I fixed the SLOW tests and explicitly testing with sdpa |
|
run-slow: patchtst |
|
This comment contains models: ["models/patchtst"] |
CI Results✅ No failing test specific to this PR 🎉 ! |
|
run-slow: patchtst |
|
💔 This comment contains |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: patchtst |
|
run-slow: patchtst |
|
This comment contains models: ["models/patchtst"] |
CI Results✅ No failing test specific to this PR 🎉 ! |
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a small comment about the deepspeed addition
I was free to update some more to add flex attn and enable some other tests. SDPA is enabled by default, so no need to set the attn implementation here. We should probably update the title tho ~ something along Update supported atttns for PatchTST
| position_enc = module._init_pe(self.config, num_patches) | ||
| if is_deepspeed_zero3_enabled(): | ||
| import deepspeed | ||
|
|
||
| with deepspeed.zero.GatheredParameters(module.position_enc, modifier_rank=None): | ||
| if module.position_enc.numel() > 0: | ||
| init.copy_(module.position_enc, position_enc) | ||
| else: | ||
| init.copy_(module.position_enc, position_enc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not exactly against this but why was this added to this PR? Can we move this to a separate PR? cc @kashif
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the slow test were failing without this change...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, thx for adding then
I will probably merge this PR tomorrow then, we cutting last PRs for v5 rn
What does this PR do?
Adds SDPA (Scaled Dot Product Attention) support for the PatchTST model.
Changes:
PatchTSTSdpaAttentionclass usingtorch.nn.functional.scaled_dot_product_attentionPatchTSTEncoderLayerbased onconfig._attn_implementation_supports_sdpa = TruetoPatchTSTPreTrainedModeltest_modeling_patchtst.py_prepare_for_classmethod for proper dynamic batch size handlingTesting:
All 67 PatchTST tests pass (103 skipped as expected):
Notes:
output_attentions=Truesince SDPA doesn't return attention weights