From ab33b972e6886d545fb82652640b8ad037e3be2b Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 16 Feb 2025 13:32:24 +0530 Subject: [PATCH 1/7] initial commit --- pytorch_forecasting/data/data_modules.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 pytorch_forecasting/data/data_modules.py diff --git a/pytorch_forecasting/data/data_modules.py b/pytorch_forecasting/data/data_modules.py new file mode 100644 index 000000000..e69de29bb From dd8b6a067f56bb22020eb67983a4fa6f0c861de9 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 16 Feb 2025 13:32:55 +0530 Subject: [PATCH 2/7] adding the timeseries and data module --- pytorch_forecasting/data/data_modules.py | 370 +++++++++++++++++++++++ pytorch_forecasting/data/timeseries.py | 210 +++++++++++++ 2 files changed, 580 insertions(+) diff --git a/pytorch_forecasting/data/data_modules.py b/pytorch_forecasting/data/data_modules.py index e69de29bb..767efd2c7 100644 --- a/pytorch_forecasting/data/data_modules.py +++ b/pytorch_forecasting/data/data_modules.py @@ -0,0 +1,370 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +from lightning.pytorch import LightningDataModule +from sklearn.preprocessing import RobustScaler, StandardScaler +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_forecasting.data.encoders import ( + EncoderNormalizer, + NaNLabelEncoder, + TorchNormalizer, +) +from pytorch_forecasting.data.timeseries import TimeSeries, _coerce_to_dict + +NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] + + +class EncoderDecoderTimeSeriesDataModule(LightningDataModule): + """ + Lightning DataModule for processing time series data in an encoder-decoder format. + + This module handles preprocessing, splitting, and batching of time series data + for use in deep learning models. It supports categorical and continuous features, + various scalers, and automatic target normalization. + + Parameters + ---------- + time_series_dataset : TimeSeries + The dataset containing time series data. + max_encoder_length : int, default=30 + Maximum length of the encoder input sequence. + min_encoder_length : Optional[int], default=None + Minimum length of the encoder input sequence. + Defaults to `max_encoder_length` if not specified. + max_prediction_length : int, default=1 + Maximum length of the decoder output sequence. + min_prediction_length : Optional[int], default=None + Minimum length of the decoder output sequence. + Defaults to `max_prediction_length` if not specified. + min_prediction_idx : Optional[int], default=None + Minimum index from which predictions start. + allow_missing_timesteps : bool, default=False + Whether to allow missing timesteps in the dataset. + add_relative_time_idx : bool, default=False + Whether to add a relative time index feature. + add_target_scales : bool, default=False + Whether to add target scaling information. + add_encoder_length : Union[bool, str], default="auto" + Whether to include encoder length information. + target_normalizer : + Union[NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None], + default="auto" + Normalizer for the target variable. If "auto", uses `RobustScaler`. + + categorical_encoders : Optional[Dict[str, NaNLabelEncoder]], default=None + Dictionary of categorical encoders. + + scalers : + Optional[Dict[str, Union[StandardScaler, RobustScaler, + TorchNormalizer, EncoderNormalizer]]], default=None + Dictionary of feature scalers. + + randomize_length : Union[None, Tuple[float, float], bool], default=False + Whether to randomize input sequence length. + predict_mode : bool, default=False + Whether the module is in prediction mode. + batch_size : int, default=32 + Batch size for DataLoader. + num_workers : int, default=0 + Number of workers for DataLoader. + train_val_test_split : tuple, default=(0.7, 0.15, 0.15) + Proportions for train, validation, and test dataset splits. + """ + + def __init__( + self, + time_series_dataset: TimeSeries, + max_encoder_length: int = 30, + min_encoder_length: Optional[int] = None, + max_prediction_length: int = 1, + min_prediction_length: Optional[int] = None, + min_prediction_idx: Optional[int] = None, + allow_missing_timesteps: bool = False, + add_relative_time_idx: bool = False, + add_target_scales: bool = False, + add_encoder_length: Union[bool, str] = "auto", + target_normalizer: Union[ + NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None + ] = "auto", + categorical_encoders: Optional[Dict[str, NaNLabelEncoder]] = None, + scalers: Optional[ + Dict[ + str, + Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer], + ] + ] = None, + randomize_length: Union[None, Tuple[float, float], bool] = False, + batch_size: int = 32, + num_workers: int = 0, + train_val_test_split: tuple = (0.7, 0.15, 0.15), + ): + super().__init__() + self.time_series_dataset = time_series_dataset + self.metadata = time_series_dataset.get_metadata() + + self.max_encoder_length = max_encoder_length + self.min_encoder_length = min_encoder_length or max_encoder_length + self.max_prediction_length = max_prediction_length + self.min_prediction_length = min_prediction_length or max_prediction_length + self.min_prediction_idx = min_prediction_idx + + self.allow_missing_timesteps = allow_missing_timesteps + self.add_relative_time_idx = add_relative_time_idx + self.add_target_scales = add_target_scales + self.add_encoder_length = add_encoder_length + self.randomize_length = randomize_length + + self.batch_size = batch_size + self.num_workers = num_workers + self.train_val_test_split = train_val_test_split + + if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto": + self.target_normalizer = RobustScaler() + else: + self.target_normalizer = target_normalizer + + self.categorical_encoders = _coerce_to_dict(categorical_encoders) + self.scalers = _coerce_to_dict(scalers) + + self.categorical_indices = [] + self.continuous_indices = [] + + for idx, col in enumerate(self.metadata["cols"]["x"]): + if self.metadata["col_type"].get(col) == "C": + self.categorical_indices.append(idx) + else: + self.continuous_indices.append(idx) + + def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: + processed_data = [] + + for idx in indices: + sample = self.time_series_dataset[idx.item()] + + target = sample["y"] + # if torch.isnan(target).any(): + # (f"Warning: NaNs detected. Sample index: {idx}, Value: {target}") + + if isinstance(target, torch.Tensor): + target = target.float() + else: + target = torch.tensor(target, dtype=torch.float32) + + features = sample["x"] + if isinstance(features, torch.Tensor): + features = features.float() + else: + features = torch.tensor(features, dtype=torch.float32) + + categorical = ( + features[:, self.categorical_indices] + if self.categorical_indices + else torch.zeros((features.shape[0], 0)) + ) + continuous = ( + features[:, self.continuous_indices] + if self.continuous_indices + else torch.zeros((features.shape[0], 0)) + ) + + processed_data.append( + { + "features": {"categorical": categorical, "continuous": continuous}, + "target": target, + "static": sample.get("st", None), + "group": sample.get("group", torch.tensor([0])), + "length": len(target), + } + ) + + return processed_data + + class _ProcessedEncoderDecoderDataset(Dataset): + """PyTorch Dataset for processed encoder-decoder time series data. + + Parameters + ---------- + processed_data : List[Dict[str, Any]] + List of preprocessed time series samples. + windows : List[Tuple[int, int, int, int]] + List of window tuples containing + (series_idx, start_idx, enc_length, pred_length). + add_relative_time_idx : bool, default=False + Whether to include relative time indices. + """ + + def __init__( + self, + processed_data: List[Dict[str, Any]], + windows: List[Tuple[int, int, int, int]], + add_relative_time_idx: bool = False, + ): + self.processed_data = processed_data + self.windows = windows + self.add_relative_time_idx = add_relative_time_idx + + def __len__(self): + return len(self.windows) + + def __getitem__(self, idx): + series_idx, start_idx, enc_length, pred_length = self.windows[idx] + data = self.processed_data[series_idx] + # if start_idx + enc_length + pred_length > len(data['target']): + # print(f"start_idx: {start_idx}, enc_length: {enc_length}, + # pred_length: {pred_length}, target length: {len(data['target'])}") + + end_idx = start_idx + enc_length + pred_length + encoder_indices = slice(start_idx, start_idx + enc_length) + decoder_indices = slice(start_idx + enc_length, end_idx) + + target_scale = data["target"][encoder_indices].abs().mean() + if target_scale == 0: + target_scale = torch.tensor(1.0) + + x = { + "encoder_cat": data["features"]["categorical"][encoder_indices], + "encoder_cont": data["features"]["continuous"][encoder_indices], + "decoder_cat": data["features"]["categorical"][decoder_indices], + "decoder_cont": data["features"]["continuous"][decoder_indices], + "encoder_lengths": torch.tensor(enc_length), + "decoder_lengths": torch.tensor(pred_length), + "decoder_target_lengths": torch.tensor(pred_length), + "groups": data["group"], + "encoder_time_idx": torch.arange(enc_length), + "decoder_time_idx": torch.arange(enc_length, enc_length + pred_length), + "target_scale": target_scale, + } + + if data["static"] is not None: + x["static_categorical_features"] = data["static"].unsqueeze(0) + x["static_continuous_features"] = torch.zeros((1, 0)) + + y = data["target"][decoder_indices] + if y.ndim == 1: + y = y.unsqueeze(-1) + + return x, y + + def _create_windows( + self, processed_data: List[Dict[str, Any]] + ) -> List[Tuple[int, int, int, int]]: + windows = [] + + for idx, data in enumerate(processed_data): + sequence_length = data["length"] + + if sequence_length < self.max_encoder_length + self.max_prediction_length: + continue + + effective_min_prediction_idx = ( + self.min_prediction_idx + if self.min_prediction_idx is not None + else self.max_encoder_length + ) + + max_prediction_idx = sequence_length - self.max_prediction_length + 1 + + if max_prediction_idx <= effective_min_prediction_idx: + continue + + for start_idx in range( + 0, max_prediction_idx - effective_min_prediction_idx + ): + if ( + start_idx + self.max_encoder_length + self.max_prediction_length + <= sequence_length + ): + windows.append( + ( + idx, + start_idx, + self.max_encoder_length, + self.max_prediction_length, + ) + ) + + return windows + + def setup(self, stage: Optional[str] = None): + total_series = len(self.time_series_dataset) + self._split_indices = torch.randperm(total_series) + + self._train_size = int(self.train_val_test_split[0] * total_series) + self._val_size = int(self.train_val_test_split[1] * total_series) + + self._train_indices = self._split_indices[: self._train_size] + self._val_indices = self._split_indices[ + self._train_size : self._train_size + self._val_size + ] + self._test_indices = self._split_indices[self._train_size + self._val_size :] + + if stage is None or stage == "fit": + if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): + self.train_processed = self._preprocess_data(self._train_indices) + self.val_processed = self._preprocess_data(self._val_indices) + + self.train_windows = self._create_windows(self.train_processed) + self.val_windows = self._create_windows(self.val_processed) + + self.train_dataset = self._ProcessedEncoderDecoderDataset( + self.train_processed, self.train_windows, self.add_relative_time_idx + ) + self.val_dataset = self._ProcessedEncoderDecoderDataset( + self.val_processed, self.val_windows, self.add_relative_time_idx + ) + + if stage is None or stage == "test": + if not hasattr(self, "test_dataset"): + self.test_processed = self._preprocess_data(self._test_indices) + self.test_windows = self._create_windows(self.test_processed) + + self.test_dataset = self._ProcessedEncoderDecoderDataset( + self.test_processed, self.test_windows, self.add_relative_time_idx + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + collate_fn=self.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + @staticmethod + def collate_fn(batch): + x_batch = { + "encoder_cat": torch.stack([x["encoder_cat"] for x, _ in batch]), + "encoder_cont": torch.stack([x["encoder_cont"] for x, _ in batch]), + "decoder_cat": torch.stack([x["decoder_cat"] for x, _ in batch]), + "decoder_cont": torch.stack([x["decoder_cont"] for x, _ in batch]), + "encoder_lengths": torch.stack([x["encoder_lengths"] for x, _ in batch]), + "decoder_lengths": torch.stack([x["decoder_lengths"] for x, _ in batch]), + "decoder_target_lengths": torch.stack( + [x["decoder_target_lengths"] for x, _ in batch] + ), + "groups": torch.stack([x["groups"] for x, _ in batch]), + "encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]), + "decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]), + "target_scale": torch.stack([x["target_scale"] for x, _ in batch]), + } + + y_batch = torch.stack([y for _, y in batch]) + return x_batch, y_batch diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 336eecd5f..8037be9fc 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2668,3 +2668,213 @@ def _coerce_to_dict(obj): if obj is None: return {} return deepcopy(obj) + + +class TimeSeries(Dataset): + """PyTorch Dataset for time series data stored in pandas DataFrame. + + ``__getitem__`` returns: + + * ``t``: tensor of shape (n_timepoints) + Time index for each time point in the past or present. Aligned with ``y``, + and ``x`` not ending in ``f``. + * ``y``: tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with ``t``. + * ``x``: tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with ``t``. + * ``group``: tensor of shape (n_groups) + Group identifiers for time series instances. + * ``st``: tensor of shape (n_static_features) + Static features. + + Optionally, the following str-keyed entries can be included: + + * ``t_f``: tensor of shape (n_timepoints_future) + Time index for each time point in the future. + Aligned with ``x_f``. + * ``x_f``: tensor of shape (n_timepoints_future, n_features) + Known features for each time point in the future. + Rows are time points, aligned with ``t_f``. + * ``weights``: tensor of shape (n_timepoints), only if weight is not None + * ``weight_f``: tensor of shape (n_timepoints_future), only if weight is + not None. + + ----------------------------------------------------------------------------------- + + ``get_metadata`` returns metadata: + + * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } + Names of columns for y, x, and static features. + List elements are in same order as column dimensions. + Columns not appearing are assumed to be named (x0, x1, etc.), + (y0, y1, etc.), (st0, st1, etc.). + * ``col_type``: dict[str, str] + maps column names to data types "F" (numerical) and "C" (categorical). + Column names not occurring are assumed "F". + * ``col_known``: dict[str, str] + maps column names to "K" (future known) or "U" (future unknown). + Column names not occurring are assumed "K". + + Parameters + ---------- + data : pd.DataFrame + data frame with sequence data. + Column names must all be str, and contain str as referred to below. + data_future : pd.DataFrame, optional, default=None + data frame with future data. + Column names must all be str, and contain str as referred to below. + May contain only columns that are in time, group, weight, known, or static. + time : str, optional, default = first col not in group_ids, weight, target, static. + integer typed column denoting the time index within ``data``. + This column is used to determine the sequence of samples. + If there are no missing observations, + the time index should increase by ``+1`` for each subsequent sample. + The first time_idx for each series does not necessarily + have to be ``0`` but any value is allowed. + target : str or List[str], optional, default = last column (at iloc -1) + column(s) in ``data`` denoting the forecasting target. + Can be categorical or numerical dtype. + group : List[str], optional, default = None + list of column names identifying a time series instance within ``data``. + This means that the ``group`` together uniquely identify an instance, + and ``group`` together with ``time`` uniquely identify a single observation + within a time series instance. + If ``None``, the dataset is assumed to be a single time series. + weight : str, optional, default=None + column name for weights. + If ``None``, it is assumed that there is no weight column. + num : list of str, optional, default = all columns with dtype in "fi" + list of numerical variables in ``data``, + list may also contain list of str, which are then grouped together. + cat : list of str, optional, default = all columns with dtype in "Obc" + list of categorical variables in ``data``, + list may also contain list of str, which are then grouped together + (e.g. useful for product categories). + known : list of str, optional, default = all variables + list of variables that change over time and are known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for special days or promotion categories). + unknown : list of str, optional, default = no variables + list of variables that are not known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for weather categories). + static : list of str, optional, default = all variables not in known, unknown + list of variables that do not change over time, + list may also contain list of str, which are then grouped together. + """ + + def __init__( + self, + data: pd.DataFrame, + data_future: Optional[pd.DataFrame] = None, + time: Optional[str] = None, + target: Optional[Union[str, List[str]]] = None, + group: Optional[List[str]] = None, + weight: Optional[str] = None, + num: Optional[List[Union[str, List[str]]]] = None, + cat: Optional[List[Union[str, List[str]]]] = None, + known: Optional[List[Union[str, List[str]]]] = None, + unknown: Optional[List[Union[str, List[str]]]] = None, + static: Optional[List[Union[str, List[str]]]] = None, + ): + + self.data = data + self.data_future = data_future + self.time = time + self.target = _coerce_to_list(target) + self.group = _coerce_to_list(group) + self.weight = weight + self.num = _coerce_to_list(num) + self.cat = _coerce_to_list(cat) + self.known = _coerce_to_list(known) + self.unknown = _coerce_to_list(unknown) + self.static = _coerce_to_list(static) + + self.feature_cols = [ + col + for col in data.columns + if col not in [self.time] + self.group + [self.weight] + self.target + ] + if self.group: + self._groups = self.data.groupby(self.group).groups + self._group_ids = list(self._groups.keys()) + else: + self._groups = {"_single_group": self.data.index} + self._group_ids = ["_single_group"] + + self._prepare_metadata() + + def _prepare_metadata(self): + """Prepare metadata for the dataset.""" + self.metadata = { + "cols": { + "y": self.target, + "x": self.feature_cols, + "st": self.static, + }, + "col_type": {}, + "col_known": {}, + } + + all_cols = self.target + self.feature_cols + self.static + for col in all_cols: + self.metadata["col_type"][col] = "C" if col in self.cat else "F" + + self.metadata["col_known"][col] = "K" if col in self.known else "U" + + def __len__(self) -> int: + """Return number of time series in the dataset.""" + return len(self._group_ids) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get time series data for given index.""" + group_id = self._group_ids[index] + + if self.group: + mask = self._groups[group_id] + data = self.data.loc[mask] + else: + data = self.data + + result = { + "t": data[self.time].values, + "y": torch.tensor(data[self.target].values), + "x": torch.tensor(data[self.feature_cols].values), + "group": torch.tensor([hash(str(group_id))]), + "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), + } + + if self.data_future is not None: + if self.group: + future_mask = self.data_future.groupby(self.group).groups[group_id] + future_data = self.data_future.loc[future_mask] + else: + future_data = self.data_future + + result.update( + { + "t_f": torch.tensor(future_data[self.time].values), + "x_f": torch.tensor(future_data[self.known].values), + } + ) + + if self.weight: + result["weight_f"] = torch.tensor(future_data[self.weight].values) + + if self.weight: + result["weights"] = torch.tensor(data[self.weight].values) + + return result + + def get_metadata(self) -> Dict: + """Return metadata about the dataset. + + Returns + ------- + Dict + Dictionary containing: + - cols: column names for y, x, and static features + - col_type: mapping of columns to their types (F/C) + - col_known: mapping of columns to their future known status (K/U) + """ + return self.metadata From c4dd9cfc6ab6083de7068bbac1de02123743934a Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 16 Feb 2025 18:10:32 +0530 Subject: [PATCH 3/7] adding predict to setup --- pytorch_forecasting/data/data_modules.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/pytorch_forecasting/data/data_modules.py b/pytorch_forecasting/data/data_modules.py index 767efd2c7..52b20eacd 100644 --- a/pytorch_forecasting/data/data_modules.py +++ b/pytorch_forecasting/data/data_modules.py @@ -62,8 +62,6 @@ class EncoderDecoderTimeSeriesDataModule(LightningDataModule): randomize_length : Union[None, Tuple[float, float], bool], default=False Whether to randomize input sequence length. - predict_mode : bool, default=False - Whether the module is in prediction mode. batch_size : int, default=32 Batch size for DataLoader. num_workers : int, default=0 @@ -314,7 +312,7 @@ def setup(self, stage: Optional[str] = None): self.val_processed, self.val_windows, self.add_relative_time_idx ) - if stage is None or stage == "test": + elif stage is None or stage == "test": if not hasattr(self, "test_dataset"): self.test_processed = self._preprocess_data(self._test_indices) self.test_windows = self._create_windows(self.test_processed) @@ -322,6 +320,13 @@ def setup(self, stage: Optional[str] = None): self.test_dataset = self._ProcessedEncoderDecoderDataset( self.test_processed, self.test_windows, self.add_relative_time_idx ) + elif stage == "predict": + predict_indices = torch.arange(len(self.time_series_dataset)) + self.predict_processed = self._preprocess_data(predict_indices) + self.predict_windows = self._create_windows(self.predict_processed) + self.predict_dataset = self._ProcessedEncoderDecoderDataset( + self.predict_processed, self.predict_windows, self.add_relative_time_idx + ) def train_dataloader(self): return DataLoader( @@ -348,6 +353,14 @@ def test_dataloader(self): collate_fn=self.collate_fn, ) + def predict_dataloader(self): + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + @staticmethod def collate_fn(batch): x_batch = { From 54d7828d33a53801b4e8e3a2c57b1b7793c9ee32 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 19 Feb 2025 19:28:20 +0530 Subject: [PATCH 4/7] Adding tests and some debugging --- pytorch_forecasting/data/data_modules.py | 8 + pytorch_forecasting/data/timeseries.py | 2 + tests/test_data/test_data_module.py | 432 +++++++++++++++++++++++ 3 files changed, 442 insertions(+) create mode 100644 tests/test_data/test_data_module.py diff --git a/pytorch_forecasting/data/data_modules.py b/pytorch_forecasting/data/data_modules.py index 52b20eacd..f2c5752de 100644 --- a/pytorch_forecasting/data/data_modules.py +++ b/pytorch_forecasting/data/data_modules.py @@ -379,5 +379,13 @@ def collate_fn(batch): "target_scale": torch.stack([x["target_scale"] for x, _ in batch]), } + if "static_categorical_features" in batch[0][0]: + x_batch["static_categorical_features"] = torch.stack( + [x["static_categorical_features"] for x, _ in batch] + ) + x_batch["static_continuous_features"] = torch.stack( + [x["static_continuous_features"] for x, _ in batch] + ) + y_batch = torch.stack([y for _, y in batch]) return x_batch, y_batch diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 8037be9fc..a08dc3721 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2657,6 +2657,8 @@ def _coerce_to_list(obj): """ if obj is None: return [] + if isinstance(obj, str): + return [obj] return list(obj) diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py new file mode 100644 index 000000000..5d85e55e5 --- /dev/null +++ b/tests/test_data/test_data_module.py @@ -0,0 +1,432 @@ +import numpy as np +import pandas as pd +import pytest +from torch.utils.data import DataLoader + +from pytorch_forecasting.data.data_modules import EncoderDecoderTimeSeriesDataModule +from pytorch_forecasting.data.timeseries import TimeSeries + + +@pytest.fixture +def sample_timeseries_data(): + """Generate a sample time series dataset for testing.""" + dates = pd.date_range(start="2023-01-01", periods=50, freq="D") + n_series = 100 + + data = [] + for i in range(n_series): + group_id = i + static_feat = i % 2 + + series = pd.DataFrame( + { + "time": (dates - dates[0]).days, + "group_id": group_id, + "category_1": np.random.randint(0, 3, len(dates), dtype=np.int32), + "category_2": np.random.randint(0, 5, len(dates), dtype=np.int32), + "value_1": np.random.randn(len(dates)).astype(np.float32), + "value_2": np.random.randn(len(dates)).astype(np.float32), + "known_future_1": np.random.randn(len(dates)).astype(np.float32), + "known_future_2": np.random.randint(0, 3, len(dates), dtype=np.int32), + "unknown_future_1": np.random.randn(len(dates)).astype(np.float32), + "target": np.sin(np.linspace(0, 8 * np.pi, len(dates))).astype( + np.float32 + ) + + np.random.randn(len(dates)).astype(np.float32) * 0.1, + "static_feat": np.full(len(dates), static_feat, dtype=np.int32), + } + ) + data.append(series) + + df = pd.concat(data, ignore_index=True) + + df = df.astype( + { + "time": np.int32, + "group_id": np.int32, + "category_1": np.int32, + "category_2": np.int32, + "value_1": np.float32, + "value_2": np.float32, + "known_future_1": np.float32, + "known_future_2": np.int32, + "unknown_future_1": np.float32, + "target": np.float32, + "static_feat": np.int32, + } + ) + + future_dates = pd.date_range(start="2023-02-20", periods=20, freq="D") + future_data = [] + for i in range(n_series): + group_id = i + future_series = pd.DataFrame( + { + "time": (future_dates - dates[0]).days, + "group_id": group_id, + "known_future_1": np.random.randn(len(future_dates)).astype(np.float32), + "known_future_2": np.random.randint( + 0, 3, len(future_dates), dtype=np.int32 + ), + } + ) + future_data.append(future_series) + + future_df = pd.concat(future_data, ignore_index=True) + + future_df = future_df.astype( + { + "time": np.int32, + "group_id": np.int32, + "known_future_1": np.float32, + "known_future_2": np.int32, + } + ) + + ts = TimeSeries( + data=df, + data_future=future_df, + time="time", + target="target", + group=["group_id"], + static=["static_feat"], + cat=["category_1", "category_2", "known_future_2"], + num=["value_1", "value_2", "known_future_1", "unknown_future_1"], + known=["known_future_1", "known_future_2"], + unknown=["unknown_future_1"], + ) + + return ts + + +def test_known_unknown_features(sample_timeseries_data): + """Test handling of known and unknown future features. + + This test checks: + + - Whether metadata correctly identifies known and unknown future features. + - Whether future data is correctly included in the dataset. + - The structure and presence of known future feature tensors in a sample. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + ) + + datamodule.setup() + batch = next(iter(datamodule.train_dataloader())) + x_batch, _ = batch + + # Verify metadata contains known/unknown information + metadata = sample_timeseries_data.get_metadata() + assert "col_known" in metadata + assert metadata["col_known"]["known_future_1"] == "K" + assert metadata["col_known"]["known_future_2"] == "K" + assert metadata["col_known"]["unknown_future_1"] == "U" + + # Verify future data handling + sample = sample_timeseries_data[0] + assert "x_f" in sample + assert sample["x_f"].shape[1] == 2 # known_future_1 and known_future_2 + + +def test_initialization(sample_timeseries_data): + """Test the initialization of the EncoderDecoderTimeSeriesDataModule. + + This test verifies: + + - The correct assignment of encoder and prediction lengths. + - The default batch size is set correctly. + - Categorical and continuous features are correctly identified. + - Metadata correctly maps categorical and continuous features. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=30, + max_prediction_length=10, + ) + + assert datamodule.max_encoder_length == 30 + assert datamodule.max_prediction_length == 10 + assert datamodule.batch_size == 32 + + # Check correct identification of categorical and continuous features + assert len(datamodule.categorical_indices) == 3 # category_1, category_2 + assert len(datamodule.continuous_indices) == 5 # value_1, value_2, static_feat + + # You might also want to verify the actual indices are correct + metadata = sample_timeseries_data.get_metadata() + feature_cols = metadata["cols"]["x"] + + # Verify categorical indices point to the right columns + for idx in datamodule.categorical_indices: + assert metadata["col_type"][feature_cols[idx]] == "C" + + # Verify continuous indices point to the right columns + for idx in datamodule.continuous_indices: + assert metadata["col_type"][feature_cols[idx]] == "F" + + +def test_setup_train_val_split(sample_timeseries_data): + """Test dataset splitting into train and validation sets. + + This test ensures: + + - The `setup` method properly splits the dataset. + - The train and validation datasets are correctly created. + - The size of the train dataset matches expectations based on the split ratio. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + train_val_test_split=(0.7, 0.15, 0.15), + ) + + datamodule.setup(stage="fit") + + # Verify dataset creation + assert hasattr(datamodule, "train_dataset") + assert hasattr(datamodule, "val_dataset") + + # Check split sizes + expected_train_size = int(0.7 * len(sample_timeseries_data)) + assert len(datamodule._train_indices) == expected_train_size + + +def test_data_loading(sample_timeseries_data): + """Test data loading and batch structure. + + This test checks: + + - The train dataloader is correctly instantiated. + - The batch contains all necessary components. + - The categorical and continuous features have the correct dimensions. + - The target tensor has the expected shape. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + batch_size=16, + ) + + datamodule.setup(stage="fit") + train_loader = datamodule.train_dataloader() + + # Verify DataLoader + assert isinstance(train_loader, DataLoader) + assert train_loader.batch_size == 16 + + # Check batch structure + batch = next(iter(train_loader)) + x_batch, y_batch = batch + + # Verify all required components are present + expected_keys = { + "encoder_cat", + "encoder_cont", + "decoder_cat", + "decoder_cont", + "encoder_lengths", + "decoder_lengths", + "decoder_target_lengths", + "groups", + "encoder_time_idx", + "decoder_time_idx", + "target_scale", + } + assert all(key in x_batch for key in expected_keys) + + # Check shapes + batch_size = 16 + assert x_batch["encoder_cat"].shape == ( + batch_size, + 20, + 3, + ) # (batch, time, n_cat_features) + assert x_batch["encoder_cont"].shape == ( + batch_size, + 20, + 5, + ) # (batch, time, n_cont_features) + assert x_batch["decoder_cat"].shape == ( + batch_size, + 5, + 3, + ) # (batch, pred_length, n_cat_features) + assert x_batch["decoder_cont"].shape == ( + batch_size, + 5, + 5, + ) # (batch, pred_length, n_cont_features) + assert y_batch.shape == (batch_size, 5, 1) # (batch, pred_length, n_targets) + + +def test_different_settings(sample_timeseries_data): + """Test different configuration settings. + + This test verifies: + + - The model handles different encoder and prediction lengths correctly. + - Relative time indices and target scales are properly included. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=15, + min_encoder_length=10, + max_prediction_length=3, + min_prediction_length=2, + batch_size=8, + add_relative_time_idx=True, + add_target_scales=True, + ) + + datamodule.setup(stage="fit") + train_loader = datamodule.train_dataloader() + batch = next(iter(train_loader)) + x_batch, y_batch = batch + + assert x_batch["encoder_cat"].shape[1] == 15 # max_encoder_length + assert x_batch["decoder_cat"].shape[1] == 3 # max_prediction_length + assert x_batch["encoder_time_idx"].shape[1] == 15 + assert "target_scale" in x_batch # verify target scales are included + + +def test_static_features(sample_timeseries_data): + """Test that static features are correctly included. + + This test ensures: + + - Static categorical features are present in the batch. + - Static feature tensor dimensions are as expected. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + ) + + datamodule.setup(stage="fit") + train_loader = datamodule.train_dataloader() + batch = next(iter(train_loader)) + x_batch, _ = batch + + # Verify static features are present + assert "static_categorical_features" in x_batch + assert ( + x_batch["static_categorical_features"].dim() == 3 + ) # (batch_size, 1, n_static_features) + + +def test_group_handling(sample_timeseries_data): + """Test that group information is correctly processed. + + This test verifies: + + - The presence of group identifiers in the batch. + - Group tensor dimensions are as expected. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + ) + + datamodule.setup(stage="fit") + train_loader = datamodule.train_dataloader() + batch = next(iter(train_loader)) + x_batch, _ = batch + + # Verify group information + assert "groups" in x_batch + assert x_batch["groups"].dim() == 2 # (batch_size, 1) + + +def test_window_creation(sample_timeseries_data): + """Test window creation for encoder-decoder time series. + + This test ensures: + + - Windows are correctly generated for each time series. + - Encoder and decoder window sizes match the expected values. + - Window indices reference valid series in the dataset. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + ) + + datamodule.setup(stage="fit") + + # Check that windows are created for each series in training set + processed_data = datamodule.train_processed + windows = datamodule.train_windows + + # Verify window parameters + for window in windows: + series_idx, start_idx, enc_length, pred_length = window + assert enc_length == 20 # max_encoder_length + assert pred_length == 5 # max_prediction_length + assert series_idx < len(processed_data) + + +def test_prediction_mode(sample_timeseries_data): + """Test the behavior of the datamodule in prediction mode. + + This test checks: + + - Whether the prediction dataset is properly created. + - The structure of the prediction batch. + - The presence of target scale information. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + ) + + datamodule.setup(stage="predict") + predict_loader = datamodule.predict_dataloader() + + # Check prediction dataset + assert hasattr(datamodule, "predict_dataset") + + # Verify prediction batch structure + batch = next(iter(predict_loader)) + x_batch, y_batch = batch + + assert x_batch["encoder_cat"].shape[1] == 20 + assert x_batch["decoder_cat"].shape[1] == 5 + assert "target_scale" in x_batch + + +@pytest.mark.parametrize( + "train_val_test_split", [(0.6, 0.2, 0.2), (0.8, 0.1, 0.1), (0.7, 0.15, 0.15)] +) +def test_different_splits(sample_timeseries_data, train_val_test_split): + """Test different train-validation-test splits. + + This test verifies: + + - The dataset is correctly split according to different ratios. + - The sizes of train, validation, and test sets match expected values. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + train_val_test_split=train_val_test_split, + ) + + datamodule.setup(stage="fit") + total_size = len(sample_timeseries_data) + expected_train_size = int(train_val_test_split[0] * total_size) + expected_val_size = int(train_val_test_split[1] * total_size) + expected_test_size = int(train_val_test_split[2] * total_size) + + assert len(datamodule._train_indices) == expected_train_size + assert len(datamodule._val_indices) == expected_val_size + assert len(datamodule._test_indices) == expected_test_size From 9f8256ed84a84a0461dba33c3c2811fa2f773104 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 19 Feb 2025 19:30:58 +0530 Subject: [PATCH 5/7] debug --- tests/test_data/test_data_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index 5d85e55e5..5600244fe 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -155,7 +155,7 @@ def test_initialization(sample_timeseries_data): assert len(datamodule.categorical_indices) == 3 # category_1, category_2 assert len(datamodule.continuous_indices) == 5 # value_1, value_2, static_feat - # You might also want to verify the actual indices are correct + # Verify the actual indices are correct metadata = sample_timeseries_data.get_metadata() feature_cols = metadata["cols"]["x"] From fda5f7edef7a76f7f90358d4997ff5b8d4afe9f0 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 19 Feb 2025 19:35:34 +0530 Subject: [PATCH 6/7] update comments --- tests/test_data/test_data_module.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index 5600244fe..52ecab27b 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -152,8 +152,12 @@ def test_initialization(sample_timeseries_data): assert datamodule.batch_size == 32 # Check correct identification of categorical and continuous features - assert len(datamodule.categorical_indices) == 3 # category_1, category_2 - assert len(datamodule.continuous_indices) == 5 # value_1, value_2, static_feat + assert ( + len(datamodule.categorical_indices) == 3 + ) # category_1, category_2, known_future_2 + assert ( + len(datamodule.continuous_indices) == 5 + ) # value_1, value_2, static_feat, known_future_1, unknown_future_1 # Verify the actual indices are correct metadata = sample_timeseries_data.get_metadata() From 6fa83261a5b200efb97b14e6486abafe3f400021 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Thu, 20 Mar 2025 23:08:01 +0530 Subject: [PATCH 7/7] baseclass metadata commit --- pytorch_forecasting/data/data_modules.py | 53 +++- .../models/base_model_lightning.py | 276 ++++++++++++++++++ .../tft_new_basic.py | 228 +++++++++++++++ 3 files changed, 554 insertions(+), 3 deletions(-) create mode 100644 pytorch_forecasting/models/base_model_lightning.py create mode 100644 pytorch_forecasting/models/temporal_fusion_transformer/tft_new_basic.py diff --git a/pytorch_forecasting/data/data_modules.py b/pytorch_forecasting/data/data_modules.py index f2c5752de..e1b48c40c 100644 --- a/pytorch_forecasting/data/data_modules.py +++ b/pytorch_forecasting/data/data_modules.py @@ -99,7 +99,7 @@ def __init__( ): super().__init__() self.time_series_dataset = time_series_dataset - self.metadata = time_series_dataset.get_metadata() + self.time_series_metadata = time_series_dataset.get_metadata() self.max_encoder_length = max_encoder_length self.min_encoder_length = min_encoder_length or max_encoder_length @@ -127,9 +127,10 @@ def __init__( self.categorical_indices = [] self.continuous_indices = [] + self.metadata = {} - for idx, col in enumerate(self.metadata["cols"]["x"]): - if self.metadata["col_type"].get(col) == "C": + for idx, col in enumerate(self.time_series_metadata["cols"]["x"]): + if self.time_series_metadata["col_type"].get(col) == "C": self.categorical_indices.append(idx) else: self.continuous_indices.append(idx) @@ -201,6 +202,35 @@ def __init__( self.processed_data = processed_data self.windows = windows self.add_relative_time_idx = add_relative_time_idx + self.metadata = {} + if len(windows) > 0: + + sample_idx = 0 + sample_x, sample_y = self.__getitem__(sample_idx) + self.metadata = { + "encoder_cat": sample_x["encoder_cat"].shape, + "encoder_cont": sample_x["encoder_cont"].shape, + "decoder_cat": sample_x["decoder_cat"].shape, + "decoder_cont": sample_x["decoder_cont"].shape, + "encoder_lengths": sample_x["encoder_lengths"].shape, + "decoder_lengths": sample_x["decoder_lengths"].shape, + "decoder_target_lengths": sample_x["decoder_target_lengths"].shape, + "groups": sample_x["groups"].shape, + "encoder_time_idx": sample_x["encoder_time_idx"].shape, + "decoder_time_idx": sample_x["decoder_time_idx"].shape, + "target_scale": sample_x["target_scale"].shape, + "target": sample_y.shape, + } + if "static_categorical_features" in sample_x: + self.metadata["static_categorical_features"] = sample_x[ + "static_categorical_features" + ].shape + self.metadata["static_continuous_features"] = sample_x[ + "static_continuous_features" + ].shape + + def get_metadata(self): + return self.metadata def __len__(self): return len(self.windows) @@ -311,6 +341,16 @@ def setup(self, stage: Optional[str] = None): self.val_dataset = self._ProcessedEncoderDecoderDataset( self.val_processed, self.val_windows, self.add_relative_time_idx ) + if hasattr(self.train_dataset, "get_metadata"): + dataset_metadata = self.train_dataset.get_metadata() + lengths = { + "max_encoder_length": self.max_encoder_length, + "max_prediction_length": self.max_prediction_length, + "min_encoder_length": self.min_encoder_length, + "min_prediction_length": self.min_prediction_length, + } + self.metadata.update(dataset_metadata) + self.metadata.update(lengths) elif stage is None or stage == "test": if not hasattr(self, "test_dataset"): @@ -320,6 +360,10 @@ def setup(self, stage: Optional[str] = None): self.test_dataset = self._ProcessedEncoderDecoderDataset( self.test_processed, self.test_windows, self.add_relative_time_idx ) + if hasattr(self.test_dataset, "get_metadata") and not self.metadata: + dataset_metadata = self.test_dataset.get_metadata() + self.metadata.update(dataset_metadata) + elif stage == "predict": predict_indices = torch.arange(len(self.time_series_dataset)) self.predict_processed = self._preprocess_data(predict_indices) @@ -327,6 +371,9 @@ def setup(self, stage: Optional[str] = None): self.predict_dataset = self._ProcessedEncoderDecoderDataset( self.predict_processed, self.predict_windows, self.add_relative_time_idx ) + if hasattr(self.predict_dataset, "get_metadata") and not self.metadata: + dataset_metadata = self.predict_dataset.get_metadata() + self.metadata.update(dataset_metadata) def train_dataloader(self): return DataLoader( diff --git a/pytorch_forecasting/models/base_model_lightning.py b/pytorch_forecasting/models/base_model_lightning.py new file mode 100644 index 000000000..a070b2e13 --- /dev/null +++ b/pytorch_forecasting/models/base_model_lightning.py @@ -0,0 +1,276 @@ +from typing import Dict, List, Optional, Tuple, Union + +from lightning.pytorch import LightningModule +from lightning.pytorch.utilities.types import STEP_OUTPUT +import torch +import torch.nn as nn +from torch.optim import Optimizer + + +class BaseModel(LightningModule): + def __init__( + self, + loss: nn.Module, + logging_metrics: Optional[List[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[Dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[Dict] = None, + ): + """ + Base model for time series forecasting. + + Parameters + ---------- + loss : nn.Module + Loss function to use for training. + logging_metrics : Optional[List[nn.Module]], optional + List of metrics to log during training, validation, and testing. + optimizer : Optional[Union[Optimizer, str]], optional + Optimizer to use for training. + Can be a string ("adam", "sgd") or an instance of `torch.optim.Optimizer`. + optimizer_params : Optional[Dict], optional + Parameters for the optimizer. + lr_scheduler : Optional[str], optional + Learning rate scheduler to use. + Supported values: "reduce_lr_on_plateau", "step_lr". + lr_scheduler_params : Optional[Dict], optional + Parameters for the learning rate scheduler. + """ + super().__init__() + self.loss = loss + self.logging_metrics = logging_metrics if logging_metrics is not None else [] + self.optimizer = optimizer + self.optimizer_params = optimizer_params if optimizer_params is not None else {} + self.lr_scheduler = lr_scheduler + self.lr_scheduler_params = ( + lr_scheduler_params if lr_scheduler_params is not None else {} + ) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass of the model. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Dictionary containing input tensors + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing output tensors (e.g., predictions). + """ + raise NotImplementedError("Forward method must be implemented by subclass.") + + def training_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Training step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="train") + return {"loss": loss} + + def validation_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Validation step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="val") + return {"val_loss": loss} + + def test_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Test step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="test") + return {"test_loss": loss} + + def predict_step( + self, + batch: Tuple[Dict[str, torch.Tensor]], + batch_idx: int, + dataloader_idx: int = 0, + ) -> torch.Tensor: + """ + Prediction step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input tensors. + batch_idx : int + Index of the batch. + dataloader_idx : int + Index of the dataloader. + + Returns + ------- + torch.Tensor + Predicted output tensor. + """ + x, _ = batch + y_hat = self(x) + return y_hat + + def configure_optimizers(self) -> Dict: + """ + Configure the optimizer and learning rate scheduler. + + Returns + ------- + Dict + Dictionary containing the optimizer and scheduler configuration. + """ + optimizer = self._get_optimizer() + if self.lr_scheduler is not None: + scheduler = self._get_scheduler(optimizer) + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + }, + } + else: + return {"optimizer": optimizer, "lr_scheduler": scheduler} + return {"optimizer": optimizer} + + def _get_optimizer(self) -> Optimizer: + """ + Get the optimizer based on the specified optimizer name and parameters. + + Returns + ------- + Optimizer + The optimizer instance. + """ + if isinstance(self.optimizer, str): + if self.optimizer.lower() == "adam": + return torch.optim.Adam(self.parameters(), **self.optimizer_params) + elif self.optimizer.lower() == "sgd": + return torch.optim.SGD(self.parameters(), **self.optimizer_params) + else: + raise ValueError(f"Optimizer {self.optimizer} not supported.") + elif isinstance(self.optimizer, Optimizer): + return self.optimizer + else: + raise ValueError( + "Optimizer must be either a string or " + "an instance of torch.optim.Optimizer." + ) + + def _get_scheduler( + self, optimizer: Optimizer + ) -> torch.optim.lr_scheduler._LRScheduler: + """ + Get the lr scheduler based on the specified scheduler name and params. + + Parameters + ---------- + optimizer : Optimizer + The optimizer instance. + + Returns + ------- + torch.optim.lr_scheduler._LRScheduler + The learning rate scheduler instance. + """ + if self.lr_scheduler.lower() == "reduce_lr_on_plateau": + return torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, **self.lr_scheduler_params + ) + elif self.lr_scheduler.lower() == "step_lr": + return torch.optim.lr_scheduler.StepLR( + optimizer, **self.lr_scheduler_params + ) + else: + raise ValueError(f"Scheduler {self.lr_scheduler} not supported.") + + def log_metrics( + self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = "val" + ) -> None: + """ + Log additional metrics during training, validation, or testing. + + Parameters + ---------- + y_hat : torch.Tensor + Predicted output tensor. + y : torch.Tensor + Target output tensor. + prefix : str + Prefix for the logged metrics (e.g., "train", "val", "test"). + """ + for metric in self.logging_metrics: + metric_value = metric(y_hat, y) + self.log( + f"{prefix}_{metric.__class__.__name__}", + metric_value, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_new_basic.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_new_basic.py new file mode 100644 index 000000000..274c8fef4 --- /dev/null +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_new_basic.py @@ -0,0 +1,228 @@ +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch.optim import Optimizer + +from pytorch_forecasting.models.base_model_lightning import BaseModel + + +class TFT(BaseModel): + def __init__( + self, + loss: nn.Module, + logging_metrics: Optional[List[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[Dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[Dict] = None, + hidden_size: int = 64, + num_layers: int = 2, + attention_head_size: int = 4, + dropout: float = 0.1, + metadata: Dict = None, + # cont_feature_size: int = 0, + # cat_feature_size: int = 0, + # static_cat_feature_size: int = 0, + # static_cont_feature_size: int = 0, + output_size: int = 1, + # max_encoder_length: int = 30, + # max_prediction_length: int = 1, + ): + """ + Temporal Fusion Transformer (TFT) model that inherits from BaseModel. + + Parameters + ---------- + loss : nn.Module + Loss function to use for training. + logging_metrics : Optional[List[nn.Module]], optional + List of metrics to log during training, validation, and testing. + optimizer : Optional[Union[Optimizer, str]], optional + Optimizer to use for training. Can be a string ("adam", "sgd") or + an instance of `torch.optim.Optimizer`. + optimizer_params : Optional[Dict], optional + Parameters for the optimizer. + lr_scheduler : Optional[str], optional + Learning rate scheduler to use. + Supported values: "reduce_lr_on_plateau", "step_lr". + lr_scheduler_params : Optional[Dict], optional + Parameters for the learning rate scheduler. + hidden_size : int, default=64 + Hidden size for LSTM layers and attention mechanism. + num_layers : int, default=2 + Number of LSTM layers. + attention_head_size : int, default=4 + Number of attention heads. + dropout : float, default=0.1 + Dropout rate. + metadata : Dict + model metadata + """ + super().__init__( + loss=loss, + logging_metrics=logging_metrics, + optimizer=optimizer, + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + lr_scheduler_params=lr_scheduler_params, + ) + self.hidden_size = hidden_size + self.num_layers = num_layers + self.attention_head_size = attention_head_size + self.dropout = dropout + if metadata is None: + raise ValueError("metadata parameter is required") + self.metadata = metadata + self.cont_feature_size = self.metadata["encoder_cont"][-1] + self.cat_feature_size = self.metadata["encoder_cat"][-1] + self.static_cat_feature_size = self.metadata["static_categorical_features"][-1] + self.static_cont_feature_size = self.metadata["static_continuous_features"][-1] + self.output_size = output_size + self.max_encoder_length = self.metadata["max_encoder_length"] + self.max_prediction_length = self.metadata["max_prediction_length"] + + total_feature_size = self.cont_feature_size + self.cat_feature_size + total_static_size = self.static_cat_feature_size + self.static_cont_feature_size + + self.encoder_var_selection = nn.Sequential( + nn.Linear(total_feature_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, total_feature_size), + nn.Sigmoid(), + ) + + self.decoder_var_selection = nn.Sequential( + nn.Linear(total_feature_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, total_feature_size), + nn.Sigmoid(), + ) + + self.static_context_linear = ( + nn.Linear(total_static_size, hidden_size) if total_static_size > 0 else None + ) + + self.lstm_encoder = nn.LSTM( + input_size=total_feature_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + ) + + self.lstm_decoder = nn.LSTM( + input_size=total_feature_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + ) + + self.self_attention = nn.MultiheadAttention( + embed_dim=hidden_size, + num_heads=attention_head_size, + dropout=dropout, + batch_first=True, + ) + + self.pre_output = nn.Linear(hidden_size, hidden_size) + self.output_layer = nn.Linear(hidden_size, output_size) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass of the TFT model. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Dictionary containing input tensors: + - encoder_cat: Categorical encoder features + - encoder_cont: Continuous encoder features + - decoder_cat: Categorical decoder features + - decoder_cont: Continuous decoder features + - static_categorical_features: Static categorical features + - static_continuous_features: Static continuous features + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing output tensors: + - prediction: Prediction output (batch_size, prediction_length, output_size) + """ + batch_size = x["encoder_cont"].shape[0] + + encoder_cat = x.get( + "encoder_cat", + torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device), + ) + encoder_cont = x.get( + "encoder_cont", + torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device), + ) + decoder_cat = x.get( + "decoder_cat", + torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device), + ) + decoder_cont = x.get( + "decoder_cont", + torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device), + ) + + encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2) + decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2) + + static_context = None + if self.static_context_linear is not None: + static_cat = x.get( + "static_categorical_features", + torch.zeros(batch_size, 0, device=self.device), + ) + static_cont = x.get( + "static_continuous_features", + torch.zeros(batch_size, 0, device=self.device), + ) + static_input = torch.cat([static_cont, static_cat], dim=1) + static_context = self.static_context_linear(static_input) + + encoder_weights = self.encoder_var_selection(encoder_input) + encoder_input = encoder_input * encoder_weights + + decoder_weights = self.decoder_var_selection(decoder_input) + decoder_input = decoder_input * decoder_weights + + if static_context is not None: + encoder_static_context = static_context.unsqueeze(1).expand( + -1, self.max_encoder_length, -1 + ) + decoder_static_context = static_context.unsqueeze(1).expand( + -1, self.max_prediction_length, -1 + ) + + encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input) + encoder_output = encoder_output + encoder_static_context + decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n)) + decoder_output = decoder_output + decoder_static_context + else: + encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input) + decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n)) + + sequence = torch.cat([encoder_output, decoder_output], dim=1) + + if static_context is not None: + expanded_static_context = static_context.unsqueeze(1).expand( + -1, sequence.size(1), -1 + ) + + attended_output, _ = self.self_attention( + sequence + expanded_static_context, sequence, sequence + ) + else: + attended_output, _ = self.self_attention(sequence, sequence, sequence) + + decoder_attended = attended_output[:, -self.max_prediction_length :, :] + + output = nn.functional.relu(self.pre_output(decoder_attended)) + prediction = self.output_layer(output) + + return {"prediction": prediction}