Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/transformers/models/patchtst/modeling_patchtst.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from ... import initialization as init
from ...activations import ACT2CLS
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
Expand Down Expand Up @@ -418,7 +419,7 @@ def __init__(self, config: PatchTSTConfig):
super().__init__()

self.channel_attention = config.channel_attention
# Multi-Head attention

self.self_attn = PatchTSTAttention(
embed_dim=config.d_model,
num_heads=config.num_attention_heads,
Expand Down Expand Up @@ -555,6 +556,9 @@ class PatchTSTPreTrainedModel(PreTrainedModel):
main_input_name = "past_values"
input_modalities = ("time",)
supports_gradient_checkpointing = False
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True

@torch.no_grad()
def _init_weights(self, module: nn.Module):
Expand All @@ -571,7 +575,15 @@ def _init_weights(self, module: nn.Module):
init.normal_(module.cls_token, std=0.02)
num_patches += 1
# initialize positional encoding
init.copy_(module.position_enc, module._init_pe(self.config, num_patches))
position_enc = module._init_pe(self.config, num_patches)
if is_deepspeed_zero3_enabled():
import deepspeed

with deepspeed.zero.GatheredParameters(module.position_enc, modifier_rank=None):
if module.position_enc.numel() > 0:
init.copy_(module.position_enc, position_enc)
else:
init.copy_(module.position_enc, position_enc)
Comment on lines +578 to +586
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly against this but why was this added to this PR? Can we move this to a separate PR? cc @kashif

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the slow test were failing without this change...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, thx for adding then

I will probably merge this PR tomorrow then, we cutting last PRs for v5 rn

elif isinstance(module, nn.LayerNorm):
init.zeros_(module.bias)
init.ones_(module.weight)
Expand Down
18 changes: 11 additions & 7 deletions tests/models/patchtst/test_modeling_patchtst.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
from transformers.testing_utils import is_flaky, require_read_token, require_torch, slow, torch_device
from transformers.utils import check_torch_load_is_safe

from ...test_configuration_common import ConfigTester
Expand Down Expand Up @@ -184,20 +184,23 @@ def test_config(self):
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)

# Get the actual batch size from the inputs (may differ from model_tester.batch_size in some tests)
batch_size = inputs_dict["past_values"].shape[0]

# if PatchTSTForPretraining
if model_class == PatchTSTForPretraining:
inputs_dict.pop("future_values")
inputs_dict.pop("future_values", None)
# else if classification model:
elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING):
rng = random.Random(self.model_tester.seed)
labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng)
labels = ids_tensor([batch_size], self.model_tester.num_targets, rng=rng)
inputs_dict["target_values"] = labels
inputs_dict.pop("future_values")
inputs_dict.pop("future_values", None)
elif model_class in get_values(MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING):
rng = random.Random(self.model_tester.seed)
target_values = floats_tensor([self.model_tester.batch_size, self.model_tester.num_targets], rng=rng)
target_values = floats_tensor([batch_size, self.model_tester.num_targets], rng=rng)
inputs_dict["target_values"] = target_values
inputs_dict.pop("future_values")
inputs_dict.pop("future_values", None)
return inputs_dict

def test_save_load_strict(self):
Expand Down Expand Up @@ -329,7 +332,7 @@ def test_pretrain_head(self):
)
torch.testing.assert_close(output[0, :7, :1, :1], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)

# Publishing of pretrained weights are under internal review. Pretrained model is not yet downloadable.
@require_read_token
def test_prediction_head(self):
model = PatchTSTForPrediction.from_pretrained("namctin/patchtst_etth1_forecast").to(torch_device)
batch = prepare_batch(file="test-batch.pt")
Expand All @@ -349,6 +352,7 @@ def test_prediction_head(self):
)
torch.testing.assert_close(output[0, :1, :7], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)

@require_read_token
def test_prediction_generation(self):
model = PatchTSTForPrediction.from_pretrained("namctin/patchtst_etth1_forecast").to(torch_device)
batch = prepare_batch(file="test-batch.pt")
Expand Down