From 252598d2ce3f31244a422cd9206961776ea79615 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 18:43:51 +0530 Subject: [PATCH 01/30] D1, D2 layer commit --- pytorch_forecasting/data/data_module.py | 633 ++++++++++++++++++++++++ pytorch_forecasting/data/timeseries.py | 257 ++++++++++ 2 files changed, 890 insertions(+) create mode 100644 pytorch_forecasting/data/data_module.py diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py new file mode 100644 index 000000000..56917696d --- /dev/null +++ b/pytorch_forecasting/data/data_module.py @@ -0,0 +1,633 @@ +####################################################################################### +# Disclaimer: This data-module is still work in progress and experimental, please +# use with care. This data-module is a basic skeleton of how the data-handling pipeline +# may look like in the future. +# This is D2 layer that will handle the preprocessing and data loaders. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + +from typing import Any, Dict, List, Optional, Tuple, Union + +from lightning.pytorch import LightningDataModule +from sklearn.preprocessing import RobustScaler, StandardScaler +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_forecasting.data.encoders import ( + EncoderNormalizer, + NaNLabelEncoder, + TorchNormalizer, +) +from pytorch_forecasting.data.timeseries import TimeSeries, _coerce_to_dict + +NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] + + +class EncoderDecoderTimeSeriesDataModule(LightningDataModule): + """ + Lightning DataModule for processing time series data in an encoder-decoder format. + + This module handles preprocessing, splitting, and batching of time series data + for use in deep learning models. It supports categorical and continuous features, + various scalers, and automatic target normalization. + + Parameters + ---------- + time_series_dataset : TimeSeries + The dataset containing time series data. + max_encoder_length : int, default=30 + Maximum length of the encoder input sequence. + min_encoder_length : Optional[int], default=None + Minimum length of the encoder input sequence. + Defaults to `max_encoder_length` if not specified. + max_prediction_length : int, default=1 + Maximum length of the decoder output sequence. + min_prediction_length : Optional[int], default=None + Minimum length of the decoder output sequence. + Defaults to `max_prediction_length` if not specified. + min_prediction_idx : Optional[int], default=None + Minimum index from which predictions start. + allow_missing_timesteps : bool, default=False + Whether to allow missing timesteps in the dataset. + add_relative_time_idx : bool, default=False + Whether to add a relative time index feature. + add_target_scales : bool, default=False + Whether to add target scaling information. + add_encoder_length : Union[bool, str], default="auto" + Whether to include encoder length information. + target_normalizer : + Union[NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None], + default="auto" + Normalizer for the target variable. If "auto", uses `RobustScaler`. + + categorical_encoders : Optional[Dict[str, NaNLabelEncoder]], default=None + Dictionary of categorical encoders. + + scalers : + Optional[Dict[str, Union[StandardScaler, RobustScaler, + TorchNormalizer, EncoderNormalizer]]], default=None + Dictionary of feature scalers. + + randomize_length : Union[None, Tuple[float, float], bool], default=False + Whether to randomize input sequence length. + batch_size : int, default=32 + Batch size for DataLoader. + num_workers : int, default=0 + Number of workers for DataLoader. + train_val_test_split : tuple, default=(0.7, 0.15, 0.15) + Proportions for train, validation, and test dataset splits. + """ + + def __init__( + self, + time_series_dataset: TimeSeries, + max_encoder_length: int = 30, + min_encoder_length: Optional[int] = None, + max_prediction_length: int = 1, + min_prediction_length: Optional[int] = None, + min_prediction_idx: Optional[int] = None, + allow_missing_timesteps: bool = False, + add_relative_time_idx: bool = False, + add_target_scales: bool = False, + add_encoder_length: Union[bool, str] = "auto", + target_normalizer: Union[ + NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None + ] = "auto", + categorical_encoders: Optional[Dict[str, NaNLabelEncoder]] = None, + scalers: Optional[ + Dict[ + str, + Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer], + ] + ] = None, + randomize_length: Union[None, Tuple[float, float], bool] = False, + batch_size: int = 32, + num_workers: int = 0, + train_val_test_split: tuple = (0.7, 0.15, 0.15), + ): + super().__init__() + self.time_series_dataset = time_series_dataset + self.time_series_metadata = time_series_dataset.get_metadata() + + self.max_encoder_length = max_encoder_length + self.min_encoder_length = min_encoder_length or max_encoder_length + self.max_prediction_length = max_prediction_length + self.min_prediction_length = min_prediction_length or max_prediction_length + self.min_prediction_idx = min_prediction_idx + + self.allow_missing_timesteps = allow_missing_timesteps + self.add_relative_time_idx = add_relative_time_idx + self.add_target_scales = add_target_scales + self.add_encoder_length = add_encoder_length + self.randomize_length = randomize_length + + self.batch_size = batch_size + self.num_workers = num_workers + self.train_val_test_split = train_val_test_split + + if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto": + self.target_normalizer = RobustScaler() + else: + self.target_normalizer = target_normalizer + + self.categorical_encoders = _coerce_to_dict(categorical_encoders) + self.scalers = _coerce_to_dict(scalers) + + self.categorical_indices = [] + self.continuous_indices = [] + self._metadata = None + + for idx, col in enumerate(self.time_series_metadata["cols"]["x"]): + if self.time_series_metadata["col_type"].get(col) == "C": + self.categorical_indices.append(idx) + else: + self.continuous_indices.append(idx) + + def _prepare_metadata(self): + """Prepare metadata for model initialisation. + + Returns + ------- + dict + dictionary containing the following keys: + + * ``encoder_cat``: Number of categorical variables in the encoder. + Computed as ``len(self.categorical_indices)``, which counts the + categorical feature indices. + * ``encoder_cont``: Number of continuous variables in the encoder. + Computed as ``len(self.continuous_indices)``, which counts the + continuous feature indices. + * ``decoder_cat``: Number of categorical variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "C"(categorical) and col_known == "K" (known) + * ``decoder_cont``: Number of continuous variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "F"(continuous) and col_known == "K"(known) + * ``target``: Number of target variables. + Computed as ``len(self.time_series_metadata["cols"]["y"])``, which + gives the number of output target columns.. + * ``static_categorical_features``: Number of static categorical features + Computed by filtering ``self.time_series_metadata["cols"]["st"]`` + (static features) where col_type == "C" (categorical). + * ``static_continuous_features``: Number of static continuous features + Computed as difference of + ``len(self.time_series_metadata["cols"]["st"])`` (static features) + and static_categorical_features that gives static continuous feature + * ``max_encoder_length``: maximum encoder length + Taken directly from `self.max_encoder_length`. + * ``max_prediction_length``: maximum prediction length + Taken directly from `self.max_prediction_length`. + * ``min_encoder_length``: minimum encoder length + Taken directly from `self.min_encoder_length`. + * ``min_prediction_length``: minimum prediction length + Taken directly from `self.min_prediction_length`. + + """ + encoder_cat_count = len(self.categorical_indices) + encoder_cont_count = len(self.continuous_indices) + + decoder_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "C" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + decoder_cont_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "F" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + + target_count = len(self.time_series_metadata["cols"]["y"]) + metadata = { + "encoder_cat": encoder_cat_count, + "encoder_cont": encoder_cont_count, + "decoder_cat": decoder_cat_count, + "decoder_cont": decoder_cont_count, + "target": target_count, + } + if self.time_series_metadata["cols"]["st"]: + static_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["st"] + if self.time_series_metadata["col_type"].get(col) == "C" + ] + ) + static_cont_count = ( + len(self.time_series_metadata["cols"]["st"]) - static_cat_count + ) + + metadata["static_categorical_features"] = static_cat_count + metadata["static_continuous_features"] = static_cont_count + else: + metadata["static_categorical_features"] = 0 + metadata["static_continuous_features"] = 0 + + metadata.update( + { + "max_encoder_length": self.max_encoder_length, + "max_prediction_length": self.max_prediction_length, + "min_encoder_length": self.min_encoder_length, + "min_prediction_length": self.min_prediction_length, + } + ) + + return metadata + + @property + def metadata(self): + """Compute metadata for model initialization. + + This property returns a dictionary containing the shapes and key information + related to the time series model. The metadata includes: + + * ``encoder_cat``: Number of categorical variables in the encoder. + * ``encoder_cont``: Number of continuous variables in the encoder. + * ``decoder_cat``: Number of categorical variables in the decoder that are + known in advance. + * ``decoder_cont``: Number of continuous variables in the decoder that are + known in advance. + * ``target``: Number of target variables. + + If static features are present, the following keys are added: + + * ``static_categorical_features``: Number of static categorical features + * ``static_continuous_features``: Number of static continuous features + + It also contains the following information: + + * ``max_encoder_length``: maximum encoder length + * ``max_prediction_length``: maximum prediction length + * ``min_encoder_length``: minimum encoder length + * ``min_prediction_length``: minimum prediction length + """ + if self._metadata is None: + self._metadata = self._prepare_metadata() + return self._metadata + + def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: + """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset. + + Preprocessing steps + -------------------- + + * Converts target (`y`) and features (`x`) to `torch.float32`. + * Masks time points that are at or before the cutoff time. + * Splits features into categorical and continuous subsets based on + predefined indices. + + + TODO: add scalers, target normalizers etc. + """ + processed_data = [] + + for idx in indices: + sample = self.time_series_dataset[idx.item()] + + target = sample["y"] + features = sample["x"] + times = sample["t"] + cutoff_time = sample["cutoff_time"] + + time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) + + if isinstance(target, torch.Tensor): + target = target.float() + else: + target = torch.tensor(target, dtype=torch.float32) + + if isinstance(features, torch.Tensor): + features = features.float() + else: + features = torch.tensor(features, dtype=torch.float32) + + # TODO: add scalers, target normalizers etc. + + categorical = ( + features[:, self.categorical_indices] + if self.categorical_indices + else torch.zeros((features.shape[0], 0)) + ) + continuous = ( + features[:, self.continuous_indices] + if self.continuous_indices + else torch.zeros((features.shape[0], 0)) + ) + + processed_data.append( + { + "features": {"categorical": categorical, "continuous": continuous}, + "target": target, + "static": sample.get("st", None), + "group": sample.get("group", torch.tensor([0])), + "length": len(target), + "time_mask": time_mask, + "times": times, + "cutoff_time": cutoff_time, + } + ) + + return processed_data + + class _ProcessedEncoderDecoderDataset(Dataset): + """PyTorch Dataset for processed encoder-decoder time series data. + + Parameters + ---------- + processed_data : List[Dict[str, Any]] + List of preprocessed time series samples. + windows : List[Tuple[int, int, int, int]] + List of window tuples containing + (series_idx, start_idx, enc_length, pred_length). + add_relative_time_idx : bool, default=False + Whether to include relative time indices. + """ + + def __init__( + self, + processed_data: List[Dict[str, Any]], + windows: List[Tuple[int, int, int, int]], + add_relative_time_idx: bool = False, + ): + self.processed_data = processed_data + self.windows = windows + self.add_relative_time_idx = add_relative_time_idx + + def __len__(self): + return len(self.windows) + + def __getitem__(self, idx): + """Retrieve a processed time series window for dataloader input. + + x : dict + Dictionary containing model inputs: + + * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features) + Categorical features for the encoder. + * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features) + Continuous features for the encoder. + * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features) + Categorical features for the decoder. + * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features) + Continuous features for the decoder. + * ``encoder_lengths`` : tensor of shape (1,) + Length of the encoder sequence. + * ``decoder_lengths`` : tensor of shape (1,) + Length of the decoder sequence. + * ``decoder_target_lengths`` : tensor of shape (1,) + Length of the decoder target sequence. + * ``groups`` : tensor of shape (1,) + Group identifier for the time series instance. + * ``encoder_time_idx`` : tensor of shape (enc_length,) + Time indices for the encoder sequence. + * ``decoder_time_idx`` : tensor of shape (pred_length,) + Time indices for the decoder sequence. + * ``target_scale`` : tensor of shape (1,) + Scaling factor for the target values. + * ``encoder_mask`` : tensor of shape (enc_length,) + Boolean mask indicating valid encoder time points. + * ``decoder_mask`` : tensor of shape (pred_length,) + Boolean mask indicating valid decoder time points. + + If static features are present, the following keys are added: + + * ``static_categorical_features`` : tensor of shape + (1, n_static_cat_features), optional + Static categorical features, if available. + * ``static_continuous_features`` : tensor of shape (1, 0), optional + Placeholder for static continuous features (currently empty). + + y : tensor of shape ``(pred_length, n_targets)`` + Target values for the decoder sequence. + """ + series_idx, start_idx, enc_length, pred_length = self.windows[idx] + data = self.processed_data[series_idx] + + end_idx = start_idx + enc_length + pred_length + encoder_indices = slice(start_idx, start_idx + enc_length) + decoder_indices = slice(start_idx + enc_length, end_idx) + + target_scale = data["target"][encoder_indices] + target_scale = target_scale[~torch.isnan(target_scale)].abs().mean() + if torch.isnan(target_scale) or target_scale == 0: + target_scale = torch.tensor(1.0) + + encoder_mask = ( + data["time_mask"][encoder_indices] + if "time_mask" in data + else torch.ones(enc_length, dtype=torch.bool) + ) + decoder_mask = ( + data["time_mask"][decoder_indices] + if "time_mask" in data + else torch.zeros(pred_length, dtype=torch.bool) + ) + + x = { + "encoder_cat": data["features"]["categorical"][encoder_indices], + "encoder_cont": data["features"]["continuous"][encoder_indices], + "decoder_cat": data["features"]["categorical"][decoder_indices], + "decoder_cont": data["features"]["continuous"][decoder_indices], + "encoder_lengths": torch.tensor(enc_length), + "decoder_lengths": torch.tensor(pred_length), + "decoder_target_lengths": torch.tensor(pred_length), + "groups": data["group"], + "encoder_time_idx": torch.arange(enc_length), + "decoder_time_idx": torch.arange(enc_length, enc_length + pred_length), + "target_scale": target_scale, + "encoder_mask": encoder_mask, + "decoder_mask": decoder_mask, + } + if data["static"] is not None: + x["static_categorical_features"] = data["static"].unsqueeze(0) + x["static_continuous_features"] = torch.zeros((1, 0)) + + y = data["target"][decoder_indices] + if y.ndim == 1: + y = y.unsqueeze(-1) + + return x, y + + def _create_windows( + self, processed_data: List[Dict[str, Any]] + ) -> List[Tuple[int, int, int, int]]: + """Generate sliding windows for training, validation, and testing. + + Returns + ------- + List[Tuple[int, int, int, int]] + A list of tuples, where each tuple consists of: + - ``series_idx`` : int + Index of the time series in `processed_data`. + - ``start_idx`` : int + Start index of the encoder window. + - ``enc_length`` : int + Length of the encoder input sequence. + - ``pred_length`` : int + Length of the decoder output sequence. + """ + windows = [] + + for idx, data in enumerate(processed_data): + sequence_length = data["length"] + + if sequence_length < self.max_encoder_length + self.max_prediction_length: + continue + + effective_min_prediction_idx = ( + self.min_prediction_idx + if self.min_prediction_idx is not None + else self.max_encoder_length + ) + + max_prediction_idx = sequence_length - self.max_prediction_length + 1 + + if max_prediction_idx <= effective_min_prediction_idx: + continue + + for start_idx in range( + 0, max_prediction_idx - effective_min_prediction_idx + ): + if ( + start_idx + self.max_encoder_length + self.max_prediction_length + <= sequence_length + ): + windows.append( + ( + idx, + start_idx, + self.max_encoder_length, + self.max_prediction_length, + ) + ) + + return windows + + def setup(self, stage: Optional[str] = None): + """Prepare the datasets for training, validation, testing, or prediction. + + Parameters + ---------- + stage : Optional[str], default=None + Specifies the stage of setup. Can be one of: + - ``"fit"`` : Prepares training and validation datasets. + - ``"test"`` : Prepares the test dataset. + - ``"predict"`` : Prepares the dataset for inference. + - ``None`` : Prepares all datasets. + """ + total_series = len(self.time_series_dataset) + self._split_indices = torch.randperm(total_series) + + self._train_size = int(self.train_val_test_split[0] * total_series) + self._val_size = int(self.train_val_test_split[1] * total_series) + + self._train_indices = self._split_indices[: self._train_size] + self._val_indices = self._split_indices[ + self._train_size : self._train_size + self._val_size + ] + self._test_indices = self._split_indices[self._train_size + self._val_size :] + + if stage is None or stage == "fit": + if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): + self.train_processed = self._preprocess_data(self._train_indices) + self.val_processed = self._preprocess_data(self._val_indices) + + self.train_windows = self._create_windows(self.train_processed) + self.val_windows = self._create_windows(self.val_processed) + + self.train_dataset = self._ProcessedEncoderDecoderDataset( + self.train_processed, self.train_windows, self.add_relative_time_idx + ) + self.val_dataset = self._ProcessedEncoderDecoderDataset( + self.val_processed, self.val_windows, self.add_relative_time_idx + ) + # print(self.val_dataset[0]) + + elif stage is None or stage == "test": + if not hasattr(self, "test_dataset"): + self.test_processed = self._preprocess_data(self._test_indices) + self.test_windows = self._create_windows(self.test_processed) + + self.test_dataset = self._ProcessedEncoderDecoderDataset( + self.test_processed, self.test_windows, self.add_relative_time_idx + ) + elif stage == "predict": + predict_indices = torch.arange(len(self.time_series_dataset)) + self.predict_processed = self._preprocess_data(predict_indices) + self.predict_windows = self._create_windows(self.predict_processed) + self.predict_dataset = self._ProcessedEncoderDecoderDataset( + self.predict_processed, self.predict_windows, self.add_relative_time_idx + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + collate_fn=self.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def predict_dataloader(self): + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + @staticmethod + def collate_fn(batch): + x_batch = { + "encoder_cat": torch.stack([x["encoder_cat"] for x, _ in batch]), + "encoder_cont": torch.stack([x["encoder_cont"] for x, _ in batch]), + "decoder_cat": torch.stack([x["decoder_cat"] for x, _ in batch]), + "decoder_cont": torch.stack([x["decoder_cont"] for x, _ in batch]), + "encoder_lengths": torch.stack([x["encoder_lengths"] for x, _ in batch]), + "decoder_lengths": torch.stack([x["decoder_lengths"] for x, _ in batch]), + "decoder_target_lengths": torch.stack( + [x["decoder_target_lengths"] for x, _ in batch] + ), + "groups": torch.stack([x["groups"] for x, _ in batch]), + "encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]), + "decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]), + "target_scale": torch.stack([x["target_scale"] for x, _ in batch]), + "encoder_mask": torch.stack([x["encoder_mask"] for x, _ in batch]), + "decoder_mask": torch.stack([x["decoder_mask"] for x, _ in batch]), + } + + if "static_categorical_features" in batch[0][0]: + x_batch["static_categorical_features"] = torch.stack( + [x["static_categorical_features"] for x, _ in batch] + ) + x_batch["static_continuous_features"] = torch.stack( + [x["static_continuous_features"] for x, _ in batch] + ) + + y_batch = torch.stack([y for _, y in batch]) + return x_batch, y_batch diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 336eecd5f..bc8300300 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2657,6 +2657,8 @@ def _coerce_to_list(obj): """ if obj is None: return [] + if isinstance(obj, str): + return [obj] return list(obj) @@ -2668,3 +2670,258 @@ def _coerce_to_dict(obj): if obj is None: return {} return deepcopy(obj) + + +####################################################################################### +# Disclaimer: This dataset class is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the data-handling pipeline may +# look like in the future. +# This is the D1 layer that is a "Raw Dataset Layer" mainly for raw data ingestion +# and turning the data to tensors. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + + +class TimeSeries(Dataset): + """PyTorch Dataset for time series data stored in pandas DataFrame. + + Parameters + ---------- + data : pd.DataFrame + data frame with sequence data. + Column names must all be str, and contain str as referred to below. + data_future : pd.DataFrame, optional, default=None + data frame with future data. + Column names must all be str, and contain str as referred to below. + May contain only columns that are in time, group, weight, known, or static. + time : str, optional, default = first col not in group_ids, weight, target, static. + integer typed column denoting the time index within ``data``. + This column is used to determine the sequence of samples. + If there are no missing observations, + the time index should increase by ``+1`` for each subsequent sample. + The first time_idx for each series does not necessarily + have to be ``0`` but any value is allowed. + target : str or List[str], optional, default = last column (at iloc -1) + column(s) in ``data`` denoting the forecasting target. + Can be categorical or numerical dtype. + group : List[str], optional, default = None + list of column names identifying a time series instance within ``data``. + This means that the ``group`` together uniquely identify an instance, + and ``group`` together with ``time`` uniquely identify a single observation + within a time series instance. + If ``None``, the dataset is assumed to be a single time series. + weight : str, optional, default=None + column name for weights. + If ``None``, it is assumed that there is no weight column. + num : list of str, optional, default = all columns with dtype in "fi" + list of numerical variables in ``data``, + list may also contain list of str, which are then grouped together. + cat : list of str, optional, default = all columns with dtype in "Obc" + list of categorical variables in ``data``, + list may also contain list of str, which are then grouped together + (e.g. useful for product categories). + known : list of str, optional, default = all variables + list of variables that change over time and are known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for special days or promotion categories). + unknown : list of str, optional, default = no variables + list of variables that are not known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for weather categories). + static : list of str, optional, default = all variables not in known, unknown + list of variables that do not change over time, + list may also contain list of str, which are then grouped together. + """ + + def __init__( + self, + data: pd.DataFrame, + data_future: Optional[pd.DataFrame] = None, + time: Optional[str] = None, + target: Optional[Union[str, List[str]]] = None, + group: Optional[List[str]] = None, + weight: Optional[str] = None, + num: Optional[List[Union[str, List[str]]]] = None, + cat: Optional[List[Union[str, List[str]]]] = None, + known: Optional[List[Union[str, List[str]]]] = None, + unknown: Optional[List[Union[str, List[str]]]] = None, + static: Optional[List[Union[str, List[str]]]] = None, + ): + + self.data = data + self.data_future = data_future + self.time = time + self.target = _coerce_to_list(target) + self.group = _coerce_to_list(group) + self.weight = weight + self.num = _coerce_to_list(num) + self.cat = _coerce_to_list(cat) + self.known = _coerce_to_list(known) + self.unknown = _coerce_to_list(unknown) + self.static = _coerce_to_list(static) + + self.feature_cols = [ + col + for col in data.columns + if col not in [self.time] + self.group + [self.weight] + self.target + ] + if self.group: + self._groups = self.data.groupby(self.group).groups + self._group_ids = list(self._groups.keys()) + else: + self._groups = {"_single_group": self.data.index} + self._group_ids = ["_single_group"] + + self._prepare_metadata() + + def _prepare_metadata(self): + """Prepare metadata for the dataset. + + The funcion returns metadata that contains: + + * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } + Names of columns for y, x, and static features. + List elements are in same order as column dimensions. + Columns not appearing are assumed to be named (x0, x1, etc.), + (y0, y1, etc.), (st0, st1, etc.). + * ``col_type``: dict[str, str] + maps column names to data types "F" (numerical) and "C" (categorical). + Column names not occurring are assumed "F". + * ``col_known``: dict[str, str] + maps column names to "K" (future known) or "U" (future unknown). + Column names not occurring are assumed "K". + """ + self.metadata = { + "cols": { + "y": self.target, + "x": self.feature_cols, + "st": self.static, + }, + "col_type": {}, + "col_known": {}, + } + + all_cols = self.target + self.feature_cols + self.static + for col in all_cols: + self.metadata["col_type"][col] = "C" if col in self.cat else "F" + + self.metadata["col_known"][col] = "K" if col in self.known else "U" + + def __len__(self) -> int: + """Return number of time series in the dataset.""" + return len(self._group_ids) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get time series data for given index. + + It returns: + + * ``t``: ``numpy.ndarray`` of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with ``y``, + and ``x`` not ending in ``f``. + * ``y``: tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with ``t``. + * ``x``: tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with ``t``. + * ``group``: tensor of shape (n_groups) + Group identifiers for time series instances. + * ``st``: tensor of shape (n_static_features) + Static features. + * ``cutoff_time``: float or ``numpy.float64`` + Cutoff time for the time series instance. + + Optionally, the following str-keyed entry can be included: + + * ``weights``: tensor of shape (n_timepoints), only if weight is not None + """ + group_id = self._group_ids[index] + + if self.group: + mask = self._groups[group_id] + data = self.data.loc[mask] + else: + data = self.data + + cutoff_time = data[self.time].max() + + result = { + "t": data[self.time].values, + "y": torch.tensor(data[self.target].values), + "x": torch.tensor(data[self.feature_cols].values), + "group": torch.tensor([hash(str(group_id))]), + "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), + "cutoff_time": cutoff_time, + } + + if self.data_future is not None: + if self.group: + future_mask = self.data_future.groupby(self.group).groups[group_id] + future_data = self.data_future.loc[future_mask] + else: + future_data = self.data_future + + combined_times = np.concatenate( + [data[self.time].values, future_data[self.time].values] + ) + combined_times = np.unique(combined_times) + combined_times.sort() + + num_timepoints = len(combined_times) + x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) + y_merged = np.full((num_timepoints, len(self.target)), np.nan) + + current_time_indices = {t: i for i, t in enumerate(combined_times)} + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + x_merged[idx] = data[self.feature_cols].values[i] + y_merged[idx] = data[self.target].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices: + idx = current_time_indices[t] + for j, col in enumerate(self.known): + if col in self.feature_cols: + feature_idx = self.feature_cols.index(col) + x_merged[idx, feature_idx] = future_data[col].values[i] + + result.update( + { + "t": combined_times, + "x": torch.tensor(x_merged, dtype=torch.float32), + "y": torch.tensor(y_merged, dtype=torch.float32), + } + ) + + if self.weight: + if self.data_future is not None and self.weight in self.data_future.columns: + weights_merged = np.full(num_timepoints, np.nan) + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + weights_merged[idx] = data[self.weight].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices and self.weight in future_data.columns: + idx = current_time_indices[t] + weights_merged[idx] = future_data[self.weight].values[i] + + result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) + else: + result["weights"] = torch.tensor( + data[self.weight].values, dtype=torch.float32 + ) + + return result + + def get_metadata(self) -> Dict: + """Return metadata about the dataset. + + Returns + ------- + Dict + Dictionary containing: + - cols: column names for y, x, and static features + - col_type: mapping of columns to their types (F/C) + - col_known: mapping of columns to their future known status (K/U) + """ + return self.metadata From d0d1c3ec7fb3bdee8e80d9ff83cd43e8990a5319 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 18:47:46 +0530 Subject: [PATCH 02/30] remove one comment --- pytorch_forecasting/data/data_module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 56917696d..2958f1705 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -550,7 +550,6 @@ def setup(self, stage: Optional[str] = None): self.val_dataset = self._ProcessedEncoderDecoderDataset( self.val_processed, self.val_windows, self.add_relative_time_idx ) - # print(self.val_dataset[0]) elif stage is None or stage == "test": if not hasattr(self, "test_dataset"): From 80e64d218a744557bd493ea07547f0f42b029573 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 19:07:01 +0530 Subject: [PATCH 03/30] model layer commit --- .../models/base/base_model_refactor.py | 283 ++++++++++++++++++ .../tft_version_two.py | 218 ++++++++++++++ 2 files changed, 501 insertions(+) create mode 100644 pytorch_forecasting/models/base/base_model_refactor.py create mode 100644 pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py diff --git a/pytorch_forecasting/models/base/base_model_refactor.py b/pytorch_forecasting/models/base/base_model_refactor.py new file mode 100644 index 000000000..ccd2c2600 --- /dev/null +++ b/pytorch_forecasting/models/base/base_model_refactor.py @@ -0,0 +1,283 @@ +######################################################################################## +# Disclaimer: This baseclass is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the base classes may look like +# in the version-2. +######################################################################################## + + +from typing import Dict, List, Optional, Tuple, Union + +from lightning.pytorch import LightningModule +from lightning.pytorch.utilities.types import STEP_OUTPUT +import torch +import torch.nn as nn +from torch.optim import Optimizer + + +class BaseModel(LightningModule): + def __init__( + self, + loss: nn.Module, + logging_metrics: Optional[List[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[Dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[Dict] = None, + ): + """ + Base model for time series forecasting. + + Parameters + ---------- + loss : nn.Module + Loss function to use for training. + logging_metrics : Optional[List[nn.Module]], optional + List of metrics to log during training, validation, and testing. + optimizer : Optional[Union[Optimizer, str]], optional + Optimizer to use for training. + Can be a string ("adam", "sgd") or an instance of `torch.optim.Optimizer`. + optimizer_params : Optional[Dict], optional + Parameters for the optimizer. + lr_scheduler : Optional[str], optional + Learning rate scheduler to use. + Supported values: "reduce_lr_on_plateau", "step_lr". + lr_scheduler_params : Optional[Dict], optional + Parameters for the learning rate scheduler. + """ + super().__init__() + self.loss = loss + self.logging_metrics = logging_metrics if logging_metrics is not None else [] + self.optimizer = optimizer + self.optimizer_params = optimizer_params if optimizer_params is not None else {} + self.lr_scheduler = lr_scheduler + self.lr_scheduler_params = ( + lr_scheduler_params if lr_scheduler_params is not None else {} + ) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass of the model. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Dictionary containing input tensors + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing output tensors + """ + raise NotImplementedError("Forward method must be implemented by subclass.") + + def training_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Training step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="train") + return {"loss": loss} + + def validation_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Validation step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="val") + return {"val_loss": loss} + + def test_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Test step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="test") + return {"test_loss": loss} + + def predict_step( + self, + batch: Tuple[Dict[str, torch.Tensor]], + batch_idx: int, + dataloader_idx: int = 0, + ) -> torch.Tensor: + """ + Prediction step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input tensors. + batch_idx : int + Index of the batch. + dataloader_idx : int + Index of the dataloader. + + Returns + ------- + torch.Tensor + Predicted output tensor. + """ + x, _ = batch + y_hat = self(x) + return y_hat + + def configure_optimizers(self) -> Dict: + """ + Configure the optimizer and learning rate scheduler. + + Returns + ------- + Dict + Dictionary containing the optimizer and scheduler configuration. + """ + optimizer = self._get_optimizer() + if self.lr_scheduler is not None: + scheduler = self._get_scheduler(optimizer) + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + }, + } + else: + return {"optimizer": optimizer, "lr_scheduler": scheduler} + return {"optimizer": optimizer} + + def _get_optimizer(self) -> Optimizer: + """ + Get the optimizer based on the specified optimizer name and parameters. + + Returns + ------- + Optimizer + The optimizer instance. + """ + if isinstance(self.optimizer, str): + if self.optimizer.lower() == "adam": + return torch.optim.Adam(self.parameters(), **self.optimizer_params) + elif self.optimizer.lower() == "sgd": + return torch.optim.SGD(self.parameters(), **self.optimizer_params) + else: + raise ValueError(f"Optimizer {self.optimizer} not supported.") + elif isinstance(self.optimizer, Optimizer): + return self.optimizer + else: + raise ValueError( + "Optimizer must be either a string or " + "an instance of torch.optim.Optimizer." + ) + + def _get_scheduler( + self, optimizer: Optimizer + ) -> torch.optim.lr_scheduler._LRScheduler: + """ + Get the lr scheduler based on the specified scheduler name and params. + + Parameters + ---------- + optimizer : Optimizer + The optimizer instance. + + Returns + ------- + torch.optim.lr_scheduler._LRScheduler + The learning rate scheduler instance. + """ + if self.lr_scheduler.lower() == "reduce_lr_on_plateau": + return torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, **self.lr_scheduler_params + ) + elif self.lr_scheduler.lower() == "step_lr": + return torch.optim.lr_scheduler.StepLR( + optimizer, **self.lr_scheduler_params + ) + else: + raise ValueError(f"Scheduler {self.lr_scheduler} not supported.") + + def log_metrics( + self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = "val" + ) -> None: + """ + Log additional metrics during training, validation, or testing. + + Parameters + ---------- + y_hat : torch.Tensor + Predicted output tensor. + y : torch.Tensor + Target output tensor. + prefix : str + Prefix for the logged metrics (e.g., "train", "val", "test"). + """ + for metric in self.logging_metrics: + metric_value = metric(y_hat, y) + self.log( + f"{prefix}_{metric.__class__.__name__}", + metric_value, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py new file mode 100644 index 000000000..30f70f98e --- /dev/null +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py @@ -0,0 +1,218 @@ +######################################################################################## +# Disclaimer: This implementation is based on the new version of data pipeline and is +# experimental, please use with care. +######################################################################################## + +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch.optim import Optimizer + +from pytorch_forecasting.models.base.base_model_refactor import BaseModel + + +class TFT(BaseModel): + def __init__( + self, + loss: nn.Module, + logging_metrics: Optional[List[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[Dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[Dict] = None, + hidden_size: int = 64, + num_layers: int = 2, + attention_head_size: int = 4, + dropout: float = 0.1, + metadata: Optional[Dict] = None, + output_size: int = 1, + ): + super().__init__( + loss=loss, + logging_metrics=logging_metrics, + optimizer=optimizer, + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + lr_scheduler_params=lr_scheduler_params, + ) + self.hidden_size = hidden_size + self.num_layers = num_layers + self.attention_head_size = attention_head_size + self.dropout = dropout + self.metadata = metadata + self.output_size = output_size + + self.max_encoder_length = self.metadata["max_encoder_length"] + self.max_prediction_length = self.metadata["max_prediction_length"] + self.encoder_cont = self.metadata["encoder_cont"] + self.encoder_cat = self.metadata["encoder_cat"] + self.static_categorical_features = self.metadata["static_categorical_features"] + self.static_continuous_features = self.metadata["static_continuous_features"] + + total_feature_size = self.encoder_cont + self.encoder_cat + total_static_size = ( + self.static_categorical_features + self.static_continuous_features + ) + + self.encoder_var_selection = nn.Sequential( + nn.Linear(total_feature_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, total_feature_size), + nn.Sigmoid(), + ) + + self.decoder_var_selection = nn.Sequential( + nn.Linear(total_feature_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, total_feature_size), + nn.Sigmoid(), + ) + + self.static_context_linear = ( + nn.Linear(total_static_size, hidden_size) if total_static_size > 0 else None + ) + + self.lstm_encoder = nn.LSTM( + input_size=total_feature_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + ) + + self.lstm_decoder = nn.LSTM( + input_size=total_feature_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + ) + + self.self_attention = nn.MultiheadAttention( + embed_dim=hidden_size, + num_heads=attention_head_size, + dropout=dropout, + batch_first=True, + ) + + self.pre_output = nn.Linear(hidden_size, hidden_size) + self.output_layer = nn.Linear(hidden_size, output_size) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass of the TFT model. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Dictionary containing input tensors: + - encoder_cat: Categorical encoder features + - encoder_cont: Continuous encoder features + - decoder_cat: Categorical decoder features + - decoder_cont: Continuous decoder features + - static_categorical_features: Static categorical features + - static_continuous_features: Static continuous features + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing output tensors: + - prediction: Prediction output (batch_size, prediction_length, output_size) + """ + batch_size = x["encoder_cont"].shape[0] + + encoder_cat = x.get( + "encoder_cat", + torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device), + ) + encoder_cont = x.get( + "encoder_cont", + torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device), + ) + decoder_cat = x.get( + "decoder_cat", + torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device), + ) + decoder_cont = x.get( + "decoder_cont", + torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device), + ) + + encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2) + decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2) + + static_context = None + if self.static_context_linear is not None: + static_cat = x.get( + "static_categorical_features", + torch.zeros(batch_size, 0, device=self.device), + ) + static_cont = x.get( + "static_continuous_features", + torch.zeros(batch_size, 0, device=self.device), + ) + + if static_cat.size(2) == 0 and static_cont.size(2) == 0: + static_context = None + elif static_cat.size(2) == 0: + static_input = static_cont.to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + elif static_cont.size(2) == 0: + static_input = static_cat.to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + else: + + static_input = torch.cat([static_cont, static_cat], dim=1).to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + + encoder_weights = self.encoder_var_selection(encoder_input) + encoder_input = encoder_input * encoder_weights + + decoder_weights = self.decoder_var_selection(decoder_input) + decoder_input = decoder_input * decoder_weights + + if static_context is not None: + encoder_static_context = static_context.unsqueeze(1).expand( + -1, self.max_encoder_length, -1 + ) + decoder_static_context = static_context.unsqueeze(1).expand( + -1, self.max_prediction_length, -1 + ) + + encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input) + encoder_output = encoder_output + encoder_static_context + decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n)) + decoder_output = decoder_output + decoder_static_context + else: + encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input) + decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n)) + + sequence = torch.cat([encoder_output, decoder_output], dim=1) + + if static_context is not None: + expanded_static_context = static_context.unsqueeze(1).expand( + -1, sequence.size(1), -1 + ) + + attended_output, _ = self.self_attention( + sequence + expanded_static_context, sequence, sequence + ) + else: + attended_output, _ = self.self_attention(sequence, sequence, sequence) + + decoder_attended = attended_output[:, -self.max_prediction_length :, :] + + output = nn.functional.relu(self.pre_output(decoder_attended)) + prediction = self.output_layer(output) + + return {"prediction": prediction} From 0319c29bf08a9d100b2b4d711ded0082f172e7f9 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 19:19:15 +0530 Subject: [PATCH 04/30] Example notebook --- examples/ptf_V2_example.ipynb | 5137 +++++++++++++++++++++++++++++++++ 1 file changed, 5137 insertions(+) create mode 100644 examples/ptf_V2_example.ipynb diff --git a/examples/ptf_V2_example.ipynb b/examples/ptf_V2_example.ipynb new file mode 100644 index 000000000..031d9d634 --- /dev/null +++ b/examples/ptf_V2_example.ipynb @@ -0,0 +1,5137 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2630DaOEI4AJ", + "outputId": "96798236-d2f1-4436-c047-49c3771d56c7" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting pytorch-forecasting\n", + " Downloading pytorch_forecasting-1.3.0-py3-none-any.whl.metadata (13 kB)\n", + "Requirement already satisfied: numpy<=3.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.0.2)\n", + "Requirement already satisfied: torch!=2.0.1,<3.0.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.6.0+cu124)\n", + "Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)\n", + " Downloading lightning-2.5.1-py3-none-any.whl.metadata (39 kB)\n", + "Requirement already satisfied: scipy<2.0,>=1.8 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.14.1)\n", + "Requirement already satisfied: pandas<3.0.0,>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.2.2)\n", + "Requirement already satisfied: scikit-learn<2.0,>=1.2 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.6.1)\n", + "Requirement already satisfied: PyYAML<8.0,>=5.4 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.0.2)\n", + "Requirement already satisfied: fsspec<2026.0,>=2022.5.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2025.3.2)\n", + "Collecting lightning-utilities<2.0,>=0.10.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)\n", + "Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (24.2)\n", + "Collecting torchmetrics<3.0,>=0.7.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading torchmetrics-1.7.0-py3-none-any.whl.metadata (21 kB)\n", + "Requirement already satisfied: tqdm<6.0,>=4.57.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.67.1)\n", + "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.13.0)\n", + "Collecting pytorch-lightning (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (3.6.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.18.0)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.1.6)\n", + "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-curand-cu12==10.3.5.147 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (0.6.2)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (2.21.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", + "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.2.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.0)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.11.15)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from lightning-utilities<2.0,>=0.10.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (75.2.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas<3.0.0,>=1.3.0->pytorch-forecasting) (1.17.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.0.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.6.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.2)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (25.3.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.5.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.3.1)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.3.1)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.18.3)\n", + "Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.11/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.10)\n", + "Downloading pytorch_forecasting-1.3.0-py3-none-any.whl (197 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m197.7/197.7 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning-2.5.1-py3-none-any.whl (818 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m818.9/818.9 kB\u001b[0m \u001b[31m16.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m41.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m29.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m30.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m31.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)\n", + "Downloading torchmetrics-1.7.0-py3-none-any.whl (960 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m960.9/960.9 kB\u001b[0m \u001b[31m35.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading pytorch_lightning-2.5.1-py3-none-any.whl (822 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.0/823.0 kB\u001b[0m \u001b[31m31.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch-lightning, lightning, pytorch-forecasting\n", + " Attempting uninstall: nvidia-nvjitlink-cu12\n", + " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", + " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", + " Attempting uninstall: nvidia-curand-cu12\n", + " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", + " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", + " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", + " Attempting uninstall: nvidia-cufft-cu12\n", + " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", + " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", + " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", + " Attempting uninstall: nvidia-cuda-runtime-cu12\n", + " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", + " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-cupti-cu12\n", + " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cublas-cu12\n", + " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", + " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", + " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", + " Attempting uninstall: nvidia-cusparse-cu12\n", + " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", + " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", + " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", + " Attempting uninstall: nvidia-cudnn-cu12\n", + " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", + " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", + " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", + " Attempting uninstall: nvidia-cusolver-cu12\n", + " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", + " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", + " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", + "Successfully installed lightning-2.5.1 lightning-utilities-0.14.3 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-forecasting-1.3.0 pytorch-lightning-2.5.1 torchmetrics-1.7.0\n" + ] + } + ], + "source": [ + "!pip install pytorch-forecasting" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "M7PQerTbI_tM" + }, + "outputs": [], + "source": [ + "from typing import Any, Dict, List, Optional, Tuple, Union\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "from torch.utils.data import Dataset\n", + "\n", + "from pytorch_forecasting.data.timeseries import _coerce_to_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "XmL5ukG9JDTD" + }, + "outputs": [], + "source": [ + "def _coerce_to_list(obj):\n", + " \"\"\"Coerce object to list.\n", + "\n", + " None is coerced to empty list, otherwise list constructor is used.\n", + " \"\"\"\n", + " if obj is None:\n", + " return []\n", + " if isinstance(obj, str):\n", + " return [obj]\n", + " return list(obj)\n", + "\n", + "\n", + "class TimeSeries(Dataset):\n", + " \"\"\"PyTorch Dataset for time series data stored in pandas DataFrame.\n", + "\n", + " Parameters\n", + " ----------\n", + " data : pd.DataFrame\n", + " data frame with sequence data.\n", + " Column names must all be str, and contain str as referred to below.\n", + " data_future : pd.DataFrame, optional, default=None\n", + " data frame with future data.\n", + " Column names must all be str, and contain str as referred to below.\n", + " May contain only columns that are in time, group, weight, known, or static.\n", + " time : str, optional, default = first col not in group_ids, weight, target, static.\n", + " integer typed column denoting the time index within ``data``.\n", + " This column is used to determine the sequence of samples.\n", + " If there are no missing observations,\n", + " the time index should increase by ``+1`` for each subsequent sample.\n", + " The first time_idx for each series does not necessarily\n", + " have to be ``0`` but any value is allowed.\n", + " target : str or List[str], optional, default = last column (at iloc -1)\n", + " column(s) in ``data`` denoting the forecasting target.\n", + " Can be categorical or numerical dtype.\n", + " group : List[str], optional, default = None\n", + " list of column names identifying a time series instance within ``data``.\n", + " This means that the ``group`` together uniquely identify an instance,\n", + " and ``group`` together with ``time`` uniquely identify a single observation\n", + " within a time series instance.\n", + " If ``None``, the dataset is assumed to be a single time series.\n", + " weight : str, optional, default=None\n", + " column name for weights.\n", + " If ``None``, it is assumed that there is no weight column.\n", + " num : list of str, optional, default = all columns with dtype in \"fi\"\n", + " list of numerical variables in ``data``,\n", + " list may also contain list of str, which are then grouped together.\n", + " cat : list of str, optional, default = all columns with dtype in \"Obc\"\n", + " list of categorical variables in ``data``,\n", + " list may also contain list of str, which are then grouped together\n", + " (e.g. useful for product categories).\n", + " known : list of str, optional, default = all variables\n", + " list of variables that change over time and are known in the future,\n", + " list may also contain list of str, which are then grouped together\n", + " (e.g. useful for special days or promotion categories).\n", + " unknown : list of str, optional, default = no variables\n", + " list of variables that are not known in the future,\n", + " list may also contain list of str, which are then grouped together\n", + " (e.g. useful for weather categories).\n", + " static : list of str, optional, default = all variables not in known, unknown\n", + " list of variables that do not change over time,\n", + " list may also contain list of str, which are then grouped together.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " data: pd.DataFrame,\n", + " data_future: Optional[pd.DataFrame] = None,\n", + " time: Optional[str] = None,\n", + " target: Optional[Union[str, List[str]]] = None,\n", + " group: Optional[List[str]] = None,\n", + " weight: Optional[str] = None,\n", + " num: Optional[List[Union[str, List[str]]]] = None,\n", + " cat: Optional[List[Union[str, List[str]]]] = None,\n", + " known: Optional[List[Union[str, List[str]]]] = None,\n", + " unknown: Optional[List[Union[str, List[str]]]] = None,\n", + " static: Optional[List[Union[str, List[str]]]] = None,\n", + " ):\n", + "\n", + " self.data = data\n", + " self.data_future = data_future\n", + " self.time = time\n", + " self.target = _coerce_to_list(target)\n", + " self.group = _coerce_to_list(group)\n", + " self.weight = weight\n", + " self.num = _coerce_to_list(num)\n", + " self.cat = _coerce_to_list(cat)\n", + " self.known = _coerce_to_list(known)\n", + " self.unknown = _coerce_to_list(unknown)\n", + " self.static = _coerce_to_list(static)\n", + "\n", + " self.feature_cols = [\n", + " col\n", + " for col in data.columns\n", + " if col not in [self.time] + self.group + [self.weight] + self.target\n", + " ]\n", + " if self.group:\n", + " self._groups = self.data.groupby(self.group).groups\n", + " self._group_ids = list(self._groups.keys())\n", + " else:\n", + " self._groups = {\"_single_group\": self.data.index}\n", + " self._group_ids = [\"_single_group\"]\n", + "\n", + " self._prepare_metadata()\n", + "\n", + " def _prepare_metadata(self):\n", + " \"\"\"Prepare metadata for the dataset.\n", + "\n", + " The funcion returns metadata that contains:\n", + "\n", + " * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] }\n", + " Names of columns for y, x, and static features.\n", + " List elements are in same order as column dimensions.\n", + " Columns not appearing are assumed to be named (x0, x1, etc.),\n", + " (y0, y1, etc.), (st0, st1, etc.).\n", + " * ``col_type``: dict[str, str]\n", + " maps column names to data types \"F\" (numerical) and \"C\" (categorical).\n", + " Column names not occurring are assumed \"F\".\n", + " * ``col_known``: dict[str, str]\n", + " maps column names to \"K\" (future known) or \"U\" (future unknown).\n", + " Column names not occurring are assumed \"K\".\n", + " \"\"\"\n", + " self.metadata = {\n", + " \"cols\": {\n", + " \"y\": self.target,\n", + " \"x\": self.feature_cols,\n", + " \"st\": self.static,\n", + " },\n", + " \"col_type\": {},\n", + " \"col_known\": {},\n", + " }\n", + "\n", + " all_cols = self.target + self.feature_cols + self.static\n", + " for col in all_cols:\n", + " self.metadata[\"col_type\"][col] = \"C\" if col in self.cat else \"F\"\n", + "\n", + " self.metadata[\"col_known\"][col] = \"K\" if col in self.known else \"U\"\n", + "\n", + " def __len__(self) -> int:\n", + " \"\"\"Return number of time series in the dataset.\"\"\"\n", + " return len(self._group_ids)\n", + "\n", + " def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:\n", + " \"\"\"Get time series data for given index.\n", + "\n", + " It returns:\n", + "\n", + " * ``t``: ``numpy.ndarray`` of shape (n_timepoints,)\n", + " Time index for each time point in the past or present. Aligned with ``y``,\n", + " and ``x`` not ending in ``f``.\n", + " * ``y``: tensor of shape (n_timepoints, n_targets)\n", + " Target values for each time point. Rows are time points, aligned with ``t``.\n", + " * ``x``: tensor of shape (n_timepoints, n_features)\n", + " Features for each time point. Rows are time points, aligned with ``t``.\n", + " * ``group``: tensor of shape (n_groups)\n", + " Group identifiers for time series instances.\n", + " * ``st``: tensor of shape (n_static_features)\n", + " Static features.\n", + " * ``cutoff_time``: float or ``numpy.float64``\n", + " Cutoff time for the time series instance.\n", + "\n", + " Optionally, the following str-keyed entry can be included:\n", + "\n", + " * ``weights``: tensor of shape (n_timepoints), only if weight is not None\n", + " \"\"\"\n", + " group_id = self._group_ids[index]\n", + "\n", + " if self.group:\n", + " mask = self._groups[group_id]\n", + " data = self.data.loc[mask]\n", + " else:\n", + " data = self.data\n", + "\n", + " cutoff_time = data[self.time].max()\n", + "\n", + " result = {\n", + " \"t\": data[self.time].values,\n", + " \"y\": torch.tensor(data[self.target].values),\n", + " \"x\": torch.tensor(data[self.feature_cols].values),\n", + " \"group\": torch.tensor([hash(str(group_id))]),\n", + " \"st\": torch.tensor(data[self.static].iloc[0].values if self.static else []),\n", + " \"cutoff_time\": cutoff_time,\n", + " }\n", + "\n", + " if self.data_future is not None:\n", + " if self.group:\n", + " future_mask = self.data_future.groupby(self.group).groups[group_id]\n", + " future_data = self.data_future.loc[future_mask]\n", + " else:\n", + " future_data = self.data_future\n", + "\n", + " combined_times = np.concatenate(\n", + " [data[self.time].values, future_data[self.time].values]\n", + " )\n", + " combined_times = np.unique(combined_times)\n", + " combined_times.sort()\n", + "\n", + " num_timepoints = len(combined_times)\n", + " x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan)\n", + " y_merged = np.full((num_timepoints, len(self.target)), np.nan)\n", + "\n", + " current_time_indices = {t: i for i, t in enumerate(combined_times)}\n", + " for i, t in enumerate(data[self.time].values):\n", + " idx = current_time_indices[t]\n", + " x_merged[idx] = data[self.feature_cols].values[i]\n", + " y_merged[idx] = data[self.target].values[i]\n", + "\n", + " for i, t in enumerate(future_data[self.time].values):\n", + " if t in current_time_indices:\n", + " idx = current_time_indices[t]\n", + " for j, col in enumerate(self.known):\n", + " if col in self.feature_cols:\n", + " feature_idx = self.feature_cols.index(col)\n", + " x_merged[idx, feature_idx] = future_data[col].values[i]\n", + "\n", + " result.update(\n", + " {\n", + " \"t\": combined_times,\n", + " \"x\": torch.tensor(x_merged, dtype=torch.float32),\n", + " \"y\": torch.tensor(y_merged, dtype=torch.float32),\n", + " }\n", + " )\n", + "\n", + " if self.weight:\n", + " if self.data_future is not None and self.weight in self.data_future.columns:\n", + " weights_merged = np.full(num_timepoints, np.nan)\n", + " for i, t in enumerate(data[self.time].values):\n", + " idx = current_time_indices[t]\n", + " weights_merged[idx] = data[self.weight].values[i]\n", + "\n", + " for i, t in enumerate(future_data[self.time].values):\n", + " if t in current_time_indices and self.weight in future_data.columns:\n", + " idx = current_time_indices[t]\n", + " weights_merged[idx] = future_data[self.weight].values[i]\n", + "\n", + " result[\"weights\"] = torch.tensor(weights_merged, dtype=torch.float32)\n", + " else:\n", + " result[\"weights\"] = torch.tensor(\n", + " data[self.weight].values, dtype=torch.float32\n", + " )\n", + "\n", + " return result\n", + "\n", + " def get_metadata(self) -> Dict:\n", + " \"\"\"Return metadata about the dataset.\n", + "\n", + " Returns\n", + " -------\n", + " Dict\n", + " Dictionary containing:\n", + " - cols: column names for y, x, and static features\n", + " - col_type: mapping of columns to their types (F/C)\n", + " - col_known: mapping of columns to their future known status (K/U)\n", + " \"\"\"\n", + " return self.metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "0Rw9LgsXJI5V" + }, + "outputs": [], + "source": [ + "from typing import Dict, List, Optional, Union\n", + "\n", + "from lightning.pytorch import LightningDataModule\n", + "from sklearn.preprocessing import RobustScaler, StandardScaler\n", + "import torch\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "from pytorch_forecasting.data.encoders import (\n", + " EncoderNormalizer,\n", + " NaNLabelEncoder,\n", + " TorchNormalizer,\n", + ")\n", + "\n", + "NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer]\n", + "\n", + "\n", + "class EncoderDecoderTimeSeriesDataModule(LightningDataModule):\n", + " \"\"\"\n", + " Lightning DataModule for processing time series data in an encoder-decoder format.\n", + "\n", + " This module handles preprocessing, splitting, and batching of time series data\n", + " for use in deep learning models. It supports categorical and continuous features,\n", + " various scalers, and automatic target normalization.\n", + "\n", + " Parameters\n", + " ----------\n", + " time_series_dataset : TimeSeries\n", + " The dataset containing time series data.\n", + " max_encoder_length : int, default=30\n", + " Maximum length of the encoder input sequence.\n", + " min_encoder_length : Optional[int], default=None\n", + " Minimum length of the encoder input sequence.\n", + " Defaults to `max_encoder_length` if not specified.\n", + " max_prediction_length : int, default=1\n", + " Maximum length of the decoder output sequence.\n", + " min_prediction_length : Optional[int], default=None\n", + " Minimum length of the decoder output sequence.\n", + " Defaults to `max_prediction_length` if not specified.\n", + " min_prediction_idx : Optional[int], default=None\n", + " Minimum index from which predictions start.\n", + " allow_missing_timesteps : bool, default=False\n", + " Whether to allow missing timesteps in the dataset.\n", + " add_relative_time_idx : bool, default=False\n", + " Whether to add a relative time index feature.\n", + " add_target_scales : bool, default=False\n", + " Whether to add target scaling information.\n", + " add_encoder_length : Union[bool, str], default=\"auto\"\n", + " Whether to include encoder length information.\n", + " target_normalizer :\n", + " Union[NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None],\n", + " default=\"auto\"\n", + " Normalizer for the target variable. If \"auto\", uses `RobustScaler`.\n", + "\n", + " categorical_encoders : Optional[Dict[str, NaNLabelEncoder]], default=None\n", + " Dictionary of categorical encoders.\n", + "\n", + " scalers :\n", + " Optional[Dict[str, Union[StandardScaler, RobustScaler,\n", + " TorchNormalizer, EncoderNormalizer]]], default=None\n", + " Dictionary of feature scalers.\n", + "\n", + " randomize_length : Union[None, Tuple[float, float], bool], default=False\n", + " Whether to randomize input sequence length.\n", + " batch_size : int, default=32\n", + " Batch size for DataLoader.\n", + " num_workers : int, default=0\n", + " Number of workers for DataLoader.\n", + " train_val_test_split : tuple, default=(0.7, 0.15, 0.15)\n", + " Proportions for train, validation, and test dataset splits.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " time_series_dataset: TimeSeries,\n", + " max_encoder_length: int = 30,\n", + " min_encoder_length: Optional[int] = None,\n", + " max_prediction_length: int = 1,\n", + " min_prediction_length: Optional[int] = None,\n", + " min_prediction_idx: Optional[int] = None,\n", + " allow_missing_timesteps: bool = False,\n", + " add_relative_time_idx: bool = False,\n", + " add_target_scales: bool = False,\n", + " add_encoder_length: Union[bool, str] = \"auto\",\n", + " target_normalizer: Union[\n", + " NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None\n", + " ] = \"auto\",\n", + " categorical_encoders: Optional[Dict[str, NaNLabelEncoder]] = None,\n", + " scalers: Optional[\n", + " Dict[\n", + " str,\n", + " Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer],\n", + " ]\n", + " ] = None,\n", + " randomize_length: Union[None, Tuple[float, float], bool] = False,\n", + " batch_size: int = 32,\n", + " num_workers: int = 0,\n", + " train_val_test_split: tuple = (0.7, 0.15, 0.15),\n", + " ):\n", + " super().__init__()\n", + " self.time_series_dataset = time_series_dataset\n", + " self.time_series_metadata = time_series_dataset.get_metadata()\n", + "\n", + " self.max_encoder_length = max_encoder_length\n", + " self.min_encoder_length = min_encoder_length or max_encoder_length\n", + " self.max_prediction_length = max_prediction_length\n", + " self.min_prediction_length = min_prediction_length or max_prediction_length\n", + " self.min_prediction_idx = min_prediction_idx\n", + "\n", + " self.allow_missing_timesteps = allow_missing_timesteps\n", + " self.add_relative_time_idx = add_relative_time_idx\n", + " self.add_target_scales = add_target_scales\n", + " self.add_encoder_length = add_encoder_length\n", + " self.randomize_length = randomize_length\n", + "\n", + " self.batch_size = batch_size\n", + " self.num_workers = num_workers\n", + " self.train_val_test_split = train_val_test_split\n", + "\n", + " if isinstance(target_normalizer, str) and target_normalizer.lower() == \"auto\":\n", + " self.target_normalizer = RobustScaler()\n", + " else:\n", + " self.target_normalizer = target_normalizer\n", + "\n", + " self.categorical_encoders = _coerce_to_dict(categorical_encoders)\n", + " self.scalers = _coerce_to_dict(scalers)\n", + "\n", + " self.categorical_indices = []\n", + " self.continuous_indices = []\n", + " self._metadata = None\n", + "\n", + " for idx, col in enumerate(self.time_series_metadata[\"cols\"][\"x\"]):\n", + " if self.time_series_metadata[\"col_type\"].get(col) == \"C\":\n", + " self.categorical_indices.append(idx)\n", + " else:\n", + " self.continuous_indices.append(idx)\n", + "\n", + " def _prepare_metadata(self):\n", + " \"\"\"Prepare metadata for model initialisation.\n", + "\n", + " Returns\n", + " -------\n", + " dict\n", + " dictionary containing the following keys:\n", + "\n", + " * ``encoder_cat``: Number of categorical variables in the encoder.\n", + " Computed as ``len(self.categorical_indices)``, which counts the\n", + " categorical feature indices.\n", + " * ``encoder_cont``: Number of continuous variables in the encoder.\n", + " Computed as ``len(self.continuous_indices)``, which counts the\n", + " continuous feature indices.\n", + " * ``decoder_cat``: Number of categorical variables in the decoder that\n", + " are known in advance.\n", + " Computed by filtering ``self.time_series_metadata[\"cols\"][\"x\"]``\n", + " where col_type == \"C\"(categorical) and col_known == \"K\" (known)\n", + " * ``decoder_cont``: Number of continuous variables in the decoder that\n", + " are known in advance.\n", + " Computed by filtering ``self.time_series_metadata[\"cols\"][\"x\"]``\n", + " where col_type == \"F\"(continuous) and col_known == \"K\"(known)\n", + " * ``target``: Number of target variables.\n", + " Computed as ``len(self.time_series_metadata[\"cols\"][\"y\"])``, which\n", + " gives the number of output target columns..\n", + " * ``static_categorical_features``: Number of static categorical features\n", + " Computed by filtering ``self.time_series_metadata[\"cols\"][\"st\"]``\n", + " (static features) where col_type == \"C\" (categorical).\n", + " * ``static_continuous_features``: Number of static continuous features\n", + " Computed as difference of\n", + " ``len(self.time_series_metadata[\"cols\"][\"st\"])`` (static features)\n", + " and static_categorical_features that gives static continuous feature\n", + " * ``max_encoder_length``: maximum encoder length\n", + " Taken directly from `self.max_encoder_length`.\n", + " * ``max_prediction_length``: maximum prediction length\n", + " Taken directly from `self.max_prediction_length`.\n", + " * ``min_encoder_length``: minimum encoder length\n", + " Taken directly from `self.min_encoder_length`.\n", + " * ``min_prediction_length``: minimum prediction length\n", + " Taken directly from `self.min_prediction_length`.\n", + "\n", + " \"\"\"\n", + " encoder_cat_count = len(self.categorical_indices)\n", + " encoder_cont_count = len(self.continuous_indices)\n", + "\n", + " decoder_cat_count = len(\n", + " [\n", + " col\n", + " for col in self.time_series_metadata[\"cols\"][\"x\"]\n", + " if self.time_series_metadata[\"col_type\"].get(col) == \"C\"\n", + " and self.time_series_metadata[\"col_known\"].get(col) == \"K\"\n", + " ]\n", + " )\n", + " decoder_cont_count = len(\n", + " [\n", + " col\n", + " for col in self.time_series_metadata[\"cols\"][\"x\"]\n", + " if self.time_series_metadata[\"col_type\"].get(col) == \"F\"\n", + " and self.time_series_metadata[\"col_known\"].get(col) == \"K\"\n", + " ]\n", + " )\n", + "\n", + " target_count = len(self.time_series_metadata[\"cols\"][\"y\"])\n", + " metadata = {\n", + " \"encoder_cat\": encoder_cat_count,\n", + " \"encoder_cont\": encoder_cont_count,\n", + " \"decoder_cat\": decoder_cat_count,\n", + " \"decoder_cont\": decoder_cont_count,\n", + " \"target\": target_count,\n", + " }\n", + " if self.time_series_metadata[\"cols\"][\"st\"]:\n", + " static_cat_count = len(\n", + " [\n", + " col\n", + " for col in self.time_series_metadata[\"cols\"][\"st\"]\n", + " if self.time_series_metadata[\"col_type\"].get(col) == \"C\"\n", + " ]\n", + " )\n", + " static_cont_count = (\n", + " len(self.time_series_metadata[\"cols\"][\"st\"]) - static_cat_count\n", + " )\n", + "\n", + " metadata[\"static_categorical_features\"] = static_cat_count\n", + " metadata[\"static_continuous_features\"] = static_cont_count\n", + " else:\n", + " metadata[\"static_categorical_features\"] = 0\n", + " metadata[\"static_continuous_features\"] = 0\n", + "\n", + " metadata.update(\n", + " {\n", + " \"max_encoder_length\": self.max_encoder_length,\n", + " \"max_prediction_length\": self.max_prediction_length,\n", + " \"min_encoder_length\": self.min_encoder_length,\n", + " \"min_prediction_length\": self.min_prediction_length,\n", + " }\n", + " )\n", + "\n", + " return metadata\n", + "\n", + " @property\n", + " def metadata(self):\n", + " \"\"\"Compute metadata for model initialization.\n", + "\n", + " This property returns a dictionary containing the shapes and key information\n", + " related to the time series model. The metadata includes:\n", + "\n", + " * ``encoder_cat``: Number of categorical variables in the encoder.\n", + " * ``encoder_cont``: Number of continuous variables in the encoder.\n", + " * ``decoder_cat``: Number of categorical variables in the decoder that are\n", + " known in advance.\n", + " * ``decoder_cont``: Number of continuous variables in the decoder that are\n", + " known in advance.\n", + " * ``target``: Number of target variables.\n", + "\n", + " If static features are present, the following keys are added:\n", + "\n", + " * ``static_categorical_features``: Number of static categorical features\n", + " * ``static_continuous_features``: Number of static continuous features\n", + "\n", + " It also contains the following information:\n", + "\n", + " * ``max_encoder_length``: maximum encoder length\n", + " * ``max_prediction_length``: maximum prediction length\n", + " * ``min_encoder_length``: minimum encoder length\n", + " * ``min_prediction_length``: minimum prediction length\n", + " \"\"\"\n", + " if self._metadata is None:\n", + " self._metadata = self._prepare_metadata()\n", + " return self._metadata\n", + "\n", + " def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]:\n", + " \"\"\"Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset.\n", + "\n", + " Preprocessing steps\n", + " --------------------\n", + "\n", + " * Converts target (`y`) and features (`x`) to `torch.float32`.\n", + " * Masks time points that are at or before the cutoff time.\n", + " * Splits features into categorical and continuous subsets based on\n", + " predefined indices.\n", + "\n", + "\n", + " TODO: add scalers, target normalizers etc.\n", + " \"\"\"\n", + " processed_data = []\n", + "\n", + " for idx in indices:\n", + " sample = self.time_series_dataset[idx.item()]\n", + "\n", + " target = sample[\"y\"]\n", + " features = sample[\"x\"]\n", + " times = sample[\"t\"]\n", + " cutoff_time = sample[\"cutoff_time\"]\n", + "\n", + " time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool)\n", + "\n", + " if isinstance(target, torch.Tensor):\n", + " target = target.float()\n", + " else:\n", + " target = torch.tensor(target, dtype=torch.float32)\n", + "\n", + " if isinstance(features, torch.Tensor):\n", + " features = features.float()\n", + " else:\n", + " features = torch.tensor(features, dtype=torch.float32)\n", + "\n", + " # TODO: add scalers, target normalizers etc.\n", + "\n", + " categorical = (\n", + " features[:, self.categorical_indices]\n", + " if self.categorical_indices\n", + " else torch.zeros((features.shape[0], 0))\n", + " )\n", + " continuous = (\n", + " features[:, self.continuous_indices]\n", + " if self.continuous_indices\n", + " else torch.zeros((features.shape[0], 0))\n", + " )\n", + "\n", + " processed_data.append(\n", + " {\n", + " \"features\": {\"categorical\": categorical, \"continuous\": continuous},\n", + " \"target\": target,\n", + " \"static\": sample.get(\"st\", None),\n", + " \"group\": sample.get(\"group\", torch.tensor([0])),\n", + " \"length\": len(target),\n", + " \"time_mask\": time_mask,\n", + " \"times\": times,\n", + " \"cutoff_time\": cutoff_time,\n", + " }\n", + " )\n", + "\n", + " return processed_data\n", + "\n", + " class _ProcessedEncoderDecoderDataset(Dataset):\n", + " \"\"\"PyTorch Dataset for processed encoder-decoder time series data.\n", + "\n", + " Parameters\n", + " ----------\n", + " processed_data : List[Dict[str, Any]]\n", + " List of preprocessed time series samples.\n", + " windows : List[Tuple[int, int, int, int]]\n", + " List of window tuples containing\n", + " (series_idx, start_idx, enc_length, pred_length).\n", + " add_relative_time_idx : bool, default=False\n", + " Whether to include relative time indices.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " processed_data: List[Dict[str, Any]],\n", + " windows: List[Tuple[int, int, int, int]],\n", + " add_relative_time_idx: bool = False,\n", + " ):\n", + " self.processed_data = processed_data\n", + " self.windows = windows\n", + " self.add_relative_time_idx = add_relative_time_idx\n", + "\n", + " def __len__(self):\n", + " return len(self.windows)\n", + "\n", + " def __getitem__(self, idx):\n", + " \"\"\"Retrieve a processed time series window for dataloader input.\n", + "\n", + " x : dict\n", + " Dictionary containing model inputs:\n", + "\n", + " * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features)\n", + " Categorical features for the encoder.\n", + " * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features)\n", + " Continuous features for the encoder.\n", + " * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features)\n", + " Categorical features for the decoder.\n", + " * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features)\n", + " Continuous features for the decoder.\n", + " * ``encoder_lengths`` : tensor of shape (1,)\n", + " Length of the encoder sequence.\n", + " * ``decoder_lengths`` : tensor of shape (1,)\n", + " Length of the decoder sequence.\n", + " * ``decoder_target_lengths`` : tensor of shape (1,)\n", + " Length of the decoder target sequence.\n", + " * ``groups`` : tensor of shape (1,)\n", + " Group identifier for the time series instance.\n", + " * ``encoder_time_idx`` : tensor of shape (enc_length,)\n", + " Time indices for the encoder sequence.\n", + " * ``decoder_time_idx`` : tensor of shape (pred_length,)\n", + " Time indices for the decoder sequence.\n", + " * ``target_scale`` : tensor of shape (1,)\n", + " Scaling factor for the target values.\n", + " * ``encoder_mask`` : tensor of shape (enc_length,)\n", + " Boolean mask indicating valid encoder time points.\n", + " * ``decoder_mask`` : tensor of shape (pred_length,)\n", + " Boolean mask indicating valid decoder time points.\n", + "\n", + " If static features are present, the following keys are added:\n", + "\n", + " * ``static_categorical_features`` : tensor of shape\n", + " (1, n_static_cat_features), optional\n", + " Static categorical features, if available.\n", + " * ``static_continuous_features`` : tensor of shape (1, 0), optional\n", + " Placeholder for static continuous features (currently empty).\n", + "\n", + " y : tensor of shape ``(pred_length, n_targets)``\n", + " Target values for the decoder sequence.\n", + " \"\"\"\n", + " series_idx, start_idx, enc_length, pred_length = self.windows[idx]\n", + " data = self.processed_data[series_idx]\n", + "\n", + " end_idx = start_idx + enc_length + pred_length\n", + " encoder_indices = slice(start_idx, start_idx + enc_length)\n", + " decoder_indices = slice(start_idx + enc_length, end_idx)\n", + "\n", + " target_scale = data[\"target\"][encoder_indices]\n", + " target_scale = target_scale[~torch.isnan(target_scale)].abs().mean()\n", + " if torch.isnan(target_scale) or target_scale == 0:\n", + " target_scale = torch.tensor(1.0)\n", + "\n", + " encoder_mask = (\n", + " data[\"time_mask\"][encoder_indices]\n", + " if \"time_mask\" in data\n", + " else torch.ones(enc_length, dtype=torch.bool)\n", + " )\n", + " decoder_mask = (\n", + " data[\"time_mask\"][decoder_indices]\n", + " if \"time_mask\" in data\n", + " else torch.zeros(pred_length, dtype=torch.bool)\n", + " )\n", + "\n", + " x = {\n", + " \"encoder_cat\": data[\"features\"][\"categorical\"][encoder_indices],\n", + " \"encoder_cont\": data[\"features\"][\"continuous\"][encoder_indices],\n", + " \"decoder_cat\": data[\"features\"][\"categorical\"][decoder_indices],\n", + " \"decoder_cont\": data[\"features\"][\"continuous\"][decoder_indices],\n", + " \"encoder_lengths\": torch.tensor(enc_length),\n", + " \"decoder_lengths\": torch.tensor(pred_length),\n", + " \"decoder_target_lengths\": torch.tensor(pred_length),\n", + " \"groups\": data[\"group\"],\n", + " \"encoder_time_idx\": torch.arange(enc_length),\n", + " \"decoder_time_idx\": torch.arange(enc_length, enc_length + pred_length),\n", + " \"target_scale\": target_scale,\n", + " \"encoder_mask\": encoder_mask,\n", + " \"decoder_mask\": decoder_mask,\n", + " }\n", + " if data[\"static\"] is not None:\n", + " x[\"static_categorical_features\"] = data[\"static\"].unsqueeze(0)\n", + " x[\"static_continuous_features\"] = torch.zeros((1, 0))\n", + "\n", + " y = data[\"target\"][decoder_indices]\n", + " if y.ndim == 1:\n", + " y = y.unsqueeze(-1)\n", + "\n", + " return x, y\n", + "\n", + " def _create_windows(\n", + " self, processed_data: List[Dict[str, Any]]\n", + " ) -> List[Tuple[int, int, int, int]]:\n", + " \"\"\"Generate sliding windows for training, validation, and testing.\n", + "\n", + " Returns\n", + " -------\n", + " List[Tuple[int, int, int, int]]\n", + " A list of tuples, where each tuple consists of:\n", + " - ``series_idx`` : int\n", + " Index of the time series in `processed_data`.\n", + " - ``start_idx`` : int\n", + " Start index of the encoder window.\n", + " - ``enc_length`` : int\n", + " Length of the encoder input sequence.\n", + " - ``pred_length`` : int\n", + " Length of the decoder output sequence.\n", + " \"\"\"\n", + " windows = []\n", + "\n", + " for idx, data in enumerate(processed_data):\n", + " sequence_length = data[\"length\"]\n", + "\n", + " if sequence_length < self.max_encoder_length + self.max_prediction_length:\n", + " continue\n", + "\n", + " effective_min_prediction_idx = (\n", + " self.min_prediction_idx\n", + " if self.min_prediction_idx is not None\n", + " else self.max_encoder_length\n", + " )\n", + "\n", + " max_prediction_idx = sequence_length - self.max_prediction_length + 1\n", + "\n", + " if max_prediction_idx <= effective_min_prediction_idx:\n", + " continue\n", + "\n", + " for start_idx in range(\n", + " 0, max_prediction_idx - effective_min_prediction_idx\n", + " ):\n", + " if (\n", + " start_idx + self.max_encoder_length + self.max_prediction_length\n", + " <= sequence_length\n", + " ):\n", + " windows.append(\n", + " (\n", + " idx,\n", + " start_idx,\n", + " self.max_encoder_length,\n", + " self.max_prediction_length,\n", + " )\n", + " )\n", + "\n", + " return windows\n", + "\n", + " def setup(self, stage: Optional[str] = None):\n", + " \"\"\"Prepare the datasets for training, validation, testing, or prediction.\n", + "\n", + " Parameters\n", + " ----------\n", + " stage : Optional[str], default=None\n", + " Specifies the stage of setup. Can be one of:\n", + " - ``\"fit\"`` : Prepares training and validation datasets.\n", + " - ``\"test\"`` : Prepares the test dataset.\n", + " - ``\"predict\"`` : Prepares the dataset for inference.\n", + " - ``None`` : Prepares all datasets.\n", + " \"\"\"\n", + " total_series = len(self.time_series_dataset)\n", + " self._split_indices = torch.randperm(total_series)\n", + "\n", + " self._train_size = int(self.train_val_test_split[0] * total_series)\n", + " self._val_size = int(self.train_val_test_split[1] * total_series)\n", + "\n", + " self._train_indices = self._split_indices[: self._train_size]\n", + " self._val_indices = self._split_indices[\n", + " self._train_size : self._train_size + self._val_size\n", + " ]\n", + " self._test_indices = self._split_indices[self._train_size + self._val_size :]\n", + "\n", + " if stage is None or stage == \"fit\":\n", + " if not hasattr(self, \"train_dataset\") or not hasattr(self, \"val_dataset\"):\n", + " self.train_processed = self._preprocess_data(self._train_indices)\n", + " self.val_processed = self._preprocess_data(self._val_indices)\n", + "\n", + " self.train_windows = self._create_windows(self.train_processed)\n", + " self.val_windows = self._create_windows(self.val_processed)\n", + "\n", + " self.train_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.train_processed, self.train_windows, self.add_relative_time_idx\n", + " )\n", + " self.val_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.val_processed, self.val_windows, self.add_relative_time_idx\n", + " )\n", + " # print(self.val_dataset[0])\n", + "\n", + " elif stage is None or stage == \"test\":\n", + " if not hasattr(self, \"test_dataset\"):\n", + " self.test_processed = self._preprocess_data(self._test_indices)\n", + " self.test_windows = self._create_windows(self.test_processed)\n", + "\n", + " self.test_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.test_processed, self.test_windows, self.add_relative_time_idx\n", + " )\n", + " elif stage == \"predict\":\n", + " predict_indices = torch.arange(len(self.time_series_dataset))\n", + " self.predict_processed = self._preprocess_data(predict_indices)\n", + " self.predict_windows = self._create_windows(self.predict_processed)\n", + " self.predict_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.predict_processed, self.predict_windows, self.add_relative_time_idx\n", + " )\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(\n", + " self.train_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " shuffle=True,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(\n", + " self.val_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(\n", + " self.test_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " def predict_dataloader(self):\n", + " return DataLoader(\n", + " self.predict_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " @staticmethod\n", + " def collate_fn(batch):\n", + " x_batch = {\n", + " \"encoder_cat\": torch.stack([x[\"encoder_cat\"] for x, _ in batch]),\n", + " \"encoder_cont\": torch.stack([x[\"encoder_cont\"] for x, _ in batch]),\n", + " \"decoder_cat\": torch.stack([x[\"decoder_cat\"] for x, _ in batch]),\n", + " \"decoder_cont\": torch.stack([x[\"decoder_cont\"] for x, _ in batch]),\n", + " \"encoder_lengths\": torch.stack([x[\"encoder_lengths\"] for x, _ in batch]),\n", + " \"decoder_lengths\": torch.stack([x[\"decoder_lengths\"] for x, _ in batch]),\n", + " \"decoder_target_lengths\": torch.stack(\n", + " [x[\"decoder_target_lengths\"] for x, _ in batch]\n", + " ),\n", + " \"groups\": torch.stack([x[\"groups\"] for x, _ in batch]),\n", + " \"encoder_time_idx\": torch.stack([x[\"encoder_time_idx\"] for x, _ in batch]),\n", + " \"decoder_time_idx\": torch.stack([x[\"decoder_time_idx\"] for x, _ in batch]),\n", + " \"target_scale\": torch.stack([x[\"target_scale\"] for x, _ in batch]),\n", + " \"encoder_mask\": torch.stack([x[\"encoder_mask\"] for x, _ in batch]),\n", + " \"decoder_mask\": torch.stack([x[\"decoder_mask\"] for x, _ in batch]),\n", + " }\n", + "\n", + " if \"static_categorical_features\" in batch[0][0]:\n", + " x_batch[\"static_categorical_features\"] = torch.stack(\n", + " [x[\"static_categorical_features\"] for x, _ in batch]\n", + " )\n", + " x_batch[\"static_continuous_features\"] = torch.stack(\n", + " [x[\"static_continuous_features\"] for x, _ in batch]\n", + " )\n", + "\n", + " y_batch = torch.stack([y for _, y in batch])\n", + " return x_batch, y_batch" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "WX-FRdusJSVN", + "outputId": "730f0fe2-f5af-4871-859d-9a4043bbeac7" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6723608094711289,\n \"min\": -1.2295384314749236,\n \"max\": 1.3194322331654313,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.3958947786657995,\n 0.7816648993958805,\n -0.9655256111265276\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6766981377008536,\n \"min\": -1.2295384314749236,\n \"max\": 1.3194322331654313,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.5707668517530808,\n 0.5020485177883972,\n -0.7734543579445009\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2783997753182645,\n \"min\": 0.011560494046953695,\n \"max\": 0.9996855497257285,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.43066463390546217,\n 0.08751257529405387,\n 0.3593350820130162\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", + "type": "dataframe", + "variable_name": "data_df" + }, + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
0000.0231380.24983401.0000000.4546680
1010.2498340.21382100.9950040.4546680
2020.2138210.67182900.9800670.4546680
3030.6718290.78104200.9553360.4546680
4040.7810420.70609200.9210610.4546680
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "text/plain": [ + " series_id time_idx x y category future_known_feature \\\n", + "0 0 0 0.023138 0.249834 0 1.000000 \n", + "1 0 1 0.249834 0.213821 0 0.995004 \n", + "2 0 2 0.213821 0.671829 0 0.980067 \n", + "3 0 3 0.671829 0.781042 0 0.955336 \n", + "4 0 4 0.781042 0.706092 0 0.921061 \n", + "\n", + " static_feature static_feature_cat \n", + "0 0.454668 0 \n", + "1 0.454668 0 \n", + "2 0.454668 0 \n", + "3 0.454668 0 \n", + "4 0.454668 0 " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from lightning.pytorch import Trainer\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "\n", + "from pytorch_forecasting.metrics import MAE, SMAPE\n", + "\n", + "num_series = 100\n", + "seq_length = 50\n", + "data_list = []\n", + "for i in range(num_series):\n", + " x = np.arange(seq_length)\n", + " y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)\n", + " category = i % 5\n", + " static_value = np.random.rand()\n", + " for t in range(seq_length - 1):\n", + " data_list.append(\n", + " {\n", + " \"series_id\": i,\n", + " \"time_idx\": t,\n", + " \"x\": y[t],\n", + " \"y\": y[t + 1],\n", + " \"category\": category,\n", + " \"future_known_feature\": np.cos(t / 10),\n", + " \"static_feature\": static_value,\n", + " \"static_feature_cat\": i % 3,\n", + " }\n", + " )\n", + "data_df = pd.DataFrame(data_list)\n", + "data_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "AxxPHK6AKSD2" + }, + "outputs": [], + "source": [ + "dataset = TimeSeries(\n", + " data=data_df,\n", + " time=\"time_idx\",\n", + " target=\"y\",\n", + " group=[\"series_id\"],\n", + " num=[\"x\", \"future_known_feature\", \"static_feature\"],\n", + " cat=[\"category\", \"static_feature_cat\"],\n", + " known=[\"future_known_feature\"],\n", + " unknown=[\"x\", \"category\"],\n", + " static=[\"static_feature\", \"static_feature_cat\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "5U5Lr_ZFKX0s" + }, + "outputs": [], + "source": [ + "data_module = EncoderDecoderTimeSeriesDataModule(\n", + " time_series_dataset=dataset,\n", + " max_encoder_length=30,\n", + " max_prediction_length=1,\n", + " batch_size=32,\n", + " categorical_encoders={\n", + " \"category\": NaNLabelEncoder(add_nan=True),\n", + " \"static_feature_cat\": NaNLabelEncoder(add_nan=True),\n", + " },\n", + " scalers={\n", + " \"x\": StandardScaler(),\n", + " \"future_known_feature\": StandardScaler(),\n", + " \"static_feature\": StandardScaler(),\n", + " },\n", + " target_normalizer=TorchNormalizer(),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "I8NgHxNqK9uV" + }, + "outputs": [], + "source": [ + "from typing import Dict, List, Optional, Union\n", + "\n", + "from lightning.pytorch import LightningModule\n", + "from lightning.pytorch.utilities.types import STEP_OUTPUT\n", + "import torch\n", + "from torch.optim import Optimizer\n", + "\n", + "\n", + "class BaseModel(LightningModule):\n", + " def __init__(\n", + " self,\n", + " loss: nn.Module,\n", + " logging_metrics: Optional[List[nn.Module]] = None,\n", + " optimizer: Optional[Union[Optimizer, str]] = \"adam\",\n", + " optimizer_params: Optional[Dict] = None,\n", + " lr_scheduler: Optional[str] = None,\n", + " lr_scheduler_params: Optional[Dict] = None,\n", + " ):\n", + " \"\"\"\n", + " Base model for time series forecasting.\n", + "\n", + " Parameters\n", + " ----------\n", + " loss : nn.Module\n", + " Loss function to use for training.\n", + " logging_metrics : Optional[List[nn.Module]], optional\n", + " List of metrics to log during training, validation, and testing.\n", + " optimizer : Optional[Union[Optimizer, str]], optional\n", + " Optimizer to use for training.\n", + " Can be a string (\"adam\", \"sgd\") or an instance of `torch.optim.Optimizer`.\n", + " optimizer_params : Optional[Dict], optional\n", + " Parameters for the optimizer.\n", + " lr_scheduler : Optional[str], optional\n", + " Learning rate scheduler to use.\n", + " Supported values: \"reduce_lr_on_plateau\", \"step_lr\".\n", + " lr_scheduler_params : Optional[Dict], optional\n", + " Parameters for the learning rate scheduler.\n", + " \"\"\"\n", + " super().__init__()\n", + " self.loss = loss\n", + " self.logging_metrics = logging_metrics if logging_metrics is not None else []\n", + " self.optimizer = optimizer\n", + " self.optimizer_params = optimizer_params if optimizer_params is not None else {}\n", + " self.lr_scheduler = lr_scheduler\n", + " self.lr_scheduler_params = (\n", + " lr_scheduler_params if lr_scheduler_params is not None else {}\n", + " )\n", + "\n", + " def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n", + " \"\"\"\n", + " Forward pass of the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " x : Dict[str, torch.Tensor]\n", + " Dictionary containing input tensors\n", + "\n", + " Returns\n", + " -------\n", + " Dict[str, torch.Tensor]\n", + " Dictionary containing output tensors\n", + " \"\"\"\n", + " raise NotImplementedError(\"Forward method must be implemented by subclass.\")\n", + "\n", + " def training_step(\n", + " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", + " ) -> STEP_OUTPUT:\n", + " \"\"\"\n", + " Training step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input and target tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + "\n", + " Returns\n", + " -------\n", + " STEP_OUTPUT\n", + " Dictionary containing the loss and other metrics.\n", + " \"\"\"\n", + " x, y = batch\n", + " y_hat_dict = self(x)\n", + " y_hat = y_hat_dict[\"prediction\"]\n", + " loss = self.loss(y_hat, y)\n", + " self.log(\n", + " \"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True\n", + " )\n", + " self.log_metrics(y_hat, y, prefix=\"train\")\n", + " return {\"loss\": loss}\n", + "\n", + " def validation_step(\n", + " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", + " ) -> STEP_OUTPUT:\n", + " \"\"\"\n", + " Validation step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input and target tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + "\n", + " Returns\n", + " -------\n", + " STEP_OUTPUT\n", + " Dictionary containing the loss and other metrics.\n", + " \"\"\"\n", + " x, y = batch\n", + " y_hat_dict = self(x)\n", + " y_hat = y_hat_dict[\"prediction\"]\n", + " loss = self.loss(y_hat, y)\n", + " self.log(\n", + " \"val_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", + " )\n", + " self.log_metrics(y_hat, y, prefix=\"val\")\n", + " return {\"val_loss\": loss}\n", + "\n", + " def test_step(\n", + " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", + " ) -> STEP_OUTPUT:\n", + " \"\"\"\n", + " Test step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input and target tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + "\n", + " Returns\n", + " -------\n", + " STEP_OUTPUT\n", + " Dictionary containing the loss and other metrics.\n", + " \"\"\"\n", + " x, y = batch\n", + " y_hat_dict = self(x)\n", + " y_hat = y_hat_dict[\"prediction\"]\n", + " loss = self.loss(y_hat, y)\n", + " self.log(\n", + " \"test_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", + " )\n", + " self.log_metrics(y_hat, y, prefix=\"test\")\n", + " return {\"test_loss\": loss}\n", + "\n", + " def predict_step(\n", + " self,\n", + " batch: Tuple[Dict[str, torch.Tensor]],\n", + " batch_idx: int,\n", + " dataloader_idx: int = 0,\n", + " ) -> torch.Tensor:\n", + " \"\"\"\n", + " Prediction step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + " dataloader_idx : int\n", + " Index of the dataloader.\n", + "\n", + " Returns\n", + " -------\n", + " torch.Tensor\n", + " Predicted output tensor.\n", + " \"\"\"\n", + " x, _ = batch\n", + " y_hat = self(x)\n", + " return y_hat\n", + "\n", + " def configure_optimizers(self) -> Dict:\n", + " \"\"\"\n", + " Configure the optimizer and learning rate scheduler.\n", + "\n", + " Returns\n", + " -------\n", + " Dict\n", + " Dictionary containing the optimizer and scheduler configuration.\n", + " \"\"\"\n", + " optimizer = self._get_optimizer()\n", + " if self.lr_scheduler is not None:\n", + " scheduler = self._get_scheduler(optimizer)\n", + " if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n", + " return {\n", + " \"optimizer\": optimizer,\n", + " \"lr_scheduler\": {\n", + " \"scheduler\": scheduler,\n", + " \"monitor\": \"val_loss\",\n", + " },\n", + " }\n", + " else:\n", + " return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n", + " return {\"optimizer\": optimizer}\n", + "\n", + " def _get_optimizer(self) -> Optimizer:\n", + " \"\"\"\n", + " Get the optimizer based on the specified optimizer name and parameters.\n", + "\n", + " Returns\n", + " -------\n", + " Optimizer\n", + " The optimizer instance.\n", + " \"\"\"\n", + " if isinstance(self.optimizer, str):\n", + " if self.optimizer.lower() == \"adam\":\n", + " return torch.optim.Adam(self.parameters(), **self.optimizer_params)\n", + " elif self.optimizer.lower() == \"sgd\":\n", + " return torch.optim.SGD(self.parameters(), **self.optimizer_params)\n", + " else:\n", + " raise ValueError(f\"Optimizer {self.optimizer} not supported.\")\n", + " elif isinstance(self.optimizer, Optimizer):\n", + " return self.optimizer\n", + " else:\n", + " raise ValueError(\n", + " \"Optimizer must be either a string or \"\n", + " \"an instance of torch.optim.Optimizer.\"\n", + " )\n", + "\n", + " def _get_scheduler(\n", + " self, optimizer: Optimizer\n", + " ) -> torch.optim.lr_scheduler._LRScheduler:\n", + " \"\"\"\n", + " Get the lr scheduler based on the specified scheduler name and params.\n", + "\n", + " Parameters\n", + " ----------\n", + " optimizer : Optimizer\n", + " The optimizer instance.\n", + "\n", + " Returns\n", + " -------\n", + " torch.optim.lr_scheduler._LRScheduler\n", + " The learning rate scheduler instance.\n", + " \"\"\"\n", + " if self.lr_scheduler.lower() == \"reduce_lr_on_plateau\":\n", + " return torch.optim.lr_scheduler.ReduceLROnPlateau(\n", + " optimizer, **self.lr_scheduler_params\n", + " )\n", + " elif self.lr_scheduler.lower() == \"step_lr\":\n", + " return torch.optim.lr_scheduler.StepLR(\n", + " optimizer, **self.lr_scheduler_params\n", + " )\n", + " else:\n", + " raise ValueError(f\"Scheduler {self.lr_scheduler} not supported.\")\n", + "\n", + " def log_metrics(\n", + " self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = \"val\"\n", + " ) -> None:\n", + " \"\"\"\n", + " Log additional metrics during training, validation, or testing.\n", + "\n", + " Parameters\n", + " ----------\n", + " y_hat : torch.Tensor\n", + " Predicted output tensor.\n", + " y : torch.Tensor\n", + " Target output tensor.\n", + " prefix : str\n", + " Prefix for the logged metrics (e.g., \"train\", \"val\", \"test\").\n", + " \"\"\"\n", + " for metric in self.logging_metrics:\n", + " metric_value = metric(y_hat, y)\n", + " self.log(\n", + " f\"{prefix}_{metric.__class__.__name__}\",\n", + " metric_value,\n", + " on_step=False,\n", + " on_epoch=True,\n", + " prog_bar=True,\n", + " logger=True,\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "n5EBeGucK_k0" + }, + "outputs": [], + "source": [ + "from typing import Dict, List, Optional, Union\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import Optimizer\n", + "\n", + "\n", + "class TFT(BaseModel):\n", + " def __init__(\n", + " self,\n", + " loss: nn.Module,\n", + " logging_metrics: Optional[List[nn.Module]] = None,\n", + " optimizer: Optional[Union[Optimizer, str]] = \"adam\",\n", + " optimizer_params: Optional[Dict] = None,\n", + " lr_scheduler: Optional[str] = None,\n", + " lr_scheduler_params: Optional[Dict] = None,\n", + " hidden_size: int = 64,\n", + " num_layers: int = 2,\n", + " attention_head_size: int = 4,\n", + " dropout: float = 0.1,\n", + " metadata: Optional[Dict] = None,\n", + " output_size: int = 1,\n", + " ):\n", + " super().__init__(\n", + " loss=loss,\n", + " logging_metrics=logging_metrics,\n", + " optimizer=optimizer,\n", + " optimizer_params=optimizer_params,\n", + " lr_scheduler=lr_scheduler,\n", + " lr_scheduler_params=lr_scheduler_params,\n", + " )\n", + " self.hidden_size = hidden_size\n", + " self.num_layers = num_layers\n", + " self.attention_head_size = attention_head_size\n", + " self.dropout = dropout\n", + " self.metadata = metadata\n", + " self.output_size = output_size\n", + "\n", + " self.max_encoder_length = self.metadata[\"max_encoder_length\"]\n", + " self.max_prediction_length = self.metadata[\"max_prediction_length\"]\n", + " self.encoder_cont = self.metadata[\"encoder_cont\"]\n", + " self.encoder_cat = self.metadata[\"encoder_cat\"]\n", + " self.static_categorical_features = self.metadata[\"static_categorical_features\"]\n", + " self.static_continuous_features = self.metadata[\"static_continuous_features\"]\n", + "\n", + " total_feature_size = self.encoder_cont + self.encoder_cat\n", + " total_static_size = (\n", + " self.static_categorical_features + self.static_continuous_features\n", + " )\n", + "\n", + " self.encoder_var_selection = nn.Sequential(\n", + " nn.Linear(total_feature_size, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_size, total_feature_size),\n", + " nn.Sigmoid(),\n", + " )\n", + "\n", + " self.decoder_var_selection = nn.Sequential(\n", + " nn.Linear(total_feature_size, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_size, total_feature_size),\n", + " nn.Sigmoid(),\n", + " )\n", + "\n", + " self.static_context_linear = (\n", + " nn.Linear(total_static_size, hidden_size) if total_static_size > 0 else None\n", + " )\n", + "\n", + " self.lstm_encoder = nn.LSTM(\n", + " input_size=total_feature_size,\n", + " hidden_size=hidden_size,\n", + " num_layers=num_layers,\n", + " dropout=dropout,\n", + " batch_first=True,\n", + " )\n", + "\n", + " self.lstm_decoder = nn.LSTM(\n", + " input_size=total_feature_size,\n", + " hidden_size=hidden_size,\n", + " num_layers=num_layers,\n", + " dropout=dropout,\n", + " batch_first=True,\n", + " )\n", + "\n", + " self.self_attention = nn.MultiheadAttention(\n", + " embed_dim=hidden_size,\n", + " num_heads=attention_head_size,\n", + " dropout=dropout,\n", + " batch_first=True,\n", + " )\n", + "\n", + " self.pre_output = nn.Linear(hidden_size, hidden_size)\n", + " self.output_layer = nn.Linear(hidden_size, output_size)\n", + "\n", + " def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n", + " \"\"\"\n", + " Forward pass of the TFT model.\n", + "\n", + " Parameters\n", + " ----------\n", + " x : Dict[str, torch.Tensor]\n", + " Dictionary containing input tensors:\n", + " - encoder_cat: Categorical encoder features\n", + " - encoder_cont: Continuous encoder features\n", + " - decoder_cat: Categorical decoder features\n", + " - decoder_cont: Continuous decoder features\n", + " - static_categorical_features: Static categorical features\n", + " - static_continuous_features: Static continuous features\n", + "\n", + " Returns\n", + " -------\n", + " Dict[str, torch.Tensor]\n", + " Dictionary containing output tensors:\n", + " - prediction: Prediction output (batch_size, prediction_length, output_size)\n", + " \"\"\"\n", + " batch_size = x[\"encoder_cont\"].shape[0]\n", + "\n", + " encoder_cat = x.get(\n", + " \"encoder_cat\",\n", + " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", + " )\n", + " encoder_cont = x.get(\n", + " \"encoder_cont\",\n", + " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", + " )\n", + " decoder_cat = x.get(\n", + " \"decoder_cat\",\n", + " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", + " )\n", + " decoder_cont = x.get(\n", + " \"decoder_cont\",\n", + " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", + " )\n", + "\n", + " encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2)\n", + " decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2)\n", + "\n", + " static_context = None\n", + " if self.static_context_linear is not None:\n", + " static_cat = x.get(\n", + " \"static_categorical_features\",\n", + " torch.zeros(batch_size, 0, device=self.device),\n", + " )\n", + " static_cont = x.get(\n", + " \"static_continuous_features\",\n", + " torch.zeros(batch_size, 0, device=self.device),\n", + " )\n", + "\n", + " if static_cat.size(2) == 0 and static_cont.size(2) == 0:\n", + " static_context = None\n", + " elif static_cat.size(2) == 0:\n", + " static_input = static_cont.to(\n", + " dtype=self.static_context_linear.weight.dtype\n", + " )\n", + " static_context = self.static_context_linear(static_input)\n", + " static_context = static_context.view(batch_size, self.hidden_size)\n", + " elif static_cont.size(2) == 0:\n", + " static_input = static_cat.to(\n", + " dtype=self.static_context_linear.weight.dtype\n", + " )\n", + " static_context = self.static_context_linear(static_input)\n", + " static_context = static_context.view(batch_size, self.hidden_size)\n", + " else:\n", + "\n", + " static_input = torch.cat([static_cont, static_cat], dim=1).to(\n", + " dtype=self.static_context_linear.weight.dtype\n", + " )\n", + " static_context = self.static_context_linear(static_input)\n", + " static_context = static_context.view(batch_size, self.hidden_size)\n", + "\n", + " encoder_weights = self.encoder_var_selection(encoder_input)\n", + " encoder_input = encoder_input * encoder_weights\n", + "\n", + " decoder_weights = self.decoder_var_selection(decoder_input)\n", + " decoder_input = decoder_input * decoder_weights\n", + "\n", + " if static_context is not None:\n", + " encoder_static_context = static_context.unsqueeze(1).expand(\n", + " -1, self.max_encoder_length, -1\n", + " )\n", + " decoder_static_context = static_context.unsqueeze(1).expand(\n", + " -1, self.max_prediction_length, -1\n", + " )\n", + "\n", + " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", + " encoder_output = encoder_output + encoder_static_context\n", + " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", + " decoder_output = decoder_output + decoder_static_context\n", + " else:\n", + " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", + " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", + "\n", + " sequence = torch.cat([encoder_output, decoder_output], dim=1)\n", + "\n", + " if static_context is not None:\n", + " expanded_static_context = static_context.unsqueeze(1).expand(\n", + " -1, sequence.size(1), -1\n", + " )\n", + "\n", + " attended_output, _ = self.self_attention(\n", + " sequence + expanded_static_context, sequence, sequence\n", + " )\n", + " else:\n", + " attended_output, _ = self.self_attention(sequence, sequence, sequence)\n", + "\n", + " decoder_attended = attended_output[:, -self.max_prediction_length :, :]\n", + "\n", + " output = nn.functional.relu(self.pre_output(decoder_attended))\n", + " prediction = self.output_layer(output)\n", + "\n", + " return {\"prediction\": prediction}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "f422383b4af64a26b16842297d5c84a6", + "51455b67f5e741b192b5b832aff3d3d3", + "ccd9a65e258244f4bd29c741c3cb4441", + "8b5d5e147c8142328d30e2492b57e2c1", + "3529aafd3f884efe8ba4bb3e33a05c5c", + "3852c1b855934e6e862ceaf37a855700", + "5ad79ce792314292a0425b6f03b85881", + "e9dcea6508f344caa5f1f99e2bfdd4db", + "2b8d1923ff3145789f77661364163cd1", + "2b2eb8cd85ba4c16aefa42634b5458cb", + "edc278a53bb74140ad1a416406f0b0bb", + "7b4b93e92400404abfd2c31129cda96f", + "35e7d924e4e44ceaad836bde8fef17d8", + "6922f0547118463e9468196bf9b1a5c2", + "119506368b47405786c806373b0bb67d", + "27ec3bfc950847afb2850ad9ea9bdd49", + "1016ad0f475740ef9e2d5054323d22e7", + "b232c9df604a4ce59f422cf60499605f", + "f9cb515fd148446ca5fcb47fab504a74", + "adf882d15ee84b2aa3ad6145e2c913bf", + "f9b429aef6904d6995836f1c1279d38d", + "8151f2de21b44422876341b10e6b8568", + "8e46e7b0ea3b429293dc6358980dffe2", + "831387d48f5f4b8cb0662c5df77aefa4", + "d08f30db32ec440f94d88727adfc3ee6", + "9dd7c13be6a54e6b843ccb8afd59fb6b", + "2db1853b41104999ab6884b33b217f0b", + "b868fe32ecc44b50b823ae533755a8ef", + "0d2d00cf7992454484dfd939928aca14", + "f6131d3c810442c5ad4f71177dac1f5e", + "8997cb6735c341bc985d4aacb6e20999", + "df6002b4a12a49fe8d62210c5ebd5b06", + "adaafe80e3ee4a5c8d6d987800630f2e", + "2828e899a43b4f17b0bd7c86a701152b", + "059bbf019de944fda861994e0e7439dd", + "197c1200910e4256945d64dd9ab33902", + "4a22ad1ba0c240a3af31a06d117220fb", + "3ac302bd9b054b9dbe0b76ba7575db68", + "78f9afd80d634c08965990e66297e5b6", + "b791ccce40544f6ba57198a0392df981", + "bff2fcab4d574dc0bcc7382a17f6b080", + "607dc1dee95147f4a39b16170a08f780", + "0c40053cf7ba496d87691bfeee0975e4", + "7b989ec8edc148cb9c3e58aa61637a3b", + "94a5cc9098f5445593b3212857bd0da2", + "c236df5d6fdf4bd1b9fdf687819bdffb", + "e3fcdbd89eb64510b3ed6947d728de1f", + "52a0d1617601472590aa3b5641236890", + "ef103e08597e424a81ff70060b4edc77", + "d25815a46ca64bbea97e6b7f4fff954a", + "3eb0147ac4d345faae3713d64eb0e66c", + "a561990cc8204ef08effecce2855a882", + "afd962bafddc4be2b6ec83067402a19a", + "c9de18b7a60148e1b35c0dfb0b20e0d5", + "5f4abc4cadbc4fc7aae645d53f2a03aa", + "02a2df129afd4ec5a0e6a11ad8d67ee3", + "eb40207588114030aeb2e66e6a66858b", + "344cd98e1e4f49dba72ad6df49695e1d", + "4f5df150e0044ee798c04fa2c722f6c0", + "5307c74c93dc472bbeeb88e6bfb487f1", + "8792c8ef2b3d43748b0062327d901bb3", + "fa845f388d7844a388f079d7cb115824", + "3d59af9594a445d68dbcc8e51994fcbf", + "fb7a0d9d587e4806a7519d9201f7f0aa", + "13d0c7ce5a394744a44439556d43d24d", + "d09998b10967447e917c8ac88abf26db", + "e03f176873e6485a91118a9cb4d7fa4d", + "3866642b5f6c4eaa8f3982da9efa2a4b", + "fb1c1f0130934f24b163d87846f4210c", + "0c32b264b85b437f825531554a51e7e1", + "056e8d3d437c40c889f19e316fbd38ac", + "f2fb8e95caad4be792063e44a9401f2c", + "46fd6f7300514bd9b11abeafc51970ec", + "642bce0f7ab848459f4b5c91068113d4", + "dc1b070cf7bd43ed8c422e2c1cec6438", + "2fad0082ed35466caaf6948f6ff4aac0", + "2d9cd71eec004370b5780a7814d4dd7b", + "db2a9324292e415e8765117e3dab8356", + "fae0d1779ea045108616dad057ed24e1", + "d22ee6b52bbf45bdbe05eb3c23b966c9", + "e579bbd44a5241f7a7fa027147295ec3", + "9eb0c593adf7413da7a4abeec48af166", + "57cb9f2255e4487383465ea0a438a01b", + "b3d493c1744248fc92301d57dcc07fe9", + "c8ddbeb5fdb9447c88561abb59edac29", + "6fd91510b57043c093e4a1fd8302fec3", + "e623887237994d72bfbff0fb2d76e697", + "debe1c18531a401e8e3e93791787e1f8" + ] + }, + "id": "Si7bbZIULBZz", + "outputId": "8c1308d6-c1b3-4a67-f56c-f11c225b592c" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.\n", + "INFO:lightning.pytorch.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.\n", + "INFO: GPU available: False, used: False\n", + "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", + "INFO: TPU available: False, using: 0 TPU cores\n", + "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO: HPU available: False, using: 0 HPUs\n", + "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n", + "INFO: \n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 709 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 51.5 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "125 K Trainable params\n", + "0 Non-trainable params\n", + "125 K Total params\n", + "0.502 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n", + "INFO:lightning.pytorch.callbacks.model_summary:\n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 709 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 51.5 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "125 K Trainable params\n", + "0 Non-trainable params\n", + "125 K Total params\n", + "0.502 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Training model...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f422383b4af64a26b16842297d5c84a6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test_MAE 0.4572468101978302 │\n", + "│ test_SMAPE 1.0497652292251587 │\n", + "│ test_loss 0.022910255938768387 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.4572468101978302 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0497652292251587 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.022910255938768387 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Prediction shape: torch.Size([32, 1, 1])\n", + "First prediction values: [[-0.06341379]]\n", + "First true values: [[0.08132173]]\n", + "\n", + "TFT model test complete!\n" + ] + } + ], + "source": [ + "model = TFT(\n", + " loss=nn.MSELoss(),\n", + " logging_metrics=[MAE(), SMAPE()],\n", + " optimizer=\"adam\",\n", + " optimizer_params={\"lr\": 1e-3},\n", + " lr_scheduler=\"reduce_lr_on_plateau\",\n", + " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", + " hidden_size=64,\n", + " num_layers=2,\n", + " attention_head_size=4,\n", + " dropout=0.1,\n", + " metadata=data_module.metadata,\n", + ")\n", + "\n", + "print(\"\\nTraining model...\")\n", + "trainer = Trainer(max_epochs=5, accelerator=\"auto\", devices=1, enable_progress_bar=True)\n", + "\n", + "trainer.fit(model, data_module)\n", + "\n", + "print(\"\\nEvaluating model...\")\n", + "test_metrics = trainer.test(model, data_module)\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " test_batch = next(iter(data_module.test_dataloader()))\n", + " x_test, y_test = test_batch\n", + " y_pred = model(x_test)\n", + "\n", + " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", + " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", + " print(\"First true values:\", y_test[0].cpu().numpy())\n", + "print(\"\\nTFT model test complete!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zVRwi2MvLGgc" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "02a2df129afd4ec5a0e6a11ad8d67ee3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_eb40207588114030aeb2e66e6a66858b", + "IPY_MODEL_344cd98e1e4f49dba72ad6df49695e1d", + "IPY_MODEL_4f5df150e0044ee798c04fa2c722f6c0" + ], + "layout": "IPY_MODEL_5307c74c93dc472bbeeb88e6bfb487f1" + } + }, + "056e8d3d437c40c889f19e316fbd38ac": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "059bbf019de944fda861994e0e7439dd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_78f9afd80d634c08965990e66297e5b6", + "placeholder": "​", + "style": "IPY_MODEL_b791ccce40544f6ba57198a0392df981", + "value": "Validation DataLoader 0: 100%" + } + }, + "0c32b264b85b437f825531554a51e7e1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2fad0082ed35466caaf6948f6ff4aac0", + "placeholder": "​", + "style": "IPY_MODEL_2d9cd71eec004370b5780a7814d4dd7b", + "value": " 9/9 [00:00<00:00, 31.33it/s]" + } + }, + "0c40053cf7ba496d87691bfeee0975e4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0d2d00cf7992454484dfd939928aca14": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "1016ad0f475740ef9e2d5054323d22e7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "119506368b47405786c806373b0bb67d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f9b429aef6904d6995836f1c1279d38d", + "placeholder": "​", + "style": "IPY_MODEL_8151f2de21b44422876341b10e6b8568", + "value": " 42/42 [00:02<00:00, 15.49it/s, v_num=2, train_loss_step=0.010, val_loss=0.0243, val_MAE=0.477, val_SMAPE=1.120, train_loss_epoch=0.0156, train_MAE=0.480, train_SMAPE=1.020]" + } + }, + "13d0c7ce5a394744a44439556d43d24d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "197c1200910e4256945d64dd9ab33902": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bff2fcab4d574dc0bcc7382a17f6b080", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_607dc1dee95147f4a39b16170a08f780", + "value": 9 + } + }, + "27ec3bfc950847afb2850ad9ea9bdd49": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "100%" + } + }, + "2828e899a43b4f17b0bd7c86a701152b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_059bbf019de944fda861994e0e7439dd", + "IPY_MODEL_197c1200910e4256945d64dd9ab33902", + "IPY_MODEL_4a22ad1ba0c240a3af31a06d117220fb" + ], + "layout": "IPY_MODEL_3ac302bd9b054b9dbe0b76ba7575db68" + } + }, + "2b2eb8cd85ba4c16aefa42634b5458cb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2b8d1923ff3145789f77661364163cd1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "2d9cd71eec004370b5780a7814d4dd7b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "2db1853b41104999ab6884b33b217f0b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "2fad0082ed35466caaf6948f6ff4aac0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "344cd98e1e4f49dba72ad6df49695e1d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3d59af9594a445d68dbcc8e51994fcbf", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_fb7a0d9d587e4806a7519d9201f7f0aa", + "value": 9 + } + }, + "3529aafd3f884efe8ba4bb3e33a05c5c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "35e7d924e4e44ceaad836bde8fef17d8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1016ad0f475740ef9e2d5054323d22e7", + "placeholder": "​", + "style": "IPY_MODEL_b232c9df604a4ce59f422cf60499605f", + "value": "Epoch 4: 100%" + } + }, + "3852c1b855934e6e862ceaf37a855700": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3866642b5f6c4eaa8f3982da9efa2a4b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f2fb8e95caad4be792063e44a9401f2c", + "placeholder": "​", + "style": "IPY_MODEL_46fd6f7300514bd9b11abeafc51970ec", + "value": "Validation DataLoader 0: 100%" + } + }, + "3ac302bd9b054b9dbe0b76ba7575db68": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "3d59af9594a445d68dbcc8e51994fcbf": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3eb0147ac4d345faae3713d64eb0e66c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "46fd6f7300514bd9b11abeafc51970ec": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "4a22ad1ba0c240a3af31a06d117220fb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0c40053cf7ba496d87691bfeee0975e4", + "placeholder": "​", + "style": "IPY_MODEL_7b989ec8edc148cb9c3e58aa61637a3b", + "value": " 9/9 [00:00<00:00, 25.66it/s]" + } + }, + "4f5df150e0044ee798c04fa2c722f6c0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_13d0c7ce5a394744a44439556d43d24d", + "placeholder": "​", + "style": "IPY_MODEL_d09998b10967447e917c8ac88abf26db", + "value": " 9/9 [00:00<00:00, 32.96it/s]" + } + }, + "51455b67f5e741b192b5b832aff3d3d3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3852c1b855934e6e862ceaf37a855700", + "placeholder": "​", + "style": "IPY_MODEL_5ad79ce792314292a0425b6f03b85881", + "value": "Sanity Checking DataLoader 0: 100%" + } + }, + "52a0d1617601472590aa3b5641236890": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c9de18b7a60148e1b35c0dfb0b20e0d5", + "placeholder": "​", + "style": "IPY_MODEL_5f4abc4cadbc4fc7aae645d53f2a03aa", + "value": " 9/9 [00:00<00:00, 34.85it/s]" + } + }, + "5307c74c93dc472bbeeb88e6bfb487f1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "57cb9f2255e4487383465ea0a438a01b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5ad79ce792314292a0425b6f03b85881": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5f4abc4cadbc4fc7aae645d53f2a03aa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "607dc1dee95147f4a39b16170a08f780": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "642bce0f7ab848459f4b5c91068113d4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6922f0547118463e9468196bf9b1a5c2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f9cb515fd148446ca5fcb47fab504a74", + "max": 42, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_adf882d15ee84b2aa3ad6145e2c913bf", + "value": 42 + } + }, + "6fd91510b57043c093e4a1fd8302fec3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "78f9afd80d634c08965990e66297e5b6": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7b4b93e92400404abfd2c31129cda96f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_35e7d924e4e44ceaad836bde8fef17d8", + "IPY_MODEL_6922f0547118463e9468196bf9b1a5c2", + "IPY_MODEL_119506368b47405786c806373b0bb67d" + ], + "layout": "IPY_MODEL_27ec3bfc950847afb2850ad9ea9bdd49" + } + }, + "7b989ec8edc148cb9c3e58aa61637a3b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "8151f2de21b44422876341b10e6b8568": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "831387d48f5f4b8cb0662c5df77aefa4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b868fe32ecc44b50b823ae533755a8ef", + "placeholder": "​", + "style": "IPY_MODEL_0d2d00cf7992454484dfd939928aca14", + "value": "Validation DataLoader 0: 100%" + } + }, + "8792c8ef2b3d43748b0062327d901bb3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8997cb6735c341bc985d4aacb6e20999": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "8b5d5e147c8142328d30e2492b57e2c1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2b2eb8cd85ba4c16aefa42634b5458cb", + "placeholder": "​", + "style": "IPY_MODEL_edc278a53bb74140ad1a416406f0b0bb", + "value": " 2/2 [00:00<00:00, 21.87it/s]" + } + }, + "8e46e7b0ea3b429293dc6358980dffe2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_831387d48f5f4b8cb0662c5df77aefa4", + "IPY_MODEL_d08f30db32ec440f94d88727adfc3ee6", + "IPY_MODEL_9dd7c13be6a54e6b843ccb8afd59fb6b" + ], + "layout": "IPY_MODEL_2db1853b41104999ab6884b33b217f0b" + } + }, + "94a5cc9098f5445593b3212857bd0da2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_c236df5d6fdf4bd1b9fdf687819bdffb", + "IPY_MODEL_e3fcdbd89eb64510b3ed6947d728de1f", + "IPY_MODEL_52a0d1617601472590aa3b5641236890" + ], + "layout": "IPY_MODEL_ef103e08597e424a81ff70060b4edc77" + } + }, + "9dd7c13be6a54e6b843ccb8afd59fb6b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_df6002b4a12a49fe8d62210c5ebd5b06", + "placeholder": "​", + "style": "IPY_MODEL_adaafe80e3ee4a5c8d6d987800630f2e", + "value": " 9/9 [00:00<00:00, 33.87it/s]" + } + }, + "9eb0c593adf7413da7a4abeec48af166": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "100%" + } + }, + "a561990cc8204ef08effecce2855a882": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "adaafe80e3ee4a5c8d6d987800630f2e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "adf882d15ee84b2aa3ad6145e2c913bf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "afd962bafddc4be2b6ec83067402a19a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b232c9df604a4ce59f422cf60499605f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b3d493c1744248fc92301d57dcc07fe9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b791ccce40544f6ba57198a0392df981": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b868fe32ecc44b50b823ae533755a8ef": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bff2fcab4d574dc0bcc7382a17f6b080": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c236df5d6fdf4bd1b9fdf687819bdffb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d25815a46ca64bbea97e6b7f4fff954a", + "placeholder": "​", + "style": "IPY_MODEL_3eb0147ac4d345faae3713d64eb0e66c", + "value": "Validation DataLoader 0: 100%" + } + }, + "c8ddbeb5fdb9447c88561abb59edac29": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c9de18b7a60148e1b35c0dfb0b20e0d5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ccd9a65e258244f4bd29c741c3cb4441": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e9dcea6508f344caa5f1f99e2bfdd4db", + "max": 2, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_2b8d1923ff3145789f77661364163cd1", + "value": 2 + } + }, + "d08f30db32ec440f94d88727adfc3ee6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f6131d3c810442c5ad4f71177dac1f5e", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_8997cb6735c341bc985d4aacb6e20999", + "value": 9 + } + }, + "d09998b10967447e917c8ac88abf26db": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "d22ee6b52bbf45bdbe05eb3c23b966c9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c8ddbeb5fdb9447c88561abb59edac29", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_6fd91510b57043c093e4a1fd8302fec3", + "value": 9 + } + }, + "d25815a46ca64bbea97e6b7f4fff954a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "db2a9324292e415e8765117e3dab8356": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_fae0d1779ea045108616dad057ed24e1", + "IPY_MODEL_d22ee6b52bbf45bdbe05eb3c23b966c9", + "IPY_MODEL_e579bbd44a5241f7a7fa027147295ec3" + ], + "layout": "IPY_MODEL_9eb0c593adf7413da7a4abeec48af166" + } + }, + "dc1b070cf7bd43ed8c422e2c1cec6438": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "debe1c18531a401e8e3e93791787e1f8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "df6002b4a12a49fe8d62210c5ebd5b06": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e03f176873e6485a91118a9cb4d7fa4d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_3866642b5f6c4eaa8f3982da9efa2a4b", + "IPY_MODEL_fb1c1f0130934f24b163d87846f4210c", + "IPY_MODEL_0c32b264b85b437f825531554a51e7e1" + ], + "layout": "IPY_MODEL_056e8d3d437c40c889f19e316fbd38ac" + } + }, + "e3fcdbd89eb64510b3ed6947d728de1f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a561990cc8204ef08effecce2855a882", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_afd962bafddc4be2b6ec83067402a19a", + "value": 9 + } + }, + "e579bbd44a5241f7a7fa027147295ec3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e623887237994d72bfbff0fb2d76e697", + "placeholder": "​", + "style": "IPY_MODEL_debe1c18531a401e8e3e93791787e1f8", + "value": " 9/9 [00:00<00:00, 33.35it/s]" + } + }, + "e623887237994d72bfbff0fb2d76e697": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e9dcea6508f344caa5f1f99e2bfdd4db": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "eb40207588114030aeb2e66e6a66858b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8792c8ef2b3d43748b0062327d901bb3", + "placeholder": "​", + "style": "IPY_MODEL_fa845f388d7844a388f079d7cb115824", + "value": "Validation DataLoader 0: 100%" + } + }, + "edc278a53bb74140ad1a416406f0b0bb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ef103e08597e424a81ff70060b4edc77": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "f2fb8e95caad4be792063e44a9401f2c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f422383b4af64a26b16842297d5c84a6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_51455b67f5e741b192b5b832aff3d3d3", + "IPY_MODEL_ccd9a65e258244f4bd29c741c3cb4441", + "IPY_MODEL_8b5d5e147c8142328d30e2492b57e2c1" + ], + "layout": "IPY_MODEL_3529aafd3f884efe8ba4bb3e33a05c5c" + } + }, + "f6131d3c810442c5ad4f71177dac1f5e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f9b429aef6904d6995836f1c1279d38d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f9cb515fd148446ca5fcb47fab504a74": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fa845f388d7844a388f079d7cb115824": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "fae0d1779ea045108616dad057ed24e1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_57cb9f2255e4487383465ea0a438a01b", + "placeholder": "​", + "style": "IPY_MODEL_b3d493c1744248fc92301d57dcc07fe9", + "value": "Testing DataLoader 0: 100%" + } + }, + "fb1c1f0130934f24b163d87846f4210c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_642bce0f7ab848459f4b5c91068113d4", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_dc1b070cf7bd43ed8c422e2c1cec6438", + "value": 9 + } + }, + "fb7a0d9d587e4806a7519d9201f7f0aa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From 6364780ae121298e3d98a2c14c6f6747bf62a7b4 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 19:34:57 +0530 Subject: [PATCH 05/30] update docstring --- pytorch_forecasting/data/timeseries.py | 44 +++++++++++++++----------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index bc8300300..9da02d3a0 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2815,25 +2815,31 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: """Get time series data for given index. - It returns: - - * ``t``: ``numpy.ndarray`` of shape (n_timepoints,) - Time index for each time point in the past or present. Aligned with ``y``, - and ``x`` not ending in ``f``. - * ``y``: tensor of shape (n_timepoints, n_targets) - Target values for each time point. Rows are time points, aligned with ``t``. - * ``x``: tensor of shape (n_timepoints, n_features) - Features for each time point. Rows are time points, aligned with ``t``. - * ``group``: tensor of shape (n_groups) - Group identifiers for time series instances. - * ``st``: tensor of shape (n_static_features) - Static features. - * ``cutoff_time``: float or ``numpy.float64`` - Cutoff time for the time series instance. - - Optionally, the following str-keyed entry can be included: - - * ``weights``: tensor of shape (n_timepoints), only if weight is not None + Returns + ------- + t : numpy.ndarray of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with `y`, + and `x` not ending in `f`. + + y : torch.Tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with `t`. + + x : torch.Tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with `t`. + + group : torch.Tensor of shape (n_groups,) + Group identifiers for time series instances. + + st : torch.Tensor of shape (n_static_features,) + Static features. + + cutoff_time : float or numpy.float64 + Cutoff time for the time series instance. + + Other Returns + ------------- + weights : torch.Tensor of shape (n_timepoints,), optional + Only included if weights are not `None`. """ group_id = self._group_ids[index] From 257183ce4d2b1f7fd40c95ecd7dc38c8004a017b Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 01:54:50 +0530 Subject: [PATCH 06/30] update data_module.py --- pytorch_forecasting/data/data_module.py | 160 ++++++++++++------------ 1 file changed, 80 insertions(+), 80 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 2958f1705..c796b85fa 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -1,15 +1,6 @@ -####################################################################################### -# Disclaimer: This data-module is still work in progress and experimental, please -# use with care. This data-module is a basic skeleton of how the data-handling pipeline -# may look like in the future. -# This is D2 layer that will handle the preprocessing and data loaders. -# For now, this pipeline handles the simplest situation: The whole data can be loaded -# into the memory. -####################################################################################### - from typing import Any, Dict, List, Optional, Tuple, Union -from lightning.pytorch import LightningDataModule +from lightning.pytorch import LightningDataModule, LightningModule from sklearn.preprocessing import RobustScaler, StandardScaler import torch from torch.utils.data import DataLoader, Dataset @@ -19,7 +10,11 @@ NaNLabelEncoder, TorchNormalizer, ) -from pytorch_forecasting.data.timeseries import TimeSeries, _coerce_to_dict +from pytorch_forecasting.data.timeseries import ( + TimeSeries, + _coerce_to_dict, + _coerce_to_list, +) NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] @@ -274,7 +269,7 @@ def metadata(self): self._metadata = self._prepare_metadata() return self._metadata - def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: + def _preprocess_data(self, series_idx: torch.Tensor) -> List[Dict[str, Any]]: """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset. Preprocessing steps @@ -288,63 +283,58 @@ def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: TODO: add scalers, target normalizers etc. """ - processed_data = [] + sample = self.time_series_dataset[series_idx] - for idx in indices: - sample = self.time_series_dataset[idx.item()] + target = sample["y"] + features = sample["x"] + times = sample["t"] + cutoff_time = sample["cutoff_time"] - target = sample["y"] - features = sample["x"] - times = sample["t"] - cutoff_time = sample["cutoff_time"] + time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) - time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) - - if isinstance(target, torch.Tensor): - target = target.float() - else: - target = torch.tensor(target, dtype=torch.float32) - - if isinstance(features, torch.Tensor): - features = features.float() - else: - features = torch.tensor(features, dtype=torch.float32) + if isinstance(target, torch.Tensor): + target = target.float() + else: + target = torch.tensor(target, dtype=torch.float32) - # TODO: add scalers, target normalizers etc. + if isinstance(features, torch.Tensor): + features = features.float() + else: + features = torch.tensor(features, dtype=torch.float32) - categorical = ( - features[:, self.categorical_indices] - if self.categorical_indices - else torch.zeros((features.shape[0], 0)) - ) - continuous = ( - features[:, self.continuous_indices] - if self.continuous_indices - else torch.zeros((features.shape[0], 0)) - ) + # TODO: add scalers, target normalizers etc. - processed_data.append( - { - "features": {"categorical": categorical, "continuous": continuous}, - "target": target, - "static": sample.get("st", None), - "group": sample.get("group", torch.tensor([0])), - "length": len(target), - "time_mask": time_mask, - "times": times, - "cutoff_time": cutoff_time, - } - ) + categorical = ( + features[:, self.categorical_indices] + if self.categorical_indices + else torch.zeros((features.shape[0], 0)) + ) + continuous = ( + features[:, self.continuous_indices] + if self.continuous_indices + else torch.zeros((features.shape[0], 0)) + ) - return processed_data + return { + "features": {"categorical": categorical, "continuous": continuous}, + "target": target, + "static": sample.get("st", None), + "group": sample.get("group", torch.tensor([0])), + "length": len(target), + "time_mask": time_mask, + "times": times, + "cutoff_time": cutoff_time, + } class _ProcessedEncoderDecoderDataset(Dataset): """PyTorch Dataset for processed encoder-decoder time series data. Parameters ---------- - processed_data : List[Dict[str, Any]] - List of preprocessed time series samples. + dataset : TimeSeries + The base time series dataset that provides access to raw data and metadata. + data_module : EncoderDecoderTimeSeriesDataModule + The data module handling preprocessing and metadata configuration. windows : List[Tuple[int, int, int, int]] List of window tuples containing (series_idx, start_idx, enc_length, pred_length). @@ -354,11 +344,13 @@ class _ProcessedEncoderDecoderDataset(Dataset): def __init__( self, - processed_data: List[Dict[str, Any]], + dataset: TimeSeries, + data_module: "EncoderDecoderTimeSeriesDataModule", windows: List[Tuple[int, int, int, int]], add_relative_time_idx: bool = False, ): - self.processed_data = processed_data + self.dataset = dataset + self.data_module = data_module self.windows = windows self.add_relative_time_idx = add_relative_time_idx @@ -410,7 +402,7 @@ def __getitem__(self, idx): Target values for the decoder sequence. """ series_idx, start_idx, enc_length, pred_length = self.windows[idx] - data = self.processed_data[series_idx] + data = self.data_module._preprocess_data(series_idx) end_idx = start_idx + enc_length + pred_length encoder_indices = slice(start_idx, start_idx + enc_length) @@ -457,9 +449,7 @@ def __getitem__(self, idx): return x, y - def _create_windows( - self, processed_data: List[Dict[str, Any]] - ) -> List[Tuple[int, int, int, int]]: + def _create_windows(self, indices: torch.Tensor) -> List[Tuple[int, int, int, int]]: """Generate sliding windows for training, validation, and testing. Returns @@ -477,8 +467,10 @@ def _create_windows( """ windows = [] - for idx, data in enumerate(processed_data): - sequence_length = data["length"] + for idx in indices: + series_idx = idx.item() + sample = self.time_series_dataset[series_idx] + sequence_length = len(sample["y"]) if sequence_length < self.max_encoder_length + self.max_prediction_length: continue @@ -503,7 +495,7 @@ def _create_windows( ): windows.append( ( - idx, + series_idx, start_idx, self.max_encoder_length, self.max_prediction_length, @@ -538,33 +530,41 @@ def setup(self, stage: Optional[str] = None): if stage is None or stage == "fit": if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): - self.train_processed = self._preprocess_data(self._train_indices) - self.val_processed = self._preprocess_data(self._val_indices) - - self.train_windows = self._create_windows(self.train_processed) - self.val_windows = self._create_windows(self.val_processed) + self.train_windows = self._create_windows(self._train_indices) + self.val_windows = self._create_windows(self._val_indices) self.train_dataset = self._ProcessedEncoderDecoderDataset( - self.train_processed, self.train_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.train_windows, + self.add_relative_time_idx, ) self.val_dataset = self._ProcessedEncoderDecoderDataset( - self.val_processed, self.val_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.val_windows, + self.add_relative_time_idx, ) - elif stage is None or stage == "test": + elif stage == "test": if not hasattr(self, "test_dataset"): - self.test_processed = self._preprocess_data(self._test_indices) - self.test_windows = self._create_windows(self.test_processed) - + self.test_windows = self._create_windows(self._test_indices) self.test_dataset = self._ProcessedEncoderDecoderDataset( - self.test_processed, self.test_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.test_windows, + self, + self.add_relative_time_idx, ) elif stage == "predict": predict_indices = torch.arange(len(self.time_series_dataset)) - self.predict_processed = self._preprocess_data(predict_indices) - self.predict_windows = self._create_windows(self.predict_processed) + self.predict_windows = self._create_windows(predict_indices) self.predict_dataset = self._ProcessedEncoderDecoderDataset( - self.predict_processed, self.predict_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.predict_windows, + self, + self.add_relative_time_idx, ) def train_dataloader(self): From 9cdcb195c4c9e3f9b6d0e76ef3b6ed889bc14998 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 01:56:55 +0530 Subject: [PATCH 07/30] update data_module.py --- pytorch_forecasting/data/data_module.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index c796b85fa..9a4a5bf5e 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -553,7 +553,6 @@ def setup(self, stage: Optional[str] = None): self.time_series_dataset, self, self.test_windows, - self, self.add_relative_time_idx, ) elif stage == "predict": @@ -563,7 +562,6 @@ def setup(self, stage: Optional[str] = None): self.time_series_dataset, self, self.predict_windows, - self, self.add_relative_time_idx, ) From ac56d4fd56aeeb1287f162559c67e785de4446f4 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 02:05:58 +0530 Subject: [PATCH 08/30] Add disclaimer --- pytorch_forecasting/data/data_module.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 9a4a5bf5e..b33a11d47 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -1,6 +1,15 @@ +####################################################################################### +# Disclaimer: This data-module is still work in progress and experimental, please +# use with care. This data-module is a basic skeleton of how the data-handling pipeline +# may look like in the future. +# This is D2 layer that will handle the preprocessing and data loaders. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + from typing import Any, Dict, List, Optional, Tuple, Union -from lightning.pytorch import LightningDataModule, LightningModule +from lightning.pytorch import LightningDataModule from sklearn.preprocessing import RobustScaler, StandardScaler import torch from torch.utils.data import DataLoader, Dataset @@ -13,7 +22,6 @@ from pytorch_forecasting.data.timeseries import ( TimeSeries, _coerce_to_dict, - _coerce_to_list, ) NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] From 25bc7ee135584d8fa6c83f5273fbbf04a3775c99 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 02:12:33 +0530 Subject: [PATCH 09/30] update notebook as well --- examples/ptf_V2_example.ipynb | 2110 ++++++++++++++++----------------- 1 file changed, 1055 insertions(+), 1055 deletions(-) diff --git a/examples/ptf_V2_example.ipynb b/examples/ptf_V2_example.ipynb index 031d9d634..0d61de395 100644 --- a/examples/ptf_V2_example.ipynb +++ b/examples/ptf_V2_example.ipynb @@ -8,7 +8,7 @@ "base_uri": "https://localhost:8080/" }, "id": "2630DaOEI4AJ", - "outputId": "96798236-d2f1-4436-c047-49c3771d56c7" + "outputId": "a6f99bf0-957b-431a-f512-6abba6629768" }, "outputs": [ { @@ -30,9 +30,9 @@ " Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)\n", "Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (24.2)\n", "Collecting torchmetrics<3.0,>=0.7.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading torchmetrics-1.7.0-py3-none-any.whl.metadata (21 kB)\n", + " Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)\n", "Requirement already satisfied: tqdm<6.0,>=4.57.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.67.1)\n", - "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.13.0)\n", + "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.13.1)\n", "Collecting pytorch-lightning (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", " Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2.8.2)\n", @@ -77,39 +77,39 @@ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.2)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (25.3.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.5.0)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.3.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.3.2)\n", "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.3.1)\n", "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.18.3)\n", "Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.11/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.10)\n", "Downloading pytorch_forecasting-1.3.0-py3-none-any.whl (197 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m197.7/197.7 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m197.7/197.7 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading lightning-2.5.1-py3-none-any.whl (818 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m818.9/818.9 kB\u001b[0m \u001b[31m16.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m818.9/818.9 kB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m41.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m58.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m29.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m55.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m30.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m26.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m31.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m84.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)\n", - "Downloading torchmetrics-1.7.0-py3-none-any.whl (960 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m960.9/960.9 kB\u001b[0m \u001b[31m35.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m961.5/961.5 kB\u001b[0m \u001b[31m51.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading pytorch_lightning-2.5.1-py3-none-any.whl (822 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.0/823.0 kB\u001b[0m \u001b[31m31.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.0/823.0 kB\u001b[0m \u001b[31m40.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch-lightning, lightning, pytorch-forecasting\n", " Attempting uninstall: nvidia-nvjitlink-cu12\n", " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", @@ -151,7 +151,7 @@ " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", - "Successfully installed lightning-2.5.1 lightning-utilities-0.14.3 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-forecasting-1.3.0 pytorch-lightning-2.5.1 torchmetrics-1.7.0\n" + "Successfully installed lightning-2.5.1 lightning-utilities-0.14.3 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-forecasting-1.3.0 pytorch-lightning-2.5.1 torchmetrics-1.7.1\n" ] } ], @@ -161,7 +161,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 2, "metadata": { "id": "M7PQerTbI_tM" }, @@ -179,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 3, "metadata": { "id": "XmL5ukG9JDTD" }, @@ -443,7 +443,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 12, "metadata": { "id": "0Rw9LgsXJI5V" }, @@ -451,7 +451,7 @@ "source": [ "from typing import Dict, List, Optional, Union\n", "\n", - "from lightning.pytorch import LightningDataModule\n", + "from lightning.pytorch import LightningDataModule, LightningModule\n", "from sklearn.preprocessing import RobustScaler, StandardScaler\n", "import torch\n", "from torch.utils.data import DataLoader, Dataset\n", @@ -715,7 +715,7 @@ " self._metadata = self._prepare_metadata()\n", " return self._metadata\n", "\n", - " def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]:\n", + " def _preprocess_data(self, series_idx: torch.Tensor) -> List[Dict[str, Any]]:\n", " \"\"\"Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset.\n", "\n", " Preprocessing steps\n", @@ -729,55 +729,48 @@ "\n", " TODO: add scalers, target normalizers etc.\n", " \"\"\"\n", - " processed_data = []\n", - "\n", - " for idx in indices:\n", - " sample = self.time_series_dataset[idx.item()]\n", - "\n", - " target = sample[\"y\"]\n", - " features = sample[\"x\"]\n", - " times = sample[\"t\"]\n", - " cutoff_time = sample[\"cutoff_time\"]\n", + " sample = self.time_series_dataset[series_idx]\n", "\n", - " time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool)\n", + " target = sample[\"y\"]\n", + " features = sample[\"x\"]\n", + " times = sample[\"t\"]\n", + " cutoff_time = sample[\"cutoff_time\"]\n", "\n", - " if isinstance(target, torch.Tensor):\n", - " target = target.float()\n", - " else:\n", - " target = torch.tensor(target, dtype=torch.float32)\n", + " time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool)\n", "\n", - " if isinstance(features, torch.Tensor):\n", - " features = features.float()\n", - " else:\n", - " features = torch.tensor(features, dtype=torch.float32)\n", + " if isinstance(target, torch.Tensor):\n", + " target = target.float()\n", + " else:\n", + " target = torch.tensor(target, dtype=torch.float32)\n", "\n", - " # TODO: add scalers, target normalizers etc.\n", + " if isinstance(features, torch.Tensor):\n", + " features = features.float()\n", + " else:\n", + " features = torch.tensor(features, dtype=torch.float32)\n", "\n", - " categorical = (\n", - " features[:, self.categorical_indices]\n", - " if self.categorical_indices\n", - " else torch.zeros((features.shape[0], 0))\n", - " )\n", - " continuous = (\n", - " features[:, self.continuous_indices]\n", - " if self.continuous_indices\n", - " else torch.zeros((features.shape[0], 0))\n", - " )\n", + " # TODO: add scalers, target normalizers etc.\n", "\n", - " processed_data.append(\n", - " {\n", - " \"features\": {\"categorical\": categorical, \"continuous\": continuous},\n", - " \"target\": target,\n", - " \"static\": sample.get(\"st\", None),\n", - " \"group\": sample.get(\"group\", torch.tensor([0])),\n", - " \"length\": len(target),\n", - " \"time_mask\": time_mask,\n", - " \"times\": times,\n", - " \"cutoff_time\": cutoff_time,\n", - " }\n", - " )\n", + " categorical = (\n", + " features[:, self.categorical_indices]\n", + " if self.categorical_indices\n", + " else torch.zeros((features.shape[0], 0))\n", + " )\n", + " continuous = (\n", + " features[:, self.continuous_indices]\n", + " if self.continuous_indices\n", + " else torch.zeros((features.shape[0], 0))\n", + " )\n", "\n", - " return processed_data\n", + " return {\n", + " \"features\": {\"categorical\": categorical, \"continuous\": continuous},\n", + " \"target\": target,\n", + " \"static\": sample.get(\"st\", None),\n", + " \"group\": sample.get(\"group\", torch.tensor([0])),\n", + " \"length\": len(target),\n", + " \"time_mask\": time_mask,\n", + " \"times\": times,\n", + " \"cutoff_time\": cutoff_time,\n", + " }\n", "\n", " class _ProcessedEncoderDecoderDataset(Dataset):\n", " \"\"\"PyTorch Dataset for processed encoder-decoder time series data.\n", @@ -795,11 +788,13 @@ "\n", " def __init__(\n", " self,\n", - " processed_data: List[Dict[str, Any]],\n", + " dataset: TimeSeries,\n", + " data_module: \"EncoderDecoderTimeSeriesDataModule\",\n", " windows: List[Tuple[int, int, int, int]],\n", " add_relative_time_idx: bool = False,\n", " ):\n", - " self.processed_data = processed_data\n", + " self.dataset = dataset\n", + " self.data_module = data_module\n", " self.windows = windows\n", " self.add_relative_time_idx = add_relative_time_idx\n", "\n", @@ -850,8 +845,9 @@ " y : tensor of shape ``(pred_length, n_targets)``\n", " Target values for the decoder sequence.\n", " \"\"\"\n", + "\n", " series_idx, start_idx, enc_length, pred_length = self.windows[idx]\n", - " data = self.processed_data[series_idx]\n", + " data = self.data_module._preprocess_data(series_idx)\n", "\n", " end_idx = start_idx + enc_length + pred_length\n", " encoder_indices = slice(start_idx, start_idx + enc_length)\n", @@ -898,9 +894,7 @@ "\n", " return x, y\n", "\n", - " def _create_windows(\n", - " self, processed_data: List[Dict[str, Any]]\n", - " ) -> List[Tuple[int, int, int, int]]:\n", + " def _create_windows(self, indices: torch.Tensor) -> List[Tuple[int, int, int, int]]:\n", " \"\"\"Generate sliding windows for training, validation, and testing.\n", "\n", " Returns\n", @@ -918,8 +912,10 @@ " \"\"\"\n", " windows = []\n", "\n", - " for idx, data in enumerate(processed_data):\n", - " sequence_length = data[\"length\"]\n", + " for idx in indices:\n", + " series_idx = idx.item()\n", + " sample = self.time_series_dataset[series_idx]\n", + " sequence_length = len(sample[\"y\"])\n", "\n", " if sequence_length < self.max_encoder_length + self.max_prediction_length:\n", " continue\n", @@ -944,7 +940,7 @@ " ):\n", " windows.append(\n", " (\n", - " idx,\n", + " series_idx,\n", " start_idx,\n", " self.max_encoder_length,\n", " self.max_prediction_length,\n", @@ -979,34 +975,39 @@ "\n", " if stage is None or stage == \"fit\":\n", " if not hasattr(self, \"train_dataset\") or not hasattr(self, \"val_dataset\"):\n", - " self.train_processed = self._preprocess_data(self._train_indices)\n", - " self.val_processed = self._preprocess_data(self._val_indices)\n", - "\n", - " self.train_windows = self._create_windows(self.train_processed)\n", - " self.val_windows = self._create_windows(self.val_processed)\n", + " self.train_windows = self._create_windows(self._train_indices)\n", + " self.val_windows = self._create_windows(self._val_indices)\n", "\n", " self.train_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.train_processed, self.train_windows, self.add_relative_time_idx\n", + " self.time_series_dataset,\n", + " self,\n", + " self.train_windows,\n", + " self.add_relative_time_idx,\n", " )\n", " self.val_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.val_processed, self.val_windows, self.add_relative_time_idx\n", + " self.time_series_dataset,\n", + " self,\n", + " self.val_windows,\n", + " self.add_relative_time_idx,\n", " )\n", - " # print(self.val_dataset[0])\n", "\n", - " elif stage is None or stage == \"test\":\n", + " elif stage == \"test\":\n", " if not hasattr(self, \"test_dataset\"):\n", - " self.test_processed = self._preprocess_data(self._test_indices)\n", - " self.test_windows = self._create_windows(self.test_processed)\n", - "\n", + " self.test_windows = self._create_windows(self._test_indices)\n", " self.test_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.test_processed, self.test_windows, self.add_relative_time_idx\n", + " self.time_series_dataset,\n", + " self,\n", + " self.test_windows,\n", + " self.add_relative_time_idx,\n", " )\n", " elif stage == \"predict\":\n", " predict_indices = torch.arange(len(self.time_series_dataset))\n", - " self.predict_processed = self._preprocess_data(predict_indices)\n", - " self.predict_windows = self._create_windows(self.predict_processed)\n", + " self.predict_windows = self._create_windows(predict_indices)\n", " self.predict_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.predict_processed, self.predict_windows, self.add_relative_time_idx\n", + " self.time_series_dataset,\n", + " self,\n", + " self.predict_windows,\n", + " self.add_relative_time_idx,\n", " )\n", "\n", " def train_dataloader(self):\n", @@ -1076,26 +1077,26 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "WX-FRdusJSVN", - "outputId": "730f0fe2-f5af-4871-859d-9a4043bbeac7" + "outputId": "2b7ae9bd-0bee-4c05-a512-61193b462274" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { - "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6723608094711289,\n \"min\": -1.2295384314749236,\n \"max\": 1.3194322331654313,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.3958947786657995,\n 0.7816648993958805,\n -0.9655256111265276\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6766981377008536,\n \"min\": -1.2295384314749236,\n \"max\": 1.3194322331654313,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.5707668517530808,\n 0.5020485177883972,\n -0.7734543579445009\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2783997753182645,\n \"min\": 0.011560494046953695,\n \"max\": 0.9996855497257285,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.43066463390546217,\n 0.08751257529405387,\n 0.3593350820130162\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", + "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6729063389612355,\n \"min\": -1.298580096280071,\n \"max\": 1.2952368656723652,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.4192357980400779,\n 0.7249265048690496,\n -0.966494107115657\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6772588242164835,\n \"min\": -1.298580096280071,\n \"max\": 1.2952368656723652,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.6408981347050933,\n 0.758204270990956,\n -0.6996142581496693\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2897600721149767,\n \"min\": 0.009824875416031609,\n \"max\": 0.9977597461896692,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.18589175791370482,\n 0.8789086949480461,\n 0.010073414332526953\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", "type": "dataframe", "variable_name": "data_df" }, "text/html": [ "\n", - "
\n", + "
\n", "
\n", "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
000-0.0712220.33976301.0000000.6264260
1010.3397630.18934800.9950040.6264260
2020.1893480.67598900.9800670.6264260
3030.6759890.79726100.9553360.6264260
4040.7972610.99501600.9210610.6264260
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "\n", - "
\n", - " \n", - "\n", - "\n", - "\n", - " \n", - "
\n", - "\n", - "
\n", - "
\n" + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting pytorch-forecasting\n", + " Downloading pytorch_forecasting-1.3.0-py3-none-any.whl.metadata (13 kB)\n", + "Requirement already satisfied: numpy<=3.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.0.2)\n", + "Requirement already satisfied: torch!=2.0.1,<3.0.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.6.0+cu124)\n", + "Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)\n", + " Downloading lightning-2.5.1-py3-none-any.whl.metadata (39 kB)\n", + "Requirement already satisfied: scipy<2.0,>=1.8 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.14.1)\n", + "Requirement already satisfied: pandas<3.0.0,>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.2.2)\n", + "Requirement already satisfied: scikit-learn<2.0,>=1.2 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.6.1)\n", + "Requirement already satisfied: PyYAML<8.0,>=5.4 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.0.2)\n", + "Requirement already satisfied: fsspec<2026.0,>=2022.5.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2025.3.2)\n", + "Collecting lightning-utilities<2.0,>=0.10.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)\n", + "Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (24.2)\n", + "Collecting torchmetrics<3.0,>=0.7.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)\n", + "Requirement already satisfied: tqdm<6.0,>=4.57.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.67.1)\n", + "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.13.1)\n", + "Collecting pytorch-lightning (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (3.6.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.18.0)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.1.6)\n", + "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-curand-cu12==10.3.5.147 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (0.6.2)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (2.21.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", + "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.2.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.0)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.11.15)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from lightning-utilities<2.0,>=0.10.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (75.2.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas<3.0.0,>=1.3.0->pytorch-forecasting) (1.17.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.0.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.6.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.2)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (25.3.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.5.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.3.2)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.3.1)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.18.3)\n", + "Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.11/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.10)\n", + "Downloading pytorch_forecasting-1.3.0-py3-none-any.whl (197 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m197.7/197.7 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning-2.5.1-py3-none-any.whl (818 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m818.9/818.9 kB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m58.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m55.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m26.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m84.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)\n", + "Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m961.5/961.5 kB\u001b[0m \u001b[31m51.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading pytorch_lightning-2.5.1-py3-none-any.whl (822 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.0/823.0 kB\u001b[0m \u001b[31m40.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch-lightning, lightning, pytorch-forecasting\n", + " Attempting uninstall: nvidia-nvjitlink-cu12\n", + " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", + " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", + " Attempting uninstall: nvidia-curand-cu12\n", + " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", + " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", + " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", + " Attempting uninstall: nvidia-cufft-cu12\n", + " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", + " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", + " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", + " Attempting uninstall: nvidia-cuda-runtime-cu12\n", + " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", + " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-cupti-cu12\n", + " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cublas-cu12\n", + " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", + " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", + " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", + " Attempting uninstall: nvidia-cusparse-cu12\n", + " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", + " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", + " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", + " Attempting uninstall: nvidia-cudnn-cu12\n", + " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", + " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", + " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", + " Attempting uninstall: nvidia-cusolver-cu12\n", + " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", + " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", + " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", + "Successfully installed lightning-2.5.1 lightning-utilities-0.14.3 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-forecasting-1.3.0 pytorch-lightning-2.5.1 torchmetrics-1.7.1\n" + ] + } ], - "text/plain": [ - " series_id time_idx x y category future_known_feature \\\n", - "0 0 0 -0.071222 0.339763 0 1.000000 \n", - "1 0 1 0.339763 0.189348 0 0.995004 \n", - "2 0 2 0.189348 0.675989 0 0.980067 \n", - "3 0 3 0.675989 0.797261 0 0.955336 \n", - "4 0 4 0.797261 0.995016 0 0.921061 \n", - "\n", - " static_feature static_feature_cat \n", - "0 0.626426 0 \n", - "1 0.626426 0 \n", - "2 0.626426 0 \n", - "3 0.626426 0 \n", - "4 0.626426 0 " + "source": [ + "!pip install pytorch-forecasting" ] }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from lightning.pytorch import Trainer\n", - "import pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "\n", - "from pytorch_forecasting.metrics import MAE, SMAPE\n", - "\n", - "num_series = 100\n", - "seq_length = 50\n", - "data_list = []\n", - "for i in range(num_series):\n", - " x = np.arange(seq_length)\n", - " y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)\n", - " category = i % 5\n", - " static_value = np.random.rand()\n", - " for t in range(seq_length - 1):\n", - " data_list.append(\n", - " {\n", - " \"series_id\": i,\n", - " \"time_idx\": t,\n", - " \"x\": y[t],\n", - " \"y\": y[t + 1],\n", - " \"category\": category,\n", - " \"future_known_feature\": np.cos(t / 10),\n", - " \"static_feature\": static_value,\n", - " \"static_feature_cat\": i % 3,\n", - " }\n", - " )\n", - "data_df = pd.DataFrame(data_list)\n", - "data_df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "AxxPHK6AKSD2" - }, - "outputs": [], - "source": [ - "dataset = TimeSeries(\n", - " data=data_df,\n", - " time=\"time_idx\",\n", - " target=\"y\",\n", - " group=[\"series_id\"],\n", - " num=[\"x\", \"future_known_feature\", \"static_feature\"],\n", - " cat=[\"category\", \"static_feature_cat\"],\n", - " known=[\"future_known_feature\"],\n", - " unknown=[\"x\", \"category\"],\n", - " static=[\"static_feature\", \"static_feature_cat\"],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "5U5Lr_ZFKX0s" - }, - "outputs": [], - "source": [ - "data_module = EncoderDecoderTimeSeriesDataModule(\n", - " time_series_dataset=dataset,\n", - " max_encoder_length=30,\n", - " max_prediction_length=1,\n", - " batch_size=32,\n", - " categorical_encoders={\n", - " \"category\": NaNLabelEncoder(add_nan=True),\n", - " \"static_feature_cat\": NaNLabelEncoder(add_nan=True),\n", - " },\n", - " scalers={\n", - " \"x\": StandardScaler(),\n", - " \"future_known_feature\": StandardScaler(),\n", - " \"static_feature\": StandardScaler(),\n", - " },\n", - " target_normalizer=TorchNormalizer(),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "I8NgHxNqK9uV" - }, - "outputs": [], - "source": [ - "from typing import Dict, List, Optional, Union\n", - "\n", - "from lightning.pytorch.utilities.types import STEP_OUTPUT\n", - "import torch\n", - "from torch.optim import Optimizer\n", - "\n", - "\n", - "class BaseModel(LightningModule):\n", - " def __init__(\n", - " self,\n", - " loss: nn.Module,\n", - " logging_metrics: Optional[List[nn.Module]] = None,\n", - " optimizer: Optional[Union[Optimizer, str]] = \"adam\",\n", - " optimizer_params: Optional[Dict] = None,\n", - " lr_scheduler: Optional[str] = None,\n", - " lr_scheduler_params: Optional[Dict] = None,\n", - " ):\n", - " \"\"\"\n", - " Base model for time series forecasting.\n", - "\n", - " Parameters\n", - " ----------\n", - " loss : nn.Module\n", - " Loss function to use for training.\n", - " logging_metrics : Optional[List[nn.Module]], optional\n", - " List of metrics to log during training, validation, and testing.\n", - " optimizer : Optional[Union[Optimizer, str]], optional\n", - " Optimizer to use for training.\n", - " Can be a string (\"adam\", \"sgd\") or an instance of `torch.optim.Optimizer`.\n", - " optimizer_params : Optional[Dict], optional\n", - " Parameters for the optimizer.\n", - " lr_scheduler : Optional[str], optional\n", - " Learning rate scheduler to use.\n", - " Supported values: \"reduce_lr_on_plateau\", \"step_lr\".\n", - " lr_scheduler_params : Optional[Dict], optional\n", - " Parameters for the learning rate scheduler.\n", - " \"\"\"\n", - " super().__init__()\n", - " self.loss = loss\n", - " self.logging_metrics = logging_metrics if logging_metrics is not None else []\n", - " self.optimizer = optimizer\n", - " self.optimizer_params = optimizer_params if optimizer_params is not None else {}\n", - " self.lr_scheduler = lr_scheduler\n", - " self.lr_scheduler_params = (\n", - " lr_scheduler_params if lr_scheduler_params is not None else {}\n", - " )\n", - "\n", - " def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n", - " \"\"\"\n", - " Forward pass of the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " x : Dict[str, torch.Tensor]\n", - " Dictionary containing input tensors\n", - "\n", - " Returns\n", - " -------\n", - " Dict[str, torch.Tensor]\n", - " Dictionary containing output tensors\n", - " \"\"\"\n", - " raise NotImplementedError(\"Forward method must be implemented by subclass.\")\n", - "\n", - " def training_step(\n", - " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", - " ) -> STEP_OUTPUT:\n", - " \"\"\"\n", - " Training step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input and target tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - "\n", - " Returns\n", - " -------\n", - " STEP_OUTPUT\n", - " Dictionary containing the loss and other metrics.\n", - " \"\"\"\n", - " x, y = batch\n", - " y_hat_dict = self(x)\n", - " y_hat = y_hat_dict[\"prediction\"]\n", - " loss = self.loss(y_hat, y)\n", - " self.log(\n", - " \"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True\n", - " )\n", - " self.log_metrics(y_hat, y, prefix=\"train\")\n", - " return {\"loss\": loss}\n", - "\n", - " def validation_step(\n", - " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", - " ) -> STEP_OUTPUT:\n", - " \"\"\"\n", - " Validation step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input and target tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - "\n", - " Returns\n", - " -------\n", - " STEP_OUTPUT\n", - " Dictionary containing the loss and other metrics.\n", - " \"\"\"\n", - " x, y = batch\n", - " y_hat_dict = self(x)\n", - " y_hat = y_hat_dict[\"prediction\"]\n", - " loss = self.loss(y_hat, y)\n", - " self.log(\n", - " \"val_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", - " )\n", - " self.log_metrics(y_hat, y, prefix=\"val\")\n", - " return {\"val_loss\": loss}\n", - "\n", - " def test_step(\n", - " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", - " ) -> STEP_OUTPUT:\n", - " \"\"\"\n", - " Test step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input and target tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - "\n", - " Returns\n", - " -------\n", - " STEP_OUTPUT\n", - " Dictionary containing the loss and other metrics.\n", - " \"\"\"\n", - " x, y = batch\n", - " y_hat_dict = self(x)\n", - " y_hat = y_hat_dict[\"prediction\"]\n", - " loss = self.loss(y_hat, y)\n", - " self.log(\n", - " \"test_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", - " )\n", - " self.log_metrics(y_hat, y, prefix=\"test\")\n", - " return {\"test_loss\": loss}\n", - "\n", - " def predict_step(\n", - " self,\n", - " batch: Tuple[Dict[str, torch.Tensor]],\n", - " batch_idx: int,\n", - " dataloader_idx: int = 0,\n", - " ) -> torch.Tensor:\n", - " \"\"\"\n", - " Prediction step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - " dataloader_idx : int\n", - " Index of the dataloader.\n", - "\n", - " Returns\n", - " -------\n", - " torch.Tensor\n", - " Predicted output tensor.\n", - " \"\"\"\n", - " x, _ = batch\n", - " y_hat = self(x)\n", - " return y_hat\n", - "\n", - " def configure_optimizers(self) -> Dict:\n", - " \"\"\"\n", - " Configure the optimizer and learning rate scheduler.\n", - "\n", - " Returns\n", - " -------\n", - " Dict\n", - " Dictionary containing the optimizer and scheduler configuration.\n", - " \"\"\"\n", - " optimizer = self._get_optimizer()\n", - " if self.lr_scheduler is not None:\n", - " scheduler = self._get_scheduler(optimizer)\n", - " if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n", - " return {\n", - " \"optimizer\": optimizer,\n", - " \"lr_scheduler\": {\n", - " \"scheduler\": scheduler,\n", - " \"monitor\": \"val_loss\",\n", - " },\n", - " }\n", - " else:\n", - " return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n", - " return {\"optimizer\": optimizer}\n", - "\n", - " def _get_optimizer(self) -> Optimizer:\n", - " \"\"\"\n", - " Get the optimizer based on the specified optimizer name and parameters.\n", - "\n", - " Returns\n", - " -------\n", - " Optimizer\n", - " The optimizer instance.\n", - " \"\"\"\n", - " if isinstance(self.optimizer, str):\n", - " if self.optimizer.lower() == \"adam\":\n", - " return torch.optim.Adam(self.parameters(), **self.optimizer_params)\n", - " elif self.optimizer.lower() == \"sgd\":\n", - " return torch.optim.SGD(self.parameters(), **self.optimizer_params)\n", - " else:\n", - " raise ValueError(f\"Optimizer {self.optimizer} not supported.\")\n", - " elif isinstance(self.optimizer, Optimizer):\n", - " return self.optimizer\n", - " else:\n", - " raise ValueError(\n", - " \"Optimizer must be either a string or \"\n", - " \"an instance of torch.optim.Optimizer.\"\n", - " )\n", - "\n", - " def _get_scheduler(\n", - " self, optimizer: Optimizer\n", - " ) -> torch.optim.lr_scheduler._LRScheduler:\n", - " \"\"\"\n", - " Get the lr scheduler based on the specified scheduler name and params.\n", - "\n", - " Parameters\n", - " ----------\n", - " optimizer : Optimizer\n", - " The optimizer instance.\n", - "\n", - " Returns\n", - " -------\n", - " torch.optim.lr_scheduler._LRScheduler\n", - " The learning rate scheduler instance.\n", - " \"\"\"\n", - " if self.lr_scheduler.lower() == \"reduce_lr_on_plateau\":\n", - " return torch.optim.lr_scheduler.ReduceLROnPlateau(\n", - " optimizer, **self.lr_scheduler_params\n", - " )\n", - " elif self.lr_scheduler.lower() == \"step_lr\":\n", - " return torch.optim.lr_scheduler.StepLR(\n", - " optimizer, **self.lr_scheduler_params\n", - " )\n", - " else:\n", - " raise ValueError(f\"Scheduler {self.lr_scheduler} not supported.\")\n", - "\n", - " def log_metrics(\n", - " self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = \"val\"\n", - " ) -> None:\n", - " \"\"\"\n", - " Log additional metrics during training, validation, or testing.\n", - "\n", - " Parameters\n", - " ----------\n", - " y_hat : torch.Tensor\n", - " Predicted output tensor.\n", - " y : torch.Tensor\n", - " Target output tensor.\n", - " prefix : str\n", - " Prefix for the logged metrics (e.g., \"train\", \"val\", \"test\").\n", - " \"\"\"\n", - " for metric in self.logging_metrics:\n", - " metric_value = metric(y_hat, y)\n", - " self.log(\n", - " f\"{prefix}_{metric.__class__.__name__}\",\n", - " metric_value,\n", - " on_step=False,\n", - " on_epoch=True,\n", - " prog_bar=True,\n", - " logger=True,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "n5EBeGucK_k0" - }, - "outputs": [], - "source": [ - "from typing import Dict, List, Optional, Tuple, Union\n", - "\n", - "import torch\n", - "import torch.nn as nn\n", - "from torch.optim import Optimizer\n", - "\n", - "\n", - "class TFT(BaseModel):\n", - " def __init__(\n", - " self,\n", - " loss: nn.Module,\n", - " logging_metrics: Optional[List[nn.Module]] = None,\n", - " optimizer: Optional[Union[Optimizer, str]] = \"adam\",\n", - " optimizer_params: Optional[Dict] = None,\n", - " lr_scheduler: Optional[str] = None,\n", - " lr_scheduler_params: Optional[Dict] = None,\n", - " hidden_size: int = 64,\n", - " num_layers: int = 2,\n", - " attention_head_size: int = 4,\n", - " dropout: float = 0.1,\n", - " metadata: Optional[Dict] = None,\n", - " output_size: int = 1,\n", - " ):\n", - " super().__init__(\n", - " loss=loss,\n", - " logging_metrics=logging_metrics,\n", - " optimizer=optimizer,\n", - " optimizer_params=optimizer_params,\n", - " lr_scheduler=lr_scheduler,\n", - " lr_scheduler_params=lr_scheduler_params,\n", - " )\n", - " self.hidden_size = hidden_size\n", - " self.num_layers = num_layers\n", - " self.attention_head_size = attention_head_size\n", - " self.dropout = dropout\n", - " self.metadata = metadata\n", - " self.output_size = output_size\n", - "\n", - " self.max_encoder_length = self.metadata[\"max_encoder_length\"]\n", - " self.max_prediction_length = self.metadata[\"max_prediction_length\"]\n", - " self.encoder_cont = self.metadata[\"encoder_cont\"]\n", - " self.encoder_cat = self.metadata[\"encoder_cat\"]\n", - " self.static_categorical_features = self.metadata[\"static_categorical_features\"]\n", - " self.static_continuous_features = self.metadata[\"static_continuous_features\"]\n", - "\n", - " total_feature_size = self.encoder_cont + self.encoder_cat\n", - " total_static_size = (\n", - " self.static_categorical_features + self.static_continuous_features\n", - " )\n", - "\n", - " self.encoder_var_selection = nn.Sequential(\n", - " nn.Linear(total_feature_size, hidden_size),\n", - " nn.ReLU(),\n", - " nn.Linear(hidden_size, total_feature_size),\n", - " nn.Sigmoid(),\n", - " )\n", - "\n", - " self.decoder_var_selection = nn.Sequential(\n", - " nn.Linear(total_feature_size, hidden_size),\n", - " nn.ReLU(),\n", - " nn.Linear(hidden_size, total_feature_size),\n", - " nn.Sigmoid(),\n", - " )\n", - "\n", - " self.static_context_linear = (\n", - " nn.Linear(total_static_size, hidden_size) if total_static_size > 0 else None\n", - " )\n", - "\n", - " self.lstm_encoder = nn.LSTM(\n", - " input_size=total_feature_size,\n", - " hidden_size=hidden_size,\n", - " num_layers=num_layers,\n", - " dropout=dropout,\n", - " batch_first=True,\n", - " )\n", - "\n", - " self.lstm_decoder = nn.LSTM(\n", - " input_size=total_feature_size,\n", - " hidden_size=hidden_size,\n", - " num_layers=num_layers,\n", - " dropout=dropout,\n", - " batch_first=True,\n", - " )\n", - "\n", - " self.self_attention = nn.MultiheadAttention(\n", - " embed_dim=hidden_size,\n", - " num_heads=attention_head_size,\n", - " dropout=dropout,\n", - " batch_first=True,\n", - " )\n", - "\n", - " self.pre_output = nn.Linear(hidden_size, hidden_size)\n", - " self.output_layer = nn.Linear(hidden_size, output_size)\n", - "\n", - " def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n", - " \"\"\"\n", - " Forward pass of the TFT model.\n", - "\n", - " Parameters\n", - " ----------\n", - " x : Dict[str, torch.Tensor]\n", - " Dictionary containing input tensors:\n", - " - encoder_cat: Categorical encoder features\n", - " - encoder_cont: Continuous encoder features\n", - " - decoder_cat: Categorical decoder features\n", - " - decoder_cont: Continuous decoder features\n", - " - static_categorical_features: Static categorical features\n", - " - static_continuous_features: Static continuous features\n", - "\n", - " Returns\n", - " -------\n", - " Dict[str, torch.Tensor]\n", - " Dictionary containing output tensors:\n", - " - prediction: Prediction output (batch_size, prediction_length, output_size)\n", - " \"\"\"\n", - " batch_size = x[\"encoder_cont\"].shape[0]\n", - "\n", - " encoder_cat = x.get(\n", - " \"encoder_cat\",\n", - " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", - " )\n", - " encoder_cont = x.get(\n", - " \"encoder_cont\",\n", - " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", - " )\n", - " decoder_cat = x.get(\n", - " \"decoder_cat\",\n", - " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", - " )\n", - " decoder_cont = x.get(\n", - " \"decoder_cont\",\n", - " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", - " )\n", - "\n", - " encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2)\n", - " decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2)\n", - "\n", - " static_context = None\n", - " if self.static_context_linear is not None:\n", - " static_cat = x.get(\n", - " \"static_categorical_features\",\n", - " torch.zeros(batch_size, 0, device=self.device),\n", - " )\n", - " static_cont = x.get(\n", - " \"static_continuous_features\",\n", - " torch.zeros(batch_size, 0, device=self.device),\n", - " )\n", - "\n", - " if static_cat.size(2) == 0 and static_cont.size(2) == 0:\n", - " static_context = None\n", - " elif static_cat.size(2) == 0:\n", - " static_input = static_cont.to(\n", - " dtype=self.static_context_linear.weight.dtype\n", - " )\n", - " static_context = self.static_context_linear(static_input)\n", - " static_context = static_context.view(batch_size, self.hidden_size)\n", - " elif static_cont.size(2) == 0:\n", - " static_input = static_cat.to(\n", - " dtype=self.static_context_linear.weight.dtype\n", - " )\n", - " static_context = self.static_context_linear(static_input)\n", - " static_context = static_context.view(batch_size, self.hidden_size)\n", - " else:\n", - "\n", - " static_input = torch.cat([static_cont, static_cat], dim=1).to(\n", - " dtype=self.static_context_linear.weight.dtype\n", - " )\n", - " static_context = self.static_context_linear(static_input)\n", - " static_context = static_context.view(batch_size, self.hidden_size)\n", - "\n", - " encoder_weights = self.encoder_var_selection(encoder_input)\n", - " encoder_input = encoder_input * encoder_weights\n", - "\n", - " decoder_weights = self.decoder_var_selection(decoder_input)\n", - " decoder_input = decoder_input * decoder_weights\n", - "\n", - " if static_context is not None:\n", - " encoder_static_context = static_context.unsqueeze(1).expand(\n", - " -1, self.max_encoder_length, -1\n", - " )\n", - " decoder_static_context = static_context.unsqueeze(1).expand(\n", - " -1, self.max_prediction_length, -1\n", - " )\n", - "\n", - " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", - " encoder_output = encoder_output + encoder_static_context\n", - " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", - " decoder_output = decoder_output + decoder_static_context\n", - " else:\n", - " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", - " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", - "\n", - " sequence = torch.cat([encoder_output, decoder_output], dim=1)\n", - "\n", - " if static_context is not None:\n", - " expanded_static_context = static_context.unsqueeze(1).expand(\n", - " -1, sequence.size(1), -1\n", - " )\n", - "\n", - " attended_output, _ = self.self_attention(\n", - " sequence + expanded_static_context, sequence, sequence\n", - " )\n", - " else:\n", - " attended_output, _ = self.self_attention(sequence, sequence, sequence)\n", - "\n", - " decoder_attended = attended_output[:, -self.max_prediction_length :, :]\n", - "\n", - " output = nn.functional.relu(self.pre_output(decoder_attended))\n", - " prediction = self.output_layer(output)\n", - "\n", - " return {\"prediction\": prediction}" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000, - "referenced_widgets": [ - "78e4ca50a25146818b654814147ed1ab", - "8cd83f0e9663436b967df6e5d38766ae", - "4c7ee7e81def457aacb0d898f92437c0", - "788def26447b409da872d076b4ba8061", - "407fe1dea38d48c99473c4019089bc90", - "6084560f9ea74d2c9e106a9500bc65eb", - "f44964101be74d03a83601eed71d8d23", - "7c642543b9e446e8acbfc8ffb47510eb", - "445b8b2d05594cf091b5275315795516", - "d3f55c1a7193496dae6d6651a5e1024f", - "c89be46611364ee5abe04a5e2f60a66d", - "e2b1c2ad8a1e4456852ad49086775ba5", - "bf6e9a42c1fb45d1ad77cd21355869b6", - "a19520255a734aa5ae9bde6727c0ea58", - "2379f649cc0649ba99cca78a0d2231d8", - "8e42e30ac1ce4c33acc6b2070bfead23", - "8855eee7f91b4f54b412c7d78212521e", - "146e9284ec9543e0babeb82674d9b9be", - "caced8bb622949c7ad2dc714da459791", - "d178aa86336342aa96c6af585f887a03", - "3a3cfb2dfc094e76906f912562c3f16e", - "6e36dd66970e430f9d9ceffa97abeffb", - "a71969db0c5644efb9161637a12a32d2", - "425461aab83744a4a78e5288dc0f6cff", - "8b804e67c6da46ac845b32c5be62e884", - "9fa24325d05f43a0be3f436c469a4c32", - "b6655e268f754b72b083604a50845afb", - "e44c84c0476a424a9be0c307170813e6", - "4a3eb3595f72469292bd3cfba1a6a077", - "b4aeaefcd9ed4c508b45047df1e8f1f1", - "4fd1738d6c3c402e84620ceae44e87cf", - "003bf82570884cc1aa2dfeb489e1216e", - "00f4b77ed87e491cbfe77c9ec290ac19", - "07e51fe161674795a4be13573c3236dd", - "931eae1ec2834b5d89702f49018260d8", - "5d8f0b2318f849008b8aaa72e3a8fba9", - "d6d0a0893f7b428db860633f37f9f4ab", - "fb7da6100bff4f359705dec33b913c39", - "22dff2b9d78b405db32b95df05046bd9", - "fd3515421fc5438e8894377037c3049f", - "cd50f28b37af46fbac6c8e63656e4384", - "917cd4785e3e4563aab3e21a458ff61d", - "dcfde39e3b7842a6832f32f3efc72fb6", - "628c517725d4470f86e1c949d4321a45", - "3564138e49ce41ffa18e3fd3ab85abb6", - "23f2ad3634f64492b6ab1758c447038d", - "759cb2575bac454ab2e5b4e859cf4e6c", - "da7a77693dcf43d884d2681d3be0831c", - "16238363d4094d68892999adc0b9a0b3", - "188c04e4fb2a4214a0b2d93b42858efc", - "709f4b4d2e3f454eb61f2addc49da64d", - "2f5f5c4f08714d9bb2d48b9891d4969c", - "62161138282642a190fb53dcd61548f1", - "757cea1affed4f7888f3136158a13600", - "0b817bd6c784496594763bee72b78031", - "8144af2dce5044e481c9f4db6f47f1f6", - "af91a519e65946e88e9ad12d924776b1", - "6a4d0f0907ac470e898e70d88f399fcd", - "13f2f627d06d462589ecea442d0d0fd0", - "27e1f807b99f43bc814f20c8e2dc543c", - "3c8961406f994fe4accde9d16cdb6fe6", - "6dc9efabf69e45fbbdc4425fc8ce0c3f", - "77bf17a9e3554cc1beef372c861d5472", - "7f7592567346419db6d24d3c6a70bfc8", - "62b2bfcf51a7466a9d6170b99b6ec793", - "dd7cb24a33194ea99d166085fef36cbf", - "b2297524052f4abca043a42f97c10bc6", - "bbaba48376ed4f8695fdcd77d62614d2", - "0d0effdc3edb47ba9d6c086946ba86d6", - "57d47150e3ce4de4aec78def4295d087", - "a4fd00d13e624d41b2564d403fee0b6c", - "7afadad53d034e0d840a7c4837d42dea", - "83fdd35e50f84646a5b9757148cb97f3", - "18ae939c9c704565bb4937931e326aba", - "8e49dbeadca3499b95b05e2c4211d057", - "1aedd9b0c2f14939b2ad59494e02aed8", - "4b6a957a4fb44af58947b5a1fef7882e", - "9595f2def65a44f1a66965b07ff2a51f", - "70e131c2803d4f6fa59d99b8ec1d5e93", - "6f13514495d642a281ec2e70ff772c76", - "8b22fbff367f480b87331935a62d4cfd", - "becf001d1a4f42e88edd89e0849c008c", - "ab1836ed5b934de6ae2fe761f0be60db", - "d83f83a0be0e42f8b4f6458053286a66", - "96d02a5000f74366ab53ef25ea8d95d9", - "6b631a66a94a4b8ab2c90daf4cf146b8", - "5962ef174f574964a1f403e663e18f80", - "6cc6d52b1d974074adfe12afcc2d3e2c" - ] - }, - "id": "Si7bbZIULBZz", - "outputId": "b5e8d9c9-e1a3-4632-b764-a3b7b1931012" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.\n", - "INFO:lightning.pytorch.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.\n", - "INFO: GPU available: False, used: False\n", - "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", - "INFO: TPU available: False, using: 0 TPU cores\n", - "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", - "INFO: HPU available: False, using: 0 HPUs\n", - "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n", - "INFO: \n", - " | Name | Type | Params | Mode \n", - "---------------------------------------------------------------------\n", - "0 | loss | MSELoss | 0 | train\n", - "1 | encoder_var_selection | Sequential | 709 | train\n", - "2 | decoder_var_selection | Sequential | 709 | train\n", - "3 | static_context_linear | Linear | 192 | train\n", - "4 | lstm_encoder | LSTM | 51.5 K | train\n", - "5 | lstm_decoder | LSTM | 51.5 K | train\n", - "6 | self_attention | MultiheadAttention | 16.6 K | train\n", - "7 | pre_output | Linear | 4.2 K | train\n", - "8 | output_layer | Linear | 65 | train\n", - "---------------------------------------------------------------------\n", - "125 K Trainable params\n", - "0 Non-trainable params\n", - "125 K Total params\n", - "0.502 Total estimated model params size (MB)\n", - "18 Modules in train mode\n", - "0 Modules in eval mode\n", - "INFO:lightning.pytorch.callbacks.model_summary:\n", - " | Name | Type | Params | Mode \n", - "---------------------------------------------------------------------\n", - "0 | loss | MSELoss | 0 | train\n", - "1 | encoder_var_selection | Sequential | 709 | train\n", - "2 | decoder_var_selection | Sequential | 709 | train\n", - "3 | static_context_linear | Linear | 192 | train\n", - "4 | lstm_encoder | LSTM | 51.5 K | train\n", - "5 | lstm_decoder | LSTM | 51.5 K | train\n", - "6 | self_attention | MultiheadAttention | 16.6 K | train\n", - "7 | pre_output | Linear | 4.2 K | train\n", - "8 | output_layer | Linear | 65 | train\n", - "---------------------------------------------------------------------\n", - "125 K Trainable params\n", - "0 Non-trainable params\n", - "125 K Total params\n", - "0.502 Total estimated model params size (MB)\n", - "18 Modules in train mode\n", - "0 Modules in eval mode\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Training model...\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "78e4ca50a25146818b654814147ed1ab", - "version_major": 2, - "version_minor": 0 + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "M7PQerTbI_tM" }, - "text/plain": [ - "Sanity Checking: | | 0/? [00:00 int:\n", + " \"\"\"Return number of time series in the dataset.\"\"\"\n", + " return len(self._group_ids)\n", + "\n", + " def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:\n", + " \"\"\"Get time series data for given index.\n", + "\n", + " Returns\n", + " -------\n", + " t : numpy.ndarray of shape (n_timepoints,)\n", + " Time index for each time point in the past or present. Aligned with `y`,\n", + " and `x` not ending in `f`.\n", + "\n", + " y : torch.Tensor of shape (n_timepoints, n_targets)\n", + " Target values for each time point. Rows are time points, aligned with `t`.\n", + "\n", + " x : torch.Tensor of shape (n_timepoints, n_features)\n", + " Features for each time point. Rows are time points, aligned with `t`.\n", + "\n", + " group : torch.Tensor of shape (n_groups,)\n", + " Group identifiers for time series instances.\n", + "\n", + " st : torch.Tensor of shape (n_static_features,)\n", + " Static features.\n", + "\n", + " cutoff_time : float or numpy.float64\n", + " Cutoff time for the time series instance.\n", + "\n", + " Other Returns\n", + " -------------\n", + " weights : torch.Tensor of shape (n_timepoints,), optional\n", + " Only included if weights are not `None`.\n", + " \"\"\"\n", + " group_id = self._group_ids[index]\n", + "\n", + " if self.group:\n", + " mask = self._groups[group_id]\n", + " data = self.data.loc[mask]\n", + " else:\n", + " data = self.data\n", + "\n", + " cutoff_time = data[self.time].max()\n", + "\n", + " result = {\n", + " \"t\": data[self.time].values,\n", + " \"y\": torch.tensor(data[self.target].values),\n", + " \"x\": torch.tensor(data[self.feature_cols].values),\n", + " \"group\": torch.tensor([hash(str(group_id))]),\n", + " \"st\": torch.tensor(data[self.static].iloc[0].values if self.static else []),\n", + " \"cutoff_time\": cutoff_time,\n", + " }\n", + "\n", + " if self.data_future is not None:\n", + " if self.group:\n", + " future_mask = self.data_future.groupby(self.group).groups[group_id]\n", + " future_data = self.data_future.loc[future_mask]\n", + " else:\n", + " future_data = self.data_future\n", + "\n", + " combined_times = np.concatenate(\n", + " [data[self.time].values, future_data[self.time].values]\n", + " )\n", + " combined_times = np.unique(combined_times)\n", + " combined_times.sort()\n", + "\n", + " num_timepoints = len(combined_times)\n", + " x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan)\n", + " y_merged = np.full((num_timepoints, len(self.target)), np.nan)\n", + "\n", + " current_time_indices = {t: i for i, t in enumerate(combined_times)}\n", + " for i, t in enumerate(data[self.time].values):\n", + " idx = current_time_indices[t]\n", + " x_merged[idx] = data[self.feature_cols].values[i]\n", + " y_merged[idx] = data[self.target].values[i]\n", + "\n", + " for i, t in enumerate(future_data[self.time].values):\n", + " if t in current_time_indices:\n", + " idx = current_time_indices[t]\n", + " for j, col in enumerate(self.known):\n", + " if col in self.feature_cols:\n", + " feature_idx = self.feature_cols.index(col)\n", + " x_merged[idx, feature_idx] = future_data[col].values[i]\n", + "\n", + " result.update(\n", + " {\n", + " \"t\": combined_times,\n", + " \"x\": torch.tensor(x_merged, dtype=torch.float32),\n", + " \"y\": torch.tensor(y_merged, dtype=torch.float32),\n", + " }\n", + " )\n", + "\n", + " if self.weight:\n", + " if self.data_future is not None and self.weight in self.data_future.columns:\n", + " weights_merged = np.full(num_timepoints, np.nan)\n", + " for i, t in enumerate(data[self.time].values):\n", + " idx = current_time_indices[t]\n", + " weights_merged[idx] = data[self.weight].values[i]\n", + "\n", + " for i, t in enumerate(future_data[self.time].values):\n", + " if t in current_time_indices and self.weight in future_data.columns:\n", + " idx = current_time_indices[t]\n", + " weights_merged[idx] = future_data[self.weight].values[i]\n", + "\n", + " result[\"weights\"] = torch.tensor(weights_merged, dtype=torch.float32)\n", + " else:\n", + " result[\"weights\"] = torch.tensor(\n", + " data[self.weight].values, dtype=torch.float32\n", + " )\n", + "\n", + " return result\n", + "\n", + " def get_metadata(self) -> Dict:\n", + " \"\"\"Return metadata about the dataset.\n", + "\n", + " Returns\n", + " -------\n", + " Dict\n", + " Dictionary containing:\n", + " - cols: column names for y, x, and static features\n", + " - col_type: mapping of columns to their types (F/C)\n", + " - col_known: mapping of columns to their future known status (K/U)\n", + " \"\"\"\n", + " return self.metadata" ] }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "a71969db0c5644efb9161637a12a32d2", - "version_major": 2, - "version_minor": 0 + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0Rw9LgsXJI5V" }, - "text/plain": [ - "Validation: | | 0/? [00:00 List[Dict[str, Any]]:\n", + " \"\"\"Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset.\n", + "\n", + " Preprocessing steps\n", + " --------------------\n", + "\n", + " * Converts target (`y`) and features (`x`) to `torch.float32`.\n", + " * Masks time points that are at or before the cutoff time.\n", + " * Splits features into categorical and continuous subsets based on\n", + " predefined indices.\n", + "\n", + "\n", + " TODO: add scalers, target normalizers etc.\n", + " \"\"\"\n", + " sample = self.time_series_dataset[series_idx]\n", + "\n", + " target = sample[\"y\"]\n", + " features = sample[\"x\"]\n", + " times = sample[\"t\"]\n", + " cutoff_time = sample[\"cutoff_time\"]\n", + "\n", + " time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool)\n", + "\n", + " if isinstance(target, torch.Tensor):\n", + " target = target.float()\n", + " else:\n", + " target = torch.tensor(target, dtype=torch.float32)\n", + "\n", + " if isinstance(features, torch.Tensor):\n", + " features = features.float()\n", + " else:\n", + " features = torch.tensor(features, dtype=torch.float32)\n", + "\n", + " # TODO: add scalers, target normalizers etc.\n", + "\n", + " categorical = (\n", + " features[:, self.categorical_indices]\n", + " if self.categorical_indices\n", + " else torch.zeros((features.shape[0], 0))\n", + " )\n", + " continuous = (\n", + " features[:, self.continuous_indices]\n", + " if self.continuous_indices\n", + " else torch.zeros((features.shape[0], 0))\n", + " )\n", + "\n", + " return {\n", + " \"features\": {\"categorical\": categorical, \"continuous\": continuous},\n", + " \"target\": target,\n", + " \"static\": sample.get(\"st\", None),\n", + " \"group\": sample.get(\"group\", torch.tensor([0])),\n", + " \"length\": len(target),\n", + " \"time_mask\": time_mask,\n", + " \"times\": times,\n", + " \"cutoff_time\": cutoff_time,\n", + " }\n", + "\n", + " class _ProcessedEncoderDecoderDataset(Dataset):\n", + " \"\"\"PyTorch Dataset for processed encoder-decoder time series data.\n", + "\n", + " Parameters\n", + " ----------\n", + " dataset : TimeSeries\n", + " The base time series dataset that provides access to raw data and metadata.\n", + " data_module : EncoderDecoderTimeSeriesDataModule\n", + " The data module handling preprocessing and metadata configuration.\n", + " windows : List[Tuple[int, int, int, int]]\n", + " List of window tuples containing\n", + " (series_idx, start_idx, enc_length, pred_length).\n", + " add_relative_time_idx : bool, default=False\n", + " Whether to include relative time indices.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " dataset: TimeSeries,\n", + " data_module: \"EncoderDecoderTimeSeriesDataModule\",\n", + " windows: List[Tuple[int, int, int, int]],\n", + " add_relative_time_idx: bool = False,\n", + " ):\n", + " self.dataset = dataset\n", + " self.data_module = data_module\n", + " self.windows = windows\n", + " self.add_relative_time_idx = add_relative_time_idx\n", + "\n", + " def __len__(self):\n", + " return len(self.windows)\n", + "\n", + " def __getitem__(self, idx):\n", + " \"\"\"Retrieve a processed time series window for dataloader input.\n", + "\n", + " x : dict\n", + " Dictionary containing model inputs:\n", + "\n", + " * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features)\n", + " Categorical features for the encoder.\n", + " * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features)\n", + " Continuous features for the encoder.\n", + " * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features)\n", + " Categorical features for the decoder.\n", + " * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features)\n", + " Continuous features for the decoder.\n", + " * ``encoder_lengths`` : tensor of shape (1,)\n", + " Length of the encoder sequence.\n", + " * ``decoder_lengths`` : tensor of shape (1,)\n", + " Length of the decoder sequence.\n", + " * ``decoder_target_lengths`` : tensor of shape (1,)\n", + " Length of the decoder target sequence.\n", + " * ``groups`` : tensor of shape (1,)\n", + " Group identifier for the time series instance.\n", + " * ``encoder_time_idx`` : tensor of shape (enc_length,)\n", + " Time indices for the encoder sequence.\n", + " * ``decoder_time_idx`` : tensor of shape (pred_length,)\n", + " Time indices for the decoder sequence.\n", + " * ``target_scale`` : tensor of shape (1,)\n", + " Scaling factor for the target values.\n", + " * ``encoder_mask`` : tensor of shape (enc_length,)\n", + " Boolean mask indicating valid encoder time points.\n", + " * ``decoder_mask`` : tensor of shape (pred_length,)\n", + " Boolean mask indicating valid decoder time points.\n", + "\n", + " If static features are present, the following keys are added:\n", + "\n", + " * ``static_categorical_features`` : tensor of shape\n", + " (1, n_static_cat_features), optional\n", + " Static categorical features, if available.\n", + " * ``static_continuous_features`` : tensor of shape (1, 0), optional\n", + " Placeholder for static continuous features (currently empty).\n", + "\n", + " y : tensor of shape ``(pred_length, n_targets)``\n", + " Target values for the decoder sequence.\n", + " \"\"\"\n", + " series_idx, start_idx, enc_length, pred_length = self.windows[idx]\n", + " data = self.data_module._preprocess_data(series_idx)\n", + "\n", + " end_idx = start_idx + enc_length + pred_length\n", + " encoder_indices = slice(start_idx, start_idx + enc_length)\n", + " decoder_indices = slice(start_idx + enc_length, end_idx)\n", + "\n", + " target_scale = data[\"target\"][encoder_indices]\n", + " target_scale = target_scale[~torch.isnan(target_scale)].abs().mean()\n", + " if torch.isnan(target_scale) or target_scale == 0:\n", + " target_scale = torch.tensor(1.0)\n", + "\n", + " encoder_mask = (\n", + " data[\"time_mask\"][encoder_indices]\n", + " if \"time_mask\" in data\n", + " else torch.ones(enc_length, dtype=torch.bool)\n", + " )\n", + " decoder_mask = (\n", + " data[\"time_mask\"][decoder_indices]\n", + " if \"time_mask\" in data\n", + " else torch.zeros(pred_length, dtype=torch.bool)\n", + " )\n", + "\n", + " x = {\n", + " \"encoder_cat\": data[\"features\"][\"categorical\"][encoder_indices],\n", + " \"encoder_cont\": data[\"features\"][\"continuous\"][encoder_indices],\n", + " \"decoder_cat\": data[\"features\"][\"categorical\"][decoder_indices],\n", + " \"decoder_cont\": data[\"features\"][\"continuous\"][decoder_indices],\n", + " \"encoder_lengths\": torch.tensor(enc_length),\n", + " \"decoder_lengths\": torch.tensor(pred_length),\n", + " \"decoder_target_lengths\": torch.tensor(pred_length),\n", + " \"groups\": data[\"group\"],\n", + " \"encoder_time_idx\": torch.arange(enc_length),\n", + " \"decoder_time_idx\": torch.arange(enc_length, enc_length + pred_length),\n", + " \"target_scale\": target_scale,\n", + " \"encoder_mask\": encoder_mask,\n", + " \"decoder_mask\": decoder_mask,\n", + " }\n", + " if data[\"static\"] is not None:\n", + " x[\"static_categorical_features\"] = data[\"static\"].unsqueeze(0)\n", + " x[\"static_continuous_features\"] = torch.zeros((1, 0))\n", + "\n", + " y = data[\"target\"][decoder_indices]\n", + " if y.ndim == 1:\n", + " y = y.unsqueeze(-1)\n", + "\n", + " return x, y\n", + "\n", + " def _create_windows(self, indices: torch.Tensor) -> List[Tuple[int, int, int, int]]:\n", + " \"\"\"Generate sliding windows for training, validation, and testing.\n", + "\n", + " Returns\n", + " -------\n", + " List[Tuple[int, int, int, int]]\n", + " A list of tuples, where each tuple consists of:\n", + " - ``series_idx`` : int\n", + " Index of the time series in `time_series_dataset`.\n", + " - ``start_idx`` : int\n", + " Start index of the encoder window.\n", + " - ``enc_length`` : int\n", + " Length of the encoder input sequence.\n", + " - ``pred_length`` : int\n", + " Length of the decoder output sequence.\n", + " \"\"\"\n", + " windows = []\n", + "\n", + " for idx in indices:\n", + " series_idx = idx.item()\n", + " sample = self.time_series_dataset[series_idx]\n", + " sequence_length = len(sample[\"y\"])\n", + "\n", + " if sequence_length < self.max_encoder_length + self.max_prediction_length:\n", + " continue\n", + "\n", + " effective_min_prediction_idx = (\n", + " self.min_prediction_idx\n", + " if self.min_prediction_idx is not None\n", + " else self.max_encoder_length\n", + " )\n", + "\n", + " max_prediction_idx = sequence_length - self.max_prediction_length + 1\n", + "\n", + " if max_prediction_idx <= effective_min_prediction_idx:\n", + " continue\n", + "\n", + " for start_idx in range(\n", + " 0, max_prediction_idx - effective_min_prediction_idx\n", + " ):\n", + " if (\n", + " start_idx + self.max_encoder_length + self.max_prediction_length\n", + " <= sequence_length\n", + " ):\n", + " windows.append(\n", + " (\n", + " series_idx,\n", + " start_idx,\n", + " self.max_encoder_length,\n", + " self.max_prediction_length,\n", + " )\n", + " )\n", + "\n", + " return windows\n", + "\n", + " def setup(self, stage: Optional[str] = None):\n", + " \"\"\"Prepare the datasets for training, validation, testing, or prediction.\n", + "\n", + " Parameters\n", + " ----------\n", + " stage : Optional[str], default=None\n", + " Specifies the stage of setup. Can be one of:\n", + " - ``\"fit\"`` : Prepares training and validation datasets.\n", + " - ``\"test\"`` : Prepares the test dataset.\n", + " - ``\"predict\"`` : Prepares the dataset for inference.\n", + " - ``None`` : Prepares ``fit`` datasets.\n", + " \"\"\"\n", + " total_series = len(self.time_series_dataset)\n", + " self._split_indices = torch.randperm(total_series)\n", + "\n", + " self._train_size = int(self.train_val_test_split[0] * total_series)\n", + " self._val_size = int(self.train_val_test_split[1] * total_series)\n", + "\n", + " self._train_indices = self._split_indices[: self._train_size]\n", + " self._val_indices = self._split_indices[\n", + " self._train_size : self._train_size + self._val_size\n", + " ]\n", + " self._test_indices = self._split_indices[self._train_size + self._val_size :]\n", + "\n", + " if stage is None or stage == \"fit\":\n", + " if not hasattr(self, \"train_dataset\") or not hasattr(self, \"val_dataset\"):\n", + " self.train_windows = self._create_windows(self._train_indices)\n", + " self.val_windows = self._create_windows(self._val_indices)\n", + "\n", + " self.train_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.time_series_dataset,\n", + " self,\n", + " self.train_windows,\n", + " self.add_relative_time_idx,\n", + " )\n", + " self.val_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.time_series_dataset,\n", + " self,\n", + " self.val_windows,\n", + " self.add_relative_time_idx,\n", + " )\n", + "\n", + " elif stage == \"test\":\n", + " if not hasattr(self, \"test_dataset\"):\n", + " self.test_windows = self._create_windows(self._test_indices)\n", + " self.test_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.time_series_dataset,\n", + " self,\n", + " self.test_windows,\n", + " self.add_relative_time_idx,\n", + " )\n", + " elif stage == \"predict\":\n", + " predict_indices = torch.arange(len(self.time_series_dataset))\n", + " self.predict_windows = self._create_windows(predict_indices)\n", + " self.predict_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.time_series_dataset,\n", + " self,\n", + " self.predict_windows,\n", + " self.add_relative_time_idx,\n", + " )\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(\n", + " self.train_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " shuffle=True,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(\n", + " self.val_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(\n", + " self.test_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " def predict_dataloader(self):\n", + " return DataLoader(\n", + " self.predict_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " @staticmethod\n", + " def collate_fn(batch):\n", + " x_batch = {\n", + " \"encoder_cat\": torch.stack([x[\"encoder_cat\"] for x, _ in batch]),\n", + " \"encoder_cont\": torch.stack([x[\"encoder_cont\"] for x, _ in batch]),\n", + " \"decoder_cat\": torch.stack([x[\"decoder_cat\"] for x, _ in batch]),\n", + " \"decoder_cont\": torch.stack([x[\"decoder_cont\"] for x, _ in batch]),\n", + " \"encoder_lengths\": torch.stack([x[\"encoder_lengths\"] for x, _ in batch]),\n", + " \"decoder_lengths\": torch.stack([x[\"decoder_lengths\"] for x, _ in batch]),\n", + " \"decoder_target_lengths\": torch.stack(\n", + " [x[\"decoder_target_lengths\"] for x, _ in batch]\n", + " ),\n", + " \"groups\": torch.stack([x[\"groups\"] for x, _ in batch]),\n", + " \"encoder_time_idx\": torch.stack([x[\"encoder_time_idx\"] for x, _ in batch]),\n", + " \"decoder_time_idx\": torch.stack([x[\"decoder_time_idx\"] for x, _ in batch]),\n", + " \"target_scale\": torch.stack([x[\"target_scale\"] for x, _ in batch]),\n", + " \"encoder_mask\": torch.stack([x[\"encoder_mask\"] for x, _ in batch]),\n", + " \"decoder_mask\": torch.stack([x[\"decoder_mask\"] for x, _ in batch]),\n", + " }\n", + "\n", + " if \"static_categorical_features\" in batch[0][0]:\n", + " x_batch[\"static_categorical_features\"] = torch.stack(\n", + " [x[\"static_categorical_features\"] for x, _ in batch]\n", + " )\n", + " x_batch[\"static_continuous_features\"] = torch.stack(\n", + " [x[\"static_continuous_features\"] for x, _ in batch]\n", + " )\n", + "\n", + " y_batch = torch.stack([y for _, y in batch])\n", + " return x_batch, y_batch" ] }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "07e51fe161674795a4be13573c3236dd", - "version_major": 2, - "version_minor": 0 + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "WX-FRdusJSVN", + "outputId": "2b7ae9bd-0bee-4c05-a512-61193b462274" }, - "text/plain": [ - "Validation: | | 0/? [00:00\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
000-0.0712220.33976301.0000000.6264260
1010.3397630.18934800.9950040.6264260
2020.1893480.67598900.9800670.6264260
3030.6759890.79726100.9553360.6264260
4040.7972610.99501600.9210610.6264260
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "text/plain": [ + " series_id time_idx x y category future_known_feature \\\n", + "0 0 0 -0.071222 0.339763 0 1.000000 \n", + "1 0 1 0.339763 0.189348 0 0.995004 \n", + "2 0 2 0.189348 0.675989 0 0.980067 \n", + "3 0 3 0.675989 0.797261 0 0.955336 \n", + "4 0 4 0.797261 0.995016 0 0.921061 \n", + "\n", + " static_feature static_feature_cat \n", + "0 0.626426 0 \n", + "1 0.626426 0 \n", + "2 0.626426 0 \n", + "3 0.626426 0 \n", + "4 0.626426 0 " + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from lightning.pytorch import Trainer\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "\n", + "from pytorch_forecasting.metrics import MAE, SMAPE\n", + "\n", + "num_series = 100\n", + "seq_length = 50\n", + "data_list = []\n", + "for i in range(num_series):\n", + " x = np.arange(seq_length)\n", + " y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)\n", + " category = i % 5\n", + " static_value = np.random.rand()\n", + " for t in range(seq_length - 1):\n", + " data_list.append(\n", + " {\n", + " \"series_id\": i,\n", + " \"time_idx\": t,\n", + " \"x\": y[t],\n", + " \"y\": y[t + 1],\n", + " \"category\": category,\n", + " \"future_known_feature\": np.cos(t / 10),\n", + " \"static_feature\": static_value,\n", + " \"static_feature_cat\": i % 3,\n", + " }\n", + " )\n", + "data_df = pd.DataFrame(data_list)\n", + "data_df.head()" ] }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "3564138e49ce41ffa18e3fd3ab85abb6", - "version_major": 2, - "version_minor": 0 + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "AxxPHK6AKSD2" }, - "text/plain": [ - "Validation: | | 0/? [00:00 Dict[str, torch.Tensor]:\n", + " \"\"\"\n", + " Forward pass of the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " x : Dict[str, torch.Tensor]\n", + " Dictionary containing input tensors\n", + "\n", + " Returns\n", + " -------\n", + " Dict[str, torch.Tensor]\n", + " Dictionary containing output tensors\n", + " \"\"\"\n", + " raise NotImplementedError(\"Forward method must be implemented by subclass.\")\n", + "\n", + " def training_step(\n", + " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", + " ) -> STEP_OUTPUT:\n", + " \"\"\"\n", + " Training step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input and target tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + "\n", + " Returns\n", + " -------\n", + " STEP_OUTPUT\n", + " Dictionary containing the loss and other metrics.\n", + " \"\"\"\n", + " x, y = batch\n", + " y_hat_dict = self(x)\n", + " y_hat = y_hat_dict[\"prediction\"]\n", + " loss = self.loss(y_hat, y)\n", + " self.log(\n", + " \"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True\n", + " )\n", + " self.log_metrics(y_hat, y, prefix=\"train\")\n", + " return {\"loss\": loss}\n", + "\n", + " def validation_step(\n", + " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", + " ) -> STEP_OUTPUT:\n", + " \"\"\"\n", + " Validation step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input and target tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + "\n", + " Returns\n", + " -------\n", + " STEP_OUTPUT\n", + " Dictionary containing the loss and other metrics.\n", + " \"\"\"\n", + " x, y = batch\n", + " y_hat_dict = self(x)\n", + " y_hat = y_hat_dict[\"prediction\"]\n", + " loss = self.loss(y_hat, y)\n", + " self.log(\n", + " \"val_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", + " )\n", + " self.log_metrics(y_hat, y, prefix=\"val\")\n", + " return {\"val_loss\": loss}\n", + "\n", + " def test_step(\n", + " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", + " ) -> STEP_OUTPUT:\n", + " \"\"\"\n", + " Test step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input and target tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + "\n", + " Returns\n", + " -------\n", + " STEP_OUTPUT\n", + " Dictionary containing the loss and other metrics.\n", + " \"\"\"\n", + " x, y = batch\n", + " y_hat_dict = self(x)\n", + " y_hat = y_hat_dict[\"prediction\"]\n", + " loss = self.loss(y_hat, y)\n", + " self.log(\n", + " \"test_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", + " )\n", + " self.log_metrics(y_hat, y, prefix=\"test\")\n", + " return {\"test_loss\": loss}\n", + "\n", + " def predict_step(\n", + " self,\n", + " batch: Tuple[Dict[str, torch.Tensor]],\n", + " batch_idx: int,\n", + " dataloader_idx: int = 0,\n", + " ) -> torch.Tensor:\n", + " \"\"\"\n", + " Prediction step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + " dataloader_idx : int\n", + " Index of the dataloader.\n", + "\n", + " Returns\n", + " -------\n", + " torch.Tensor\n", + " Predicted output tensor.\n", + " \"\"\"\n", + " x, _ = batch\n", + " y_hat = self(x)\n", + " return y_hat\n", + "\n", + " def configure_optimizers(self) -> Dict:\n", + " \"\"\"\n", + " Configure the optimizer and learning rate scheduler.\n", + "\n", + " Returns\n", + " -------\n", + " Dict\n", + " Dictionary containing the optimizer and scheduler configuration.\n", + " \"\"\"\n", + " optimizer = self._get_optimizer()\n", + " if self.lr_scheduler is not None:\n", + " scheduler = self._get_scheduler(optimizer)\n", + " if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n", + " return {\n", + " \"optimizer\": optimizer,\n", + " \"lr_scheduler\": {\n", + " \"scheduler\": scheduler,\n", + " \"monitor\": \"val_loss\",\n", + " },\n", + " }\n", + " else:\n", + " return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n", + " return {\"optimizer\": optimizer}\n", + "\n", + " def _get_optimizer(self) -> Optimizer:\n", + " \"\"\"\n", + " Get the optimizer based on the specified optimizer name and parameters.\n", + "\n", + " Returns\n", + " -------\n", + " Optimizer\n", + " The optimizer instance.\n", + " \"\"\"\n", + " if isinstance(self.optimizer, str):\n", + " if self.optimizer.lower() == \"adam\":\n", + " return torch.optim.Adam(self.parameters(), **self.optimizer_params)\n", + " elif self.optimizer.lower() == \"sgd\":\n", + " return torch.optim.SGD(self.parameters(), **self.optimizer_params)\n", + " else:\n", + " raise ValueError(f\"Optimizer {self.optimizer} not supported.\")\n", + " elif isinstance(self.optimizer, Optimizer):\n", + " return self.optimizer\n", + " else:\n", + " raise ValueError(\n", + " \"Optimizer must be either a string or \"\n", + " \"an instance of torch.optim.Optimizer.\"\n", + " )\n", + "\n", + " def _get_scheduler(\n", + " self, optimizer: Optimizer\n", + " ) -> torch.optim.lr_scheduler._LRScheduler:\n", + " \"\"\"\n", + " Get the lr scheduler based on the specified scheduler name and params.\n", + "\n", + " Parameters\n", + " ----------\n", + " optimizer : Optimizer\n", + " The optimizer instance.\n", + "\n", + " Returns\n", + " -------\n", + " torch.optim.lr_scheduler._LRScheduler\n", + " The learning rate scheduler instance.\n", + " \"\"\"\n", + " if self.lr_scheduler.lower() == \"reduce_lr_on_plateau\":\n", + " return torch.optim.lr_scheduler.ReduceLROnPlateau(\n", + " optimizer, **self.lr_scheduler_params\n", + " )\n", + " elif self.lr_scheduler.lower() == \"step_lr\":\n", + " return torch.optim.lr_scheduler.StepLR(\n", + " optimizer, **self.lr_scheduler_params\n", + " )\n", + " else:\n", + " raise ValueError(f\"Scheduler {self.lr_scheduler} not supported.\")\n", + "\n", + " def log_metrics(\n", + " self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = \"val\"\n", + " ) -> None:\n", + " \"\"\"\n", + " Log additional metrics during training, validation, or testing.\n", + "\n", + " Parameters\n", + " ----------\n", + " y_hat : torch.Tensor\n", + " Predicted output tensor.\n", + " y : torch.Tensor\n", + " Target output tensor.\n", + " prefix : str\n", + " Prefix for the logged metrics (e.g., \"train\", \"val\", \"test\").\n", + " \"\"\"\n", + " for metric in self.logging_metrics:\n", + " metric_value = metric(y_hat, y)\n", + " self.log(\n", + " f\"{prefix}_{metric.__class__.__name__}\",\n", + " metric_value,\n", + " on_step=False,\n", + " on_epoch=True,\n", + " prog_bar=True,\n", + " logger=True,\n", + " )" ] }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO: `Trainer.fit` stopped: `max_epochs=5` reached.\n", - "INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Evaluating model...\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9595f2def65a44f1a66965b07ff2a51f", - "version_major": 2, - "version_minor": 0 + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "n5EBeGucK_k0" }, - "text/plain": [ - "Testing: | | 0/? [00:00 0 else None\n", + " )\n", + "\n", + " self.lstm_encoder = nn.LSTM(\n", + " input_size=total_feature_size,\n", + " hidden_size=hidden_size,\n", + " num_layers=num_layers,\n", + " dropout=dropout,\n", + " batch_first=True,\n", + " )\n", + "\n", + " self.lstm_decoder = nn.LSTM(\n", + " input_size=total_feature_size,\n", + " hidden_size=hidden_size,\n", + " num_layers=num_layers,\n", + " dropout=dropout,\n", + " batch_first=True,\n", + " )\n", + "\n", + " self.self_attention = nn.MultiheadAttention(\n", + " embed_dim=hidden_size,\n", + " num_heads=attention_head_size,\n", + " dropout=dropout,\n", + " batch_first=True,\n", + " )\n", + "\n", + " self.pre_output = nn.Linear(hidden_size, hidden_size)\n", + " self.output_layer = nn.Linear(hidden_size, output_size)\n", + "\n", + " def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n", + " \"\"\"\n", + " Forward pass of the TFT model.\n", + "\n", + " Parameters\n", + " ----------\n", + " x : Dict[str, torch.Tensor]\n", + " Dictionary containing input tensors:\n", + " - encoder_cat: Categorical encoder features\n", + " - encoder_cont: Continuous encoder features\n", + " - decoder_cat: Categorical decoder features\n", + " - decoder_cont: Continuous decoder features\n", + " - static_categorical_features: Static categorical features\n", + " - static_continuous_features: Static continuous features\n", + "\n", + " Returns\n", + " -------\n", + " Dict[str, torch.Tensor]\n", + " Dictionary containing output tensors:\n", + " - prediction: Prediction output (batch_size, prediction_length, output_size)\n", + " \"\"\"\n", + " batch_size = x[\"encoder_cont\"].shape[0]\n", + "\n", + " encoder_cat = x.get(\n", + " \"encoder_cat\",\n", + " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", + " )\n", + " encoder_cont = x.get(\n", + " \"encoder_cont\",\n", + " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", + " )\n", + " decoder_cat = x.get(\n", + " \"decoder_cat\",\n", + " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", + " )\n", + " decoder_cont = x.get(\n", + " \"decoder_cont\",\n", + " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", + " )\n", + "\n", + " encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2)\n", + " decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2)\n", + "\n", + " static_context = None\n", + " if self.static_context_linear is not None:\n", + " static_cat = x.get(\n", + " \"static_categorical_features\",\n", + " torch.zeros(batch_size, 0, device=self.device),\n", + " )\n", + " static_cont = x.get(\n", + " \"static_continuous_features\",\n", + " torch.zeros(batch_size, 0, device=self.device),\n", + " )\n", + "\n", + " if static_cat.size(2) == 0 and static_cont.size(2) == 0:\n", + " static_context = None\n", + " elif static_cat.size(2) == 0:\n", + " static_input = static_cont.to(\n", + " dtype=self.static_context_linear.weight.dtype\n", + " )\n", + " static_context = self.static_context_linear(static_input)\n", + " static_context = static_context.view(batch_size, self.hidden_size)\n", + " elif static_cont.size(2) == 0:\n", + " static_input = static_cat.to(\n", + " dtype=self.static_context_linear.weight.dtype\n", + " )\n", + " static_context = self.static_context_linear(static_input)\n", + " static_context = static_context.view(batch_size, self.hidden_size)\n", + " else:\n", + "\n", + " static_input = torch.cat([static_cont, static_cat], dim=1).to(\n", + " dtype=self.static_context_linear.weight.dtype\n", + " )\n", + " static_context = self.static_context_linear(static_input)\n", + " static_context = static_context.view(batch_size, self.hidden_size)\n", + "\n", + " encoder_weights = self.encoder_var_selection(encoder_input)\n", + " encoder_input = encoder_input * encoder_weights\n", + "\n", + " decoder_weights = self.decoder_var_selection(decoder_input)\n", + " decoder_input = decoder_input * decoder_weights\n", + "\n", + " if static_context is not None:\n", + " encoder_static_context = static_context.unsqueeze(1).expand(\n", + " -1, self.max_encoder_length, -1\n", + " )\n", + " decoder_static_context = static_context.unsqueeze(1).expand(\n", + " -1, self.max_prediction_length, -1\n", + " )\n", + "\n", + " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", + " encoder_output = encoder_output + encoder_static_context\n", + " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", + " decoder_output = decoder_output + decoder_static_context\n", + " else:\n", + " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", + " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", + "\n", + " sequence = torch.cat([encoder_output, decoder_output], dim=1)\n", + "\n", + " if static_context is not None:\n", + " expanded_static_context = static_context.unsqueeze(1).expand(\n", + " -1, sequence.size(1), -1\n", + " )\n", + "\n", + " attended_output, _ = self.self_attention(\n", + " sequence + expanded_static_context, sequence, sequence\n", + " )\n", + " else:\n", + " attended_output, _ = self.self_attention(sequence, sequence, sequence)\n", + "\n", + " decoder_attended = attended_output[:, -self.max_prediction_length :, :]\n", + "\n", + " output = nn.functional.relu(self.pre_output(decoder_attended))\n", + " prediction = self.output_layer(output)\n", + "\n", + " return {\"prediction\": prediction}" ] }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
-       "┃        Test metric               DataLoader 0        ┃\n",
-       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
-       "│         test_MAE              0.4637378454208374     │\n",
-       "│        test_SMAPE             1.0857858657836914     │\n",
-       "│         test_loss            0.014832879416644573    │\n",
-       "└───────────────────────────┴───────────────────────────┘\n",
-       "
\n" + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "78e4ca50a25146818b654814147ed1ab", + "8cd83f0e9663436b967df6e5d38766ae", + "4c7ee7e81def457aacb0d898f92437c0", + "788def26447b409da872d076b4ba8061", + "407fe1dea38d48c99473c4019089bc90", + "6084560f9ea74d2c9e106a9500bc65eb", + "f44964101be74d03a83601eed71d8d23", + "7c642543b9e446e8acbfc8ffb47510eb", + "445b8b2d05594cf091b5275315795516", + "d3f55c1a7193496dae6d6651a5e1024f", + "c89be46611364ee5abe04a5e2f60a66d", + "e2b1c2ad8a1e4456852ad49086775ba5", + "bf6e9a42c1fb45d1ad77cd21355869b6", + "a19520255a734aa5ae9bde6727c0ea58", + "2379f649cc0649ba99cca78a0d2231d8", + "8e42e30ac1ce4c33acc6b2070bfead23", + "8855eee7f91b4f54b412c7d78212521e", + "146e9284ec9543e0babeb82674d9b9be", + "caced8bb622949c7ad2dc714da459791", + "d178aa86336342aa96c6af585f887a03", + "3a3cfb2dfc094e76906f912562c3f16e", + "6e36dd66970e430f9d9ceffa97abeffb", + "a71969db0c5644efb9161637a12a32d2", + "425461aab83744a4a78e5288dc0f6cff", + "8b804e67c6da46ac845b32c5be62e884", + "9fa24325d05f43a0be3f436c469a4c32", + "b6655e268f754b72b083604a50845afb", + "e44c84c0476a424a9be0c307170813e6", + "4a3eb3595f72469292bd3cfba1a6a077", + "b4aeaefcd9ed4c508b45047df1e8f1f1", + "4fd1738d6c3c402e84620ceae44e87cf", + "003bf82570884cc1aa2dfeb489e1216e", + "00f4b77ed87e491cbfe77c9ec290ac19", + "07e51fe161674795a4be13573c3236dd", + "931eae1ec2834b5d89702f49018260d8", + "5d8f0b2318f849008b8aaa72e3a8fba9", + "d6d0a0893f7b428db860633f37f9f4ab", + "fb7da6100bff4f359705dec33b913c39", + "22dff2b9d78b405db32b95df05046bd9", + "fd3515421fc5438e8894377037c3049f", + "cd50f28b37af46fbac6c8e63656e4384", + "917cd4785e3e4563aab3e21a458ff61d", + "dcfde39e3b7842a6832f32f3efc72fb6", + "628c517725d4470f86e1c949d4321a45", + "3564138e49ce41ffa18e3fd3ab85abb6", + "23f2ad3634f64492b6ab1758c447038d", + "759cb2575bac454ab2e5b4e859cf4e6c", + "da7a77693dcf43d884d2681d3be0831c", + "16238363d4094d68892999adc0b9a0b3", + "188c04e4fb2a4214a0b2d93b42858efc", + "709f4b4d2e3f454eb61f2addc49da64d", + "2f5f5c4f08714d9bb2d48b9891d4969c", + "62161138282642a190fb53dcd61548f1", + "757cea1affed4f7888f3136158a13600", + "0b817bd6c784496594763bee72b78031", + "8144af2dce5044e481c9f4db6f47f1f6", + "af91a519e65946e88e9ad12d924776b1", + "6a4d0f0907ac470e898e70d88f399fcd", + "13f2f627d06d462589ecea442d0d0fd0", + "27e1f807b99f43bc814f20c8e2dc543c", + "3c8961406f994fe4accde9d16cdb6fe6", + "6dc9efabf69e45fbbdc4425fc8ce0c3f", + "77bf17a9e3554cc1beef372c861d5472", + "7f7592567346419db6d24d3c6a70bfc8", + "62b2bfcf51a7466a9d6170b99b6ec793", + "dd7cb24a33194ea99d166085fef36cbf", + "b2297524052f4abca043a42f97c10bc6", + "bbaba48376ed4f8695fdcd77d62614d2", + "0d0effdc3edb47ba9d6c086946ba86d6", + "57d47150e3ce4de4aec78def4295d087", + "a4fd00d13e624d41b2564d403fee0b6c", + "7afadad53d034e0d840a7c4837d42dea", + "83fdd35e50f84646a5b9757148cb97f3", + "18ae939c9c704565bb4937931e326aba", + "8e49dbeadca3499b95b05e2c4211d057", + "1aedd9b0c2f14939b2ad59494e02aed8", + "4b6a957a4fb44af58947b5a1fef7882e", + "9595f2def65a44f1a66965b07ff2a51f", + "70e131c2803d4f6fa59d99b8ec1d5e93", + "6f13514495d642a281ec2e70ff772c76", + "8b22fbff367f480b87331935a62d4cfd", + "becf001d1a4f42e88edd89e0849c008c", + "ab1836ed5b934de6ae2fe761f0be60db", + "d83f83a0be0e42f8b4f6458053286a66", + "96d02a5000f74366ab53ef25ea8d95d9", + "6b631a66a94a4b8ab2c90daf4cf146b8", + "5962ef174f574964a1f403e663e18f80", + "6cc6d52b1d974074adfe12afcc2d3e2c" + ] + }, + "id": "Si7bbZIULBZz", + "outputId": "b5e8d9c9-e1a3-4632-b764-a3b7b1931012" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.\n", + "INFO:lightning.pytorch.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.\n", + "INFO: GPU available: False, used: False\n", + "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", + "INFO: TPU available: False, using: 0 TPU cores\n", + "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO: HPU available: False, using: 0 HPUs\n", + "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n", + "INFO: \n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 709 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 51.5 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "125 K Trainable params\n", + "0 Non-trainable params\n", + "125 K Total params\n", + "0.502 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n", + "INFO:lightning.pytorch.callbacks.model_summary:\n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 709 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 51.5 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "125 K Trainable params\n", + "0 Non-trainable params\n", + "125 K Total params\n", + "0.502 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Training model...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "78e4ca50a25146818b654814147ed1ab", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test_MAE 0.4637378454208374 │\n", + "│ test_SMAPE 1.0857858657836914 │\n", + "│ test_loss 0.014832879416644573 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.4637378454208374 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0857858657836914 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.014832879416644573 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Prediction shape: torch.Size([32, 1, 1])\n", + "First prediction values: [[-0.00597369]]\n", + "First true values: [[-0.09480439]]\n", + "\n", + "TFT model test complete!\n" + ] + } ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.4637378454208374 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0857858657836914 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.014832879416644573 \u001b[0m\u001b[35m \u001b[0m│\n", - "└───────────────────────────┴───────────────────────────┘\n" + "source": [ + "model = TFT(\n", + " loss=nn.MSELoss(),\n", + " logging_metrics=[MAE(), SMAPE()],\n", + " optimizer=\"adam\",\n", + " optimizer_params={\"lr\": 1e-3},\n", + " lr_scheduler=\"reduce_lr_on_plateau\",\n", + " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", + " hidden_size=64,\n", + " num_layers=2,\n", + " attention_head_size=4,\n", + " dropout=0.1,\n", + " metadata=data_module.metadata,\n", + ")\n", + "\n", + "print(\"\\nTraining model...\")\n", + "trainer = Trainer(max_epochs=5, accelerator=\"auto\", devices=1, enable_progress_bar=True)\n", + "\n", + "trainer.fit(model, data_module)\n", + "\n", + "print(\"\\nEvaluating model...\")\n", + "test_metrics = trainer.test(model, data_module)\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " test_batch = next(iter(data_module.test_dataloader()))\n", + " x_test, y_test = test_batch\n", + " y_pred = model(x_test)\n", + "\n", + " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", + " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", + " print(\"First true values:\", y_test[0].cpu().numpy())\n", + "print(\"\\nTFT model test complete!\")" ] }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Prediction shape: torch.Size([32, 1, 1])\n", - "First prediction values: [[-0.00597369]]\n", - "First true values: [[-0.09480439]]\n", - "\n", - "TFT model test complete!\n" - ] - } - ], - "source": [ - "model = TFT(\n", - " loss=nn.MSELoss(),\n", - " logging_metrics=[MAE(), SMAPE()],\n", - " optimizer=\"adam\",\n", - " optimizer_params={\"lr\": 1e-3},\n", - " lr_scheduler=\"reduce_lr_on_plateau\",\n", - " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", - " hidden_size=64,\n", - " num_layers=2,\n", - " attention_head_size=4,\n", - " dropout=0.1,\n", - " metadata=data_module.metadata,\n", - ")\n", - "\n", - "print(\"\\nTraining model...\")\n", - "trainer = Trainer(max_epochs=5, accelerator=\"auto\", devices=1, enable_progress_bar=True)\n", - "\n", - "trainer.fit(model, data_module)\n", - "\n", - "print(\"\\nEvaluating model...\")\n", - "test_metrics = trainer.test(model, data_module)\n", - "\n", - "model.eval()\n", - "with torch.no_grad():\n", - " test_batch = next(iter(data_module.test_dataloader()))\n", - " x_test, y_test = test_batch\n", - " y_pred = model(x_test)\n", - "\n", - " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", - " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", - " print(\"First true values:\", y_test[0].cpu().numpy())\n", - "print(\"\\nTFT model test complete!\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zVRwi2MvLGgc" - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "003bf82570884cc1aa2dfeb489e1216e": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "00f4b77ed87e491cbfe77c9ec290ac19": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "07e51fe161674795a4be13573c3236dd": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_931eae1ec2834b5d89702f49018260d8", - "IPY_MODEL_5d8f0b2318f849008b8aaa72e3a8fba9", - "IPY_MODEL_d6d0a0893f7b428db860633f37f9f4ab" - ], - "layout": "IPY_MODEL_fb7da6100bff4f359705dec33b913c39" - } - }, - "0b817bd6c784496594763bee72b78031": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "0d0effdc3edb47ba9d6c086946ba86d6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_18ae939c9c704565bb4937931e326aba", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_8e49dbeadca3499b95b05e2c4211d057", - "value": 9 - } - }, - "13f2f627d06d462589ecea442d0d0fd0": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_62b2bfcf51a7466a9d6170b99b6ec793", - "placeholder": "​", - "style": "IPY_MODEL_dd7cb24a33194ea99d166085fef36cbf", - "value": " 9/9 [00:00<00:00, 11.96it/s]" - } - }, - "146e9284ec9543e0babeb82674d9b9be": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "16238363d4094d68892999adc0b9a0b3": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "188c04e4fb2a4214a0b2d93b42858efc": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "18ae939c9c704565bb4937931e326aba": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "1aedd9b0c2f14939b2ad59494e02aed8": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "22dff2b9d78b405db32b95df05046bd9": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2379f649cc0649ba99cca78a0d2231d8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_3a3cfb2dfc094e76906f912562c3f16e", - "placeholder": "​", - "style": "IPY_MODEL_6e36dd66970e430f9d9ceffa97abeffb", - "value": " 42/42 [00:05<00:00,  7.01it/s, v_num=1, train_loss_step=0.00959, val_loss=0.0166, val_MAE=0.472, val_SMAPE=1.050, train_loss_epoch=0.0169, train_MAE=0.464, train_SMAPE=1.030]" - } - }, - "23f2ad3634f64492b6ab1758c447038d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_188c04e4fb2a4214a0b2d93b42858efc", - "placeholder": "​", - "style": "IPY_MODEL_709f4b4d2e3f454eb61f2addc49da64d", - "value": "Validation DataLoader 0: 100%" - } - }, - "27e1f807b99f43bc814f20c8e2dc543c": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "2f5f5c4f08714d9bb2d48b9891d4969c": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "3564138e49ce41ffa18e3fd3ab85abb6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_23f2ad3634f64492b6ab1758c447038d", - "IPY_MODEL_759cb2575bac454ab2e5b4e859cf4e6c", - "IPY_MODEL_da7a77693dcf43d884d2681d3be0831c" - ], - "layout": "IPY_MODEL_16238363d4094d68892999adc0b9a0b3" - } - }, - "3a3cfb2dfc094e76906f912562c3f16e": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "3c8961406f994fe4accde9d16cdb6fe6": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "407fe1dea38d48c99473c4019089bc90": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "425461aab83744a4a78e5288dc0f6cff": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_e44c84c0476a424a9be0c307170813e6", - "placeholder": "​", - "style": "IPY_MODEL_4a3eb3595f72469292bd3cfba1a6a077", - "value": "Validation DataLoader 0: 100%" - } - }, - "445b8b2d05594cf091b5275315795516": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "4a3eb3595f72469292bd3cfba1a6a077": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "4b6a957a4fb44af58947b5a1fef7882e": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "4c7ee7e81def457aacb0d898f92437c0": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_7c642543b9e446e8acbfc8ffb47510eb", - "max": 2, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_445b8b2d05594cf091b5275315795516", - "value": 2 - } - }, - "4fd1738d6c3c402e84620ceae44e87cf": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "57d47150e3ce4de4aec78def4295d087": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_1aedd9b0c2f14939b2ad59494e02aed8", - "placeholder": "​", - "style": "IPY_MODEL_4b6a957a4fb44af58947b5a1fef7882e", - "value": " 9/9 [00:00<00:00, 12.24it/s]" - } - }, - "5962ef174f574964a1f403e663e18f80": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "5d8f0b2318f849008b8aaa72e3a8fba9": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_cd50f28b37af46fbac6c8e63656e4384", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_917cd4785e3e4563aab3e21a458ff61d", - "value": 9 - } - }, - "6084560f9ea74d2c9e106a9500bc65eb": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "62161138282642a190fb53dcd61548f1": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "628c517725d4470f86e1c949d4321a45": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "62b2bfcf51a7466a9d6170b99b6ec793": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "6a4d0f0907ac470e898e70d88f399fcd": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_77bf17a9e3554cc1beef372c861d5472", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_7f7592567346419db6d24d3c6a70bfc8", - "value": 9 - } - }, - "6b631a66a94a4b8ab2c90daf4cf146b8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "6cc6d52b1d974074adfe12afcc2d3e2c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "6dc9efabf69e45fbbdc4425fc8ce0c3f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "6e36dd66970e430f9d9ceffa97abeffb": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "6f13514495d642a281ec2e70ff772c76": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_96d02a5000f74366ab53ef25ea8d95d9", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_6b631a66a94a4b8ab2c90daf4cf146b8", - "value": 9 - } - }, - "709f4b4d2e3f454eb61f2addc49da64d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "70e131c2803d4f6fa59d99b8ec1d5e93": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_ab1836ed5b934de6ae2fe761f0be60db", - "placeholder": "​", - "style": "IPY_MODEL_d83f83a0be0e42f8b4f6458053286a66", - "value": "Testing DataLoader 0: 100%" - } - }, - "757cea1affed4f7888f3136158a13600": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "759cb2575bac454ab2e5b4e859cf4e6c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_2f5f5c4f08714d9bb2d48b9891d4969c", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_62161138282642a190fb53dcd61548f1", - "value": 9 - } - }, - "77bf17a9e3554cc1beef372c861d5472": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "788def26447b409da872d076b4ba8061": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_d3f55c1a7193496dae6d6651a5e1024f", - "placeholder": "​", - "style": "IPY_MODEL_c89be46611364ee5abe04a5e2f60a66d", - "value": " 2/2 [00:00<00:00, 18.07it/s]" - } - }, - "78e4ca50a25146818b654814147ed1ab": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_8cd83f0e9663436b967df6e5d38766ae", - "IPY_MODEL_4c7ee7e81def457aacb0d898f92437c0", - "IPY_MODEL_788def26447b409da872d076b4ba8061" - ], - "layout": "IPY_MODEL_407fe1dea38d48c99473c4019089bc90" - } - }, - "7afadad53d034e0d840a7c4837d42dea": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "7c642543b9e446e8acbfc8ffb47510eb": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "7f7592567346419db6d24d3c6a70bfc8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "8144af2dce5044e481c9f4db6f47f1f6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_af91a519e65946e88e9ad12d924776b1", - "IPY_MODEL_6a4d0f0907ac470e898e70d88f399fcd", - "IPY_MODEL_13f2f627d06d462589ecea442d0d0fd0" - ], - "layout": "IPY_MODEL_27e1f807b99f43bc814f20c8e2dc543c" - } - }, - "83fdd35e50f84646a5b9757148cb97f3": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "8855eee7f91b4f54b412c7d78212521e": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "8b22fbff367f480b87331935a62d4cfd": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_5962ef174f574964a1f403e663e18f80", - "placeholder": "​", - "style": "IPY_MODEL_6cc6d52b1d974074adfe12afcc2d3e2c", - "value": " 9/9 [00:00<00:00, 12.57it/s]" - } - }, - "8b804e67c6da46ac845b32c5be62e884": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_b4aeaefcd9ed4c508b45047df1e8f1f1", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_4fd1738d6c3c402e84620ceae44e87cf", - "value": 9 - } - }, - "8cd83f0e9663436b967df6e5d38766ae": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_6084560f9ea74d2c9e106a9500bc65eb", - "placeholder": "​", - "style": "IPY_MODEL_f44964101be74d03a83601eed71d8d23", - "value": "Sanity Checking DataLoader 0: 100%" - } - }, - "8e42e30ac1ce4c33acc6b2070bfead23": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": "100%" - } - }, - "8e49dbeadca3499b95b05e2c4211d057": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "917cd4785e3e4563aab3e21a458ff61d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "931eae1ec2834b5d89702f49018260d8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_22dff2b9d78b405db32b95df05046bd9", - "placeholder": "​", - "style": "IPY_MODEL_fd3515421fc5438e8894377037c3049f", - "value": "Validation DataLoader 0: 100%" - } - }, - "9595f2def65a44f1a66965b07ff2a51f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_70e131c2803d4f6fa59d99b8ec1d5e93", - "IPY_MODEL_6f13514495d642a281ec2e70ff772c76", - "IPY_MODEL_8b22fbff367f480b87331935a62d4cfd" - ], - "layout": "IPY_MODEL_becf001d1a4f42e88edd89e0849c008c" - } - }, - "96d02a5000f74366ab53ef25ea8d95d9": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "9fa24325d05f43a0be3f436c469a4c32": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_003bf82570884cc1aa2dfeb489e1216e", - "placeholder": "​", - "style": "IPY_MODEL_00f4b77ed87e491cbfe77c9ec290ac19", - "value": " 9/9 [00:00<00:00, 12.36it/s]" - } - }, - "a19520255a734aa5ae9bde6727c0ea58": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_caced8bb622949c7ad2dc714da459791", - "max": 42, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_d178aa86336342aa96c6af585f887a03", - "value": 42 - } - }, - "a4fd00d13e624d41b2564d403fee0b6c": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "a71969db0c5644efb9161637a12a32d2": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_425461aab83744a4a78e5288dc0f6cff", - "IPY_MODEL_8b804e67c6da46ac845b32c5be62e884", - "IPY_MODEL_9fa24325d05f43a0be3f436c469a4c32" - ], - "layout": "IPY_MODEL_b6655e268f754b72b083604a50845afb" - } - }, - "ab1836ed5b934de6ae2fe761f0be60db": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "af91a519e65946e88e9ad12d924776b1": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_3c8961406f994fe4accde9d16cdb6fe6", - "placeholder": "​", - "style": "IPY_MODEL_6dc9efabf69e45fbbdc4425fc8ce0c3f", - "value": "Validation DataLoader 0: 100%" - } - }, - "b2297524052f4abca043a42f97c10bc6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_bbaba48376ed4f8695fdcd77d62614d2", - "IPY_MODEL_0d0effdc3edb47ba9d6c086946ba86d6", - "IPY_MODEL_57d47150e3ce4de4aec78def4295d087" - ], - "layout": "IPY_MODEL_a4fd00d13e624d41b2564d403fee0b6c" - } - }, - "b4aeaefcd9ed4c508b45047df1e8f1f1": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "b6655e268f754b72b083604a50845afb": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "bbaba48376ed4f8695fdcd77d62614d2": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_7afadad53d034e0d840a7c4837d42dea", - "placeholder": "​", - "style": "IPY_MODEL_83fdd35e50f84646a5b9757148cb97f3", - "value": "Validation DataLoader 0: 100%" - } - }, - "becf001d1a4f42e88edd89e0849c008c": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": "100%" - } - }, - "bf6e9a42c1fb45d1ad77cd21355869b6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_8855eee7f91b4f54b412c7d78212521e", - "placeholder": "​", - "style": "IPY_MODEL_146e9284ec9543e0babeb82674d9b9be", - "value": "Epoch 4: 100%" - } - }, - "c89be46611364ee5abe04a5e2f60a66d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "caced8bb622949c7ad2dc714da459791": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "cd50f28b37af46fbac6c8e63656e4384": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d178aa86336342aa96c6af585f887a03": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "d3f55c1a7193496dae6d6651a5e1024f": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d6d0a0893f7b428db860633f37f9f4ab": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_dcfde39e3b7842a6832f32f3efc72fb6", - "placeholder": "​", - "style": "IPY_MODEL_628c517725d4470f86e1c949d4321a45", - "value": " 9/9 [00:00<00:00, 12.22it/s]" - } - }, - "d83f83a0be0e42f8b4f6458053286a66": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "da7a77693dcf43d884d2681d3be0831c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_757cea1affed4f7888f3136158a13600", - "placeholder": "​", - "style": "IPY_MODEL_0b817bd6c784496594763bee72b78031", - "value": " 9/9 [00:00<00:00, 12.37it/s]" - } - }, - "dcfde39e3b7842a6832f32f3efc72fb6": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "dd7cb24a33194ea99d166085fef36cbf": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "e2b1c2ad8a1e4456852ad49086775ba5": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_bf6e9a42c1fb45d1ad77cd21355869b6", - "IPY_MODEL_a19520255a734aa5ae9bde6727c0ea58", - "IPY_MODEL_2379f649cc0649ba99cca78a0d2231d8" - ], - "layout": "IPY_MODEL_8e42e30ac1ce4c33acc6b2070bfead23" - } - }, - "e44c84c0476a424a9be0c307170813e6": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f44964101be74d03a83601eed71d8d23": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zVRwi2MvLGgc" + }, + "outputs": [], + "source": [] } - }, - "fb7da6100bff4f359705dec33b913c39": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" } }, - "fd3515421fc5438e8894377037c3049f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - } + "nbformat": 4, + "nbformat_minor": 0 } - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} From 20aafb749cfebdb1f9789b4dff5120fa8527db74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 30 Apr 2025 18:40:01 +0200 Subject: [PATCH 15/30] refactor file --- pytorch_forecasting/data/__init__.py | 3 +- .../data/timeseries/__init__.py | 9 + .../data/timeseries/_coerce.py | 25 ++ .../_timeseries.py} | 286 +----------------- .../data/timeseries/_timeseries_v2.py | 276 +++++++++++++++++ 5 files changed, 314 insertions(+), 285 deletions(-) create mode 100644 pytorch_forecasting/data/timeseries/__init__.py create mode 100644 pytorch_forecasting/data/timeseries/_coerce.py rename pytorch_forecasting/data/{timeseries.py => timeseries/_timeseries.py} (90%) create mode 100644 pytorch_forecasting/data/timeseries/_timeseries_v2.py diff --git a/pytorch_forecasting/data/__init__.py b/pytorch_forecasting/data/__init__.py index 301c8394d..17be285a0 100644 --- a/pytorch_forecasting/data/__init__.py +++ b/pytorch_forecasting/data/__init__.py @@ -13,10 +13,11 @@ TorchNormalizer, ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler -from pytorch_forecasting.data.timeseries import TimeSeriesDataSet +from pytorch_forecasting.data.timeseries import TimeSeries, TimeSeriesDataSet __all__ = [ "TimeSeriesDataSet", + "TimeSeries", "NaNLabelEncoder", "GroupNormalizer", "TorchNormalizer", diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py new file mode 100644 index 000000000..7734cccf2 --- /dev/null +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -0,0 +1,9 @@ +"""Data loaders for time series data.""" + +from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries +from pytorch_forecasting.data.timeseries._timeseries import TimeSeriesDataSet + +__all__ = [ + "TimeSeriesDataSet", + "TimeSeries", +] diff --git a/pytorch_forecasting/data/timeseries/_coerce.py b/pytorch_forecasting/data/timeseries/_coerce.py new file mode 100644 index 000000000..328431aa8 --- /dev/null +++ b/pytorch_forecasting/data/timeseries/_coerce.py @@ -0,0 +1,25 @@ +"""Coercion functions for various data types.""" + +from copy import deepcopy + + +def _coerce_to_list(obj): + """Coerce object to list. + + None is coerced to empty list, otherwise list constructor is used. + """ + if obj is None: + return [] + if isinstance(obj, str): + return [obj] + return list(obj) + + +def _coerce_to_dict(obj): + """Coerce object to dict. + + None is coerce to empty dict, otherwise deepcopy is used. + """ + if obj is None: + return {} + return deepcopy(obj) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py similarity index 90% rename from pytorch_forecasting/data/timeseries.py rename to pytorch_forecasting/data/timeseries/_timeseries.py index fda08d561..263e0ea3a 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -9,7 +9,7 @@ from copy import copy as _copy, deepcopy from functools import lru_cache import inspect -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Optional, Type, TypeVar, Union import warnings import numpy as np @@ -31,6 +31,7 @@ TorchNormalizer, ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler +from pytorch_forecasting.data.timeseries._coerce import _coerce_to_dict, _coerce_to_list from pytorch_forecasting.utils import repr_class from pytorch_forecasting.utils._dependencies import _check_matplotlib @@ -2663,286 +2664,3 @@ def __repr__(self) -> str: attributes=self.get_parameters(), extra_attributes=dict(length=len(self)), ) - - -def _coerce_to_list(obj): - """Coerce object to list. - - None is coerced to empty list, otherwise list constructor is used. - """ - if obj is None: - return [] - if isinstance(obj, str): - return [obj] - return list(obj) - - -def _coerce_to_dict(obj): - """Coerce object to dict. - - None is coerce to empty dict, otherwise deepcopy is used. - """ - if obj is None: - return {} - return deepcopy(obj) - - -####################################################################################### -# Disclaimer: This dataset class is still work in progress and experimental, please -# use with care. This class is a basic skeleton of how the data-handling pipeline may -# look like in the future. -# This is the D1 layer that is a "Raw Dataset Layer" mainly for raw data ingestion -# and turning the data to tensors. -# For now, this pipeline handles the simplest situation: The whole data can be loaded -# into the memory. -####################################################################################### - - -class TimeSeries(Dataset): - """PyTorch Dataset for time series data stored in pandas DataFrame. - - Parameters - ---------- - data : pd.DataFrame - data frame with sequence data. - Column names must all be str, and contain str as referred to below. - data_future : pd.DataFrame, optional, default=None - data frame with future data. - Column names must all be str, and contain str as referred to below. - May contain only columns that are in time, group, weight, known, or static. - time : str, optional, default = first col not in group_ids, weight, target, static. - integer typed column denoting the time index within ``data``. - This column is used to determine the sequence of samples. - If there are no missing observations, - the time index should increase by ``+1`` for each subsequent sample. - The first time_idx for each series does not necessarily - have to be ``0`` but any value is allowed. - target : str or List[str], optional, default = last column (at iloc -1) - column(s) in ``data`` denoting the forecasting target. - Can be categorical or numerical dtype. - group : List[str], optional, default = None - list of column names identifying a time series instance within ``data``. - This means that the ``group`` together uniquely identify an instance, - and ``group`` together with ``time`` uniquely identify a single observation - within a time series instance. - If ``None``, the dataset is assumed to be a single time series. - weight : str, optional, default=None - column name for weights. - If ``None``, it is assumed that there is no weight column. - num : list of str, optional, default = all columns with dtype in "fi" - list of numerical variables in ``data``, - list may also contain list of str, which are then grouped together. - cat : list of str, optional, default = all columns with dtype in "Obc" - list of categorical variables in ``data``, - list may also contain list of str, which are then grouped together - (e.g. useful for product categories). - known : list of str, optional, default = all variables - list of variables that change over time and are known in the future, - list may also contain list of str, which are then grouped together - (e.g. useful for special days or promotion categories). - unknown : list of str, optional, default = no variables - list of variables that are not known in the future, - list may also contain list of str, which are then grouped together - (e.g. useful for weather categories). - static : list of str, optional, default = all variables not in known, unknown - list of variables that do not change over time, - list may also contain list of str, which are then grouped together. - """ - - def __init__( - self, - data: pd.DataFrame, - data_future: Optional[pd.DataFrame] = None, - time: Optional[str] = None, - target: Optional[Union[str, List[str]]] = None, - group: Optional[List[str]] = None, - weight: Optional[str] = None, - num: Optional[List[Union[str, List[str]]]] = None, - cat: Optional[List[Union[str, List[str]]]] = None, - known: Optional[List[Union[str, List[str]]]] = None, - unknown: Optional[List[Union[str, List[str]]]] = None, - static: Optional[List[Union[str, List[str]]]] = None, - ): - - self.data = data - self.data_future = data_future - self.time = time - self.target = _coerce_to_list(target) - self.group = _coerce_to_list(group) - self.weight = weight - self.num = _coerce_to_list(num) - self.cat = _coerce_to_list(cat) - self.known = _coerce_to_list(known) - self.unknown = _coerce_to_list(unknown) - self.static = _coerce_to_list(static) - - self.feature_cols = [ - col - for col in data.columns - if col not in [self.time] + self.group + [self.weight] + self.target - ] - if self.group: - self._groups = self.data.groupby(self.group).groups - self._group_ids = list(self._groups.keys()) - else: - self._groups = {"_single_group": self.data.index} - self._group_ids = ["_single_group"] - - self._prepare_metadata() - - def _prepare_metadata(self): - """Prepare metadata for the dataset. - - The funcion returns metadata that contains: - - * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } - Names of columns for y, x, and static features. - List elements are in same order as column dimensions. - Columns not appearing are assumed to be named (x0, x1, etc.), - (y0, y1, etc.), (st0, st1, etc.). - * ``col_type``: dict[str, str] - maps column names to data types "F" (numerical) and "C" (categorical). - Column names not occurring are assumed "F". - * ``col_known``: dict[str, str] - maps column names to "K" (future known) or "U" (future unknown). - Column names not occurring are assumed "K". - """ - self.metadata = { - "cols": { - "y": self.target, - "x": self.feature_cols, - "st": self.static, - }, - "col_type": {}, - "col_known": {}, - } - - all_cols = self.target + self.feature_cols + self.static - for col in all_cols: - self.metadata["col_type"][col] = "C" if col in self.cat else "F" - - self.metadata["col_known"][col] = "K" if col in self.known else "U" - - def __len__(self) -> int: - """Return number of time series in the dataset.""" - return len(self._group_ids) - - def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: - """Get time series data for given index. - - Returns - ------- - t : numpy.ndarray of shape (n_timepoints,) - Time index for each time point in the past or present. Aligned with `y`, - and `x` not ending in `f`. - - y : torch.Tensor of shape (n_timepoints, n_targets) - Target values for each time point. Rows are time points, aligned with `t`. - - x : torch.Tensor of shape (n_timepoints, n_features) - Features for each time point. Rows are time points, aligned with `t`. - - group : torch.Tensor of shape (n_groups,) - Group identifiers for time series instances. - - st : torch.Tensor of shape (n_static_features,) - Static features. - - cutoff_time : float or numpy.float64 - Cutoff time for the time series instance. - - Other Returns - ------------- - weights : torch.Tensor of shape (n_timepoints,), optional - Only included if weights are not `None`. - """ - group_id = self._group_ids[index] - - if self.group: - mask = self._groups[group_id] - data = self.data.loc[mask] - else: - data = self.data - - cutoff_time = data[self.time].max() - - result = { - "t": data[self.time].values, - "y": torch.tensor(data[self.target].values), - "x": torch.tensor(data[self.feature_cols].values), - "group": torch.tensor([hash(str(group_id))]), - "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), - "cutoff_time": cutoff_time, - } - - if self.data_future is not None: - if self.group: - future_mask = self.data_future.groupby(self.group).groups[group_id] - future_data = self.data_future.loc[future_mask] - else: - future_data = self.data_future - - combined_times = np.concatenate( - [data[self.time].values, future_data[self.time].values] - ) - combined_times = np.unique(combined_times) - combined_times.sort() - - num_timepoints = len(combined_times) - x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) - y_merged = np.full((num_timepoints, len(self.target)), np.nan) - - current_time_indices = {t: i for i, t in enumerate(combined_times)} - for i, t in enumerate(data[self.time].values): - idx = current_time_indices[t] - x_merged[idx] = data[self.feature_cols].values[i] - y_merged[idx] = data[self.target].values[i] - - for i, t in enumerate(future_data[self.time].values): - if t in current_time_indices: - idx = current_time_indices[t] - for j, col in enumerate(self.known): - if col in self.feature_cols: - feature_idx = self.feature_cols.index(col) - x_merged[idx, feature_idx] = future_data[col].values[i] - - result.update( - { - "t": combined_times, - "x": torch.tensor(x_merged, dtype=torch.float32), - "y": torch.tensor(y_merged, dtype=torch.float32), - } - ) - - if self.weight: - if self.data_future is not None and self.weight in self.data_future.columns: - weights_merged = np.full(num_timepoints, np.nan) - for i, t in enumerate(data[self.time].values): - idx = current_time_indices[t] - weights_merged[idx] = data[self.weight].values[i] - - for i, t in enumerate(future_data[self.time].values): - if t in current_time_indices and self.weight in future_data.columns: - idx = current_time_indices[t] - weights_merged[idx] = future_data[self.weight].values[i] - - result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) - else: - result["weights"] = torch.tensor( - data[self.weight].values, dtype=torch.float32 - ) - - return result - - def get_metadata(self) -> Dict: - """Return metadata about the dataset. - - Returns - ------- - Dict - Dictionary containing: - - cols: column names for y, x, and static features - - col_type: mapping of columns to their types (F/C) - - col_known: mapping of columns to their future known status (K/U) - """ - return self.metadata diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py new file mode 100644 index 000000000..53bf7228d --- /dev/null +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -0,0 +1,276 @@ +""" +Timeseries dataset - v2 prototype. + +Beta version, experimental - use for testing but not in production. +""" + +from typing import Dict, List, Optional, Union +import warnings + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset + +from pytorch_forecasting.data.timeseries._coerce import _coerce_to_list + + +####################################################################################### +# Disclaimer: This dataset class is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the data-handling pipeline may +# look like in the future. +# This is the D1 layer that is a "Raw Dataset Layer" mainly for raw data ingestion +# and turning the data to tensors. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + + +class TimeSeries(Dataset): + """PyTorch Dataset for time series data stored in pandas DataFrame. + + Parameters + ---------- + data : pd.DataFrame + data frame with sequence data. + Column names must all be str, and contain str as referred to below. + data_future : pd.DataFrame, optional, default=None + data frame with future data. + Column names must all be str, and contain str as referred to below. + May contain only columns that are in time, group, weight, known, or static. + time : str, optional, default = first col not in group_ids, weight, target, static. + integer typed column denoting the time index within ``data``. + This column is used to determine the sequence of samples. + If there are no missing observations, + the time index should increase by ``+1`` for each subsequent sample. + The first time_idx for each series does not necessarily + have to be ``0`` but any value is allowed. + target : str or List[str], optional, default = last column (at iloc -1) + column(s) in ``data`` denoting the forecasting target. + Can be categorical or numerical dtype. + group : List[str], optional, default = None + list of column names identifying a time series instance within ``data``. + This means that the ``group`` together uniquely identify an instance, + and ``group`` together with ``time`` uniquely identify a single observation + within a time series instance. + If ``None``, the dataset is assumed to be a single time series. + weight : str, optional, default=None + column name for weights. + If ``None``, it is assumed that there is no weight column. + num : list of str, optional, default = all columns with dtype in "fi" + list of numerical variables in ``data``, + list may also contain list of str, which are then grouped together. + cat : list of str, optional, default = all columns with dtype in "Obc" + list of categorical variables in ``data``, + list may also contain list of str, which are then grouped together + (e.g. useful for product categories). + known : list of str, optional, default = all variables + list of variables that change over time and are known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for special days or promotion categories). + unknown : list of str, optional, default = no variables + list of variables that are not known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for weather categories). + static : list of str, optional, default = all variables not in known, unknown + list of variables that do not change over time, + list may also contain list of str, which are then grouped together. + """ + + def __init__( + self, + data: pd.DataFrame, + data_future: Optional[pd.DataFrame] = None, + time: Optional[str] = None, + target: Optional[Union[str, List[str]]] = None, + group: Optional[List[str]] = None, + weight: Optional[str] = None, + num: Optional[List[Union[str, List[str]]]] = None, + cat: Optional[List[Union[str, List[str]]]] = None, + known: Optional[List[Union[str, List[str]]]] = None, + unknown: Optional[List[Union[str, List[str]]]] = None, + static: Optional[List[Union[str, List[str]]]] = None, + ): + + self.data = data + self.data_future = data_future + self.time = time + self.target = _coerce_to_list(target) + self.group = _coerce_to_list(group) + self.weight = weight + self.num = _coerce_to_list(num) + self.cat = _coerce_to_list(cat) + self.known = _coerce_to_list(known) + self.unknown = _coerce_to_list(unknown) + self.static = _coerce_to_list(static) + + self.feature_cols = [ + col + for col in data.columns + if col not in [self.time] + self.group + [self.weight] + self.target + ] + if self.group: + self._groups = self.data.groupby(self.group).groups + self._group_ids = list(self._groups.keys()) + else: + self._groups = {"_single_group": self.data.index} + self._group_ids = ["_single_group"] + + self._prepare_metadata() + + def _prepare_metadata(self): + """Prepare metadata for the dataset. + + The funcion returns metadata that contains: + + * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } + Names of columns for y, x, and static features. + List elements are in same order as column dimensions. + Columns not appearing are assumed to be named (x0, x1, etc.), + (y0, y1, etc.), (st0, st1, etc.). + * ``col_type``: dict[str, str] + maps column names to data types "F" (numerical) and "C" (categorical). + Column names not occurring are assumed "F". + * ``col_known``: dict[str, str] + maps column names to "K" (future known) or "U" (future unknown). + Column names not occurring are assumed "K". + """ + self.metadata = { + "cols": { + "y": self.target, + "x": self.feature_cols, + "st": self.static, + }, + "col_type": {}, + "col_known": {}, + } + + all_cols = self.target + self.feature_cols + self.static + for col in all_cols: + self.metadata["col_type"][col] = "C" if col in self.cat else "F" + + self.metadata["col_known"][col] = "K" if col in self.known else "U" + + def __len__(self) -> int: + """Return number of time series in the dataset.""" + return len(self._group_ids) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get time series data for given index. + + Returns + ------- + t : numpy.ndarray of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with `y`, + and `x` not ending in `f`. + + y : torch.Tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with `t`. + + x : torch.Tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with `t`. + + group : torch.Tensor of shape (n_groups,) + Group identifiers for time series instances. + + st : torch.Tensor of shape (n_static_features,) + Static features. + + cutoff_time : float or numpy.float64 + Cutoff time for the time series instance. + + Other Returns + ------------- + weights : torch.Tensor of shape (n_timepoints,), optional + Only included if weights are not `None`. + """ + group_id = self._group_ids[index] + + if self.group: + mask = self._groups[group_id] + data = self.data.loc[mask] + else: + data = self.data + + cutoff_time = data[self.time].max() + + result = { + "t": data[self.time].values, + "y": torch.tensor(data[self.target].values), + "x": torch.tensor(data[self.feature_cols].values), + "group": torch.tensor([hash(str(group_id))]), + "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), + "cutoff_time": cutoff_time, + } + + if self.data_future is not None: + if self.group: + future_mask = self.data_future.groupby(self.group).groups[group_id] + future_data = self.data_future.loc[future_mask] + else: + future_data = self.data_future + + combined_times = np.concatenate( + [data[self.time].values, future_data[self.time].values] + ) + combined_times = np.unique(combined_times) + combined_times.sort() + + num_timepoints = len(combined_times) + x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) + y_merged = np.full((num_timepoints, len(self.target)), np.nan) + + current_time_indices = {t: i for i, t in enumerate(combined_times)} + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + x_merged[idx] = data[self.feature_cols].values[i] + y_merged[idx] = data[self.target].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices: + idx = current_time_indices[t] + for j, col in enumerate(self.known): + if col in self.feature_cols: + feature_idx = self.feature_cols.index(col) + x_merged[idx, feature_idx] = future_data[col].values[i] + + result.update( + { + "t": combined_times, + "x": torch.tensor(x_merged, dtype=torch.float32), + "y": torch.tensor(y_merged, dtype=torch.float32), + } + ) + + if self.weight: + if self.data_future is not None and self.weight in self.data_future.columns: + weights_merged = np.full(num_timepoints, np.nan) + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + weights_merged[idx] = data[self.weight].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices and self.weight in future_data.columns: + idx = current_time_indices[t] + weights_merged[idx] = future_data[self.weight].values[i] + + result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) + else: + result["weights"] = torch.tensor( + data[self.weight].values, dtype=torch.float32 + ) + + return result + + def get_metadata(self) -> Dict: + """Return metadata about the dataset. + + Returns + ------- + Dict + Dictionary containing: + - cols: column names for y, x, and static features + - col_type: mapping of columns to their types (F/C) + - col_known: mapping of columns to their future known status (K/U) + """ + return self.metadata From 043820dd3be3041a019fd9cd2cb1e681d25a79a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 30 Apr 2025 18:43:50 +0200 Subject: [PATCH 16/30] warning --- .../data/timeseries/_timeseries_v2.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 53bf7228d..1c91d2525 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -104,6 +104,18 @@ def __init__( self.unknown = _coerce_to_list(unknown) self.static = _coerce_to_list(static) + warnings.warn( + "TimeSeries is part of an experimental rework of the " + "pytorch-forecasting data layer, " + "scheduled for release with v2.0.0. " + "The API is not stable and may change without prior warning. " + "For beta testing, but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) + self.feature_cols = [ col for col in data.columns From 1720a15e9cff3e5c3ebcd0bf3ec03995d068e4b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 13:58:09 +0200 Subject: [PATCH 17/30] linting --- pytorch_forecasting/data/timeseries/__init__.py | 2 +- pytorch_forecasting/data/timeseries/_timeseries_v2.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py index 7734cccf2..85973267a 100644 --- a/pytorch_forecasting/data/timeseries/__init__.py +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -1,7 +1,7 @@ """Data loaders for time series data.""" -from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries from pytorch_forecasting.data.timeseries._timeseries import TimeSeriesDataSet +from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries __all__ = [ "TimeSeriesDataSet", diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 1c91d2525..76972ab4d 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -14,7 +14,6 @@ from pytorch_forecasting.data.timeseries._coerce import _coerce_to_list - ####################################################################################### # Disclaimer: This dataset class is still work in progress and experimental, please # use with care. This class is a basic skeleton of how the data-handling pipeline may From af44474d16b3fcdf5e99acb4b9d1f7345119d8cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:21:58 +0200 Subject: [PATCH 18/30] move coercion to utils --- pytorch_forecasting/data/data_module.py | 6 ++---- pytorch_forecasting/{data/timeseries => utils}/_coerce.py | 0 2 files changed, 2 insertions(+), 4 deletions(-) rename pytorch_forecasting/{data/timeseries => utils}/_coerce.py (100%) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 1203e83ac..9d3ebbedb 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -19,10 +19,8 @@ NaNLabelEncoder, TorchNormalizer, ) -from pytorch_forecasting.data.timeseries import ( - TimeSeries, - _coerce_to_dict, -) +from pytorch_forecasting.data.timeseries import TimeSeries +from pytorch_forecasting.utils._coerce import _coerce_to_dict NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] diff --git a/pytorch_forecasting/data/timeseries/_coerce.py b/pytorch_forecasting/utils/_coerce.py similarity index 100% rename from pytorch_forecasting/data/timeseries/_coerce.py rename to pytorch_forecasting/utils/_coerce.py From a3cb8b736b0b134c8faa97f5ef2993deb28fb75b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:22:18 +0200 Subject: [PATCH 19/30] linting --- pytorch_forecasting/data/timeseries/_timeseries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/_timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py index 263e0ea3a..30fe9e0bb 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -31,8 +31,8 @@ TorchNormalizer, ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler -from pytorch_forecasting.data.timeseries._coerce import _coerce_to_dict, _coerce_to_list from pytorch_forecasting.utils import repr_class +from pytorch_forecasting.utils._coerce import _coerce_to_dict, _coerce_to_list from pytorch_forecasting.utils._dependencies import _check_matplotlib From 75d7fb54d8405ef493197c5a4d2fc86a5e9e9d5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:25:51 +0200 Subject: [PATCH 20/30] Update _timeseries_v2.py --- pytorch_forecasting/data/timeseries/_timeseries_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 76972ab4d..afa45725b 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -12,7 +12,7 @@ import torch from torch.utils.data import Dataset -from pytorch_forecasting.data.timeseries._coerce import _coerce_to_list +from pytorch_forecasting.utils._coerce import _coerce_to_list ####################################################################################### # Disclaimer: This dataset class is still work in progress and experimental, please From 1b946e699be9db2e201a2361779a695356a0460b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:30:13 +0200 Subject: [PATCH 21/30] Update __init__.py --- pytorch_forecasting/data/timeseries/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py index 85973267a..b359a0aa9 100644 --- a/pytorch_forecasting/data/timeseries/__init__.py +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -1,9 +1,15 @@ """Data loaders for time series data.""" -from pytorch_forecasting.data.timeseries._timeseries import TimeSeriesDataSet +from pytorch_forecasting.data.timeseries._timeseries import ( + _find_end_indices, + check_for_nonfinite, + TimeSeriesDataSet, +) from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries __all__ = [ + "_find_end_indices", + "check_for_nonfinite", "TimeSeriesDataSet", "TimeSeries", ] From 3edb08b7ea1b97d06b47b0ebcc83aaef9bec8083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:33:17 +0200 Subject: [PATCH 22/30] Update __init__.py --- pytorch_forecasting/data/timeseries/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py index b359a0aa9..788c08201 100644 --- a/pytorch_forecasting/data/timeseries/__init__.py +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -1,9 +1,9 @@ """Data loaders for time series data.""" from pytorch_forecasting.data.timeseries._timeseries import ( + TimeSeriesDataSet, _find_end_indices, check_for_nonfinite, - TimeSeriesDataSet, ) from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries From e350291c110f567e69946e0e113f2471b7472738 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 11 May 2025 22:10:01 +0530 Subject: [PATCH 23/30] update tests --- tests/test_data/test_data_module.py | 72 ++++++++++++++--------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index c14e3d8f4..4051b852c 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -9,7 +9,7 @@ @pytest.fixture def sample_timeseries_data(): """Create a sample time series dataset with only numerical values.""" - num_groups = 5 + num_groups = 10 seq_length = 100 groups = [] @@ -128,22 +128,22 @@ def test_metadata_property(data_module): assert metadata["decoder_cont"] == 1 # Only known_future marked as known -# def test_setup(data_module): -# """Test the setup method that prepares the datasets.""" -# data_module.setup(stage="fit") -# print(data_module._val_indices) -# assert hasattr(data_module, "train_dataset") -# assert hasattr(data_module, "val_dataset") -# assert len(data_module.train_windows) > 0 -# assert len(data_module.val_windows) > 0 -# -# data_module.setup(stage="test") -# assert hasattr(data_module, "test_dataset") -# assert len(data_module.test_windows) > 0 -# -# data_module.setup(stage="predict") -# assert hasattr(data_module, "predict_dataset") -# assert len(data_module.predict_windows) > 0 +def test_setup(data_module): + """Test the setup method that prepares the datasets.""" + data_module.setup(stage="fit") + print(data_module._val_indices) + assert hasattr(data_module, "train_dataset") + assert hasattr(data_module, "val_dataset") + assert len(data_module.train_windows) > 0 + assert len(data_module.val_windows) > 0 + + data_module.setup(stage="test") + assert hasattr(data_module, "test_dataset") + assert len(data_module.test_windows) > 0 + + data_module.setup(stage="predict") + assert hasattr(data_module, "predict_dataset") + assert len(data_module.predict_windows) > 0 def test_create_windows(data_module): @@ -407,25 +407,25 @@ def test_with_static_features(): assert "static_continuous_features" in x -# def test_different_train_val_test_split(sample_timeseries_data): -# """Test with different train/val/test split ratios.""" -# dm = EncoderDecoderTimeSeriesDataModule( -# time_series_dataset=sample_timeseries_data, -# max_encoder_length=24, -# max_prediction_length=12, -# batch_size=4, -# train_val_test_split=(0.8, 0.1, 0.1), -# ) -# -# dm.setup() -# -# total_series = len(sample_timeseries_data) -# expected_train = int(0.8 * total_series) -# expected_val = int(0.1 * total_series) -# -# assert len(dm._train_indices) == expected_train -# assert len(dm._val_indices) == expected_val -# assert len(dm._test_indices) == total_series - expected_train - expected_val +def test_different_train_val_test_split(sample_timeseries_data): + """Test with different train/val/test split ratios.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=4, + train_val_test_split=(0.8, 0.1, 0.1), + ) + + dm.setup() + + total_series = len(sample_timeseries_data) + expected_train = int(0.8 * total_series) + expected_val = int(0.1 * total_series) + + assert len(dm._train_indices) == expected_train + assert len(dm._val_indices) == expected_val + assert len(dm._test_indices) == total_series - expected_train - expected_val def test_multivariate_target(): From 3099691d3cc792bd528f50ff3c51a0fa4a9ce28a Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 12 May 2025 00:22:27 +0530 Subject: [PATCH 24/30] update tft_v2 --- .../tft_version_two.py | 65 +++++++++++-------- 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py index 30f70f98e..2bfe407d7 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py @@ -36,6 +36,8 @@ def __init__( lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, ) + self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"]) + self.hidden_size = hidden_size self.num_layers = num_layers self.attention_head_size = attention_head_size @@ -47,42 +49,51 @@ def __init__( self.max_prediction_length = self.metadata["max_prediction_length"] self.encoder_cont = self.metadata["encoder_cont"] self.encoder_cat = self.metadata["encoder_cat"] - self.static_categorical_features = self.metadata["static_categorical_features"] - self.static_continuous_features = self.metadata["static_continuous_features"] - - total_feature_size = self.encoder_cont + self.encoder_cat - total_static_size = ( - self.static_categorical_features + self.static_continuous_features - ) - - self.encoder_var_selection = nn.Sequential( - nn.Linear(total_feature_size, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, total_feature_size), - nn.Sigmoid(), - ) - - self.decoder_var_selection = nn.Sequential( - nn.Linear(total_feature_size, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, total_feature_size), - nn.Sigmoid(), - ) + self.encoder_input_dim = self.encoder_cont + self.encoder_cat + self.decoder_cont = self.metadata["decoder_cont"] + self.decoder_cat = self.metadata["decoder_cat"] + self.decoder_input_dim = self.decoder_cont + self.decoder_cat + self.static_cat_dim = self.metadata.get("static_categorical_features", 0) + self.static_cont_dim = self.metadata.get("static_continuous_features", 0) + self.static_input_dim = self.static_cat_dim + self.static_cont_dim + + if self.encoder_input_dim > 0: + self.encoder_var_selection = nn.Sequential( + nn.Linear(self.encoder_input_dim, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, self.encoder_input_dim), + nn.Sigmoid(), + ) + else: + self.encoder_var_selection = None + + if self.decoder_input_dim > 0: + self.decoder_var_selection = nn.Sequential( + nn.Linear(self.decoder_input_dim, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, self.decoder_input_dim), + nn.Sigmoid(), + ) + else: + self.decoder_var_selection = None - self.static_context_linear = ( - nn.Linear(total_static_size, hidden_size) if total_static_size > 0 else None - ) + if self.static_input_dim > 0: + self.static_context_linear = nn.Linear(self.static_input_dim, hidden_size) + else: + self.static_context_linear = None + _lstm_encoder_input_actual_dim = self.encoder_input_dim self.lstm_encoder = nn.LSTM( - input_size=total_feature_size, + input_size=max(1, _lstm_encoder_input_actual_dim), hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True, ) + _lstm_decoder_input_actual_dim = self.decoder_input_dim self.lstm_decoder = nn.LSTM( - input_size=total_feature_size, + input_size=max(1, _lstm_decoder_input_actual_dim), hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, @@ -97,7 +108,7 @@ def __init__( ) self.pre_output = nn.Linear(hidden_size, hidden_size) - self.output_layer = nn.Linear(hidden_size, output_size) + self.output_layer = nn.Linear(hidden_size, self.output_size) def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ From d6e62bbb5da93f0cfc903b1edd9b68a8e2806243 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 12 May 2025 00:51:22 +0530 Subject: [PATCH 25/30] update notebook --- examples/ptf_V2_example.ipynb | 3360 ++++++++++----------------------- 1 file changed, 989 insertions(+), 2371 deletions(-) diff --git a/examples/ptf_V2_example.ipynb b/examples/ptf_V2_example.ipynb index 90d8e93be..2e39108f3 100644 --- a/examples/ptf_V2_example.ipynb +++ b/examples/ptf_V2_example.ipynb @@ -1,2404 +1,1022 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "2630DaOEI4AJ", - "outputId": "a6f99bf0-957b-431a-f512-6abba6629768" + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2630DaOEI4AJ", + "outputId": "b4cc7100-2a9c-41a3-e890-164b90c91c03" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting pytorch-forecasting\n", + " Downloading pytorch_forecasting-1.3.0-py3-none-any.whl.metadata (13 kB)\n", + "Requirement already satisfied: numpy<=3.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.0.2)\n", + "Requirement already satisfied: torch!=2.0.1,<3.0.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.6.0+cu124)\n", + "Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)\n", + " Downloading lightning-2.5.1.post0-py3-none-any.whl.metadata (39 kB)\n", + "Requirement already satisfied: scipy<2.0,>=1.8 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.15.2)\n", + "Requirement already satisfied: pandas<3.0.0,>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.2.2)\n", + "Requirement already satisfied: scikit-learn<2.0,>=1.2 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.6.1)\n", + "Requirement already satisfied: PyYAML<8.0,>=5.4 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.0.2)\n", + "Requirement already satisfied: fsspec<2026.0,>=2022.5.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2025.3.2)\n", + "Collecting lightning-utilities<2.0,>=0.10.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)\n", + "Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (24.2)\n", + "Collecting torchmetrics<3.0,>=0.7.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)\n", + "Requirement already satisfied: tqdm<6.0,>=4.57.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.67.1)\n", + "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.13.2)\n", + "Collecting pytorch-lightning (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading pytorch_lightning-2.5.1.post0-py3-none-any.whl.metadata (20 kB)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (3.6.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.18.0)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.1.6)\n", + "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-curand-cu12==10.3.5.147 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (0.6.2)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (2.21.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", + "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.2.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.0)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.11.15)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from lightning-utilities<2.0,>=0.10.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (75.2.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas<3.0.0,>=1.3.0->pytorch-forecasting) (1.17.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.0.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.6.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.2)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (25.3.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.6.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.4.3)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.3.1)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.20.0)\n", + "Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.11/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.10)\n", + "Downloading pytorch_forecasting-1.3.0-py3-none-any.whl (197 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m197.7/197.7 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning-2.5.1.post0-py3-none-any.whl (819 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m819.0/819.0 kB\u001b[0m \u001b[31m16.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m37.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m20.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m12.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m64.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)\n", + "Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m961.5/961.5 kB\u001b[0m \u001b[31m35.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading pytorch_lightning-2.5.1.post0-py3-none-any.whl (823 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.1/823.1 kB\u001b[0m \u001b[31m37.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch-lightning, lightning, pytorch-forecasting\n", + " Attempting uninstall: nvidia-nvjitlink-cu12\n", + " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", + " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", + " Attempting uninstall: nvidia-curand-cu12\n", + " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", + " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", + " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", + " Attempting uninstall: nvidia-cufft-cu12\n", + " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", + " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", + " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", + " Attempting uninstall: nvidia-cuda-runtime-cu12\n", + " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", + " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-cupti-cu12\n", + " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cublas-cu12\n", + " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", + " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", + " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", + " Attempting uninstall: nvidia-cusparse-cu12\n", + " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", + " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", + " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", + " Attempting uninstall: nvidia-cudnn-cu12\n", + " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", + " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", + " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", + " Attempting uninstall: nvidia-cusolver-cu12\n", + " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", + " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", + " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", + "Successfully installed lightning-2.5.1.post0 lightning-utilities-0.14.3 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-forecasting-1.3.0 pytorch-lightning-2.5.1.post0 torchmetrics-1.7.1\n" + ] + } + ], + "source": [ + "!pip install pytorch-forecasting" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "M7PQerTbI_tM" + }, + "outputs": [], + "source": [ + "from typing import Any, Dict, List, Optional, Tuple, Union\n", + "\n", + "from lightning.pytorch import Trainer\n", + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.preprocessing import RobustScaler, StandardScaler\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import Optimizer\n", + "from torch.utils.data import Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DGTyf3vct-Jk" + }, + "outputs": [], + "source": [ + "from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule\n", + "from pytorch_forecasting.data.encoders import (\n", + " EncoderNormalizer,\n", + " NaNLabelEncoder,\n", + " TorchNormalizer,\n", + ")\n", + "from pytorch_forecasting.data.timeseries import TimeSeries\n", + "from pytorch_forecasting.metrics import MAE, SMAPE\n", + "from pytorch_forecasting.models.temporal_fusion_transformer.tft_version_two import TFT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "WX-FRdusJSVN", + "outputId": "d162e241-3076-415c-db39-8c571bbaa282" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6702424716860947,\n \"min\": -1.2572875930191487,\n \"max\": 1.347291996576924,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.40064266948811306,\n 0.688757012378203,\n -0.9278241195910876\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6742204890063661,\n \"min\": -1.2572875930191487,\n \"max\": 1.347291996576924,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.6039073395571968,\n 0.5832480743546181,\n -0.801772762118357\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2912918986931303,\n \"min\": 0.007584244652032224,\n \"max\": 0.9959799570401108,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.4307103381570838,\n 0.6664272198589233,\n 0.16731443141739688\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", + "type": "dataframe", + "variable_name": "data_df" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting pytorch-forecasting\n", - " Downloading pytorch_forecasting-1.3.0-py3-none-any.whl.metadata (13 kB)\n", - "Requirement already satisfied: numpy<=3.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.0.2)\n", - "Requirement already satisfied: torch!=2.0.1,<3.0.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.6.0+cu124)\n", - "Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)\n", - " Downloading lightning-2.5.1-py3-none-any.whl.metadata (39 kB)\n", - "Requirement already satisfied: scipy<2.0,>=1.8 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.14.1)\n", - "Requirement already satisfied: pandas<3.0.0,>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.2.2)\n", - "Requirement already satisfied: scikit-learn<2.0,>=1.2 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.6.1)\n", - "Requirement already satisfied: PyYAML<8.0,>=5.4 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.0.2)\n", - "Requirement already satisfied: fsspec<2026.0,>=2022.5.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2025.3.2)\n", - "Collecting lightning-utilities<2.0,>=0.10.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)\n", - "Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (24.2)\n", - "Collecting torchmetrics<3.0,>=0.7.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)\n", - "Requirement already satisfied: tqdm<6.0,>=4.57.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.67.1)\n", - "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.13.1)\n", - "Collecting pytorch-lightning (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2.8.2)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", - "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", - "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (1.4.2)\n", - "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (3.6.0)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.18.0)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.4.2)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.1.6)\n", - "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", - "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", - "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Collecting nvidia-curand-cu12==10.3.5.147 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", - "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", - "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (0.6.2)\n", - "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (2.21.5)\n", - "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", - "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.2.0)\n", - "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.13.1)\n", - "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.0)\n", - "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.11.15)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from lightning-utilities<2.0,>=0.10.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (75.2.0)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas<3.0.0,>=1.3.0->pytorch-forecasting) (1.17.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.0.2)\n", - "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.6.1)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.2)\n", - "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (25.3.0)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.5.0)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.3.2)\n", - "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.3.1)\n", - "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.18.3)\n", - "Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.11/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.10)\n", - "Downloading pytorch_forecasting-1.3.0-py3-none-any.whl (197 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m197.7/197.7 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading lightning-2.5.1-py3-none-any.whl (818 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m818.9/818.9 kB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m58.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m55.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m26.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m84.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)\n", - "Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m961.5/961.5 kB\u001b[0m \u001b[31m51.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading pytorch_lightning-2.5.1-py3-none-any.whl (822 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.0/823.0 kB\u001b[0m \u001b[31m40.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch-lightning, lightning, pytorch-forecasting\n", - " Attempting uninstall: nvidia-nvjitlink-cu12\n", - " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", - " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", - " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", - " Attempting uninstall: nvidia-curand-cu12\n", - " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", - " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", - " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", - " Attempting uninstall: nvidia-cufft-cu12\n", - " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", - " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", - " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", - " Attempting uninstall: nvidia-cuda-runtime-cu12\n", - " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", - " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", - " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", - " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", - " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", - " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", - " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", - " Attempting uninstall: nvidia-cuda-cupti-cu12\n", - " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", - " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", - " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", - " Attempting uninstall: nvidia-cublas-cu12\n", - " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", - " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", - " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", - " Attempting uninstall: nvidia-cusparse-cu12\n", - " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", - " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", - " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", - " Attempting uninstall: nvidia-cudnn-cu12\n", - " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", - " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", - " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", - " Attempting uninstall: nvidia-cusolver-cu12\n", - " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", - " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", - " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", - "Successfully installed lightning-2.5.1 lightning-utilities-0.14.3 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-forecasting-1.3.0 pytorch-lightning-2.5.1 torchmetrics-1.7.1\n" - ] - } + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
0000.2009680.23229801.0000000.687290
1010.2322980.33666900.9950040.687290
2020.3366690.63606300.9800670.687290
3030.6360630.92771000.9553360.687290
4040.9277101.00855400.9210610.687290
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" ], - "source": [ - "!pip install pytorch-forecasting" + "text/plain": [ + " series_id time_idx x y category future_known_feature \\\n", + "0 0 0 0.200968 0.232298 0 1.000000 \n", + "1 0 1 0.232298 0.336669 0 0.995004 \n", + "2 0 2 0.336669 0.636063 0 0.980067 \n", + "3 0 3 0.636063 0.927710 0 0.955336 \n", + "4 0 4 0.927710 1.008554 0 0.921061 \n", + "\n", + " static_feature static_feature_cat \n", + "0 0.68729 0 \n", + "1 0.68729 0 \n", + "2 0.68729 0 \n", + "3 0.68729 0 \n", + "4 0.68729 0 " ] }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "M7PQerTbI_tM" + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "\n", + "num_series = 100\n", + "seq_length = 50\n", + "data_list = []\n", + "for i in range(num_series):\n", + " x = np.arange(seq_length)\n", + " y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)\n", + " category = i % 5\n", + " static_value = np.random.rand()\n", + " for t in range(seq_length - 1):\n", + " data_list.append(\n", + " {\n", + " \"series_id\": i,\n", + " \"time_idx\": t,\n", + " \"x\": y[t],\n", + " \"y\": y[t + 1],\n", + " \"category\": category,\n", + " \"future_known_feature\": np.cos(t / 10),\n", + " \"static_feature\": static_value,\n", + " \"static_feature_cat\": i % 3,\n", + " }\n", + " )\n", + "data_df = pd.DataFrame(data_list)\n", + "data_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "AxxPHK6AKSD2", + "outputId": "dd95173d-73c2-451b-8b67-c9cc7298cf9d" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":106: UserWarning: TimeSeries is part of an experimental rework of the pytorch-forecasting data layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. For beta testing, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "dataset = TimeSeries(\n", + " data=data_df,\n", + " time=\"time_idx\",\n", + " target=\"y\",\n", + " group=[\"series_id\"],\n", + " num=[\"x\", \"future_known_feature\", \"static_feature\"],\n", + " cat=[\"category\", \"static_feature_cat\"],\n", + " known=[\"future_known_feature\"],\n", + " unknown=[\"x\", \"category\"],\n", + " static=[\"static_feature\", \"static_feature_cat\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "5U5Lr_ZFKX0s" + }, + "outputs": [], + "source": [ + "data_module = EncoderDecoderTimeSeriesDataModule(\n", + " time_series_dataset=dataset,\n", + " max_encoder_length=30,\n", + " max_prediction_length=1,\n", + " batch_size=32,\n", + " categorical_encoders={\n", + " \"category\": NaNLabelEncoder(add_nan=True),\n", + " \"static_feature_cat\": NaNLabelEncoder(add_nan=True),\n", + " },\n", + " scalers={\n", + " \"x\": StandardScaler(),\n", + " \"future_known_feature\": StandardScaler(),\n", + " \"static_feature\": StandardScaler(),\n", + " },\n", + " target_normalizer=TorchNormalizer(),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "7de178cd43ab4104ba2445a057a5f1a4", + "40598800b1234eaeb769f18e3c27865c", + "05a21a6e18ca46e280a8c66d6a73cf81", + "e4a86f54cf33447f8959864533cbae72", + "d5ba9114135f4ec1818790a62beb865a", + "39ab0f799ae64a8c9c0d6041b17c0ba5", + "6b98518240b744fe8a708b37a2dcaabe", + "274c42415f834e679f45d02d7d1c01d7", + "c88fdff0012c4e4e85563adb36287774", + "c2490a4409f34a608f6073ff6b9426eb", + "1a5e8e6d619740c9b0fd752a8f886b0c", + "ceff330da01b4ed39eafa8820cb6a5ca", + "d1d794dcb83746a48628280e7d552a70", + "31e24662df144d7196e01b873ffed137", + "07c7871be3274f45a93f01c6003b35fb", + "86ab236d48d04f1980836770cfe61b0d", + "73b46d96f2d54ea8b132b409b0739588", + "0b8814890f1d4143842acce4df31d93e", + "7960939d00d844f094e1bcdc1acda7f9", + "2dcacb2bf93c4d16bfb8657670499fc5", + "ed187855e406486ca0ea2259f3e2f43c", + "c73f370b888f4893be08d51b08b23a87", + "ea890982ba5a4c3d8e15e6bbd7285f6a", + "02f4487acf82401b972b4593da15de15", + "cca521da65e946abad580e9db9f2ac6b", + "b4c9075a5c1148ee8e422b2fcf86d90b", + "57dd80719a2e444cab8345bf5086a2f7", + "14e2d31e04e9447ba4f0c55000e0abe2", + "b35a713d4d7041109db487d10c55aaed", + "59ca144ddc884bb5a1c038e96cdf0dc0", + "95f9b03446af41fb83876f625aae5d75", + "bd2154cf3b04468b93a3bec23b5e34fa", + "f4713a710f274dd389691d7c12a7e740", + "07571714d67e4b8793ee76b0fe151e67", + "4e7273caa91147019971bf75ffe71e49", + "b751dbe4e93341eaa7d1a7683c277d83", + "11e1c349cb894caa9fd77333f7ababb9", + "8c9ac67ac6af488fb16e64c834634a30", + "704f28911d674088b8dfc240c6e28449", + "67da733b7a254020ad4d8d3877eb4494", + "72858abca77d430a8d009ad72127b331", + "86e82ab383a54ef7aa7c0abec6faf1be", + "60f097a304044e3db6dda60ff381776a", + "334e7a44226f4414a825e2d81e9571c1", + "13246130b24145d680092a5a3929546e", + "e64a93136fbf468b8e503cd202dbc986", + "559df10442ac42fc8033bdf014864334", + "67e6440ba0424db18646d47beff2e37b", + "36707c56fcfb4fadbfff37d44ee52d5b", + "484702bf9d854cb6a964de47d6975aa6", + "ddb67e4d1749424299a2c40b36810809", + "6f1e8c2aa06a48548a7750869d3f056e", + "d7df1875fbbb4c78909d76c5b81ffd95", + "cb60f83efc234cd6a90212125fe841a1", + "dfca7d9f0f7844ea953ce4b659493695", + "14ed4e565d5a4b319b0c38557a935b92", + "48e7f65e9dcc48289f44724da248875d", + "8220ff572b31472e91dc5d7553ca43b9", + "fcd1b8b77a7f4c108ba02de9779f5bb2", + "a53e4183893242ea884db9f25e439d94", + "e54fdc83afdf492a9540bc892bdc262a", + "04ce55b4c81a46b09d55b8a0dc3fae00", + "86eddf636c2b45fc93a69f3c8a260b1c", + "287ce5b9bafb402e9698f20779f52386", + "ef80c0b7306046738cfd4e16f5af1afc", + "c0bb5f38ec9346119441c6cb8fd71c5a", + "a54b4228eaf34dc1b40ee8d40500e069", + "28cb77e3b48a4c669dcaa79609d332c2", + "8acc35455cb84ece9e7b3d84d8870a7a", + "afe9a01608f449479cc491f75e095d55", + "9b5971f9e8d44e8a872b782ef49d306a", + "1538291371bc45b2bcb7a687c6b8f79c", + "d595a038c20e4d349e392ff1795dd418", + "a67fa2461a18428892e57790074fd5b6", + "6733a54e084a4d348a85e574985720b7", + "204594abf12e43d4abf0fa35544fb64a", + "eee2d2ef09294652bd94ff768ffb99f0", + "c8b8814037464155935670b776378ab9", + "0b6b787693af4b548ad59ad7ba6c921f", + "6b7172a2b1fd4ddea3dba3d435bf36fe", + "ed0ffcd137e94edebad5e41956ad7466", + "0344079173a84c448cd6edbac612d970", + "890f17882dc44e76be06dd88e6d07cff", + "1c552ec98dcd4325800ee8c9dddc398a", + "89d54947a3d146e4b089a22a2711a234", + "91e7ef398d1646bcbfb598480d949c74", + "13173089480c49b18c4ba016da69a855", + "6b717d2759ca445da7642a2cf20f22b5" + ] + }, + "id": "Si7bbZIULBZz", + "outputId": "0b2f26b3-e37c-4ab6-c234-90693745a8cd" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", + "INFO:lightning.pytorch.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", + "INFO: GPU available: False, used: False\n", + "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", + "INFO: TPU available: False, using: 0 TPU cores\n", + "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO: HPU available: False, using: 0 HPUs\n", + "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n", + "INFO: \n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 193 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 50.4 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "123 K Trainable params\n", + "0 Non-trainable params\n", + "123 K Total params\n", + "0.495 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n", + "INFO:lightning.pytorch.callbacks.model_summary:\n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 193 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 50.4 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "123 K Trainable params\n", + "0 Non-trainable params\n", + "123 K Total params\n", + "0.495 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Training model...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7de178cd43ab4104ba2445a057a5f1a4", + "version_major": 2, + "version_minor": 0 }, - "outputs": [], - "source": [ - "from typing import Any, Dict, List, Optional, Tuple, Union\n", - "\n", - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", - "from torch.utils.data import Dataset\n", - "\n", - "from pytorch_forecasting.data.timeseries import _coerce_to_dict" + "text/plain": [ + "Sanity Checking: | | 0/? [00:00 int:\n", - " \"\"\"Return number of time series in the dataset.\"\"\"\n", - " return len(self._group_ids)\n", - "\n", - " def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:\n", - " \"\"\"Get time series data for given index.\n", - "\n", - " Returns\n", - " -------\n", - " t : numpy.ndarray of shape (n_timepoints,)\n", - " Time index for each time point in the past or present. Aligned with `y`,\n", - " and `x` not ending in `f`.\n", - "\n", - " y : torch.Tensor of shape (n_timepoints, n_targets)\n", - " Target values for each time point. Rows are time points, aligned with `t`.\n", - "\n", - " x : torch.Tensor of shape (n_timepoints, n_features)\n", - " Features for each time point. Rows are time points, aligned with `t`.\n", - "\n", - " group : torch.Tensor of shape (n_groups,)\n", - " Group identifiers for time series instances.\n", - "\n", - " st : torch.Tensor of shape (n_static_features,)\n", - " Static features.\n", - "\n", - " cutoff_time : float or numpy.float64\n", - " Cutoff time for the time series instance.\n", - "\n", - " Other Returns\n", - " -------------\n", - " weights : torch.Tensor of shape (n_timepoints,), optional\n", - " Only included if weights are not `None`.\n", - " \"\"\"\n", - " group_id = self._group_ids[index]\n", - "\n", - " if self.group:\n", - " mask = self._groups[group_id]\n", - " data = self.data.loc[mask]\n", - " else:\n", - " data = self.data\n", - "\n", - " cutoff_time = data[self.time].max()\n", - "\n", - " result = {\n", - " \"t\": data[self.time].values,\n", - " \"y\": torch.tensor(data[self.target].values),\n", - " \"x\": torch.tensor(data[self.feature_cols].values),\n", - " \"group\": torch.tensor([hash(str(group_id))]),\n", - " \"st\": torch.tensor(data[self.static].iloc[0].values if self.static else []),\n", - " \"cutoff_time\": cutoff_time,\n", - " }\n", - "\n", - " if self.data_future is not None:\n", - " if self.group:\n", - " future_mask = self.data_future.groupby(self.group).groups[group_id]\n", - " future_data = self.data_future.loc[future_mask]\n", - " else:\n", - " future_data = self.data_future\n", - "\n", - " combined_times = np.concatenate(\n", - " [data[self.time].values, future_data[self.time].values]\n", - " )\n", - " combined_times = np.unique(combined_times)\n", - " combined_times.sort()\n", - "\n", - " num_timepoints = len(combined_times)\n", - " x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan)\n", - " y_merged = np.full((num_timepoints, len(self.target)), np.nan)\n", - "\n", - " current_time_indices = {t: i for i, t in enumerate(combined_times)}\n", - " for i, t in enumerate(data[self.time].values):\n", - " idx = current_time_indices[t]\n", - " x_merged[idx] = data[self.feature_cols].values[i]\n", - " y_merged[idx] = data[self.target].values[i]\n", - "\n", - " for i, t in enumerate(future_data[self.time].values):\n", - " if t in current_time_indices:\n", - " idx = current_time_indices[t]\n", - " for j, col in enumerate(self.known):\n", - " if col in self.feature_cols:\n", - " feature_idx = self.feature_cols.index(col)\n", - " x_merged[idx, feature_idx] = future_data[col].values[i]\n", - "\n", - " result.update(\n", - " {\n", - " \"t\": combined_times,\n", - " \"x\": torch.tensor(x_merged, dtype=torch.float32),\n", - " \"y\": torch.tensor(y_merged, dtype=torch.float32),\n", - " }\n", - " )\n", - "\n", - " if self.weight:\n", - " if self.data_future is not None and self.weight in self.data_future.columns:\n", - " weights_merged = np.full(num_timepoints, np.nan)\n", - " for i, t in enumerate(data[self.time].values):\n", - " idx = current_time_indices[t]\n", - " weights_merged[idx] = data[self.weight].values[i]\n", - "\n", - " for i, t in enumerate(future_data[self.time].values):\n", - " if t in current_time_indices and self.weight in future_data.columns:\n", - " idx = current_time_indices[t]\n", - " weights_merged[idx] = future_data[self.weight].values[i]\n", - "\n", - " result[\"weights\"] = torch.tensor(weights_merged, dtype=torch.float32)\n", - " else:\n", - " result[\"weights\"] = torch.tensor(\n", - " data[self.weight].values, dtype=torch.float32\n", - " )\n", - "\n", - " return result\n", - "\n", - " def get_metadata(self) -> Dict:\n", - " \"\"\"Return metadata about the dataset.\n", - "\n", - " Returns\n", - " -------\n", - " Dict\n", - " Dictionary containing:\n", - " - cols: column names for y, x, and static features\n", - " - col_type: mapping of columns to their types (F/C)\n", - " - col_known: mapping of columns to their future known status (K/U)\n", - " \"\"\"\n", - " return self.metadata" + "text/plain": [ + "Training: | | 0/? [00:00 List[Dict[str, Any]]:\n", - " \"\"\"Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset.\n", - "\n", - " Preprocessing steps\n", - " --------------------\n", - "\n", - " * Converts target (`y`) and features (`x`) to `torch.float32`.\n", - " * Masks time points that are at or before the cutoff time.\n", - " * Splits features into categorical and continuous subsets based on\n", - " predefined indices.\n", - "\n", - "\n", - " TODO: add scalers, target normalizers etc.\n", - " \"\"\"\n", - " sample = self.time_series_dataset[series_idx]\n", - "\n", - " target = sample[\"y\"]\n", - " features = sample[\"x\"]\n", - " times = sample[\"t\"]\n", - " cutoff_time = sample[\"cutoff_time\"]\n", - "\n", - " time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool)\n", - "\n", - " if isinstance(target, torch.Tensor):\n", - " target = target.float()\n", - " else:\n", - " target = torch.tensor(target, dtype=torch.float32)\n", - "\n", - " if isinstance(features, torch.Tensor):\n", - " features = features.float()\n", - " else:\n", - " features = torch.tensor(features, dtype=torch.float32)\n", - "\n", - " # TODO: add scalers, target normalizers etc.\n", - "\n", - " categorical = (\n", - " features[:, self.categorical_indices]\n", - " if self.categorical_indices\n", - " else torch.zeros((features.shape[0], 0))\n", - " )\n", - " continuous = (\n", - " features[:, self.continuous_indices]\n", - " if self.continuous_indices\n", - " else torch.zeros((features.shape[0], 0))\n", - " )\n", - "\n", - " return {\n", - " \"features\": {\"categorical\": categorical, \"continuous\": continuous},\n", - " \"target\": target,\n", - " \"static\": sample.get(\"st\", None),\n", - " \"group\": sample.get(\"group\", torch.tensor([0])),\n", - " \"length\": len(target),\n", - " \"time_mask\": time_mask,\n", - " \"times\": times,\n", - " \"cutoff_time\": cutoff_time,\n", - " }\n", - "\n", - " class _ProcessedEncoderDecoderDataset(Dataset):\n", - " \"\"\"PyTorch Dataset for processed encoder-decoder time series data.\n", - "\n", - " Parameters\n", - " ----------\n", - " dataset : TimeSeries\n", - " The base time series dataset that provides access to raw data and metadata.\n", - " data_module : EncoderDecoderTimeSeriesDataModule\n", - " The data module handling preprocessing and metadata configuration.\n", - " windows : List[Tuple[int, int, int, int]]\n", - " List of window tuples containing\n", - " (series_idx, start_idx, enc_length, pred_length).\n", - " add_relative_time_idx : bool, default=False\n", - " Whether to include relative time indices.\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " dataset: TimeSeries,\n", - " data_module: \"EncoderDecoderTimeSeriesDataModule\",\n", - " windows: List[Tuple[int, int, int, int]],\n", - " add_relative_time_idx: bool = False,\n", - " ):\n", - " self.dataset = dataset\n", - " self.data_module = data_module\n", - " self.windows = windows\n", - " self.add_relative_time_idx = add_relative_time_idx\n", - "\n", - " def __len__(self):\n", - " return len(self.windows)\n", - "\n", - " def __getitem__(self, idx):\n", - " \"\"\"Retrieve a processed time series window for dataloader input.\n", - "\n", - " x : dict\n", - " Dictionary containing model inputs:\n", - "\n", - " * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features)\n", - " Categorical features for the encoder.\n", - " * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features)\n", - " Continuous features for the encoder.\n", - " * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features)\n", - " Categorical features for the decoder.\n", - " * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features)\n", - " Continuous features for the decoder.\n", - " * ``encoder_lengths`` : tensor of shape (1,)\n", - " Length of the encoder sequence.\n", - " * ``decoder_lengths`` : tensor of shape (1,)\n", - " Length of the decoder sequence.\n", - " * ``decoder_target_lengths`` : tensor of shape (1,)\n", - " Length of the decoder target sequence.\n", - " * ``groups`` : tensor of shape (1,)\n", - " Group identifier for the time series instance.\n", - " * ``encoder_time_idx`` : tensor of shape (enc_length,)\n", - " Time indices for the encoder sequence.\n", - " * ``decoder_time_idx`` : tensor of shape (pred_length,)\n", - " Time indices for the decoder sequence.\n", - " * ``target_scale`` : tensor of shape (1,)\n", - " Scaling factor for the target values.\n", - " * ``encoder_mask`` : tensor of shape (enc_length,)\n", - " Boolean mask indicating valid encoder time points.\n", - " * ``decoder_mask`` : tensor of shape (pred_length,)\n", - " Boolean mask indicating valid decoder time points.\n", - "\n", - " If static features are present, the following keys are added:\n", - "\n", - " * ``static_categorical_features`` : tensor of shape\n", - " (1, n_static_cat_features), optional\n", - " Static categorical features, if available.\n", - " * ``static_continuous_features`` : tensor of shape (1, 0), optional\n", - " Placeholder for static continuous features (currently empty).\n", - "\n", - " y : tensor of shape ``(pred_length, n_targets)``\n", - " Target values for the decoder sequence.\n", - " \"\"\"\n", - " series_idx, start_idx, enc_length, pred_length = self.windows[idx]\n", - " data = self.data_module._preprocess_data(series_idx)\n", - "\n", - " end_idx = start_idx + enc_length + pred_length\n", - " encoder_indices = slice(start_idx, start_idx + enc_length)\n", - " decoder_indices = slice(start_idx + enc_length, end_idx)\n", - "\n", - " target_scale = data[\"target\"][encoder_indices]\n", - " target_scale = target_scale[~torch.isnan(target_scale)].abs().mean()\n", - " if torch.isnan(target_scale) or target_scale == 0:\n", - " target_scale = torch.tensor(1.0)\n", - "\n", - " encoder_mask = (\n", - " data[\"time_mask\"][encoder_indices]\n", - " if \"time_mask\" in data\n", - " else torch.ones(enc_length, dtype=torch.bool)\n", - " )\n", - " decoder_mask = (\n", - " data[\"time_mask\"][decoder_indices]\n", - " if \"time_mask\" in data\n", - " else torch.zeros(pred_length, dtype=torch.bool)\n", - " )\n", - "\n", - " x = {\n", - " \"encoder_cat\": data[\"features\"][\"categorical\"][encoder_indices],\n", - " \"encoder_cont\": data[\"features\"][\"continuous\"][encoder_indices],\n", - " \"decoder_cat\": data[\"features\"][\"categorical\"][decoder_indices],\n", - " \"decoder_cont\": data[\"features\"][\"continuous\"][decoder_indices],\n", - " \"encoder_lengths\": torch.tensor(enc_length),\n", - " \"decoder_lengths\": torch.tensor(pred_length),\n", - " \"decoder_target_lengths\": torch.tensor(pred_length),\n", - " \"groups\": data[\"group\"],\n", - " \"encoder_time_idx\": torch.arange(enc_length),\n", - " \"decoder_time_idx\": torch.arange(enc_length, enc_length + pred_length),\n", - " \"target_scale\": target_scale,\n", - " \"encoder_mask\": encoder_mask,\n", - " \"decoder_mask\": decoder_mask,\n", - " }\n", - " if data[\"static\"] is not None:\n", - " x[\"static_categorical_features\"] = data[\"static\"].unsqueeze(0)\n", - " x[\"static_continuous_features\"] = torch.zeros((1, 0))\n", - "\n", - " y = data[\"target\"][decoder_indices]\n", - " if y.ndim == 1:\n", - " y = y.unsqueeze(-1)\n", - "\n", - " return x, y\n", - "\n", - " def _create_windows(self, indices: torch.Tensor) -> List[Tuple[int, int, int, int]]:\n", - " \"\"\"Generate sliding windows for training, validation, and testing.\n", - "\n", - " Returns\n", - " -------\n", - " List[Tuple[int, int, int, int]]\n", - " A list of tuples, where each tuple consists of:\n", - " - ``series_idx`` : int\n", - " Index of the time series in `time_series_dataset`.\n", - " - ``start_idx`` : int\n", - " Start index of the encoder window.\n", - " - ``enc_length`` : int\n", - " Length of the encoder input sequence.\n", - " - ``pred_length`` : int\n", - " Length of the decoder output sequence.\n", - " \"\"\"\n", - " windows = []\n", - "\n", - " for idx in indices:\n", - " series_idx = idx.item()\n", - " sample = self.time_series_dataset[series_idx]\n", - " sequence_length = len(sample[\"y\"])\n", - "\n", - " if sequence_length < self.max_encoder_length + self.max_prediction_length:\n", - " continue\n", - "\n", - " effective_min_prediction_idx = (\n", - " self.min_prediction_idx\n", - " if self.min_prediction_idx is not None\n", - " else self.max_encoder_length\n", - " )\n", - "\n", - " max_prediction_idx = sequence_length - self.max_prediction_length + 1\n", - "\n", - " if max_prediction_idx <= effective_min_prediction_idx:\n", - " continue\n", - "\n", - " for start_idx in range(\n", - " 0, max_prediction_idx - effective_min_prediction_idx\n", - " ):\n", - " if (\n", - " start_idx + self.max_encoder_length + self.max_prediction_length\n", - " <= sequence_length\n", - " ):\n", - " windows.append(\n", - " (\n", - " series_idx,\n", - " start_idx,\n", - " self.max_encoder_length,\n", - " self.max_prediction_length,\n", - " )\n", - " )\n", - "\n", - " return windows\n", - "\n", - " def setup(self, stage: Optional[str] = None):\n", - " \"\"\"Prepare the datasets for training, validation, testing, or prediction.\n", - "\n", - " Parameters\n", - " ----------\n", - " stage : Optional[str], default=None\n", - " Specifies the stage of setup. Can be one of:\n", - " - ``\"fit\"`` : Prepares training and validation datasets.\n", - " - ``\"test\"`` : Prepares the test dataset.\n", - " - ``\"predict\"`` : Prepares the dataset for inference.\n", - " - ``None`` : Prepares ``fit`` datasets.\n", - " \"\"\"\n", - " total_series = len(self.time_series_dataset)\n", - " self._split_indices = torch.randperm(total_series)\n", - "\n", - " self._train_size = int(self.train_val_test_split[0] * total_series)\n", - " self._val_size = int(self.train_val_test_split[1] * total_series)\n", - "\n", - " self._train_indices = self._split_indices[: self._train_size]\n", - " self._val_indices = self._split_indices[\n", - " self._train_size : self._train_size + self._val_size\n", - " ]\n", - " self._test_indices = self._split_indices[self._train_size + self._val_size :]\n", - "\n", - " if stage is None or stage == \"fit\":\n", - " if not hasattr(self, \"train_dataset\") or not hasattr(self, \"val_dataset\"):\n", - " self.train_windows = self._create_windows(self._train_indices)\n", - " self.val_windows = self._create_windows(self._val_indices)\n", - "\n", - " self.train_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.time_series_dataset,\n", - " self,\n", - " self.train_windows,\n", - " self.add_relative_time_idx,\n", - " )\n", - " self.val_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.time_series_dataset,\n", - " self,\n", - " self.val_windows,\n", - " self.add_relative_time_idx,\n", - " )\n", - "\n", - " elif stage == \"test\":\n", - " if not hasattr(self, \"test_dataset\"):\n", - " self.test_windows = self._create_windows(self._test_indices)\n", - " self.test_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.time_series_dataset,\n", - " self,\n", - " self.test_windows,\n", - " self.add_relative_time_idx,\n", - " )\n", - " elif stage == \"predict\":\n", - " predict_indices = torch.arange(len(self.time_series_dataset))\n", - " self.predict_windows = self._create_windows(predict_indices)\n", - " self.predict_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.time_series_dataset,\n", - " self,\n", - " self.predict_windows,\n", - " self.add_relative_time_idx,\n", - " )\n", - "\n", - " def train_dataloader(self):\n", - " return DataLoader(\n", - " self.train_dataset,\n", - " batch_size=self.batch_size,\n", - " num_workers=self.num_workers,\n", - " shuffle=True,\n", - " collate_fn=self.collate_fn,\n", - " )\n", - "\n", - " def val_dataloader(self):\n", - " return DataLoader(\n", - " self.val_dataset,\n", - " batch_size=self.batch_size,\n", - " num_workers=self.num_workers,\n", - " collate_fn=self.collate_fn,\n", - " )\n", - "\n", - " def test_dataloader(self):\n", - " return DataLoader(\n", - " self.test_dataset,\n", - " batch_size=self.batch_size,\n", - " num_workers=self.num_workers,\n", - " collate_fn=self.collate_fn,\n", - " )\n", - "\n", - " def predict_dataloader(self):\n", - " return DataLoader(\n", - " self.predict_dataset,\n", - " batch_size=self.batch_size,\n", - " num_workers=self.num_workers,\n", - " collate_fn=self.collate_fn,\n", - " )\n", - "\n", - " @staticmethod\n", - " def collate_fn(batch):\n", - " x_batch = {\n", - " \"encoder_cat\": torch.stack([x[\"encoder_cat\"] for x, _ in batch]),\n", - " \"encoder_cont\": torch.stack([x[\"encoder_cont\"] for x, _ in batch]),\n", - " \"decoder_cat\": torch.stack([x[\"decoder_cat\"] for x, _ in batch]),\n", - " \"decoder_cont\": torch.stack([x[\"decoder_cont\"] for x, _ in batch]),\n", - " \"encoder_lengths\": torch.stack([x[\"encoder_lengths\"] for x, _ in batch]),\n", - " \"decoder_lengths\": torch.stack([x[\"decoder_lengths\"] for x, _ in batch]),\n", - " \"decoder_target_lengths\": torch.stack(\n", - " [x[\"decoder_target_lengths\"] for x, _ in batch]\n", - " ),\n", - " \"groups\": torch.stack([x[\"groups\"] for x, _ in batch]),\n", - " \"encoder_time_idx\": torch.stack([x[\"encoder_time_idx\"] for x, _ in batch]),\n", - " \"decoder_time_idx\": torch.stack([x[\"decoder_time_idx\"] for x, _ in batch]),\n", - " \"target_scale\": torch.stack([x[\"target_scale\"] for x, _ in batch]),\n", - " \"encoder_mask\": torch.stack([x[\"encoder_mask\"] for x, _ in batch]),\n", - " \"decoder_mask\": torch.stack([x[\"decoder_mask\"] for x, _ in batch]),\n", - " }\n", - "\n", - " if \"static_categorical_features\" in batch[0][0]:\n", - " x_batch[\"static_categorical_features\"] = torch.stack(\n", - " [x[\"static_categorical_features\"] for x, _ in batch]\n", - " )\n", - " x_batch[\"static_continuous_features\"] = torch.stack(\n", - " [x[\"static_continuous_features\"] for x, _ in batch]\n", - " )\n", - "\n", - " y_batch = torch.stack([y for _, y in batch])\n", - " return x_batch, y_batch" + "text/plain": [ + "Validation: | | 0/? [00:00\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
000-0.0712220.33976301.0000000.6264260
1010.3397630.18934800.9950040.6264260
2020.1893480.67598900.9800670.6264260
3030.6759890.79726100.9553360.6264260
4040.7972610.99501600.9210610.6264260
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "\n", - "
\n", - " \n", - "\n", - "\n", - "\n", - " \n", - "
\n", - "\n", - "
\n", - " \n" - ], - "text/plain": [ - " series_id time_idx x y category future_known_feature \\\n", - "0 0 0 -0.071222 0.339763 0 1.000000 \n", - "1 0 1 0.339763 0.189348 0 0.995004 \n", - "2 0 2 0.189348 0.675989 0 0.980067 \n", - "3 0 3 0.675989 0.797261 0 0.955336 \n", - "4 0 4 0.797261 0.995016 0 0.921061 \n", - "\n", - " static_feature static_feature_cat \n", - "0 0.626426 0 \n", - "1 0.626426 0 \n", - "2 0.626426 0 \n", - "3 0.626426 0 \n", - "4 0.626426 0 " - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from lightning.pytorch import Trainer\n", - "import pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "\n", - "from pytorch_forecasting.metrics import MAE, SMAPE\n", - "\n", - "num_series = 100\n", - "seq_length = 50\n", - "data_list = []\n", - "for i in range(num_series):\n", - " x = np.arange(seq_length)\n", - " y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)\n", - " category = i % 5\n", - " static_value = np.random.rand()\n", - " for t in range(seq_length - 1):\n", - " data_list.append(\n", - " {\n", - " \"series_id\": i,\n", - " \"time_idx\": t,\n", - " \"x\": y[t],\n", - " \"y\": y[t + 1],\n", - " \"category\": category,\n", - " \"future_known_feature\": np.cos(t / 10),\n", - " \"static_feature\": static_value,\n", - " \"static_feature_cat\": i % 3,\n", - " }\n", - " )\n", - "data_df = pd.DataFrame(data_list)\n", - "data_df.head()" + "text/plain": [ + "Validation: | | 0/? [00:00 Dict[str, torch.Tensor]:\n", - " \"\"\"\n", - " Forward pass of the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " x : Dict[str, torch.Tensor]\n", - " Dictionary containing input tensors\n", - "\n", - " Returns\n", - " -------\n", - " Dict[str, torch.Tensor]\n", - " Dictionary containing output tensors\n", - " \"\"\"\n", - " raise NotImplementedError(\"Forward method must be implemented by subclass.\")\n", - "\n", - " def training_step(\n", - " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", - " ) -> STEP_OUTPUT:\n", - " \"\"\"\n", - " Training step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input and target tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - "\n", - " Returns\n", - " -------\n", - " STEP_OUTPUT\n", - " Dictionary containing the loss and other metrics.\n", - " \"\"\"\n", - " x, y = batch\n", - " y_hat_dict = self(x)\n", - " y_hat = y_hat_dict[\"prediction\"]\n", - " loss = self.loss(y_hat, y)\n", - " self.log(\n", - " \"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True\n", - " )\n", - " self.log_metrics(y_hat, y, prefix=\"train\")\n", - " return {\"loss\": loss}\n", - "\n", - " def validation_step(\n", - " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", - " ) -> STEP_OUTPUT:\n", - " \"\"\"\n", - " Validation step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input and target tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - "\n", - " Returns\n", - " -------\n", - " STEP_OUTPUT\n", - " Dictionary containing the loss and other metrics.\n", - " \"\"\"\n", - " x, y = batch\n", - " y_hat_dict = self(x)\n", - " y_hat = y_hat_dict[\"prediction\"]\n", - " loss = self.loss(y_hat, y)\n", - " self.log(\n", - " \"val_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", - " )\n", - " self.log_metrics(y_hat, y, prefix=\"val\")\n", - " return {\"val_loss\": loss}\n", - "\n", - " def test_step(\n", - " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", - " ) -> STEP_OUTPUT:\n", - " \"\"\"\n", - " Test step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input and target tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - "\n", - " Returns\n", - " -------\n", - " STEP_OUTPUT\n", - " Dictionary containing the loss and other metrics.\n", - " \"\"\"\n", - " x, y = batch\n", - " y_hat_dict = self(x)\n", - " y_hat = y_hat_dict[\"prediction\"]\n", - " loss = self.loss(y_hat, y)\n", - " self.log(\n", - " \"test_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", - " )\n", - " self.log_metrics(y_hat, y, prefix=\"test\")\n", - " return {\"test_loss\": loss}\n", - "\n", - " def predict_step(\n", - " self,\n", - " batch: Tuple[Dict[str, torch.Tensor]],\n", - " batch_idx: int,\n", - " dataloader_idx: int = 0,\n", - " ) -> torch.Tensor:\n", - " \"\"\"\n", - " Prediction step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - " dataloader_idx : int\n", - " Index of the dataloader.\n", - "\n", - " Returns\n", - " -------\n", - " torch.Tensor\n", - " Predicted output tensor.\n", - " \"\"\"\n", - " x, _ = batch\n", - " y_hat = self(x)\n", - " return y_hat\n", - "\n", - " def configure_optimizers(self) -> Dict:\n", - " \"\"\"\n", - " Configure the optimizer and learning rate scheduler.\n", - "\n", - " Returns\n", - " -------\n", - " Dict\n", - " Dictionary containing the optimizer and scheduler configuration.\n", - " \"\"\"\n", - " optimizer = self._get_optimizer()\n", - " if self.lr_scheduler is not None:\n", - " scheduler = self._get_scheduler(optimizer)\n", - " if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n", - " return {\n", - " \"optimizer\": optimizer,\n", - " \"lr_scheduler\": {\n", - " \"scheduler\": scheduler,\n", - " \"monitor\": \"val_loss\",\n", - " },\n", - " }\n", - " else:\n", - " return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n", - " return {\"optimizer\": optimizer}\n", - "\n", - " def _get_optimizer(self) -> Optimizer:\n", - " \"\"\"\n", - " Get the optimizer based on the specified optimizer name and parameters.\n", - "\n", - " Returns\n", - " -------\n", - " Optimizer\n", - " The optimizer instance.\n", - " \"\"\"\n", - " if isinstance(self.optimizer, str):\n", - " if self.optimizer.lower() == \"adam\":\n", - " return torch.optim.Adam(self.parameters(), **self.optimizer_params)\n", - " elif self.optimizer.lower() == \"sgd\":\n", - " return torch.optim.SGD(self.parameters(), **self.optimizer_params)\n", - " else:\n", - " raise ValueError(f\"Optimizer {self.optimizer} not supported.\")\n", - " elif isinstance(self.optimizer, Optimizer):\n", - " return self.optimizer\n", - " else:\n", - " raise ValueError(\n", - " \"Optimizer must be either a string or \"\n", - " \"an instance of torch.optim.Optimizer.\"\n", - " )\n", - "\n", - " def _get_scheduler(\n", - " self, optimizer: Optimizer\n", - " ) -> torch.optim.lr_scheduler._LRScheduler:\n", - " \"\"\"\n", - " Get the lr scheduler based on the specified scheduler name and params.\n", - "\n", - " Parameters\n", - " ----------\n", - " optimizer : Optimizer\n", - " The optimizer instance.\n", - "\n", - " Returns\n", - " -------\n", - " torch.optim.lr_scheduler._LRScheduler\n", - " The learning rate scheduler instance.\n", - " \"\"\"\n", - " if self.lr_scheduler.lower() == \"reduce_lr_on_plateau\":\n", - " return torch.optim.lr_scheduler.ReduceLROnPlateau(\n", - " optimizer, **self.lr_scheduler_params\n", - " )\n", - " elif self.lr_scheduler.lower() == \"step_lr\":\n", - " return torch.optim.lr_scheduler.StepLR(\n", - " optimizer, **self.lr_scheduler_params\n", - " )\n", - " else:\n", - " raise ValueError(f\"Scheduler {self.lr_scheduler} not supported.\")\n", - "\n", - " def log_metrics(\n", - " self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = \"val\"\n", - " ) -> None:\n", - " \"\"\"\n", - " Log additional metrics during training, validation, or testing.\n", - "\n", - " Parameters\n", - " ----------\n", - " y_hat : torch.Tensor\n", - " Predicted output tensor.\n", - " y : torch.Tensor\n", - " Target output tensor.\n", - " prefix : str\n", - " Prefix for the logged metrics (e.g., \"train\", \"val\", \"test\").\n", - " \"\"\"\n", - " for metric in self.logging_metrics:\n", - " metric_value = metric(y_hat, y)\n", - " self.log(\n", - " f\"{prefix}_{metric.__class__.__name__}\",\n", - " metric_value,\n", - " on_step=False,\n", - " on_epoch=True,\n", - " prog_bar=True,\n", - " logger=True,\n", - " )" + "text/plain": [ + "Validation: | | 0/? [00:00 0 else None\n", - " )\n", - "\n", - " self.lstm_encoder = nn.LSTM(\n", - " input_size=total_feature_size,\n", - " hidden_size=hidden_size,\n", - " num_layers=num_layers,\n", - " dropout=dropout,\n", - " batch_first=True,\n", - " )\n", - "\n", - " self.lstm_decoder = nn.LSTM(\n", - " input_size=total_feature_size,\n", - " hidden_size=hidden_size,\n", - " num_layers=num_layers,\n", - " dropout=dropout,\n", - " batch_first=True,\n", - " )\n", - "\n", - " self.self_attention = nn.MultiheadAttention(\n", - " embed_dim=hidden_size,\n", - " num_heads=attention_head_size,\n", - " dropout=dropout,\n", - " batch_first=True,\n", - " )\n", - "\n", - " self.pre_output = nn.Linear(hidden_size, hidden_size)\n", - " self.output_layer = nn.Linear(hidden_size, output_size)\n", - "\n", - " def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n", - " \"\"\"\n", - " Forward pass of the TFT model.\n", - "\n", - " Parameters\n", - " ----------\n", - " x : Dict[str, torch.Tensor]\n", - " Dictionary containing input tensors:\n", - " - encoder_cat: Categorical encoder features\n", - " - encoder_cont: Continuous encoder features\n", - " - decoder_cat: Categorical decoder features\n", - " - decoder_cont: Continuous decoder features\n", - " - static_categorical_features: Static categorical features\n", - " - static_continuous_features: Static continuous features\n", - "\n", - " Returns\n", - " -------\n", - " Dict[str, torch.Tensor]\n", - " Dictionary containing output tensors:\n", - " - prediction: Prediction output (batch_size, prediction_length, output_size)\n", - " \"\"\"\n", - " batch_size = x[\"encoder_cont\"].shape[0]\n", - "\n", - " encoder_cat = x.get(\n", - " \"encoder_cat\",\n", - " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", - " )\n", - " encoder_cont = x.get(\n", - " \"encoder_cont\",\n", - " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", - " )\n", - " decoder_cat = x.get(\n", - " \"decoder_cat\",\n", - " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", - " )\n", - " decoder_cont = x.get(\n", - " \"decoder_cont\",\n", - " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", - " )\n", - "\n", - " encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2)\n", - " decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2)\n", - "\n", - " static_context = None\n", - " if self.static_context_linear is not None:\n", - " static_cat = x.get(\n", - " \"static_categorical_features\",\n", - " torch.zeros(batch_size, 0, device=self.device),\n", - " )\n", - " static_cont = x.get(\n", - " \"static_continuous_features\",\n", - " torch.zeros(batch_size, 0, device=self.device),\n", - " )\n", - "\n", - " if static_cat.size(2) == 0 and static_cont.size(2) == 0:\n", - " static_context = None\n", - " elif static_cat.size(2) == 0:\n", - " static_input = static_cont.to(\n", - " dtype=self.static_context_linear.weight.dtype\n", - " )\n", - " static_context = self.static_context_linear(static_input)\n", - " static_context = static_context.view(batch_size, self.hidden_size)\n", - " elif static_cont.size(2) == 0:\n", - " static_input = static_cat.to(\n", - " dtype=self.static_context_linear.weight.dtype\n", - " )\n", - " static_context = self.static_context_linear(static_input)\n", - " static_context = static_context.view(batch_size, self.hidden_size)\n", - " else:\n", - "\n", - " static_input = torch.cat([static_cont, static_cat], dim=1).to(\n", - " dtype=self.static_context_linear.weight.dtype\n", - " )\n", - " static_context = self.static_context_linear(static_input)\n", - " static_context = static_context.view(batch_size, self.hidden_size)\n", - "\n", - " encoder_weights = self.encoder_var_selection(encoder_input)\n", - " encoder_input = encoder_input * encoder_weights\n", - "\n", - " decoder_weights = self.decoder_var_selection(decoder_input)\n", - " decoder_input = decoder_input * decoder_weights\n", - "\n", - " if static_context is not None:\n", - " encoder_static_context = static_context.unsqueeze(1).expand(\n", - " -1, self.max_encoder_length, -1\n", - " )\n", - " decoder_static_context = static_context.unsqueeze(1).expand(\n", - " -1, self.max_prediction_length, -1\n", - " )\n", - "\n", - " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", - " encoder_output = encoder_output + encoder_static_context\n", - " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", - " decoder_output = decoder_output + decoder_static_context\n", - " else:\n", - " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", - " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", - "\n", - " sequence = torch.cat([encoder_output, decoder_output], dim=1)\n", - "\n", - " if static_context is not None:\n", - " expanded_static_context = static_context.unsqueeze(1).expand(\n", - " -1, sequence.size(1), -1\n", - " )\n", - "\n", - " attended_output, _ = self.self_attention(\n", - " sequence + expanded_static_context, sequence, sequence\n", - " )\n", - " else:\n", - " attended_output, _ = self.self_attention(sequence, sequence, sequence)\n", - "\n", - " decoder_attended = attended_output[:, -self.max_prediction_length :, :]\n", - "\n", - " output = nn.functional.relu(self.pre_output(decoder_attended))\n", - " prediction = self.output_layer(output)\n", - "\n", - " return {\"prediction\": prediction}" + "text/plain": [ + "Testing: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃ Test metric DataLoader 0 ┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ test_MAE 0.4637378454208374 │\n", - "│ test_SMAPE 1.0857858657836914 │\n", - "│ test_loss 0.014832879416644573 │\n", - "└───────────────────────────┴───────────────────────────┘\n", - "\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.4637378454208374 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0857858657836914 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.014832879416644573 \u001b[0m\u001b[35m \u001b[0m│\n", - "└───────────────────────────┴───────────────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Prediction shape: torch.Size([32, 1, 1])\n", - "First prediction values: [[-0.00597369]]\n", - "First true values: [[-0.09480439]]\n", - "\n", - "TFT model test complete!\n" - ] - } + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃        Test metric               DataLoader 0        ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│         test_MAE              0.45287469029426575    │\n",
+       "│        test_SMAPE              0.942494809627533     │\n",
+       "│         test_loss             0.01396977063268423    │\n",
+       "└───────────────────────────┴───────────────────────────┘\n",
+       "
\n" ], - "source": [ - "model = TFT(\n", - " loss=nn.MSELoss(),\n", - " logging_metrics=[MAE(), SMAPE()],\n", - " optimizer=\"adam\",\n", - " optimizer_params={\"lr\": 1e-3},\n", - " lr_scheduler=\"reduce_lr_on_plateau\",\n", - " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", - " hidden_size=64,\n", - " num_layers=2,\n", - " attention_head_size=4,\n", - " dropout=0.1,\n", - " metadata=data_module.metadata,\n", - ")\n", - "\n", - "print(\"\\nTraining model...\")\n", - "trainer = Trainer(max_epochs=5, accelerator=\"auto\", devices=1, enable_progress_bar=True)\n", - "\n", - "trainer.fit(model, data_module)\n", - "\n", - "print(\"\\nEvaluating model...\")\n", - "test_metrics = trainer.test(model, data_module)\n", - "\n", - "model.eval()\n", - "with torch.no_grad():\n", - " test_batch = next(iter(data_module.test_dataloader()))\n", - " x_test, y_test = test_batch\n", - " y_pred = model(x_test)\n", - "\n", - " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", - " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", - " print(\"First true values:\", y_test[0].cpu().numpy())\n", - "print(\"\\nTFT model test complete!\")" + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.45287469029426575 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.942494809627533 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.01396977063268423 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zVRwi2MvLGgc" - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } + "metadata": {}, + "output_type": "display_data" }, - "nbformat": 4, - "nbformat_minor": 0 - } + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Prediction shape: torch.Size([32, 1, 1])\n", + "First prediction values: [[0.11045122]]\n", + "First true values: [[-0.0491814]]\n", + "\n", + "TFT model test complete!\n" + ] + } + ], + "source": [ + "model = TFT(\n", + " loss=nn.MSELoss(),\n", + " logging_metrics=[MAE(), SMAPE()],\n", + " optimizer=\"adam\",\n", + " optimizer_params={\"lr\": 1e-3},\n", + " lr_scheduler=\"reduce_lr_on_plateau\",\n", + " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", + " hidden_size=64,\n", + " num_layers=2,\n", + " attention_head_size=4,\n", + " dropout=0.1,\n", + " metadata=data_module.metadata,\n", + ")\n", + "\n", + "print(\"\\nTraining model...\")\n", + "trainer = Trainer(\n", + " max_epochs=5,\n", + " accelerator=\"auto\",\n", + " devices=1,\n", + " enable_progress_bar=True,\n", + " log_every_n_steps=10,\n", + ")\n", + "\n", + "trainer.fit(model, data_module)\n", + "\n", + "print(\"\\nEvaluating model...\")\n", + "test_metrics = trainer.test(model, data_module)\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " test_batch = next(iter(data_module.test_dataloader()))\n", + " x_test, y_test = test_batch\n", + " y_pred = model(x_test)\n", + "\n", + " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", + " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", + " print(\"First true values:\", y_test[0].cpu().numpy())\n", + "print(\"\\nTFT model test complete!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zVRwi2MvLGgc" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From f1957169c56fcb0e1bd1052bc1cbf34ced575e78 Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Tue, 13 May 2025 00:22:46 +0530 Subject: [PATCH 26/30] add usage notebook for v2 version of timexer tslib specific D2 layer is pending --- examples/ptf_V2_example.ipynb | 6 +- examples/tslib_v2_example.ipynb | 1677 +++++++++++++++++ pytorch_forecasting/data/data_module.py | 2 + pytorch_forecasting/models/__init__.py | 2 + .../models/timexer/__init__.py | 29 + .../models/timexer/_timexer.py | 289 +++ .../models/timexer/sub_modules.py | 251 +++ 7 files changed, 2252 insertions(+), 4 deletions(-) create mode 100644 examples/tslib_v2_example.ipynb create mode 100644 pytorch_forecasting/models/timexer/__init__.py create mode 100644 pytorch_forecasting/models/timexer/_timexer.py create mode 100644 pytorch_forecasting/models/timexer/sub_modules.py diff --git a/examples/ptf_V2_example.ipynb b/examples/ptf_V2_example.ipynb index 2e39108f3..419f55c03 100644 --- a/examples/ptf_V2_example.ipynb +++ b/examples/ptf_V2_example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -156,7 +156,7 @@ } ], "source": [ - "!pip install pytorch-forecasting" + "pip install pytorch-forecasting" ] }, { @@ -539,8 +539,6 @@ } ], "source": [ - "\n", - "\n", "num_series = 100\n", "seq_length = 50\n", "data_list = []\n", diff --git a/examples/tslib_v2_example.ipynb b/examples/tslib_v2_example.ipynb new file mode 100644 index 000000000..ded1a3590 --- /dev/null +++ b/examples/tslib_v2_example.ipynb @@ -0,0 +1,1677 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c20f63be", + "metadata": {}, + "source": [ + "# NOTE\n", + "\n", + "This notebook is just an example to demonstrate D1 compatibility with the TimeXer model. Considering that there is no concrete design for a TSLib specific D2 layer, for the time being we are using the `EncoderDecoderDataModule` and `BaseModel` for implementing `TimeXer`. The implementation is rather confusing with many overlapping bits because the D2 isn't solely built for TSLib, but is a demonstration of how TimeXer works with the latest version of v2 rework and shows promise for more models from TSlib." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7563b0a7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: pytorch-forecasting in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (1.2.0)\n", + "Requirement already satisfied: numpy<=3.0.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from pytorch-forecasting) (2.2.2)\n", + "Requirement already satisfied: torch!=2.0.1,<3.0.0,>=2.0.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from pytorch-forecasting) (2.5.1)\n", + "Requirement already satisfied: lightning<3.0.0,>=2.0.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from pytorch-forecasting) (2.5.0.post0)\n", + "Requirement already satisfied: scipy<2.0,>=1.8 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from pytorch-forecasting) (1.15.1)\n", + "Requirement already satisfied: pandas<3.0.0,>=1.3.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from pytorch-forecasting) (2.2.3)\n", + "Requirement already satisfied: scikit-learn<2.0,>=1.2 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from pytorch-forecasting) (1.6.1)\n", + "Requirement already satisfied: PyYAML<8.0,>=5.4 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.0.2)\n", + "Requirement already satisfied: fsspec[http]<2026.0,>=2022.5.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2024.12.0)\n", + "Requirement already satisfied: lightning-utilities<2.0,>=0.10.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.11.9)\n", + "Requirement already satisfied: packaging<25.0,>=20.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (24.2)\n", + "Requirement already satisfied: torchmetrics<3.0,>=0.7.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.6.1)\n", + "Requirement already satisfied: tqdm<6.0,>=4.57.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.67.1)\n", + "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.12.2)\n", + "Requirement already satisfied: pytorch-lightning in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.5.0.post0)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2024.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.1)\n", + "Requirement already satisfied: joblib>=1.2.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (3.5.0)\n", + "Requirement already satisfied: filelock in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.17.0)\n", + "Requirement already satisfied: networkx in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.1.5)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (9.1.0.70)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.5.8)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (11.2.1.3)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (10.3.5.147)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (11.6.1.9)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.3.1.170)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (2.21.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", + "Requirement already satisfied: triton==3.1.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.1.0)\n", + "Requirement already satisfied: setuptools in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (75.8.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from sympy==1.13.1->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.0)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.11.11)\n", + "Requirement already satisfied: six>=1.5 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas<3.0.0,>=1.3.0->pytorch-forecasting) (1.17.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from jinja2->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.0.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.4.4)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.2)\n", + "Requirement already satisfied: attrs>=17.3.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (25.1.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.5.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.1.0)\n", + "Requirement already satisfied: propcache>=0.2.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.2.1)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.18.3)\n", + "Requirement already satisfied: idna>=2.0 in /home/pranav/Desktop/code/pytorch-forecasting/.venv/lib/python3.12/site-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.10)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.1.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install pytorch-forecasting\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "524fb344", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any, Dict, List, Optional, Tuple, Union\n", + "\n", + "from lightning.pytorch import Trainer\n", + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.preprocessing import RobustScaler, StandardScaler\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import Optimizer\n", + "from torch.utils.data import Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "3f1b0019", + "metadata": {}, + "outputs": [], + "source": [ + "from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule\n", + "from pytorch_forecasting.data.encoders import (\n", + " EncoderNormalizer,\n", + " NaNLabelEncoder,\n", + " TorchNormalizer,\n", + ")\n", + "from pytorch_forecasting.data.timeseries import TimeSeries\n", + "from pytorch_forecasting.metrics import MAE, SMAPE\n", + "from pytorch_forecasting.models.timexer._timexer import TimeXer" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "453abaa2", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.microsoft.datawrangler.viewer.v0+json": { + "columns": [ + { + "name": "index", + "rawType": "int64", + "type": "integer" + }, + { + "name": "series_id", + "rawType": "int64", + "type": "integer" + }, + { + "name": "time_idx", + "rawType": "int64", + "type": "integer" + }, + { + "name": "x", + "rawType": "float64", + "type": "float" + }, + { + "name": "y", + "rawType": "float64", + "type": "float" + }, + { + "name": "static_feature", + "rawType": "float64", + "type": "float" + } + ], + "conversionMethod": "pd.DataFrame", + "ref": "1eaf488b-f9bd-44a2-be5d-6da4307bec55", + "rows": [ + [ + "0", + "0", + "0", + "-0.0149095238786125", + "0.060817338247073804", + "0.3881772925541149" + ], + [ + "1", + "0", + "1", + "0.060817338247073804", + "0.36381631980778206", + "0.3881772925541149" + ], + [ + "2", + "0", + "2", + "0.36381631980778206", + "0.5852826362835459", + "0.3881772925541149" + ], + [ + "3", + "0", + "3", + "0.5852826362835459", + "0.6755320182905834", + "0.3881772925541149" + ], + [ + "4", + "0", + "4", + "0.6755320182905834", + "0.7312102935054476", + "0.3881772925541149" + ] + ], + "shape": { + "columns": 5, + "rows": 5 + } + }, + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
series_idtime_idxxystatic_feature
000-0.0149100.0608170.388177
1010.0608170.3638160.388177
2020.3638160.5852830.388177
3030.5852830.6755320.388177
4040.6755320.7312100.388177
\n", + "
" + ], + "text/plain": [ + " series_id time_idx x y static_feature\n", + "0 0 0 -0.014910 0.060817 0.388177\n", + "1 0 1 0.060817 0.363816 0.388177\n", + "2 0 2 0.363816 0.585283 0.388177\n", + "3 0 3 0.585283 0.675532 0.388177\n", + "4 0 4 0.675532 0.731210 0.388177" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "num_series = 100\n", + "seq_length = 50\n", + "data_list = []\n", + "for i in range(num_series):\n", + " x = np.arange(seq_length)\n", + " y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)\n", + " category = i % 5\n", + " static_value = np.random.rand()\n", + " for t in range(seq_length - 1):\n", + " data_list.append(\n", + " {\n", + " \"series_id\": i,\n", + " \"time_idx\": t,\n", + " \"x\": y[t],\n", + " \"y\": y[t + 1],\n", + " \"static_feature\": static_value,\n", + " }\n", + " )\n", + "data_df = pd.DataFrame(data_list)\n", + "data_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ee0f975b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/pranav/Desktop/code/pytorch-forecasting/pytorch_forecasting/data/timeseries/_timeseries_v2.py:106: UserWarning: TimeSeries is part of an experimental rework of the pytorch-forecasting data layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. For beta testing, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "dataset = TimeSeries(\n", + " data=data_df,\n", + " time=\"time_idx\",\n", + " target=\"y\",\n", + " group=[\"series_id\"],\n", + " num=[\"x\", \"static_feature\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4f13b58f", + "metadata": {}, + "outputs": [], + "source": [ + "data_module = EncoderDecoderTimeSeriesDataModule(\n", + " time_series_dataset=dataset,\n", + " max_encoder_length=30,\n", + " max_prediction_length=1,\n", + " batch_size=16,\n", + " scalers={\"x\": StandardScaler(), \"static_feature\": StandardScaler()},\n", + " target_normalizer=TorchNormalizer(),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "33eb6e78", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'encoder_cat': 0,\n", + " 'encoder_cont': 2,\n", + " 'decoder_cat': 0,\n", + " 'decoder_cont': 0,\n", + " 'target': 1,\n", + " 'static_categorical_features': 0,\n", + " 'static_continuous_features': 0,\n", + " 'max_encoder_length': 30,\n", + " 'max_prediction_length': 1,\n", + " 'min_encoder_length': 30,\n", + " 'min_prediction_length': 1}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_module.setup(stage=\"fit\")\n", + "data_module.metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "7c943cd2", + "metadata": {}, + "outputs": [], + "source": [ + "model = TimeXer(\n", + " loss=nn.L1Loss(),\n", + " logging_metrics=[MAE(), SMAPE()],\n", + " context_length=30,\n", + " prediction_length=1,\n", + " task_name=\"long_term_forecast\",\n", + " features=\"MS\",\n", + " d_model=32,\n", + " n_heads=2,\n", + " e_layers=1,\n", + " d_ff=64,\n", + " dropout=0.1,\n", + " patch_length=1,\n", + " use_norm=False,\n", + " metadata=data_module.metadata,\n", + " optimizer=\"adam\",\n", + " optimizer_params={\"lr\": 1e-3},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "a472a9b5", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n" + ] + } + ], + "source": [ + "trainer = Trainer(\n", + " max_epochs=5,\n", + " accelerator=\"auto\",\n", + " devices=1,\n", + " enable_progress_bar=True,\n", + " log_every_n_steps=10,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9ee9aa67", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "You are using a CUDA device ('NVIDIA GeForce RTX 3050 6GB Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params | Mode \n", + "----------------------------------------------------------------\n", + "0 | loss | L1Loss | 0 | train\n", + "1 | en_embedding | EnEmbedding | 64 | train\n", + "2 | ex_embedding | DataEmbedding_inverted | 992 | train\n", + "3 | encoder | Encoder | 12.9 K | train\n", + "4 | head | FlattenHead | 993 | train\n", + "----------------------------------------------------------------\n", + "14.9 K Trainable params\n", + "0 Non-trainable params\n", + "14.9 K Total params\n", + "0.060 Total estimated model params size (MB)\n", + "36 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3913efebc2214dc489de8e1ff608c2f7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00 Dict[str, torch.Tensor]: + """ + Forecast for univariate or multivariate with single target (MS) case. + + Args: + x: Dictionary containing entries for encoder_cat, encoder_cont + """ + batch_size = x["encoder_cont"].shape[0] + encoder_cont = x["encoder_cont"] + encoder_time_idx = x.get("encoder_time_idx", None) + past_target = x.get( + "target", + torch.zeros(batch_size, self.prediction_length, 0, device=self.device), + ) + + if encoder_time_idx is not None and encoder_time_idx.dim() == 2: + # change [batch_size, time_steps] to [batch_size, time_steps, features] + encoder_time_idx = encoder_time_idx.unsqueeze(-1) + + en_embed, n_vars = self.en_embedding(past_target.permute(0, 2, 1)) + ex_embed = self.ex_embedding(encoder_cont, encoder_time_idx) + + enc_out = self.encoder(en_embed, ex_embed) + enc_out = torch.reshape( + enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]) + ) + + enc_out = enc_out.permute(0, 1, 3, 2) + + dec_out = self.head(enc_out) + if self.n_quantiles is not None: + dec_out = dec_out.permute(0, 2, 1, 3) + else: + dec_out = dec_out.permute(0, 2, 1) + + return dec_out + + def _forecast_multi(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forecast for multivariate with multiple targets (M) case. + + Args: + x: Dictionary containing entries for encoder_cat, encoder_cont + Returns: + Dictionary with predictions + """ + + batch_size = x["encoder_cont"].shape[0] + encoder_cont = x.get( + "encoder_cont", + torch.zeros(batch_size, self.prediction_length, device=self.device), + ) + encoder_time_idx = x.get("encoder_time_idx", None) + encoder_targets = x.get( + "target", + torch.zeros(batch_size, self.prediction_length, device=self.device), + ) + en_embed, n_vars = self.en_embedding(encoder_targets.permute(0, 2, 1)) + ex_embed = self.ex_embedding(encoder_cont, encoder_time_idx) + + # batch_size x sequence_length x d_model + enc_out = self.encoder(en_embed, ex_embed) + + enc_out = torch.reshape( + enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]) + ) # batch_size x n_vars x sequence_length x d_model + + enc_out = enc_out.permute(0, 1, 3, 2) + + dec_out = self.head(enc_out) + if self.n_quantiles is not None: + dec_out = dec_out.permute(0, 2, 1, 3) + else: + dec_out = dec_out.permute(0, 2, 1) + + return dec_out + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass of the model. + + Args: + x: Dictionary containing model inputs + + Returns: + Dictionary with model outputs + """ + if ( + self.task_name == "long_term_forecast" + or self.task_name == "short_term_forecast" + ): # noqa: E501 + if self.features == "M": + out = self._forecast_multi(x) + else: + out = self._forecast(x) + prediction = out[:, : self.prediction_length, :] + + # note: prediction.size(2) is the number of target variables i.e n_targets + target_indices = range(prediction.size(2)) + + if self.n_quantiles is not None: + prediction = [prediction[..., i, :] for i in target_indices] + else: + + if len(target_indices) == 1: + prediction = prediction[..., 0] + else: + prediction = [prediction[..., i] for i in target_indices] + return {"prediction": prediction} + else: + return None diff --git a/pytorch_forecasting/models/timexer/sub_modules.py b/pytorch_forecasting/models/timexer/sub_modules.py new file mode 100644 index 000000000..c3506ec49 --- /dev/null +++ b/pytorch_forecasting/models/timexer/sub_modules.py @@ -0,0 +1,251 @@ +""" +Implementation of `nn.Modules` for TimeXer model. +""" + +import math +from math import sqrt + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TriangularCausalMask: + def __init__(self, B, L, device="cpu"): + mask_shape = [B, 1, L, L] + with torch.no_grad(): + self._mask = torch.triu( + torch.ones(mask_shape, dtype=torch.bool), diagonal=1 + ).to(device) + + @property + def mask(self): + return self._mask + + +class FullAttention(nn.Module): + def __init__( + self, + mask_flag=True, + factor=5, + scale=None, + attention_dropout=0.1, + output_attention=False, + ): + super(FullAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1.0 / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + scores.masked_fill_(attn_mask.mask, -np.abs) + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + + +class AttentionLayer(nn.Module): + def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None): + super(AttentionLayer, self).__init__() + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_attention = attention + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + + out, attn = self.inner_attention( + queries, keys, values, attn_mask, tau=tau, delta=delta + ) + out = out.view(B, L, -1) + + return self.out_projection(out), attn + + +class DataEmbedding_inverted(nn.Module): + def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): + super(DataEmbedding_inverted, self).__init__() + self.value_embedding = nn.Linear(c_in, d_model) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = x.permute(0, 2, 1) + # x: [Batch Variate Time] + if x_mark is None: + x = self.value_embedding(x) + else: + x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) + # x: [Batch Variate d_model] + return self.dropout(x) + + +class PositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=5000): + super(PositionalEmbedding, self).__init__() + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + pe.require_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = ( + torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) + ).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + return self.pe[:, : x.size(1)] + + +class FlattenHead(nn.Module): + def __init__(self, n_vars, nf, target_window, head_dropout=0, n_quantiles=None): + super().__init__() + self.n_vars = n_vars + self.flatten = nn.Flatten(start_dim=-2) + self.linear = nn.Linear(nf, target_window) + self.n_quantiles = n_quantiles + + if self.n_quantiles is not None: + self.linear = nn.Linear(nf, target_window * n_quantiles) + else: + self.linear = nn.Linear(nf, target_window) + self.dropout = nn.Dropout(head_dropout) + + def forward(self, x): + x = self.flatten(x) + x = self.linear(x) + x = self.dropout(x) + + if self.n_quantiles is not None: + batch_size, n_vars = x.shape[0], x.shape[1] + x = x.reshape(batch_size, n_vars, -1, self.n_quantiles) + return x + + +class EnEmbedding(nn.Module): + def __init__(self, n_vars, d_model, patch_len, dropout): + super(EnEmbedding, self).__init__() + + self.patch_len = patch_len + + self.value_embedding = nn.Linear(patch_len, d_model, bias=False) + self.glb_token = nn.Parameter(torch.randn(1, n_vars, 1, d_model)) + self.position_embedding = PositionalEmbedding(d_model) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + n_vars = x.shape[1] + glb = self.glb_token.repeat((x.shape[0], 1, 1, 1)) + + x = x.unfold(dimension=-1, size=self.patch_len, step=self.patch_len) + x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) + # Input encoding + x = self.value_embedding(x) + self.position_embedding(x) + x = torch.reshape(x, (-1, n_vars, x.shape[-2], x.shape[-1])) + x = torch.cat([x, glb], dim=2) + x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) + return self.dropout(x), n_vars + + +class Encoder(nn.Module): + def __init__(self, layers, norm_layer=None, projection=None): + super(Encoder, self).__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + self.projection = projection + + def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): + for layer in self.layers: + x = layer( + x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta + ) + + if self.norm is not None: + x = self.norm(x) + + if self.projection is not None: + x = self.projection(x) + return x + + +class EncoderLayer(nn.Module): + def __init__( + self, + self_attention, + cross_attention, + d_model, + d_ff=None, + dropout=0.1, + activation="relu", + ): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): + B, L, D = cross.shape + x = x + self.dropout( + self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0] + ) + x = self.norm1(x) + + x_glb_ori = x[:, -1, :].unsqueeze(1) + x_glb = torch.reshape(x_glb_ori, (B, -1, D)) + x_glb_attn = self.dropout( + self.cross_attention( + x_glb, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta + )[0] + ) + x_glb_attn = torch.reshape( + x_glb_attn, (x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2]) + ).unsqueeze(1) + x_glb = x_glb_ori + x_glb_attn + x_glb = self.norm2(x_glb) + + y = x = torch.cat([x[:, :-1, :], x_glb], dim=1) + + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm3(x + y) From 2c14517964c1d39ca35828c168502daa4c0c0c17 Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Wed, 14 May 2025 00:55:47 +0530 Subject: [PATCH 27/30] dummy commit --- examples/tslib_v2_example.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tslib_v2_example.ipynb b/examples/tslib_v2_example.ipynb index ded1a3590..ea2f32d94 100644 --- a/examples/tslib_v2_example.ipynb +++ b/examples/tslib_v2_example.ipynb @@ -7,7 +7,7 @@ "source": [ "# NOTE\n", "\n", - "This notebook is just an example to demonstrate D1 compatibility with the TimeXer model. Considering that there is no concrete design for a TSLib specific D2 layer, for the time being we are using the `EncoderDecoderDataModule` and `BaseModel` for implementing `TimeXer`. The implementation is rather confusing with many overlapping bits because the D2 isn't solely built for TSLib, but is a demonstration of how TimeXer works with the latest version of v2 rework and shows promise for more models from TSlib." + "This notebook is just an example to demonstrate D1 compatibility with the `TimeXer` model. Considering that there is no concrete design for a TSLib specific D2 layer, for the time being we are using the `EncoderDecoderDataModule` and `BaseModel` for implementing `TimeXer`. The implementation is rather confusing with many overlapping bits because the D2 isn't solely built for TSLib, but is a demonstration of how TimeXer works with the latest version of v2 rework and shows promise for more models from TSlib." ] }, { From d84e1a0d9f61cca0dc45a687687819753896abb2 Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Wed, 14 May 2025 01:15:45 +0530 Subject: [PATCH 28/30] dummy commit to trigger code quality checks --- examples/tslib_v2_example.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tslib_v2_example.ipynb b/examples/tslib_v2_example.ipynb index ea2f32d94..44809a7ee 100644 --- a/examples/tslib_v2_example.ipynb +++ b/examples/tslib_v2_example.ipynb @@ -7,7 +7,7 @@ "source": [ "# NOTE\n", "\n", - "This notebook is just an example to demonstrate D1 compatibility with the `TimeXer` model. Considering that there is no concrete design for a TSLib specific D2 layer, for the time being we are using the `EncoderDecoderDataModule` and `BaseModel` for implementing `TimeXer`. The implementation is rather confusing with many overlapping bits because the D2 isn't solely built for TSLib, but is a demonstration of how TimeXer works with the latest version of v2 rework and shows promise for more models from TSlib." + "This notebook is just an example to demonstrate D1 compatibility with the `TimeXer` model. Considering that there is no concrete design for a TSLib specific D2 layer, for the time being we are using the `EncoderDecoderDataModule` and `BaseModel` for implementing `TimeXer`. The implementation is rather confusing with many overlapping bits because the D2 isn't solely built for TSLib, but is a demonstration of how `TimeXer` works with the latest version of v2 rework and shows promise for more models from TSlib." ] }, { From f87842f21565d276e35c194b666d9f8cf177c66d Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Wed, 14 May 2025 11:39:33 +0530 Subject: [PATCH 29/30] fix lint issues --- examples/ptf_V2_example.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ptf_V2_example.ipynb b/examples/ptf_V2_example.ipynb index 419f55c03..e8b85e9e9 100644 --- a/examples/ptf_V2_example.ipynb +++ b/examples/ptf_V2_example.ipynb @@ -156,7 +156,7 @@ } ], "source": [ - "pip install pytorch-forecasting" + "%pip install pytorch-forecasting" ] }, { From 21e1c632c2ff6e9b5b4f974ecb4a03b40513c582 Mon Sep 17 00:00:00 2001 From: Pranav Bhat Date: Fri, 30 May 2025 14:46:01 +0530 Subject: [PATCH 30/30] fix deprcated syntax to comply with latest code-quality checks --- examples/ptf_V2_example.ipynb | 2 - examples/tslib_v2_example.ipynb | 4 +- .../models/base/base_model_refactor.py | 42 +++++++++---------- .../tft_version_two.py | 16 +++---- .../models/timexer/_timexer.py | 16 +++---- .../models/timexer/sub_modules.py | 14 +++---- 6 files changed, 46 insertions(+), 48 deletions(-) diff --git a/examples/ptf_V2_example.ipynb b/examples/ptf_V2_example.ipynb index e8b85e9e9..71e90e9de 100644 --- a/examples/ptf_V2_example.ipynb +++ b/examples/ptf_V2_example.ipynb @@ -167,8 +167,6 @@ }, "outputs": [], "source": [ - "from typing import Any, Dict, List, Optional, Tuple, Union\n", - "\n", "from lightning.pytorch import Trainer\n", "import numpy as np\n", "import pandas as pd\n", diff --git a/examples/tslib_v2_example.ipynb b/examples/tslib_v2_example.ipynb index 44809a7ee..d68491d94 100644 --- a/examples/tslib_v2_example.ipynb +++ b/examples/tslib_v2_example.ipynb @@ -83,12 +83,12 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "524fb344", "metadata": {}, "outputs": [], "source": [ - "from typing import Any, Dict, List, Optional, Tuple, Union\n", + "from typing import Any, Optional, Union\n", "\n", "from lightning.pytorch import Trainer\n", "import numpy as np\n", diff --git a/pytorch_forecasting/models/base/base_model_refactor.py b/pytorch_forecasting/models/base/base_model_refactor.py index ccd2c2600..f03d70020 100644 --- a/pytorch_forecasting/models/base/base_model_refactor.py +++ b/pytorch_forecasting/models/base/base_model_refactor.py @@ -5,7 +5,7 @@ ######################################################################################## -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from lightning.pytorch import LightningModule from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -18,11 +18,11 @@ class BaseModel(LightningModule): def __init__( self, loss: nn.Module, - logging_metrics: Optional[List[nn.Module]] = None, + logging_metrics: Optional[list[nn.Module]] = None, optimizer: Optional[Union[Optimizer, str]] = "adam", - optimizer_params: Optional[Dict] = None, + optimizer_params: Optional[dict] = None, lr_scheduler: Optional[str] = None, - lr_scheduler_params: Optional[Dict] = None, + lr_scheduler_params: Optional[dict] = None, ): """ Base model for time series forecasting. @@ -31,17 +31,17 @@ def __init__( ---------- loss : nn.Module Loss function to use for training. - logging_metrics : Optional[List[nn.Module]], optional - List of metrics to log during training, validation, and testing. + logging_metrics : Optional[list[nn.Module]], optional + list of metrics to log during training, validation, and testing. optimizer : Optional[Union[Optimizer, str]], optional Optimizer to use for training. Can be a string ("adam", "sgd") or an instance of `torch.optim.Optimizer`. - optimizer_params : Optional[Dict], optional + optimizer_params : Optional[dict], optional Parameters for the optimizer. lr_scheduler : Optional[str], optional Learning rate scheduler to use. Supported values: "reduce_lr_on_plateau", "step_lr". - lr_scheduler_params : Optional[Dict], optional + lr_scheduler_params : Optional[dict], optional Parameters for the learning rate scheduler. """ super().__init__() @@ -54,31 +54,31 @@ def __init__( lr_scheduler_params if lr_scheduler_params is not None else {} ) - def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the model. Parameters ---------- - x : Dict[str, torch.Tensor] + x : dict[str, torch.Tensor] Dictionary containing input tensors Returns ------- - Dict[str, torch.Tensor] + dict[str, torch.Tensor] Dictionary containing output tensors """ raise NotImplementedError("Forward method must be implemented by subclass.") def training_step( - self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + self, batch: tuple[dict[str, torch.Tensor]], batch_idx: int ) -> STEP_OUTPUT: """ Training step for the model. Parameters ---------- - batch : Tuple[Dict[str, torch.Tensor]] + batch : tuple[dict[str, torch.Tensor]] Batch of data containing input and target tensors. batch_idx : int Index of the batch. @@ -99,14 +99,14 @@ def training_step( return {"loss": loss} def validation_step( - self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + self, batch: tuple[dict[str, torch.Tensor]], batch_idx: int ) -> STEP_OUTPUT: """ Validation step for the model. Parameters ---------- - batch : Tuple[Dict[str, torch.Tensor]] + batch : tuple[dict[str, torch.Tensor]] Batch of data containing input and target tensors. batch_idx : int Index of the batch. @@ -127,14 +127,14 @@ def validation_step( return {"val_loss": loss} def test_step( - self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + self, batch: tuple[dict[str, torch.Tensor]], batch_idx: int ) -> STEP_OUTPUT: """ Test step for the model. Parameters ---------- - batch : Tuple[Dict[str, torch.Tensor]] + batch : tuple[dict[str, torch.Tensor]] Batch of data containing input and target tensors. batch_idx : int Index of the batch. @@ -156,7 +156,7 @@ def test_step( def predict_step( self, - batch: Tuple[Dict[str, torch.Tensor]], + batch: tuple[dict[str, torch.Tensor]], batch_idx: int, dataloader_idx: int = 0, ) -> torch.Tensor: @@ -165,7 +165,7 @@ def predict_step( Parameters ---------- - batch : Tuple[Dict[str, torch.Tensor]] + batch : tuple[dict[str, torch.Tensor]] Batch of data containing input tensors. batch_idx : int Index of the batch. @@ -181,13 +181,13 @@ def predict_step( y_hat = self(x) return y_hat - def configure_optimizers(self) -> Dict: + def configure_optimizers(self) -> dict: """ Configure the optimizer and learning rate scheduler. Returns ------- - Dict + dict Dictionary containing the optimizer and scheduler configuration. """ optimizer = self._get_optimizer() diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py index 2bfe407d7..571d08b6e 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py @@ -3,7 +3,7 @@ # experimental, please use with care. ######################################################################################## -from typing import Dict, List, Optional, Union +from typing import Optional, Union import torch import torch.nn as nn @@ -16,16 +16,16 @@ class TFT(BaseModel): def __init__( self, loss: nn.Module, - logging_metrics: Optional[List[nn.Module]] = None, + logging_metrics: Optional[list[nn.Module]] = None, optimizer: Optional[Union[Optimizer, str]] = "adam", - optimizer_params: Optional[Dict] = None, + optimizer_params: Optional[dict] = None, lr_scheduler: Optional[str] = None, - lr_scheduler_params: Optional[Dict] = None, + lr_scheduler_params: Optional[dict] = None, hidden_size: int = 64, num_layers: int = 2, attention_head_size: int = 4, dropout: float = 0.1, - metadata: Optional[Dict] = None, + metadata: Optional[dict] = None, output_size: int = 1, ): super().__init__( @@ -110,13 +110,13 @@ def __init__( self.pre_output = nn.Linear(hidden_size, hidden_size) self.output_layer = nn.Linear(hidden_size, self.output_size) - def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the TFT model. Parameters ---------- - x : Dict[str, torch.Tensor] + x : dict[str, torch.Tensor] Dictionary containing input tensors: - encoder_cat: Categorical encoder features - encoder_cont: Continuous encoder features @@ -127,7 +127,7 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: Returns ------- - Dict[str, torch.Tensor] + dict[str, torch.Tensor] Dictionary containing output tensors: - prediction: Prediction output (batch_size, prediction_length, output_size) """ diff --git a/pytorch_forecasting/models/timexer/_timexer.py b/pytorch_forecasting/models/timexer/_timexer.py index 7d067b544..221476c95 100644 --- a/pytorch_forecasting/models/timexer/_timexer.py +++ b/pytorch_forecasting/models/timexer/_timexer.py @@ -11,7 +11,7 @@ ###################################################### from copy import copy -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import lightning.pytorch as pl from lightning.pytorch import LightningModule, Trainer @@ -48,11 +48,11 @@ def __init__( context_length: int, prediction_length: int, loss: nn.Module, - logging_metrics: Optional[List[nn.Module]] = None, + logging_metrics: Optional[list[nn.Module]] = None, optimizer: Optional[Union[Optimizer, str]] = "adam", - optimizer_params: Optional[Dict] = None, + optimizer_params: Optional[dict] = None, lr_scheduler: Optional[str] = None, - lr_scheduler_params: Optional[Dict] = None, + lr_scheduler_params: Optional[dict] = None, task_name: str = "long_term_forecast", features: str = "MS", enc_in: int = None, @@ -67,7 +67,7 @@ def __init__( factor: int = 5, embed_type: str = "fixed", freq: str = "h", - metadata: Optional[Dict] = None, + metadata: Optional[dict] = None, target_positions: torch.LongTensor = None, ): """An implementation of the TimeXer model. @@ -176,7 +176,7 @@ def __init__( n_quantiles=self.n_quantiles, ) - def _forecast(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forecast for univariate or multivariate with single target (MS) case. @@ -213,7 +213,7 @@ def _forecast(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return dec_out - def _forecast_multi(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def _forecast_multi(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forecast for multivariate with multiple targets (M) case. @@ -253,7 +253,7 @@ def _forecast_multi(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor] return dec_out - def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the model. diff --git a/pytorch_forecasting/models/timexer/sub_modules.py b/pytorch_forecasting/models/timexer/sub_modules.py index c3506ec49..c13b9fc61 100644 --- a/pytorch_forecasting/models/timexer/sub_modules.py +++ b/pytorch_forecasting/models/timexer/sub_modules.py @@ -33,7 +33,7 @@ def __init__( attention_dropout=0.1, output_attention=False, ): - super(FullAttention, self).__init__() + super().__init__() self.scale = scale self.mask_flag = mask_flag self.output_attention = output_attention @@ -61,7 +61,7 @@ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): class AttentionLayer(nn.Module): def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None): - super(AttentionLayer, self).__init__() + super().__init__() d_keys = d_keys or (d_model // n_heads) d_values = d_values or (d_model // n_heads) @@ -92,7 +92,7 @@ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): class DataEmbedding_inverted(nn.Module): def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): - super(DataEmbedding_inverted, self).__init__() + super().__init__() self.value_embedding = nn.Linear(c_in, d_model) self.dropout = nn.Dropout(p=dropout) @@ -109,7 +109,7 @@ def forward(self, x, x_mark): class PositionalEmbedding(nn.Module): def __init__(self, d_model, max_len=5000): - super(PositionalEmbedding, self).__init__() + super().__init__() # Compute the positional encodings once in log space. pe = torch.zeros(max_len, d_model).float() pe.require_grad = False @@ -156,7 +156,7 @@ def forward(self, x): class EnEmbedding(nn.Module): def __init__(self, n_vars, d_model, patch_len, dropout): - super(EnEmbedding, self).__init__() + super().__init__() self.patch_len = patch_len @@ -182,7 +182,7 @@ def forward(self, x): class Encoder(nn.Module): def __init__(self, layers, norm_layer=None, projection=None): - super(Encoder, self).__init__() + super().__init__() self.layers = nn.ModuleList(layers) self.norm = norm_layer self.projection = projection @@ -211,7 +211,7 @@ def __init__( dropout=0.1, activation="relu", ): - super(EncoderLayer, self).__init__() + super().__init__() d_ff = d_ff or 4 * d_model self.self_attention = self_attention self.cross_attention = cross_attention