From 23fe9ff8507c15f81c72129188ce9158efa08c4b Mon Sep 17 00:00:00 2001 From: Furkan-rgb <50831308+Furkan-rgb@users.noreply.github.com> Date: Thu, 27 Nov 2025 21:02:54 +0000 Subject: [PATCH 1/5] Add SDPA and Flash Attention support for PatchTST model - Add _supports_sdpa = True and _supports_flash_attn = True to PatchTSTPreTrainedModel - The existing PatchTSTAttention class already uses ALL_ATTENTION_FUNCTIONS to select the attention implementation based on config._attn_implementation - Fix test_modeling_patchtst.py _prepare_for_class for dynamic batch sizes --- .../models/patchtst/modeling_patchtst.py | 4 +++- tests/models/patchtst/test_modeling_patchtst.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index d482efa5b832..6be83148539b 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -418,7 +418,7 @@ def __init__(self, config: PatchTSTConfig): super().__init__() self.channel_attention = config.channel_attention - # Multi-Head attention + self.self_attn = PatchTSTAttention( embed_dim=config.d_model, num_heads=config.num_attention_heads, @@ -555,6 +555,8 @@ class PatchTSTPreTrainedModel(PreTrainedModel): main_input_name = "past_values" input_modalities = ("time",) supports_gradient_checkpointing = False + _supports_flash_attn = True + _supports_sdpa = True @torch.no_grad() def _init_weights(self, module: nn.Module): diff --git a/tests/models/patchtst/test_modeling_patchtst.py b/tests/models/patchtst/test_modeling_patchtst.py index ba29095bf8ac..efbd4e83b98c 100644 --- a/tests/models/patchtst/test_modeling_patchtst.py +++ b/tests/models/patchtst/test_modeling_patchtst.py @@ -184,20 +184,23 @@ def test_config(self): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + # Get the actual batch size from the inputs (may differ from model_tester.batch_size in some tests) + batch_size = inputs_dict["past_values"].shape[0] + # if PatchTSTForPretraining if model_class == PatchTSTForPretraining: - inputs_dict.pop("future_values") + inputs_dict.pop("future_values", None) # else if classification model: elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING): rng = random.Random(self.model_tester.seed) - labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng) + labels = ids_tensor([batch_size], self.model_tester.num_targets, rng=rng) inputs_dict["target_values"] = labels - inputs_dict.pop("future_values") + inputs_dict.pop("future_values", None) elif model_class in get_values(MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING): rng = random.Random(self.model_tester.seed) - target_values = floats_tensor([self.model_tester.batch_size, self.model_tester.num_targets], rng=rng) + target_values = floats_tensor([batch_size, self.model_tester.num_targets], rng=rng) inputs_dict["target_values"] = target_values - inputs_dict.pop("future_values") + inputs_dict.pop("future_values", None) return inputs_dict def test_save_load_strict(self): From 1869b172f48a2b68837b3da8ca6b4084ebcce63a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 28 Nov 2025 21:33:56 +0100 Subject: [PATCH 2/5] Guard PatchTST positional init under ZeRO-3 --- src/transformers/models/patchtst/modeling_patchtst.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index 6be83148539b..4d6912698f51 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -24,6 +24,7 @@ from ... import initialization as init from ...activations import ACT2CLS +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -573,7 +574,15 @@ def _init_weights(self, module: nn.Module): init.normal_(module.cls_token, std=0.02) num_patches += 1 # initialize positional encoding - init.copy_(module.position_enc, module._init_pe(self.config, num_patches)) + position_enc = module._init_pe(self.config, num_patches) + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(module.position_enc, modifier_rank=None): + if module.position_enc.numel() > 0: + init.copy_(module.position_enc, position_enc) + else: + init.copy_(module.position_enc, position_enc) elif isinstance(module, nn.LayerNorm): init.zeros_(module.bias) init.ones_(module.weight) From d01e82d658ddfb1d801bd5f503e7ef3574ee6030 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 28 Nov 2025 21:50:44 +0100 Subject: [PATCH 3/5] Force SDPA in PatchTST regression integration test --- tests/models/patchtst/test_modeling_patchtst.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/patchtst/test_modeling_patchtst.py b/tests/models/patchtst/test_modeling_patchtst.py index efbd4e83b98c..8b329c32c205 100644 --- a/tests/models/patchtst/test_modeling_patchtst.py +++ b/tests/models/patchtst/test_modeling_patchtst.py @@ -372,6 +372,7 @@ def test_prediction_generation(self): def test_regression_generation(self): model = PatchTSTForRegression.from_pretrained("ibm/patchtst-etth1-regression-distribution").to(torch_device) + model.config.attn_implementation = "sdpa" batch = prepare_batch(repo_id="ibm/patchtst-etth1-test-data", file="regression_distribution_batch.pt") torch.manual_seed(0) From c53a0cf4bcb6aaa983d99dde056875ec4e1ff3f5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 28 Nov 2025 21:59:39 +0100 Subject: [PATCH 4/5] Use sdpa attn in PatchTST regression test --- tests/models/patchtst/test_modeling_patchtst.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/models/patchtst/test_modeling_patchtst.py b/tests/models/patchtst/test_modeling_patchtst.py index 8b329c32c205..5c1e5a9917f3 100644 --- a/tests/models/patchtst/test_modeling_patchtst.py +++ b/tests/models/patchtst/test_modeling_patchtst.py @@ -371,8 +371,9 @@ def test_prediction_generation(self): torch.testing.assert_close(mean_prediction[0, -1:], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) def test_regression_generation(self): - model = PatchTSTForRegression.from_pretrained("ibm/patchtst-etth1-regression-distribution").to(torch_device) - model.config.attn_implementation = "sdpa" + model = PatchTSTForRegression.from_pretrained( + "ibm/patchtst-etth1-regression-distribution", attn_implementation="sdpa" + ).to(torch_device) batch = prepare_batch(repo_id="ibm/patchtst-etth1-test-data", file="regression_distribution_batch.pt") torch.manual_seed(0) From 3c6e82e1629a177f54cd927fe12a28e522fd49a9 Mon Sep 17 00:00:00 2001 From: vasqu Date: Mon, 1 Dec 2025 17:52:32 +0100 Subject: [PATCH 5/5] fixups re tests --- src/transformers/models/patchtst/modeling_patchtst.py | 1 + tests/models/patchtst/test_modeling_patchtst.py | 9 ++++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index 4d6912698f51..c0bbdc748eea 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -558,6 +558,7 @@ class PatchTSTPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = False _supports_flash_attn = True _supports_sdpa = True + _supports_flex_attn = True @torch.no_grad() def _init_weights(self, module: nn.Module): diff --git a/tests/models/patchtst/test_modeling_patchtst.py b/tests/models/patchtst/test_modeling_patchtst.py index 5c1e5a9917f3..72ac0f8087d2 100644 --- a/tests/models/patchtst/test_modeling_patchtst.py +++ b/tests/models/patchtst/test_modeling_patchtst.py @@ -22,7 +22,7 @@ from transformers import is_torch_available from transformers.models.auto import get_values -from transformers.testing_utils import is_flaky, require_torch, slow, torch_device +from transformers.testing_utils import is_flaky, require_read_token, require_torch, slow, torch_device from transformers.utils import check_torch_load_is_safe from ...test_configuration_common import ConfigTester @@ -332,7 +332,7 @@ def test_pretrain_head(self): ) torch.testing.assert_close(output[0, :7, :1, :1], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) - # Publishing of pretrained weights are under internal review. Pretrained model is not yet downloadable. + @require_read_token def test_prediction_head(self): model = PatchTSTForPrediction.from_pretrained("namctin/patchtst_etth1_forecast").to(torch_device) batch = prepare_batch(file="test-batch.pt") @@ -352,6 +352,7 @@ def test_prediction_head(self): ) torch.testing.assert_close(output[0, :1, :7], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) + @require_read_token def test_prediction_generation(self): model = PatchTSTForPrediction.from_pretrained("namctin/patchtst_etth1_forecast").to(torch_device) batch = prepare_batch(file="test-batch.pt") @@ -371,9 +372,7 @@ def test_prediction_generation(self): torch.testing.assert_close(mean_prediction[0, -1:], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) def test_regression_generation(self): - model = PatchTSTForRegression.from_pretrained( - "ibm/patchtst-etth1-regression-distribution", attn_implementation="sdpa" - ).to(torch_device) + model = PatchTSTForRegression.from_pretrained("ibm/patchtst-etth1-regression-distribution").to(torch_device) batch = prepare_batch(repo_id="ibm/patchtst-etth1-test-data", file="regression_distribution_batch.pt") torch.manual_seed(0)