Skip to content

Commit 23fe9ff

Browse files
committed
Add SDPA and Flash Attention support for PatchTST model
- Add _supports_sdpa = True and _supports_flash_attn = True to PatchTSTPreTrainedModel - The existing PatchTSTAttention class already uses ALL_ATTENTION_FUNCTIONS to select the attention implementation based on config._attn_implementation - Fix test_modeling_patchtst.py _prepare_for_class for dynamic batch sizes
1 parent d08b98b commit 23fe9ff

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

src/transformers/models/patchtst/modeling_patchtst.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def __init__(self, config: PatchTSTConfig):
418418
super().__init__()
419419

420420
self.channel_attention = config.channel_attention
421-
# Multi-Head attention
421+
422422
self.self_attn = PatchTSTAttention(
423423
embed_dim=config.d_model,
424424
num_heads=config.num_attention_heads,
@@ -555,6 +555,8 @@ class PatchTSTPreTrainedModel(PreTrainedModel):
555555
main_input_name = "past_values"
556556
input_modalities = ("time",)
557557
supports_gradient_checkpointing = False
558+
_supports_flash_attn = True
559+
_supports_sdpa = True
558560

559561
@torch.no_grad()
560562
def _init_weights(self, module: nn.Module):

tests/models/patchtst/test_modeling_patchtst.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,20 +184,23 @@ def test_config(self):
184184
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
185185
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
186186

187+
# Get the actual batch size from the inputs (may differ from model_tester.batch_size in some tests)
188+
batch_size = inputs_dict["past_values"].shape[0]
189+
187190
# if PatchTSTForPretraining
188191
if model_class == PatchTSTForPretraining:
189-
inputs_dict.pop("future_values")
192+
inputs_dict.pop("future_values", None)
190193
# else if classification model:
191194
elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING):
192195
rng = random.Random(self.model_tester.seed)
193-
labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng)
196+
labels = ids_tensor([batch_size], self.model_tester.num_targets, rng=rng)
194197
inputs_dict["target_values"] = labels
195-
inputs_dict.pop("future_values")
198+
inputs_dict.pop("future_values", None)
196199
elif model_class in get_values(MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING):
197200
rng = random.Random(self.model_tester.seed)
198-
target_values = floats_tensor([self.model_tester.batch_size, self.model_tester.num_targets], rng=rng)
201+
target_values = floats_tensor([batch_size, self.model_tester.num_targets], rng=rng)
199202
inputs_dict["target_values"] = target_values
200-
inputs_dict.pop("future_values")
203+
inputs_dict.pop("future_values", None)
201204
return inputs_dict
202205

203206
def test_save_load_strict(self):

0 commit comments

Comments
 (0)