Commit 23fe9ff
committed
Add SDPA and Flash Attention support for PatchTST model
- Add _supports_sdpa = True and _supports_flash_attn = True to PatchTSTPreTrainedModel
- The existing PatchTSTAttention class already uses ALL_ATTENTION_FUNCTIONS
to select the attention implementation based on config._attn_implementation
- Fix test_modeling_patchtst.py _prepare_for_class for dynamic batch sizes1 parent d08b98b commit 23fe9ff
File tree
2 files changed
+11
-6
lines changed- src/transformers/models/patchtst
- tests/models/patchtst
2 files changed
+11
-6
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
418 | 418 | | |
419 | 419 | | |
420 | 420 | | |
421 | | - | |
| 421 | + | |
422 | 422 | | |
423 | 423 | | |
424 | 424 | | |
| |||
555 | 555 | | |
556 | 556 | | |
557 | 557 | | |
| 558 | + | |
| 559 | + | |
558 | 560 | | |
559 | 561 | | |
560 | 562 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
184 | 184 | | |
185 | 185 | | |
186 | 186 | | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
187 | 190 | | |
188 | 191 | | |
189 | | - | |
| 192 | + | |
190 | 193 | | |
191 | 194 | | |
192 | 195 | | |
193 | | - | |
| 196 | + | |
194 | 197 | | |
195 | | - | |
| 198 | + | |
196 | 199 | | |
197 | 200 | | |
198 | | - | |
| 201 | + | |
199 | 202 | | |
200 | | - | |
| 203 | + | |
201 | 204 | | |
202 | 205 | | |
203 | 206 | | |
| |||
0 commit comments