diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index d482efa5b832..c0bbdc748eea 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -24,6 +24,7 @@ from ... import initialization as init from ...activations import ACT2CLS +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -418,7 +419,7 @@ def __init__(self, config: PatchTSTConfig): super().__init__() self.channel_attention = config.channel_attention - # Multi-Head attention + self.self_attn = PatchTSTAttention( embed_dim=config.d_model, num_heads=config.num_attention_heads, @@ -555,6 +556,9 @@ class PatchTSTPreTrainedModel(PreTrainedModel): main_input_name = "past_values" input_modalities = ("time",) supports_gradient_checkpointing = False + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True @torch.no_grad() def _init_weights(self, module: nn.Module): @@ -571,7 +575,15 @@ def _init_weights(self, module: nn.Module): init.normal_(module.cls_token, std=0.02) num_patches += 1 # initialize positional encoding - init.copy_(module.position_enc, module._init_pe(self.config, num_patches)) + position_enc = module._init_pe(self.config, num_patches) + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(module.position_enc, modifier_rank=None): + if module.position_enc.numel() > 0: + init.copy_(module.position_enc, position_enc) + else: + init.copy_(module.position_enc, position_enc) elif isinstance(module, nn.LayerNorm): init.zeros_(module.bias) init.ones_(module.weight) diff --git a/tests/models/patchtst/test_modeling_patchtst.py b/tests/models/patchtst/test_modeling_patchtst.py index ba29095bf8ac..72ac0f8087d2 100644 --- a/tests/models/patchtst/test_modeling_patchtst.py +++ b/tests/models/patchtst/test_modeling_patchtst.py @@ -22,7 +22,7 @@ from transformers import is_torch_available from transformers.models.auto import get_values -from transformers.testing_utils import is_flaky, require_torch, slow, torch_device +from transformers.testing_utils import is_flaky, require_read_token, require_torch, slow, torch_device from transformers.utils import check_torch_load_is_safe from ...test_configuration_common import ConfigTester @@ -184,20 +184,23 @@ def test_config(self): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + # Get the actual batch size from the inputs (may differ from model_tester.batch_size in some tests) + batch_size = inputs_dict["past_values"].shape[0] + # if PatchTSTForPretraining if model_class == PatchTSTForPretraining: - inputs_dict.pop("future_values") + inputs_dict.pop("future_values", None) # else if classification model: elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING): rng = random.Random(self.model_tester.seed) - labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng) + labels = ids_tensor([batch_size], self.model_tester.num_targets, rng=rng) inputs_dict["target_values"] = labels - inputs_dict.pop("future_values") + inputs_dict.pop("future_values", None) elif model_class in get_values(MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING): rng = random.Random(self.model_tester.seed) - target_values = floats_tensor([self.model_tester.batch_size, self.model_tester.num_targets], rng=rng) + target_values = floats_tensor([batch_size, self.model_tester.num_targets], rng=rng) inputs_dict["target_values"] = target_values - inputs_dict.pop("future_values") + inputs_dict.pop("future_values", None) return inputs_dict def test_save_load_strict(self): @@ -329,7 +332,7 @@ def test_pretrain_head(self): ) torch.testing.assert_close(output[0, :7, :1, :1], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) - # Publishing of pretrained weights are under internal review. Pretrained model is not yet downloadable. + @require_read_token def test_prediction_head(self): model = PatchTSTForPrediction.from_pretrained("namctin/patchtst_etth1_forecast").to(torch_device) batch = prepare_batch(file="test-batch.pt") @@ -349,6 +352,7 @@ def test_prediction_head(self): ) torch.testing.assert_close(output[0, :1, :7], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) + @require_read_token def test_prediction_generation(self): model = PatchTSTForPrediction.from_pretrained("namctin/patchtst_etth1_forecast").to(torch_device) batch = prepare_batch(file="test-batch.pt")