From 252598d2ce3f31244a422cd9206961776ea79615 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 18:43:51 +0530 Subject: [PATCH 01/43] D1, D2 layer commit --- pytorch_forecasting/data/data_module.py | 633 ++++++++++++++++++++++++ pytorch_forecasting/data/timeseries.py | 257 ++++++++++ 2 files changed, 890 insertions(+) create mode 100644 pytorch_forecasting/data/data_module.py diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py new file mode 100644 index 000000000..56917696d --- /dev/null +++ b/pytorch_forecasting/data/data_module.py @@ -0,0 +1,633 @@ +####################################################################################### +# Disclaimer: This data-module is still work in progress and experimental, please +# use with care. This data-module is a basic skeleton of how the data-handling pipeline +# may look like in the future. +# This is D2 layer that will handle the preprocessing and data loaders. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + +from typing import Any, Dict, List, Optional, Tuple, Union + +from lightning.pytorch import LightningDataModule +from sklearn.preprocessing import RobustScaler, StandardScaler +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_forecasting.data.encoders import ( + EncoderNormalizer, + NaNLabelEncoder, + TorchNormalizer, +) +from pytorch_forecasting.data.timeseries import TimeSeries, _coerce_to_dict + +NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] + + +class EncoderDecoderTimeSeriesDataModule(LightningDataModule): + """ + Lightning DataModule for processing time series data in an encoder-decoder format. + + This module handles preprocessing, splitting, and batching of time series data + for use in deep learning models. It supports categorical and continuous features, + various scalers, and automatic target normalization. + + Parameters + ---------- + time_series_dataset : TimeSeries + The dataset containing time series data. + max_encoder_length : int, default=30 + Maximum length of the encoder input sequence. + min_encoder_length : Optional[int], default=None + Minimum length of the encoder input sequence. + Defaults to `max_encoder_length` if not specified. + max_prediction_length : int, default=1 + Maximum length of the decoder output sequence. + min_prediction_length : Optional[int], default=None + Minimum length of the decoder output sequence. + Defaults to `max_prediction_length` if not specified. + min_prediction_idx : Optional[int], default=None + Minimum index from which predictions start. + allow_missing_timesteps : bool, default=False + Whether to allow missing timesteps in the dataset. + add_relative_time_idx : bool, default=False + Whether to add a relative time index feature. + add_target_scales : bool, default=False + Whether to add target scaling information. + add_encoder_length : Union[bool, str], default="auto" + Whether to include encoder length information. + target_normalizer : + Union[NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None], + default="auto" + Normalizer for the target variable. If "auto", uses `RobustScaler`. + + categorical_encoders : Optional[Dict[str, NaNLabelEncoder]], default=None + Dictionary of categorical encoders. + + scalers : + Optional[Dict[str, Union[StandardScaler, RobustScaler, + TorchNormalizer, EncoderNormalizer]]], default=None + Dictionary of feature scalers. + + randomize_length : Union[None, Tuple[float, float], bool], default=False + Whether to randomize input sequence length. + batch_size : int, default=32 + Batch size for DataLoader. + num_workers : int, default=0 + Number of workers for DataLoader. + train_val_test_split : tuple, default=(0.7, 0.15, 0.15) + Proportions for train, validation, and test dataset splits. + """ + + def __init__( + self, + time_series_dataset: TimeSeries, + max_encoder_length: int = 30, + min_encoder_length: Optional[int] = None, + max_prediction_length: int = 1, + min_prediction_length: Optional[int] = None, + min_prediction_idx: Optional[int] = None, + allow_missing_timesteps: bool = False, + add_relative_time_idx: bool = False, + add_target_scales: bool = False, + add_encoder_length: Union[bool, str] = "auto", + target_normalizer: Union[ + NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None + ] = "auto", + categorical_encoders: Optional[Dict[str, NaNLabelEncoder]] = None, + scalers: Optional[ + Dict[ + str, + Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer], + ] + ] = None, + randomize_length: Union[None, Tuple[float, float], bool] = False, + batch_size: int = 32, + num_workers: int = 0, + train_val_test_split: tuple = (0.7, 0.15, 0.15), + ): + super().__init__() + self.time_series_dataset = time_series_dataset + self.time_series_metadata = time_series_dataset.get_metadata() + + self.max_encoder_length = max_encoder_length + self.min_encoder_length = min_encoder_length or max_encoder_length + self.max_prediction_length = max_prediction_length + self.min_prediction_length = min_prediction_length or max_prediction_length + self.min_prediction_idx = min_prediction_idx + + self.allow_missing_timesteps = allow_missing_timesteps + self.add_relative_time_idx = add_relative_time_idx + self.add_target_scales = add_target_scales + self.add_encoder_length = add_encoder_length + self.randomize_length = randomize_length + + self.batch_size = batch_size + self.num_workers = num_workers + self.train_val_test_split = train_val_test_split + + if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto": + self.target_normalizer = RobustScaler() + else: + self.target_normalizer = target_normalizer + + self.categorical_encoders = _coerce_to_dict(categorical_encoders) + self.scalers = _coerce_to_dict(scalers) + + self.categorical_indices = [] + self.continuous_indices = [] + self._metadata = None + + for idx, col in enumerate(self.time_series_metadata["cols"]["x"]): + if self.time_series_metadata["col_type"].get(col) == "C": + self.categorical_indices.append(idx) + else: + self.continuous_indices.append(idx) + + def _prepare_metadata(self): + """Prepare metadata for model initialisation. + + Returns + ------- + dict + dictionary containing the following keys: + + * ``encoder_cat``: Number of categorical variables in the encoder. + Computed as ``len(self.categorical_indices)``, which counts the + categorical feature indices. + * ``encoder_cont``: Number of continuous variables in the encoder. + Computed as ``len(self.continuous_indices)``, which counts the + continuous feature indices. + * ``decoder_cat``: Number of categorical variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "C"(categorical) and col_known == "K" (known) + * ``decoder_cont``: Number of continuous variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "F"(continuous) and col_known == "K"(known) + * ``target``: Number of target variables. + Computed as ``len(self.time_series_metadata["cols"]["y"])``, which + gives the number of output target columns.. + * ``static_categorical_features``: Number of static categorical features + Computed by filtering ``self.time_series_metadata["cols"]["st"]`` + (static features) where col_type == "C" (categorical). + * ``static_continuous_features``: Number of static continuous features + Computed as difference of + ``len(self.time_series_metadata["cols"]["st"])`` (static features) + and static_categorical_features that gives static continuous feature + * ``max_encoder_length``: maximum encoder length + Taken directly from `self.max_encoder_length`. + * ``max_prediction_length``: maximum prediction length + Taken directly from `self.max_prediction_length`. + * ``min_encoder_length``: minimum encoder length + Taken directly from `self.min_encoder_length`. + * ``min_prediction_length``: minimum prediction length + Taken directly from `self.min_prediction_length`. + + """ + encoder_cat_count = len(self.categorical_indices) + encoder_cont_count = len(self.continuous_indices) + + decoder_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "C" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + decoder_cont_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "F" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + + target_count = len(self.time_series_metadata["cols"]["y"]) + metadata = { + "encoder_cat": encoder_cat_count, + "encoder_cont": encoder_cont_count, + "decoder_cat": decoder_cat_count, + "decoder_cont": decoder_cont_count, + "target": target_count, + } + if self.time_series_metadata["cols"]["st"]: + static_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["st"] + if self.time_series_metadata["col_type"].get(col) == "C" + ] + ) + static_cont_count = ( + len(self.time_series_metadata["cols"]["st"]) - static_cat_count + ) + + metadata["static_categorical_features"] = static_cat_count + metadata["static_continuous_features"] = static_cont_count + else: + metadata["static_categorical_features"] = 0 + metadata["static_continuous_features"] = 0 + + metadata.update( + { + "max_encoder_length": self.max_encoder_length, + "max_prediction_length": self.max_prediction_length, + "min_encoder_length": self.min_encoder_length, + "min_prediction_length": self.min_prediction_length, + } + ) + + return metadata + + @property + def metadata(self): + """Compute metadata for model initialization. + + This property returns a dictionary containing the shapes and key information + related to the time series model. The metadata includes: + + * ``encoder_cat``: Number of categorical variables in the encoder. + * ``encoder_cont``: Number of continuous variables in the encoder. + * ``decoder_cat``: Number of categorical variables in the decoder that are + known in advance. + * ``decoder_cont``: Number of continuous variables in the decoder that are + known in advance. + * ``target``: Number of target variables. + + If static features are present, the following keys are added: + + * ``static_categorical_features``: Number of static categorical features + * ``static_continuous_features``: Number of static continuous features + + It also contains the following information: + + * ``max_encoder_length``: maximum encoder length + * ``max_prediction_length``: maximum prediction length + * ``min_encoder_length``: minimum encoder length + * ``min_prediction_length``: minimum prediction length + """ + if self._metadata is None: + self._metadata = self._prepare_metadata() + return self._metadata + + def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: + """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset. + + Preprocessing steps + -------------------- + + * Converts target (`y`) and features (`x`) to `torch.float32`. + * Masks time points that are at or before the cutoff time. + * Splits features into categorical and continuous subsets based on + predefined indices. + + + TODO: add scalers, target normalizers etc. + """ + processed_data = [] + + for idx in indices: + sample = self.time_series_dataset[idx.item()] + + target = sample["y"] + features = sample["x"] + times = sample["t"] + cutoff_time = sample["cutoff_time"] + + time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) + + if isinstance(target, torch.Tensor): + target = target.float() + else: + target = torch.tensor(target, dtype=torch.float32) + + if isinstance(features, torch.Tensor): + features = features.float() + else: + features = torch.tensor(features, dtype=torch.float32) + + # TODO: add scalers, target normalizers etc. + + categorical = ( + features[:, self.categorical_indices] + if self.categorical_indices + else torch.zeros((features.shape[0], 0)) + ) + continuous = ( + features[:, self.continuous_indices] + if self.continuous_indices + else torch.zeros((features.shape[0], 0)) + ) + + processed_data.append( + { + "features": {"categorical": categorical, "continuous": continuous}, + "target": target, + "static": sample.get("st", None), + "group": sample.get("group", torch.tensor([0])), + "length": len(target), + "time_mask": time_mask, + "times": times, + "cutoff_time": cutoff_time, + } + ) + + return processed_data + + class _ProcessedEncoderDecoderDataset(Dataset): + """PyTorch Dataset for processed encoder-decoder time series data. + + Parameters + ---------- + processed_data : List[Dict[str, Any]] + List of preprocessed time series samples. + windows : List[Tuple[int, int, int, int]] + List of window tuples containing + (series_idx, start_idx, enc_length, pred_length). + add_relative_time_idx : bool, default=False + Whether to include relative time indices. + """ + + def __init__( + self, + processed_data: List[Dict[str, Any]], + windows: List[Tuple[int, int, int, int]], + add_relative_time_idx: bool = False, + ): + self.processed_data = processed_data + self.windows = windows + self.add_relative_time_idx = add_relative_time_idx + + def __len__(self): + return len(self.windows) + + def __getitem__(self, idx): + """Retrieve a processed time series window for dataloader input. + + x : dict + Dictionary containing model inputs: + + * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features) + Categorical features for the encoder. + * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features) + Continuous features for the encoder. + * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features) + Categorical features for the decoder. + * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features) + Continuous features for the decoder. + * ``encoder_lengths`` : tensor of shape (1,) + Length of the encoder sequence. + * ``decoder_lengths`` : tensor of shape (1,) + Length of the decoder sequence. + * ``decoder_target_lengths`` : tensor of shape (1,) + Length of the decoder target sequence. + * ``groups`` : tensor of shape (1,) + Group identifier for the time series instance. + * ``encoder_time_idx`` : tensor of shape (enc_length,) + Time indices for the encoder sequence. + * ``decoder_time_idx`` : tensor of shape (pred_length,) + Time indices for the decoder sequence. + * ``target_scale`` : tensor of shape (1,) + Scaling factor for the target values. + * ``encoder_mask`` : tensor of shape (enc_length,) + Boolean mask indicating valid encoder time points. + * ``decoder_mask`` : tensor of shape (pred_length,) + Boolean mask indicating valid decoder time points. + + If static features are present, the following keys are added: + + * ``static_categorical_features`` : tensor of shape + (1, n_static_cat_features), optional + Static categorical features, if available. + * ``static_continuous_features`` : tensor of shape (1, 0), optional + Placeholder for static continuous features (currently empty). + + y : tensor of shape ``(pred_length, n_targets)`` + Target values for the decoder sequence. + """ + series_idx, start_idx, enc_length, pred_length = self.windows[idx] + data = self.processed_data[series_idx] + + end_idx = start_idx + enc_length + pred_length + encoder_indices = slice(start_idx, start_idx + enc_length) + decoder_indices = slice(start_idx + enc_length, end_idx) + + target_scale = data["target"][encoder_indices] + target_scale = target_scale[~torch.isnan(target_scale)].abs().mean() + if torch.isnan(target_scale) or target_scale == 0: + target_scale = torch.tensor(1.0) + + encoder_mask = ( + data["time_mask"][encoder_indices] + if "time_mask" in data + else torch.ones(enc_length, dtype=torch.bool) + ) + decoder_mask = ( + data["time_mask"][decoder_indices] + if "time_mask" in data + else torch.zeros(pred_length, dtype=torch.bool) + ) + + x = { + "encoder_cat": data["features"]["categorical"][encoder_indices], + "encoder_cont": data["features"]["continuous"][encoder_indices], + "decoder_cat": data["features"]["categorical"][decoder_indices], + "decoder_cont": data["features"]["continuous"][decoder_indices], + "encoder_lengths": torch.tensor(enc_length), + "decoder_lengths": torch.tensor(pred_length), + "decoder_target_lengths": torch.tensor(pred_length), + "groups": data["group"], + "encoder_time_idx": torch.arange(enc_length), + "decoder_time_idx": torch.arange(enc_length, enc_length + pred_length), + "target_scale": target_scale, + "encoder_mask": encoder_mask, + "decoder_mask": decoder_mask, + } + if data["static"] is not None: + x["static_categorical_features"] = data["static"].unsqueeze(0) + x["static_continuous_features"] = torch.zeros((1, 0)) + + y = data["target"][decoder_indices] + if y.ndim == 1: + y = y.unsqueeze(-1) + + return x, y + + def _create_windows( + self, processed_data: List[Dict[str, Any]] + ) -> List[Tuple[int, int, int, int]]: + """Generate sliding windows for training, validation, and testing. + + Returns + ------- + List[Tuple[int, int, int, int]] + A list of tuples, where each tuple consists of: + - ``series_idx`` : int + Index of the time series in `processed_data`. + - ``start_idx`` : int + Start index of the encoder window. + - ``enc_length`` : int + Length of the encoder input sequence. + - ``pred_length`` : int + Length of the decoder output sequence. + """ + windows = [] + + for idx, data in enumerate(processed_data): + sequence_length = data["length"] + + if sequence_length < self.max_encoder_length + self.max_prediction_length: + continue + + effective_min_prediction_idx = ( + self.min_prediction_idx + if self.min_prediction_idx is not None + else self.max_encoder_length + ) + + max_prediction_idx = sequence_length - self.max_prediction_length + 1 + + if max_prediction_idx <= effective_min_prediction_idx: + continue + + for start_idx in range( + 0, max_prediction_idx - effective_min_prediction_idx + ): + if ( + start_idx + self.max_encoder_length + self.max_prediction_length + <= sequence_length + ): + windows.append( + ( + idx, + start_idx, + self.max_encoder_length, + self.max_prediction_length, + ) + ) + + return windows + + def setup(self, stage: Optional[str] = None): + """Prepare the datasets for training, validation, testing, or prediction. + + Parameters + ---------- + stage : Optional[str], default=None + Specifies the stage of setup. Can be one of: + - ``"fit"`` : Prepares training and validation datasets. + - ``"test"`` : Prepares the test dataset. + - ``"predict"`` : Prepares the dataset for inference. + - ``None`` : Prepares all datasets. + """ + total_series = len(self.time_series_dataset) + self._split_indices = torch.randperm(total_series) + + self._train_size = int(self.train_val_test_split[0] * total_series) + self._val_size = int(self.train_val_test_split[1] * total_series) + + self._train_indices = self._split_indices[: self._train_size] + self._val_indices = self._split_indices[ + self._train_size : self._train_size + self._val_size + ] + self._test_indices = self._split_indices[self._train_size + self._val_size :] + + if stage is None or stage == "fit": + if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): + self.train_processed = self._preprocess_data(self._train_indices) + self.val_processed = self._preprocess_data(self._val_indices) + + self.train_windows = self._create_windows(self.train_processed) + self.val_windows = self._create_windows(self.val_processed) + + self.train_dataset = self._ProcessedEncoderDecoderDataset( + self.train_processed, self.train_windows, self.add_relative_time_idx + ) + self.val_dataset = self._ProcessedEncoderDecoderDataset( + self.val_processed, self.val_windows, self.add_relative_time_idx + ) + # print(self.val_dataset[0]) + + elif stage is None or stage == "test": + if not hasattr(self, "test_dataset"): + self.test_processed = self._preprocess_data(self._test_indices) + self.test_windows = self._create_windows(self.test_processed) + + self.test_dataset = self._ProcessedEncoderDecoderDataset( + self.test_processed, self.test_windows, self.add_relative_time_idx + ) + elif stage == "predict": + predict_indices = torch.arange(len(self.time_series_dataset)) + self.predict_processed = self._preprocess_data(predict_indices) + self.predict_windows = self._create_windows(self.predict_processed) + self.predict_dataset = self._ProcessedEncoderDecoderDataset( + self.predict_processed, self.predict_windows, self.add_relative_time_idx + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + collate_fn=self.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def predict_dataloader(self): + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + @staticmethod + def collate_fn(batch): + x_batch = { + "encoder_cat": torch.stack([x["encoder_cat"] for x, _ in batch]), + "encoder_cont": torch.stack([x["encoder_cont"] for x, _ in batch]), + "decoder_cat": torch.stack([x["decoder_cat"] for x, _ in batch]), + "decoder_cont": torch.stack([x["decoder_cont"] for x, _ in batch]), + "encoder_lengths": torch.stack([x["encoder_lengths"] for x, _ in batch]), + "decoder_lengths": torch.stack([x["decoder_lengths"] for x, _ in batch]), + "decoder_target_lengths": torch.stack( + [x["decoder_target_lengths"] for x, _ in batch] + ), + "groups": torch.stack([x["groups"] for x, _ in batch]), + "encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]), + "decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]), + "target_scale": torch.stack([x["target_scale"] for x, _ in batch]), + "encoder_mask": torch.stack([x["encoder_mask"] for x, _ in batch]), + "decoder_mask": torch.stack([x["decoder_mask"] for x, _ in batch]), + } + + if "static_categorical_features" in batch[0][0]: + x_batch["static_categorical_features"] = torch.stack( + [x["static_categorical_features"] for x, _ in batch] + ) + x_batch["static_continuous_features"] = torch.stack( + [x["static_continuous_features"] for x, _ in batch] + ) + + y_batch = torch.stack([y for _, y in batch]) + return x_batch, y_batch diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 336eecd5f..bc8300300 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2657,6 +2657,8 @@ def _coerce_to_list(obj): """ if obj is None: return [] + if isinstance(obj, str): + return [obj] return list(obj) @@ -2668,3 +2670,258 @@ def _coerce_to_dict(obj): if obj is None: return {} return deepcopy(obj) + + +####################################################################################### +# Disclaimer: This dataset class is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the data-handling pipeline may +# look like in the future. +# This is the D1 layer that is a "Raw Dataset Layer" mainly for raw data ingestion +# and turning the data to tensors. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + + +class TimeSeries(Dataset): + """PyTorch Dataset for time series data stored in pandas DataFrame. + + Parameters + ---------- + data : pd.DataFrame + data frame with sequence data. + Column names must all be str, and contain str as referred to below. + data_future : pd.DataFrame, optional, default=None + data frame with future data. + Column names must all be str, and contain str as referred to below. + May contain only columns that are in time, group, weight, known, or static. + time : str, optional, default = first col not in group_ids, weight, target, static. + integer typed column denoting the time index within ``data``. + This column is used to determine the sequence of samples. + If there are no missing observations, + the time index should increase by ``+1`` for each subsequent sample. + The first time_idx for each series does not necessarily + have to be ``0`` but any value is allowed. + target : str or List[str], optional, default = last column (at iloc -1) + column(s) in ``data`` denoting the forecasting target. + Can be categorical or numerical dtype. + group : List[str], optional, default = None + list of column names identifying a time series instance within ``data``. + This means that the ``group`` together uniquely identify an instance, + and ``group`` together with ``time`` uniquely identify a single observation + within a time series instance. + If ``None``, the dataset is assumed to be a single time series. + weight : str, optional, default=None + column name for weights. + If ``None``, it is assumed that there is no weight column. + num : list of str, optional, default = all columns with dtype in "fi" + list of numerical variables in ``data``, + list may also contain list of str, which are then grouped together. + cat : list of str, optional, default = all columns with dtype in "Obc" + list of categorical variables in ``data``, + list may also contain list of str, which are then grouped together + (e.g. useful for product categories). + known : list of str, optional, default = all variables + list of variables that change over time and are known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for special days or promotion categories). + unknown : list of str, optional, default = no variables + list of variables that are not known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for weather categories). + static : list of str, optional, default = all variables not in known, unknown + list of variables that do not change over time, + list may also contain list of str, which are then grouped together. + """ + + def __init__( + self, + data: pd.DataFrame, + data_future: Optional[pd.DataFrame] = None, + time: Optional[str] = None, + target: Optional[Union[str, List[str]]] = None, + group: Optional[List[str]] = None, + weight: Optional[str] = None, + num: Optional[List[Union[str, List[str]]]] = None, + cat: Optional[List[Union[str, List[str]]]] = None, + known: Optional[List[Union[str, List[str]]]] = None, + unknown: Optional[List[Union[str, List[str]]]] = None, + static: Optional[List[Union[str, List[str]]]] = None, + ): + + self.data = data + self.data_future = data_future + self.time = time + self.target = _coerce_to_list(target) + self.group = _coerce_to_list(group) + self.weight = weight + self.num = _coerce_to_list(num) + self.cat = _coerce_to_list(cat) + self.known = _coerce_to_list(known) + self.unknown = _coerce_to_list(unknown) + self.static = _coerce_to_list(static) + + self.feature_cols = [ + col + for col in data.columns + if col not in [self.time] + self.group + [self.weight] + self.target + ] + if self.group: + self._groups = self.data.groupby(self.group).groups + self._group_ids = list(self._groups.keys()) + else: + self._groups = {"_single_group": self.data.index} + self._group_ids = ["_single_group"] + + self._prepare_metadata() + + def _prepare_metadata(self): + """Prepare metadata for the dataset. + + The funcion returns metadata that contains: + + * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } + Names of columns for y, x, and static features. + List elements are in same order as column dimensions. + Columns not appearing are assumed to be named (x0, x1, etc.), + (y0, y1, etc.), (st0, st1, etc.). + * ``col_type``: dict[str, str] + maps column names to data types "F" (numerical) and "C" (categorical). + Column names not occurring are assumed "F". + * ``col_known``: dict[str, str] + maps column names to "K" (future known) or "U" (future unknown). + Column names not occurring are assumed "K". + """ + self.metadata = { + "cols": { + "y": self.target, + "x": self.feature_cols, + "st": self.static, + }, + "col_type": {}, + "col_known": {}, + } + + all_cols = self.target + self.feature_cols + self.static + for col in all_cols: + self.metadata["col_type"][col] = "C" if col in self.cat else "F" + + self.metadata["col_known"][col] = "K" if col in self.known else "U" + + def __len__(self) -> int: + """Return number of time series in the dataset.""" + return len(self._group_ids) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get time series data for given index. + + It returns: + + * ``t``: ``numpy.ndarray`` of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with ``y``, + and ``x`` not ending in ``f``. + * ``y``: tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with ``t``. + * ``x``: tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with ``t``. + * ``group``: tensor of shape (n_groups) + Group identifiers for time series instances. + * ``st``: tensor of shape (n_static_features) + Static features. + * ``cutoff_time``: float or ``numpy.float64`` + Cutoff time for the time series instance. + + Optionally, the following str-keyed entry can be included: + + * ``weights``: tensor of shape (n_timepoints), only if weight is not None + """ + group_id = self._group_ids[index] + + if self.group: + mask = self._groups[group_id] + data = self.data.loc[mask] + else: + data = self.data + + cutoff_time = data[self.time].max() + + result = { + "t": data[self.time].values, + "y": torch.tensor(data[self.target].values), + "x": torch.tensor(data[self.feature_cols].values), + "group": torch.tensor([hash(str(group_id))]), + "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), + "cutoff_time": cutoff_time, + } + + if self.data_future is not None: + if self.group: + future_mask = self.data_future.groupby(self.group).groups[group_id] + future_data = self.data_future.loc[future_mask] + else: + future_data = self.data_future + + combined_times = np.concatenate( + [data[self.time].values, future_data[self.time].values] + ) + combined_times = np.unique(combined_times) + combined_times.sort() + + num_timepoints = len(combined_times) + x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) + y_merged = np.full((num_timepoints, len(self.target)), np.nan) + + current_time_indices = {t: i for i, t in enumerate(combined_times)} + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + x_merged[idx] = data[self.feature_cols].values[i] + y_merged[idx] = data[self.target].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices: + idx = current_time_indices[t] + for j, col in enumerate(self.known): + if col in self.feature_cols: + feature_idx = self.feature_cols.index(col) + x_merged[idx, feature_idx] = future_data[col].values[i] + + result.update( + { + "t": combined_times, + "x": torch.tensor(x_merged, dtype=torch.float32), + "y": torch.tensor(y_merged, dtype=torch.float32), + } + ) + + if self.weight: + if self.data_future is not None and self.weight in self.data_future.columns: + weights_merged = np.full(num_timepoints, np.nan) + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + weights_merged[idx] = data[self.weight].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices and self.weight in future_data.columns: + idx = current_time_indices[t] + weights_merged[idx] = future_data[self.weight].values[i] + + result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) + else: + result["weights"] = torch.tensor( + data[self.weight].values, dtype=torch.float32 + ) + + return result + + def get_metadata(self) -> Dict: + """Return metadata about the dataset. + + Returns + ------- + Dict + Dictionary containing: + - cols: column names for y, x, and static features + - col_type: mapping of columns to their types (F/C) + - col_known: mapping of columns to their future known status (K/U) + """ + return self.metadata From d0d1c3ec7fb3bdee8e80d9ff83cd43e8990a5319 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 18:47:46 +0530 Subject: [PATCH 02/43] remove one comment --- pytorch_forecasting/data/data_module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 56917696d..2958f1705 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -550,7 +550,6 @@ def setup(self, stage: Optional[str] = None): self.val_dataset = self._ProcessedEncoderDecoderDataset( self.val_processed, self.val_windows, self.add_relative_time_idx ) - # print(self.val_dataset[0]) elif stage is None or stage == "test": if not hasattr(self, "test_dataset"): From 80e64d218a744557bd493ea07547f0f42b029573 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 19:07:01 +0530 Subject: [PATCH 03/43] model layer commit --- .../models/base/base_model_refactor.py | 283 ++++++++++++++++++ .../tft_version_two.py | 218 ++++++++++++++ 2 files changed, 501 insertions(+) create mode 100644 pytorch_forecasting/models/base/base_model_refactor.py create mode 100644 pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py diff --git a/pytorch_forecasting/models/base/base_model_refactor.py b/pytorch_forecasting/models/base/base_model_refactor.py new file mode 100644 index 000000000..ccd2c2600 --- /dev/null +++ b/pytorch_forecasting/models/base/base_model_refactor.py @@ -0,0 +1,283 @@ +######################################################################################## +# Disclaimer: This baseclass is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the base classes may look like +# in the version-2. +######################################################################################## + + +from typing import Dict, List, Optional, Tuple, Union + +from lightning.pytorch import LightningModule +from lightning.pytorch.utilities.types import STEP_OUTPUT +import torch +import torch.nn as nn +from torch.optim import Optimizer + + +class BaseModel(LightningModule): + def __init__( + self, + loss: nn.Module, + logging_metrics: Optional[List[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[Dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[Dict] = None, + ): + """ + Base model for time series forecasting. + + Parameters + ---------- + loss : nn.Module + Loss function to use for training. + logging_metrics : Optional[List[nn.Module]], optional + List of metrics to log during training, validation, and testing. + optimizer : Optional[Union[Optimizer, str]], optional + Optimizer to use for training. + Can be a string ("adam", "sgd") or an instance of `torch.optim.Optimizer`. + optimizer_params : Optional[Dict], optional + Parameters for the optimizer. + lr_scheduler : Optional[str], optional + Learning rate scheduler to use. + Supported values: "reduce_lr_on_plateau", "step_lr". + lr_scheduler_params : Optional[Dict], optional + Parameters for the learning rate scheduler. + """ + super().__init__() + self.loss = loss + self.logging_metrics = logging_metrics if logging_metrics is not None else [] + self.optimizer = optimizer + self.optimizer_params = optimizer_params if optimizer_params is not None else {} + self.lr_scheduler = lr_scheduler + self.lr_scheduler_params = ( + lr_scheduler_params if lr_scheduler_params is not None else {} + ) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass of the model. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Dictionary containing input tensors + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing output tensors + """ + raise NotImplementedError("Forward method must be implemented by subclass.") + + def training_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Training step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="train") + return {"loss": loss} + + def validation_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Validation step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="val") + return {"val_loss": loss} + + def test_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Test step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="test") + return {"test_loss": loss} + + def predict_step( + self, + batch: Tuple[Dict[str, torch.Tensor]], + batch_idx: int, + dataloader_idx: int = 0, + ) -> torch.Tensor: + """ + Prediction step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input tensors. + batch_idx : int + Index of the batch. + dataloader_idx : int + Index of the dataloader. + + Returns + ------- + torch.Tensor + Predicted output tensor. + """ + x, _ = batch + y_hat = self(x) + return y_hat + + def configure_optimizers(self) -> Dict: + """ + Configure the optimizer and learning rate scheduler. + + Returns + ------- + Dict + Dictionary containing the optimizer and scheduler configuration. + """ + optimizer = self._get_optimizer() + if self.lr_scheduler is not None: + scheduler = self._get_scheduler(optimizer) + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + }, + } + else: + return {"optimizer": optimizer, "lr_scheduler": scheduler} + return {"optimizer": optimizer} + + def _get_optimizer(self) -> Optimizer: + """ + Get the optimizer based on the specified optimizer name and parameters. + + Returns + ------- + Optimizer + The optimizer instance. + """ + if isinstance(self.optimizer, str): + if self.optimizer.lower() == "adam": + return torch.optim.Adam(self.parameters(), **self.optimizer_params) + elif self.optimizer.lower() == "sgd": + return torch.optim.SGD(self.parameters(), **self.optimizer_params) + else: + raise ValueError(f"Optimizer {self.optimizer} not supported.") + elif isinstance(self.optimizer, Optimizer): + return self.optimizer + else: + raise ValueError( + "Optimizer must be either a string or " + "an instance of torch.optim.Optimizer." + ) + + def _get_scheduler( + self, optimizer: Optimizer + ) -> torch.optim.lr_scheduler._LRScheduler: + """ + Get the lr scheduler based on the specified scheduler name and params. + + Parameters + ---------- + optimizer : Optimizer + The optimizer instance. + + Returns + ------- + torch.optim.lr_scheduler._LRScheduler + The learning rate scheduler instance. + """ + if self.lr_scheduler.lower() == "reduce_lr_on_plateau": + return torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, **self.lr_scheduler_params + ) + elif self.lr_scheduler.lower() == "step_lr": + return torch.optim.lr_scheduler.StepLR( + optimizer, **self.lr_scheduler_params + ) + else: + raise ValueError(f"Scheduler {self.lr_scheduler} not supported.") + + def log_metrics( + self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = "val" + ) -> None: + """ + Log additional metrics during training, validation, or testing. + + Parameters + ---------- + y_hat : torch.Tensor + Predicted output tensor. + y : torch.Tensor + Target output tensor. + prefix : str + Prefix for the logged metrics (e.g., "train", "val", "test"). + """ + for metric in self.logging_metrics: + metric_value = metric(y_hat, y) + self.log( + f"{prefix}_{metric.__class__.__name__}", + metric_value, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py new file mode 100644 index 000000000..30f70f98e --- /dev/null +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py @@ -0,0 +1,218 @@ +######################################################################################## +# Disclaimer: This implementation is based on the new version of data pipeline and is +# experimental, please use with care. +######################################################################################## + +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch.optim import Optimizer + +from pytorch_forecasting.models.base.base_model_refactor import BaseModel + + +class TFT(BaseModel): + def __init__( + self, + loss: nn.Module, + logging_metrics: Optional[List[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[Dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[Dict] = None, + hidden_size: int = 64, + num_layers: int = 2, + attention_head_size: int = 4, + dropout: float = 0.1, + metadata: Optional[Dict] = None, + output_size: int = 1, + ): + super().__init__( + loss=loss, + logging_metrics=logging_metrics, + optimizer=optimizer, + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + lr_scheduler_params=lr_scheduler_params, + ) + self.hidden_size = hidden_size + self.num_layers = num_layers + self.attention_head_size = attention_head_size + self.dropout = dropout + self.metadata = metadata + self.output_size = output_size + + self.max_encoder_length = self.metadata["max_encoder_length"] + self.max_prediction_length = self.metadata["max_prediction_length"] + self.encoder_cont = self.metadata["encoder_cont"] + self.encoder_cat = self.metadata["encoder_cat"] + self.static_categorical_features = self.metadata["static_categorical_features"] + self.static_continuous_features = self.metadata["static_continuous_features"] + + total_feature_size = self.encoder_cont + self.encoder_cat + total_static_size = ( + self.static_categorical_features + self.static_continuous_features + ) + + self.encoder_var_selection = nn.Sequential( + nn.Linear(total_feature_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, total_feature_size), + nn.Sigmoid(), + ) + + self.decoder_var_selection = nn.Sequential( + nn.Linear(total_feature_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, total_feature_size), + nn.Sigmoid(), + ) + + self.static_context_linear = ( + nn.Linear(total_static_size, hidden_size) if total_static_size > 0 else None + ) + + self.lstm_encoder = nn.LSTM( + input_size=total_feature_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + ) + + self.lstm_decoder = nn.LSTM( + input_size=total_feature_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + ) + + self.self_attention = nn.MultiheadAttention( + embed_dim=hidden_size, + num_heads=attention_head_size, + dropout=dropout, + batch_first=True, + ) + + self.pre_output = nn.Linear(hidden_size, hidden_size) + self.output_layer = nn.Linear(hidden_size, output_size) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass of the TFT model. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Dictionary containing input tensors: + - encoder_cat: Categorical encoder features + - encoder_cont: Continuous encoder features + - decoder_cat: Categorical decoder features + - decoder_cont: Continuous decoder features + - static_categorical_features: Static categorical features + - static_continuous_features: Static continuous features + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing output tensors: + - prediction: Prediction output (batch_size, prediction_length, output_size) + """ + batch_size = x["encoder_cont"].shape[0] + + encoder_cat = x.get( + "encoder_cat", + torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device), + ) + encoder_cont = x.get( + "encoder_cont", + torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device), + ) + decoder_cat = x.get( + "decoder_cat", + torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device), + ) + decoder_cont = x.get( + "decoder_cont", + torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device), + ) + + encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2) + decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2) + + static_context = None + if self.static_context_linear is not None: + static_cat = x.get( + "static_categorical_features", + torch.zeros(batch_size, 0, device=self.device), + ) + static_cont = x.get( + "static_continuous_features", + torch.zeros(batch_size, 0, device=self.device), + ) + + if static_cat.size(2) == 0 and static_cont.size(2) == 0: + static_context = None + elif static_cat.size(2) == 0: + static_input = static_cont.to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + elif static_cont.size(2) == 0: + static_input = static_cat.to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + else: + + static_input = torch.cat([static_cont, static_cat], dim=1).to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + + encoder_weights = self.encoder_var_selection(encoder_input) + encoder_input = encoder_input * encoder_weights + + decoder_weights = self.decoder_var_selection(decoder_input) + decoder_input = decoder_input * decoder_weights + + if static_context is not None: + encoder_static_context = static_context.unsqueeze(1).expand( + -1, self.max_encoder_length, -1 + ) + decoder_static_context = static_context.unsqueeze(1).expand( + -1, self.max_prediction_length, -1 + ) + + encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input) + encoder_output = encoder_output + encoder_static_context + decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n)) + decoder_output = decoder_output + decoder_static_context + else: + encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input) + decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n)) + + sequence = torch.cat([encoder_output, decoder_output], dim=1) + + if static_context is not None: + expanded_static_context = static_context.unsqueeze(1).expand( + -1, sequence.size(1), -1 + ) + + attended_output, _ = self.self_attention( + sequence + expanded_static_context, sequence, sequence + ) + else: + attended_output, _ = self.self_attention(sequence, sequence, sequence) + + decoder_attended = attended_output[:, -self.max_prediction_length :, :] + + output = nn.functional.relu(self.pre_output(decoder_attended)) + prediction = self.output_layer(output) + + return {"prediction": prediction} From 0319c29bf08a9d100b2b4d711ded0082f172e7f9 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 19:19:15 +0530 Subject: [PATCH 04/43] Example notebook --- examples/ptf_V2_example.ipynb | 5137 +++++++++++++++++++++++++++++++++ 1 file changed, 5137 insertions(+) create mode 100644 examples/ptf_V2_example.ipynb diff --git a/examples/ptf_V2_example.ipynb b/examples/ptf_V2_example.ipynb new file mode 100644 index 000000000..031d9d634 --- /dev/null +++ b/examples/ptf_V2_example.ipynb @@ -0,0 +1,5137 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2630DaOEI4AJ", + "outputId": "96798236-d2f1-4436-c047-49c3771d56c7" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting pytorch-forecasting\n", + " Downloading pytorch_forecasting-1.3.0-py3-none-any.whl.metadata (13 kB)\n", + "Requirement already satisfied: numpy<=3.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.0.2)\n", + "Requirement already satisfied: torch!=2.0.1,<3.0.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.6.0+cu124)\n", + "Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)\n", + " Downloading lightning-2.5.1-py3-none-any.whl.metadata (39 kB)\n", + "Requirement already satisfied: scipy<2.0,>=1.8 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.14.1)\n", + "Requirement already satisfied: pandas<3.0.0,>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.2.2)\n", + "Requirement already satisfied: scikit-learn<2.0,>=1.2 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.6.1)\n", + "Requirement already satisfied: PyYAML<8.0,>=5.4 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.0.2)\n", + "Requirement already satisfied: fsspec<2026.0,>=2022.5.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2025.3.2)\n", + "Collecting lightning-utilities<2.0,>=0.10.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)\n", + "Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (24.2)\n", + "Collecting torchmetrics<3.0,>=0.7.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading torchmetrics-1.7.0-py3-none-any.whl.metadata (21 kB)\n", + "Requirement already satisfied: tqdm<6.0,>=4.57.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.67.1)\n", + "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.13.0)\n", + "Collecting pytorch-lightning (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (3.6.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.18.0)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.1.6)\n", + "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-curand-cu12==10.3.5.147 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (0.6.2)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (2.21.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", + "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.2.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.0)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.11.15)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from lightning-utilities<2.0,>=0.10.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (75.2.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas<3.0.0,>=1.3.0->pytorch-forecasting) (1.17.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.0.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.6.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.2)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (25.3.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.5.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.3.1)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.3.1)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.18.3)\n", + "Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.11/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.10)\n", + "Downloading pytorch_forecasting-1.3.0-py3-none-any.whl (197 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m197.7/197.7 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning-2.5.1-py3-none-any.whl (818 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m818.9/818.9 kB\u001b[0m \u001b[31m16.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m41.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m29.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m30.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m31.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)\n", + "Downloading torchmetrics-1.7.0-py3-none-any.whl (960 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m960.9/960.9 kB\u001b[0m \u001b[31m35.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading pytorch_lightning-2.5.1-py3-none-any.whl (822 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.0/823.0 kB\u001b[0m \u001b[31m31.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch-lightning, lightning, pytorch-forecasting\n", + " Attempting uninstall: nvidia-nvjitlink-cu12\n", + " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", + " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", + " Attempting uninstall: nvidia-curand-cu12\n", + " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", + " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", + " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", + " Attempting uninstall: nvidia-cufft-cu12\n", + " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", + " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", + " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", + " Attempting uninstall: nvidia-cuda-runtime-cu12\n", + " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", + " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-cupti-cu12\n", + " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cublas-cu12\n", + " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", + " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", + " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", + " Attempting uninstall: nvidia-cusparse-cu12\n", + " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", + " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", + " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", + " Attempting uninstall: nvidia-cudnn-cu12\n", + " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", + " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", + " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", + " Attempting uninstall: nvidia-cusolver-cu12\n", + " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", + " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", + " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", + "Successfully installed lightning-2.5.1 lightning-utilities-0.14.3 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-forecasting-1.3.0 pytorch-lightning-2.5.1 torchmetrics-1.7.0\n" + ] + } + ], + "source": [ + "!pip install pytorch-forecasting" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "M7PQerTbI_tM" + }, + "outputs": [], + "source": [ + "from typing import Any, Dict, List, Optional, Tuple, Union\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "from torch.utils.data import Dataset\n", + "\n", + "from pytorch_forecasting.data.timeseries import _coerce_to_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "XmL5ukG9JDTD" + }, + "outputs": [], + "source": [ + "def _coerce_to_list(obj):\n", + " \"\"\"Coerce object to list.\n", + "\n", + " None is coerced to empty list, otherwise list constructor is used.\n", + " \"\"\"\n", + " if obj is None:\n", + " return []\n", + " if isinstance(obj, str):\n", + " return [obj]\n", + " return list(obj)\n", + "\n", + "\n", + "class TimeSeries(Dataset):\n", + " \"\"\"PyTorch Dataset for time series data stored in pandas DataFrame.\n", + "\n", + " Parameters\n", + " ----------\n", + " data : pd.DataFrame\n", + " data frame with sequence data.\n", + " Column names must all be str, and contain str as referred to below.\n", + " data_future : pd.DataFrame, optional, default=None\n", + " data frame with future data.\n", + " Column names must all be str, and contain str as referred to below.\n", + " May contain only columns that are in time, group, weight, known, or static.\n", + " time : str, optional, default = first col not in group_ids, weight, target, static.\n", + " integer typed column denoting the time index within ``data``.\n", + " This column is used to determine the sequence of samples.\n", + " If there are no missing observations,\n", + " the time index should increase by ``+1`` for each subsequent sample.\n", + " The first time_idx for each series does not necessarily\n", + " have to be ``0`` but any value is allowed.\n", + " target : str or List[str], optional, default = last column (at iloc -1)\n", + " column(s) in ``data`` denoting the forecasting target.\n", + " Can be categorical or numerical dtype.\n", + " group : List[str], optional, default = None\n", + " list of column names identifying a time series instance within ``data``.\n", + " This means that the ``group`` together uniquely identify an instance,\n", + " and ``group`` together with ``time`` uniquely identify a single observation\n", + " within a time series instance.\n", + " If ``None``, the dataset is assumed to be a single time series.\n", + " weight : str, optional, default=None\n", + " column name for weights.\n", + " If ``None``, it is assumed that there is no weight column.\n", + " num : list of str, optional, default = all columns with dtype in \"fi\"\n", + " list of numerical variables in ``data``,\n", + " list may also contain list of str, which are then grouped together.\n", + " cat : list of str, optional, default = all columns with dtype in \"Obc\"\n", + " list of categorical variables in ``data``,\n", + " list may also contain list of str, which are then grouped together\n", + " (e.g. useful for product categories).\n", + " known : list of str, optional, default = all variables\n", + " list of variables that change over time and are known in the future,\n", + " list may also contain list of str, which are then grouped together\n", + " (e.g. useful for special days or promotion categories).\n", + " unknown : list of str, optional, default = no variables\n", + " list of variables that are not known in the future,\n", + " list may also contain list of str, which are then grouped together\n", + " (e.g. useful for weather categories).\n", + " static : list of str, optional, default = all variables not in known, unknown\n", + " list of variables that do not change over time,\n", + " list may also contain list of str, which are then grouped together.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " data: pd.DataFrame,\n", + " data_future: Optional[pd.DataFrame] = None,\n", + " time: Optional[str] = None,\n", + " target: Optional[Union[str, List[str]]] = None,\n", + " group: Optional[List[str]] = None,\n", + " weight: Optional[str] = None,\n", + " num: Optional[List[Union[str, List[str]]]] = None,\n", + " cat: Optional[List[Union[str, List[str]]]] = None,\n", + " known: Optional[List[Union[str, List[str]]]] = None,\n", + " unknown: Optional[List[Union[str, List[str]]]] = None,\n", + " static: Optional[List[Union[str, List[str]]]] = None,\n", + " ):\n", + "\n", + " self.data = data\n", + " self.data_future = data_future\n", + " self.time = time\n", + " self.target = _coerce_to_list(target)\n", + " self.group = _coerce_to_list(group)\n", + " self.weight = weight\n", + " self.num = _coerce_to_list(num)\n", + " self.cat = _coerce_to_list(cat)\n", + " self.known = _coerce_to_list(known)\n", + " self.unknown = _coerce_to_list(unknown)\n", + " self.static = _coerce_to_list(static)\n", + "\n", + " self.feature_cols = [\n", + " col\n", + " for col in data.columns\n", + " if col not in [self.time] + self.group + [self.weight] + self.target\n", + " ]\n", + " if self.group:\n", + " self._groups = self.data.groupby(self.group).groups\n", + " self._group_ids = list(self._groups.keys())\n", + " else:\n", + " self._groups = {\"_single_group\": self.data.index}\n", + " self._group_ids = [\"_single_group\"]\n", + "\n", + " self._prepare_metadata()\n", + "\n", + " def _prepare_metadata(self):\n", + " \"\"\"Prepare metadata for the dataset.\n", + "\n", + " The funcion returns metadata that contains:\n", + "\n", + " * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] }\n", + " Names of columns for y, x, and static features.\n", + " List elements are in same order as column dimensions.\n", + " Columns not appearing are assumed to be named (x0, x1, etc.),\n", + " (y0, y1, etc.), (st0, st1, etc.).\n", + " * ``col_type``: dict[str, str]\n", + " maps column names to data types \"F\" (numerical) and \"C\" (categorical).\n", + " Column names not occurring are assumed \"F\".\n", + " * ``col_known``: dict[str, str]\n", + " maps column names to \"K\" (future known) or \"U\" (future unknown).\n", + " Column names not occurring are assumed \"K\".\n", + " \"\"\"\n", + " self.metadata = {\n", + " \"cols\": {\n", + " \"y\": self.target,\n", + " \"x\": self.feature_cols,\n", + " \"st\": self.static,\n", + " },\n", + " \"col_type\": {},\n", + " \"col_known\": {},\n", + " }\n", + "\n", + " all_cols = self.target + self.feature_cols + self.static\n", + " for col in all_cols:\n", + " self.metadata[\"col_type\"][col] = \"C\" if col in self.cat else \"F\"\n", + "\n", + " self.metadata[\"col_known\"][col] = \"K\" if col in self.known else \"U\"\n", + "\n", + " def __len__(self) -> int:\n", + " \"\"\"Return number of time series in the dataset.\"\"\"\n", + " return len(self._group_ids)\n", + "\n", + " def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:\n", + " \"\"\"Get time series data for given index.\n", + "\n", + " It returns:\n", + "\n", + " * ``t``: ``numpy.ndarray`` of shape (n_timepoints,)\n", + " Time index for each time point in the past or present. Aligned with ``y``,\n", + " and ``x`` not ending in ``f``.\n", + " * ``y``: tensor of shape (n_timepoints, n_targets)\n", + " Target values for each time point. Rows are time points, aligned with ``t``.\n", + " * ``x``: tensor of shape (n_timepoints, n_features)\n", + " Features for each time point. Rows are time points, aligned with ``t``.\n", + " * ``group``: tensor of shape (n_groups)\n", + " Group identifiers for time series instances.\n", + " * ``st``: tensor of shape (n_static_features)\n", + " Static features.\n", + " * ``cutoff_time``: float or ``numpy.float64``\n", + " Cutoff time for the time series instance.\n", + "\n", + " Optionally, the following str-keyed entry can be included:\n", + "\n", + " * ``weights``: tensor of shape (n_timepoints), only if weight is not None\n", + " \"\"\"\n", + " group_id = self._group_ids[index]\n", + "\n", + " if self.group:\n", + " mask = self._groups[group_id]\n", + " data = self.data.loc[mask]\n", + " else:\n", + " data = self.data\n", + "\n", + " cutoff_time = data[self.time].max()\n", + "\n", + " result = {\n", + " \"t\": data[self.time].values,\n", + " \"y\": torch.tensor(data[self.target].values),\n", + " \"x\": torch.tensor(data[self.feature_cols].values),\n", + " \"group\": torch.tensor([hash(str(group_id))]),\n", + " \"st\": torch.tensor(data[self.static].iloc[0].values if self.static else []),\n", + " \"cutoff_time\": cutoff_time,\n", + " }\n", + "\n", + " if self.data_future is not None:\n", + " if self.group:\n", + " future_mask = self.data_future.groupby(self.group).groups[group_id]\n", + " future_data = self.data_future.loc[future_mask]\n", + " else:\n", + " future_data = self.data_future\n", + "\n", + " combined_times = np.concatenate(\n", + " [data[self.time].values, future_data[self.time].values]\n", + " )\n", + " combined_times = np.unique(combined_times)\n", + " combined_times.sort()\n", + "\n", + " num_timepoints = len(combined_times)\n", + " x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan)\n", + " y_merged = np.full((num_timepoints, len(self.target)), np.nan)\n", + "\n", + " current_time_indices = {t: i for i, t in enumerate(combined_times)}\n", + " for i, t in enumerate(data[self.time].values):\n", + " idx = current_time_indices[t]\n", + " x_merged[idx] = data[self.feature_cols].values[i]\n", + " y_merged[idx] = data[self.target].values[i]\n", + "\n", + " for i, t in enumerate(future_data[self.time].values):\n", + " if t in current_time_indices:\n", + " idx = current_time_indices[t]\n", + " for j, col in enumerate(self.known):\n", + " if col in self.feature_cols:\n", + " feature_idx = self.feature_cols.index(col)\n", + " x_merged[idx, feature_idx] = future_data[col].values[i]\n", + "\n", + " result.update(\n", + " {\n", + " \"t\": combined_times,\n", + " \"x\": torch.tensor(x_merged, dtype=torch.float32),\n", + " \"y\": torch.tensor(y_merged, dtype=torch.float32),\n", + " }\n", + " )\n", + "\n", + " if self.weight:\n", + " if self.data_future is not None and self.weight in self.data_future.columns:\n", + " weights_merged = np.full(num_timepoints, np.nan)\n", + " for i, t in enumerate(data[self.time].values):\n", + " idx = current_time_indices[t]\n", + " weights_merged[idx] = data[self.weight].values[i]\n", + "\n", + " for i, t in enumerate(future_data[self.time].values):\n", + " if t in current_time_indices and self.weight in future_data.columns:\n", + " idx = current_time_indices[t]\n", + " weights_merged[idx] = future_data[self.weight].values[i]\n", + "\n", + " result[\"weights\"] = torch.tensor(weights_merged, dtype=torch.float32)\n", + " else:\n", + " result[\"weights\"] = torch.tensor(\n", + " data[self.weight].values, dtype=torch.float32\n", + " )\n", + "\n", + " return result\n", + "\n", + " def get_metadata(self) -> Dict:\n", + " \"\"\"Return metadata about the dataset.\n", + "\n", + " Returns\n", + " -------\n", + " Dict\n", + " Dictionary containing:\n", + " - cols: column names for y, x, and static features\n", + " - col_type: mapping of columns to their types (F/C)\n", + " - col_known: mapping of columns to their future known status (K/U)\n", + " \"\"\"\n", + " return self.metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "0Rw9LgsXJI5V" + }, + "outputs": [], + "source": [ + "from typing import Dict, List, Optional, Union\n", + "\n", + "from lightning.pytorch import LightningDataModule\n", + "from sklearn.preprocessing import RobustScaler, StandardScaler\n", + "import torch\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "from pytorch_forecasting.data.encoders import (\n", + " EncoderNormalizer,\n", + " NaNLabelEncoder,\n", + " TorchNormalizer,\n", + ")\n", + "\n", + "NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer]\n", + "\n", + "\n", + "class EncoderDecoderTimeSeriesDataModule(LightningDataModule):\n", + " \"\"\"\n", + " Lightning DataModule for processing time series data in an encoder-decoder format.\n", + "\n", + " This module handles preprocessing, splitting, and batching of time series data\n", + " for use in deep learning models. It supports categorical and continuous features,\n", + " various scalers, and automatic target normalization.\n", + "\n", + " Parameters\n", + " ----------\n", + " time_series_dataset : TimeSeries\n", + " The dataset containing time series data.\n", + " max_encoder_length : int, default=30\n", + " Maximum length of the encoder input sequence.\n", + " min_encoder_length : Optional[int], default=None\n", + " Minimum length of the encoder input sequence.\n", + " Defaults to `max_encoder_length` if not specified.\n", + " max_prediction_length : int, default=1\n", + " Maximum length of the decoder output sequence.\n", + " min_prediction_length : Optional[int], default=None\n", + " Minimum length of the decoder output sequence.\n", + " Defaults to `max_prediction_length` if not specified.\n", + " min_prediction_idx : Optional[int], default=None\n", + " Minimum index from which predictions start.\n", + " allow_missing_timesteps : bool, default=False\n", + " Whether to allow missing timesteps in the dataset.\n", + " add_relative_time_idx : bool, default=False\n", + " Whether to add a relative time index feature.\n", + " add_target_scales : bool, default=False\n", + " Whether to add target scaling information.\n", + " add_encoder_length : Union[bool, str], default=\"auto\"\n", + " Whether to include encoder length information.\n", + " target_normalizer :\n", + " Union[NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None],\n", + " default=\"auto\"\n", + " Normalizer for the target variable. If \"auto\", uses `RobustScaler`.\n", + "\n", + " categorical_encoders : Optional[Dict[str, NaNLabelEncoder]], default=None\n", + " Dictionary of categorical encoders.\n", + "\n", + " scalers :\n", + " Optional[Dict[str, Union[StandardScaler, RobustScaler,\n", + " TorchNormalizer, EncoderNormalizer]]], default=None\n", + " Dictionary of feature scalers.\n", + "\n", + " randomize_length : Union[None, Tuple[float, float], bool], default=False\n", + " Whether to randomize input sequence length.\n", + " batch_size : int, default=32\n", + " Batch size for DataLoader.\n", + " num_workers : int, default=0\n", + " Number of workers for DataLoader.\n", + " train_val_test_split : tuple, default=(0.7, 0.15, 0.15)\n", + " Proportions for train, validation, and test dataset splits.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " time_series_dataset: TimeSeries,\n", + " max_encoder_length: int = 30,\n", + " min_encoder_length: Optional[int] = None,\n", + " max_prediction_length: int = 1,\n", + " min_prediction_length: Optional[int] = None,\n", + " min_prediction_idx: Optional[int] = None,\n", + " allow_missing_timesteps: bool = False,\n", + " add_relative_time_idx: bool = False,\n", + " add_target_scales: bool = False,\n", + " add_encoder_length: Union[bool, str] = \"auto\",\n", + " target_normalizer: Union[\n", + " NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None\n", + " ] = \"auto\",\n", + " categorical_encoders: Optional[Dict[str, NaNLabelEncoder]] = None,\n", + " scalers: Optional[\n", + " Dict[\n", + " str,\n", + " Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer],\n", + " ]\n", + " ] = None,\n", + " randomize_length: Union[None, Tuple[float, float], bool] = False,\n", + " batch_size: int = 32,\n", + " num_workers: int = 0,\n", + " train_val_test_split: tuple = (0.7, 0.15, 0.15),\n", + " ):\n", + " super().__init__()\n", + " self.time_series_dataset = time_series_dataset\n", + " self.time_series_metadata = time_series_dataset.get_metadata()\n", + "\n", + " self.max_encoder_length = max_encoder_length\n", + " self.min_encoder_length = min_encoder_length or max_encoder_length\n", + " self.max_prediction_length = max_prediction_length\n", + " self.min_prediction_length = min_prediction_length or max_prediction_length\n", + " self.min_prediction_idx = min_prediction_idx\n", + "\n", + " self.allow_missing_timesteps = allow_missing_timesteps\n", + " self.add_relative_time_idx = add_relative_time_idx\n", + " self.add_target_scales = add_target_scales\n", + " self.add_encoder_length = add_encoder_length\n", + " self.randomize_length = randomize_length\n", + "\n", + " self.batch_size = batch_size\n", + " self.num_workers = num_workers\n", + " self.train_val_test_split = train_val_test_split\n", + "\n", + " if isinstance(target_normalizer, str) and target_normalizer.lower() == \"auto\":\n", + " self.target_normalizer = RobustScaler()\n", + " else:\n", + " self.target_normalizer = target_normalizer\n", + "\n", + " self.categorical_encoders = _coerce_to_dict(categorical_encoders)\n", + " self.scalers = _coerce_to_dict(scalers)\n", + "\n", + " self.categorical_indices = []\n", + " self.continuous_indices = []\n", + " self._metadata = None\n", + "\n", + " for idx, col in enumerate(self.time_series_metadata[\"cols\"][\"x\"]):\n", + " if self.time_series_metadata[\"col_type\"].get(col) == \"C\":\n", + " self.categorical_indices.append(idx)\n", + " else:\n", + " self.continuous_indices.append(idx)\n", + "\n", + " def _prepare_metadata(self):\n", + " \"\"\"Prepare metadata for model initialisation.\n", + "\n", + " Returns\n", + " -------\n", + " dict\n", + " dictionary containing the following keys:\n", + "\n", + " * ``encoder_cat``: Number of categorical variables in the encoder.\n", + " Computed as ``len(self.categorical_indices)``, which counts the\n", + " categorical feature indices.\n", + " * ``encoder_cont``: Number of continuous variables in the encoder.\n", + " Computed as ``len(self.continuous_indices)``, which counts the\n", + " continuous feature indices.\n", + " * ``decoder_cat``: Number of categorical variables in the decoder that\n", + " are known in advance.\n", + " Computed by filtering ``self.time_series_metadata[\"cols\"][\"x\"]``\n", + " where col_type == \"C\"(categorical) and col_known == \"K\" (known)\n", + " * ``decoder_cont``: Number of continuous variables in the decoder that\n", + " are known in advance.\n", + " Computed by filtering ``self.time_series_metadata[\"cols\"][\"x\"]``\n", + " where col_type == \"F\"(continuous) and col_known == \"K\"(known)\n", + " * ``target``: Number of target variables.\n", + " Computed as ``len(self.time_series_metadata[\"cols\"][\"y\"])``, which\n", + " gives the number of output target columns..\n", + " * ``static_categorical_features``: Number of static categorical features\n", + " Computed by filtering ``self.time_series_metadata[\"cols\"][\"st\"]``\n", + " (static features) where col_type == \"C\" (categorical).\n", + " * ``static_continuous_features``: Number of static continuous features\n", + " Computed as difference of\n", + " ``len(self.time_series_metadata[\"cols\"][\"st\"])`` (static features)\n", + " and static_categorical_features that gives static continuous feature\n", + " * ``max_encoder_length``: maximum encoder length\n", + " Taken directly from `self.max_encoder_length`.\n", + " * ``max_prediction_length``: maximum prediction length\n", + " Taken directly from `self.max_prediction_length`.\n", + " * ``min_encoder_length``: minimum encoder length\n", + " Taken directly from `self.min_encoder_length`.\n", + " * ``min_prediction_length``: minimum prediction length\n", + " Taken directly from `self.min_prediction_length`.\n", + "\n", + " \"\"\"\n", + " encoder_cat_count = len(self.categorical_indices)\n", + " encoder_cont_count = len(self.continuous_indices)\n", + "\n", + " decoder_cat_count = len(\n", + " [\n", + " col\n", + " for col in self.time_series_metadata[\"cols\"][\"x\"]\n", + " if self.time_series_metadata[\"col_type\"].get(col) == \"C\"\n", + " and self.time_series_metadata[\"col_known\"].get(col) == \"K\"\n", + " ]\n", + " )\n", + " decoder_cont_count = len(\n", + " [\n", + " col\n", + " for col in self.time_series_metadata[\"cols\"][\"x\"]\n", + " if self.time_series_metadata[\"col_type\"].get(col) == \"F\"\n", + " and self.time_series_metadata[\"col_known\"].get(col) == \"K\"\n", + " ]\n", + " )\n", + "\n", + " target_count = len(self.time_series_metadata[\"cols\"][\"y\"])\n", + " metadata = {\n", + " \"encoder_cat\": encoder_cat_count,\n", + " \"encoder_cont\": encoder_cont_count,\n", + " \"decoder_cat\": decoder_cat_count,\n", + " \"decoder_cont\": decoder_cont_count,\n", + " \"target\": target_count,\n", + " }\n", + " if self.time_series_metadata[\"cols\"][\"st\"]:\n", + " static_cat_count = len(\n", + " [\n", + " col\n", + " for col in self.time_series_metadata[\"cols\"][\"st\"]\n", + " if self.time_series_metadata[\"col_type\"].get(col) == \"C\"\n", + " ]\n", + " )\n", + " static_cont_count = (\n", + " len(self.time_series_metadata[\"cols\"][\"st\"]) - static_cat_count\n", + " )\n", + "\n", + " metadata[\"static_categorical_features\"] = static_cat_count\n", + " metadata[\"static_continuous_features\"] = static_cont_count\n", + " else:\n", + " metadata[\"static_categorical_features\"] = 0\n", + " metadata[\"static_continuous_features\"] = 0\n", + "\n", + " metadata.update(\n", + " {\n", + " \"max_encoder_length\": self.max_encoder_length,\n", + " \"max_prediction_length\": self.max_prediction_length,\n", + " \"min_encoder_length\": self.min_encoder_length,\n", + " \"min_prediction_length\": self.min_prediction_length,\n", + " }\n", + " )\n", + "\n", + " return metadata\n", + "\n", + " @property\n", + " def metadata(self):\n", + " \"\"\"Compute metadata for model initialization.\n", + "\n", + " This property returns a dictionary containing the shapes and key information\n", + " related to the time series model. The metadata includes:\n", + "\n", + " * ``encoder_cat``: Number of categorical variables in the encoder.\n", + " * ``encoder_cont``: Number of continuous variables in the encoder.\n", + " * ``decoder_cat``: Number of categorical variables in the decoder that are\n", + " known in advance.\n", + " * ``decoder_cont``: Number of continuous variables in the decoder that are\n", + " known in advance.\n", + " * ``target``: Number of target variables.\n", + "\n", + " If static features are present, the following keys are added:\n", + "\n", + " * ``static_categorical_features``: Number of static categorical features\n", + " * ``static_continuous_features``: Number of static continuous features\n", + "\n", + " It also contains the following information:\n", + "\n", + " * ``max_encoder_length``: maximum encoder length\n", + " * ``max_prediction_length``: maximum prediction length\n", + " * ``min_encoder_length``: minimum encoder length\n", + " * ``min_prediction_length``: minimum prediction length\n", + " \"\"\"\n", + " if self._metadata is None:\n", + " self._metadata = self._prepare_metadata()\n", + " return self._metadata\n", + "\n", + " def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]:\n", + " \"\"\"Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset.\n", + "\n", + " Preprocessing steps\n", + " --------------------\n", + "\n", + " * Converts target (`y`) and features (`x`) to `torch.float32`.\n", + " * Masks time points that are at or before the cutoff time.\n", + " * Splits features into categorical and continuous subsets based on\n", + " predefined indices.\n", + "\n", + "\n", + " TODO: add scalers, target normalizers etc.\n", + " \"\"\"\n", + " processed_data = []\n", + "\n", + " for idx in indices:\n", + " sample = self.time_series_dataset[idx.item()]\n", + "\n", + " target = sample[\"y\"]\n", + " features = sample[\"x\"]\n", + " times = sample[\"t\"]\n", + " cutoff_time = sample[\"cutoff_time\"]\n", + "\n", + " time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool)\n", + "\n", + " if isinstance(target, torch.Tensor):\n", + " target = target.float()\n", + " else:\n", + " target = torch.tensor(target, dtype=torch.float32)\n", + "\n", + " if isinstance(features, torch.Tensor):\n", + " features = features.float()\n", + " else:\n", + " features = torch.tensor(features, dtype=torch.float32)\n", + "\n", + " # TODO: add scalers, target normalizers etc.\n", + "\n", + " categorical = (\n", + " features[:, self.categorical_indices]\n", + " if self.categorical_indices\n", + " else torch.zeros((features.shape[0], 0))\n", + " )\n", + " continuous = (\n", + " features[:, self.continuous_indices]\n", + " if self.continuous_indices\n", + " else torch.zeros((features.shape[0], 0))\n", + " )\n", + "\n", + " processed_data.append(\n", + " {\n", + " \"features\": {\"categorical\": categorical, \"continuous\": continuous},\n", + " \"target\": target,\n", + " \"static\": sample.get(\"st\", None),\n", + " \"group\": sample.get(\"group\", torch.tensor([0])),\n", + " \"length\": len(target),\n", + " \"time_mask\": time_mask,\n", + " \"times\": times,\n", + " \"cutoff_time\": cutoff_time,\n", + " }\n", + " )\n", + "\n", + " return processed_data\n", + "\n", + " class _ProcessedEncoderDecoderDataset(Dataset):\n", + " \"\"\"PyTorch Dataset for processed encoder-decoder time series data.\n", + "\n", + " Parameters\n", + " ----------\n", + " processed_data : List[Dict[str, Any]]\n", + " List of preprocessed time series samples.\n", + " windows : List[Tuple[int, int, int, int]]\n", + " List of window tuples containing\n", + " (series_idx, start_idx, enc_length, pred_length).\n", + " add_relative_time_idx : bool, default=False\n", + " Whether to include relative time indices.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " processed_data: List[Dict[str, Any]],\n", + " windows: List[Tuple[int, int, int, int]],\n", + " add_relative_time_idx: bool = False,\n", + " ):\n", + " self.processed_data = processed_data\n", + " self.windows = windows\n", + " self.add_relative_time_idx = add_relative_time_idx\n", + "\n", + " def __len__(self):\n", + " return len(self.windows)\n", + "\n", + " def __getitem__(self, idx):\n", + " \"\"\"Retrieve a processed time series window for dataloader input.\n", + "\n", + " x : dict\n", + " Dictionary containing model inputs:\n", + "\n", + " * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features)\n", + " Categorical features for the encoder.\n", + " * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features)\n", + " Continuous features for the encoder.\n", + " * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features)\n", + " Categorical features for the decoder.\n", + " * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features)\n", + " Continuous features for the decoder.\n", + " * ``encoder_lengths`` : tensor of shape (1,)\n", + " Length of the encoder sequence.\n", + " * ``decoder_lengths`` : tensor of shape (1,)\n", + " Length of the decoder sequence.\n", + " * ``decoder_target_lengths`` : tensor of shape (1,)\n", + " Length of the decoder target sequence.\n", + " * ``groups`` : tensor of shape (1,)\n", + " Group identifier for the time series instance.\n", + " * ``encoder_time_idx`` : tensor of shape (enc_length,)\n", + " Time indices for the encoder sequence.\n", + " * ``decoder_time_idx`` : tensor of shape (pred_length,)\n", + " Time indices for the decoder sequence.\n", + " * ``target_scale`` : tensor of shape (1,)\n", + " Scaling factor for the target values.\n", + " * ``encoder_mask`` : tensor of shape (enc_length,)\n", + " Boolean mask indicating valid encoder time points.\n", + " * ``decoder_mask`` : tensor of shape (pred_length,)\n", + " Boolean mask indicating valid decoder time points.\n", + "\n", + " If static features are present, the following keys are added:\n", + "\n", + " * ``static_categorical_features`` : tensor of shape\n", + " (1, n_static_cat_features), optional\n", + " Static categorical features, if available.\n", + " * ``static_continuous_features`` : tensor of shape (1, 0), optional\n", + " Placeholder for static continuous features (currently empty).\n", + "\n", + " y : tensor of shape ``(pred_length, n_targets)``\n", + " Target values for the decoder sequence.\n", + " \"\"\"\n", + " series_idx, start_idx, enc_length, pred_length = self.windows[idx]\n", + " data = self.processed_data[series_idx]\n", + "\n", + " end_idx = start_idx + enc_length + pred_length\n", + " encoder_indices = slice(start_idx, start_idx + enc_length)\n", + " decoder_indices = slice(start_idx + enc_length, end_idx)\n", + "\n", + " target_scale = data[\"target\"][encoder_indices]\n", + " target_scale = target_scale[~torch.isnan(target_scale)].abs().mean()\n", + " if torch.isnan(target_scale) or target_scale == 0:\n", + " target_scale = torch.tensor(1.0)\n", + "\n", + " encoder_mask = (\n", + " data[\"time_mask\"][encoder_indices]\n", + " if \"time_mask\" in data\n", + " else torch.ones(enc_length, dtype=torch.bool)\n", + " )\n", + " decoder_mask = (\n", + " data[\"time_mask\"][decoder_indices]\n", + " if \"time_mask\" in data\n", + " else torch.zeros(pred_length, dtype=torch.bool)\n", + " )\n", + "\n", + " x = {\n", + " \"encoder_cat\": data[\"features\"][\"categorical\"][encoder_indices],\n", + " \"encoder_cont\": data[\"features\"][\"continuous\"][encoder_indices],\n", + " \"decoder_cat\": data[\"features\"][\"categorical\"][decoder_indices],\n", + " \"decoder_cont\": data[\"features\"][\"continuous\"][decoder_indices],\n", + " \"encoder_lengths\": torch.tensor(enc_length),\n", + " \"decoder_lengths\": torch.tensor(pred_length),\n", + " \"decoder_target_lengths\": torch.tensor(pred_length),\n", + " \"groups\": data[\"group\"],\n", + " \"encoder_time_idx\": torch.arange(enc_length),\n", + " \"decoder_time_idx\": torch.arange(enc_length, enc_length + pred_length),\n", + " \"target_scale\": target_scale,\n", + " \"encoder_mask\": encoder_mask,\n", + " \"decoder_mask\": decoder_mask,\n", + " }\n", + " if data[\"static\"] is not None:\n", + " x[\"static_categorical_features\"] = data[\"static\"].unsqueeze(0)\n", + " x[\"static_continuous_features\"] = torch.zeros((1, 0))\n", + "\n", + " y = data[\"target\"][decoder_indices]\n", + " if y.ndim == 1:\n", + " y = y.unsqueeze(-1)\n", + "\n", + " return x, y\n", + "\n", + " def _create_windows(\n", + " self, processed_data: List[Dict[str, Any]]\n", + " ) -> List[Tuple[int, int, int, int]]:\n", + " \"\"\"Generate sliding windows for training, validation, and testing.\n", + "\n", + " Returns\n", + " -------\n", + " List[Tuple[int, int, int, int]]\n", + " A list of tuples, where each tuple consists of:\n", + " - ``series_idx`` : int\n", + " Index of the time series in `processed_data`.\n", + " - ``start_idx`` : int\n", + " Start index of the encoder window.\n", + " - ``enc_length`` : int\n", + " Length of the encoder input sequence.\n", + " - ``pred_length`` : int\n", + " Length of the decoder output sequence.\n", + " \"\"\"\n", + " windows = []\n", + "\n", + " for idx, data in enumerate(processed_data):\n", + " sequence_length = data[\"length\"]\n", + "\n", + " if sequence_length < self.max_encoder_length + self.max_prediction_length:\n", + " continue\n", + "\n", + " effective_min_prediction_idx = (\n", + " self.min_prediction_idx\n", + " if self.min_prediction_idx is not None\n", + " else self.max_encoder_length\n", + " )\n", + "\n", + " max_prediction_idx = sequence_length - self.max_prediction_length + 1\n", + "\n", + " if max_prediction_idx <= effective_min_prediction_idx:\n", + " continue\n", + "\n", + " for start_idx in range(\n", + " 0, max_prediction_idx - effective_min_prediction_idx\n", + " ):\n", + " if (\n", + " start_idx + self.max_encoder_length + self.max_prediction_length\n", + " <= sequence_length\n", + " ):\n", + " windows.append(\n", + " (\n", + " idx,\n", + " start_idx,\n", + " self.max_encoder_length,\n", + " self.max_prediction_length,\n", + " )\n", + " )\n", + "\n", + " return windows\n", + "\n", + " def setup(self, stage: Optional[str] = None):\n", + " \"\"\"Prepare the datasets for training, validation, testing, or prediction.\n", + "\n", + " Parameters\n", + " ----------\n", + " stage : Optional[str], default=None\n", + " Specifies the stage of setup. Can be one of:\n", + " - ``\"fit\"`` : Prepares training and validation datasets.\n", + " - ``\"test\"`` : Prepares the test dataset.\n", + " - ``\"predict\"`` : Prepares the dataset for inference.\n", + " - ``None`` : Prepares all datasets.\n", + " \"\"\"\n", + " total_series = len(self.time_series_dataset)\n", + " self._split_indices = torch.randperm(total_series)\n", + "\n", + " self._train_size = int(self.train_val_test_split[0] * total_series)\n", + " self._val_size = int(self.train_val_test_split[1] * total_series)\n", + "\n", + " self._train_indices = self._split_indices[: self._train_size]\n", + " self._val_indices = self._split_indices[\n", + " self._train_size : self._train_size + self._val_size\n", + " ]\n", + " self._test_indices = self._split_indices[self._train_size + self._val_size :]\n", + "\n", + " if stage is None or stage == \"fit\":\n", + " if not hasattr(self, \"train_dataset\") or not hasattr(self, \"val_dataset\"):\n", + " self.train_processed = self._preprocess_data(self._train_indices)\n", + " self.val_processed = self._preprocess_data(self._val_indices)\n", + "\n", + " self.train_windows = self._create_windows(self.train_processed)\n", + " self.val_windows = self._create_windows(self.val_processed)\n", + "\n", + " self.train_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.train_processed, self.train_windows, self.add_relative_time_idx\n", + " )\n", + " self.val_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.val_processed, self.val_windows, self.add_relative_time_idx\n", + " )\n", + " # print(self.val_dataset[0])\n", + "\n", + " elif stage is None or stage == \"test\":\n", + " if not hasattr(self, \"test_dataset\"):\n", + " self.test_processed = self._preprocess_data(self._test_indices)\n", + " self.test_windows = self._create_windows(self.test_processed)\n", + "\n", + " self.test_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.test_processed, self.test_windows, self.add_relative_time_idx\n", + " )\n", + " elif stage == \"predict\":\n", + " predict_indices = torch.arange(len(self.time_series_dataset))\n", + " self.predict_processed = self._preprocess_data(predict_indices)\n", + " self.predict_windows = self._create_windows(self.predict_processed)\n", + " self.predict_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.predict_processed, self.predict_windows, self.add_relative_time_idx\n", + " )\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(\n", + " self.train_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " shuffle=True,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(\n", + " self.val_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(\n", + " self.test_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " def predict_dataloader(self):\n", + " return DataLoader(\n", + " self.predict_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " @staticmethod\n", + " def collate_fn(batch):\n", + " x_batch = {\n", + " \"encoder_cat\": torch.stack([x[\"encoder_cat\"] for x, _ in batch]),\n", + " \"encoder_cont\": torch.stack([x[\"encoder_cont\"] for x, _ in batch]),\n", + " \"decoder_cat\": torch.stack([x[\"decoder_cat\"] for x, _ in batch]),\n", + " \"decoder_cont\": torch.stack([x[\"decoder_cont\"] for x, _ in batch]),\n", + " \"encoder_lengths\": torch.stack([x[\"encoder_lengths\"] for x, _ in batch]),\n", + " \"decoder_lengths\": torch.stack([x[\"decoder_lengths\"] for x, _ in batch]),\n", + " \"decoder_target_lengths\": torch.stack(\n", + " [x[\"decoder_target_lengths\"] for x, _ in batch]\n", + " ),\n", + " \"groups\": torch.stack([x[\"groups\"] for x, _ in batch]),\n", + " \"encoder_time_idx\": torch.stack([x[\"encoder_time_idx\"] for x, _ in batch]),\n", + " \"decoder_time_idx\": torch.stack([x[\"decoder_time_idx\"] for x, _ in batch]),\n", + " \"target_scale\": torch.stack([x[\"target_scale\"] for x, _ in batch]),\n", + " \"encoder_mask\": torch.stack([x[\"encoder_mask\"] for x, _ in batch]),\n", + " \"decoder_mask\": torch.stack([x[\"decoder_mask\"] for x, _ in batch]),\n", + " }\n", + "\n", + " if \"static_categorical_features\" in batch[0][0]:\n", + " x_batch[\"static_categorical_features\"] = torch.stack(\n", + " [x[\"static_categorical_features\"] for x, _ in batch]\n", + " )\n", + " x_batch[\"static_continuous_features\"] = torch.stack(\n", + " [x[\"static_continuous_features\"] for x, _ in batch]\n", + " )\n", + "\n", + " y_batch = torch.stack([y for _, y in batch])\n", + " return x_batch, y_batch" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "WX-FRdusJSVN", + "outputId": "730f0fe2-f5af-4871-859d-9a4043bbeac7" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6723608094711289,\n \"min\": -1.2295384314749236,\n \"max\": 1.3194322331654313,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.3958947786657995,\n 0.7816648993958805,\n -0.9655256111265276\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6766981377008536,\n \"min\": -1.2295384314749236,\n \"max\": 1.3194322331654313,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.5707668517530808,\n 0.5020485177883972,\n -0.7734543579445009\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2783997753182645,\n \"min\": 0.011560494046953695,\n \"max\": 0.9996855497257285,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.43066463390546217,\n 0.08751257529405387,\n 0.3593350820130162\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", + "type": "dataframe", + "variable_name": "data_df" + }, + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
0000.0231380.24983401.0000000.4546680
1010.2498340.21382100.9950040.4546680
2020.2138210.67182900.9800670.4546680
3030.6718290.78104200.9553360.4546680
4040.7810420.70609200.9210610.4546680
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "text/plain": [ + " series_id time_idx x y category future_known_feature \\\n", + "0 0 0 0.023138 0.249834 0 1.000000 \n", + "1 0 1 0.249834 0.213821 0 0.995004 \n", + "2 0 2 0.213821 0.671829 0 0.980067 \n", + "3 0 3 0.671829 0.781042 0 0.955336 \n", + "4 0 4 0.781042 0.706092 0 0.921061 \n", + "\n", + " static_feature static_feature_cat \n", + "0 0.454668 0 \n", + "1 0.454668 0 \n", + "2 0.454668 0 \n", + "3 0.454668 0 \n", + "4 0.454668 0 " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from lightning.pytorch import Trainer\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "\n", + "from pytorch_forecasting.metrics import MAE, SMAPE\n", + "\n", + "num_series = 100\n", + "seq_length = 50\n", + "data_list = []\n", + "for i in range(num_series):\n", + " x = np.arange(seq_length)\n", + " y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)\n", + " category = i % 5\n", + " static_value = np.random.rand()\n", + " for t in range(seq_length - 1):\n", + " data_list.append(\n", + " {\n", + " \"series_id\": i,\n", + " \"time_idx\": t,\n", + " \"x\": y[t],\n", + " \"y\": y[t + 1],\n", + " \"category\": category,\n", + " \"future_known_feature\": np.cos(t / 10),\n", + " \"static_feature\": static_value,\n", + " \"static_feature_cat\": i % 3,\n", + " }\n", + " )\n", + "data_df = pd.DataFrame(data_list)\n", + "data_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "AxxPHK6AKSD2" + }, + "outputs": [], + "source": [ + "dataset = TimeSeries(\n", + " data=data_df,\n", + " time=\"time_idx\",\n", + " target=\"y\",\n", + " group=[\"series_id\"],\n", + " num=[\"x\", \"future_known_feature\", \"static_feature\"],\n", + " cat=[\"category\", \"static_feature_cat\"],\n", + " known=[\"future_known_feature\"],\n", + " unknown=[\"x\", \"category\"],\n", + " static=[\"static_feature\", \"static_feature_cat\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "5U5Lr_ZFKX0s" + }, + "outputs": [], + "source": [ + "data_module = EncoderDecoderTimeSeriesDataModule(\n", + " time_series_dataset=dataset,\n", + " max_encoder_length=30,\n", + " max_prediction_length=1,\n", + " batch_size=32,\n", + " categorical_encoders={\n", + " \"category\": NaNLabelEncoder(add_nan=True),\n", + " \"static_feature_cat\": NaNLabelEncoder(add_nan=True),\n", + " },\n", + " scalers={\n", + " \"x\": StandardScaler(),\n", + " \"future_known_feature\": StandardScaler(),\n", + " \"static_feature\": StandardScaler(),\n", + " },\n", + " target_normalizer=TorchNormalizer(),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "I8NgHxNqK9uV" + }, + "outputs": [], + "source": [ + "from typing import Dict, List, Optional, Union\n", + "\n", + "from lightning.pytorch import LightningModule\n", + "from lightning.pytorch.utilities.types import STEP_OUTPUT\n", + "import torch\n", + "from torch.optim import Optimizer\n", + "\n", + "\n", + "class BaseModel(LightningModule):\n", + " def __init__(\n", + " self,\n", + " loss: nn.Module,\n", + " logging_metrics: Optional[List[nn.Module]] = None,\n", + " optimizer: Optional[Union[Optimizer, str]] = \"adam\",\n", + " optimizer_params: Optional[Dict] = None,\n", + " lr_scheduler: Optional[str] = None,\n", + " lr_scheduler_params: Optional[Dict] = None,\n", + " ):\n", + " \"\"\"\n", + " Base model for time series forecasting.\n", + "\n", + " Parameters\n", + " ----------\n", + " loss : nn.Module\n", + " Loss function to use for training.\n", + " logging_metrics : Optional[List[nn.Module]], optional\n", + " List of metrics to log during training, validation, and testing.\n", + " optimizer : Optional[Union[Optimizer, str]], optional\n", + " Optimizer to use for training.\n", + " Can be a string (\"adam\", \"sgd\") or an instance of `torch.optim.Optimizer`.\n", + " optimizer_params : Optional[Dict], optional\n", + " Parameters for the optimizer.\n", + " lr_scheduler : Optional[str], optional\n", + " Learning rate scheduler to use.\n", + " Supported values: \"reduce_lr_on_plateau\", \"step_lr\".\n", + " lr_scheduler_params : Optional[Dict], optional\n", + " Parameters for the learning rate scheduler.\n", + " \"\"\"\n", + " super().__init__()\n", + " self.loss = loss\n", + " self.logging_metrics = logging_metrics if logging_metrics is not None else []\n", + " self.optimizer = optimizer\n", + " self.optimizer_params = optimizer_params if optimizer_params is not None else {}\n", + " self.lr_scheduler = lr_scheduler\n", + " self.lr_scheduler_params = (\n", + " lr_scheduler_params if lr_scheduler_params is not None else {}\n", + " )\n", + "\n", + " def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n", + " \"\"\"\n", + " Forward pass of the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " x : Dict[str, torch.Tensor]\n", + " Dictionary containing input tensors\n", + "\n", + " Returns\n", + " -------\n", + " Dict[str, torch.Tensor]\n", + " Dictionary containing output tensors\n", + " \"\"\"\n", + " raise NotImplementedError(\"Forward method must be implemented by subclass.\")\n", + "\n", + " def training_step(\n", + " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", + " ) -> STEP_OUTPUT:\n", + " \"\"\"\n", + " Training step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input and target tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + "\n", + " Returns\n", + " -------\n", + " STEP_OUTPUT\n", + " Dictionary containing the loss and other metrics.\n", + " \"\"\"\n", + " x, y = batch\n", + " y_hat_dict = self(x)\n", + " y_hat = y_hat_dict[\"prediction\"]\n", + " loss = self.loss(y_hat, y)\n", + " self.log(\n", + " \"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True\n", + " )\n", + " self.log_metrics(y_hat, y, prefix=\"train\")\n", + " return {\"loss\": loss}\n", + "\n", + " def validation_step(\n", + " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", + " ) -> STEP_OUTPUT:\n", + " \"\"\"\n", + " Validation step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input and target tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + "\n", + " Returns\n", + " -------\n", + " STEP_OUTPUT\n", + " Dictionary containing the loss and other metrics.\n", + " \"\"\"\n", + " x, y = batch\n", + " y_hat_dict = self(x)\n", + " y_hat = y_hat_dict[\"prediction\"]\n", + " loss = self.loss(y_hat, y)\n", + " self.log(\n", + " \"val_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", + " )\n", + " self.log_metrics(y_hat, y, prefix=\"val\")\n", + " return {\"val_loss\": loss}\n", + "\n", + " def test_step(\n", + " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", + " ) -> STEP_OUTPUT:\n", + " \"\"\"\n", + " Test step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input and target tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + "\n", + " Returns\n", + " -------\n", + " STEP_OUTPUT\n", + " Dictionary containing the loss and other metrics.\n", + " \"\"\"\n", + " x, y = batch\n", + " y_hat_dict = self(x)\n", + " y_hat = y_hat_dict[\"prediction\"]\n", + " loss = self.loss(y_hat, y)\n", + " self.log(\n", + " \"test_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", + " )\n", + " self.log_metrics(y_hat, y, prefix=\"test\")\n", + " return {\"test_loss\": loss}\n", + "\n", + " def predict_step(\n", + " self,\n", + " batch: Tuple[Dict[str, torch.Tensor]],\n", + " batch_idx: int,\n", + " dataloader_idx: int = 0,\n", + " ) -> torch.Tensor:\n", + " \"\"\"\n", + " Prediction step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + " dataloader_idx : int\n", + " Index of the dataloader.\n", + "\n", + " Returns\n", + " -------\n", + " torch.Tensor\n", + " Predicted output tensor.\n", + " \"\"\"\n", + " x, _ = batch\n", + " y_hat = self(x)\n", + " return y_hat\n", + "\n", + " def configure_optimizers(self) -> Dict:\n", + " \"\"\"\n", + " Configure the optimizer and learning rate scheduler.\n", + "\n", + " Returns\n", + " -------\n", + " Dict\n", + " Dictionary containing the optimizer and scheduler configuration.\n", + " \"\"\"\n", + " optimizer = self._get_optimizer()\n", + " if self.lr_scheduler is not None:\n", + " scheduler = self._get_scheduler(optimizer)\n", + " if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n", + " return {\n", + " \"optimizer\": optimizer,\n", + " \"lr_scheduler\": {\n", + " \"scheduler\": scheduler,\n", + " \"monitor\": \"val_loss\",\n", + " },\n", + " }\n", + " else:\n", + " return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n", + " return {\"optimizer\": optimizer}\n", + "\n", + " def _get_optimizer(self) -> Optimizer:\n", + " \"\"\"\n", + " Get the optimizer based on the specified optimizer name and parameters.\n", + "\n", + " Returns\n", + " -------\n", + " Optimizer\n", + " The optimizer instance.\n", + " \"\"\"\n", + " if isinstance(self.optimizer, str):\n", + " if self.optimizer.lower() == \"adam\":\n", + " return torch.optim.Adam(self.parameters(), **self.optimizer_params)\n", + " elif self.optimizer.lower() == \"sgd\":\n", + " return torch.optim.SGD(self.parameters(), **self.optimizer_params)\n", + " else:\n", + " raise ValueError(f\"Optimizer {self.optimizer} not supported.\")\n", + " elif isinstance(self.optimizer, Optimizer):\n", + " return self.optimizer\n", + " else:\n", + " raise ValueError(\n", + " \"Optimizer must be either a string or \"\n", + " \"an instance of torch.optim.Optimizer.\"\n", + " )\n", + "\n", + " def _get_scheduler(\n", + " self, optimizer: Optimizer\n", + " ) -> torch.optim.lr_scheduler._LRScheduler:\n", + " \"\"\"\n", + " Get the lr scheduler based on the specified scheduler name and params.\n", + "\n", + " Parameters\n", + " ----------\n", + " optimizer : Optimizer\n", + " The optimizer instance.\n", + "\n", + " Returns\n", + " -------\n", + " torch.optim.lr_scheduler._LRScheduler\n", + " The learning rate scheduler instance.\n", + " \"\"\"\n", + " if self.lr_scheduler.lower() == \"reduce_lr_on_plateau\":\n", + " return torch.optim.lr_scheduler.ReduceLROnPlateau(\n", + " optimizer, **self.lr_scheduler_params\n", + " )\n", + " elif self.lr_scheduler.lower() == \"step_lr\":\n", + " return torch.optim.lr_scheduler.StepLR(\n", + " optimizer, **self.lr_scheduler_params\n", + " )\n", + " else:\n", + " raise ValueError(f\"Scheduler {self.lr_scheduler} not supported.\")\n", + "\n", + " def log_metrics(\n", + " self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = \"val\"\n", + " ) -> None:\n", + " \"\"\"\n", + " Log additional metrics during training, validation, or testing.\n", + "\n", + " Parameters\n", + " ----------\n", + " y_hat : torch.Tensor\n", + " Predicted output tensor.\n", + " y : torch.Tensor\n", + " Target output tensor.\n", + " prefix : str\n", + " Prefix for the logged metrics (e.g., \"train\", \"val\", \"test\").\n", + " \"\"\"\n", + " for metric in self.logging_metrics:\n", + " metric_value = metric(y_hat, y)\n", + " self.log(\n", + " f\"{prefix}_{metric.__class__.__name__}\",\n", + " metric_value,\n", + " on_step=False,\n", + " on_epoch=True,\n", + " prog_bar=True,\n", + " logger=True,\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "n5EBeGucK_k0" + }, + "outputs": [], + "source": [ + "from typing import Dict, List, Optional, Union\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import Optimizer\n", + "\n", + "\n", + "class TFT(BaseModel):\n", + " def __init__(\n", + " self,\n", + " loss: nn.Module,\n", + " logging_metrics: Optional[List[nn.Module]] = None,\n", + " optimizer: Optional[Union[Optimizer, str]] = \"adam\",\n", + " optimizer_params: Optional[Dict] = None,\n", + " lr_scheduler: Optional[str] = None,\n", + " lr_scheduler_params: Optional[Dict] = None,\n", + " hidden_size: int = 64,\n", + " num_layers: int = 2,\n", + " attention_head_size: int = 4,\n", + " dropout: float = 0.1,\n", + " metadata: Optional[Dict] = None,\n", + " output_size: int = 1,\n", + " ):\n", + " super().__init__(\n", + " loss=loss,\n", + " logging_metrics=logging_metrics,\n", + " optimizer=optimizer,\n", + " optimizer_params=optimizer_params,\n", + " lr_scheduler=lr_scheduler,\n", + " lr_scheduler_params=lr_scheduler_params,\n", + " )\n", + " self.hidden_size = hidden_size\n", + " self.num_layers = num_layers\n", + " self.attention_head_size = attention_head_size\n", + " self.dropout = dropout\n", + " self.metadata = metadata\n", + " self.output_size = output_size\n", + "\n", + " self.max_encoder_length = self.metadata[\"max_encoder_length\"]\n", + " self.max_prediction_length = self.metadata[\"max_prediction_length\"]\n", + " self.encoder_cont = self.metadata[\"encoder_cont\"]\n", + " self.encoder_cat = self.metadata[\"encoder_cat\"]\n", + " self.static_categorical_features = self.metadata[\"static_categorical_features\"]\n", + " self.static_continuous_features = self.metadata[\"static_continuous_features\"]\n", + "\n", + " total_feature_size = self.encoder_cont + self.encoder_cat\n", + " total_static_size = (\n", + " self.static_categorical_features + self.static_continuous_features\n", + " )\n", + "\n", + " self.encoder_var_selection = nn.Sequential(\n", + " nn.Linear(total_feature_size, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_size, total_feature_size),\n", + " nn.Sigmoid(),\n", + " )\n", + "\n", + " self.decoder_var_selection = nn.Sequential(\n", + " nn.Linear(total_feature_size, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_size, total_feature_size),\n", + " nn.Sigmoid(),\n", + " )\n", + "\n", + " self.static_context_linear = (\n", + " nn.Linear(total_static_size, hidden_size) if total_static_size > 0 else None\n", + " )\n", + "\n", + " self.lstm_encoder = nn.LSTM(\n", + " input_size=total_feature_size,\n", + " hidden_size=hidden_size,\n", + " num_layers=num_layers,\n", + " dropout=dropout,\n", + " batch_first=True,\n", + " )\n", + "\n", + " self.lstm_decoder = nn.LSTM(\n", + " input_size=total_feature_size,\n", + " hidden_size=hidden_size,\n", + " num_layers=num_layers,\n", + " dropout=dropout,\n", + " batch_first=True,\n", + " )\n", + "\n", + " self.self_attention = nn.MultiheadAttention(\n", + " embed_dim=hidden_size,\n", + " num_heads=attention_head_size,\n", + " dropout=dropout,\n", + " batch_first=True,\n", + " )\n", + "\n", + " self.pre_output = nn.Linear(hidden_size, hidden_size)\n", + " self.output_layer = nn.Linear(hidden_size, output_size)\n", + "\n", + " def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n", + " \"\"\"\n", + " Forward pass of the TFT model.\n", + "\n", + " Parameters\n", + " ----------\n", + " x : Dict[str, torch.Tensor]\n", + " Dictionary containing input tensors:\n", + " - encoder_cat: Categorical encoder features\n", + " - encoder_cont: Continuous encoder features\n", + " - decoder_cat: Categorical decoder features\n", + " - decoder_cont: Continuous decoder features\n", + " - static_categorical_features: Static categorical features\n", + " - static_continuous_features: Static continuous features\n", + "\n", + " Returns\n", + " -------\n", + " Dict[str, torch.Tensor]\n", + " Dictionary containing output tensors:\n", + " - prediction: Prediction output (batch_size, prediction_length, output_size)\n", + " \"\"\"\n", + " batch_size = x[\"encoder_cont\"].shape[0]\n", + "\n", + " encoder_cat = x.get(\n", + " \"encoder_cat\",\n", + " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", + " )\n", + " encoder_cont = x.get(\n", + " \"encoder_cont\",\n", + " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", + " )\n", + " decoder_cat = x.get(\n", + " \"decoder_cat\",\n", + " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", + " )\n", + " decoder_cont = x.get(\n", + " \"decoder_cont\",\n", + " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", + " )\n", + "\n", + " encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2)\n", + " decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2)\n", + "\n", + " static_context = None\n", + " if self.static_context_linear is not None:\n", + " static_cat = x.get(\n", + " \"static_categorical_features\",\n", + " torch.zeros(batch_size, 0, device=self.device),\n", + " )\n", + " static_cont = x.get(\n", + " \"static_continuous_features\",\n", + " torch.zeros(batch_size, 0, device=self.device),\n", + " )\n", + "\n", + " if static_cat.size(2) == 0 and static_cont.size(2) == 0:\n", + " static_context = None\n", + " elif static_cat.size(2) == 0:\n", + " static_input = static_cont.to(\n", + " dtype=self.static_context_linear.weight.dtype\n", + " )\n", + " static_context = self.static_context_linear(static_input)\n", + " static_context = static_context.view(batch_size, self.hidden_size)\n", + " elif static_cont.size(2) == 0:\n", + " static_input = static_cat.to(\n", + " dtype=self.static_context_linear.weight.dtype\n", + " )\n", + " static_context = self.static_context_linear(static_input)\n", + " static_context = static_context.view(batch_size, self.hidden_size)\n", + " else:\n", + "\n", + " static_input = torch.cat([static_cont, static_cat], dim=1).to(\n", + " dtype=self.static_context_linear.weight.dtype\n", + " )\n", + " static_context = self.static_context_linear(static_input)\n", + " static_context = static_context.view(batch_size, self.hidden_size)\n", + "\n", + " encoder_weights = self.encoder_var_selection(encoder_input)\n", + " encoder_input = encoder_input * encoder_weights\n", + "\n", + " decoder_weights = self.decoder_var_selection(decoder_input)\n", + " decoder_input = decoder_input * decoder_weights\n", + "\n", + " if static_context is not None:\n", + " encoder_static_context = static_context.unsqueeze(1).expand(\n", + " -1, self.max_encoder_length, -1\n", + " )\n", + " decoder_static_context = static_context.unsqueeze(1).expand(\n", + " -1, self.max_prediction_length, -1\n", + " )\n", + "\n", + " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", + " encoder_output = encoder_output + encoder_static_context\n", + " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", + " decoder_output = decoder_output + decoder_static_context\n", + " else:\n", + " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", + " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", + "\n", + " sequence = torch.cat([encoder_output, decoder_output], dim=1)\n", + "\n", + " if static_context is not None:\n", + " expanded_static_context = static_context.unsqueeze(1).expand(\n", + " -1, sequence.size(1), -1\n", + " )\n", + "\n", + " attended_output, _ = self.self_attention(\n", + " sequence + expanded_static_context, sequence, sequence\n", + " )\n", + " else:\n", + " attended_output, _ = self.self_attention(sequence, sequence, sequence)\n", + "\n", + " decoder_attended = attended_output[:, -self.max_prediction_length :, :]\n", + "\n", + " output = nn.functional.relu(self.pre_output(decoder_attended))\n", + " prediction = self.output_layer(output)\n", + "\n", + " return {\"prediction\": prediction}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "f422383b4af64a26b16842297d5c84a6", + "51455b67f5e741b192b5b832aff3d3d3", + "ccd9a65e258244f4bd29c741c3cb4441", + "8b5d5e147c8142328d30e2492b57e2c1", + "3529aafd3f884efe8ba4bb3e33a05c5c", + "3852c1b855934e6e862ceaf37a855700", + "5ad79ce792314292a0425b6f03b85881", + "e9dcea6508f344caa5f1f99e2bfdd4db", + "2b8d1923ff3145789f77661364163cd1", + "2b2eb8cd85ba4c16aefa42634b5458cb", + "edc278a53bb74140ad1a416406f0b0bb", + "7b4b93e92400404abfd2c31129cda96f", + "35e7d924e4e44ceaad836bde8fef17d8", + "6922f0547118463e9468196bf9b1a5c2", + "119506368b47405786c806373b0bb67d", + "27ec3bfc950847afb2850ad9ea9bdd49", + "1016ad0f475740ef9e2d5054323d22e7", + "b232c9df604a4ce59f422cf60499605f", + "f9cb515fd148446ca5fcb47fab504a74", + "adf882d15ee84b2aa3ad6145e2c913bf", + "f9b429aef6904d6995836f1c1279d38d", + "8151f2de21b44422876341b10e6b8568", + "8e46e7b0ea3b429293dc6358980dffe2", + "831387d48f5f4b8cb0662c5df77aefa4", + "d08f30db32ec440f94d88727adfc3ee6", + "9dd7c13be6a54e6b843ccb8afd59fb6b", + "2db1853b41104999ab6884b33b217f0b", + "b868fe32ecc44b50b823ae533755a8ef", + "0d2d00cf7992454484dfd939928aca14", + "f6131d3c810442c5ad4f71177dac1f5e", + "8997cb6735c341bc985d4aacb6e20999", + "df6002b4a12a49fe8d62210c5ebd5b06", + "adaafe80e3ee4a5c8d6d987800630f2e", + "2828e899a43b4f17b0bd7c86a701152b", + "059bbf019de944fda861994e0e7439dd", + "197c1200910e4256945d64dd9ab33902", + "4a22ad1ba0c240a3af31a06d117220fb", + "3ac302bd9b054b9dbe0b76ba7575db68", + "78f9afd80d634c08965990e66297e5b6", + "b791ccce40544f6ba57198a0392df981", + "bff2fcab4d574dc0bcc7382a17f6b080", + "607dc1dee95147f4a39b16170a08f780", + "0c40053cf7ba496d87691bfeee0975e4", + "7b989ec8edc148cb9c3e58aa61637a3b", + "94a5cc9098f5445593b3212857bd0da2", + "c236df5d6fdf4bd1b9fdf687819bdffb", + "e3fcdbd89eb64510b3ed6947d728de1f", + "52a0d1617601472590aa3b5641236890", + "ef103e08597e424a81ff70060b4edc77", + "d25815a46ca64bbea97e6b7f4fff954a", + "3eb0147ac4d345faae3713d64eb0e66c", + "a561990cc8204ef08effecce2855a882", + "afd962bafddc4be2b6ec83067402a19a", + "c9de18b7a60148e1b35c0dfb0b20e0d5", + "5f4abc4cadbc4fc7aae645d53f2a03aa", + "02a2df129afd4ec5a0e6a11ad8d67ee3", + "eb40207588114030aeb2e66e6a66858b", + "344cd98e1e4f49dba72ad6df49695e1d", + "4f5df150e0044ee798c04fa2c722f6c0", + "5307c74c93dc472bbeeb88e6bfb487f1", + "8792c8ef2b3d43748b0062327d901bb3", + "fa845f388d7844a388f079d7cb115824", + "3d59af9594a445d68dbcc8e51994fcbf", + "fb7a0d9d587e4806a7519d9201f7f0aa", + "13d0c7ce5a394744a44439556d43d24d", + "d09998b10967447e917c8ac88abf26db", + "e03f176873e6485a91118a9cb4d7fa4d", + "3866642b5f6c4eaa8f3982da9efa2a4b", + "fb1c1f0130934f24b163d87846f4210c", + "0c32b264b85b437f825531554a51e7e1", + "056e8d3d437c40c889f19e316fbd38ac", + "f2fb8e95caad4be792063e44a9401f2c", + "46fd6f7300514bd9b11abeafc51970ec", + "642bce0f7ab848459f4b5c91068113d4", + "dc1b070cf7bd43ed8c422e2c1cec6438", + "2fad0082ed35466caaf6948f6ff4aac0", + "2d9cd71eec004370b5780a7814d4dd7b", + "db2a9324292e415e8765117e3dab8356", + "fae0d1779ea045108616dad057ed24e1", + "d22ee6b52bbf45bdbe05eb3c23b966c9", + "e579bbd44a5241f7a7fa027147295ec3", + "9eb0c593adf7413da7a4abeec48af166", + "57cb9f2255e4487383465ea0a438a01b", + "b3d493c1744248fc92301d57dcc07fe9", + "c8ddbeb5fdb9447c88561abb59edac29", + "6fd91510b57043c093e4a1fd8302fec3", + "e623887237994d72bfbff0fb2d76e697", + "debe1c18531a401e8e3e93791787e1f8" + ] + }, + "id": "Si7bbZIULBZz", + "outputId": "8c1308d6-c1b3-4a67-f56c-f11c225b592c" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.\n", + "INFO:lightning.pytorch.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.\n", + "INFO: GPU available: False, used: False\n", + "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", + "INFO: TPU available: False, using: 0 TPU cores\n", + "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO: HPU available: False, using: 0 HPUs\n", + "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n", + "INFO: \n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 709 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 51.5 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "125 K Trainable params\n", + "0 Non-trainable params\n", + "125 K Total params\n", + "0.502 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n", + "INFO:lightning.pytorch.callbacks.model_summary:\n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 709 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 51.5 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "125 K Trainable params\n", + "0 Non-trainable params\n", + "125 K Total params\n", + "0.502 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Training model...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f422383b4af64a26b16842297d5c84a6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test_MAE 0.4572468101978302 │\n", + "│ test_SMAPE 1.0497652292251587 │\n", + "│ test_loss 0.022910255938768387 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.4572468101978302 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0497652292251587 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.022910255938768387 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Prediction shape: torch.Size([32, 1, 1])\n", + "First prediction values: [[-0.06341379]]\n", + "First true values: [[0.08132173]]\n", + "\n", + "TFT model test complete!\n" + ] + } + ], + "source": [ + "model = TFT(\n", + " loss=nn.MSELoss(),\n", + " logging_metrics=[MAE(), SMAPE()],\n", + " optimizer=\"adam\",\n", + " optimizer_params={\"lr\": 1e-3},\n", + " lr_scheduler=\"reduce_lr_on_plateau\",\n", + " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", + " hidden_size=64,\n", + " num_layers=2,\n", + " attention_head_size=4,\n", + " dropout=0.1,\n", + " metadata=data_module.metadata,\n", + ")\n", + "\n", + "print(\"\\nTraining model...\")\n", + "trainer = Trainer(max_epochs=5, accelerator=\"auto\", devices=1, enable_progress_bar=True)\n", + "\n", + "trainer.fit(model, data_module)\n", + "\n", + "print(\"\\nEvaluating model...\")\n", + "test_metrics = trainer.test(model, data_module)\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " test_batch = next(iter(data_module.test_dataloader()))\n", + " x_test, y_test = test_batch\n", + " y_pred = model(x_test)\n", + "\n", + " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", + " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", + " print(\"First true values:\", y_test[0].cpu().numpy())\n", + "print(\"\\nTFT model test complete!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zVRwi2MvLGgc" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "02a2df129afd4ec5a0e6a11ad8d67ee3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_eb40207588114030aeb2e66e6a66858b", + "IPY_MODEL_344cd98e1e4f49dba72ad6df49695e1d", + "IPY_MODEL_4f5df150e0044ee798c04fa2c722f6c0" + ], + "layout": "IPY_MODEL_5307c74c93dc472bbeeb88e6bfb487f1" + } + }, + "056e8d3d437c40c889f19e316fbd38ac": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "059bbf019de944fda861994e0e7439dd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_78f9afd80d634c08965990e66297e5b6", + "placeholder": "​", + "style": "IPY_MODEL_b791ccce40544f6ba57198a0392df981", + "value": "Validation DataLoader 0: 100%" + } + }, + "0c32b264b85b437f825531554a51e7e1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2fad0082ed35466caaf6948f6ff4aac0", + "placeholder": "​", + "style": "IPY_MODEL_2d9cd71eec004370b5780a7814d4dd7b", + "value": " 9/9 [00:00<00:00, 31.33it/s]" + } + }, + "0c40053cf7ba496d87691bfeee0975e4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0d2d00cf7992454484dfd939928aca14": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "1016ad0f475740ef9e2d5054323d22e7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "119506368b47405786c806373b0bb67d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f9b429aef6904d6995836f1c1279d38d", + "placeholder": "​", + "style": "IPY_MODEL_8151f2de21b44422876341b10e6b8568", + "value": " 42/42 [00:02<00:00, 15.49it/s, v_num=2, train_loss_step=0.010, val_loss=0.0243, val_MAE=0.477, val_SMAPE=1.120, train_loss_epoch=0.0156, train_MAE=0.480, train_SMAPE=1.020]" + } + }, + "13d0c7ce5a394744a44439556d43d24d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "197c1200910e4256945d64dd9ab33902": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bff2fcab4d574dc0bcc7382a17f6b080", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_607dc1dee95147f4a39b16170a08f780", + "value": 9 + } + }, + "27ec3bfc950847afb2850ad9ea9bdd49": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "100%" + } + }, + "2828e899a43b4f17b0bd7c86a701152b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_059bbf019de944fda861994e0e7439dd", + "IPY_MODEL_197c1200910e4256945d64dd9ab33902", + "IPY_MODEL_4a22ad1ba0c240a3af31a06d117220fb" + ], + "layout": "IPY_MODEL_3ac302bd9b054b9dbe0b76ba7575db68" + } + }, + "2b2eb8cd85ba4c16aefa42634b5458cb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2b8d1923ff3145789f77661364163cd1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "2d9cd71eec004370b5780a7814d4dd7b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "2db1853b41104999ab6884b33b217f0b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "2fad0082ed35466caaf6948f6ff4aac0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "344cd98e1e4f49dba72ad6df49695e1d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3d59af9594a445d68dbcc8e51994fcbf", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_fb7a0d9d587e4806a7519d9201f7f0aa", + "value": 9 + } + }, + "3529aafd3f884efe8ba4bb3e33a05c5c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "35e7d924e4e44ceaad836bde8fef17d8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1016ad0f475740ef9e2d5054323d22e7", + "placeholder": "​", + "style": "IPY_MODEL_b232c9df604a4ce59f422cf60499605f", + "value": "Epoch 4: 100%" + } + }, + "3852c1b855934e6e862ceaf37a855700": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3866642b5f6c4eaa8f3982da9efa2a4b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f2fb8e95caad4be792063e44a9401f2c", + "placeholder": "​", + "style": "IPY_MODEL_46fd6f7300514bd9b11abeafc51970ec", + "value": "Validation DataLoader 0: 100%" + } + }, + "3ac302bd9b054b9dbe0b76ba7575db68": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "3d59af9594a445d68dbcc8e51994fcbf": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3eb0147ac4d345faae3713d64eb0e66c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "46fd6f7300514bd9b11abeafc51970ec": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "4a22ad1ba0c240a3af31a06d117220fb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0c40053cf7ba496d87691bfeee0975e4", + "placeholder": "​", + "style": "IPY_MODEL_7b989ec8edc148cb9c3e58aa61637a3b", + "value": " 9/9 [00:00<00:00, 25.66it/s]" + } + }, + "4f5df150e0044ee798c04fa2c722f6c0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_13d0c7ce5a394744a44439556d43d24d", + "placeholder": "​", + "style": "IPY_MODEL_d09998b10967447e917c8ac88abf26db", + "value": " 9/9 [00:00<00:00, 32.96it/s]" + } + }, + "51455b67f5e741b192b5b832aff3d3d3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3852c1b855934e6e862ceaf37a855700", + "placeholder": "​", + "style": "IPY_MODEL_5ad79ce792314292a0425b6f03b85881", + "value": "Sanity Checking DataLoader 0: 100%" + } + }, + "52a0d1617601472590aa3b5641236890": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c9de18b7a60148e1b35c0dfb0b20e0d5", + "placeholder": "​", + "style": "IPY_MODEL_5f4abc4cadbc4fc7aae645d53f2a03aa", + "value": " 9/9 [00:00<00:00, 34.85it/s]" + } + }, + "5307c74c93dc472bbeeb88e6bfb487f1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "57cb9f2255e4487383465ea0a438a01b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5ad79ce792314292a0425b6f03b85881": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5f4abc4cadbc4fc7aae645d53f2a03aa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "607dc1dee95147f4a39b16170a08f780": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "642bce0f7ab848459f4b5c91068113d4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6922f0547118463e9468196bf9b1a5c2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f9cb515fd148446ca5fcb47fab504a74", + "max": 42, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_adf882d15ee84b2aa3ad6145e2c913bf", + "value": 42 + } + }, + "6fd91510b57043c093e4a1fd8302fec3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "78f9afd80d634c08965990e66297e5b6": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7b4b93e92400404abfd2c31129cda96f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_35e7d924e4e44ceaad836bde8fef17d8", + "IPY_MODEL_6922f0547118463e9468196bf9b1a5c2", + "IPY_MODEL_119506368b47405786c806373b0bb67d" + ], + "layout": "IPY_MODEL_27ec3bfc950847afb2850ad9ea9bdd49" + } + }, + "7b989ec8edc148cb9c3e58aa61637a3b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "8151f2de21b44422876341b10e6b8568": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "831387d48f5f4b8cb0662c5df77aefa4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b868fe32ecc44b50b823ae533755a8ef", + "placeholder": "​", + "style": "IPY_MODEL_0d2d00cf7992454484dfd939928aca14", + "value": "Validation DataLoader 0: 100%" + } + }, + "8792c8ef2b3d43748b0062327d901bb3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8997cb6735c341bc985d4aacb6e20999": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "8b5d5e147c8142328d30e2492b57e2c1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2b2eb8cd85ba4c16aefa42634b5458cb", + "placeholder": "​", + "style": "IPY_MODEL_edc278a53bb74140ad1a416406f0b0bb", + "value": " 2/2 [00:00<00:00, 21.87it/s]" + } + }, + "8e46e7b0ea3b429293dc6358980dffe2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_831387d48f5f4b8cb0662c5df77aefa4", + "IPY_MODEL_d08f30db32ec440f94d88727adfc3ee6", + "IPY_MODEL_9dd7c13be6a54e6b843ccb8afd59fb6b" + ], + "layout": "IPY_MODEL_2db1853b41104999ab6884b33b217f0b" + } + }, + "94a5cc9098f5445593b3212857bd0da2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_c236df5d6fdf4bd1b9fdf687819bdffb", + "IPY_MODEL_e3fcdbd89eb64510b3ed6947d728de1f", + "IPY_MODEL_52a0d1617601472590aa3b5641236890" + ], + "layout": "IPY_MODEL_ef103e08597e424a81ff70060b4edc77" + } + }, + "9dd7c13be6a54e6b843ccb8afd59fb6b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_df6002b4a12a49fe8d62210c5ebd5b06", + "placeholder": "​", + "style": "IPY_MODEL_adaafe80e3ee4a5c8d6d987800630f2e", + "value": " 9/9 [00:00<00:00, 33.87it/s]" + } + }, + "9eb0c593adf7413da7a4abeec48af166": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "100%" + } + }, + "a561990cc8204ef08effecce2855a882": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "adaafe80e3ee4a5c8d6d987800630f2e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "adf882d15ee84b2aa3ad6145e2c913bf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "afd962bafddc4be2b6ec83067402a19a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b232c9df604a4ce59f422cf60499605f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b3d493c1744248fc92301d57dcc07fe9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b791ccce40544f6ba57198a0392df981": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b868fe32ecc44b50b823ae533755a8ef": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bff2fcab4d574dc0bcc7382a17f6b080": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c236df5d6fdf4bd1b9fdf687819bdffb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d25815a46ca64bbea97e6b7f4fff954a", + "placeholder": "​", + "style": "IPY_MODEL_3eb0147ac4d345faae3713d64eb0e66c", + "value": "Validation DataLoader 0: 100%" + } + }, + "c8ddbeb5fdb9447c88561abb59edac29": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c9de18b7a60148e1b35c0dfb0b20e0d5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ccd9a65e258244f4bd29c741c3cb4441": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e9dcea6508f344caa5f1f99e2bfdd4db", + "max": 2, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_2b8d1923ff3145789f77661364163cd1", + "value": 2 + } + }, + "d08f30db32ec440f94d88727adfc3ee6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f6131d3c810442c5ad4f71177dac1f5e", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_8997cb6735c341bc985d4aacb6e20999", + "value": 9 + } + }, + "d09998b10967447e917c8ac88abf26db": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "d22ee6b52bbf45bdbe05eb3c23b966c9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c8ddbeb5fdb9447c88561abb59edac29", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_6fd91510b57043c093e4a1fd8302fec3", + "value": 9 + } + }, + "d25815a46ca64bbea97e6b7f4fff954a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "db2a9324292e415e8765117e3dab8356": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_fae0d1779ea045108616dad057ed24e1", + "IPY_MODEL_d22ee6b52bbf45bdbe05eb3c23b966c9", + "IPY_MODEL_e579bbd44a5241f7a7fa027147295ec3" + ], + "layout": "IPY_MODEL_9eb0c593adf7413da7a4abeec48af166" + } + }, + "dc1b070cf7bd43ed8c422e2c1cec6438": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "debe1c18531a401e8e3e93791787e1f8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "df6002b4a12a49fe8d62210c5ebd5b06": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e03f176873e6485a91118a9cb4d7fa4d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_3866642b5f6c4eaa8f3982da9efa2a4b", + "IPY_MODEL_fb1c1f0130934f24b163d87846f4210c", + "IPY_MODEL_0c32b264b85b437f825531554a51e7e1" + ], + "layout": "IPY_MODEL_056e8d3d437c40c889f19e316fbd38ac" + } + }, + "e3fcdbd89eb64510b3ed6947d728de1f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a561990cc8204ef08effecce2855a882", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_afd962bafddc4be2b6ec83067402a19a", + "value": 9 + } + }, + "e579bbd44a5241f7a7fa027147295ec3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e623887237994d72bfbff0fb2d76e697", + "placeholder": "​", + "style": "IPY_MODEL_debe1c18531a401e8e3e93791787e1f8", + "value": " 9/9 [00:00<00:00, 33.35it/s]" + } + }, + "e623887237994d72bfbff0fb2d76e697": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e9dcea6508f344caa5f1f99e2bfdd4db": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "eb40207588114030aeb2e66e6a66858b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8792c8ef2b3d43748b0062327d901bb3", + "placeholder": "​", + "style": "IPY_MODEL_fa845f388d7844a388f079d7cb115824", + "value": "Validation DataLoader 0: 100%" + } + }, + "edc278a53bb74140ad1a416406f0b0bb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ef103e08597e424a81ff70060b4edc77": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "f2fb8e95caad4be792063e44a9401f2c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f422383b4af64a26b16842297d5c84a6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_51455b67f5e741b192b5b832aff3d3d3", + "IPY_MODEL_ccd9a65e258244f4bd29c741c3cb4441", + "IPY_MODEL_8b5d5e147c8142328d30e2492b57e2c1" + ], + "layout": "IPY_MODEL_3529aafd3f884efe8ba4bb3e33a05c5c" + } + }, + "f6131d3c810442c5ad4f71177dac1f5e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f9b429aef6904d6995836f1c1279d38d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f9cb515fd148446ca5fcb47fab504a74": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fa845f388d7844a388f079d7cb115824": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "fae0d1779ea045108616dad057ed24e1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_57cb9f2255e4487383465ea0a438a01b", + "placeholder": "​", + "style": "IPY_MODEL_b3d493c1744248fc92301d57dcc07fe9", + "value": "Testing DataLoader 0: 100%" + } + }, + "fb1c1f0130934f24b163d87846f4210c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_642bce0f7ab848459f4b5c91068113d4", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_dc1b070cf7bd43ed8c422e2c1cec6438", + "value": 9 + } + }, + "fb7a0d9d587e4806a7519d9201f7f0aa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From 6364780ae121298e3d98a2c14c6f6747bf62a7b4 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 19:34:57 +0530 Subject: [PATCH 05/43] update docstring --- pytorch_forecasting/data/timeseries.py | 44 +++++++++++++++----------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index bc8300300..9da02d3a0 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2815,25 +2815,31 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: """Get time series data for given index. - It returns: - - * ``t``: ``numpy.ndarray`` of shape (n_timepoints,) - Time index for each time point in the past or present. Aligned with ``y``, - and ``x`` not ending in ``f``. - * ``y``: tensor of shape (n_timepoints, n_targets) - Target values for each time point. Rows are time points, aligned with ``t``. - * ``x``: tensor of shape (n_timepoints, n_features) - Features for each time point. Rows are time points, aligned with ``t``. - * ``group``: tensor of shape (n_groups) - Group identifiers for time series instances. - * ``st``: tensor of shape (n_static_features) - Static features. - * ``cutoff_time``: float or ``numpy.float64`` - Cutoff time for the time series instance. - - Optionally, the following str-keyed entry can be included: - - * ``weights``: tensor of shape (n_timepoints), only if weight is not None + Returns + ------- + t : numpy.ndarray of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with `y`, + and `x` not ending in `f`. + + y : torch.Tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with `t`. + + x : torch.Tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with `t`. + + group : torch.Tensor of shape (n_groups,) + Group identifiers for time series instances. + + st : torch.Tensor of shape (n_static_features,) + Static features. + + cutoff_time : float or numpy.float64 + Cutoff time for the time series instance. + + Other Returns + ------------- + weights : torch.Tensor of shape (n_timepoints,), optional + Only included if weights are not `None`. """ group_id = self._group_ids[index] From 257183ce4d2b1f7fd40c95ecd7dc38c8004a017b Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 01:54:50 +0530 Subject: [PATCH 06/43] update data_module.py --- pytorch_forecasting/data/data_module.py | 160 ++++++++++++------------ 1 file changed, 80 insertions(+), 80 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 2958f1705..c796b85fa 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -1,15 +1,6 @@ -####################################################################################### -# Disclaimer: This data-module is still work in progress and experimental, please -# use with care. This data-module is a basic skeleton of how the data-handling pipeline -# may look like in the future. -# This is D2 layer that will handle the preprocessing and data loaders. -# For now, this pipeline handles the simplest situation: The whole data can be loaded -# into the memory. -####################################################################################### - from typing import Any, Dict, List, Optional, Tuple, Union -from lightning.pytorch import LightningDataModule +from lightning.pytorch import LightningDataModule, LightningModule from sklearn.preprocessing import RobustScaler, StandardScaler import torch from torch.utils.data import DataLoader, Dataset @@ -19,7 +10,11 @@ NaNLabelEncoder, TorchNormalizer, ) -from pytorch_forecasting.data.timeseries import TimeSeries, _coerce_to_dict +from pytorch_forecasting.data.timeseries import ( + TimeSeries, + _coerce_to_dict, + _coerce_to_list, +) NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] @@ -274,7 +269,7 @@ def metadata(self): self._metadata = self._prepare_metadata() return self._metadata - def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: + def _preprocess_data(self, series_idx: torch.Tensor) -> List[Dict[str, Any]]: """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset. Preprocessing steps @@ -288,63 +283,58 @@ def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: TODO: add scalers, target normalizers etc. """ - processed_data = [] + sample = self.time_series_dataset[series_idx] - for idx in indices: - sample = self.time_series_dataset[idx.item()] + target = sample["y"] + features = sample["x"] + times = sample["t"] + cutoff_time = sample["cutoff_time"] - target = sample["y"] - features = sample["x"] - times = sample["t"] - cutoff_time = sample["cutoff_time"] + time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) - time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) - - if isinstance(target, torch.Tensor): - target = target.float() - else: - target = torch.tensor(target, dtype=torch.float32) - - if isinstance(features, torch.Tensor): - features = features.float() - else: - features = torch.tensor(features, dtype=torch.float32) + if isinstance(target, torch.Tensor): + target = target.float() + else: + target = torch.tensor(target, dtype=torch.float32) - # TODO: add scalers, target normalizers etc. + if isinstance(features, torch.Tensor): + features = features.float() + else: + features = torch.tensor(features, dtype=torch.float32) - categorical = ( - features[:, self.categorical_indices] - if self.categorical_indices - else torch.zeros((features.shape[0], 0)) - ) - continuous = ( - features[:, self.continuous_indices] - if self.continuous_indices - else torch.zeros((features.shape[0], 0)) - ) + # TODO: add scalers, target normalizers etc. - processed_data.append( - { - "features": {"categorical": categorical, "continuous": continuous}, - "target": target, - "static": sample.get("st", None), - "group": sample.get("group", torch.tensor([0])), - "length": len(target), - "time_mask": time_mask, - "times": times, - "cutoff_time": cutoff_time, - } - ) + categorical = ( + features[:, self.categorical_indices] + if self.categorical_indices + else torch.zeros((features.shape[0], 0)) + ) + continuous = ( + features[:, self.continuous_indices] + if self.continuous_indices + else torch.zeros((features.shape[0], 0)) + ) - return processed_data + return { + "features": {"categorical": categorical, "continuous": continuous}, + "target": target, + "static": sample.get("st", None), + "group": sample.get("group", torch.tensor([0])), + "length": len(target), + "time_mask": time_mask, + "times": times, + "cutoff_time": cutoff_time, + } class _ProcessedEncoderDecoderDataset(Dataset): """PyTorch Dataset for processed encoder-decoder time series data. Parameters ---------- - processed_data : List[Dict[str, Any]] - List of preprocessed time series samples. + dataset : TimeSeries + The base time series dataset that provides access to raw data and metadata. + data_module : EncoderDecoderTimeSeriesDataModule + The data module handling preprocessing and metadata configuration. windows : List[Tuple[int, int, int, int]] List of window tuples containing (series_idx, start_idx, enc_length, pred_length). @@ -354,11 +344,13 @@ class _ProcessedEncoderDecoderDataset(Dataset): def __init__( self, - processed_data: List[Dict[str, Any]], + dataset: TimeSeries, + data_module: "EncoderDecoderTimeSeriesDataModule", windows: List[Tuple[int, int, int, int]], add_relative_time_idx: bool = False, ): - self.processed_data = processed_data + self.dataset = dataset + self.data_module = data_module self.windows = windows self.add_relative_time_idx = add_relative_time_idx @@ -410,7 +402,7 @@ def __getitem__(self, idx): Target values for the decoder sequence. """ series_idx, start_idx, enc_length, pred_length = self.windows[idx] - data = self.processed_data[series_idx] + data = self.data_module._preprocess_data(series_idx) end_idx = start_idx + enc_length + pred_length encoder_indices = slice(start_idx, start_idx + enc_length) @@ -457,9 +449,7 @@ def __getitem__(self, idx): return x, y - def _create_windows( - self, processed_data: List[Dict[str, Any]] - ) -> List[Tuple[int, int, int, int]]: + def _create_windows(self, indices: torch.Tensor) -> List[Tuple[int, int, int, int]]: """Generate sliding windows for training, validation, and testing. Returns @@ -477,8 +467,10 @@ def _create_windows( """ windows = [] - for idx, data in enumerate(processed_data): - sequence_length = data["length"] + for idx in indices: + series_idx = idx.item() + sample = self.time_series_dataset[series_idx] + sequence_length = len(sample["y"]) if sequence_length < self.max_encoder_length + self.max_prediction_length: continue @@ -503,7 +495,7 @@ def _create_windows( ): windows.append( ( - idx, + series_idx, start_idx, self.max_encoder_length, self.max_prediction_length, @@ -538,33 +530,41 @@ def setup(self, stage: Optional[str] = None): if stage is None or stage == "fit": if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): - self.train_processed = self._preprocess_data(self._train_indices) - self.val_processed = self._preprocess_data(self._val_indices) - - self.train_windows = self._create_windows(self.train_processed) - self.val_windows = self._create_windows(self.val_processed) + self.train_windows = self._create_windows(self._train_indices) + self.val_windows = self._create_windows(self._val_indices) self.train_dataset = self._ProcessedEncoderDecoderDataset( - self.train_processed, self.train_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.train_windows, + self.add_relative_time_idx, ) self.val_dataset = self._ProcessedEncoderDecoderDataset( - self.val_processed, self.val_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.val_windows, + self.add_relative_time_idx, ) - elif stage is None or stage == "test": + elif stage == "test": if not hasattr(self, "test_dataset"): - self.test_processed = self._preprocess_data(self._test_indices) - self.test_windows = self._create_windows(self.test_processed) - + self.test_windows = self._create_windows(self._test_indices) self.test_dataset = self._ProcessedEncoderDecoderDataset( - self.test_processed, self.test_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.test_windows, + self, + self.add_relative_time_idx, ) elif stage == "predict": predict_indices = torch.arange(len(self.time_series_dataset)) - self.predict_processed = self._preprocess_data(predict_indices) - self.predict_windows = self._create_windows(self.predict_processed) + self.predict_windows = self._create_windows(predict_indices) self.predict_dataset = self._ProcessedEncoderDecoderDataset( - self.predict_processed, self.predict_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.predict_windows, + self, + self.add_relative_time_idx, ) def train_dataloader(self): From 9cdcb195c4c9e3f9b6d0e76ef3b6ed889bc14998 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 01:56:55 +0530 Subject: [PATCH 07/43] update data_module.py --- pytorch_forecasting/data/data_module.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index c796b85fa..9a4a5bf5e 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -553,7 +553,6 @@ def setup(self, stage: Optional[str] = None): self.time_series_dataset, self, self.test_windows, - self, self.add_relative_time_idx, ) elif stage == "predict": @@ -563,7 +562,6 @@ def setup(self, stage: Optional[str] = None): self.time_series_dataset, self, self.predict_windows, - self, self.add_relative_time_idx, ) From ac56d4fd56aeeb1287f162559c67e785de4446f4 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 02:05:58 +0530 Subject: [PATCH 08/43] Add disclaimer --- pytorch_forecasting/data/data_module.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 9a4a5bf5e..b33a11d47 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -1,6 +1,15 @@ +####################################################################################### +# Disclaimer: This data-module is still work in progress and experimental, please +# use with care. This data-module is a basic skeleton of how the data-handling pipeline +# may look like in the future. +# This is D2 layer that will handle the preprocessing and data loaders. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + from typing import Any, Dict, List, Optional, Tuple, Union -from lightning.pytorch import LightningDataModule, LightningModule +from lightning.pytorch import LightningDataModule from sklearn.preprocessing import RobustScaler, StandardScaler import torch from torch.utils.data import DataLoader, Dataset @@ -13,7 +22,6 @@ from pytorch_forecasting.data.timeseries import ( TimeSeries, _coerce_to_dict, - _coerce_to_list, ) NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] From 25bc7ee135584d8fa6c83f5273fbbf04a3775c99 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 02:12:33 +0530 Subject: [PATCH 09/43] update notebook as well --- examples/ptf_V2_example.ipynb | 2110 ++++++++++++++++----------------- 1 file changed, 1055 insertions(+), 1055 deletions(-) diff --git a/examples/ptf_V2_example.ipynb b/examples/ptf_V2_example.ipynb index 031d9d634..0d61de395 100644 --- a/examples/ptf_V2_example.ipynb +++ b/examples/ptf_V2_example.ipynb @@ -8,7 +8,7 @@ "base_uri": "https://localhost:8080/" }, "id": "2630DaOEI4AJ", - "outputId": "96798236-d2f1-4436-c047-49c3771d56c7" + "outputId": "a6f99bf0-957b-431a-f512-6abba6629768" }, "outputs": [ { @@ -30,9 +30,9 @@ " Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)\n", "Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (24.2)\n", "Collecting torchmetrics<3.0,>=0.7.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading torchmetrics-1.7.0-py3-none-any.whl.metadata (21 kB)\n", + " Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)\n", "Requirement already satisfied: tqdm<6.0,>=4.57.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.67.1)\n", - "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.13.0)\n", + "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.13.1)\n", "Collecting pytorch-lightning (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", " Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2.8.2)\n", @@ -77,39 +77,39 @@ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.2)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (25.3.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.5.0)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.3.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.3.2)\n", "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.3.1)\n", "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.18.3)\n", "Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.11/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.10)\n", "Downloading pytorch_forecasting-1.3.0-py3-none-any.whl (197 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m197.7/197.7 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m197.7/197.7 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading lightning-2.5.1-py3-none-any.whl (818 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m818.9/818.9 kB\u001b[0m \u001b[31m16.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m818.9/818.9 kB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m41.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m58.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m29.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m55.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m30.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m26.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m31.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m84.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)\n", - "Downloading torchmetrics-1.7.0-py3-none-any.whl (960 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m960.9/960.9 kB\u001b[0m \u001b[31m35.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m961.5/961.5 kB\u001b[0m \u001b[31m51.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading pytorch_lightning-2.5.1-py3-none-any.whl (822 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.0/823.0 kB\u001b[0m \u001b[31m31.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.0/823.0 kB\u001b[0m \u001b[31m40.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch-lightning, lightning, pytorch-forecasting\n", " Attempting uninstall: nvidia-nvjitlink-cu12\n", " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", @@ -151,7 +151,7 @@ " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", - "Successfully installed lightning-2.5.1 lightning-utilities-0.14.3 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-forecasting-1.3.0 pytorch-lightning-2.5.1 torchmetrics-1.7.0\n" + "Successfully installed lightning-2.5.1 lightning-utilities-0.14.3 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-forecasting-1.3.0 pytorch-lightning-2.5.1 torchmetrics-1.7.1\n" ] } ], @@ -161,7 +161,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 2, "metadata": { "id": "M7PQerTbI_tM" }, @@ -179,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 3, "metadata": { "id": "XmL5ukG9JDTD" }, @@ -443,7 +443,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 12, "metadata": { "id": "0Rw9LgsXJI5V" }, @@ -451,7 +451,7 @@ "source": [ "from typing import Dict, List, Optional, Union\n", "\n", - "from lightning.pytorch import LightningDataModule\n", + "from lightning.pytorch import LightningDataModule, LightningModule\n", "from sklearn.preprocessing import RobustScaler, StandardScaler\n", "import torch\n", "from torch.utils.data import DataLoader, Dataset\n", @@ -715,7 +715,7 @@ " self._metadata = self._prepare_metadata()\n", " return self._metadata\n", "\n", - " def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]:\n", + " def _preprocess_data(self, series_idx: torch.Tensor) -> List[Dict[str, Any]]:\n", " \"\"\"Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset.\n", "\n", " Preprocessing steps\n", @@ -729,55 +729,48 @@ "\n", " TODO: add scalers, target normalizers etc.\n", " \"\"\"\n", - " processed_data = []\n", - "\n", - " for idx in indices:\n", - " sample = self.time_series_dataset[idx.item()]\n", - "\n", - " target = sample[\"y\"]\n", - " features = sample[\"x\"]\n", - " times = sample[\"t\"]\n", - " cutoff_time = sample[\"cutoff_time\"]\n", + " sample = self.time_series_dataset[series_idx]\n", "\n", - " time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool)\n", + " target = sample[\"y\"]\n", + " features = sample[\"x\"]\n", + " times = sample[\"t\"]\n", + " cutoff_time = sample[\"cutoff_time\"]\n", "\n", - " if isinstance(target, torch.Tensor):\n", - " target = target.float()\n", - " else:\n", - " target = torch.tensor(target, dtype=torch.float32)\n", + " time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool)\n", "\n", - " if isinstance(features, torch.Tensor):\n", - " features = features.float()\n", - " else:\n", - " features = torch.tensor(features, dtype=torch.float32)\n", + " if isinstance(target, torch.Tensor):\n", + " target = target.float()\n", + " else:\n", + " target = torch.tensor(target, dtype=torch.float32)\n", "\n", - " # TODO: add scalers, target normalizers etc.\n", + " if isinstance(features, torch.Tensor):\n", + " features = features.float()\n", + " else:\n", + " features = torch.tensor(features, dtype=torch.float32)\n", "\n", - " categorical = (\n", - " features[:, self.categorical_indices]\n", - " if self.categorical_indices\n", - " else torch.zeros((features.shape[0], 0))\n", - " )\n", - " continuous = (\n", - " features[:, self.continuous_indices]\n", - " if self.continuous_indices\n", - " else torch.zeros((features.shape[0], 0))\n", - " )\n", + " # TODO: add scalers, target normalizers etc.\n", "\n", - " processed_data.append(\n", - " {\n", - " \"features\": {\"categorical\": categorical, \"continuous\": continuous},\n", - " \"target\": target,\n", - " \"static\": sample.get(\"st\", None),\n", - " \"group\": sample.get(\"group\", torch.tensor([0])),\n", - " \"length\": len(target),\n", - " \"time_mask\": time_mask,\n", - " \"times\": times,\n", - " \"cutoff_time\": cutoff_time,\n", - " }\n", - " )\n", + " categorical = (\n", + " features[:, self.categorical_indices]\n", + " if self.categorical_indices\n", + " else torch.zeros((features.shape[0], 0))\n", + " )\n", + " continuous = (\n", + " features[:, self.continuous_indices]\n", + " if self.continuous_indices\n", + " else torch.zeros((features.shape[0], 0))\n", + " )\n", "\n", - " return processed_data\n", + " return {\n", + " \"features\": {\"categorical\": categorical, \"continuous\": continuous},\n", + " \"target\": target,\n", + " \"static\": sample.get(\"st\", None),\n", + " \"group\": sample.get(\"group\", torch.tensor([0])),\n", + " \"length\": len(target),\n", + " \"time_mask\": time_mask,\n", + " \"times\": times,\n", + " \"cutoff_time\": cutoff_time,\n", + " }\n", "\n", " class _ProcessedEncoderDecoderDataset(Dataset):\n", " \"\"\"PyTorch Dataset for processed encoder-decoder time series data.\n", @@ -795,11 +788,13 @@ "\n", " def __init__(\n", " self,\n", - " processed_data: List[Dict[str, Any]],\n", + " dataset: TimeSeries,\n", + " data_module: \"EncoderDecoderTimeSeriesDataModule\",\n", " windows: List[Tuple[int, int, int, int]],\n", " add_relative_time_idx: bool = False,\n", " ):\n", - " self.processed_data = processed_data\n", + " self.dataset = dataset\n", + " self.data_module = data_module\n", " self.windows = windows\n", " self.add_relative_time_idx = add_relative_time_idx\n", "\n", @@ -850,8 +845,9 @@ " y : tensor of shape ``(pred_length, n_targets)``\n", " Target values for the decoder sequence.\n", " \"\"\"\n", + "\n", " series_idx, start_idx, enc_length, pred_length = self.windows[idx]\n", - " data = self.processed_data[series_idx]\n", + " data = self.data_module._preprocess_data(series_idx)\n", "\n", " end_idx = start_idx + enc_length + pred_length\n", " encoder_indices = slice(start_idx, start_idx + enc_length)\n", @@ -898,9 +894,7 @@ "\n", " return x, y\n", "\n", - " def _create_windows(\n", - " self, processed_data: List[Dict[str, Any]]\n", - " ) -> List[Tuple[int, int, int, int]]:\n", + " def _create_windows(self, indices: torch.Tensor) -> List[Tuple[int, int, int, int]]:\n", " \"\"\"Generate sliding windows for training, validation, and testing.\n", "\n", " Returns\n", @@ -918,8 +912,10 @@ " \"\"\"\n", " windows = []\n", "\n", - " for idx, data in enumerate(processed_data):\n", - " sequence_length = data[\"length\"]\n", + " for idx in indices:\n", + " series_idx = idx.item()\n", + " sample = self.time_series_dataset[series_idx]\n", + " sequence_length = len(sample[\"y\"])\n", "\n", " if sequence_length < self.max_encoder_length + self.max_prediction_length:\n", " continue\n", @@ -944,7 +940,7 @@ " ):\n", " windows.append(\n", " (\n", - " idx,\n", + " series_idx,\n", " start_idx,\n", " self.max_encoder_length,\n", " self.max_prediction_length,\n", @@ -979,34 +975,39 @@ "\n", " if stage is None or stage == \"fit\":\n", " if not hasattr(self, \"train_dataset\") or not hasattr(self, \"val_dataset\"):\n", - " self.train_processed = self._preprocess_data(self._train_indices)\n", - " self.val_processed = self._preprocess_data(self._val_indices)\n", - "\n", - " self.train_windows = self._create_windows(self.train_processed)\n", - " self.val_windows = self._create_windows(self.val_processed)\n", + " self.train_windows = self._create_windows(self._train_indices)\n", + " self.val_windows = self._create_windows(self._val_indices)\n", "\n", " self.train_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.train_processed, self.train_windows, self.add_relative_time_idx\n", + " self.time_series_dataset,\n", + " self,\n", + " self.train_windows,\n", + " self.add_relative_time_idx,\n", " )\n", " self.val_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.val_processed, self.val_windows, self.add_relative_time_idx\n", + " self.time_series_dataset,\n", + " self,\n", + " self.val_windows,\n", + " self.add_relative_time_idx,\n", " )\n", - " # print(self.val_dataset[0])\n", "\n", - " elif stage is None or stage == \"test\":\n", + " elif stage == \"test\":\n", " if not hasattr(self, \"test_dataset\"):\n", - " self.test_processed = self._preprocess_data(self._test_indices)\n", - " self.test_windows = self._create_windows(self.test_processed)\n", - "\n", + " self.test_windows = self._create_windows(self._test_indices)\n", " self.test_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.test_processed, self.test_windows, self.add_relative_time_idx\n", + " self.time_series_dataset,\n", + " self,\n", + " self.test_windows,\n", + " self.add_relative_time_idx,\n", " )\n", " elif stage == \"predict\":\n", " predict_indices = torch.arange(len(self.time_series_dataset))\n", - " self.predict_processed = self._preprocess_data(predict_indices)\n", - " self.predict_windows = self._create_windows(self.predict_processed)\n", + " self.predict_windows = self._create_windows(predict_indices)\n", " self.predict_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.predict_processed, self.predict_windows, self.add_relative_time_idx\n", + " self.time_series_dataset,\n", + " self,\n", + " self.predict_windows,\n", + " self.add_relative_time_idx,\n", " )\n", "\n", " def train_dataloader(self):\n", @@ -1076,26 +1077,26 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "WX-FRdusJSVN", - "outputId": "730f0fe2-f5af-4871-859d-9a4043bbeac7" + "outputId": "2b7ae9bd-0bee-4c05-a512-61193b462274" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { - "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6723608094711289,\n \"min\": -1.2295384314749236,\n \"max\": 1.3194322331654313,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.3958947786657995,\n 0.7816648993958805,\n -0.9655256111265276\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6766981377008536,\n \"min\": -1.2295384314749236,\n \"max\": 1.3194322331654313,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.5707668517530808,\n 0.5020485177883972,\n -0.7734543579445009\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2783997753182645,\n \"min\": 0.011560494046953695,\n \"max\": 0.9996855497257285,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.43066463390546217,\n 0.08751257529405387,\n 0.3593350820130162\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", + "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6729063389612355,\n \"min\": -1.298580096280071,\n \"max\": 1.2952368656723652,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.4192357980400779,\n 0.7249265048690496,\n -0.966494107115657\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6772588242164835,\n \"min\": -1.298580096280071,\n \"max\": 1.2952368656723652,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.6408981347050933,\n 0.758204270990956,\n -0.6996142581496693\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2897600721149767,\n \"min\": 0.009824875416031609,\n \"max\": 0.9977597461896692,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.18589175791370482,\n 0.8789086949480461,\n 0.010073414332526953\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", "type": "dataframe", "variable_name": "data_df" }, "text/html": [ "\n", - "
\n", + "
\n", "
\n", "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
000-0.0712220.33976301.0000000.6264260
1010.3397630.18934800.9950040.6264260
2020.1893480.67598900.9800670.6264260
3030.6759890.79726100.9553360.6264260
4040.7972610.99501600.9210610.6264260
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "\n", - "
\n", - " \n", - "\n", - "\n", - "\n", - " \n", - "
\n", - "\n", - "
\n", - "
\n" + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting pytorch-forecasting\n", + " Downloading pytorch_forecasting-1.3.0-py3-none-any.whl.metadata (13 kB)\n", + "Requirement already satisfied: numpy<=3.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.0.2)\n", + "Requirement already satisfied: torch!=2.0.1,<3.0.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.6.0+cu124)\n", + "Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)\n", + " Downloading lightning-2.5.1-py3-none-any.whl.metadata (39 kB)\n", + "Requirement already satisfied: scipy<2.0,>=1.8 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.14.1)\n", + "Requirement already satisfied: pandas<3.0.0,>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.2.2)\n", + "Requirement already satisfied: scikit-learn<2.0,>=1.2 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.6.1)\n", + "Requirement already satisfied: PyYAML<8.0,>=5.4 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.0.2)\n", + "Requirement already satisfied: fsspec<2026.0,>=2022.5.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2025.3.2)\n", + "Collecting lightning-utilities<2.0,>=0.10.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)\n", + "Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (24.2)\n", + "Collecting torchmetrics<3.0,>=0.7.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)\n", + "Requirement already satisfied: tqdm<6.0,>=4.57.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.67.1)\n", + "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.13.1)\n", + "Collecting pytorch-lightning (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (3.6.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.18.0)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.1.6)\n", + "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-curand-cu12==10.3.5.147 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (0.6.2)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (2.21.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", + "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.2.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.0)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.11.15)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from lightning-utilities<2.0,>=0.10.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (75.2.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas<3.0.0,>=1.3.0->pytorch-forecasting) (1.17.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.0.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.6.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.2)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (25.3.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.5.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.3.2)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.3.1)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.18.3)\n", + "Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.11/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.10)\n", + "Downloading pytorch_forecasting-1.3.0-py3-none-any.whl (197 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m197.7/197.7 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning-2.5.1-py3-none-any.whl (818 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m818.9/818.9 kB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m58.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m55.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m26.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m84.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)\n", + "Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m961.5/961.5 kB\u001b[0m \u001b[31m51.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading pytorch_lightning-2.5.1-py3-none-any.whl (822 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.0/823.0 kB\u001b[0m \u001b[31m40.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch-lightning, lightning, pytorch-forecasting\n", + " Attempting uninstall: nvidia-nvjitlink-cu12\n", + " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", + " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", + " Attempting uninstall: nvidia-curand-cu12\n", + " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", + " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", + " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", + " Attempting uninstall: nvidia-cufft-cu12\n", + " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", + " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", + " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", + " Attempting uninstall: nvidia-cuda-runtime-cu12\n", + " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", + " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-cupti-cu12\n", + " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cublas-cu12\n", + " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", + " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", + " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", + " Attempting uninstall: nvidia-cusparse-cu12\n", + " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", + " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", + " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", + " Attempting uninstall: nvidia-cudnn-cu12\n", + " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", + " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", + " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", + " Attempting uninstall: nvidia-cusolver-cu12\n", + " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", + " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", + " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", + "Successfully installed lightning-2.5.1 lightning-utilities-0.14.3 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-forecasting-1.3.0 pytorch-lightning-2.5.1 torchmetrics-1.7.1\n" + ] + } ], - "text/plain": [ - " series_id time_idx x y category future_known_feature \\\n", - "0 0 0 -0.071222 0.339763 0 1.000000 \n", - "1 0 1 0.339763 0.189348 0 0.995004 \n", - "2 0 2 0.189348 0.675989 0 0.980067 \n", - "3 0 3 0.675989 0.797261 0 0.955336 \n", - "4 0 4 0.797261 0.995016 0 0.921061 \n", - "\n", - " static_feature static_feature_cat \n", - "0 0.626426 0 \n", - "1 0.626426 0 \n", - "2 0.626426 0 \n", - "3 0.626426 0 \n", - "4 0.626426 0 " + "source": [ + "!pip install pytorch-forecasting" ] }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from lightning.pytorch import Trainer\n", - "import pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "\n", - "from pytorch_forecasting.metrics import MAE, SMAPE\n", - "\n", - "num_series = 100\n", - "seq_length = 50\n", - "data_list = []\n", - "for i in range(num_series):\n", - " x = np.arange(seq_length)\n", - " y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)\n", - " category = i % 5\n", - " static_value = np.random.rand()\n", - " for t in range(seq_length - 1):\n", - " data_list.append(\n", - " {\n", - " \"series_id\": i,\n", - " \"time_idx\": t,\n", - " \"x\": y[t],\n", - " \"y\": y[t + 1],\n", - " \"category\": category,\n", - " \"future_known_feature\": np.cos(t / 10),\n", - " \"static_feature\": static_value,\n", - " \"static_feature_cat\": i % 3,\n", - " }\n", - " )\n", - "data_df = pd.DataFrame(data_list)\n", - "data_df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "AxxPHK6AKSD2" - }, - "outputs": [], - "source": [ - "dataset = TimeSeries(\n", - " data=data_df,\n", - " time=\"time_idx\",\n", - " target=\"y\",\n", - " group=[\"series_id\"],\n", - " num=[\"x\", \"future_known_feature\", \"static_feature\"],\n", - " cat=[\"category\", \"static_feature_cat\"],\n", - " known=[\"future_known_feature\"],\n", - " unknown=[\"x\", \"category\"],\n", - " static=[\"static_feature\", \"static_feature_cat\"],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "5U5Lr_ZFKX0s" - }, - "outputs": [], - "source": [ - "data_module = EncoderDecoderTimeSeriesDataModule(\n", - " time_series_dataset=dataset,\n", - " max_encoder_length=30,\n", - " max_prediction_length=1,\n", - " batch_size=32,\n", - " categorical_encoders={\n", - " \"category\": NaNLabelEncoder(add_nan=True),\n", - " \"static_feature_cat\": NaNLabelEncoder(add_nan=True),\n", - " },\n", - " scalers={\n", - " \"x\": StandardScaler(),\n", - " \"future_known_feature\": StandardScaler(),\n", - " \"static_feature\": StandardScaler(),\n", - " },\n", - " target_normalizer=TorchNormalizer(),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "I8NgHxNqK9uV" - }, - "outputs": [], - "source": [ - "from typing import Dict, List, Optional, Union\n", - "\n", - "from lightning.pytorch.utilities.types import STEP_OUTPUT\n", - "import torch\n", - "from torch.optim import Optimizer\n", - "\n", - "\n", - "class BaseModel(LightningModule):\n", - " def __init__(\n", - " self,\n", - " loss: nn.Module,\n", - " logging_metrics: Optional[List[nn.Module]] = None,\n", - " optimizer: Optional[Union[Optimizer, str]] = \"adam\",\n", - " optimizer_params: Optional[Dict] = None,\n", - " lr_scheduler: Optional[str] = None,\n", - " lr_scheduler_params: Optional[Dict] = None,\n", - " ):\n", - " \"\"\"\n", - " Base model for time series forecasting.\n", - "\n", - " Parameters\n", - " ----------\n", - " loss : nn.Module\n", - " Loss function to use for training.\n", - " logging_metrics : Optional[List[nn.Module]], optional\n", - " List of metrics to log during training, validation, and testing.\n", - " optimizer : Optional[Union[Optimizer, str]], optional\n", - " Optimizer to use for training.\n", - " Can be a string (\"adam\", \"sgd\") or an instance of `torch.optim.Optimizer`.\n", - " optimizer_params : Optional[Dict], optional\n", - " Parameters for the optimizer.\n", - " lr_scheduler : Optional[str], optional\n", - " Learning rate scheduler to use.\n", - " Supported values: \"reduce_lr_on_plateau\", \"step_lr\".\n", - " lr_scheduler_params : Optional[Dict], optional\n", - " Parameters for the learning rate scheduler.\n", - " \"\"\"\n", - " super().__init__()\n", - " self.loss = loss\n", - " self.logging_metrics = logging_metrics if logging_metrics is not None else []\n", - " self.optimizer = optimizer\n", - " self.optimizer_params = optimizer_params if optimizer_params is not None else {}\n", - " self.lr_scheduler = lr_scheduler\n", - " self.lr_scheduler_params = (\n", - " lr_scheduler_params if lr_scheduler_params is not None else {}\n", - " )\n", - "\n", - " def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n", - " \"\"\"\n", - " Forward pass of the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " x : Dict[str, torch.Tensor]\n", - " Dictionary containing input tensors\n", - "\n", - " Returns\n", - " -------\n", - " Dict[str, torch.Tensor]\n", - " Dictionary containing output tensors\n", - " \"\"\"\n", - " raise NotImplementedError(\"Forward method must be implemented by subclass.\")\n", - "\n", - " def training_step(\n", - " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", - " ) -> STEP_OUTPUT:\n", - " \"\"\"\n", - " Training step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input and target tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - "\n", - " Returns\n", - " -------\n", - " STEP_OUTPUT\n", - " Dictionary containing the loss and other metrics.\n", - " \"\"\"\n", - " x, y = batch\n", - " y_hat_dict = self(x)\n", - " y_hat = y_hat_dict[\"prediction\"]\n", - " loss = self.loss(y_hat, y)\n", - " self.log(\n", - " \"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True\n", - " )\n", - " self.log_metrics(y_hat, y, prefix=\"train\")\n", - " return {\"loss\": loss}\n", - "\n", - " def validation_step(\n", - " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", - " ) -> STEP_OUTPUT:\n", - " \"\"\"\n", - " Validation step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input and target tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - "\n", - " Returns\n", - " -------\n", - " STEP_OUTPUT\n", - " Dictionary containing the loss and other metrics.\n", - " \"\"\"\n", - " x, y = batch\n", - " y_hat_dict = self(x)\n", - " y_hat = y_hat_dict[\"prediction\"]\n", - " loss = self.loss(y_hat, y)\n", - " self.log(\n", - " \"val_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", - " )\n", - " self.log_metrics(y_hat, y, prefix=\"val\")\n", - " return {\"val_loss\": loss}\n", - "\n", - " def test_step(\n", - " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", - " ) -> STEP_OUTPUT:\n", - " \"\"\"\n", - " Test step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input and target tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - "\n", - " Returns\n", - " -------\n", - " STEP_OUTPUT\n", - " Dictionary containing the loss and other metrics.\n", - " \"\"\"\n", - " x, y = batch\n", - " y_hat_dict = self(x)\n", - " y_hat = y_hat_dict[\"prediction\"]\n", - " loss = self.loss(y_hat, y)\n", - " self.log(\n", - " \"test_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", - " )\n", - " self.log_metrics(y_hat, y, prefix=\"test\")\n", - " return {\"test_loss\": loss}\n", - "\n", - " def predict_step(\n", - " self,\n", - " batch: Tuple[Dict[str, torch.Tensor]],\n", - " batch_idx: int,\n", - " dataloader_idx: int = 0,\n", - " ) -> torch.Tensor:\n", - " \"\"\"\n", - " Prediction step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - " dataloader_idx : int\n", - " Index of the dataloader.\n", - "\n", - " Returns\n", - " -------\n", - " torch.Tensor\n", - " Predicted output tensor.\n", - " \"\"\"\n", - " x, _ = batch\n", - " y_hat = self(x)\n", - " return y_hat\n", - "\n", - " def configure_optimizers(self) -> Dict:\n", - " \"\"\"\n", - " Configure the optimizer and learning rate scheduler.\n", - "\n", - " Returns\n", - " -------\n", - " Dict\n", - " Dictionary containing the optimizer and scheduler configuration.\n", - " \"\"\"\n", - " optimizer = self._get_optimizer()\n", - " if self.lr_scheduler is not None:\n", - " scheduler = self._get_scheduler(optimizer)\n", - " if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n", - " return {\n", - " \"optimizer\": optimizer,\n", - " \"lr_scheduler\": {\n", - " \"scheduler\": scheduler,\n", - " \"monitor\": \"val_loss\",\n", - " },\n", - " }\n", - " else:\n", - " return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n", - " return {\"optimizer\": optimizer}\n", - "\n", - " def _get_optimizer(self) -> Optimizer:\n", - " \"\"\"\n", - " Get the optimizer based on the specified optimizer name and parameters.\n", - "\n", - " Returns\n", - " -------\n", - " Optimizer\n", - " The optimizer instance.\n", - " \"\"\"\n", - " if isinstance(self.optimizer, str):\n", - " if self.optimizer.lower() == \"adam\":\n", - " return torch.optim.Adam(self.parameters(), **self.optimizer_params)\n", - " elif self.optimizer.lower() == \"sgd\":\n", - " return torch.optim.SGD(self.parameters(), **self.optimizer_params)\n", - " else:\n", - " raise ValueError(f\"Optimizer {self.optimizer} not supported.\")\n", - " elif isinstance(self.optimizer, Optimizer):\n", - " return self.optimizer\n", - " else:\n", - " raise ValueError(\n", - " \"Optimizer must be either a string or \"\n", - " \"an instance of torch.optim.Optimizer.\"\n", - " )\n", - "\n", - " def _get_scheduler(\n", - " self, optimizer: Optimizer\n", - " ) -> torch.optim.lr_scheduler._LRScheduler:\n", - " \"\"\"\n", - " Get the lr scheduler based on the specified scheduler name and params.\n", - "\n", - " Parameters\n", - " ----------\n", - " optimizer : Optimizer\n", - " The optimizer instance.\n", - "\n", - " Returns\n", - " -------\n", - " torch.optim.lr_scheduler._LRScheduler\n", - " The learning rate scheduler instance.\n", - " \"\"\"\n", - " if self.lr_scheduler.lower() == \"reduce_lr_on_plateau\":\n", - " return torch.optim.lr_scheduler.ReduceLROnPlateau(\n", - " optimizer, **self.lr_scheduler_params\n", - " )\n", - " elif self.lr_scheduler.lower() == \"step_lr\":\n", - " return torch.optim.lr_scheduler.StepLR(\n", - " optimizer, **self.lr_scheduler_params\n", - " )\n", - " else:\n", - " raise ValueError(f\"Scheduler {self.lr_scheduler} not supported.\")\n", - "\n", - " def log_metrics(\n", - " self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = \"val\"\n", - " ) -> None:\n", - " \"\"\"\n", - " Log additional metrics during training, validation, or testing.\n", - "\n", - " Parameters\n", - " ----------\n", - " y_hat : torch.Tensor\n", - " Predicted output tensor.\n", - " y : torch.Tensor\n", - " Target output tensor.\n", - " prefix : str\n", - " Prefix for the logged metrics (e.g., \"train\", \"val\", \"test\").\n", - " \"\"\"\n", - " for metric in self.logging_metrics:\n", - " metric_value = metric(y_hat, y)\n", - " self.log(\n", - " f\"{prefix}_{metric.__class__.__name__}\",\n", - " metric_value,\n", - " on_step=False,\n", - " on_epoch=True,\n", - " prog_bar=True,\n", - " logger=True,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "n5EBeGucK_k0" - }, - "outputs": [], - "source": [ - "from typing import Dict, List, Optional, Tuple, Union\n", - "\n", - "import torch\n", - "import torch.nn as nn\n", - "from torch.optim import Optimizer\n", - "\n", - "\n", - "class TFT(BaseModel):\n", - " def __init__(\n", - " self,\n", - " loss: nn.Module,\n", - " logging_metrics: Optional[List[nn.Module]] = None,\n", - " optimizer: Optional[Union[Optimizer, str]] = \"adam\",\n", - " optimizer_params: Optional[Dict] = None,\n", - " lr_scheduler: Optional[str] = None,\n", - " lr_scheduler_params: Optional[Dict] = None,\n", - " hidden_size: int = 64,\n", - " num_layers: int = 2,\n", - " attention_head_size: int = 4,\n", - " dropout: float = 0.1,\n", - " metadata: Optional[Dict] = None,\n", - " output_size: int = 1,\n", - " ):\n", - " super().__init__(\n", - " loss=loss,\n", - " logging_metrics=logging_metrics,\n", - " optimizer=optimizer,\n", - " optimizer_params=optimizer_params,\n", - " lr_scheduler=lr_scheduler,\n", - " lr_scheduler_params=lr_scheduler_params,\n", - " )\n", - " self.hidden_size = hidden_size\n", - " self.num_layers = num_layers\n", - " self.attention_head_size = attention_head_size\n", - " self.dropout = dropout\n", - " self.metadata = metadata\n", - " self.output_size = output_size\n", - "\n", - " self.max_encoder_length = self.metadata[\"max_encoder_length\"]\n", - " self.max_prediction_length = self.metadata[\"max_prediction_length\"]\n", - " self.encoder_cont = self.metadata[\"encoder_cont\"]\n", - " self.encoder_cat = self.metadata[\"encoder_cat\"]\n", - " self.static_categorical_features = self.metadata[\"static_categorical_features\"]\n", - " self.static_continuous_features = self.metadata[\"static_continuous_features\"]\n", - "\n", - " total_feature_size = self.encoder_cont + self.encoder_cat\n", - " total_static_size = (\n", - " self.static_categorical_features + self.static_continuous_features\n", - " )\n", - "\n", - " self.encoder_var_selection = nn.Sequential(\n", - " nn.Linear(total_feature_size, hidden_size),\n", - " nn.ReLU(),\n", - " nn.Linear(hidden_size, total_feature_size),\n", - " nn.Sigmoid(),\n", - " )\n", - "\n", - " self.decoder_var_selection = nn.Sequential(\n", - " nn.Linear(total_feature_size, hidden_size),\n", - " nn.ReLU(),\n", - " nn.Linear(hidden_size, total_feature_size),\n", - " nn.Sigmoid(),\n", - " )\n", - "\n", - " self.static_context_linear = (\n", - " nn.Linear(total_static_size, hidden_size) if total_static_size > 0 else None\n", - " )\n", - "\n", - " self.lstm_encoder = nn.LSTM(\n", - " input_size=total_feature_size,\n", - " hidden_size=hidden_size,\n", - " num_layers=num_layers,\n", - " dropout=dropout,\n", - " batch_first=True,\n", - " )\n", - "\n", - " self.lstm_decoder = nn.LSTM(\n", - " input_size=total_feature_size,\n", - " hidden_size=hidden_size,\n", - " num_layers=num_layers,\n", - " dropout=dropout,\n", - " batch_first=True,\n", - " )\n", - "\n", - " self.self_attention = nn.MultiheadAttention(\n", - " embed_dim=hidden_size,\n", - " num_heads=attention_head_size,\n", - " dropout=dropout,\n", - " batch_first=True,\n", - " )\n", - "\n", - " self.pre_output = nn.Linear(hidden_size, hidden_size)\n", - " self.output_layer = nn.Linear(hidden_size, output_size)\n", - "\n", - " def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n", - " \"\"\"\n", - " Forward pass of the TFT model.\n", - "\n", - " Parameters\n", - " ----------\n", - " x : Dict[str, torch.Tensor]\n", - " Dictionary containing input tensors:\n", - " - encoder_cat: Categorical encoder features\n", - " - encoder_cont: Continuous encoder features\n", - " - decoder_cat: Categorical decoder features\n", - " - decoder_cont: Continuous decoder features\n", - " - static_categorical_features: Static categorical features\n", - " - static_continuous_features: Static continuous features\n", - "\n", - " Returns\n", - " -------\n", - " Dict[str, torch.Tensor]\n", - " Dictionary containing output tensors:\n", - " - prediction: Prediction output (batch_size, prediction_length, output_size)\n", - " \"\"\"\n", - " batch_size = x[\"encoder_cont\"].shape[0]\n", - "\n", - " encoder_cat = x.get(\n", - " \"encoder_cat\",\n", - " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", - " )\n", - " encoder_cont = x.get(\n", - " \"encoder_cont\",\n", - " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", - " )\n", - " decoder_cat = x.get(\n", - " \"decoder_cat\",\n", - " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", - " )\n", - " decoder_cont = x.get(\n", - " \"decoder_cont\",\n", - " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", - " )\n", - "\n", - " encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2)\n", - " decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2)\n", - "\n", - " static_context = None\n", - " if self.static_context_linear is not None:\n", - " static_cat = x.get(\n", - " \"static_categorical_features\",\n", - " torch.zeros(batch_size, 0, device=self.device),\n", - " )\n", - " static_cont = x.get(\n", - " \"static_continuous_features\",\n", - " torch.zeros(batch_size, 0, device=self.device),\n", - " )\n", - "\n", - " if static_cat.size(2) == 0 and static_cont.size(2) == 0:\n", - " static_context = None\n", - " elif static_cat.size(2) == 0:\n", - " static_input = static_cont.to(\n", - " dtype=self.static_context_linear.weight.dtype\n", - " )\n", - " static_context = self.static_context_linear(static_input)\n", - " static_context = static_context.view(batch_size, self.hidden_size)\n", - " elif static_cont.size(2) == 0:\n", - " static_input = static_cat.to(\n", - " dtype=self.static_context_linear.weight.dtype\n", - " )\n", - " static_context = self.static_context_linear(static_input)\n", - " static_context = static_context.view(batch_size, self.hidden_size)\n", - " else:\n", - "\n", - " static_input = torch.cat([static_cont, static_cat], dim=1).to(\n", - " dtype=self.static_context_linear.weight.dtype\n", - " )\n", - " static_context = self.static_context_linear(static_input)\n", - " static_context = static_context.view(batch_size, self.hidden_size)\n", - "\n", - " encoder_weights = self.encoder_var_selection(encoder_input)\n", - " encoder_input = encoder_input * encoder_weights\n", - "\n", - " decoder_weights = self.decoder_var_selection(decoder_input)\n", - " decoder_input = decoder_input * decoder_weights\n", - "\n", - " if static_context is not None:\n", - " encoder_static_context = static_context.unsqueeze(1).expand(\n", - " -1, self.max_encoder_length, -1\n", - " )\n", - " decoder_static_context = static_context.unsqueeze(1).expand(\n", - " -1, self.max_prediction_length, -1\n", - " )\n", - "\n", - " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", - " encoder_output = encoder_output + encoder_static_context\n", - " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", - " decoder_output = decoder_output + decoder_static_context\n", - " else:\n", - " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", - " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", - "\n", - " sequence = torch.cat([encoder_output, decoder_output], dim=1)\n", - "\n", - " if static_context is not None:\n", - " expanded_static_context = static_context.unsqueeze(1).expand(\n", - " -1, sequence.size(1), -1\n", - " )\n", - "\n", - " attended_output, _ = self.self_attention(\n", - " sequence + expanded_static_context, sequence, sequence\n", - " )\n", - " else:\n", - " attended_output, _ = self.self_attention(sequence, sequence, sequence)\n", - "\n", - " decoder_attended = attended_output[:, -self.max_prediction_length :, :]\n", - "\n", - " output = nn.functional.relu(self.pre_output(decoder_attended))\n", - " prediction = self.output_layer(output)\n", - "\n", - " return {\"prediction\": prediction}" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000, - "referenced_widgets": [ - "78e4ca50a25146818b654814147ed1ab", - "8cd83f0e9663436b967df6e5d38766ae", - "4c7ee7e81def457aacb0d898f92437c0", - "788def26447b409da872d076b4ba8061", - "407fe1dea38d48c99473c4019089bc90", - "6084560f9ea74d2c9e106a9500bc65eb", - "f44964101be74d03a83601eed71d8d23", - "7c642543b9e446e8acbfc8ffb47510eb", - "445b8b2d05594cf091b5275315795516", - "d3f55c1a7193496dae6d6651a5e1024f", - "c89be46611364ee5abe04a5e2f60a66d", - "e2b1c2ad8a1e4456852ad49086775ba5", - "bf6e9a42c1fb45d1ad77cd21355869b6", - "a19520255a734aa5ae9bde6727c0ea58", - "2379f649cc0649ba99cca78a0d2231d8", - "8e42e30ac1ce4c33acc6b2070bfead23", - "8855eee7f91b4f54b412c7d78212521e", - "146e9284ec9543e0babeb82674d9b9be", - "caced8bb622949c7ad2dc714da459791", - "d178aa86336342aa96c6af585f887a03", - "3a3cfb2dfc094e76906f912562c3f16e", - "6e36dd66970e430f9d9ceffa97abeffb", - "a71969db0c5644efb9161637a12a32d2", - "425461aab83744a4a78e5288dc0f6cff", - "8b804e67c6da46ac845b32c5be62e884", - "9fa24325d05f43a0be3f436c469a4c32", - "b6655e268f754b72b083604a50845afb", - "e44c84c0476a424a9be0c307170813e6", - "4a3eb3595f72469292bd3cfba1a6a077", - "b4aeaefcd9ed4c508b45047df1e8f1f1", - "4fd1738d6c3c402e84620ceae44e87cf", - "003bf82570884cc1aa2dfeb489e1216e", - "00f4b77ed87e491cbfe77c9ec290ac19", - "07e51fe161674795a4be13573c3236dd", - "931eae1ec2834b5d89702f49018260d8", - "5d8f0b2318f849008b8aaa72e3a8fba9", - "d6d0a0893f7b428db860633f37f9f4ab", - "fb7da6100bff4f359705dec33b913c39", - "22dff2b9d78b405db32b95df05046bd9", - "fd3515421fc5438e8894377037c3049f", - "cd50f28b37af46fbac6c8e63656e4384", - "917cd4785e3e4563aab3e21a458ff61d", - "dcfde39e3b7842a6832f32f3efc72fb6", - "628c517725d4470f86e1c949d4321a45", - "3564138e49ce41ffa18e3fd3ab85abb6", - "23f2ad3634f64492b6ab1758c447038d", - "759cb2575bac454ab2e5b4e859cf4e6c", - "da7a77693dcf43d884d2681d3be0831c", - "16238363d4094d68892999adc0b9a0b3", - "188c04e4fb2a4214a0b2d93b42858efc", - "709f4b4d2e3f454eb61f2addc49da64d", - "2f5f5c4f08714d9bb2d48b9891d4969c", - "62161138282642a190fb53dcd61548f1", - "757cea1affed4f7888f3136158a13600", - "0b817bd6c784496594763bee72b78031", - "8144af2dce5044e481c9f4db6f47f1f6", - "af91a519e65946e88e9ad12d924776b1", - "6a4d0f0907ac470e898e70d88f399fcd", - "13f2f627d06d462589ecea442d0d0fd0", - "27e1f807b99f43bc814f20c8e2dc543c", - "3c8961406f994fe4accde9d16cdb6fe6", - "6dc9efabf69e45fbbdc4425fc8ce0c3f", - "77bf17a9e3554cc1beef372c861d5472", - "7f7592567346419db6d24d3c6a70bfc8", - "62b2bfcf51a7466a9d6170b99b6ec793", - "dd7cb24a33194ea99d166085fef36cbf", - "b2297524052f4abca043a42f97c10bc6", - "bbaba48376ed4f8695fdcd77d62614d2", - "0d0effdc3edb47ba9d6c086946ba86d6", - "57d47150e3ce4de4aec78def4295d087", - "a4fd00d13e624d41b2564d403fee0b6c", - "7afadad53d034e0d840a7c4837d42dea", - "83fdd35e50f84646a5b9757148cb97f3", - "18ae939c9c704565bb4937931e326aba", - "8e49dbeadca3499b95b05e2c4211d057", - "1aedd9b0c2f14939b2ad59494e02aed8", - "4b6a957a4fb44af58947b5a1fef7882e", - "9595f2def65a44f1a66965b07ff2a51f", - "70e131c2803d4f6fa59d99b8ec1d5e93", - "6f13514495d642a281ec2e70ff772c76", - "8b22fbff367f480b87331935a62d4cfd", - "becf001d1a4f42e88edd89e0849c008c", - "ab1836ed5b934de6ae2fe761f0be60db", - "d83f83a0be0e42f8b4f6458053286a66", - "96d02a5000f74366ab53ef25ea8d95d9", - "6b631a66a94a4b8ab2c90daf4cf146b8", - "5962ef174f574964a1f403e663e18f80", - "6cc6d52b1d974074adfe12afcc2d3e2c" - ] - }, - "id": "Si7bbZIULBZz", - "outputId": "b5e8d9c9-e1a3-4632-b764-a3b7b1931012" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.\n", - "INFO:lightning.pytorch.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.\n", - "INFO: GPU available: False, used: False\n", - "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", - "INFO: TPU available: False, using: 0 TPU cores\n", - "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", - "INFO: HPU available: False, using: 0 HPUs\n", - "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n", - "INFO: \n", - " | Name | Type | Params | Mode \n", - "---------------------------------------------------------------------\n", - "0 | loss | MSELoss | 0 | train\n", - "1 | encoder_var_selection | Sequential | 709 | train\n", - "2 | decoder_var_selection | Sequential | 709 | train\n", - "3 | static_context_linear | Linear | 192 | train\n", - "4 | lstm_encoder | LSTM | 51.5 K | train\n", - "5 | lstm_decoder | LSTM | 51.5 K | train\n", - "6 | self_attention | MultiheadAttention | 16.6 K | train\n", - "7 | pre_output | Linear | 4.2 K | train\n", - "8 | output_layer | Linear | 65 | train\n", - "---------------------------------------------------------------------\n", - "125 K Trainable params\n", - "0 Non-trainable params\n", - "125 K Total params\n", - "0.502 Total estimated model params size (MB)\n", - "18 Modules in train mode\n", - "0 Modules in eval mode\n", - "INFO:lightning.pytorch.callbacks.model_summary:\n", - " | Name | Type | Params | Mode \n", - "---------------------------------------------------------------------\n", - "0 | loss | MSELoss | 0 | train\n", - "1 | encoder_var_selection | Sequential | 709 | train\n", - "2 | decoder_var_selection | Sequential | 709 | train\n", - "3 | static_context_linear | Linear | 192 | train\n", - "4 | lstm_encoder | LSTM | 51.5 K | train\n", - "5 | lstm_decoder | LSTM | 51.5 K | train\n", - "6 | self_attention | MultiheadAttention | 16.6 K | train\n", - "7 | pre_output | Linear | 4.2 K | train\n", - "8 | output_layer | Linear | 65 | train\n", - "---------------------------------------------------------------------\n", - "125 K Trainable params\n", - "0 Non-trainable params\n", - "125 K Total params\n", - "0.502 Total estimated model params size (MB)\n", - "18 Modules in train mode\n", - "0 Modules in eval mode\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Training model...\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "78e4ca50a25146818b654814147ed1ab", - "version_major": 2, - "version_minor": 0 + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "M7PQerTbI_tM" }, - "text/plain": [ - "Sanity Checking: | | 0/? [00:00 int:\n", + " \"\"\"Return number of time series in the dataset.\"\"\"\n", + " return len(self._group_ids)\n", + "\n", + " def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:\n", + " \"\"\"Get time series data for given index.\n", + "\n", + " Returns\n", + " -------\n", + " t : numpy.ndarray of shape (n_timepoints,)\n", + " Time index for each time point in the past or present. Aligned with `y`,\n", + " and `x` not ending in `f`.\n", + "\n", + " y : torch.Tensor of shape (n_timepoints, n_targets)\n", + " Target values for each time point. Rows are time points, aligned with `t`.\n", + "\n", + " x : torch.Tensor of shape (n_timepoints, n_features)\n", + " Features for each time point. Rows are time points, aligned with `t`.\n", + "\n", + " group : torch.Tensor of shape (n_groups,)\n", + " Group identifiers for time series instances.\n", + "\n", + " st : torch.Tensor of shape (n_static_features,)\n", + " Static features.\n", + "\n", + " cutoff_time : float or numpy.float64\n", + " Cutoff time for the time series instance.\n", + "\n", + " Other Returns\n", + " -------------\n", + " weights : torch.Tensor of shape (n_timepoints,), optional\n", + " Only included if weights are not `None`.\n", + " \"\"\"\n", + " group_id = self._group_ids[index]\n", + "\n", + " if self.group:\n", + " mask = self._groups[group_id]\n", + " data = self.data.loc[mask]\n", + " else:\n", + " data = self.data\n", + "\n", + " cutoff_time = data[self.time].max()\n", + "\n", + " result = {\n", + " \"t\": data[self.time].values,\n", + " \"y\": torch.tensor(data[self.target].values),\n", + " \"x\": torch.tensor(data[self.feature_cols].values),\n", + " \"group\": torch.tensor([hash(str(group_id))]),\n", + " \"st\": torch.tensor(data[self.static].iloc[0].values if self.static else []),\n", + " \"cutoff_time\": cutoff_time,\n", + " }\n", + "\n", + " if self.data_future is not None:\n", + " if self.group:\n", + " future_mask = self.data_future.groupby(self.group).groups[group_id]\n", + " future_data = self.data_future.loc[future_mask]\n", + " else:\n", + " future_data = self.data_future\n", + "\n", + " combined_times = np.concatenate(\n", + " [data[self.time].values, future_data[self.time].values]\n", + " )\n", + " combined_times = np.unique(combined_times)\n", + " combined_times.sort()\n", + "\n", + " num_timepoints = len(combined_times)\n", + " x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan)\n", + " y_merged = np.full((num_timepoints, len(self.target)), np.nan)\n", + "\n", + " current_time_indices = {t: i for i, t in enumerate(combined_times)}\n", + " for i, t in enumerate(data[self.time].values):\n", + " idx = current_time_indices[t]\n", + " x_merged[idx] = data[self.feature_cols].values[i]\n", + " y_merged[idx] = data[self.target].values[i]\n", + "\n", + " for i, t in enumerate(future_data[self.time].values):\n", + " if t in current_time_indices:\n", + " idx = current_time_indices[t]\n", + " for j, col in enumerate(self.known):\n", + " if col in self.feature_cols:\n", + " feature_idx = self.feature_cols.index(col)\n", + " x_merged[idx, feature_idx] = future_data[col].values[i]\n", + "\n", + " result.update(\n", + " {\n", + " \"t\": combined_times,\n", + " \"x\": torch.tensor(x_merged, dtype=torch.float32),\n", + " \"y\": torch.tensor(y_merged, dtype=torch.float32),\n", + " }\n", + " )\n", + "\n", + " if self.weight:\n", + " if self.data_future is not None and self.weight in self.data_future.columns:\n", + " weights_merged = np.full(num_timepoints, np.nan)\n", + " for i, t in enumerate(data[self.time].values):\n", + " idx = current_time_indices[t]\n", + " weights_merged[idx] = data[self.weight].values[i]\n", + "\n", + " for i, t in enumerate(future_data[self.time].values):\n", + " if t in current_time_indices and self.weight in future_data.columns:\n", + " idx = current_time_indices[t]\n", + " weights_merged[idx] = future_data[self.weight].values[i]\n", + "\n", + " result[\"weights\"] = torch.tensor(weights_merged, dtype=torch.float32)\n", + " else:\n", + " result[\"weights\"] = torch.tensor(\n", + " data[self.weight].values, dtype=torch.float32\n", + " )\n", + "\n", + " return result\n", + "\n", + " def get_metadata(self) -> Dict:\n", + " \"\"\"Return metadata about the dataset.\n", + "\n", + " Returns\n", + " -------\n", + " Dict\n", + " Dictionary containing:\n", + " - cols: column names for y, x, and static features\n", + " - col_type: mapping of columns to their types (F/C)\n", + " - col_known: mapping of columns to their future known status (K/U)\n", + " \"\"\"\n", + " return self.metadata" ] }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "a71969db0c5644efb9161637a12a32d2", - "version_major": 2, - "version_minor": 0 + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0Rw9LgsXJI5V" }, - "text/plain": [ - "Validation: | | 0/? [00:00 List[Dict[str, Any]]:\n", + " \"\"\"Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset.\n", + "\n", + " Preprocessing steps\n", + " --------------------\n", + "\n", + " * Converts target (`y`) and features (`x`) to `torch.float32`.\n", + " * Masks time points that are at or before the cutoff time.\n", + " * Splits features into categorical and continuous subsets based on\n", + " predefined indices.\n", + "\n", + "\n", + " TODO: add scalers, target normalizers etc.\n", + " \"\"\"\n", + " sample = self.time_series_dataset[series_idx]\n", + "\n", + " target = sample[\"y\"]\n", + " features = sample[\"x\"]\n", + " times = sample[\"t\"]\n", + " cutoff_time = sample[\"cutoff_time\"]\n", + "\n", + " time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool)\n", + "\n", + " if isinstance(target, torch.Tensor):\n", + " target = target.float()\n", + " else:\n", + " target = torch.tensor(target, dtype=torch.float32)\n", + "\n", + " if isinstance(features, torch.Tensor):\n", + " features = features.float()\n", + " else:\n", + " features = torch.tensor(features, dtype=torch.float32)\n", + "\n", + " # TODO: add scalers, target normalizers etc.\n", + "\n", + " categorical = (\n", + " features[:, self.categorical_indices]\n", + " if self.categorical_indices\n", + " else torch.zeros((features.shape[0], 0))\n", + " )\n", + " continuous = (\n", + " features[:, self.continuous_indices]\n", + " if self.continuous_indices\n", + " else torch.zeros((features.shape[0], 0))\n", + " )\n", + "\n", + " return {\n", + " \"features\": {\"categorical\": categorical, \"continuous\": continuous},\n", + " \"target\": target,\n", + " \"static\": sample.get(\"st\", None),\n", + " \"group\": sample.get(\"group\", torch.tensor([0])),\n", + " \"length\": len(target),\n", + " \"time_mask\": time_mask,\n", + " \"times\": times,\n", + " \"cutoff_time\": cutoff_time,\n", + " }\n", + "\n", + " class _ProcessedEncoderDecoderDataset(Dataset):\n", + " \"\"\"PyTorch Dataset for processed encoder-decoder time series data.\n", + "\n", + " Parameters\n", + " ----------\n", + " dataset : TimeSeries\n", + " The base time series dataset that provides access to raw data and metadata.\n", + " data_module : EncoderDecoderTimeSeriesDataModule\n", + " The data module handling preprocessing and metadata configuration.\n", + " windows : List[Tuple[int, int, int, int]]\n", + " List of window tuples containing\n", + " (series_idx, start_idx, enc_length, pred_length).\n", + " add_relative_time_idx : bool, default=False\n", + " Whether to include relative time indices.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " dataset: TimeSeries,\n", + " data_module: \"EncoderDecoderTimeSeriesDataModule\",\n", + " windows: List[Tuple[int, int, int, int]],\n", + " add_relative_time_idx: bool = False,\n", + " ):\n", + " self.dataset = dataset\n", + " self.data_module = data_module\n", + " self.windows = windows\n", + " self.add_relative_time_idx = add_relative_time_idx\n", + "\n", + " def __len__(self):\n", + " return len(self.windows)\n", + "\n", + " def __getitem__(self, idx):\n", + " \"\"\"Retrieve a processed time series window for dataloader input.\n", + "\n", + " x : dict\n", + " Dictionary containing model inputs:\n", + "\n", + " * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features)\n", + " Categorical features for the encoder.\n", + " * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features)\n", + " Continuous features for the encoder.\n", + " * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features)\n", + " Categorical features for the decoder.\n", + " * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features)\n", + " Continuous features for the decoder.\n", + " * ``encoder_lengths`` : tensor of shape (1,)\n", + " Length of the encoder sequence.\n", + " * ``decoder_lengths`` : tensor of shape (1,)\n", + " Length of the decoder sequence.\n", + " * ``decoder_target_lengths`` : tensor of shape (1,)\n", + " Length of the decoder target sequence.\n", + " * ``groups`` : tensor of shape (1,)\n", + " Group identifier for the time series instance.\n", + " * ``encoder_time_idx`` : tensor of shape (enc_length,)\n", + " Time indices for the encoder sequence.\n", + " * ``decoder_time_idx`` : tensor of shape (pred_length,)\n", + " Time indices for the decoder sequence.\n", + " * ``target_scale`` : tensor of shape (1,)\n", + " Scaling factor for the target values.\n", + " * ``encoder_mask`` : tensor of shape (enc_length,)\n", + " Boolean mask indicating valid encoder time points.\n", + " * ``decoder_mask`` : tensor of shape (pred_length,)\n", + " Boolean mask indicating valid decoder time points.\n", + "\n", + " If static features are present, the following keys are added:\n", + "\n", + " * ``static_categorical_features`` : tensor of shape\n", + " (1, n_static_cat_features), optional\n", + " Static categorical features, if available.\n", + " * ``static_continuous_features`` : tensor of shape (1, 0), optional\n", + " Placeholder for static continuous features (currently empty).\n", + "\n", + " y : tensor of shape ``(pred_length, n_targets)``\n", + " Target values for the decoder sequence.\n", + " \"\"\"\n", + " series_idx, start_idx, enc_length, pred_length = self.windows[idx]\n", + " data = self.data_module._preprocess_data(series_idx)\n", + "\n", + " end_idx = start_idx + enc_length + pred_length\n", + " encoder_indices = slice(start_idx, start_idx + enc_length)\n", + " decoder_indices = slice(start_idx + enc_length, end_idx)\n", + "\n", + " target_scale = data[\"target\"][encoder_indices]\n", + " target_scale = target_scale[~torch.isnan(target_scale)].abs().mean()\n", + " if torch.isnan(target_scale) or target_scale == 0:\n", + " target_scale = torch.tensor(1.0)\n", + "\n", + " encoder_mask = (\n", + " data[\"time_mask\"][encoder_indices]\n", + " if \"time_mask\" in data\n", + " else torch.ones(enc_length, dtype=torch.bool)\n", + " )\n", + " decoder_mask = (\n", + " data[\"time_mask\"][decoder_indices]\n", + " if \"time_mask\" in data\n", + " else torch.zeros(pred_length, dtype=torch.bool)\n", + " )\n", + "\n", + " x = {\n", + " \"encoder_cat\": data[\"features\"][\"categorical\"][encoder_indices],\n", + " \"encoder_cont\": data[\"features\"][\"continuous\"][encoder_indices],\n", + " \"decoder_cat\": data[\"features\"][\"categorical\"][decoder_indices],\n", + " \"decoder_cont\": data[\"features\"][\"continuous\"][decoder_indices],\n", + " \"encoder_lengths\": torch.tensor(enc_length),\n", + " \"decoder_lengths\": torch.tensor(pred_length),\n", + " \"decoder_target_lengths\": torch.tensor(pred_length),\n", + " \"groups\": data[\"group\"],\n", + " \"encoder_time_idx\": torch.arange(enc_length),\n", + " \"decoder_time_idx\": torch.arange(enc_length, enc_length + pred_length),\n", + " \"target_scale\": target_scale,\n", + " \"encoder_mask\": encoder_mask,\n", + " \"decoder_mask\": decoder_mask,\n", + " }\n", + " if data[\"static\"] is not None:\n", + " x[\"static_categorical_features\"] = data[\"static\"].unsqueeze(0)\n", + " x[\"static_continuous_features\"] = torch.zeros((1, 0))\n", + "\n", + " y = data[\"target\"][decoder_indices]\n", + " if y.ndim == 1:\n", + " y = y.unsqueeze(-1)\n", + "\n", + " return x, y\n", + "\n", + " def _create_windows(self, indices: torch.Tensor) -> List[Tuple[int, int, int, int]]:\n", + " \"\"\"Generate sliding windows for training, validation, and testing.\n", + "\n", + " Returns\n", + " -------\n", + " List[Tuple[int, int, int, int]]\n", + " A list of tuples, where each tuple consists of:\n", + " - ``series_idx`` : int\n", + " Index of the time series in `time_series_dataset`.\n", + " - ``start_idx`` : int\n", + " Start index of the encoder window.\n", + " - ``enc_length`` : int\n", + " Length of the encoder input sequence.\n", + " - ``pred_length`` : int\n", + " Length of the decoder output sequence.\n", + " \"\"\"\n", + " windows = []\n", + "\n", + " for idx in indices:\n", + " series_idx = idx.item()\n", + " sample = self.time_series_dataset[series_idx]\n", + " sequence_length = len(sample[\"y\"])\n", + "\n", + " if sequence_length < self.max_encoder_length + self.max_prediction_length:\n", + " continue\n", + "\n", + " effective_min_prediction_idx = (\n", + " self.min_prediction_idx\n", + " if self.min_prediction_idx is not None\n", + " else self.max_encoder_length\n", + " )\n", + "\n", + " max_prediction_idx = sequence_length - self.max_prediction_length + 1\n", + "\n", + " if max_prediction_idx <= effective_min_prediction_idx:\n", + " continue\n", + "\n", + " for start_idx in range(\n", + " 0, max_prediction_idx - effective_min_prediction_idx\n", + " ):\n", + " if (\n", + " start_idx + self.max_encoder_length + self.max_prediction_length\n", + " <= sequence_length\n", + " ):\n", + " windows.append(\n", + " (\n", + " series_idx,\n", + " start_idx,\n", + " self.max_encoder_length,\n", + " self.max_prediction_length,\n", + " )\n", + " )\n", + "\n", + " return windows\n", + "\n", + " def setup(self, stage: Optional[str] = None):\n", + " \"\"\"Prepare the datasets for training, validation, testing, or prediction.\n", + "\n", + " Parameters\n", + " ----------\n", + " stage : Optional[str], default=None\n", + " Specifies the stage of setup. Can be one of:\n", + " - ``\"fit\"`` : Prepares training and validation datasets.\n", + " - ``\"test\"`` : Prepares the test dataset.\n", + " - ``\"predict\"`` : Prepares the dataset for inference.\n", + " - ``None`` : Prepares ``fit`` datasets.\n", + " \"\"\"\n", + " total_series = len(self.time_series_dataset)\n", + " self._split_indices = torch.randperm(total_series)\n", + "\n", + " self._train_size = int(self.train_val_test_split[0] * total_series)\n", + " self._val_size = int(self.train_val_test_split[1] * total_series)\n", + "\n", + " self._train_indices = self._split_indices[: self._train_size]\n", + " self._val_indices = self._split_indices[\n", + " self._train_size : self._train_size + self._val_size\n", + " ]\n", + " self._test_indices = self._split_indices[self._train_size + self._val_size :]\n", + "\n", + " if stage is None or stage == \"fit\":\n", + " if not hasattr(self, \"train_dataset\") or not hasattr(self, \"val_dataset\"):\n", + " self.train_windows = self._create_windows(self._train_indices)\n", + " self.val_windows = self._create_windows(self._val_indices)\n", + "\n", + " self.train_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.time_series_dataset,\n", + " self,\n", + " self.train_windows,\n", + " self.add_relative_time_idx,\n", + " )\n", + " self.val_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.time_series_dataset,\n", + " self,\n", + " self.val_windows,\n", + " self.add_relative_time_idx,\n", + " )\n", + "\n", + " elif stage == \"test\":\n", + " if not hasattr(self, \"test_dataset\"):\n", + " self.test_windows = self._create_windows(self._test_indices)\n", + " self.test_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.time_series_dataset,\n", + " self,\n", + " self.test_windows,\n", + " self.add_relative_time_idx,\n", + " )\n", + " elif stage == \"predict\":\n", + " predict_indices = torch.arange(len(self.time_series_dataset))\n", + " self.predict_windows = self._create_windows(predict_indices)\n", + " self.predict_dataset = self._ProcessedEncoderDecoderDataset(\n", + " self.time_series_dataset,\n", + " self,\n", + " self.predict_windows,\n", + " self.add_relative_time_idx,\n", + " )\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(\n", + " self.train_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " shuffle=True,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(\n", + " self.val_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(\n", + " self.test_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " def predict_dataloader(self):\n", + " return DataLoader(\n", + " self.predict_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " collate_fn=self.collate_fn,\n", + " )\n", + "\n", + " @staticmethod\n", + " def collate_fn(batch):\n", + " x_batch = {\n", + " \"encoder_cat\": torch.stack([x[\"encoder_cat\"] for x, _ in batch]),\n", + " \"encoder_cont\": torch.stack([x[\"encoder_cont\"] for x, _ in batch]),\n", + " \"decoder_cat\": torch.stack([x[\"decoder_cat\"] for x, _ in batch]),\n", + " \"decoder_cont\": torch.stack([x[\"decoder_cont\"] for x, _ in batch]),\n", + " \"encoder_lengths\": torch.stack([x[\"encoder_lengths\"] for x, _ in batch]),\n", + " \"decoder_lengths\": torch.stack([x[\"decoder_lengths\"] for x, _ in batch]),\n", + " \"decoder_target_lengths\": torch.stack(\n", + " [x[\"decoder_target_lengths\"] for x, _ in batch]\n", + " ),\n", + " \"groups\": torch.stack([x[\"groups\"] for x, _ in batch]),\n", + " \"encoder_time_idx\": torch.stack([x[\"encoder_time_idx\"] for x, _ in batch]),\n", + " \"decoder_time_idx\": torch.stack([x[\"decoder_time_idx\"] for x, _ in batch]),\n", + " \"target_scale\": torch.stack([x[\"target_scale\"] for x, _ in batch]),\n", + " \"encoder_mask\": torch.stack([x[\"encoder_mask\"] for x, _ in batch]),\n", + " \"decoder_mask\": torch.stack([x[\"decoder_mask\"] for x, _ in batch]),\n", + " }\n", + "\n", + " if \"static_categorical_features\" in batch[0][0]:\n", + " x_batch[\"static_categorical_features\"] = torch.stack(\n", + " [x[\"static_categorical_features\"] for x, _ in batch]\n", + " )\n", + " x_batch[\"static_continuous_features\"] = torch.stack(\n", + " [x[\"static_continuous_features\"] for x, _ in batch]\n", + " )\n", + "\n", + " y_batch = torch.stack([y for _, y in batch])\n", + " return x_batch, y_batch" ] }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "07e51fe161674795a4be13573c3236dd", - "version_major": 2, - "version_minor": 0 + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "WX-FRdusJSVN", + "outputId": "2b7ae9bd-0bee-4c05-a512-61193b462274" }, - "text/plain": [ - "Validation: | | 0/? [00:00\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
000-0.0712220.33976301.0000000.6264260
1010.3397630.18934800.9950040.6264260
2020.1893480.67598900.9800670.6264260
3030.6759890.79726100.9553360.6264260
4040.7972610.99501600.9210610.6264260
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "text/plain": [ + " series_id time_idx x y category future_known_feature \\\n", + "0 0 0 -0.071222 0.339763 0 1.000000 \n", + "1 0 1 0.339763 0.189348 0 0.995004 \n", + "2 0 2 0.189348 0.675989 0 0.980067 \n", + "3 0 3 0.675989 0.797261 0 0.955336 \n", + "4 0 4 0.797261 0.995016 0 0.921061 \n", + "\n", + " static_feature static_feature_cat \n", + "0 0.626426 0 \n", + "1 0.626426 0 \n", + "2 0.626426 0 \n", + "3 0.626426 0 \n", + "4 0.626426 0 " + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from lightning.pytorch import Trainer\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "\n", + "from pytorch_forecasting.metrics import MAE, SMAPE\n", + "\n", + "num_series = 100\n", + "seq_length = 50\n", + "data_list = []\n", + "for i in range(num_series):\n", + " x = np.arange(seq_length)\n", + " y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)\n", + " category = i % 5\n", + " static_value = np.random.rand()\n", + " for t in range(seq_length - 1):\n", + " data_list.append(\n", + " {\n", + " \"series_id\": i,\n", + " \"time_idx\": t,\n", + " \"x\": y[t],\n", + " \"y\": y[t + 1],\n", + " \"category\": category,\n", + " \"future_known_feature\": np.cos(t / 10),\n", + " \"static_feature\": static_value,\n", + " \"static_feature_cat\": i % 3,\n", + " }\n", + " )\n", + "data_df = pd.DataFrame(data_list)\n", + "data_df.head()" ] }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "3564138e49ce41ffa18e3fd3ab85abb6", - "version_major": 2, - "version_minor": 0 + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "AxxPHK6AKSD2" }, - "text/plain": [ - "Validation: | | 0/? [00:00 Dict[str, torch.Tensor]:\n", + " \"\"\"\n", + " Forward pass of the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " x : Dict[str, torch.Tensor]\n", + " Dictionary containing input tensors\n", + "\n", + " Returns\n", + " -------\n", + " Dict[str, torch.Tensor]\n", + " Dictionary containing output tensors\n", + " \"\"\"\n", + " raise NotImplementedError(\"Forward method must be implemented by subclass.\")\n", + "\n", + " def training_step(\n", + " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", + " ) -> STEP_OUTPUT:\n", + " \"\"\"\n", + " Training step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input and target tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + "\n", + " Returns\n", + " -------\n", + " STEP_OUTPUT\n", + " Dictionary containing the loss and other metrics.\n", + " \"\"\"\n", + " x, y = batch\n", + " y_hat_dict = self(x)\n", + " y_hat = y_hat_dict[\"prediction\"]\n", + " loss = self.loss(y_hat, y)\n", + " self.log(\n", + " \"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True\n", + " )\n", + " self.log_metrics(y_hat, y, prefix=\"train\")\n", + " return {\"loss\": loss}\n", + "\n", + " def validation_step(\n", + " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", + " ) -> STEP_OUTPUT:\n", + " \"\"\"\n", + " Validation step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input and target tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + "\n", + " Returns\n", + " -------\n", + " STEP_OUTPUT\n", + " Dictionary containing the loss and other metrics.\n", + " \"\"\"\n", + " x, y = batch\n", + " y_hat_dict = self(x)\n", + " y_hat = y_hat_dict[\"prediction\"]\n", + " loss = self.loss(y_hat, y)\n", + " self.log(\n", + " \"val_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", + " )\n", + " self.log_metrics(y_hat, y, prefix=\"val\")\n", + " return {\"val_loss\": loss}\n", + "\n", + " def test_step(\n", + " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", + " ) -> STEP_OUTPUT:\n", + " \"\"\"\n", + " Test step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input and target tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + "\n", + " Returns\n", + " -------\n", + " STEP_OUTPUT\n", + " Dictionary containing the loss and other metrics.\n", + " \"\"\"\n", + " x, y = batch\n", + " y_hat_dict = self(x)\n", + " y_hat = y_hat_dict[\"prediction\"]\n", + " loss = self.loss(y_hat, y)\n", + " self.log(\n", + " \"test_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", + " )\n", + " self.log_metrics(y_hat, y, prefix=\"test\")\n", + " return {\"test_loss\": loss}\n", + "\n", + " def predict_step(\n", + " self,\n", + " batch: Tuple[Dict[str, torch.Tensor]],\n", + " batch_idx: int,\n", + " dataloader_idx: int = 0,\n", + " ) -> torch.Tensor:\n", + " \"\"\"\n", + " Prediction step for the model.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : Tuple[Dict[str, torch.Tensor]]\n", + " Batch of data containing input tensors.\n", + " batch_idx : int\n", + " Index of the batch.\n", + " dataloader_idx : int\n", + " Index of the dataloader.\n", + "\n", + " Returns\n", + " -------\n", + " torch.Tensor\n", + " Predicted output tensor.\n", + " \"\"\"\n", + " x, _ = batch\n", + " y_hat = self(x)\n", + " return y_hat\n", + "\n", + " def configure_optimizers(self) -> Dict:\n", + " \"\"\"\n", + " Configure the optimizer and learning rate scheduler.\n", + "\n", + " Returns\n", + " -------\n", + " Dict\n", + " Dictionary containing the optimizer and scheduler configuration.\n", + " \"\"\"\n", + " optimizer = self._get_optimizer()\n", + " if self.lr_scheduler is not None:\n", + " scheduler = self._get_scheduler(optimizer)\n", + " if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n", + " return {\n", + " \"optimizer\": optimizer,\n", + " \"lr_scheduler\": {\n", + " \"scheduler\": scheduler,\n", + " \"monitor\": \"val_loss\",\n", + " },\n", + " }\n", + " else:\n", + " return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n", + " return {\"optimizer\": optimizer}\n", + "\n", + " def _get_optimizer(self) -> Optimizer:\n", + " \"\"\"\n", + " Get the optimizer based on the specified optimizer name and parameters.\n", + "\n", + " Returns\n", + " -------\n", + " Optimizer\n", + " The optimizer instance.\n", + " \"\"\"\n", + " if isinstance(self.optimizer, str):\n", + " if self.optimizer.lower() == \"adam\":\n", + " return torch.optim.Adam(self.parameters(), **self.optimizer_params)\n", + " elif self.optimizer.lower() == \"sgd\":\n", + " return torch.optim.SGD(self.parameters(), **self.optimizer_params)\n", + " else:\n", + " raise ValueError(f\"Optimizer {self.optimizer} not supported.\")\n", + " elif isinstance(self.optimizer, Optimizer):\n", + " return self.optimizer\n", + " else:\n", + " raise ValueError(\n", + " \"Optimizer must be either a string or \"\n", + " \"an instance of torch.optim.Optimizer.\"\n", + " )\n", + "\n", + " def _get_scheduler(\n", + " self, optimizer: Optimizer\n", + " ) -> torch.optim.lr_scheduler._LRScheduler:\n", + " \"\"\"\n", + " Get the lr scheduler based on the specified scheduler name and params.\n", + "\n", + " Parameters\n", + " ----------\n", + " optimizer : Optimizer\n", + " The optimizer instance.\n", + "\n", + " Returns\n", + " -------\n", + " torch.optim.lr_scheduler._LRScheduler\n", + " The learning rate scheduler instance.\n", + " \"\"\"\n", + " if self.lr_scheduler.lower() == \"reduce_lr_on_plateau\":\n", + " return torch.optim.lr_scheduler.ReduceLROnPlateau(\n", + " optimizer, **self.lr_scheduler_params\n", + " )\n", + " elif self.lr_scheduler.lower() == \"step_lr\":\n", + " return torch.optim.lr_scheduler.StepLR(\n", + " optimizer, **self.lr_scheduler_params\n", + " )\n", + " else:\n", + " raise ValueError(f\"Scheduler {self.lr_scheduler} not supported.\")\n", + "\n", + " def log_metrics(\n", + " self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = \"val\"\n", + " ) -> None:\n", + " \"\"\"\n", + " Log additional metrics during training, validation, or testing.\n", + "\n", + " Parameters\n", + " ----------\n", + " y_hat : torch.Tensor\n", + " Predicted output tensor.\n", + " y : torch.Tensor\n", + " Target output tensor.\n", + " prefix : str\n", + " Prefix for the logged metrics (e.g., \"train\", \"val\", \"test\").\n", + " \"\"\"\n", + " for metric in self.logging_metrics:\n", + " metric_value = metric(y_hat, y)\n", + " self.log(\n", + " f\"{prefix}_{metric.__class__.__name__}\",\n", + " metric_value,\n", + " on_step=False,\n", + " on_epoch=True,\n", + " prog_bar=True,\n", + " logger=True,\n", + " )" ] }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO: `Trainer.fit` stopped: `max_epochs=5` reached.\n", - "INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Evaluating model...\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9595f2def65a44f1a66965b07ff2a51f", - "version_major": 2, - "version_minor": 0 + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "n5EBeGucK_k0" }, - "text/plain": [ - "Testing: | | 0/? [00:00 0 else None\n", + " )\n", + "\n", + " self.lstm_encoder = nn.LSTM(\n", + " input_size=total_feature_size,\n", + " hidden_size=hidden_size,\n", + " num_layers=num_layers,\n", + " dropout=dropout,\n", + " batch_first=True,\n", + " )\n", + "\n", + " self.lstm_decoder = nn.LSTM(\n", + " input_size=total_feature_size,\n", + " hidden_size=hidden_size,\n", + " num_layers=num_layers,\n", + " dropout=dropout,\n", + " batch_first=True,\n", + " )\n", + "\n", + " self.self_attention = nn.MultiheadAttention(\n", + " embed_dim=hidden_size,\n", + " num_heads=attention_head_size,\n", + " dropout=dropout,\n", + " batch_first=True,\n", + " )\n", + "\n", + " self.pre_output = nn.Linear(hidden_size, hidden_size)\n", + " self.output_layer = nn.Linear(hidden_size, output_size)\n", + "\n", + " def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n", + " \"\"\"\n", + " Forward pass of the TFT model.\n", + "\n", + " Parameters\n", + " ----------\n", + " x : Dict[str, torch.Tensor]\n", + " Dictionary containing input tensors:\n", + " - encoder_cat: Categorical encoder features\n", + " - encoder_cont: Continuous encoder features\n", + " - decoder_cat: Categorical decoder features\n", + " - decoder_cont: Continuous decoder features\n", + " - static_categorical_features: Static categorical features\n", + " - static_continuous_features: Static continuous features\n", + "\n", + " Returns\n", + " -------\n", + " Dict[str, torch.Tensor]\n", + " Dictionary containing output tensors:\n", + " - prediction: Prediction output (batch_size, prediction_length, output_size)\n", + " \"\"\"\n", + " batch_size = x[\"encoder_cont\"].shape[0]\n", + "\n", + " encoder_cat = x.get(\n", + " \"encoder_cat\",\n", + " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", + " )\n", + " encoder_cont = x.get(\n", + " \"encoder_cont\",\n", + " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", + " )\n", + " decoder_cat = x.get(\n", + " \"decoder_cat\",\n", + " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", + " )\n", + " decoder_cont = x.get(\n", + " \"decoder_cont\",\n", + " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", + " )\n", + "\n", + " encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2)\n", + " decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2)\n", + "\n", + " static_context = None\n", + " if self.static_context_linear is not None:\n", + " static_cat = x.get(\n", + " \"static_categorical_features\",\n", + " torch.zeros(batch_size, 0, device=self.device),\n", + " )\n", + " static_cont = x.get(\n", + " \"static_continuous_features\",\n", + " torch.zeros(batch_size, 0, device=self.device),\n", + " )\n", + "\n", + " if static_cat.size(2) == 0 and static_cont.size(2) == 0:\n", + " static_context = None\n", + " elif static_cat.size(2) == 0:\n", + " static_input = static_cont.to(\n", + " dtype=self.static_context_linear.weight.dtype\n", + " )\n", + " static_context = self.static_context_linear(static_input)\n", + " static_context = static_context.view(batch_size, self.hidden_size)\n", + " elif static_cont.size(2) == 0:\n", + " static_input = static_cat.to(\n", + " dtype=self.static_context_linear.weight.dtype\n", + " )\n", + " static_context = self.static_context_linear(static_input)\n", + " static_context = static_context.view(batch_size, self.hidden_size)\n", + " else:\n", + "\n", + " static_input = torch.cat([static_cont, static_cat], dim=1).to(\n", + " dtype=self.static_context_linear.weight.dtype\n", + " )\n", + " static_context = self.static_context_linear(static_input)\n", + " static_context = static_context.view(batch_size, self.hidden_size)\n", + "\n", + " encoder_weights = self.encoder_var_selection(encoder_input)\n", + " encoder_input = encoder_input * encoder_weights\n", + "\n", + " decoder_weights = self.decoder_var_selection(decoder_input)\n", + " decoder_input = decoder_input * decoder_weights\n", + "\n", + " if static_context is not None:\n", + " encoder_static_context = static_context.unsqueeze(1).expand(\n", + " -1, self.max_encoder_length, -1\n", + " )\n", + " decoder_static_context = static_context.unsqueeze(1).expand(\n", + " -1, self.max_prediction_length, -1\n", + " )\n", + "\n", + " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", + " encoder_output = encoder_output + encoder_static_context\n", + " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", + " decoder_output = decoder_output + decoder_static_context\n", + " else:\n", + " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", + " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", + "\n", + " sequence = torch.cat([encoder_output, decoder_output], dim=1)\n", + "\n", + " if static_context is not None:\n", + " expanded_static_context = static_context.unsqueeze(1).expand(\n", + " -1, sequence.size(1), -1\n", + " )\n", + "\n", + " attended_output, _ = self.self_attention(\n", + " sequence + expanded_static_context, sequence, sequence\n", + " )\n", + " else:\n", + " attended_output, _ = self.self_attention(sequence, sequence, sequence)\n", + "\n", + " decoder_attended = attended_output[:, -self.max_prediction_length :, :]\n", + "\n", + " output = nn.functional.relu(self.pre_output(decoder_attended))\n", + " prediction = self.output_layer(output)\n", + "\n", + " return {\"prediction\": prediction}" ] }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
-       "┃        Test metric               DataLoader 0        ┃\n",
-       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
-       "│         test_MAE              0.4637378454208374     │\n",
-       "│        test_SMAPE             1.0857858657836914     │\n",
-       "│         test_loss            0.014832879416644573    │\n",
-       "└───────────────────────────┴───────────────────────────┘\n",
-       "
\n" + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "78e4ca50a25146818b654814147ed1ab", + "8cd83f0e9663436b967df6e5d38766ae", + "4c7ee7e81def457aacb0d898f92437c0", + "788def26447b409da872d076b4ba8061", + "407fe1dea38d48c99473c4019089bc90", + "6084560f9ea74d2c9e106a9500bc65eb", + "f44964101be74d03a83601eed71d8d23", + "7c642543b9e446e8acbfc8ffb47510eb", + "445b8b2d05594cf091b5275315795516", + "d3f55c1a7193496dae6d6651a5e1024f", + "c89be46611364ee5abe04a5e2f60a66d", + "e2b1c2ad8a1e4456852ad49086775ba5", + "bf6e9a42c1fb45d1ad77cd21355869b6", + "a19520255a734aa5ae9bde6727c0ea58", + "2379f649cc0649ba99cca78a0d2231d8", + "8e42e30ac1ce4c33acc6b2070bfead23", + "8855eee7f91b4f54b412c7d78212521e", + "146e9284ec9543e0babeb82674d9b9be", + "caced8bb622949c7ad2dc714da459791", + "d178aa86336342aa96c6af585f887a03", + "3a3cfb2dfc094e76906f912562c3f16e", + "6e36dd66970e430f9d9ceffa97abeffb", + "a71969db0c5644efb9161637a12a32d2", + "425461aab83744a4a78e5288dc0f6cff", + "8b804e67c6da46ac845b32c5be62e884", + "9fa24325d05f43a0be3f436c469a4c32", + "b6655e268f754b72b083604a50845afb", + "e44c84c0476a424a9be0c307170813e6", + "4a3eb3595f72469292bd3cfba1a6a077", + "b4aeaefcd9ed4c508b45047df1e8f1f1", + "4fd1738d6c3c402e84620ceae44e87cf", + "003bf82570884cc1aa2dfeb489e1216e", + "00f4b77ed87e491cbfe77c9ec290ac19", + "07e51fe161674795a4be13573c3236dd", + "931eae1ec2834b5d89702f49018260d8", + "5d8f0b2318f849008b8aaa72e3a8fba9", + "d6d0a0893f7b428db860633f37f9f4ab", + "fb7da6100bff4f359705dec33b913c39", + "22dff2b9d78b405db32b95df05046bd9", + "fd3515421fc5438e8894377037c3049f", + "cd50f28b37af46fbac6c8e63656e4384", + "917cd4785e3e4563aab3e21a458ff61d", + "dcfde39e3b7842a6832f32f3efc72fb6", + "628c517725d4470f86e1c949d4321a45", + "3564138e49ce41ffa18e3fd3ab85abb6", + "23f2ad3634f64492b6ab1758c447038d", + "759cb2575bac454ab2e5b4e859cf4e6c", + "da7a77693dcf43d884d2681d3be0831c", + "16238363d4094d68892999adc0b9a0b3", + "188c04e4fb2a4214a0b2d93b42858efc", + "709f4b4d2e3f454eb61f2addc49da64d", + "2f5f5c4f08714d9bb2d48b9891d4969c", + "62161138282642a190fb53dcd61548f1", + "757cea1affed4f7888f3136158a13600", + "0b817bd6c784496594763bee72b78031", + "8144af2dce5044e481c9f4db6f47f1f6", + "af91a519e65946e88e9ad12d924776b1", + "6a4d0f0907ac470e898e70d88f399fcd", + "13f2f627d06d462589ecea442d0d0fd0", + "27e1f807b99f43bc814f20c8e2dc543c", + "3c8961406f994fe4accde9d16cdb6fe6", + "6dc9efabf69e45fbbdc4425fc8ce0c3f", + "77bf17a9e3554cc1beef372c861d5472", + "7f7592567346419db6d24d3c6a70bfc8", + "62b2bfcf51a7466a9d6170b99b6ec793", + "dd7cb24a33194ea99d166085fef36cbf", + "b2297524052f4abca043a42f97c10bc6", + "bbaba48376ed4f8695fdcd77d62614d2", + "0d0effdc3edb47ba9d6c086946ba86d6", + "57d47150e3ce4de4aec78def4295d087", + "a4fd00d13e624d41b2564d403fee0b6c", + "7afadad53d034e0d840a7c4837d42dea", + "83fdd35e50f84646a5b9757148cb97f3", + "18ae939c9c704565bb4937931e326aba", + "8e49dbeadca3499b95b05e2c4211d057", + "1aedd9b0c2f14939b2ad59494e02aed8", + "4b6a957a4fb44af58947b5a1fef7882e", + "9595f2def65a44f1a66965b07ff2a51f", + "70e131c2803d4f6fa59d99b8ec1d5e93", + "6f13514495d642a281ec2e70ff772c76", + "8b22fbff367f480b87331935a62d4cfd", + "becf001d1a4f42e88edd89e0849c008c", + "ab1836ed5b934de6ae2fe761f0be60db", + "d83f83a0be0e42f8b4f6458053286a66", + "96d02a5000f74366ab53ef25ea8d95d9", + "6b631a66a94a4b8ab2c90daf4cf146b8", + "5962ef174f574964a1f403e663e18f80", + "6cc6d52b1d974074adfe12afcc2d3e2c" + ] + }, + "id": "Si7bbZIULBZz", + "outputId": "b5e8d9c9-e1a3-4632-b764-a3b7b1931012" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.\n", + "INFO:lightning.pytorch.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.\n", + "INFO: GPU available: False, used: False\n", + "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", + "INFO: TPU available: False, using: 0 TPU cores\n", + "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO: HPU available: False, using: 0 HPUs\n", + "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n", + "INFO: \n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 709 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 51.5 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "125 K Trainable params\n", + "0 Non-trainable params\n", + "125 K Total params\n", + "0.502 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n", + "INFO:lightning.pytorch.callbacks.model_summary:\n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 709 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 51.5 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "125 K Trainable params\n", + "0 Non-trainable params\n", + "125 K Total params\n", + "0.502 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Training model...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "78e4ca50a25146818b654814147ed1ab", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test_MAE 0.4637378454208374 │\n", + "│ test_SMAPE 1.0857858657836914 │\n", + "│ test_loss 0.014832879416644573 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.4637378454208374 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0857858657836914 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.014832879416644573 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Prediction shape: torch.Size([32, 1, 1])\n", + "First prediction values: [[-0.00597369]]\n", + "First true values: [[-0.09480439]]\n", + "\n", + "TFT model test complete!\n" + ] + } ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.4637378454208374 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0857858657836914 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.014832879416644573 \u001b[0m\u001b[35m \u001b[0m│\n", - "└───────────────────────────┴───────────────────────────┘\n" + "source": [ + "model = TFT(\n", + " loss=nn.MSELoss(),\n", + " logging_metrics=[MAE(), SMAPE()],\n", + " optimizer=\"adam\",\n", + " optimizer_params={\"lr\": 1e-3},\n", + " lr_scheduler=\"reduce_lr_on_plateau\",\n", + " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", + " hidden_size=64,\n", + " num_layers=2,\n", + " attention_head_size=4,\n", + " dropout=0.1,\n", + " metadata=data_module.metadata,\n", + ")\n", + "\n", + "print(\"\\nTraining model...\")\n", + "trainer = Trainer(max_epochs=5, accelerator=\"auto\", devices=1, enable_progress_bar=True)\n", + "\n", + "trainer.fit(model, data_module)\n", + "\n", + "print(\"\\nEvaluating model...\")\n", + "test_metrics = trainer.test(model, data_module)\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " test_batch = next(iter(data_module.test_dataloader()))\n", + " x_test, y_test = test_batch\n", + " y_pred = model(x_test)\n", + "\n", + " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", + " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", + " print(\"First true values:\", y_test[0].cpu().numpy())\n", + "print(\"\\nTFT model test complete!\")" ] }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Prediction shape: torch.Size([32, 1, 1])\n", - "First prediction values: [[-0.00597369]]\n", - "First true values: [[-0.09480439]]\n", - "\n", - "TFT model test complete!\n" - ] - } - ], - "source": [ - "model = TFT(\n", - " loss=nn.MSELoss(),\n", - " logging_metrics=[MAE(), SMAPE()],\n", - " optimizer=\"adam\",\n", - " optimizer_params={\"lr\": 1e-3},\n", - " lr_scheduler=\"reduce_lr_on_plateau\",\n", - " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", - " hidden_size=64,\n", - " num_layers=2,\n", - " attention_head_size=4,\n", - " dropout=0.1,\n", - " metadata=data_module.metadata,\n", - ")\n", - "\n", - "print(\"\\nTraining model...\")\n", - "trainer = Trainer(max_epochs=5, accelerator=\"auto\", devices=1, enable_progress_bar=True)\n", - "\n", - "trainer.fit(model, data_module)\n", - "\n", - "print(\"\\nEvaluating model...\")\n", - "test_metrics = trainer.test(model, data_module)\n", - "\n", - "model.eval()\n", - "with torch.no_grad():\n", - " test_batch = next(iter(data_module.test_dataloader()))\n", - " x_test, y_test = test_batch\n", - " y_pred = model(x_test)\n", - "\n", - " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", - " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", - " print(\"First true values:\", y_test[0].cpu().numpy())\n", - "print(\"\\nTFT model test complete!\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zVRwi2MvLGgc" - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "003bf82570884cc1aa2dfeb489e1216e": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "00f4b77ed87e491cbfe77c9ec290ac19": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "07e51fe161674795a4be13573c3236dd": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_931eae1ec2834b5d89702f49018260d8", - "IPY_MODEL_5d8f0b2318f849008b8aaa72e3a8fba9", - "IPY_MODEL_d6d0a0893f7b428db860633f37f9f4ab" - ], - "layout": "IPY_MODEL_fb7da6100bff4f359705dec33b913c39" - } - }, - "0b817bd6c784496594763bee72b78031": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "0d0effdc3edb47ba9d6c086946ba86d6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_18ae939c9c704565bb4937931e326aba", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_8e49dbeadca3499b95b05e2c4211d057", - "value": 9 - } - }, - "13f2f627d06d462589ecea442d0d0fd0": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_62b2bfcf51a7466a9d6170b99b6ec793", - "placeholder": "​", - "style": "IPY_MODEL_dd7cb24a33194ea99d166085fef36cbf", - "value": " 9/9 [00:00<00:00, 11.96it/s]" - } - }, - "146e9284ec9543e0babeb82674d9b9be": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "16238363d4094d68892999adc0b9a0b3": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "188c04e4fb2a4214a0b2d93b42858efc": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "18ae939c9c704565bb4937931e326aba": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "1aedd9b0c2f14939b2ad59494e02aed8": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "22dff2b9d78b405db32b95df05046bd9": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2379f649cc0649ba99cca78a0d2231d8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_3a3cfb2dfc094e76906f912562c3f16e", - "placeholder": "​", - "style": "IPY_MODEL_6e36dd66970e430f9d9ceffa97abeffb", - "value": " 42/42 [00:05<00:00,  7.01it/s, v_num=1, train_loss_step=0.00959, val_loss=0.0166, val_MAE=0.472, val_SMAPE=1.050, train_loss_epoch=0.0169, train_MAE=0.464, train_SMAPE=1.030]" - } - }, - "23f2ad3634f64492b6ab1758c447038d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_188c04e4fb2a4214a0b2d93b42858efc", - "placeholder": "​", - "style": "IPY_MODEL_709f4b4d2e3f454eb61f2addc49da64d", - "value": "Validation DataLoader 0: 100%" - } - }, - "27e1f807b99f43bc814f20c8e2dc543c": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "2f5f5c4f08714d9bb2d48b9891d4969c": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "3564138e49ce41ffa18e3fd3ab85abb6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_23f2ad3634f64492b6ab1758c447038d", - "IPY_MODEL_759cb2575bac454ab2e5b4e859cf4e6c", - "IPY_MODEL_da7a77693dcf43d884d2681d3be0831c" - ], - "layout": "IPY_MODEL_16238363d4094d68892999adc0b9a0b3" - } - }, - "3a3cfb2dfc094e76906f912562c3f16e": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "3c8961406f994fe4accde9d16cdb6fe6": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "407fe1dea38d48c99473c4019089bc90": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "425461aab83744a4a78e5288dc0f6cff": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_e44c84c0476a424a9be0c307170813e6", - "placeholder": "​", - "style": "IPY_MODEL_4a3eb3595f72469292bd3cfba1a6a077", - "value": "Validation DataLoader 0: 100%" - } - }, - "445b8b2d05594cf091b5275315795516": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "4a3eb3595f72469292bd3cfba1a6a077": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "4b6a957a4fb44af58947b5a1fef7882e": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "4c7ee7e81def457aacb0d898f92437c0": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_7c642543b9e446e8acbfc8ffb47510eb", - "max": 2, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_445b8b2d05594cf091b5275315795516", - "value": 2 - } - }, - "4fd1738d6c3c402e84620ceae44e87cf": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "57d47150e3ce4de4aec78def4295d087": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_1aedd9b0c2f14939b2ad59494e02aed8", - "placeholder": "​", - "style": "IPY_MODEL_4b6a957a4fb44af58947b5a1fef7882e", - "value": " 9/9 [00:00<00:00, 12.24it/s]" - } - }, - "5962ef174f574964a1f403e663e18f80": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "5d8f0b2318f849008b8aaa72e3a8fba9": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_cd50f28b37af46fbac6c8e63656e4384", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_917cd4785e3e4563aab3e21a458ff61d", - "value": 9 - } - }, - "6084560f9ea74d2c9e106a9500bc65eb": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "62161138282642a190fb53dcd61548f1": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "628c517725d4470f86e1c949d4321a45": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "62b2bfcf51a7466a9d6170b99b6ec793": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "6a4d0f0907ac470e898e70d88f399fcd": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_77bf17a9e3554cc1beef372c861d5472", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_7f7592567346419db6d24d3c6a70bfc8", - "value": 9 - } - }, - "6b631a66a94a4b8ab2c90daf4cf146b8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "6cc6d52b1d974074adfe12afcc2d3e2c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "6dc9efabf69e45fbbdc4425fc8ce0c3f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "6e36dd66970e430f9d9ceffa97abeffb": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "6f13514495d642a281ec2e70ff772c76": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_96d02a5000f74366ab53ef25ea8d95d9", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_6b631a66a94a4b8ab2c90daf4cf146b8", - "value": 9 - } - }, - "709f4b4d2e3f454eb61f2addc49da64d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "70e131c2803d4f6fa59d99b8ec1d5e93": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_ab1836ed5b934de6ae2fe761f0be60db", - "placeholder": "​", - "style": "IPY_MODEL_d83f83a0be0e42f8b4f6458053286a66", - "value": "Testing DataLoader 0: 100%" - } - }, - "757cea1affed4f7888f3136158a13600": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "759cb2575bac454ab2e5b4e859cf4e6c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_2f5f5c4f08714d9bb2d48b9891d4969c", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_62161138282642a190fb53dcd61548f1", - "value": 9 - } - }, - "77bf17a9e3554cc1beef372c861d5472": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "788def26447b409da872d076b4ba8061": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_d3f55c1a7193496dae6d6651a5e1024f", - "placeholder": "​", - "style": "IPY_MODEL_c89be46611364ee5abe04a5e2f60a66d", - "value": " 2/2 [00:00<00:00, 18.07it/s]" - } - }, - "78e4ca50a25146818b654814147ed1ab": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_8cd83f0e9663436b967df6e5d38766ae", - "IPY_MODEL_4c7ee7e81def457aacb0d898f92437c0", - "IPY_MODEL_788def26447b409da872d076b4ba8061" - ], - "layout": "IPY_MODEL_407fe1dea38d48c99473c4019089bc90" - } - }, - "7afadad53d034e0d840a7c4837d42dea": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "7c642543b9e446e8acbfc8ffb47510eb": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "7f7592567346419db6d24d3c6a70bfc8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "8144af2dce5044e481c9f4db6f47f1f6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_af91a519e65946e88e9ad12d924776b1", - "IPY_MODEL_6a4d0f0907ac470e898e70d88f399fcd", - "IPY_MODEL_13f2f627d06d462589ecea442d0d0fd0" - ], - "layout": "IPY_MODEL_27e1f807b99f43bc814f20c8e2dc543c" - } - }, - "83fdd35e50f84646a5b9757148cb97f3": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "8855eee7f91b4f54b412c7d78212521e": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "8b22fbff367f480b87331935a62d4cfd": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_5962ef174f574964a1f403e663e18f80", - "placeholder": "​", - "style": "IPY_MODEL_6cc6d52b1d974074adfe12afcc2d3e2c", - "value": " 9/9 [00:00<00:00, 12.57it/s]" - } - }, - "8b804e67c6da46ac845b32c5be62e884": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_b4aeaefcd9ed4c508b45047df1e8f1f1", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_4fd1738d6c3c402e84620ceae44e87cf", - "value": 9 - } - }, - "8cd83f0e9663436b967df6e5d38766ae": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_6084560f9ea74d2c9e106a9500bc65eb", - "placeholder": "​", - "style": "IPY_MODEL_f44964101be74d03a83601eed71d8d23", - "value": "Sanity Checking DataLoader 0: 100%" - } - }, - "8e42e30ac1ce4c33acc6b2070bfead23": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": "100%" - } - }, - "8e49dbeadca3499b95b05e2c4211d057": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "917cd4785e3e4563aab3e21a458ff61d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "931eae1ec2834b5d89702f49018260d8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_22dff2b9d78b405db32b95df05046bd9", - "placeholder": "​", - "style": "IPY_MODEL_fd3515421fc5438e8894377037c3049f", - "value": "Validation DataLoader 0: 100%" - } - }, - "9595f2def65a44f1a66965b07ff2a51f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_70e131c2803d4f6fa59d99b8ec1d5e93", - "IPY_MODEL_6f13514495d642a281ec2e70ff772c76", - "IPY_MODEL_8b22fbff367f480b87331935a62d4cfd" - ], - "layout": "IPY_MODEL_becf001d1a4f42e88edd89e0849c008c" - } - }, - "96d02a5000f74366ab53ef25ea8d95d9": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "9fa24325d05f43a0be3f436c469a4c32": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_003bf82570884cc1aa2dfeb489e1216e", - "placeholder": "​", - "style": "IPY_MODEL_00f4b77ed87e491cbfe77c9ec290ac19", - "value": " 9/9 [00:00<00:00, 12.36it/s]" - } - }, - "a19520255a734aa5ae9bde6727c0ea58": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_caced8bb622949c7ad2dc714da459791", - "max": 42, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_d178aa86336342aa96c6af585f887a03", - "value": 42 - } - }, - "a4fd00d13e624d41b2564d403fee0b6c": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "a71969db0c5644efb9161637a12a32d2": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_425461aab83744a4a78e5288dc0f6cff", - "IPY_MODEL_8b804e67c6da46ac845b32c5be62e884", - "IPY_MODEL_9fa24325d05f43a0be3f436c469a4c32" - ], - "layout": "IPY_MODEL_b6655e268f754b72b083604a50845afb" - } - }, - "ab1836ed5b934de6ae2fe761f0be60db": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "af91a519e65946e88e9ad12d924776b1": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_3c8961406f994fe4accde9d16cdb6fe6", - "placeholder": "​", - "style": "IPY_MODEL_6dc9efabf69e45fbbdc4425fc8ce0c3f", - "value": "Validation DataLoader 0: 100%" - } - }, - "b2297524052f4abca043a42f97c10bc6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_bbaba48376ed4f8695fdcd77d62614d2", - "IPY_MODEL_0d0effdc3edb47ba9d6c086946ba86d6", - "IPY_MODEL_57d47150e3ce4de4aec78def4295d087" - ], - "layout": "IPY_MODEL_a4fd00d13e624d41b2564d403fee0b6c" - } - }, - "b4aeaefcd9ed4c508b45047df1e8f1f1": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "b6655e268f754b72b083604a50845afb": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "bbaba48376ed4f8695fdcd77d62614d2": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_7afadad53d034e0d840a7c4837d42dea", - "placeholder": "​", - "style": "IPY_MODEL_83fdd35e50f84646a5b9757148cb97f3", - "value": "Validation DataLoader 0: 100%" - } - }, - "becf001d1a4f42e88edd89e0849c008c": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": "100%" - } - }, - "bf6e9a42c1fb45d1ad77cd21355869b6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_8855eee7f91b4f54b412c7d78212521e", - "placeholder": "​", - "style": "IPY_MODEL_146e9284ec9543e0babeb82674d9b9be", - "value": "Epoch 4: 100%" - } - }, - "c89be46611364ee5abe04a5e2f60a66d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "caced8bb622949c7ad2dc714da459791": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "cd50f28b37af46fbac6c8e63656e4384": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d178aa86336342aa96c6af585f887a03": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "d3f55c1a7193496dae6d6651a5e1024f": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d6d0a0893f7b428db860633f37f9f4ab": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_dcfde39e3b7842a6832f32f3efc72fb6", - "placeholder": "​", - "style": "IPY_MODEL_628c517725d4470f86e1c949d4321a45", - "value": " 9/9 [00:00<00:00, 12.22it/s]" - } - }, - "d83f83a0be0e42f8b4f6458053286a66": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "da7a77693dcf43d884d2681d3be0831c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_757cea1affed4f7888f3136158a13600", - "placeholder": "​", - "style": "IPY_MODEL_0b817bd6c784496594763bee72b78031", - "value": " 9/9 [00:00<00:00, 12.37it/s]" - } - }, - "dcfde39e3b7842a6832f32f3efc72fb6": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "dd7cb24a33194ea99d166085fef36cbf": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "e2b1c2ad8a1e4456852ad49086775ba5": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_bf6e9a42c1fb45d1ad77cd21355869b6", - "IPY_MODEL_a19520255a734aa5ae9bde6727c0ea58", - "IPY_MODEL_2379f649cc0649ba99cca78a0d2231d8" - ], - "layout": "IPY_MODEL_8e42e30ac1ce4c33acc6b2070bfead23" - } - }, - "e44c84c0476a424a9be0c307170813e6": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f44964101be74d03a83601eed71d8d23": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zVRwi2MvLGgc" + }, + "outputs": [], + "source": [] } - }, - "fb7da6100bff4f359705dec33b913c39": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" } }, - "fd3515421fc5438e8894377037c3049f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - } + "nbformat": 4, + "nbformat_minor": 0 } - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} From 20aafb749cfebdb1f9789b4dff5120fa8527db74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 30 Apr 2025 18:40:01 +0200 Subject: [PATCH 15/43] refactor file --- pytorch_forecasting/data/__init__.py | 3 +- .../data/timeseries/__init__.py | 9 + .../data/timeseries/_coerce.py | 25 ++ .../_timeseries.py} | 286 +----------------- .../data/timeseries/_timeseries_v2.py | 276 +++++++++++++++++ 5 files changed, 314 insertions(+), 285 deletions(-) create mode 100644 pytorch_forecasting/data/timeseries/__init__.py create mode 100644 pytorch_forecasting/data/timeseries/_coerce.py rename pytorch_forecasting/data/{timeseries.py => timeseries/_timeseries.py} (90%) create mode 100644 pytorch_forecasting/data/timeseries/_timeseries_v2.py diff --git a/pytorch_forecasting/data/__init__.py b/pytorch_forecasting/data/__init__.py index 301c8394d..17be285a0 100644 --- a/pytorch_forecasting/data/__init__.py +++ b/pytorch_forecasting/data/__init__.py @@ -13,10 +13,11 @@ TorchNormalizer, ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler -from pytorch_forecasting.data.timeseries import TimeSeriesDataSet +from pytorch_forecasting.data.timeseries import TimeSeries, TimeSeriesDataSet __all__ = [ "TimeSeriesDataSet", + "TimeSeries", "NaNLabelEncoder", "GroupNormalizer", "TorchNormalizer", diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py new file mode 100644 index 000000000..7734cccf2 --- /dev/null +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -0,0 +1,9 @@ +"""Data loaders for time series data.""" + +from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries +from pytorch_forecasting.data.timeseries._timeseries import TimeSeriesDataSet + +__all__ = [ + "TimeSeriesDataSet", + "TimeSeries", +] diff --git a/pytorch_forecasting/data/timeseries/_coerce.py b/pytorch_forecasting/data/timeseries/_coerce.py new file mode 100644 index 000000000..328431aa8 --- /dev/null +++ b/pytorch_forecasting/data/timeseries/_coerce.py @@ -0,0 +1,25 @@ +"""Coercion functions for various data types.""" + +from copy import deepcopy + + +def _coerce_to_list(obj): + """Coerce object to list. + + None is coerced to empty list, otherwise list constructor is used. + """ + if obj is None: + return [] + if isinstance(obj, str): + return [obj] + return list(obj) + + +def _coerce_to_dict(obj): + """Coerce object to dict. + + None is coerce to empty dict, otherwise deepcopy is used. + """ + if obj is None: + return {} + return deepcopy(obj) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py similarity index 90% rename from pytorch_forecasting/data/timeseries.py rename to pytorch_forecasting/data/timeseries/_timeseries.py index fda08d561..263e0ea3a 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -9,7 +9,7 @@ from copy import copy as _copy, deepcopy from functools import lru_cache import inspect -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Optional, Type, TypeVar, Union import warnings import numpy as np @@ -31,6 +31,7 @@ TorchNormalizer, ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler +from pytorch_forecasting.data.timeseries._coerce import _coerce_to_dict, _coerce_to_list from pytorch_forecasting.utils import repr_class from pytorch_forecasting.utils._dependencies import _check_matplotlib @@ -2663,286 +2664,3 @@ def __repr__(self) -> str: attributes=self.get_parameters(), extra_attributes=dict(length=len(self)), ) - - -def _coerce_to_list(obj): - """Coerce object to list. - - None is coerced to empty list, otherwise list constructor is used. - """ - if obj is None: - return [] - if isinstance(obj, str): - return [obj] - return list(obj) - - -def _coerce_to_dict(obj): - """Coerce object to dict. - - None is coerce to empty dict, otherwise deepcopy is used. - """ - if obj is None: - return {} - return deepcopy(obj) - - -####################################################################################### -# Disclaimer: This dataset class is still work in progress and experimental, please -# use with care. This class is a basic skeleton of how the data-handling pipeline may -# look like in the future. -# This is the D1 layer that is a "Raw Dataset Layer" mainly for raw data ingestion -# and turning the data to tensors. -# For now, this pipeline handles the simplest situation: The whole data can be loaded -# into the memory. -####################################################################################### - - -class TimeSeries(Dataset): - """PyTorch Dataset for time series data stored in pandas DataFrame. - - Parameters - ---------- - data : pd.DataFrame - data frame with sequence data. - Column names must all be str, and contain str as referred to below. - data_future : pd.DataFrame, optional, default=None - data frame with future data. - Column names must all be str, and contain str as referred to below. - May contain only columns that are in time, group, weight, known, or static. - time : str, optional, default = first col not in group_ids, weight, target, static. - integer typed column denoting the time index within ``data``. - This column is used to determine the sequence of samples. - If there are no missing observations, - the time index should increase by ``+1`` for each subsequent sample. - The first time_idx for each series does not necessarily - have to be ``0`` but any value is allowed. - target : str or List[str], optional, default = last column (at iloc -1) - column(s) in ``data`` denoting the forecasting target. - Can be categorical or numerical dtype. - group : List[str], optional, default = None - list of column names identifying a time series instance within ``data``. - This means that the ``group`` together uniquely identify an instance, - and ``group`` together with ``time`` uniquely identify a single observation - within a time series instance. - If ``None``, the dataset is assumed to be a single time series. - weight : str, optional, default=None - column name for weights. - If ``None``, it is assumed that there is no weight column. - num : list of str, optional, default = all columns with dtype in "fi" - list of numerical variables in ``data``, - list may also contain list of str, which are then grouped together. - cat : list of str, optional, default = all columns with dtype in "Obc" - list of categorical variables in ``data``, - list may also contain list of str, which are then grouped together - (e.g. useful for product categories). - known : list of str, optional, default = all variables - list of variables that change over time and are known in the future, - list may also contain list of str, which are then grouped together - (e.g. useful for special days or promotion categories). - unknown : list of str, optional, default = no variables - list of variables that are not known in the future, - list may also contain list of str, which are then grouped together - (e.g. useful for weather categories). - static : list of str, optional, default = all variables not in known, unknown - list of variables that do not change over time, - list may also contain list of str, which are then grouped together. - """ - - def __init__( - self, - data: pd.DataFrame, - data_future: Optional[pd.DataFrame] = None, - time: Optional[str] = None, - target: Optional[Union[str, List[str]]] = None, - group: Optional[List[str]] = None, - weight: Optional[str] = None, - num: Optional[List[Union[str, List[str]]]] = None, - cat: Optional[List[Union[str, List[str]]]] = None, - known: Optional[List[Union[str, List[str]]]] = None, - unknown: Optional[List[Union[str, List[str]]]] = None, - static: Optional[List[Union[str, List[str]]]] = None, - ): - - self.data = data - self.data_future = data_future - self.time = time - self.target = _coerce_to_list(target) - self.group = _coerce_to_list(group) - self.weight = weight - self.num = _coerce_to_list(num) - self.cat = _coerce_to_list(cat) - self.known = _coerce_to_list(known) - self.unknown = _coerce_to_list(unknown) - self.static = _coerce_to_list(static) - - self.feature_cols = [ - col - for col in data.columns - if col not in [self.time] + self.group + [self.weight] + self.target - ] - if self.group: - self._groups = self.data.groupby(self.group).groups - self._group_ids = list(self._groups.keys()) - else: - self._groups = {"_single_group": self.data.index} - self._group_ids = ["_single_group"] - - self._prepare_metadata() - - def _prepare_metadata(self): - """Prepare metadata for the dataset. - - The funcion returns metadata that contains: - - * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } - Names of columns for y, x, and static features. - List elements are in same order as column dimensions. - Columns not appearing are assumed to be named (x0, x1, etc.), - (y0, y1, etc.), (st0, st1, etc.). - * ``col_type``: dict[str, str] - maps column names to data types "F" (numerical) and "C" (categorical). - Column names not occurring are assumed "F". - * ``col_known``: dict[str, str] - maps column names to "K" (future known) or "U" (future unknown). - Column names not occurring are assumed "K". - """ - self.metadata = { - "cols": { - "y": self.target, - "x": self.feature_cols, - "st": self.static, - }, - "col_type": {}, - "col_known": {}, - } - - all_cols = self.target + self.feature_cols + self.static - for col in all_cols: - self.metadata["col_type"][col] = "C" if col in self.cat else "F" - - self.metadata["col_known"][col] = "K" if col in self.known else "U" - - def __len__(self) -> int: - """Return number of time series in the dataset.""" - return len(self._group_ids) - - def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: - """Get time series data for given index. - - Returns - ------- - t : numpy.ndarray of shape (n_timepoints,) - Time index for each time point in the past or present. Aligned with `y`, - and `x` not ending in `f`. - - y : torch.Tensor of shape (n_timepoints, n_targets) - Target values for each time point. Rows are time points, aligned with `t`. - - x : torch.Tensor of shape (n_timepoints, n_features) - Features for each time point. Rows are time points, aligned with `t`. - - group : torch.Tensor of shape (n_groups,) - Group identifiers for time series instances. - - st : torch.Tensor of shape (n_static_features,) - Static features. - - cutoff_time : float or numpy.float64 - Cutoff time for the time series instance. - - Other Returns - ------------- - weights : torch.Tensor of shape (n_timepoints,), optional - Only included if weights are not `None`. - """ - group_id = self._group_ids[index] - - if self.group: - mask = self._groups[group_id] - data = self.data.loc[mask] - else: - data = self.data - - cutoff_time = data[self.time].max() - - result = { - "t": data[self.time].values, - "y": torch.tensor(data[self.target].values), - "x": torch.tensor(data[self.feature_cols].values), - "group": torch.tensor([hash(str(group_id))]), - "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), - "cutoff_time": cutoff_time, - } - - if self.data_future is not None: - if self.group: - future_mask = self.data_future.groupby(self.group).groups[group_id] - future_data = self.data_future.loc[future_mask] - else: - future_data = self.data_future - - combined_times = np.concatenate( - [data[self.time].values, future_data[self.time].values] - ) - combined_times = np.unique(combined_times) - combined_times.sort() - - num_timepoints = len(combined_times) - x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) - y_merged = np.full((num_timepoints, len(self.target)), np.nan) - - current_time_indices = {t: i for i, t in enumerate(combined_times)} - for i, t in enumerate(data[self.time].values): - idx = current_time_indices[t] - x_merged[idx] = data[self.feature_cols].values[i] - y_merged[idx] = data[self.target].values[i] - - for i, t in enumerate(future_data[self.time].values): - if t in current_time_indices: - idx = current_time_indices[t] - for j, col in enumerate(self.known): - if col in self.feature_cols: - feature_idx = self.feature_cols.index(col) - x_merged[idx, feature_idx] = future_data[col].values[i] - - result.update( - { - "t": combined_times, - "x": torch.tensor(x_merged, dtype=torch.float32), - "y": torch.tensor(y_merged, dtype=torch.float32), - } - ) - - if self.weight: - if self.data_future is not None and self.weight in self.data_future.columns: - weights_merged = np.full(num_timepoints, np.nan) - for i, t in enumerate(data[self.time].values): - idx = current_time_indices[t] - weights_merged[idx] = data[self.weight].values[i] - - for i, t in enumerate(future_data[self.time].values): - if t in current_time_indices and self.weight in future_data.columns: - idx = current_time_indices[t] - weights_merged[idx] = future_data[self.weight].values[i] - - result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) - else: - result["weights"] = torch.tensor( - data[self.weight].values, dtype=torch.float32 - ) - - return result - - def get_metadata(self) -> Dict: - """Return metadata about the dataset. - - Returns - ------- - Dict - Dictionary containing: - - cols: column names for y, x, and static features - - col_type: mapping of columns to their types (F/C) - - col_known: mapping of columns to their future known status (K/U) - """ - return self.metadata diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py new file mode 100644 index 000000000..53bf7228d --- /dev/null +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -0,0 +1,276 @@ +""" +Timeseries dataset - v2 prototype. + +Beta version, experimental - use for testing but not in production. +""" + +from typing import Dict, List, Optional, Union +import warnings + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset + +from pytorch_forecasting.data.timeseries._coerce import _coerce_to_list + + +####################################################################################### +# Disclaimer: This dataset class is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the data-handling pipeline may +# look like in the future. +# This is the D1 layer that is a "Raw Dataset Layer" mainly for raw data ingestion +# and turning the data to tensors. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + + +class TimeSeries(Dataset): + """PyTorch Dataset for time series data stored in pandas DataFrame. + + Parameters + ---------- + data : pd.DataFrame + data frame with sequence data. + Column names must all be str, and contain str as referred to below. + data_future : pd.DataFrame, optional, default=None + data frame with future data. + Column names must all be str, and contain str as referred to below. + May contain only columns that are in time, group, weight, known, or static. + time : str, optional, default = first col not in group_ids, weight, target, static. + integer typed column denoting the time index within ``data``. + This column is used to determine the sequence of samples. + If there are no missing observations, + the time index should increase by ``+1`` for each subsequent sample. + The first time_idx for each series does not necessarily + have to be ``0`` but any value is allowed. + target : str or List[str], optional, default = last column (at iloc -1) + column(s) in ``data`` denoting the forecasting target. + Can be categorical or numerical dtype. + group : List[str], optional, default = None + list of column names identifying a time series instance within ``data``. + This means that the ``group`` together uniquely identify an instance, + and ``group`` together with ``time`` uniquely identify a single observation + within a time series instance. + If ``None``, the dataset is assumed to be a single time series. + weight : str, optional, default=None + column name for weights. + If ``None``, it is assumed that there is no weight column. + num : list of str, optional, default = all columns with dtype in "fi" + list of numerical variables in ``data``, + list may also contain list of str, which are then grouped together. + cat : list of str, optional, default = all columns with dtype in "Obc" + list of categorical variables in ``data``, + list may also contain list of str, which are then grouped together + (e.g. useful for product categories). + known : list of str, optional, default = all variables + list of variables that change over time and are known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for special days or promotion categories). + unknown : list of str, optional, default = no variables + list of variables that are not known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for weather categories). + static : list of str, optional, default = all variables not in known, unknown + list of variables that do not change over time, + list may also contain list of str, which are then grouped together. + """ + + def __init__( + self, + data: pd.DataFrame, + data_future: Optional[pd.DataFrame] = None, + time: Optional[str] = None, + target: Optional[Union[str, List[str]]] = None, + group: Optional[List[str]] = None, + weight: Optional[str] = None, + num: Optional[List[Union[str, List[str]]]] = None, + cat: Optional[List[Union[str, List[str]]]] = None, + known: Optional[List[Union[str, List[str]]]] = None, + unknown: Optional[List[Union[str, List[str]]]] = None, + static: Optional[List[Union[str, List[str]]]] = None, + ): + + self.data = data + self.data_future = data_future + self.time = time + self.target = _coerce_to_list(target) + self.group = _coerce_to_list(group) + self.weight = weight + self.num = _coerce_to_list(num) + self.cat = _coerce_to_list(cat) + self.known = _coerce_to_list(known) + self.unknown = _coerce_to_list(unknown) + self.static = _coerce_to_list(static) + + self.feature_cols = [ + col + for col in data.columns + if col not in [self.time] + self.group + [self.weight] + self.target + ] + if self.group: + self._groups = self.data.groupby(self.group).groups + self._group_ids = list(self._groups.keys()) + else: + self._groups = {"_single_group": self.data.index} + self._group_ids = ["_single_group"] + + self._prepare_metadata() + + def _prepare_metadata(self): + """Prepare metadata for the dataset. + + The funcion returns metadata that contains: + + * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } + Names of columns for y, x, and static features. + List elements are in same order as column dimensions. + Columns not appearing are assumed to be named (x0, x1, etc.), + (y0, y1, etc.), (st0, st1, etc.). + * ``col_type``: dict[str, str] + maps column names to data types "F" (numerical) and "C" (categorical). + Column names not occurring are assumed "F". + * ``col_known``: dict[str, str] + maps column names to "K" (future known) or "U" (future unknown). + Column names not occurring are assumed "K". + """ + self.metadata = { + "cols": { + "y": self.target, + "x": self.feature_cols, + "st": self.static, + }, + "col_type": {}, + "col_known": {}, + } + + all_cols = self.target + self.feature_cols + self.static + for col in all_cols: + self.metadata["col_type"][col] = "C" if col in self.cat else "F" + + self.metadata["col_known"][col] = "K" if col in self.known else "U" + + def __len__(self) -> int: + """Return number of time series in the dataset.""" + return len(self._group_ids) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get time series data for given index. + + Returns + ------- + t : numpy.ndarray of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with `y`, + and `x` not ending in `f`. + + y : torch.Tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with `t`. + + x : torch.Tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with `t`. + + group : torch.Tensor of shape (n_groups,) + Group identifiers for time series instances. + + st : torch.Tensor of shape (n_static_features,) + Static features. + + cutoff_time : float or numpy.float64 + Cutoff time for the time series instance. + + Other Returns + ------------- + weights : torch.Tensor of shape (n_timepoints,), optional + Only included if weights are not `None`. + """ + group_id = self._group_ids[index] + + if self.group: + mask = self._groups[group_id] + data = self.data.loc[mask] + else: + data = self.data + + cutoff_time = data[self.time].max() + + result = { + "t": data[self.time].values, + "y": torch.tensor(data[self.target].values), + "x": torch.tensor(data[self.feature_cols].values), + "group": torch.tensor([hash(str(group_id))]), + "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), + "cutoff_time": cutoff_time, + } + + if self.data_future is not None: + if self.group: + future_mask = self.data_future.groupby(self.group).groups[group_id] + future_data = self.data_future.loc[future_mask] + else: + future_data = self.data_future + + combined_times = np.concatenate( + [data[self.time].values, future_data[self.time].values] + ) + combined_times = np.unique(combined_times) + combined_times.sort() + + num_timepoints = len(combined_times) + x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) + y_merged = np.full((num_timepoints, len(self.target)), np.nan) + + current_time_indices = {t: i for i, t in enumerate(combined_times)} + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + x_merged[idx] = data[self.feature_cols].values[i] + y_merged[idx] = data[self.target].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices: + idx = current_time_indices[t] + for j, col in enumerate(self.known): + if col in self.feature_cols: + feature_idx = self.feature_cols.index(col) + x_merged[idx, feature_idx] = future_data[col].values[i] + + result.update( + { + "t": combined_times, + "x": torch.tensor(x_merged, dtype=torch.float32), + "y": torch.tensor(y_merged, dtype=torch.float32), + } + ) + + if self.weight: + if self.data_future is not None and self.weight in self.data_future.columns: + weights_merged = np.full(num_timepoints, np.nan) + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + weights_merged[idx] = data[self.weight].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices and self.weight in future_data.columns: + idx = current_time_indices[t] + weights_merged[idx] = future_data[self.weight].values[i] + + result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) + else: + result["weights"] = torch.tensor( + data[self.weight].values, dtype=torch.float32 + ) + + return result + + def get_metadata(self) -> Dict: + """Return metadata about the dataset. + + Returns + ------- + Dict + Dictionary containing: + - cols: column names for y, x, and static features + - col_type: mapping of columns to their types (F/C) + - col_known: mapping of columns to their future known status (K/U) + """ + return self.metadata From 043820dd3be3041a019fd9cd2cb1e681d25a79a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 30 Apr 2025 18:43:50 +0200 Subject: [PATCH 16/43] warning --- .../data/timeseries/_timeseries_v2.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 53bf7228d..1c91d2525 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -104,6 +104,18 @@ def __init__( self.unknown = _coerce_to_list(unknown) self.static = _coerce_to_list(static) + warnings.warn( + "TimeSeries is part of an experimental rework of the " + "pytorch-forecasting data layer, " + "scheduled for release with v2.0.0. " + "The API is not stable and may change without prior warning. " + "For beta testing, but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) + self.feature_cols = [ col for col in data.columns From 1720a15e9cff3e5c3ebcd0bf3ec03995d068e4b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 13:58:09 +0200 Subject: [PATCH 17/43] linting --- pytorch_forecasting/data/timeseries/__init__.py | 2 +- pytorch_forecasting/data/timeseries/_timeseries_v2.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py index 7734cccf2..85973267a 100644 --- a/pytorch_forecasting/data/timeseries/__init__.py +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -1,7 +1,7 @@ """Data loaders for time series data.""" -from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries from pytorch_forecasting.data.timeseries._timeseries import TimeSeriesDataSet +from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries __all__ = [ "TimeSeriesDataSet", diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 1c91d2525..76972ab4d 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -14,7 +14,6 @@ from pytorch_forecasting.data.timeseries._coerce import _coerce_to_list - ####################################################################################### # Disclaimer: This dataset class is still work in progress and experimental, please # use with care. This class is a basic skeleton of how the data-handling pipeline may From af44474d16b3fcdf5e99acb4b9d1f7345119d8cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:21:58 +0200 Subject: [PATCH 18/43] move coercion to utils --- pytorch_forecasting/data/data_module.py | 6 ++---- pytorch_forecasting/{data/timeseries => utils}/_coerce.py | 0 2 files changed, 2 insertions(+), 4 deletions(-) rename pytorch_forecasting/{data/timeseries => utils}/_coerce.py (100%) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 1203e83ac..9d3ebbedb 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -19,10 +19,8 @@ NaNLabelEncoder, TorchNormalizer, ) -from pytorch_forecasting.data.timeseries import ( - TimeSeries, - _coerce_to_dict, -) +from pytorch_forecasting.data.timeseries import TimeSeries +from pytorch_forecasting.utils._coerce import _coerce_to_dict NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] diff --git a/pytorch_forecasting/data/timeseries/_coerce.py b/pytorch_forecasting/utils/_coerce.py similarity index 100% rename from pytorch_forecasting/data/timeseries/_coerce.py rename to pytorch_forecasting/utils/_coerce.py From a3cb8b736b0b134c8faa97f5ef2993deb28fb75b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:22:18 +0200 Subject: [PATCH 19/43] linting --- pytorch_forecasting/data/timeseries/_timeseries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/_timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py index 263e0ea3a..30fe9e0bb 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -31,8 +31,8 @@ TorchNormalizer, ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler -from pytorch_forecasting.data.timeseries._coerce import _coerce_to_dict, _coerce_to_list from pytorch_forecasting.utils import repr_class +from pytorch_forecasting.utils._coerce import _coerce_to_dict, _coerce_to_list from pytorch_forecasting.utils._dependencies import _check_matplotlib From 75d7fb54d8405ef493197c5a4d2fc86a5e9e9d5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:25:51 +0200 Subject: [PATCH 20/43] Update _timeseries_v2.py --- pytorch_forecasting/data/timeseries/_timeseries_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 76972ab4d..afa45725b 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -12,7 +12,7 @@ import torch from torch.utils.data import Dataset -from pytorch_forecasting.data.timeseries._coerce import _coerce_to_list +from pytorch_forecasting.utils._coerce import _coerce_to_list ####################################################################################### # Disclaimer: This dataset class is still work in progress and experimental, please From 1b946e699be9db2e201a2361779a695356a0460b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:30:13 +0200 Subject: [PATCH 21/43] Update __init__.py --- pytorch_forecasting/data/timeseries/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py index 85973267a..b359a0aa9 100644 --- a/pytorch_forecasting/data/timeseries/__init__.py +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -1,9 +1,15 @@ """Data loaders for time series data.""" -from pytorch_forecasting.data.timeseries._timeseries import TimeSeriesDataSet +from pytorch_forecasting.data.timeseries._timeseries import ( + _find_end_indices, + check_for_nonfinite, + TimeSeriesDataSet, +) from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries __all__ = [ + "_find_end_indices", + "check_for_nonfinite", "TimeSeriesDataSet", "TimeSeries", ] From 3edb08b7ea1b97d06b47b0ebcc83aaef9bec8083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:33:17 +0200 Subject: [PATCH 22/43] Update __init__.py --- pytorch_forecasting/data/timeseries/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py index b359a0aa9..788c08201 100644 --- a/pytorch_forecasting/data/timeseries/__init__.py +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -1,9 +1,9 @@ """Data loaders for time series data.""" from pytorch_forecasting.data.timeseries._timeseries import ( + TimeSeriesDataSet, _find_end_indices, check_for_nonfinite, - TimeSeriesDataSet, ) from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries From e350291c110f567e69946e0e113f2471b7472738 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 11 May 2025 22:10:01 +0530 Subject: [PATCH 23/43] update tests --- tests/test_data/test_data_module.py | 72 ++++++++++++++--------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index c14e3d8f4..4051b852c 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -9,7 +9,7 @@ @pytest.fixture def sample_timeseries_data(): """Create a sample time series dataset with only numerical values.""" - num_groups = 5 + num_groups = 10 seq_length = 100 groups = [] @@ -128,22 +128,22 @@ def test_metadata_property(data_module): assert metadata["decoder_cont"] == 1 # Only known_future marked as known -# def test_setup(data_module): -# """Test the setup method that prepares the datasets.""" -# data_module.setup(stage="fit") -# print(data_module._val_indices) -# assert hasattr(data_module, "train_dataset") -# assert hasattr(data_module, "val_dataset") -# assert len(data_module.train_windows) > 0 -# assert len(data_module.val_windows) > 0 -# -# data_module.setup(stage="test") -# assert hasattr(data_module, "test_dataset") -# assert len(data_module.test_windows) > 0 -# -# data_module.setup(stage="predict") -# assert hasattr(data_module, "predict_dataset") -# assert len(data_module.predict_windows) > 0 +def test_setup(data_module): + """Test the setup method that prepares the datasets.""" + data_module.setup(stage="fit") + print(data_module._val_indices) + assert hasattr(data_module, "train_dataset") + assert hasattr(data_module, "val_dataset") + assert len(data_module.train_windows) > 0 + assert len(data_module.val_windows) > 0 + + data_module.setup(stage="test") + assert hasattr(data_module, "test_dataset") + assert len(data_module.test_windows) > 0 + + data_module.setup(stage="predict") + assert hasattr(data_module, "predict_dataset") + assert len(data_module.predict_windows) > 0 def test_create_windows(data_module): @@ -407,25 +407,25 @@ def test_with_static_features(): assert "static_continuous_features" in x -# def test_different_train_val_test_split(sample_timeseries_data): -# """Test with different train/val/test split ratios.""" -# dm = EncoderDecoderTimeSeriesDataModule( -# time_series_dataset=sample_timeseries_data, -# max_encoder_length=24, -# max_prediction_length=12, -# batch_size=4, -# train_val_test_split=(0.8, 0.1, 0.1), -# ) -# -# dm.setup() -# -# total_series = len(sample_timeseries_data) -# expected_train = int(0.8 * total_series) -# expected_val = int(0.1 * total_series) -# -# assert len(dm._train_indices) == expected_train -# assert len(dm._val_indices) == expected_val -# assert len(dm._test_indices) == total_series - expected_train - expected_val +def test_different_train_val_test_split(sample_timeseries_data): + """Test with different train/val/test split ratios.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=4, + train_val_test_split=(0.8, 0.1, 0.1), + ) + + dm.setup() + + total_series = len(sample_timeseries_data) + expected_train = int(0.8 * total_series) + expected_val = int(0.1 * total_series) + + assert len(dm._train_indices) == expected_train + assert len(dm._val_indices) == expected_val + assert len(dm._test_indices) == total_series - expected_train - expected_val def test_multivariate_target(): From 3099691d3cc792bd528f50ff3c51a0fa4a9ce28a Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 12 May 2025 00:22:27 +0530 Subject: [PATCH 24/43] update tft_v2 --- .../tft_version_two.py | 65 +++++++++++-------- 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py index 30f70f98e..2bfe407d7 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py @@ -36,6 +36,8 @@ def __init__( lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, ) + self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"]) + self.hidden_size = hidden_size self.num_layers = num_layers self.attention_head_size = attention_head_size @@ -47,42 +49,51 @@ def __init__( self.max_prediction_length = self.metadata["max_prediction_length"] self.encoder_cont = self.metadata["encoder_cont"] self.encoder_cat = self.metadata["encoder_cat"] - self.static_categorical_features = self.metadata["static_categorical_features"] - self.static_continuous_features = self.metadata["static_continuous_features"] - - total_feature_size = self.encoder_cont + self.encoder_cat - total_static_size = ( - self.static_categorical_features + self.static_continuous_features - ) - - self.encoder_var_selection = nn.Sequential( - nn.Linear(total_feature_size, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, total_feature_size), - nn.Sigmoid(), - ) - - self.decoder_var_selection = nn.Sequential( - nn.Linear(total_feature_size, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, total_feature_size), - nn.Sigmoid(), - ) + self.encoder_input_dim = self.encoder_cont + self.encoder_cat + self.decoder_cont = self.metadata["decoder_cont"] + self.decoder_cat = self.metadata["decoder_cat"] + self.decoder_input_dim = self.decoder_cont + self.decoder_cat + self.static_cat_dim = self.metadata.get("static_categorical_features", 0) + self.static_cont_dim = self.metadata.get("static_continuous_features", 0) + self.static_input_dim = self.static_cat_dim + self.static_cont_dim + + if self.encoder_input_dim > 0: + self.encoder_var_selection = nn.Sequential( + nn.Linear(self.encoder_input_dim, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, self.encoder_input_dim), + nn.Sigmoid(), + ) + else: + self.encoder_var_selection = None + + if self.decoder_input_dim > 0: + self.decoder_var_selection = nn.Sequential( + nn.Linear(self.decoder_input_dim, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, self.decoder_input_dim), + nn.Sigmoid(), + ) + else: + self.decoder_var_selection = None - self.static_context_linear = ( - nn.Linear(total_static_size, hidden_size) if total_static_size > 0 else None - ) + if self.static_input_dim > 0: + self.static_context_linear = nn.Linear(self.static_input_dim, hidden_size) + else: + self.static_context_linear = None + _lstm_encoder_input_actual_dim = self.encoder_input_dim self.lstm_encoder = nn.LSTM( - input_size=total_feature_size, + input_size=max(1, _lstm_encoder_input_actual_dim), hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True, ) + _lstm_decoder_input_actual_dim = self.decoder_input_dim self.lstm_decoder = nn.LSTM( - input_size=total_feature_size, + input_size=max(1, _lstm_decoder_input_actual_dim), hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, @@ -97,7 +108,7 @@ def __init__( ) self.pre_output = nn.Linear(hidden_size, hidden_size) - self.output_layer = nn.Linear(hidden_size, output_size) + self.output_layer = nn.Linear(hidden_size, self.output_size) def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ From d6e62bbb5da93f0cfc903b1edd9b68a8e2806243 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 12 May 2025 00:51:22 +0530 Subject: [PATCH 25/43] update notebook --- examples/ptf_V2_example.ipynb | 3360 ++++++++++----------------------- 1 file changed, 989 insertions(+), 2371 deletions(-) diff --git a/examples/ptf_V2_example.ipynb b/examples/ptf_V2_example.ipynb index 90d8e93be..2e39108f3 100644 --- a/examples/ptf_V2_example.ipynb +++ b/examples/ptf_V2_example.ipynb @@ -1,2404 +1,1022 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "2630DaOEI4AJ", - "outputId": "a6f99bf0-957b-431a-f512-6abba6629768" + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2630DaOEI4AJ", + "outputId": "b4cc7100-2a9c-41a3-e890-164b90c91c03" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting pytorch-forecasting\n", + " Downloading pytorch_forecasting-1.3.0-py3-none-any.whl.metadata (13 kB)\n", + "Requirement already satisfied: numpy<=3.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.0.2)\n", + "Requirement already satisfied: torch!=2.0.1,<3.0.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.6.0+cu124)\n", + "Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)\n", + " Downloading lightning-2.5.1.post0-py3-none-any.whl.metadata (39 kB)\n", + "Requirement already satisfied: scipy<2.0,>=1.8 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.15.2)\n", + "Requirement already satisfied: pandas<3.0.0,>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.2.2)\n", + "Requirement already satisfied: scikit-learn<2.0,>=1.2 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.6.1)\n", + "Requirement already satisfied: PyYAML<8.0,>=5.4 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.0.2)\n", + "Requirement already satisfied: fsspec<2026.0,>=2022.5.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2025.3.2)\n", + "Collecting lightning-utilities<2.0,>=0.10.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)\n", + "Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (24.2)\n", + "Collecting torchmetrics<3.0,>=0.7.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)\n", + "Requirement already satisfied: tqdm<6.0,>=4.57.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.67.1)\n", + "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.13.2)\n", + "Collecting pytorch-lightning (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading pytorch_lightning-2.5.1.post0-py3-none-any.whl.metadata (20 kB)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (3.6.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.18.0)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.1.6)\n", + "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-curand-cu12==10.3.5.147 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (0.6.2)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (2.21.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", + "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.2.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.0)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.11.15)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from lightning-utilities<2.0,>=0.10.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (75.2.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas<3.0.0,>=1.3.0->pytorch-forecasting) (1.17.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.0.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.6.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.2)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (25.3.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.6.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.4.3)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.3.1)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.20.0)\n", + "Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.11/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.10)\n", + "Downloading pytorch_forecasting-1.3.0-py3-none-any.whl (197 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m197.7/197.7 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning-2.5.1.post0-py3-none-any.whl (819 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m819.0/819.0 kB\u001b[0m \u001b[31m16.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m37.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m20.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m12.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m64.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)\n", + "Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m961.5/961.5 kB\u001b[0m \u001b[31m35.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading pytorch_lightning-2.5.1.post0-py3-none-any.whl (823 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.1/823.1 kB\u001b[0m \u001b[31m37.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch-lightning, lightning, pytorch-forecasting\n", + " Attempting uninstall: nvidia-nvjitlink-cu12\n", + " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", + " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", + " Attempting uninstall: nvidia-curand-cu12\n", + " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", + " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", + " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", + " Attempting uninstall: nvidia-cufft-cu12\n", + " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", + " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", + " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", + " Attempting uninstall: nvidia-cuda-runtime-cu12\n", + " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", + " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-cupti-cu12\n", + " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cublas-cu12\n", + " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", + " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", + " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", + " Attempting uninstall: nvidia-cusparse-cu12\n", + " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", + " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", + " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", + " Attempting uninstall: nvidia-cudnn-cu12\n", + " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", + " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", + " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", + " Attempting uninstall: nvidia-cusolver-cu12\n", + " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", + " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", + " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", + "Successfully installed lightning-2.5.1.post0 lightning-utilities-0.14.3 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-forecasting-1.3.0 pytorch-lightning-2.5.1.post0 torchmetrics-1.7.1\n" + ] + } + ], + "source": [ + "!pip install pytorch-forecasting" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "M7PQerTbI_tM" + }, + "outputs": [], + "source": [ + "from typing import Any, Dict, List, Optional, Tuple, Union\n", + "\n", + "from lightning.pytorch import Trainer\n", + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.preprocessing import RobustScaler, StandardScaler\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import Optimizer\n", + "from torch.utils.data import Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DGTyf3vct-Jk" + }, + "outputs": [], + "source": [ + "from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule\n", + "from pytorch_forecasting.data.encoders import (\n", + " EncoderNormalizer,\n", + " NaNLabelEncoder,\n", + " TorchNormalizer,\n", + ")\n", + "from pytorch_forecasting.data.timeseries import TimeSeries\n", + "from pytorch_forecasting.metrics import MAE, SMAPE\n", + "from pytorch_forecasting.models.temporal_fusion_transformer.tft_version_two import TFT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "WX-FRdusJSVN", + "outputId": "d162e241-3076-415c-db39-8c571bbaa282" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6702424716860947,\n \"min\": -1.2572875930191487,\n \"max\": 1.347291996576924,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.40064266948811306,\n 0.688757012378203,\n -0.9278241195910876\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6742204890063661,\n \"min\": -1.2572875930191487,\n \"max\": 1.347291996576924,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.6039073395571968,\n 0.5832480743546181,\n -0.801772762118357\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2912918986931303,\n \"min\": 0.007584244652032224,\n \"max\": 0.9959799570401108,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.4307103381570838,\n 0.6664272198589233,\n 0.16731443141739688\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", + "type": "dataframe", + "variable_name": "data_df" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting pytorch-forecasting\n", - " Downloading pytorch_forecasting-1.3.0-py3-none-any.whl.metadata (13 kB)\n", - "Requirement already satisfied: numpy<=3.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.0.2)\n", - "Requirement already satisfied: torch!=2.0.1,<3.0.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.6.0+cu124)\n", - "Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)\n", - " Downloading lightning-2.5.1-py3-none-any.whl.metadata (39 kB)\n", - "Requirement already satisfied: scipy<2.0,>=1.8 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.14.1)\n", - "Requirement already satisfied: pandas<3.0.0,>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.2.2)\n", - "Requirement already satisfied: scikit-learn<2.0,>=1.2 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.6.1)\n", - "Requirement already satisfied: PyYAML<8.0,>=5.4 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.0.2)\n", - "Requirement already satisfied: fsspec<2026.0,>=2022.5.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2025.3.2)\n", - "Collecting lightning-utilities<2.0,>=0.10.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)\n", - "Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (24.2)\n", - "Collecting torchmetrics<3.0,>=0.7.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)\n", - "Requirement already satisfied: tqdm<6.0,>=4.57.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.67.1)\n", - "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.13.1)\n", - "Collecting pytorch-lightning (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2.8.2)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", - "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", - "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (1.4.2)\n", - "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (3.6.0)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.18.0)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.4.2)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.1.6)\n", - "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", - "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", - "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Collecting nvidia-curand-cu12==10.3.5.147 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", - "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", - "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (0.6.2)\n", - "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (2.21.5)\n", - "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", - "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.2.0)\n", - "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.13.1)\n", - "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.0)\n", - "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.11.15)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from lightning-utilities<2.0,>=0.10.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (75.2.0)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas<3.0.0,>=1.3.0->pytorch-forecasting) (1.17.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.0.2)\n", - "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.6.1)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.2)\n", - "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (25.3.0)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.5.0)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.3.2)\n", - "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.3.1)\n", - "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.18.3)\n", - "Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.11/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.10)\n", - "Downloading pytorch_forecasting-1.3.0-py3-none-any.whl (197 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m197.7/197.7 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading lightning-2.5.1-py3-none-any.whl (818 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m818.9/818.9 kB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m58.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m55.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m26.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m84.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)\n", - "Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m961.5/961.5 kB\u001b[0m \u001b[31m51.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading pytorch_lightning-2.5.1-py3-none-any.whl (822 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.0/823.0 kB\u001b[0m \u001b[31m40.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch-lightning, lightning, pytorch-forecasting\n", - " Attempting uninstall: nvidia-nvjitlink-cu12\n", - " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", - " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", - " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", - " Attempting uninstall: nvidia-curand-cu12\n", - " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", - " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", - " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", - " Attempting uninstall: nvidia-cufft-cu12\n", - " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", - " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", - " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", - " Attempting uninstall: nvidia-cuda-runtime-cu12\n", - " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", - " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", - " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", - " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", - " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", - " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", - " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", - " Attempting uninstall: nvidia-cuda-cupti-cu12\n", - " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", - " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", - " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", - " Attempting uninstall: nvidia-cublas-cu12\n", - " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", - " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", - " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", - " Attempting uninstall: nvidia-cusparse-cu12\n", - " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", - " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", - " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", - " Attempting uninstall: nvidia-cudnn-cu12\n", - " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", - " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", - " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", - " Attempting uninstall: nvidia-cusolver-cu12\n", - " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", - " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", - " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", - "Successfully installed lightning-2.5.1 lightning-utilities-0.14.3 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-forecasting-1.3.0 pytorch-lightning-2.5.1 torchmetrics-1.7.1\n" - ] - } + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
0000.2009680.23229801.0000000.687290
1010.2322980.33666900.9950040.687290
2020.3366690.63606300.9800670.687290
3030.6360630.92771000.9553360.687290
4040.9277101.00855400.9210610.687290
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" ], - "source": [ - "!pip install pytorch-forecasting" + "text/plain": [ + " series_id time_idx x y category future_known_feature \\\n", + "0 0 0 0.200968 0.232298 0 1.000000 \n", + "1 0 1 0.232298 0.336669 0 0.995004 \n", + "2 0 2 0.336669 0.636063 0 0.980067 \n", + "3 0 3 0.636063 0.927710 0 0.955336 \n", + "4 0 4 0.927710 1.008554 0 0.921061 \n", + "\n", + " static_feature static_feature_cat \n", + "0 0.68729 0 \n", + "1 0.68729 0 \n", + "2 0.68729 0 \n", + "3 0.68729 0 \n", + "4 0.68729 0 " ] }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "M7PQerTbI_tM" + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "\n", + "num_series = 100\n", + "seq_length = 50\n", + "data_list = []\n", + "for i in range(num_series):\n", + " x = np.arange(seq_length)\n", + " y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)\n", + " category = i % 5\n", + " static_value = np.random.rand()\n", + " for t in range(seq_length - 1):\n", + " data_list.append(\n", + " {\n", + " \"series_id\": i,\n", + " \"time_idx\": t,\n", + " \"x\": y[t],\n", + " \"y\": y[t + 1],\n", + " \"category\": category,\n", + " \"future_known_feature\": np.cos(t / 10),\n", + " \"static_feature\": static_value,\n", + " \"static_feature_cat\": i % 3,\n", + " }\n", + " )\n", + "data_df = pd.DataFrame(data_list)\n", + "data_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "AxxPHK6AKSD2", + "outputId": "dd95173d-73c2-451b-8b67-c9cc7298cf9d" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":106: UserWarning: TimeSeries is part of an experimental rework of the pytorch-forecasting data layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. For beta testing, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "dataset = TimeSeries(\n", + " data=data_df,\n", + " time=\"time_idx\",\n", + " target=\"y\",\n", + " group=[\"series_id\"],\n", + " num=[\"x\", \"future_known_feature\", \"static_feature\"],\n", + " cat=[\"category\", \"static_feature_cat\"],\n", + " known=[\"future_known_feature\"],\n", + " unknown=[\"x\", \"category\"],\n", + " static=[\"static_feature\", \"static_feature_cat\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "5U5Lr_ZFKX0s" + }, + "outputs": [], + "source": [ + "data_module = EncoderDecoderTimeSeriesDataModule(\n", + " time_series_dataset=dataset,\n", + " max_encoder_length=30,\n", + " max_prediction_length=1,\n", + " batch_size=32,\n", + " categorical_encoders={\n", + " \"category\": NaNLabelEncoder(add_nan=True),\n", + " \"static_feature_cat\": NaNLabelEncoder(add_nan=True),\n", + " },\n", + " scalers={\n", + " \"x\": StandardScaler(),\n", + " \"future_known_feature\": StandardScaler(),\n", + " \"static_feature\": StandardScaler(),\n", + " },\n", + " target_normalizer=TorchNormalizer(),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "7de178cd43ab4104ba2445a057a5f1a4", + "40598800b1234eaeb769f18e3c27865c", + "05a21a6e18ca46e280a8c66d6a73cf81", + "e4a86f54cf33447f8959864533cbae72", + "d5ba9114135f4ec1818790a62beb865a", + "39ab0f799ae64a8c9c0d6041b17c0ba5", + "6b98518240b744fe8a708b37a2dcaabe", + "274c42415f834e679f45d02d7d1c01d7", + "c88fdff0012c4e4e85563adb36287774", + "c2490a4409f34a608f6073ff6b9426eb", + "1a5e8e6d619740c9b0fd752a8f886b0c", + "ceff330da01b4ed39eafa8820cb6a5ca", + "d1d794dcb83746a48628280e7d552a70", + "31e24662df144d7196e01b873ffed137", + "07c7871be3274f45a93f01c6003b35fb", + "86ab236d48d04f1980836770cfe61b0d", + "73b46d96f2d54ea8b132b409b0739588", + "0b8814890f1d4143842acce4df31d93e", + "7960939d00d844f094e1bcdc1acda7f9", + "2dcacb2bf93c4d16bfb8657670499fc5", + "ed187855e406486ca0ea2259f3e2f43c", + "c73f370b888f4893be08d51b08b23a87", + "ea890982ba5a4c3d8e15e6bbd7285f6a", + "02f4487acf82401b972b4593da15de15", + "cca521da65e946abad580e9db9f2ac6b", + "b4c9075a5c1148ee8e422b2fcf86d90b", + "57dd80719a2e444cab8345bf5086a2f7", + "14e2d31e04e9447ba4f0c55000e0abe2", + "b35a713d4d7041109db487d10c55aaed", + "59ca144ddc884bb5a1c038e96cdf0dc0", + "95f9b03446af41fb83876f625aae5d75", + "bd2154cf3b04468b93a3bec23b5e34fa", + "f4713a710f274dd389691d7c12a7e740", + "07571714d67e4b8793ee76b0fe151e67", + "4e7273caa91147019971bf75ffe71e49", + "b751dbe4e93341eaa7d1a7683c277d83", + "11e1c349cb894caa9fd77333f7ababb9", + "8c9ac67ac6af488fb16e64c834634a30", + "704f28911d674088b8dfc240c6e28449", + "67da733b7a254020ad4d8d3877eb4494", + "72858abca77d430a8d009ad72127b331", + "86e82ab383a54ef7aa7c0abec6faf1be", + "60f097a304044e3db6dda60ff381776a", + "334e7a44226f4414a825e2d81e9571c1", + "13246130b24145d680092a5a3929546e", + "e64a93136fbf468b8e503cd202dbc986", + "559df10442ac42fc8033bdf014864334", + "67e6440ba0424db18646d47beff2e37b", + "36707c56fcfb4fadbfff37d44ee52d5b", + "484702bf9d854cb6a964de47d6975aa6", + "ddb67e4d1749424299a2c40b36810809", + "6f1e8c2aa06a48548a7750869d3f056e", + "d7df1875fbbb4c78909d76c5b81ffd95", + "cb60f83efc234cd6a90212125fe841a1", + "dfca7d9f0f7844ea953ce4b659493695", + "14ed4e565d5a4b319b0c38557a935b92", + "48e7f65e9dcc48289f44724da248875d", + "8220ff572b31472e91dc5d7553ca43b9", + "fcd1b8b77a7f4c108ba02de9779f5bb2", + "a53e4183893242ea884db9f25e439d94", + "e54fdc83afdf492a9540bc892bdc262a", + "04ce55b4c81a46b09d55b8a0dc3fae00", + "86eddf636c2b45fc93a69f3c8a260b1c", + "287ce5b9bafb402e9698f20779f52386", + "ef80c0b7306046738cfd4e16f5af1afc", + "c0bb5f38ec9346119441c6cb8fd71c5a", + "a54b4228eaf34dc1b40ee8d40500e069", + "28cb77e3b48a4c669dcaa79609d332c2", + "8acc35455cb84ece9e7b3d84d8870a7a", + "afe9a01608f449479cc491f75e095d55", + "9b5971f9e8d44e8a872b782ef49d306a", + "1538291371bc45b2bcb7a687c6b8f79c", + "d595a038c20e4d349e392ff1795dd418", + "a67fa2461a18428892e57790074fd5b6", + "6733a54e084a4d348a85e574985720b7", + "204594abf12e43d4abf0fa35544fb64a", + "eee2d2ef09294652bd94ff768ffb99f0", + "c8b8814037464155935670b776378ab9", + "0b6b787693af4b548ad59ad7ba6c921f", + "6b7172a2b1fd4ddea3dba3d435bf36fe", + "ed0ffcd137e94edebad5e41956ad7466", + "0344079173a84c448cd6edbac612d970", + "890f17882dc44e76be06dd88e6d07cff", + "1c552ec98dcd4325800ee8c9dddc398a", + "89d54947a3d146e4b089a22a2711a234", + "91e7ef398d1646bcbfb598480d949c74", + "13173089480c49b18c4ba016da69a855", + "6b717d2759ca445da7642a2cf20f22b5" + ] + }, + "id": "Si7bbZIULBZz", + "outputId": "0b2f26b3-e37c-4ab6-c234-90693745a8cd" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", + "INFO:lightning.pytorch.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", + "INFO: GPU available: False, used: False\n", + "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", + "INFO: TPU available: False, using: 0 TPU cores\n", + "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO: HPU available: False, using: 0 HPUs\n", + "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n", + "INFO: \n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 193 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 50.4 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "123 K Trainable params\n", + "0 Non-trainable params\n", + "123 K Total params\n", + "0.495 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n", + "INFO:lightning.pytorch.callbacks.model_summary:\n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 193 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 50.4 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "123 K Trainable params\n", + "0 Non-trainable params\n", + "123 K Total params\n", + "0.495 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Training model...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7de178cd43ab4104ba2445a057a5f1a4", + "version_major": 2, + "version_minor": 0 }, - "outputs": [], - "source": [ - "from typing import Any, Dict, List, Optional, Tuple, Union\n", - "\n", - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", - "from torch.utils.data import Dataset\n", - "\n", - "from pytorch_forecasting.data.timeseries import _coerce_to_dict" + "text/plain": [ + "Sanity Checking: | | 0/? [00:00 int:\n", - " \"\"\"Return number of time series in the dataset.\"\"\"\n", - " return len(self._group_ids)\n", - "\n", - " def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:\n", - " \"\"\"Get time series data for given index.\n", - "\n", - " Returns\n", - " -------\n", - " t : numpy.ndarray of shape (n_timepoints,)\n", - " Time index for each time point in the past or present. Aligned with `y`,\n", - " and `x` not ending in `f`.\n", - "\n", - " y : torch.Tensor of shape (n_timepoints, n_targets)\n", - " Target values for each time point. Rows are time points, aligned with `t`.\n", - "\n", - " x : torch.Tensor of shape (n_timepoints, n_features)\n", - " Features for each time point. Rows are time points, aligned with `t`.\n", - "\n", - " group : torch.Tensor of shape (n_groups,)\n", - " Group identifiers for time series instances.\n", - "\n", - " st : torch.Tensor of shape (n_static_features,)\n", - " Static features.\n", - "\n", - " cutoff_time : float or numpy.float64\n", - " Cutoff time for the time series instance.\n", - "\n", - " Other Returns\n", - " -------------\n", - " weights : torch.Tensor of shape (n_timepoints,), optional\n", - " Only included if weights are not `None`.\n", - " \"\"\"\n", - " group_id = self._group_ids[index]\n", - "\n", - " if self.group:\n", - " mask = self._groups[group_id]\n", - " data = self.data.loc[mask]\n", - " else:\n", - " data = self.data\n", - "\n", - " cutoff_time = data[self.time].max()\n", - "\n", - " result = {\n", - " \"t\": data[self.time].values,\n", - " \"y\": torch.tensor(data[self.target].values),\n", - " \"x\": torch.tensor(data[self.feature_cols].values),\n", - " \"group\": torch.tensor([hash(str(group_id))]),\n", - " \"st\": torch.tensor(data[self.static].iloc[0].values if self.static else []),\n", - " \"cutoff_time\": cutoff_time,\n", - " }\n", - "\n", - " if self.data_future is not None:\n", - " if self.group:\n", - " future_mask = self.data_future.groupby(self.group).groups[group_id]\n", - " future_data = self.data_future.loc[future_mask]\n", - " else:\n", - " future_data = self.data_future\n", - "\n", - " combined_times = np.concatenate(\n", - " [data[self.time].values, future_data[self.time].values]\n", - " )\n", - " combined_times = np.unique(combined_times)\n", - " combined_times.sort()\n", - "\n", - " num_timepoints = len(combined_times)\n", - " x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan)\n", - " y_merged = np.full((num_timepoints, len(self.target)), np.nan)\n", - "\n", - " current_time_indices = {t: i for i, t in enumerate(combined_times)}\n", - " for i, t in enumerate(data[self.time].values):\n", - " idx = current_time_indices[t]\n", - " x_merged[idx] = data[self.feature_cols].values[i]\n", - " y_merged[idx] = data[self.target].values[i]\n", - "\n", - " for i, t in enumerate(future_data[self.time].values):\n", - " if t in current_time_indices:\n", - " idx = current_time_indices[t]\n", - " for j, col in enumerate(self.known):\n", - " if col in self.feature_cols:\n", - " feature_idx = self.feature_cols.index(col)\n", - " x_merged[idx, feature_idx] = future_data[col].values[i]\n", - "\n", - " result.update(\n", - " {\n", - " \"t\": combined_times,\n", - " \"x\": torch.tensor(x_merged, dtype=torch.float32),\n", - " \"y\": torch.tensor(y_merged, dtype=torch.float32),\n", - " }\n", - " )\n", - "\n", - " if self.weight:\n", - " if self.data_future is not None and self.weight in self.data_future.columns:\n", - " weights_merged = np.full(num_timepoints, np.nan)\n", - " for i, t in enumerate(data[self.time].values):\n", - " idx = current_time_indices[t]\n", - " weights_merged[idx] = data[self.weight].values[i]\n", - "\n", - " for i, t in enumerate(future_data[self.time].values):\n", - " if t in current_time_indices and self.weight in future_data.columns:\n", - " idx = current_time_indices[t]\n", - " weights_merged[idx] = future_data[self.weight].values[i]\n", - "\n", - " result[\"weights\"] = torch.tensor(weights_merged, dtype=torch.float32)\n", - " else:\n", - " result[\"weights\"] = torch.tensor(\n", - " data[self.weight].values, dtype=torch.float32\n", - " )\n", - "\n", - " return result\n", - "\n", - " def get_metadata(self) -> Dict:\n", - " \"\"\"Return metadata about the dataset.\n", - "\n", - " Returns\n", - " -------\n", - " Dict\n", - " Dictionary containing:\n", - " - cols: column names for y, x, and static features\n", - " - col_type: mapping of columns to their types (F/C)\n", - " - col_known: mapping of columns to their future known status (K/U)\n", - " \"\"\"\n", - " return self.metadata" + "text/plain": [ + "Training: | | 0/? [00:00 List[Dict[str, Any]]:\n", - " \"\"\"Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset.\n", - "\n", - " Preprocessing steps\n", - " --------------------\n", - "\n", - " * Converts target (`y`) and features (`x`) to `torch.float32`.\n", - " * Masks time points that are at or before the cutoff time.\n", - " * Splits features into categorical and continuous subsets based on\n", - " predefined indices.\n", - "\n", - "\n", - " TODO: add scalers, target normalizers etc.\n", - " \"\"\"\n", - " sample = self.time_series_dataset[series_idx]\n", - "\n", - " target = sample[\"y\"]\n", - " features = sample[\"x\"]\n", - " times = sample[\"t\"]\n", - " cutoff_time = sample[\"cutoff_time\"]\n", - "\n", - " time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool)\n", - "\n", - " if isinstance(target, torch.Tensor):\n", - " target = target.float()\n", - " else:\n", - " target = torch.tensor(target, dtype=torch.float32)\n", - "\n", - " if isinstance(features, torch.Tensor):\n", - " features = features.float()\n", - " else:\n", - " features = torch.tensor(features, dtype=torch.float32)\n", - "\n", - " # TODO: add scalers, target normalizers etc.\n", - "\n", - " categorical = (\n", - " features[:, self.categorical_indices]\n", - " if self.categorical_indices\n", - " else torch.zeros((features.shape[0], 0))\n", - " )\n", - " continuous = (\n", - " features[:, self.continuous_indices]\n", - " if self.continuous_indices\n", - " else torch.zeros((features.shape[0], 0))\n", - " )\n", - "\n", - " return {\n", - " \"features\": {\"categorical\": categorical, \"continuous\": continuous},\n", - " \"target\": target,\n", - " \"static\": sample.get(\"st\", None),\n", - " \"group\": sample.get(\"group\", torch.tensor([0])),\n", - " \"length\": len(target),\n", - " \"time_mask\": time_mask,\n", - " \"times\": times,\n", - " \"cutoff_time\": cutoff_time,\n", - " }\n", - "\n", - " class _ProcessedEncoderDecoderDataset(Dataset):\n", - " \"\"\"PyTorch Dataset for processed encoder-decoder time series data.\n", - "\n", - " Parameters\n", - " ----------\n", - " dataset : TimeSeries\n", - " The base time series dataset that provides access to raw data and metadata.\n", - " data_module : EncoderDecoderTimeSeriesDataModule\n", - " The data module handling preprocessing and metadata configuration.\n", - " windows : List[Tuple[int, int, int, int]]\n", - " List of window tuples containing\n", - " (series_idx, start_idx, enc_length, pred_length).\n", - " add_relative_time_idx : bool, default=False\n", - " Whether to include relative time indices.\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " dataset: TimeSeries,\n", - " data_module: \"EncoderDecoderTimeSeriesDataModule\",\n", - " windows: List[Tuple[int, int, int, int]],\n", - " add_relative_time_idx: bool = False,\n", - " ):\n", - " self.dataset = dataset\n", - " self.data_module = data_module\n", - " self.windows = windows\n", - " self.add_relative_time_idx = add_relative_time_idx\n", - "\n", - " def __len__(self):\n", - " return len(self.windows)\n", - "\n", - " def __getitem__(self, idx):\n", - " \"\"\"Retrieve a processed time series window for dataloader input.\n", - "\n", - " x : dict\n", - " Dictionary containing model inputs:\n", - "\n", - " * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features)\n", - " Categorical features for the encoder.\n", - " * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features)\n", - " Continuous features for the encoder.\n", - " * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features)\n", - " Categorical features for the decoder.\n", - " * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features)\n", - " Continuous features for the decoder.\n", - " * ``encoder_lengths`` : tensor of shape (1,)\n", - " Length of the encoder sequence.\n", - " * ``decoder_lengths`` : tensor of shape (1,)\n", - " Length of the decoder sequence.\n", - " * ``decoder_target_lengths`` : tensor of shape (1,)\n", - " Length of the decoder target sequence.\n", - " * ``groups`` : tensor of shape (1,)\n", - " Group identifier for the time series instance.\n", - " * ``encoder_time_idx`` : tensor of shape (enc_length,)\n", - " Time indices for the encoder sequence.\n", - " * ``decoder_time_idx`` : tensor of shape (pred_length,)\n", - " Time indices for the decoder sequence.\n", - " * ``target_scale`` : tensor of shape (1,)\n", - " Scaling factor for the target values.\n", - " * ``encoder_mask`` : tensor of shape (enc_length,)\n", - " Boolean mask indicating valid encoder time points.\n", - " * ``decoder_mask`` : tensor of shape (pred_length,)\n", - " Boolean mask indicating valid decoder time points.\n", - "\n", - " If static features are present, the following keys are added:\n", - "\n", - " * ``static_categorical_features`` : tensor of shape\n", - " (1, n_static_cat_features), optional\n", - " Static categorical features, if available.\n", - " * ``static_continuous_features`` : tensor of shape (1, 0), optional\n", - " Placeholder for static continuous features (currently empty).\n", - "\n", - " y : tensor of shape ``(pred_length, n_targets)``\n", - " Target values for the decoder sequence.\n", - " \"\"\"\n", - " series_idx, start_idx, enc_length, pred_length = self.windows[idx]\n", - " data = self.data_module._preprocess_data(series_idx)\n", - "\n", - " end_idx = start_idx + enc_length + pred_length\n", - " encoder_indices = slice(start_idx, start_idx + enc_length)\n", - " decoder_indices = slice(start_idx + enc_length, end_idx)\n", - "\n", - " target_scale = data[\"target\"][encoder_indices]\n", - " target_scale = target_scale[~torch.isnan(target_scale)].abs().mean()\n", - " if torch.isnan(target_scale) or target_scale == 0:\n", - " target_scale = torch.tensor(1.0)\n", - "\n", - " encoder_mask = (\n", - " data[\"time_mask\"][encoder_indices]\n", - " if \"time_mask\" in data\n", - " else torch.ones(enc_length, dtype=torch.bool)\n", - " )\n", - " decoder_mask = (\n", - " data[\"time_mask\"][decoder_indices]\n", - " if \"time_mask\" in data\n", - " else torch.zeros(pred_length, dtype=torch.bool)\n", - " )\n", - "\n", - " x = {\n", - " \"encoder_cat\": data[\"features\"][\"categorical\"][encoder_indices],\n", - " \"encoder_cont\": data[\"features\"][\"continuous\"][encoder_indices],\n", - " \"decoder_cat\": data[\"features\"][\"categorical\"][decoder_indices],\n", - " \"decoder_cont\": data[\"features\"][\"continuous\"][decoder_indices],\n", - " \"encoder_lengths\": torch.tensor(enc_length),\n", - " \"decoder_lengths\": torch.tensor(pred_length),\n", - " \"decoder_target_lengths\": torch.tensor(pred_length),\n", - " \"groups\": data[\"group\"],\n", - " \"encoder_time_idx\": torch.arange(enc_length),\n", - " \"decoder_time_idx\": torch.arange(enc_length, enc_length + pred_length),\n", - " \"target_scale\": target_scale,\n", - " \"encoder_mask\": encoder_mask,\n", - " \"decoder_mask\": decoder_mask,\n", - " }\n", - " if data[\"static\"] is not None:\n", - " x[\"static_categorical_features\"] = data[\"static\"].unsqueeze(0)\n", - " x[\"static_continuous_features\"] = torch.zeros((1, 0))\n", - "\n", - " y = data[\"target\"][decoder_indices]\n", - " if y.ndim == 1:\n", - " y = y.unsqueeze(-1)\n", - "\n", - " return x, y\n", - "\n", - " def _create_windows(self, indices: torch.Tensor) -> List[Tuple[int, int, int, int]]:\n", - " \"\"\"Generate sliding windows for training, validation, and testing.\n", - "\n", - " Returns\n", - " -------\n", - " List[Tuple[int, int, int, int]]\n", - " A list of tuples, where each tuple consists of:\n", - " - ``series_idx`` : int\n", - " Index of the time series in `time_series_dataset`.\n", - " - ``start_idx`` : int\n", - " Start index of the encoder window.\n", - " - ``enc_length`` : int\n", - " Length of the encoder input sequence.\n", - " - ``pred_length`` : int\n", - " Length of the decoder output sequence.\n", - " \"\"\"\n", - " windows = []\n", - "\n", - " for idx in indices:\n", - " series_idx = idx.item()\n", - " sample = self.time_series_dataset[series_idx]\n", - " sequence_length = len(sample[\"y\"])\n", - "\n", - " if sequence_length < self.max_encoder_length + self.max_prediction_length:\n", - " continue\n", - "\n", - " effective_min_prediction_idx = (\n", - " self.min_prediction_idx\n", - " if self.min_prediction_idx is not None\n", - " else self.max_encoder_length\n", - " )\n", - "\n", - " max_prediction_idx = sequence_length - self.max_prediction_length + 1\n", - "\n", - " if max_prediction_idx <= effective_min_prediction_idx:\n", - " continue\n", - "\n", - " for start_idx in range(\n", - " 0, max_prediction_idx - effective_min_prediction_idx\n", - " ):\n", - " if (\n", - " start_idx + self.max_encoder_length + self.max_prediction_length\n", - " <= sequence_length\n", - " ):\n", - " windows.append(\n", - " (\n", - " series_idx,\n", - " start_idx,\n", - " self.max_encoder_length,\n", - " self.max_prediction_length,\n", - " )\n", - " )\n", - "\n", - " return windows\n", - "\n", - " def setup(self, stage: Optional[str] = None):\n", - " \"\"\"Prepare the datasets for training, validation, testing, or prediction.\n", - "\n", - " Parameters\n", - " ----------\n", - " stage : Optional[str], default=None\n", - " Specifies the stage of setup. Can be one of:\n", - " - ``\"fit\"`` : Prepares training and validation datasets.\n", - " - ``\"test\"`` : Prepares the test dataset.\n", - " - ``\"predict\"`` : Prepares the dataset for inference.\n", - " - ``None`` : Prepares ``fit`` datasets.\n", - " \"\"\"\n", - " total_series = len(self.time_series_dataset)\n", - " self._split_indices = torch.randperm(total_series)\n", - "\n", - " self._train_size = int(self.train_val_test_split[0] * total_series)\n", - " self._val_size = int(self.train_val_test_split[1] * total_series)\n", - "\n", - " self._train_indices = self._split_indices[: self._train_size]\n", - " self._val_indices = self._split_indices[\n", - " self._train_size : self._train_size + self._val_size\n", - " ]\n", - " self._test_indices = self._split_indices[self._train_size + self._val_size :]\n", - "\n", - " if stage is None or stage == \"fit\":\n", - " if not hasattr(self, \"train_dataset\") or not hasattr(self, \"val_dataset\"):\n", - " self.train_windows = self._create_windows(self._train_indices)\n", - " self.val_windows = self._create_windows(self._val_indices)\n", - "\n", - " self.train_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.time_series_dataset,\n", - " self,\n", - " self.train_windows,\n", - " self.add_relative_time_idx,\n", - " )\n", - " self.val_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.time_series_dataset,\n", - " self,\n", - " self.val_windows,\n", - " self.add_relative_time_idx,\n", - " )\n", - "\n", - " elif stage == \"test\":\n", - " if not hasattr(self, \"test_dataset\"):\n", - " self.test_windows = self._create_windows(self._test_indices)\n", - " self.test_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.time_series_dataset,\n", - " self,\n", - " self.test_windows,\n", - " self.add_relative_time_idx,\n", - " )\n", - " elif stage == \"predict\":\n", - " predict_indices = torch.arange(len(self.time_series_dataset))\n", - " self.predict_windows = self._create_windows(predict_indices)\n", - " self.predict_dataset = self._ProcessedEncoderDecoderDataset(\n", - " self.time_series_dataset,\n", - " self,\n", - " self.predict_windows,\n", - " self.add_relative_time_idx,\n", - " )\n", - "\n", - " def train_dataloader(self):\n", - " return DataLoader(\n", - " self.train_dataset,\n", - " batch_size=self.batch_size,\n", - " num_workers=self.num_workers,\n", - " shuffle=True,\n", - " collate_fn=self.collate_fn,\n", - " )\n", - "\n", - " def val_dataloader(self):\n", - " return DataLoader(\n", - " self.val_dataset,\n", - " batch_size=self.batch_size,\n", - " num_workers=self.num_workers,\n", - " collate_fn=self.collate_fn,\n", - " )\n", - "\n", - " def test_dataloader(self):\n", - " return DataLoader(\n", - " self.test_dataset,\n", - " batch_size=self.batch_size,\n", - " num_workers=self.num_workers,\n", - " collate_fn=self.collate_fn,\n", - " )\n", - "\n", - " def predict_dataloader(self):\n", - " return DataLoader(\n", - " self.predict_dataset,\n", - " batch_size=self.batch_size,\n", - " num_workers=self.num_workers,\n", - " collate_fn=self.collate_fn,\n", - " )\n", - "\n", - " @staticmethod\n", - " def collate_fn(batch):\n", - " x_batch = {\n", - " \"encoder_cat\": torch.stack([x[\"encoder_cat\"] for x, _ in batch]),\n", - " \"encoder_cont\": torch.stack([x[\"encoder_cont\"] for x, _ in batch]),\n", - " \"decoder_cat\": torch.stack([x[\"decoder_cat\"] for x, _ in batch]),\n", - " \"decoder_cont\": torch.stack([x[\"decoder_cont\"] for x, _ in batch]),\n", - " \"encoder_lengths\": torch.stack([x[\"encoder_lengths\"] for x, _ in batch]),\n", - " \"decoder_lengths\": torch.stack([x[\"decoder_lengths\"] for x, _ in batch]),\n", - " \"decoder_target_lengths\": torch.stack(\n", - " [x[\"decoder_target_lengths\"] for x, _ in batch]\n", - " ),\n", - " \"groups\": torch.stack([x[\"groups\"] for x, _ in batch]),\n", - " \"encoder_time_idx\": torch.stack([x[\"encoder_time_idx\"] for x, _ in batch]),\n", - " \"decoder_time_idx\": torch.stack([x[\"decoder_time_idx\"] for x, _ in batch]),\n", - " \"target_scale\": torch.stack([x[\"target_scale\"] for x, _ in batch]),\n", - " \"encoder_mask\": torch.stack([x[\"encoder_mask\"] for x, _ in batch]),\n", - " \"decoder_mask\": torch.stack([x[\"decoder_mask\"] for x, _ in batch]),\n", - " }\n", - "\n", - " if \"static_categorical_features\" in batch[0][0]:\n", - " x_batch[\"static_categorical_features\"] = torch.stack(\n", - " [x[\"static_categorical_features\"] for x, _ in batch]\n", - " )\n", - " x_batch[\"static_continuous_features\"] = torch.stack(\n", - " [x[\"static_continuous_features\"] for x, _ in batch]\n", - " )\n", - "\n", - " y_batch = torch.stack([y for _, y in batch])\n", - " return x_batch, y_batch" + "text/plain": [ + "Validation: | | 0/? [00:00\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
000-0.0712220.33976301.0000000.6264260
1010.3397630.18934800.9950040.6264260
2020.1893480.67598900.9800670.6264260
3030.6759890.79726100.9553360.6264260
4040.7972610.99501600.9210610.6264260
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "\n", - "
\n", - " \n", - "\n", - "\n", - "\n", - " \n", - "
\n", - "\n", - "
\n", - " \n" - ], - "text/plain": [ - " series_id time_idx x y category future_known_feature \\\n", - "0 0 0 -0.071222 0.339763 0 1.000000 \n", - "1 0 1 0.339763 0.189348 0 0.995004 \n", - "2 0 2 0.189348 0.675989 0 0.980067 \n", - "3 0 3 0.675989 0.797261 0 0.955336 \n", - "4 0 4 0.797261 0.995016 0 0.921061 \n", - "\n", - " static_feature static_feature_cat \n", - "0 0.626426 0 \n", - "1 0.626426 0 \n", - "2 0.626426 0 \n", - "3 0.626426 0 \n", - "4 0.626426 0 " - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from lightning.pytorch import Trainer\n", - "import pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "\n", - "from pytorch_forecasting.metrics import MAE, SMAPE\n", - "\n", - "num_series = 100\n", - "seq_length = 50\n", - "data_list = []\n", - "for i in range(num_series):\n", - " x = np.arange(seq_length)\n", - " y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)\n", - " category = i % 5\n", - " static_value = np.random.rand()\n", - " for t in range(seq_length - 1):\n", - " data_list.append(\n", - " {\n", - " \"series_id\": i,\n", - " \"time_idx\": t,\n", - " \"x\": y[t],\n", - " \"y\": y[t + 1],\n", - " \"category\": category,\n", - " \"future_known_feature\": np.cos(t / 10),\n", - " \"static_feature\": static_value,\n", - " \"static_feature_cat\": i % 3,\n", - " }\n", - " )\n", - "data_df = pd.DataFrame(data_list)\n", - "data_df.head()" + "text/plain": [ + "Validation: | | 0/? [00:00 Dict[str, torch.Tensor]:\n", - " \"\"\"\n", - " Forward pass of the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " x : Dict[str, torch.Tensor]\n", - " Dictionary containing input tensors\n", - "\n", - " Returns\n", - " -------\n", - " Dict[str, torch.Tensor]\n", - " Dictionary containing output tensors\n", - " \"\"\"\n", - " raise NotImplementedError(\"Forward method must be implemented by subclass.\")\n", - "\n", - " def training_step(\n", - " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", - " ) -> STEP_OUTPUT:\n", - " \"\"\"\n", - " Training step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input and target tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - "\n", - " Returns\n", - " -------\n", - " STEP_OUTPUT\n", - " Dictionary containing the loss and other metrics.\n", - " \"\"\"\n", - " x, y = batch\n", - " y_hat_dict = self(x)\n", - " y_hat = y_hat_dict[\"prediction\"]\n", - " loss = self.loss(y_hat, y)\n", - " self.log(\n", - " \"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True\n", - " )\n", - " self.log_metrics(y_hat, y, prefix=\"train\")\n", - " return {\"loss\": loss}\n", - "\n", - " def validation_step(\n", - " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", - " ) -> STEP_OUTPUT:\n", - " \"\"\"\n", - " Validation step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input and target tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - "\n", - " Returns\n", - " -------\n", - " STEP_OUTPUT\n", - " Dictionary containing the loss and other metrics.\n", - " \"\"\"\n", - " x, y = batch\n", - " y_hat_dict = self(x)\n", - " y_hat = y_hat_dict[\"prediction\"]\n", - " loss = self.loss(y_hat, y)\n", - " self.log(\n", - " \"val_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", - " )\n", - " self.log_metrics(y_hat, y, prefix=\"val\")\n", - " return {\"val_loss\": loss}\n", - "\n", - " def test_step(\n", - " self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int\n", - " ) -> STEP_OUTPUT:\n", - " \"\"\"\n", - " Test step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input and target tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - "\n", - " Returns\n", - " -------\n", - " STEP_OUTPUT\n", - " Dictionary containing the loss and other metrics.\n", - " \"\"\"\n", - " x, y = batch\n", - " y_hat_dict = self(x)\n", - " y_hat = y_hat_dict[\"prediction\"]\n", - " loss = self.loss(y_hat, y)\n", - " self.log(\n", - " \"test_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True\n", - " )\n", - " self.log_metrics(y_hat, y, prefix=\"test\")\n", - " return {\"test_loss\": loss}\n", - "\n", - " def predict_step(\n", - " self,\n", - " batch: Tuple[Dict[str, torch.Tensor]],\n", - " batch_idx: int,\n", - " dataloader_idx: int = 0,\n", - " ) -> torch.Tensor:\n", - " \"\"\"\n", - " Prediction step for the model.\n", - "\n", - " Parameters\n", - " ----------\n", - " batch : Tuple[Dict[str, torch.Tensor]]\n", - " Batch of data containing input tensors.\n", - " batch_idx : int\n", - " Index of the batch.\n", - " dataloader_idx : int\n", - " Index of the dataloader.\n", - "\n", - " Returns\n", - " -------\n", - " torch.Tensor\n", - " Predicted output tensor.\n", - " \"\"\"\n", - " x, _ = batch\n", - " y_hat = self(x)\n", - " return y_hat\n", - "\n", - " def configure_optimizers(self) -> Dict:\n", - " \"\"\"\n", - " Configure the optimizer and learning rate scheduler.\n", - "\n", - " Returns\n", - " -------\n", - " Dict\n", - " Dictionary containing the optimizer and scheduler configuration.\n", - " \"\"\"\n", - " optimizer = self._get_optimizer()\n", - " if self.lr_scheduler is not None:\n", - " scheduler = self._get_scheduler(optimizer)\n", - " if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n", - " return {\n", - " \"optimizer\": optimizer,\n", - " \"lr_scheduler\": {\n", - " \"scheduler\": scheduler,\n", - " \"monitor\": \"val_loss\",\n", - " },\n", - " }\n", - " else:\n", - " return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n", - " return {\"optimizer\": optimizer}\n", - "\n", - " def _get_optimizer(self) -> Optimizer:\n", - " \"\"\"\n", - " Get the optimizer based on the specified optimizer name and parameters.\n", - "\n", - " Returns\n", - " -------\n", - " Optimizer\n", - " The optimizer instance.\n", - " \"\"\"\n", - " if isinstance(self.optimizer, str):\n", - " if self.optimizer.lower() == \"adam\":\n", - " return torch.optim.Adam(self.parameters(), **self.optimizer_params)\n", - " elif self.optimizer.lower() == \"sgd\":\n", - " return torch.optim.SGD(self.parameters(), **self.optimizer_params)\n", - " else:\n", - " raise ValueError(f\"Optimizer {self.optimizer} not supported.\")\n", - " elif isinstance(self.optimizer, Optimizer):\n", - " return self.optimizer\n", - " else:\n", - " raise ValueError(\n", - " \"Optimizer must be either a string or \"\n", - " \"an instance of torch.optim.Optimizer.\"\n", - " )\n", - "\n", - " def _get_scheduler(\n", - " self, optimizer: Optimizer\n", - " ) -> torch.optim.lr_scheduler._LRScheduler:\n", - " \"\"\"\n", - " Get the lr scheduler based on the specified scheduler name and params.\n", - "\n", - " Parameters\n", - " ----------\n", - " optimizer : Optimizer\n", - " The optimizer instance.\n", - "\n", - " Returns\n", - " -------\n", - " torch.optim.lr_scheduler._LRScheduler\n", - " The learning rate scheduler instance.\n", - " \"\"\"\n", - " if self.lr_scheduler.lower() == \"reduce_lr_on_plateau\":\n", - " return torch.optim.lr_scheduler.ReduceLROnPlateau(\n", - " optimizer, **self.lr_scheduler_params\n", - " )\n", - " elif self.lr_scheduler.lower() == \"step_lr\":\n", - " return torch.optim.lr_scheduler.StepLR(\n", - " optimizer, **self.lr_scheduler_params\n", - " )\n", - " else:\n", - " raise ValueError(f\"Scheduler {self.lr_scheduler} not supported.\")\n", - "\n", - " def log_metrics(\n", - " self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = \"val\"\n", - " ) -> None:\n", - " \"\"\"\n", - " Log additional metrics during training, validation, or testing.\n", - "\n", - " Parameters\n", - " ----------\n", - " y_hat : torch.Tensor\n", - " Predicted output tensor.\n", - " y : torch.Tensor\n", - " Target output tensor.\n", - " prefix : str\n", - " Prefix for the logged metrics (e.g., \"train\", \"val\", \"test\").\n", - " \"\"\"\n", - " for metric in self.logging_metrics:\n", - " metric_value = metric(y_hat, y)\n", - " self.log(\n", - " f\"{prefix}_{metric.__class__.__name__}\",\n", - " metric_value,\n", - " on_step=False,\n", - " on_epoch=True,\n", - " prog_bar=True,\n", - " logger=True,\n", - " )" + "text/plain": [ + "Validation: | | 0/? [00:00 0 else None\n", - " )\n", - "\n", - " self.lstm_encoder = nn.LSTM(\n", - " input_size=total_feature_size,\n", - " hidden_size=hidden_size,\n", - " num_layers=num_layers,\n", - " dropout=dropout,\n", - " batch_first=True,\n", - " )\n", - "\n", - " self.lstm_decoder = nn.LSTM(\n", - " input_size=total_feature_size,\n", - " hidden_size=hidden_size,\n", - " num_layers=num_layers,\n", - " dropout=dropout,\n", - " batch_first=True,\n", - " )\n", - "\n", - " self.self_attention = nn.MultiheadAttention(\n", - " embed_dim=hidden_size,\n", - " num_heads=attention_head_size,\n", - " dropout=dropout,\n", - " batch_first=True,\n", - " )\n", - "\n", - " self.pre_output = nn.Linear(hidden_size, hidden_size)\n", - " self.output_layer = nn.Linear(hidden_size, output_size)\n", - "\n", - " def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n", - " \"\"\"\n", - " Forward pass of the TFT model.\n", - "\n", - " Parameters\n", - " ----------\n", - " x : Dict[str, torch.Tensor]\n", - " Dictionary containing input tensors:\n", - " - encoder_cat: Categorical encoder features\n", - " - encoder_cont: Continuous encoder features\n", - " - decoder_cat: Categorical decoder features\n", - " - decoder_cont: Continuous decoder features\n", - " - static_categorical_features: Static categorical features\n", - " - static_continuous_features: Static continuous features\n", - "\n", - " Returns\n", - " -------\n", - " Dict[str, torch.Tensor]\n", - " Dictionary containing output tensors:\n", - " - prediction: Prediction output (batch_size, prediction_length, output_size)\n", - " \"\"\"\n", - " batch_size = x[\"encoder_cont\"].shape[0]\n", - "\n", - " encoder_cat = x.get(\n", - " \"encoder_cat\",\n", - " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", - " )\n", - " encoder_cont = x.get(\n", - " \"encoder_cont\",\n", - " torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),\n", - " )\n", - " decoder_cat = x.get(\n", - " \"decoder_cat\",\n", - " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", - " )\n", - " decoder_cont = x.get(\n", - " \"decoder_cont\",\n", - " torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),\n", - " )\n", - "\n", - " encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2)\n", - " decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2)\n", - "\n", - " static_context = None\n", - " if self.static_context_linear is not None:\n", - " static_cat = x.get(\n", - " \"static_categorical_features\",\n", - " torch.zeros(batch_size, 0, device=self.device),\n", - " )\n", - " static_cont = x.get(\n", - " \"static_continuous_features\",\n", - " torch.zeros(batch_size, 0, device=self.device),\n", - " )\n", - "\n", - " if static_cat.size(2) == 0 and static_cont.size(2) == 0:\n", - " static_context = None\n", - " elif static_cat.size(2) == 0:\n", - " static_input = static_cont.to(\n", - " dtype=self.static_context_linear.weight.dtype\n", - " )\n", - " static_context = self.static_context_linear(static_input)\n", - " static_context = static_context.view(batch_size, self.hidden_size)\n", - " elif static_cont.size(2) == 0:\n", - " static_input = static_cat.to(\n", - " dtype=self.static_context_linear.weight.dtype\n", - " )\n", - " static_context = self.static_context_linear(static_input)\n", - " static_context = static_context.view(batch_size, self.hidden_size)\n", - " else:\n", - "\n", - " static_input = torch.cat([static_cont, static_cat], dim=1).to(\n", - " dtype=self.static_context_linear.weight.dtype\n", - " )\n", - " static_context = self.static_context_linear(static_input)\n", - " static_context = static_context.view(batch_size, self.hidden_size)\n", - "\n", - " encoder_weights = self.encoder_var_selection(encoder_input)\n", - " encoder_input = encoder_input * encoder_weights\n", - "\n", - " decoder_weights = self.decoder_var_selection(decoder_input)\n", - " decoder_input = decoder_input * decoder_weights\n", - "\n", - " if static_context is not None:\n", - " encoder_static_context = static_context.unsqueeze(1).expand(\n", - " -1, self.max_encoder_length, -1\n", - " )\n", - " decoder_static_context = static_context.unsqueeze(1).expand(\n", - " -1, self.max_prediction_length, -1\n", - " )\n", - "\n", - " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", - " encoder_output = encoder_output + encoder_static_context\n", - " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", - " decoder_output = decoder_output + decoder_static_context\n", - " else:\n", - " encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)\n", - " decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))\n", - "\n", - " sequence = torch.cat([encoder_output, decoder_output], dim=1)\n", - "\n", - " if static_context is not None:\n", - " expanded_static_context = static_context.unsqueeze(1).expand(\n", - " -1, sequence.size(1), -1\n", - " )\n", - "\n", - " attended_output, _ = self.self_attention(\n", - " sequence + expanded_static_context, sequence, sequence\n", - " )\n", - " else:\n", - " attended_output, _ = self.self_attention(sequence, sequence, sequence)\n", - "\n", - " decoder_attended = attended_output[:, -self.max_prediction_length :, :]\n", - "\n", - " output = nn.functional.relu(self.pre_output(decoder_attended))\n", - " prediction = self.output_layer(output)\n", - "\n", - " return {\"prediction\": prediction}" + "text/plain": [ + "Testing: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃ Test metric DataLoader 0 ┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ test_MAE 0.4637378454208374 │\n", - "│ test_SMAPE 1.0857858657836914 │\n", - "│ test_loss 0.014832879416644573 │\n", - "└───────────────────────────┴───────────────────────────┘\n", - "\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.4637378454208374 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0857858657836914 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.014832879416644573 \u001b[0m\u001b[35m \u001b[0m│\n", - "└───────────────────────────┴───────────────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Prediction shape: torch.Size([32, 1, 1])\n", - "First prediction values: [[-0.00597369]]\n", - "First true values: [[-0.09480439]]\n", - "\n", - "TFT model test complete!\n" - ] - } + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃        Test metric               DataLoader 0        ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│         test_MAE              0.45287469029426575    │\n",
+       "│        test_SMAPE              0.942494809627533     │\n",
+       "│         test_loss             0.01396977063268423    │\n",
+       "└───────────────────────────┴───────────────────────────┘\n",
+       "
\n" ], - "source": [ - "model = TFT(\n", - " loss=nn.MSELoss(),\n", - " logging_metrics=[MAE(), SMAPE()],\n", - " optimizer=\"adam\",\n", - " optimizer_params={\"lr\": 1e-3},\n", - " lr_scheduler=\"reduce_lr_on_plateau\",\n", - " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", - " hidden_size=64,\n", - " num_layers=2,\n", - " attention_head_size=4,\n", - " dropout=0.1,\n", - " metadata=data_module.metadata,\n", - ")\n", - "\n", - "print(\"\\nTraining model...\")\n", - "trainer = Trainer(max_epochs=5, accelerator=\"auto\", devices=1, enable_progress_bar=True)\n", - "\n", - "trainer.fit(model, data_module)\n", - "\n", - "print(\"\\nEvaluating model...\")\n", - "test_metrics = trainer.test(model, data_module)\n", - "\n", - "model.eval()\n", - "with torch.no_grad():\n", - " test_batch = next(iter(data_module.test_dataloader()))\n", - " x_test, y_test = test_batch\n", - " y_pred = model(x_test)\n", - "\n", - " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", - " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", - " print(\"First true values:\", y_test[0].cpu().numpy())\n", - "print(\"\\nTFT model test complete!\")" + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.45287469029426575 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.942494809627533 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.01396977063268423 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zVRwi2MvLGgc" - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } + "metadata": {}, + "output_type": "display_data" }, - "nbformat": 4, - "nbformat_minor": 0 - } + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Prediction shape: torch.Size([32, 1, 1])\n", + "First prediction values: [[0.11045122]]\n", + "First true values: [[-0.0491814]]\n", + "\n", + "TFT model test complete!\n" + ] + } + ], + "source": [ + "model = TFT(\n", + " loss=nn.MSELoss(),\n", + " logging_metrics=[MAE(), SMAPE()],\n", + " optimizer=\"adam\",\n", + " optimizer_params={\"lr\": 1e-3},\n", + " lr_scheduler=\"reduce_lr_on_plateau\",\n", + " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", + " hidden_size=64,\n", + " num_layers=2,\n", + " attention_head_size=4,\n", + " dropout=0.1,\n", + " metadata=data_module.metadata,\n", + ")\n", + "\n", + "print(\"\\nTraining model...\")\n", + "trainer = Trainer(\n", + " max_epochs=5,\n", + " accelerator=\"auto\",\n", + " devices=1,\n", + " enable_progress_bar=True,\n", + " log_every_n_steps=10,\n", + ")\n", + "\n", + "trainer.fit(model, data_module)\n", + "\n", + "print(\"\\nEvaluating model...\")\n", + "test_metrics = trainer.test(model, data_module)\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " test_batch = next(iter(data_module.test_dataloader()))\n", + " x_test, y_test = test_batch\n", + " y_pred = model(x_test)\n", + "\n", + " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", + " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", + " print(\"First true values:\", y_test[0].cpu().numpy())\n", + "print(\"\\nTFT model test complete!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zVRwi2MvLGgc" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From 77cb979808d83cbcfb4e7c3ed5ffd888c0828d31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 13 May 2025 08:14:03 +0200 Subject: [PATCH 26/43] warnings and init attr handling --- pytorch_forecasting/data/data_module.py | 44 +++++++++---- .../data/timeseries/_timeseries_v2.py | 61 +++++++++++-------- 2 files changed, 67 insertions(+), 38 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 9d3ebbedb..690fb6057 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -8,6 +8,7 @@ ####################################################################################### from typing import Any, Dict, List, Optional, Tuple, Union +from warnings import warn from lightning.pytorch import LightningDataModule from sklearn.preprocessing import RobustScaler, StandardScaler @@ -107,33 +108,50 @@ def __init__( num_workers: int = 0, train_val_test_split: tuple = (0.7, 0.15, 0.15), ): - super().__init__() - self.time_series_dataset = time_series_dataset - self.time_series_metadata = time_series_dataset.get_metadata() + self.time_series_dataset = time_series_dataset self.max_encoder_length = max_encoder_length - self.min_encoder_length = min_encoder_length or max_encoder_length + self.min_encoder_length = min_encoder_length self.max_prediction_length = max_prediction_length - self.min_prediction_length = min_prediction_length or max_prediction_length + self.min_prediction_length = min_prediction_length self.min_prediction_idx = min_prediction_idx - self.allow_missing_timesteps = allow_missing_timesteps self.add_relative_time_idx = add_relative_time_idx self.add_target_scales = add_target_scales self.add_encoder_length = add_encoder_length self.randomize_length = randomize_length - + self.target_normalizer = target_normalizer + self.categorical_encoders = categorical_encoders + self.scalers = scalers self.batch_size = batch_size self.num_workers = num_workers self.train_val_test_split = train_val_test_split + warn( + "TimeSeries is part of an experimental rework of the " + "pytorch-forecasting data layer, " + "scheduled for release with v2.0.0. " + "The API is not stable and may change without prior warning. " + "For beta testing, but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) + + super().__init__() + + # handle defaults and derived attributes if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto": - self.target_normalizer = RobustScaler() + self._target_normalizer = RobustScaler() else: - self.target_normalizer = target_normalizer + self._target_normalizer = target_normalizer - self.categorical_encoders = _coerce_to_dict(categorical_encoders) - self.scalers = _coerce_to_dict(scalers) + self.time_series_metadata = time_series_dataset.get_metadata() + self._min_prediction_length = min_prediction_length or max_prediction_length + self._min_encoder_length = min_encoder_length or max_encoder_length + self._categorical_encoders = _coerce_to_dict(categorical_encoders) + self._scalers = _coerce_to_dict(scalers) self.categorical_indices = [] self.continuous_indices = [] @@ -237,8 +255,8 @@ def _prepare_metadata(self): { "max_encoder_length": self.max_encoder_length, "max_prediction_length": self.max_prediction_length, - "min_encoder_length": self.min_encoder_length, - "min_prediction_length": self.min_prediction_length, + "min_encoder_length": self._min_encoder_length, + "min_prediction_length": self._min_prediction_length, } ) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index afa45725b..1f0ba6820 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -5,7 +5,7 @@ """ from typing import Dict, List, Optional, Union -import warnings +from warnings import warn import numpy as np import pandas as pd @@ -94,16 +94,16 @@ def __init__( self.data = data self.data_future = data_future self.time = time - self.target = _coerce_to_list(target) - self.group = _coerce_to_list(group) + self.target = target + self.group = group self.weight = weight - self.num = _coerce_to_list(num) - self.cat = _coerce_to_list(cat) - self.known = _coerce_to_list(known) - self.unknown = _coerce_to_list(unknown) - self.static = _coerce_to_list(static) + self.num = num + self.cat = cat + self.known = known + self.unknown = unknown + self.static = static - warnings.warn( + warn( "TimeSeries is part of an experimental rework of the " "pytorch-forecasting data layer, " "scheduled for release with v2.0.0. " @@ -115,13 +115,24 @@ def __init__( UserWarning, ) + super.__init__() + + # handle defaults, coercion, and derived attributes + self._target = _coerce_to_list(target) + self._group = _coerce_to_list(group) + self._num = _coerce_to_list(num) + self._cat = _coerce_to_list(cat) + self._known = _coerce_to_list(known) + self._unknown = _coerce_to_list(unknown) + self._static = _coerce_to_list(static) + self.feature_cols = [ col for col in data.columns - if col not in [self.time] + self.group + [self.weight] + self.target + if col not in [self.time] + self._group + [self.weight] + self._target ] - if self.group: - self._groups = self.data.groupby(self.group).groups + if self._group: + self._groups = self.data.groupby(self._group).groups self._group_ids = list(self._groups.keys()) else: self._groups = {"_single_group": self.data.index} @@ -148,19 +159,19 @@ def _prepare_metadata(self): """ self.metadata = { "cols": { - "y": self.target, + "y": self._target, "x": self.feature_cols, - "st": self.static, + "st": self._static, }, "col_type": {}, "col_known": {}, } - all_cols = self.target + self.feature_cols + self.static + all_cols = self._target + self.feature_cols + self._static for col in all_cols: - self.metadata["col_type"][col] = "C" if col in self.cat else "F" + self.metadata["col_type"][col] = "C" if col in self._cat else "F" - self.metadata["col_known"][col] = "K" if col in self.known else "U" + self.metadata["col_known"][col] = "K" if col in self._known else "U" def __len__(self) -> int: """Return number of time series in the dataset.""" @@ -197,7 +208,7 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: """ group_id = self._group_ids[index] - if self.group: + if self._group: mask = self._groups[group_id] data = self.data.loc[mask] else: @@ -207,16 +218,16 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: result = { "t": data[self.time].values, - "y": torch.tensor(data[self.target].values), + "y": torch.tensor(data[self._target].values), "x": torch.tensor(data[self.feature_cols].values), "group": torch.tensor([hash(str(group_id))]), - "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), + "st": torch.tensor(data[self._static].iloc[0].values if self._static else []), "cutoff_time": cutoff_time, } if self.data_future is not None: - if self.group: - future_mask = self.data_future.groupby(self.group).groups[group_id] + if self._group: + future_mask = self.data_future.groupby(self._group).groups[group_id] future_data = self.data_future.loc[future_mask] else: future_data = self.data_future @@ -229,18 +240,18 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: num_timepoints = len(combined_times) x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) - y_merged = np.full((num_timepoints, len(self.target)), np.nan) + y_merged = np.full((num_timepoints, len(self._target)), np.nan) current_time_indices = {t: i for i, t in enumerate(combined_times)} for i, t in enumerate(data[self.time].values): idx = current_time_indices[t] x_merged[idx] = data[self.feature_cols].values[i] - y_merged[idx] = data[self.target].values[i] + y_merged[idx] = data[self._target].values[i] for i, t in enumerate(future_data[self.time].values): if t in current_time_indices: idx = current_time_indices[t] - for j, col in enumerate(self.known): + for j, col in enumerate(self._known): if col in self.feature_cols: feature_idx = self.feature_cols.index(col) x_merged[idx, feature_idx] = future_data[col].values[i] From f8c94e626010d165cf022e0fd3f0a22c994759c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 13 May 2025 08:25:53 +0200 Subject: [PATCH 27/43] simplify TimeSeries.__getitem__ --- .../data/timeseries/_timeseries_v2.py | 73 +++++++++++-------- 1 file changed, 44 insertions(+), 29 deletions(-) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 1f0ba6820..5e24f6454 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -206,54 +206,69 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: weights : torch.Tensor of shape (n_timepoints,), optional Only included if weights are not `None`. """ - group_id = self._group_ids[index] - - if self._group: - mask = self._groups[group_id] + time = self.time + feature_cols = self.feature_cols + _target = self._target + _known = self._known + _static = self._static + _group = self._group + _groups = self._groups + _group_ids = self._group_ids + weight = self.weight + data_future = self.data_future + + group_id = _group_ids[index] + + if _group: + mask = _groups[group_id] data = self.data.loc[mask] else: data = self.data - cutoff_time = data[self.time].max() + cutoff_time = data[time].max() + + data_vals = data[time].values + data_tgt_vals = data[_target].values + data_feat_vals = data[feature_cols].values result = { - "t": data[self.time].values, - "y": torch.tensor(data[self._target].values), - "x": torch.tensor(data[self.feature_cols].values), + "t": data_vals, + "y": torch.tensor(data_tgt_vals), + "x": torch.tensor(data_feat_vals), "group": torch.tensor([hash(str(group_id))]), - "st": torch.tensor(data[self._static].iloc[0].values if self._static else []), + "st": torch.tensor(data[_static].iloc[0].values if _static else []), "cutoff_time": cutoff_time, } - if self.data_future is not None: - if self._group: - future_mask = self.data_future.groupby(self._group).groups[group_id] + if data_future is not None: + if _group: + future_mask = self.data_future.groupby(_group).groups[group_id] future_data = self.data_future.loc[future_mask] else: future_data = self.data_future - combined_times = np.concatenate( - [data[self.time].values, future_data[self.time].values] - ) + data_fut_vals = future_data[time].values + + combined_times = np.concatenate([data_vals, data_fut_vals]) combined_times = np.unique(combined_times) combined_times.sort() num_timepoints = len(combined_times) - x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) - y_merged = np.full((num_timepoints, len(self._target)), np.nan) + x_merged = np.full((num_timepoints, len(feature_cols)), np.nan) + y_merged = np.full((num_timepoints, len(_target)), np.nan) current_time_indices = {t: i for i, t in enumerate(combined_times)} - for i, t in enumerate(data[self.time].values): + for i, t in enumerate(data_vals): idx = current_time_indices[t] - x_merged[idx] = data[self.feature_cols].values[i] - y_merged[idx] = data[self._target].values[i] + x_merged[idx] = data_feat_vals[i] + y_merged[idx] = data_tgt_vals[i] - for i, t in enumerate(future_data[self.time].values): + for i, t in enumerate(data_fut_vals): if t in current_time_indices: idx = current_time_indices[t] - for j, col in enumerate(self._known): - if col in self.feature_cols: - feature_idx = self.feature_cols.index(col) + for j, col in enumerate(_known): + if col in feature_cols: + feature_idx = feature_cols.index(col) x_merged[idx, feature_idx] = future_data[col].values[i] result.update( @@ -264,17 +279,17 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: } ) - if self.weight: + if weight: if self.data_future is not None and self.weight in self.data_future.columns: weights_merged = np.full(num_timepoints, np.nan) - for i, t in enumerate(data[self.time].values): + for i, t in enumerate(data_vals): idx = current_time_indices[t] - weights_merged[idx] = data[self.weight].values[i] + weights_merged[idx] = data[weight].values[i] - for i, t in enumerate(future_data[self.time].values): + for i, t in enumerate(data_fut_vals): if t in current_time_indices and self.weight in future_data.columns: idx = current_time_indices[t] - weights_merged[idx] = future_data[self.weight].values[i] + weights_merged[idx] = future_data[weight].values[i] result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) else: From c289255286540b96ddcf5667851f06edf7af0c7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 13 May 2025 08:36:17 +0200 Subject: [PATCH 28/43] Update _timeseries_v2.py --- pytorch_forecasting/data/timeseries/_timeseries_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 5e24f6454..178b273bc 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -115,7 +115,7 @@ def __init__( UserWarning, ) - super.__init__() + super().__init__() # handle defaults, coercion, and derived attributes self._target = _coerce_to_list(target) From 9467f387287f3ba4a56ef1a1a4673c2215deb355 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 13 May 2025 08:44:38 +0200 Subject: [PATCH 29/43] Update data_module.py --- pytorch_forecasting/data/data_module.py | 65 ++++++++++++------------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 690fb6057..7b0d45312 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -171,39 +171,38 @@ def _prepare_metadata(self): dict dictionary containing the following keys: - * ``encoder_cat``: Number of categorical variables in the encoder. - Computed as ``len(self.categorical_indices)``, which counts the - categorical feature indices. - * ``encoder_cont``: Number of continuous variables in the encoder. - Computed as ``len(self.continuous_indices)``, which counts the - continuous feature indices. - * ``decoder_cat``: Number of categorical variables in the decoder that - are known in advance. - Computed by filtering ``self.time_series_metadata["cols"]["x"]`` - where col_type == "C"(categorical) and col_known == "K" (known) - * ``decoder_cont``: Number of continuous variables in the decoder that - are known in advance. - Computed by filtering ``self.time_series_metadata["cols"]["x"]`` - where col_type == "F"(continuous) and col_known == "K"(known) - * ``target``: Number of target variables. - Computed as ``len(self.time_series_metadata["cols"]["y"])``, which - gives the number of output target columns.. - * ``static_categorical_features``: Number of static categorical features - Computed by filtering ``self.time_series_metadata["cols"]["st"]`` - (static features) where col_type == "C" (categorical). - * ``static_continuous_features``: Number of static continuous features - Computed as difference of - ``len(self.time_series_metadata["cols"]["st"])`` (static features) - and static_categorical_features that gives static continuous feature - * ``max_encoder_length``: maximum encoder length - Taken directly from `self.max_encoder_length`. - * ``max_prediction_length``: maximum prediction length - Taken directly from `self.max_prediction_length`. - * ``min_encoder_length``: minimum encoder length - Taken directly from `self.min_encoder_length`. - * ``min_prediction_length``: minimum prediction length - Taken directly from `self.min_prediction_length`. - + * ``encoder_cat``: Number of categorical variables in the encoder. + Computed as ``len(self.categorical_indices)``, which counts the + categorical feature indices. + * ``encoder_cont``: Number of continuous variables in the encoder. + Computed as ``len(self.continuous_indices)``, which counts the + continuous feature indices. + * ``decoder_cat``: Number of categorical variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "C"(categorical) and col_known == "K" (known) + * ``decoder_cont``: Number of continuous variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "F"(continuous) and col_known == "K"(known) + * ``target``: Number of target variables. + Computed as ``len(self.time_series_metadata["cols"]["y"])``, which + gives the number of output target columns.. + * ``static_categorical_features``: Number of static categorical features + Computed by filtering ``self.time_series_metadata["cols"]["st"]`` + (static features) where col_type == "C" (categorical). + * ``static_continuous_features``: Number of static continuous features + Computed as difference of + ``len(self.time_series_metadata["cols"]["st"])`` (static features) + and static_categorical_features that gives static continuous feature + * ``max_encoder_length``: maximum encoder length + Taken directly from `self.max_encoder_length`. + * ``max_prediction_length``: maximum prediction length + Taken directly from `self.max_prediction_length`. + * ``min_encoder_length``: minimum encoder length + Taken directly from `self.min_encoder_length`. + * ``min_prediction_length``: minimum prediction length + Taken directly from `self.min_prediction_length`. """ encoder_cat_count = len(self.categorical_indices) encoder_cont_count = len(self.continuous_indices) From c3b40ad0f3298e84b70b12a050614da3909799e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 13 May 2025 08:50:43 +0200 Subject: [PATCH 30/43] backwards compat of private/public attrs --- pytorch_forecasting/data/data_module.py | 8 ++++++++ pytorch_forecasting/data/timeseries/_timeseries_v2.py | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 7b0d45312..c8252014d 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -163,6 +163,14 @@ def __init__( else: self.continuous_indices.append(idx) + # overwrite __init__ params for upwards compatibility with AS PRs + # todo: should we avoid this and ensure classes are dataclass-like? + self.min_prediction_length = self._min_prediction_length + self.min_encoder_length = self._min_encoder_length + self.categorical_encoders = self._categorical_encoders + self.scalers = self._scalers + self.target_normalizer = self._target_normalizer + def _prepare_metadata(self): """Prepare metadata for model initialisation. diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 178b273bc..d5ecbcabb 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -140,6 +140,16 @@ def __init__( self._prepare_metadata() + # overwrite __init__ params for upwards compatibility with AS PRs + # todo: should we avoid this and ensure classes are dataclass-like? + self.group = self._group + self.target = self._target + self.num = self._num + self.cat = self._cat + self.known = self._known + self.unknown = self._unknown + self.static = self._static + def _prepare_metadata(self): """Prepare metadata for the dataset. From 38c28dc031ecebddca3385bb0f1c58b4423a1b35 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 18:51:05 +0530 Subject: [PATCH 31/43] add tests --- .../tft_version_two.py | 38 +- tests/test_models/test_tft_v2.py | 367 ++++++++++++++++++ 2 files changed, 398 insertions(+), 7 deletions(-) create mode 100644 tests/test_models/test_tft_v2.py diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py index 2bfe407d7..1a1634356 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py @@ -157,11 +157,11 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if self.static_context_linear is not None: static_cat = x.get( "static_categorical_features", - torch.zeros(batch_size, 0, device=self.device), + torch.zeros(batch_size, 1, 0, device=self.device), ) static_cont = x.get( "static_continuous_features", - torch.zeros(batch_size, 0, device=self.device), + torch.zeros(batch_size, 1, 0, device=self.device), ) if static_cat.size(2) == 0 and static_cont.size(2) == 0: @@ -180,17 +180,41 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: static_context = static_context.view(batch_size, self.hidden_size) else: - static_input = torch.cat([static_cont, static_cat], dim=1).to( + static_input = torch.cat([static_cont, static_cat], dim=2).to( dtype=self.static_context_linear.weight.dtype ) static_context = self.static_context_linear(static_input) static_context = static_context.view(batch_size, self.hidden_size) - encoder_weights = self.encoder_var_selection(encoder_input) - encoder_input = encoder_input * encoder_weights + if self.encoder_var_selection is not None: + encoder_weights = self.encoder_var_selection(encoder_input) + encoder_input = encoder_input * encoder_weights + else: + if self.encoder_input_dim == 0: + encoder_input = torch.zeros( + batch_size, + self.max_encoder_length, + 1, + device=self.device, + dtype=encoder_input.dtype, + ) + else: + encoder_input = encoder_input - decoder_weights = self.decoder_var_selection(decoder_input) - decoder_input = decoder_input * decoder_weights + if self.decoder_var_selection is not None: + decoder_weights = self.decoder_var_selection(decoder_input) + decoder_input = decoder_input * decoder_weights + else: + if self.decoder_input_dim == 0: + decoder_input = torch.zeros( + batch_size, + self.max_prediction_length, + 1, + device=self.device, + dtype=decoder_input.dtype, + ) + else: + decoder_input = decoder_input if static_context is not None: encoder_static_context = static_context.unsqueeze(1).expand( diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py new file mode 100644 index 000000000..e69d3d06d --- /dev/null +++ b/tests/test_models/test_tft_v2.py @@ -0,0 +1,367 @@ +import numpy as np +import pandas as pd +import pytest +import torch +import torch.nn as nn + +from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule +from pytorch_forecasting.data.timeseries import TimeSeries +from pytorch_forecasting.models.temporal_fusion_transformer.tft_version_two import TFT + +BATCH_SIZE_TEST = 2 +MAX_ENCODER_LENGTH_TEST = 10 +MAX_PREDICTION_LENGTH_TEST = 5 +HIDDEN_SIZE_TEST = 8 +OUTPUT_SIZE_TEST = 1 +ATTENTION_HEAD_SIZE_TEST = 2 +NUM_LAYERS_TEST = 1 +DROPOUT_TEST = 0.1 + + +def get_default_test_metadata( + enc_cont=2, + enc_cat=1, + dec_cont=1, + dec_cat=1, + static_cat=1, + static_cont=1, + output_size=OUTPUT_SIZE_TEST, +): + return { + "max_encoder_length": MAX_ENCODER_LENGTH_TEST, + "max_prediction_length": MAX_PREDICTION_LENGTH_TEST, + "encoder_cont": enc_cont, + "encoder_cat": enc_cat, + "decoder_cont": dec_cont, + "decoder_cat": dec_cat, + "static_categorical_features": static_cat, + "static_continuous_features": static_cont, + "target": output_size, + } + + +def create_tft_input_batch_for_test(metadata, batch_size=BATCH_SIZE_TEST, device="cpu"): + def _get_dim_val(key): + return metadata.get(key, 0) + + x = { + "encoder_cont": torch.randn( + batch_size, + metadata["max_encoder_length"], + _get_dim_val("encoder_cont"), + device=device, + ), + "encoder_cat": torch.randn( + batch_size, + metadata["max_encoder_length"], + _get_dim_val("encoder_cat"), + device=device, + ), + "decoder_cont": torch.randn( + batch_size, + metadata["max_prediction_length"], + _get_dim_val("decoder_cont"), + device=device, + ), + "decoder_cat": torch.randn( + batch_size, + metadata["max_prediction_length"], + _get_dim_val("decoder_cat"), + device=device, + ), + "static_categorical_features": torch.randn( + batch_size, 1, _get_dim_val("static_categorical_features"), device=device + ), + "static_continuous_features": torch.randn( + batch_size, 1, _get_dim_val("static_continuous_features"), device=device + ), + "encoder_lengths": torch.full( + (batch_size,), + metadata["max_encoder_length"], + dtype=torch.long, + device=device, + ), + "decoder_lengths": torch.full( + (batch_size,), + metadata["max_prediction_length"], + dtype=torch.long, + device=device, + ), + "groups": torch.arange(batch_size, device=device).unsqueeze(1), + "encoder_time_idx": torch.stack( + [torch.arange(metadata["max_encoder_length"], device=device)] * batch_size + ), + "decoder_time_idx": torch.stack( + [ + torch.arange( + metadata["max_encoder_length"], + metadata["max_encoder_length"] + metadata["max_prediction_length"], + device=device, + ) + ] + * batch_size + ), + "target_scale": torch.ones((batch_size, 1), device=device), + } + return x + + +dummy_loss_for_test = nn.MSELoss() + + +@pytest.fixture(scope="module") +def tft_model_params_fixture_func(): + return { + "loss": dummy_loss_for_test, + "hidden_size": HIDDEN_SIZE_TEST, + "num_layers": NUM_LAYERS_TEST, + "attention_head_size": ATTENTION_HEAD_SIZE_TEST, + "dropout": DROPOUT_TEST, + "output_size": OUTPUT_SIZE_TEST, + } + + +class TestTFTInitialization: + def test_basic_initialization(self, tft_model_params_fixture_func): + metadata = get_default_test_metadata(output_size=OUTPUT_SIZE_TEST) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.hidden_size == HIDDEN_SIZE_TEST + assert model.num_layers == NUM_LAYERS_TEST + assert hasattr(model, "metadata") and model.metadata == metadata + assert ( + model.encoder_input_dim + == metadata["encoder_cont"] + metadata["encoder_cat"] + ) + assert ( + model.static_input_dim + == metadata["static_categorical_features"] + + metadata["static_continuous_features"] + ) + assert isinstance(model.lstm_encoder, nn.LSTM) + assert model.lstm_encoder.input_size == max(1, model.encoder_input_dim) + assert isinstance(model.self_attention, nn.MultiheadAttention) + if hasattr(model, "hparams") and model.hparams: + assert model.hparams.get("hidden_size") == HIDDEN_SIZE_TEST + assert model.output_size == OUTPUT_SIZE_TEST + + def test_initialization_no_time_varying_features( + self, tft_model_params_fixture_func + ): + metadata = get_default_test_metadata( + enc_cont=0, enc_cat=0, dec_cont=0, dec_cat=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.encoder_input_dim == 0 + assert model.encoder_var_selection is None + assert model.lstm_encoder.input_size == 1 + assert model.decoder_input_dim == 0 + assert model.decoder_var_selection is None + assert model.lstm_decoder.input_size == 1 + + def test_initialization_no_static_features(self, tft_model_params_fixture_func): + metadata = get_default_test_metadata( + static_cat=0, static_cont=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.static_input_dim == 0 + assert model.static_context_linear is None + + +class TestTFTForwardPass: + @pytest.mark.parametrize( + "enc_c, enc_k, dec_c, dec_k, stat_c, stat_k", + [ + (2, 1, 1, 1, 1, 1), + (2, 0, 1, 0, 0, 0), + (0, 0, 0, 0, 1, 1), + (0, 0, 0, 0, 0, 0), + (1, 0, 1, 0, 1, 0), + (1, 0, 1, 0, 0, 1), + ], + ) + def test_forward_pass_configs( + self, tft_model_params_fixture_func, enc_c, enc_k, dec_c, dec_k, stat_c, stat_k + ): + current_tft_actual_output_size = tft_model_params_fixture_func["output_size"] + metadata = get_default_test_metadata( + enc_cont=enc_c, + enc_cat=enc_k, + dec_cont=dec_c, + dec_cat=dec_k, + static_cat=stat_c, + static_cont=stat_k, + output_size=current_tft_actual_output_size, + ) + model_params = tft_model_params_fixture_func.copy() + model_params["output_size"] = current_tft_actual_output_size + model = TFT(**model_params, metadata=metadata) + model.eval() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + x = create_tft_input_batch_for_test( + metadata, batch_size=BATCH_SIZE_TEST, device=device + ) + output_dict = model(x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + BATCH_SIZE_TEST, + MAX_PREDICTION_LENGTH_TEST, + current_tft_actual_output_size, + ) + assert not torch.isnan(predictions).any(), "NaNs in prediction" + assert not torch.isinf(predictions).any(), "Infs in prediction" + + +@pytest.fixture +def sample_pandas_data_for_test(): + """Create sample data ensuring all feature columns are numeric (float32).""" + series_len = MAX_ENCODER_LENGTH_TEST + MAX_PREDICTION_LENGTH_TEST + 5 + num_groups = 6 + data = [] + + for i in range(num_groups): + static_cont_val = np.float32(i * 10.0) + static_cat_code = np.float32(i % 2) + + df_group = pd.DataFrame( + { + "time_idx": np.arange(series_len, dtype=np.int64), + "group_id_str": np.repeat(f"g{i}", series_len), + "target": np.random.rand(series_len).astype(np.float32) + i, + "enc_cont1": np.random.rand(series_len).astype(np.float32), + "enc_cat1_codes": np.random.randint(0, 3, series_len).astype( + np.float32 + ), + "dec_known_cont": np.sin(np.arange(series_len) / 5.0).astype( + np.float32 + ), + "dec_known_cat_codes": np.random.randint(0, 2, series_len).astype( + np.float32 + ), + "static_cat_feat_codes": np.full( + series_len, static_cat_code, dtype=np.float32 + ), + "static_cont_feat": np.full( + series_len, static_cont_val, dtype=np.float32 + ), + } + ) + data.append(df_group) + + df = pd.concat(data, ignore_index=True) + + df["group_id"] = df["group_id_str"].astype("category") + df.drop(columns=["group_id_str"], inplace=True) + + return df + + +@pytest.fixture +def timeseries_obj_for_test(sample_pandas_data_for_test): + df = sample_pandas_data_for_test + + return TimeSeries( + data=df, + time="time_idx", + target="target", + group=["group_id"], + num=[ + "enc_cont1", + "enc_cat1_codes", + "dec_known_cont", + "dec_known_cat_codes", + "static_cat_feat_codes", + "static_cont_feat", + ], + cat=[], + known=["dec_known_cont", "dec_known_cat_codes", "time_idx"], + static=["static_cat_feat_codes", "static_cont_feat"], + ) + + +@pytest.fixture +def data_module_for_test(timeseries_obj_for_test): + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=timeseries_obj_for_test, + batch_size=BATCH_SIZE_TEST, + max_encoder_length=MAX_ENCODER_LENGTH_TEST, + max_prediction_length=MAX_PREDICTION_LENGTH_TEST, + train_val_test_split=(0.5, 0.25, 0.25), + num_workers=0, # Added for consistency + ) + dm.setup("fit") + dm.setup("test") + return dm + + +class TestTFTWithDataModule: + def test_model_with_datamodule_integration( + self, tft_model_params_fixture_func, data_module_for_test + ): + dm = data_module_for_test + model_metadata_from_dm = dm.metadata + + assert ( + model_metadata_from_dm["encoder_cont"] == 6 + ), f"Actual encoder_cont: {model_metadata_from_dm['encoder_cont']}" + assert ( + model_metadata_from_dm["encoder_cat"] == 0 + ), f"Actual encoder_cat: {model_metadata_from_dm['encoder_cat']}" + assert ( + model_metadata_from_dm["decoder_cont"] == 2 + ), f"Actual decoder_cont: {model_metadata_from_dm['decoder_cont']}" + assert ( + model_metadata_from_dm["decoder_cat"] == 0 + ), f"Actual decoder_cat: {model_metadata_from_dm['decoder_cat']}" + assert ( + model_metadata_from_dm["static_categorical_features"] == 0 + ), f"Actual static_cat: {model_metadata_from_dm['static_categorical_features']}" + assert ( + model_metadata_from_dm["static_continuous_features"] == 2 + ), f"Actual static_cont: {model_metadata_from_dm['static_continuous_features']}" + assert model_metadata_from_dm["target"] == 1 + + tft_init_args = tft_model_params_fixture_func.copy() + tft_init_args["output_size"] = model_metadata_from_dm["target"] + model = TFT(**tft_init_args, metadata=model_metadata_from_dm) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + + train_loader = dm.train_dataloader() + batch_x, batch_y = next(iter(train_loader)) + + actual_batch_size = batch_x["encoder_cont"].shape[0] + batch_x = {k: v.to(device) for k, v in batch_x.items()} + batch_y = batch_y.to(device) + + assert ( + batch_x["encoder_cont"].shape[2] == model_metadata_from_dm["encoder_cont"] + ) + assert batch_x["encoder_cat"].shape[2] == model_metadata_from_dm["encoder_cat"] + assert ( + batch_x["decoder_cont"].shape[2] == model_metadata_from_dm["decoder_cont"] + ) + assert batch_x["decoder_cat"].shape[2] == model_metadata_from_dm["decoder_cat"] + # assert ( + # batch_x["static_categorical_features"].shape[2] + # == model_metadata_from_dm["static_categorical_features"] + # ) + # assert ( + # batch_x["static_continuous_features"].shape[2] + # == model_metadata_from_dm["static_continuous_features"] + # ) + + output_dict = model(batch_x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + ) + assert not torch.isnan(predictions).any() + assert batch_y.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + ) From 9d80eb822e47c92e3b542cd70fe98103e00bd829 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 19:10:57 +0530 Subject: [PATCH 32/43] add tests --- tests/test_models/test_tft_v2.py | 311 +++++++++++++++---------------- 1 file changed, 152 insertions(+), 159 deletions(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index e69d3d06d..0455ad818 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -121,95 +121,92 @@ def tft_model_params_fixture_func(): } -class TestTFTInitialization: - def test_basic_initialization(self, tft_model_params_fixture_func): - metadata = get_default_test_metadata(output_size=OUTPUT_SIZE_TEST) - model = TFT(**tft_model_params_fixture_func, metadata=metadata) - assert model.hidden_size == HIDDEN_SIZE_TEST - assert model.num_layers == NUM_LAYERS_TEST - assert hasattr(model, "metadata") and model.metadata == metadata - assert ( - model.encoder_input_dim - == metadata["encoder_cont"] + metadata["encoder_cat"] - ) - assert ( - model.static_input_dim - == metadata["static_categorical_features"] - + metadata["static_continuous_features"] - ) - assert isinstance(model.lstm_encoder, nn.LSTM) - assert model.lstm_encoder.input_size == max(1, model.encoder_input_dim) - assert isinstance(model.self_attention, nn.MultiheadAttention) - if hasattr(model, "hparams") and model.hparams: - assert model.hparams.get("hidden_size") == HIDDEN_SIZE_TEST - assert model.output_size == OUTPUT_SIZE_TEST - - def test_initialization_no_time_varying_features( - self, tft_model_params_fixture_func - ): - metadata = get_default_test_metadata( - enc_cont=0, enc_cat=0, dec_cont=0, dec_cat=0, output_size=OUTPUT_SIZE_TEST - ) - model = TFT(**tft_model_params_fixture_func, metadata=metadata) - assert model.encoder_input_dim == 0 - assert model.encoder_var_selection is None - assert model.lstm_encoder.input_size == 1 - assert model.decoder_input_dim == 0 - assert model.decoder_var_selection is None - assert model.lstm_decoder.input_size == 1 - - def test_initialization_no_static_features(self, tft_model_params_fixture_func): - metadata = get_default_test_metadata( - static_cat=0, static_cont=0, output_size=OUTPUT_SIZE_TEST - ) - model = TFT(**tft_model_params_fixture_func, metadata=metadata) - assert model.static_input_dim == 0 - assert model.static_context_linear is None - - -class TestTFTForwardPass: - @pytest.mark.parametrize( - "enc_c, enc_k, dec_c, dec_k, stat_c, stat_k", - [ - (2, 1, 1, 1, 1, 1), - (2, 0, 1, 0, 0, 0), - (0, 0, 0, 0, 1, 1), - (0, 0, 0, 0, 0, 0), - (1, 0, 1, 0, 1, 0), - (1, 0, 1, 0, 0, 1), - ], +# Converted from TestTFTInitialization class +def test_basic_initialization(tft_model_params_fixture_func): + metadata = get_default_test_metadata(output_size=OUTPUT_SIZE_TEST) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.hidden_size == HIDDEN_SIZE_TEST + assert model.num_layers == NUM_LAYERS_TEST + assert hasattr(model, "metadata") and model.metadata == metadata + assert model.encoder_input_dim == metadata["encoder_cont"] + metadata["encoder_cat"] + assert ( + model.static_input_dim + == metadata["static_categorical_features"] + + metadata["static_continuous_features"] ) - def test_forward_pass_configs( - self, tft_model_params_fixture_func, enc_c, enc_k, dec_c, dec_k, stat_c, stat_k - ): - current_tft_actual_output_size = tft_model_params_fixture_func["output_size"] - metadata = get_default_test_metadata( - enc_cont=enc_c, - enc_cat=enc_k, - dec_cont=dec_c, - dec_cat=dec_k, - static_cat=stat_c, - static_cont=stat_k, - output_size=current_tft_actual_output_size, - ) - model_params = tft_model_params_fixture_func.copy() - model_params["output_size"] = current_tft_actual_output_size - model = TFT(**model_params, metadata=metadata) - model.eval() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - x = create_tft_input_batch_for_test( - metadata, batch_size=BATCH_SIZE_TEST, device=device - ) - output_dict = model(x) - predictions = output_dict["prediction"] - assert predictions.shape == ( - BATCH_SIZE_TEST, - MAX_PREDICTION_LENGTH_TEST, - current_tft_actual_output_size, - ) - assert not torch.isnan(predictions).any(), "NaNs in prediction" - assert not torch.isinf(predictions).any(), "Infs in prediction" + assert isinstance(model.lstm_encoder, nn.LSTM) + assert model.lstm_encoder.input_size == max(1, model.encoder_input_dim) + assert isinstance(model.self_attention, nn.MultiheadAttention) + if hasattr(model, "hparams") and model.hparams: + assert model.hparams.get("hidden_size") == HIDDEN_SIZE_TEST + assert model.output_size == OUTPUT_SIZE_TEST + + +def test_initialization_no_time_varying_features(tft_model_params_fixture_func): + metadata = get_default_test_metadata( + enc_cont=0, enc_cat=0, dec_cont=0, dec_cat=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.encoder_input_dim == 0 + assert model.encoder_var_selection is None + assert model.lstm_encoder.input_size == 1 + assert model.decoder_input_dim == 0 + assert model.decoder_var_selection is None + assert model.lstm_decoder.input_size == 1 + + +def test_initialization_no_static_features(tft_model_params_fixture_func): + metadata = get_default_test_metadata( + static_cat=0, static_cont=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.static_input_dim == 0 + assert model.static_context_linear is None + + +# Converted from TestTFTForwardPass class +@pytest.mark.parametrize( + "enc_c, enc_k, dec_c, dec_k, stat_c, stat_k", + [ + (2, 1, 1, 1, 1, 1), + (2, 0, 1, 0, 0, 0), + (0, 0, 0, 0, 1, 1), + (0, 0, 0, 0, 0, 0), + (1, 0, 1, 0, 1, 0), + (1, 0, 1, 0, 0, 1), + ], +) +def test_forward_pass_configs( + tft_model_params_fixture_func, enc_c, enc_k, dec_c, dec_k, stat_c, stat_k +): + current_tft_actual_output_size = tft_model_params_fixture_func["output_size"] + metadata = get_default_test_metadata( + enc_cont=enc_c, + enc_cat=enc_k, + dec_cont=dec_c, + dec_cat=dec_k, + static_cat=stat_c, + static_cont=stat_k, + output_size=current_tft_actual_output_size, + ) + model_params = tft_model_params_fixture_func.copy() + model_params["output_size"] = current_tft_actual_output_size + model = TFT(**model_params, metadata=metadata) + model.eval() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + x = create_tft_input_batch_for_test( + metadata, batch_size=BATCH_SIZE_TEST, device=device + ) + output_dict = model(x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + BATCH_SIZE_TEST, + MAX_PREDICTION_LENGTH_TEST, + current_tft_actual_output_size, + ) + assert not torch.isnan(predictions).any(), "NaNs in prediction" + assert not torch.isinf(predictions).any(), "Infs in prediction" @pytest.fixture @@ -294,74 +291,70 @@ def data_module_for_test(timeseries_obj_for_test): return dm -class TestTFTWithDataModule: - def test_model_with_datamodule_integration( - self, tft_model_params_fixture_func, data_module_for_test - ): - dm = data_module_for_test - model_metadata_from_dm = dm.metadata - - assert ( - model_metadata_from_dm["encoder_cont"] == 6 - ), f"Actual encoder_cont: {model_metadata_from_dm['encoder_cont']}" - assert ( - model_metadata_from_dm["encoder_cat"] == 0 - ), f"Actual encoder_cat: {model_metadata_from_dm['encoder_cat']}" - assert ( - model_metadata_from_dm["decoder_cont"] == 2 - ), f"Actual decoder_cont: {model_metadata_from_dm['decoder_cont']}" - assert ( - model_metadata_from_dm["decoder_cat"] == 0 - ), f"Actual decoder_cat: {model_metadata_from_dm['decoder_cat']}" - assert ( - model_metadata_from_dm["static_categorical_features"] == 0 - ), f"Actual static_cat: {model_metadata_from_dm['static_categorical_features']}" - assert ( - model_metadata_from_dm["static_continuous_features"] == 2 - ), f"Actual static_cont: {model_metadata_from_dm['static_continuous_features']}" - assert model_metadata_from_dm["target"] == 1 - - tft_init_args = tft_model_params_fixture_func.copy() - tft_init_args["output_size"] = model_metadata_from_dm["target"] - model = TFT(**tft_init_args, metadata=model_metadata_from_dm) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - model.eval() - - train_loader = dm.train_dataloader() - batch_x, batch_y = next(iter(train_loader)) - - actual_batch_size = batch_x["encoder_cont"].shape[0] - batch_x = {k: v.to(device) for k, v in batch_x.items()} - batch_y = batch_y.to(device) - - assert ( - batch_x["encoder_cont"].shape[2] == model_metadata_from_dm["encoder_cont"] - ) - assert batch_x["encoder_cat"].shape[2] == model_metadata_from_dm["encoder_cat"] - assert ( - batch_x["decoder_cont"].shape[2] == model_metadata_from_dm["decoder_cont"] - ) - assert batch_x["decoder_cat"].shape[2] == model_metadata_from_dm["decoder_cat"] - # assert ( - # batch_x["static_categorical_features"].shape[2] - # == model_metadata_from_dm["static_categorical_features"] - # ) - # assert ( - # batch_x["static_continuous_features"].shape[2] - # == model_metadata_from_dm["static_continuous_features"] - # ) - - output_dict = model(batch_x) - predictions = output_dict["prediction"] - assert predictions.shape == ( - actual_batch_size, - MAX_PREDICTION_LENGTH_TEST, - model_metadata_from_dm["target"], - ) - assert not torch.isnan(predictions).any() - assert batch_y.shape == ( - actual_batch_size, - MAX_PREDICTION_LENGTH_TEST, - model_metadata_from_dm["target"], - ) +# Converted from TestTFTWithDataModule class +def test_model_with_datamodule_integration( + tft_model_params_fixture_func, data_module_for_test +): + dm = data_module_for_test + model_metadata_from_dm = dm.metadata + + assert ( + model_metadata_from_dm["encoder_cont"] == 6 + ), f"Actual encoder_cont: {model_metadata_from_dm['encoder_cont']}" + assert ( + model_metadata_from_dm["encoder_cat"] == 0 + ), f"Actual encoder_cat: {model_metadata_from_dm['encoder_cat']}" + assert ( + model_metadata_from_dm["decoder_cont"] == 2 + ), f"Actual decoder_cont: {model_metadata_from_dm['decoder_cont']}" + assert ( + model_metadata_from_dm["decoder_cat"] == 0 + ), f"Actual decoder_cat: {model_metadata_from_dm['decoder_cat']}" + assert ( + model_metadata_from_dm["static_categorical_features"] == 0 + ), f"Actual static_cat: {model_metadata_from_dm['static_categorical_features']}" + assert ( + model_metadata_from_dm["static_continuous_features"] == 2 + ), f"Actual static_cont: {model_metadata_from_dm['static_continuous_features']}" + assert model_metadata_from_dm["target"] == 1 + + tft_init_args = tft_model_params_fixture_func.copy() + tft_init_args["output_size"] = model_metadata_from_dm["target"] + model = TFT(**tft_init_args, metadata=model_metadata_from_dm) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + + train_loader = dm.train_dataloader() + batch_x, batch_y = next(iter(train_loader)) + + actual_batch_size = batch_x["encoder_cont"].shape[0] + batch_x = {k: v.to(device) for k, v in batch_x.items()} + batch_y = batch_y.to(device) + + assert batch_x["encoder_cont"].shape[2] == model_metadata_from_dm["encoder_cont"] + assert batch_x["encoder_cat"].shape[2] == model_metadata_from_dm["encoder_cat"] + assert batch_x["decoder_cont"].shape[2] == model_metadata_from_dm["decoder_cont"] + assert batch_x["decoder_cat"].shape[2] == model_metadata_from_dm["decoder_cat"] + # assert ( + # batch_x["static_categorical_features"].shape[2] + # == model_metadata_from_dm["static_categorical_features"] + # ) + # assert ( + # batch_x["static_continuous_features"].shape[2] + # == model_metadata_from_dm["static_continuous_features"] + # ) + + output_dict = model(batch_x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + ) + assert not torch.isnan(predictions).any() + assert batch_y.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + ) From a8ccfe36d383191ba6bd23902543aed40dbe0d39 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 19:17:36 +0530 Subject: [PATCH 33/43] add tests --- tests/test_models/test_tft_v2.py | 37 +++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index 0455ad818..ae74d59fc 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -123,6 +123,14 @@ def tft_model_params_fixture_func(): # Converted from TestTFTInitialization class def test_basic_initialization(tft_model_params_fixture_func): + """Test basic initialization of the TFT model with default metadata. + + Verifies: + - Model attributes match the provided metadata (e.g., hidden_size, num_layers). + - Proper construction of key model components (LSTM, attention, etc.). + - Correct dimensionality of input layers based on metadata. + - Model retains metadata and hyperparameters as expected. + """ metadata = get_default_test_metadata(output_size=OUTPUT_SIZE_TEST) model = TFT(**tft_model_params_fixture_func, metadata=metadata) assert model.hidden_size == HIDDEN_SIZE_TEST @@ -143,6 +151,13 @@ def test_basic_initialization(tft_model_params_fixture_func): def test_initialization_no_time_varying_features(tft_model_params_fixture_func): + """Test TFT initialization with no time-varying (encoder/decoder) features. + + Verifies: + - Model handles zero encoder/decoder input dimensions correctly. + - Skips creation of encoder/decoder variable selection networks. + - Defaults to input size 1 for LSTMs when no time-varying features exist. + """ metadata = get_default_test_metadata( enc_cont=0, enc_cat=0, dec_cont=0, dec_cat=0, output_size=OUTPUT_SIZE_TEST ) @@ -156,6 +171,12 @@ def test_initialization_no_time_varying_features(tft_model_params_fixture_func): def test_initialization_no_static_features(tft_model_params_fixture_func): + """Test TFT initialization with no static features. + + Verifies: + - Model static input dim is 0. + - Static context linear layer is not created. + """ metadata = get_default_test_metadata( static_cat=0, static_cont=0, output_size=OUTPUT_SIZE_TEST ) @@ -179,6 +200,13 @@ def test_initialization_no_static_features(tft_model_params_fixture_func): def test_forward_pass_configs( tft_model_params_fixture_func, enc_c, enc_k, dec_c, dec_k, stat_c, stat_k ): + """Test TFT forward pass across multiple feature configurations. + + Verifies: + - Model can forward pass without errors for varying combinations of input types. + - Output prediction tensor has expected shape. + - Output contains no NaNs or infinities. + """ current_tft_actual_output_size = tft_model_params_fixture_func["output_size"] metadata = get_default_test_metadata( enc_cont=enc_c, @@ -211,7 +239,6 @@ def test_forward_pass_configs( @pytest.fixture def sample_pandas_data_for_test(): - """Create sample data ensuring all feature columns are numeric (float32).""" series_len = MAX_ENCODER_LENGTH_TEST + MAX_PREDICTION_LENGTH_TEST + 5 num_groups = 6 data = [] @@ -295,6 +322,14 @@ def data_module_for_test(timeseries_obj_for_test): def test_model_with_datamodule_integration( tft_model_params_fixture_func, data_module_for_test ): + """Integration test to ensure TFT works correctly with data module. + + Verifies: + - Metadata inferred from data module matches expected input dimensions. + - Model processes real dataloader batches correctly. + - Output and target tensors from model and data module align in shape. + - No NaNs in predictions. + """ dm = data_module_for_test model_metadata_from_dm = dm.metadata From f900ba5e4d4912573e7dc79c398386e683d5e807 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 19:24:21 +0530 Subject: [PATCH 34/43] add more docstrings --- tests/test_models/test_tft_v2.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index ae74d59fc..d79eac874 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -27,6 +27,7 @@ def get_default_test_metadata( static_cont=1, output_size=OUTPUT_SIZE_TEST, ): + """Return a dict representing default metadata for TFT model initialization.""" return { "max_encoder_length": MAX_ENCODER_LENGTH_TEST, "max_prediction_length": MAX_PREDICTION_LENGTH_TEST, @@ -41,6 +42,8 @@ def get_default_test_metadata( def create_tft_input_batch_for_test(metadata, batch_size=BATCH_SIZE_TEST, device="cpu"): + """Create a synthetic input batch dictionary for testing TFT forward passes.""" + def _get_dim_val(key): return metadata.get(key, 0) @@ -111,6 +114,7 @@ def _get_dim_val(key): @pytest.fixture(scope="module") def tft_model_params_fixture_func(): + """Create a default set of model parameters for TFT.""" return { "loss": dummy_loss_for_test, "hidden_size": HIDDEN_SIZE_TEST, @@ -121,7 +125,6 @@ def tft_model_params_fixture_func(): } -# Converted from TestTFTInitialization class def test_basic_initialization(tft_model_params_fixture_func): """Test basic initialization of the TFT model with default metadata. @@ -239,6 +242,7 @@ def test_forward_pass_configs( @pytest.fixture def sample_pandas_data_for_test(): + """Create synthetic multivariate time series data as a pandas DataFrame.""" series_len = MAX_ENCODER_LENGTH_TEST + MAX_PREDICTION_LENGTH_TEST + 5 num_groups = 6 data = [] @@ -282,6 +286,7 @@ def sample_pandas_data_for_test(): @pytest.fixture def timeseries_obj_for_test(sample_pandas_data_for_test): + """Convert sample DataFrame into a TimeSeries object.""" df = sample_pandas_data_for_test return TimeSeries( @@ -305,6 +310,7 @@ def timeseries_obj_for_test(sample_pandas_data_for_test): @pytest.fixture def data_module_for_test(timeseries_obj_for_test): + """Initialize and sets up an EncoderDecoderTimeSeriesDataModule.""" dm = EncoderDecoderTimeSeriesDataModule( time_series_dataset=timeseries_obj_for_test, batch_size=BATCH_SIZE_TEST, @@ -318,7 +324,6 @@ def data_module_for_test(timeseries_obj_for_test): return dm -# Converted from TestTFTWithDataModule class def test_model_with_datamodule_integration( tft_model_params_fixture_func, data_module_for_test ): From ed1b79936df9c4cb18c29393f964228997001b98 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 19:26:40 +0530 Subject: [PATCH 35/43] add note about the commented out tests --- tests/test_models/test_tft_v2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index d79eac874..57a50e75e 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -188,7 +188,6 @@ def test_initialization_no_static_features(tft_model_params_fixture_func): assert model.static_context_linear is None -# Converted from TestTFTForwardPass class @pytest.mark.parametrize( "enc_c, enc_k, dec_c, dec_k, stat_c, stat_k", [ @@ -334,6 +333,8 @@ def test_model_with_datamodule_integration( - Model processes real dataloader batches correctly. - Output and target tensors from model and data module align in shape. - No NaNs in predictions. + + Note: The commented out tests are to test a bug in data_module """ dm = data_module_for_test model_metadata_from_dm = dm.metadata From c0ceb8a16703573144e3d0bd3aa6ab978157a341 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 17 May 2025 02:08:06 +0530 Subject: [PATCH 36/43] add the commented out tests --- tests/test_models/test_tft_v2.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index 57a50e75e..f541082ce 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -316,7 +316,6 @@ def data_module_for_test(timeseries_obj_for_test): max_encoder_length=MAX_ENCODER_LENGTH_TEST, max_prediction_length=MAX_PREDICTION_LENGTH_TEST, train_val_test_split=(0.5, 0.25, 0.25), - num_workers=0, # Added for consistency ) dm.setup("fit") dm.setup("test") @@ -377,14 +376,14 @@ def test_model_with_datamodule_integration( assert batch_x["encoder_cat"].shape[2] == model_metadata_from_dm["encoder_cat"] assert batch_x["decoder_cont"].shape[2] == model_metadata_from_dm["decoder_cont"] assert batch_x["decoder_cat"].shape[2] == model_metadata_from_dm["decoder_cat"] - # assert ( - # batch_x["static_categorical_features"].shape[2] - # == model_metadata_from_dm["static_categorical_features"] - # ) - # assert ( - # batch_x["static_continuous_features"].shape[2] - # == model_metadata_from_dm["static_continuous_features"] - # ) + assert ( + batch_x["static_categorical_features"].shape[2] + == model_metadata_from_dm["static_categorical_features"] + ) + assert ( + batch_x["static_continuous_features"].shape[2] + == model_metadata_from_dm["static_continuous_features"] + ) output_dict = model(batch_x) predictions = output_dict["prediction"] From 3828c260d4b32ee7fcd9fc300776126c70f6a3b6 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 17 May 2025 02:09:16 +0530 Subject: [PATCH 37/43] remove note --- tests/test_models/test_tft_v2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index f541082ce..791ea10ef 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -332,8 +332,6 @@ def test_model_with_datamodule_integration( - Model processes real dataloader batches correctly. - Output and target tensors from model and data module align in shape. - No NaNs in predictions. - - Note: The commented out tests are to test a bug in data_module """ dm = data_module_for_test model_metadata_from_dm = dm.metadata From 30b541b2910c461e3e488e19137c1242c0b0627b Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 21 May 2025 00:52:29 +0530 Subject: [PATCH 38/43] make the modules private --- .../{base_model_refactor.py => _base_model_v2.py} | 13 +++++++++++++ .../{tft_version_two.py => _tft_v2.py} | 2 +- .../test_models/{test_tft_v2.py => _test_tft_v2.py} | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) rename pytorch_forecasting/models/base/{base_model_refactor.py => _base_model_v2.py} (93%) rename pytorch_forecasting/models/temporal_fusion_transformer/{tft_version_two.py => _tft_v2.py} (99%) rename tests/test_models/{test_tft_v2.py => _test_tft_v2.py} (99%) diff --git a/pytorch_forecasting/models/base/base_model_refactor.py b/pytorch_forecasting/models/base/_base_model_v2.py similarity index 93% rename from pytorch_forecasting/models/base/base_model_refactor.py rename to pytorch_forecasting/models/base/_base_model_v2.py index ccd2c2600..ddefc29fb 100644 --- a/pytorch_forecasting/models/base/base_model_refactor.py +++ b/pytorch_forecasting/models/base/_base_model_v2.py @@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Tuple, Union +from warnings import warn from lightning.pytorch import LightningModule from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -53,6 +54,18 @@ def __init__( self.lr_scheduler_params = ( lr_scheduler_params if lr_scheduler_params is not None else {} ) + self.model_name = self.__class__.__name__ + warn( + f"The Model '{self.model_name}' is part of an experimental rework" + "of the pytorch-forecasting model layer, scheduled for release with v2.0.0." + " The API is not stable and may change without prior warning. " + "This class is intended for beta testing and as a basic skeleton, " + "but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py similarity index 99% rename from pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py rename to pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py index 1a1634356..a0cf7d39e 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py @@ -9,7 +9,7 @@ import torch.nn as nn from torch.optim import Optimizer -from pytorch_forecasting.models.base.base_model_refactor import BaseModel +from pytorch_forecasting.models.base._base_model_v2 import BaseModel class TFT(BaseModel): diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/_test_tft_v2.py similarity index 99% rename from tests/test_models/test_tft_v2.py rename to tests/test_models/_test_tft_v2.py index 791ea10ef..13d92d5db 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/_test_tft_v2.py @@ -6,7 +6,7 @@ from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule from pytorch_forecasting.data.timeseries import TimeSeries -from pytorch_forecasting.models.temporal_fusion_transformer.tft_version_two import TFT +from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT BATCH_SIZE_TEST = 2 MAX_ENCODER_LENGTH_TEST = 10 From 469ddc72cbab5d5b54c4f2ed0301a698b48b0f1f Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 2 Jun 2025 00:51:15 +0530 Subject: [PATCH 39/43] move the nb to tutorials --- docs/source/tutorials/ptf_V2_example.ipynb | 3742 ++++++++++++++++++++ examples/ptf_V2_example.ipynb | 1022 ------ pytorch_forecasting/data/examples.py | 24 + 3 files changed, 3766 insertions(+), 1022 deletions(-) create mode 100644 docs/source/tutorials/ptf_V2_example.ipynb delete mode 100644 examples/ptf_V2_example.ipynb diff --git a/docs/source/tutorials/ptf_V2_example.ipynb b/docs/source/tutorials/ptf_V2_example.ipynb new file mode 100644 index 000000000..9e2c5b61c --- /dev/null +++ b/docs/source/tutorials/ptf_V2_example.ipynb @@ -0,0 +1,3742 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "M7PQerTbI_tM" + }, + "outputs": [], + "source": [ + "from lightning.pytorch import Trainer\n", + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.preprocessing import RobustScaler, StandardScaler\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import Optimizer\n", + "from torch.utils.data import Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DGTyf3vct-Jk" + }, + "outputs": [], + "source": [ + "from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule\n", + "from pytorch_forecasting.data.encoders import (\n", + " NaNLabelEncoder,\n", + " TorchNormalizer,\n", + ")\n", + "from pytorch_forecasting.data.timeseries import TimeSeries\n", + "from pytorch_forecasting.metrics import MAE, SMAPE\n", + "from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT\n", + "from pytorch_forecasting.data.examples import load_toydata" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "WX-FRdusJSVN", + "outputId": "8a3ff2f5-afac-4aa8-fc9b-197434d23b10" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6708702925213534,\n \"min\": -1.3075502393215899,\n \"max\": 1.3507284163939648,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.5301916832922879,\n 0.9038568051261738,\n -1.0667842789852795\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6751171140336862,\n \"min\": -1.3075502393215899,\n \"max\": 1.3507284163939648,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.6349349027782476,\n 0.6498298627180364,\n -0.6540777514915888\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2671504392334601,\n \"min\": 0.021493369564058562,\n \"max\": 0.9758804007580262,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.7394229429308296,\n 0.8404599998401439,\n 0.24249095973810553\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", + "type": "dataframe", + "variable_name": "data_df" + }, + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
000-0.0932870.30890801.0000000.5551660
1010.3089080.27633000.9950040.5551660
2020.2763300.48430200.9800670.5551660
3030.4843020.64893800.9553360.5551660
4040.6489380.94295600.9210610.5551660
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "text/plain": [ + " series_id time_idx ... static_feature static_feature_cat\n", + "0 0 0 ... 0.555166 0\n", + "1 0 1 ... 0.555166 0\n", + "2 0 2 ... 0.555166 0\n", + "3 0 3 ... 0.555166 0\n", + "4 0 4 ... 0.555166 0\n", + "\n", + "[5 rows x 8 columns]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "num_series = 100\n", + "seq_length = 50\n", + "data_df = load_toydata(num_series, seq_length)\n", + "data_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "AxxPHK6AKSD2", + "outputId": "ff7c3e3d-7cb7-405b-d668-ed65e66c0dde" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/content/pytorch-forecasting/pytorch_forecasting/data/timeseries/_timeseries_v2.py:105: UserWarning: TimeSeries is part of an experimental rework of the pytorch-forecasting data layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. For beta testing, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", + " warn(\n" + ] + } + ], + "source": [ + "dataset = TimeSeries(\n", + " data=data_df,\n", + " time=\"time_idx\",\n", + " target=\"y\",\n", + " group=[\"series_id\"],\n", + " num=[\"x\", \"future_known_feature\", \"static_feature\"],\n", + " cat=[\"category\", \"static_feature_cat\"],\n", + " known=[\"future_known_feature\"],\n", + " unknown=[\"x\", \"category\"],\n", + " static=[\"static_feature\", \"static_feature_cat\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5U5Lr_ZFKX0s", + "outputId": "3b51be31-c103-4c4a-8eec-1fc72e15bdff" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/content/pytorch-forecasting/pytorch_forecasting/data/data_module.py:129: UserWarning: TimeSeries is part of an experimental rework of the pytorch-forecasting data layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. For beta testing, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", + " warn(\n" + ] + } + ], + "source": [ + "data_module = EncoderDecoderTimeSeriesDataModule(\n", + " time_series_dataset=dataset,\n", + " max_encoder_length=30,\n", + " max_prediction_length=1,\n", + " batch_size=32,\n", + " categorical_encoders={\n", + " \"category\": NaNLabelEncoder(add_nan=True),\n", + " \"static_feature_cat\": NaNLabelEncoder(add_nan=True),\n", + " },\n", + " scalers={\n", + " \"x\": StandardScaler(),\n", + " \"future_known_feature\": StandardScaler(),\n", + " \"static_feature\": StandardScaler(),\n", + " },\n", + " target_normalizer=TorchNormalizer(),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "45783fa479844c42b468e57f59d04893", + "e1233d07b0f54d609f5e80d7db0aa319", + "3da9fc2e0fec4f55a9a04a14d0865f17", + "e309d524e11d4fdfa9833ee588f623f4", + "967c9daec0574b9c9347c764b513722c", + "32dfa18d4a764de4845dd937f3e4f099", + "d1edfa752fcd4f2993cbfa18442cd49d", + "5786865f15ff40b085a5bdc908fafafb", + "42d2d7a93a9c4e4caec8e9cfb9078a8f", + "01a5c809e4254e43a1ce9678173eb4a9", + "43bca6c15e324927a42bdeef4aabcfa0", + "f8c3866ecdef4360a695e16d5b83bb63", + "ec6637f83f4244d3ba2b2a3635a1ef85", + "7378401a35094aa0bd416c33931d7c03", + "5b47c482771448e9af0bf88e684e421f", + "83c313f300244406a536595f694d794b", + "9e5dde09d59845f6b0f78f9ab86f509c", + "e63282c1bc2e49a89036864eea8d1315", + "56917958a5a8402c9cecda208ded82f2", + "f914f3a966e24f8eb0e6c23ec3ed5b61", + "fc4ec9c1331a40eab9aa68a15a7a89fe", + "2e907b4f2f674731933839fe854b7915", + "d2fb0c3c65ee43f3875409f6f603697d", + "941c97f36d2d4c7780fa034234011c22", + "0a19ccb1dabc4f689ebbfc2aa8a2855c", + "c29831a4eaa7492f808f0655993f2907", + "d96e9c61b63f4365aeccd88605cc85d0", + "6a0d5b88cbef4b0c9eee5669fabe29b1", + "66fb47ed48ad4c5d9d8269eeead488c2", + "29aa1ffff46e40758d9a65d6412c1a14", + "37719701ed6e4474a84c8f7bedfef5ea", + "f95586183a2142b2b2458eee62135678", + "049d94a2e6e24c6d95280581a7c680ef", + "24f8555a1ec5453095322162a016c235", + "25d2cdd4420c42ac9ada8fc1c0b4a102", + "b32e630900154feea6b63dcb1e6fc136", + "31738edb3858416baa14371956256152", + "b4d7dafb73af4633a23d1fbf197eb289", + "69f243809bc74cbdaf413772c4184902", + "2f34c1ad49dc46c1b0958e1f78433f7f", + "d98887028e6b4b48bd428d81fa6b74f5", + "9078b5fa96dc41eaa8db4e8b0748e56f", + "9dc235acfcac469ab653a149792dcd0d", + "27e9ab9d495543d7aadd4f418fbf6d9d", + "399c8534021d4f808130d746199f887e", + "ac37c8ac19004341958dd04791943ccf", + "7a52d6588e2440bfa3b6bb78f353d6cb", + "76d7913b8b9c440289a0bfd89bf46674", + "2c624ad70fa442d9ba15d77d32c3b4fb", + "1fcc462941bf4f97813d83f30d04a7e2", + "495b5207e43840e3a6649516295bdbb3", + "ddd73bc058724056ab643b0634f7b5db", + "aef199a692bc4338a63f81b38c94ad90", + "47d85d85ca9e4f5a8ac72e2c5b1d0a0f", + "5a8b4b3a2a214889aed4690c3736aa6f", + "9be2ea5a2a0b4932bdefb4dcc6ed7e3b", + "513343d79e4a4e0492c2a2e29aa79879", + "be5cd97e0aee4455b417f5fbea068f96", + "2902b8469715416fa64c61d3bd26613c", + "051d93365fe148c18fc0fa90af70ea45", + "bcd6636168424a9797d695e67db39dd8", + "36eafed5df36478b80e59d32675f03b1", + "834bf5c188ed497db28148bbd0ab8694", + "891dd939561b4078a091f3536336c11e", + "4b63bd8fd47641968594ec0bb7919955", + "7d60450c60184959accb717f31270816", + "7c19c43c559c40ab8801aaf0bd801355", + "41a0c0a294d94ae7911c6275862786e2", + "bb41bd772f5c4a528998e0369b51dffd", + "a0e9638c78b6405ba95e7d2a89cbe800", + "99820b7ff3c74f6ea2e89d27701786b3", + "ca32b0845d63450ea10b2f74e45b85ab", + "63d9975f089242e1a2e25e9f667615b8", + "243678f5c0ab4f2c8bce997e67490134", + "c2d1809f9e894c3eafec87e34580c5a7", + "2f50b7ba59b546a9a0a139a9880c7439", + "7af2d15ad21c4b54911d9f505ee8a1be", + "cef82d8c32104bae84227b9aa8687824", + "4944ace504f946c79b60161890c968a0", + "7475bee2c4d8459fbb1a8c026c87cc53", + "813b34681d614379ba78e28d7605fff5", + "82db490ebcbc493e9cdcdeed02675bb1", + "3ffd7aa2c8ea4098882524150409893f", + "8197a6b6b0be4b6c828e8a242e275869", + "d4ff8288c3254e0d9fa1b00bcf38aa78", + "d13c22944bdf4b6b882a08956d3475f1", + "c14874c570ea4310a9edfd148ea2fe73", + "8f7a1cbbb81b42cc8262e28dc0d1f278" + ] + }, + "id": "Si7bbZIULBZz", + "outputId": "4bf7f692-fbb8-4e1e-afe8-6f7a4ebe89ca" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/content/pytorch-forecasting/pytorch_forecasting/models/base/_base_model_v2.py:58: UserWarning: The Model 'TFT' is part of an experimental reworkof the pytorch-forecasting model layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. This class is intended for beta testing and as a basic skeleton, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", + " warn(\n", + "INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", + "INFO:lightning.pytorch.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", + "INFO: GPU available: False, used: False\n", + "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", + "INFO: TPU available: False, using: 0 TPU cores\n", + "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO: HPU available: False, using: 0 HPUs\n", + "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Training model...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: \n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 193 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 50.4 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "123 K Trainable params\n", + "0 Non-trainable params\n", + "123 K Total params\n", + "0.495 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n", + "INFO:lightning.pytorch.callbacks.model_summary:\n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 193 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 50.4 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "123 K Trainable params\n", + "0 Non-trainable params\n", + "123 K Total params\n", + "0.495 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "45783fa479844c42b468e57f59d04893", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test_MAE 0.5018836855888367 │\n", + "│ test_SMAPE 1.1201428174972534 │\n", + "│ test_loss 0.012998351827263832 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5018836855888367 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.1201428174972534 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.012998351827263832 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Prediction shape: torch.Size([32, 1, 1])\n", + "First prediction values: [[-0.12206323]]\n", + "First true values: [[-0.04691907]]\n", + "\n", + "TFT model test complete!\n" + ] + } + ], + "source": [ + "model = TFT(\n", + " loss=nn.MSELoss(),\n", + " logging_metrics=[MAE(), SMAPE()],\n", + " optimizer=\"adam\",\n", + " optimizer_params={\"lr\": 1e-3},\n", + " lr_scheduler=\"reduce_lr_on_plateau\",\n", + " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", + " hidden_size=64,\n", + " num_layers=2,\n", + " attention_head_size=4,\n", + " dropout=0.1,\n", + " metadata=data_module.metadata,\n", + ")\n", + "\n", + "print(\"\\nTraining model...\")\n", + "trainer = Trainer(\n", + " max_epochs=5,\n", + " accelerator=\"auto\",\n", + " devices=1,\n", + " enable_progress_bar=True,\n", + " log_every_n_steps=10,\n", + ")\n", + "\n", + "trainer.fit(model, data_module)\n", + "\n", + "print(\"\\nEvaluating model...\")\n", + "test_metrics = trainer.test(model, data_module)\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " test_batch = next(iter(data_module.test_dataloader()))\n", + " x_test, y_test = test_batch\n", + " y_pred = model(x_test)\n", + "\n", + " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", + " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", + " print(\"First true values:\", y_test[0].cpu().numpy())\n", + "print(\"\\nTFT model test complete!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zVRwi2MvLGgc" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.3" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "01a5c809e4254e43a1ce9678173eb4a9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "049d94a2e6e24c6d95280581a7c680ef": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "051d93365fe148c18fc0fa90af70ea45": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "0a19ccb1dabc4f689ebbfc2aa8a2855c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_29aa1ffff46e40758d9a65d6412c1a14", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_37719701ed6e4474a84c8f7bedfef5ea", + "tabbable": null, + "tooltip": null, + "value": 9 + } + }, + "1fcc462941bf4f97813d83f30d04a7e2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "243678f5c0ab4f2c8bce997e67490134": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "24f8555a1ec5453095322162a016c235": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_25d2cdd4420c42ac9ada8fc1c0b4a102", + "IPY_MODEL_b32e630900154feea6b63dcb1e6fc136", + "IPY_MODEL_31738edb3858416baa14371956256152" + ], + "layout": "IPY_MODEL_b4d7dafb73af4633a23d1fbf197eb289", + "tabbable": null, + "tooltip": null + } + }, + "25d2cdd4420c42ac9ada8fc1c0b4a102": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_69f243809bc74cbdaf413772c4184902", + "placeholder": "​", + "style": "IPY_MODEL_2f34c1ad49dc46c1b0958e1f78433f7f", + "tabbable": null, + "tooltip": null, + "value": "Validation DataLoader 0: 100%" + } + }, + "27e9ab9d495543d7aadd4f418fbf6d9d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "2902b8469715416fa64c61d3bd26613c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_4b63bd8fd47641968594ec0bb7919955", + "placeholder": "​", + "style": "IPY_MODEL_7d60450c60184959accb717f31270816", + "tabbable": null, + "tooltip": null, + "value": " 9/9 [00:00<00:00, 11.61it/s]" + } + }, + "29aa1ffff46e40758d9a65d6412c1a14": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2c624ad70fa442d9ba15d77d32c3b4fb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "2e907b4f2f674731933839fe854b7915": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "2f34c1ad49dc46c1b0958e1f78433f7f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "2f50b7ba59b546a9a0a139a9880c7439": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "31738edb3858416baa14371956256152": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_9dc235acfcac469ab653a149792dcd0d", + "placeholder": "​", + "style": "IPY_MODEL_27e9ab9d495543d7aadd4f418fbf6d9d", + "tabbable": null, + "tooltip": null, + "value": " 9/9 [00:00<00:00, 11.50it/s]" + } + }, + "32dfa18d4a764de4845dd937f3e4f099": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "36eafed5df36478b80e59d32675f03b1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "37719701ed6e4474a84c8f7bedfef5ea": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "399c8534021d4f808130d746199f887e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_ac37c8ac19004341958dd04791943ccf", + "IPY_MODEL_7a52d6588e2440bfa3b6bb78f353d6cb", + "IPY_MODEL_76d7913b8b9c440289a0bfd89bf46674" + ], + "layout": "IPY_MODEL_2c624ad70fa442d9ba15d77d32c3b4fb", + "tabbable": null, + "tooltip": null + } + }, + "3da9fc2e0fec4f55a9a04a14d0865f17": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_5786865f15ff40b085a5bdc908fafafb", + "max": 2, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_42d2d7a93a9c4e4caec8e9cfb9078a8f", + "tabbable": null, + "tooltip": null, + "value": 2 + } + }, + "3ffd7aa2c8ea4098882524150409893f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "41a0c0a294d94ae7911c6275862786e2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_ca32b0845d63450ea10b2f74e45b85ab", + "placeholder": "​", + "style": "IPY_MODEL_63d9975f089242e1a2e25e9f667615b8", + "tabbable": null, + "tooltip": null, + "value": "Validation DataLoader 0: 100%" + } + }, + "42d2d7a93a9c4e4caec8e9cfb9078a8f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "43bca6c15e324927a42bdeef4aabcfa0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "45783fa479844c42b468e57f59d04893": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e1233d07b0f54d609f5e80d7db0aa319", + "IPY_MODEL_3da9fc2e0fec4f55a9a04a14d0865f17", + "IPY_MODEL_e309d524e11d4fdfa9833ee588f623f4" + ], + "layout": "IPY_MODEL_967c9daec0574b9c9347c764b513722c", + "tabbable": null, + "tooltip": null + } + }, + "47d85d85ca9e4f5a8ac72e2c5b1d0a0f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4944ace504f946c79b60161890c968a0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_3ffd7aa2c8ea4098882524150409893f", + "placeholder": "​", + "style": "IPY_MODEL_8197a6b6b0be4b6c828e8a242e275869", + "tabbable": null, + "tooltip": null, + "value": "Testing DataLoader 0: 100%" + } + }, + "495b5207e43840e3a6649516295bdbb3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "4b63bd8fd47641968594ec0bb7919955": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "513343d79e4a4e0492c2a2e29aa79879": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_bcd6636168424a9797d695e67db39dd8", + "placeholder": "​", + "style": "IPY_MODEL_36eafed5df36478b80e59d32675f03b1", + "tabbable": null, + "tooltip": null, + "value": "Validation DataLoader 0: 100%" + } + }, + "56917958a5a8402c9cecda208ded82f2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5786865f15ff40b085a5bdc908fafafb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5a8b4b3a2a214889aed4690c3736aa6f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "5b47c482771448e9af0bf88e684e421f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_fc4ec9c1331a40eab9aa68a15a7a89fe", + "placeholder": "​", + "style": "IPY_MODEL_2e907b4f2f674731933839fe854b7915", + "tabbable": null, + "tooltip": null, + "value": " 42/42 [00:06<00:00,  6.45it/s, v_num=0, train_loss_step=0.00881, val_loss=0.0153, val_MAE=0.487, val_SMAPE=1.110, train_loss_epoch=0.0182, train_MAE=0.457, train_SMAPE=0.993]" + } + }, + "63d9975f089242e1a2e25e9f667615b8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "66fb47ed48ad4c5d9d8269eeead488c2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "69f243809bc74cbdaf413772c4184902": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6a0d5b88cbef4b0c9eee5669fabe29b1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7378401a35094aa0bd416c33931d7c03": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_56917958a5a8402c9cecda208ded82f2", + "max": 42, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f914f3a966e24f8eb0e6c23ec3ed5b61", + "tabbable": null, + "tooltip": null, + "value": 42 + } + }, + "7475bee2c4d8459fbb1a8c026c87cc53": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_d4ff8288c3254e0d9fa1b00bcf38aa78", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_d13c22944bdf4b6b882a08956d3475f1", + "tabbable": null, + "tooltip": null, + "value": 9 + } + }, + "76d7913b8b9c440289a0bfd89bf46674": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_47d85d85ca9e4f5a8ac72e2c5b1d0a0f", + "placeholder": "​", + "style": "IPY_MODEL_5a8b4b3a2a214889aed4690c3736aa6f", + "tabbable": null, + "tooltip": null, + "value": " 9/9 [00:00<00:00, 11.67it/s]" + } + }, + "7a52d6588e2440bfa3b6bb78f353d6cb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_ddd73bc058724056ab643b0634f7b5db", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_aef199a692bc4338a63f81b38c94ad90", + "tabbable": null, + "tooltip": null, + "value": 9 + } + }, + "7af2d15ad21c4b54911d9f505ee8a1be": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "7c19c43c559c40ab8801aaf0bd801355": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_41a0c0a294d94ae7911c6275862786e2", + "IPY_MODEL_bb41bd772f5c4a528998e0369b51dffd", + "IPY_MODEL_a0e9638c78b6405ba95e7d2a89cbe800" + ], + "layout": "IPY_MODEL_99820b7ff3c74f6ea2e89d27701786b3", + "tabbable": null, + "tooltip": null + } + }, + "7d60450c60184959accb717f31270816": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "813b34681d614379ba78e28d7605fff5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_c14874c570ea4310a9edfd148ea2fe73", + "placeholder": "​", + "style": "IPY_MODEL_8f7a1cbbb81b42cc8262e28dc0d1f278", + "tabbable": null, + "tooltip": null, + "value": " 9/9 [00:00<00:00,  9.42it/s]" + } + }, + "8197a6b6b0be4b6c828e8a242e275869": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "82db490ebcbc493e9cdcdeed02675bb1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "100%" + } + }, + "834bf5c188ed497db28148bbd0ab8694": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "83c313f300244406a536595f694d794b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "100%" + } + }, + "891dd939561b4078a091f3536336c11e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "8f7a1cbbb81b42cc8262e28dc0d1f278": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "9078b5fa96dc41eaa8db4e8b0748e56f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "941c97f36d2d4c7780fa034234011c22": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_6a0d5b88cbef4b0c9eee5669fabe29b1", + "placeholder": "​", + "style": "IPY_MODEL_66fb47ed48ad4c5d9d8269eeead488c2", + "tabbable": null, + "tooltip": null, + "value": "Validation DataLoader 0: 100%" + } + }, + "967c9daec0574b9c9347c764b513722c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "99820b7ff3c74f6ea2e89d27701786b3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "9be2ea5a2a0b4932bdefb4dcc6ed7e3b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_513343d79e4a4e0492c2a2e29aa79879", + "IPY_MODEL_be5cd97e0aee4455b417f5fbea068f96", + "IPY_MODEL_2902b8469715416fa64c61d3bd26613c" + ], + "layout": "IPY_MODEL_051d93365fe148c18fc0fa90af70ea45", + "tabbable": null, + "tooltip": null + } + }, + "9dc235acfcac469ab653a149792dcd0d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9e5dde09d59845f6b0f78f9ab86f509c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a0e9638c78b6405ba95e7d2a89cbe800": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_2f50b7ba59b546a9a0a139a9880c7439", + "placeholder": "​", + "style": "IPY_MODEL_7af2d15ad21c4b54911d9f505ee8a1be", + "tabbable": null, + "tooltip": null, + "value": " 9/9 [00:01<00:00,  8.20it/s]" + } + }, + "ac37c8ac19004341958dd04791943ccf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_1fcc462941bf4f97813d83f30d04a7e2", + "placeholder": "​", + "style": "IPY_MODEL_495b5207e43840e3a6649516295bdbb3", + "tabbable": null, + "tooltip": null, + "value": "Validation DataLoader 0: 100%" + } + }, + "aef199a692bc4338a63f81b38c94ad90": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b32e630900154feea6b63dcb1e6fc136": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_d98887028e6b4b48bd428d81fa6b74f5", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_9078b5fa96dc41eaa8db4e8b0748e56f", + "tabbable": null, + "tooltip": null, + "value": 9 + } + }, + "b4d7dafb73af4633a23d1fbf197eb289": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "bb41bd772f5c4a528998e0369b51dffd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_243678f5c0ab4f2c8bce997e67490134", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_c2d1809f9e894c3eafec87e34580c5a7", + "tabbable": null, + "tooltip": null, + "value": 9 + } + }, + "bcd6636168424a9797d695e67db39dd8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "be5cd97e0aee4455b417f5fbea068f96": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_834bf5c188ed497db28148bbd0ab8694", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_891dd939561b4078a091f3536336c11e", + "tabbable": null, + "tooltip": null, + "value": 9 + } + }, + "c14874c570ea4310a9edfd148ea2fe73": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c29831a4eaa7492f808f0655993f2907": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_f95586183a2142b2b2458eee62135678", + "placeholder": "​", + "style": "IPY_MODEL_049d94a2e6e24c6d95280581a7c680ef", + "tabbable": null, + "tooltip": null, + "value": " 9/9 [00:00<00:00, 11.33it/s]" + } + }, + "c2d1809f9e894c3eafec87e34580c5a7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "ca32b0845d63450ea10b2f74e45b85ab": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cef82d8c32104bae84227b9aa8687824": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_4944ace504f946c79b60161890c968a0", + "IPY_MODEL_7475bee2c4d8459fbb1a8c026c87cc53", + "IPY_MODEL_813b34681d614379ba78e28d7605fff5" + ], + "layout": "IPY_MODEL_82db490ebcbc493e9cdcdeed02675bb1", + "tabbable": null, + "tooltip": null + } + }, + "d13c22944bdf4b6b882a08956d3475f1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "d1edfa752fcd4f2993cbfa18442cd49d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "d2fb0c3c65ee43f3875409f6f603697d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_941c97f36d2d4c7780fa034234011c22", + "IPY_MODEL_0a19ccb1dabc4f689ebbfc2aa8a2855c", + "IPY_MODEL_c29831a4eaa7492f808f0655993f2907" + ], + "layout": "IPY_MODEL_d96e9c61b63f4365aeccd88605cc85d0", + "tabbable": null, + "tooltip": null + } + }, + "d4ff8288c3254e0d9fa1b00bcf38aa78": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d96e9c61b63f4365aeccd88605cc85d0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "d98887028e6b4b48bd428d81fa6b74f5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ddd73bc058724056ab643b0634f7b5db": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e1233d07b0f54d609f5e80d7db0aa319": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_32dfa18d4a764de4845dd937f3e4f099", + "placeholder": "​", + "style": "IPY_MODEL_d1edfa752fcd4f2993cbfa18442cd49d", + "tabbable": null, + "tooltip": null, + "value": "Sanity Checking DataLoader 0: 100%" + } + }, + "e309d524e11d4fdfa9833ee588f623f4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_01a5c809e4254e43a1ce9678173eb4a9", + "placeholder": "​", + "style": "IPY_MODEL_43bca6c15e324927a42bdeef4aabcfa0", + "tabbable": null, + "tooltip": null, + "value": " 2/2 [00:00<00:00,  2.57it/s]" + } + }, + "e63282c1bc2e49a89036864eea8d1315": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "ec6637f83f4244d3ba2b2a3635a1ef85": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_9e5dde09d59845f6b0f78f9ab86f509c", + "placeholder": "​", + "style": "IPY_MODEL_e63282c1bc2e49a89036864eea8d1315", + "tabbable": null, + "tooltip": null, + "value": "Epoch 4: 100%" + } + }, + "f8c3866ecdef4360a695e16d5b83bb63": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_ec6637f83f4244d3ba2b2a3635a1ef85", + "IPY_MODEL_7378401a35094aa0bd416c33931d7c03", + "IPY_MODEL_5b47c482771448e9af0bf88e684e421f" + ], + "layout": "IPY_MODEL_83c313f300244406a536595f694d794b", + "tabbable": null, + "tooltip": null + } + }, + "f914f3a966e24f8eb0e6c23ec3ed5b61": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "f95586183a2142b2b2458eee62135678": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fc4ec9c1331a40eab9aa68a15a7a89fe": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/ptf_V2_example.ipynb b/examples/ptf_V2_example.ipynb deleted file mode 100644 index 2e39108f3..000000000 --- a/examples/ptf_V2_example.ipynb +++ /dev/null @@ -1,1022 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "2630DaOEI4AJ", - "outputId": "b4cc7100-2a9c-41a3-e890-164b90c91c03" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting pytorch-forecasting\n", - " Downloading pytorch_forecasting-1.3.0-py3-none-any.whl.metadata (13 kB)\n", - "Requirement already satisfied: numpy<=3.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.0.2)\n", - "Requirement already satisfied: torch!=2.0.1,<3.0.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.6.0+cu124)\n", - "Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)\n", - " Downloading lightning-2.5.1.post0-py3-none-any.whl.metadata (39 kB)\n", - "Requirement already satisfied: scipy<2.0,>=1.8 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.15.2)\n", - "Requirement already satisfied: pandas<3.0.0,>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.2.2)\n", - "Requirement already satisfied: scikit-learn<2.0,>=1.2 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.6.1)\n", - "Requirement already satisfied: PyYAML<8.0,>=5.4 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.0.2)\n", - "Requirement already satisfied: fsspec<2026.0,>=2022.5.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2025.3.2)\n", - "Collecting lightning-utilities<2.0,>=0.10.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)\n", - "Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (24.2)\n", - "Collecting torchmetrics<3.0,>=0.7.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)\n", - "Requirement already satisfied: tqdm<6.0,>=4.57.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.67.1)\n", - "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.13.2)\n", - "Collecting pytorch-lightning (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading pytorch_lightning-2.5.1.post0-py3-none-any.whl.metadata (20 kB)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2.9.0.post0)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", - "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", - "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (1.4.2)\n", - "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (3.6.0)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.18.0)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.4.2)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.1.6)\n", - "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", - "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", - "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Collecting nvidia-curand-cu12==10.3.5.147 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", - "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", - "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (0.6.2)\n", - "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (2.21.5)\n", - "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", - "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", - " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", - "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.2.0)\n", - "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.13.1)\n", - "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.0)\n", - "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.11.15)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from lightning-utilities<2.0,>=0.10.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (75.2.0)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas<3.0.0,>=1.3.0->pytorch-forecasting) (1.17.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.0.2)\n", - "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.6.1)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.2)\n", - "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (25.3.0)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.6.0)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.4.3)\n", - "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.3.1)\n", - "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.20.0)\n", - "Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.11/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.10)\n", - "Downloading pytorch_forecasting-1.3.0-py3-none-any.whl (197 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m197.7/197.7 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading lightning-2.5.1.post0-py3-none-any.whl (819 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m819.0/819.0 kB\u001b[0m \u001b[31m16.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m37.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m20.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m12.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m64.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)\n", - "Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m961.5/961.5 kB\u001b[0m \u001b[31m35.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading pytorch_lightning-2.5.1.post0-py3-none-any.whl (823 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.1/823.1 kB\u001b[0m \u001b[31m37.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch-lightning, lightning, pytorch-forecasting\n", - " Attempting uninstall: nvidia-nvjitlink-cu12\n", - " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", - " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", - " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", - " Attempting uninstall: nvidia-curand-cu12\n", - " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", - " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", - " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", - " Attempting uninstall: nvidia-cufft-cu12\n", - " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", - " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", - " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", - " Attempting uninstall: nvidia-cuda-runtime-cu12\n", - " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", - " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", - " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", - " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", - " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", - " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", - " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", - " Attempting uninstall: nvidia-cuda-cupti-cu12\n", - " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", - " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", - " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", - " Attempting uninstall: nvidia-cublas-cu12\n", - " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", - " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", - " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", - " Attempting uninstall: nvidia-cusparse-cu12\n", - " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", - " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", - " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", - " Attempting uninstall: nvidia-cudnn-cu12\n", - " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", - " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", - " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", - " Attempting uninstall: nvidia-cusolver-cu12\n", - " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", - " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", - " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", - "Successfully installed lightning-2.5.1.post0 lightning-utilities-0.14.3 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-forecasting-1.3.0 pytorch-lightning-2.5.1.post0 torchmetrics-1.7.1\n" - ] - } - ], - "source": [ - "!pip install pytorch-forecasting" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "M7PQerTbI_tM" - }, - "outputs": [], - "source": [ - "from typing import Any, Dict, List, Optional, Tuple, Union\n", - "\n", - "from lightning.pytorch import Trainer\n", - "import numpy as np\n", - "import pandas as pd\n", - "from sklearn.preprocessing import RobustScaler, StandardScaler\n", - "import torch\n", - "import torch.nn as nn\n", - "from torch.optim import Optimizer\n", - "from torch.utils.data import Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DGTyf3vct-Jk" - }, - "outputs": [], - "source": [ - "from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule\n", - "from pytorch_forecasting.data.encoders import (\n", - " EncoderNormalizer,\n", - " NaNLabelEncoder,\n", - " TorchNormalizer,\n", - ")\n", - "from pytorch_forecasting.data.timeseries import TimeSeries\n", - "from pytorch_forecasting.metrics import MAE, SMAPE\n", - "from pytorch_forecasting.models.temporal_fusion_transformer.tft_version_two import TFT" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 206 - }, - "id": "WX-FRdusJSVN", - "outputId": "d162e241-3076-415c-db39-8c571bbaa282" - }, - "outputs": [ - { - "data": { - "application/vnd.google.colaboratory.intrinsic+json": { - "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6702424716860947,\n \"min\": -1.2572875930191487,\n \"max\": 1.347291996576924,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.40064266948811306,\n 0.688757012378203,\n -0.9278241195910876\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6742204890063661,\n \"min\": -1.2572875930191487,\n \"max\": 1.347291996576924,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.6039073395571968,\n 0.5832480743546181,\n -0.801772762118357\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2912918986931303,\n \"min\": 0.007584244652032224,\n \"max\": 0.9959799570401108,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.4307103381570838,\n 0.6664272198589233,\n 0.16731443141739688\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", - "type": "dataframe", - "variable_name": "data_df" - }, - "text/html": [ - "\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
0000.2009680.23229801.0000000.687290
1010.2322980.33666900.9950040.687290
2020.3366690.63606300.9800670.687290
3030.6360630.92771000.9553360.687290
4040.9277101.00855400.9210610.687290
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "\n", - "
\n", - " \n", - "\n", - "\n", - "\n", - " \n", - "
\n", - "\n", - "
\n", - "
\n" - ], - "text/plain": [ - " series_id time_idx x y category future_known_feature \\\n", - "0 0 0 0.200968 0.232298 0 1.000000 \n", - "1 0 1 0.232298 0.336669 0 0.995004 \n", - "2 0 2 0.336669 0.636063 0 0.980067 \n", - "3 0 3 0.636063 0.927710 0 0.955336 \n", - "4 0 4 0.927710 1.008554 0 0.921061 \n", - "\n", - " static_feature static_feature_cat \n", - "0 0.68729 0 \n", - "1 0.68729 0 \n", - "2 0.68729 0 \n", - "3 0.68729 0 \n", - "4 0.68729 0 " - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "\n", - "num_series = 100\n", - "seq_length = 50\n", - "data_list = []\n", - "for i in range(num_series):\n", - " x = np.arange(seq_length)\n", - " y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)\n", - " category = i % 5\n", - " static_value = np.random.rand()\n", - " for t in range(seq_length - 1):\n", - " data_list.append(\n", - " {\n", - " \"series_id\": i,\n", - " \"time_idx\": t,\n", - " \"x\": y[t],\n", - " \"y\": y[t + 1],\n", - " \"category\": category,\n", - " \"future_known_feature\": np.cos(t / 10),\n", - " \"static_feature\": static_value,\n", - " \"static_feature_cat\": i % 3,\n", - " }\n", - " )\n", - "data_df = pd.DataFrame(data_list)\n", - "data_df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "AxxPHK6AKSD2", - "outputId": "dd95173d-73c2-451b-8b67-c9cc7298cf9d" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - ":106: UserWarning: TimeSeries is part of an experimental rework of the pytorch-forecasting data layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. For beta testing, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "dataset = TimeSeries(\n", - " data=data_df,\n", - " time=\"time_idx\",\n", - " target=\"y\",\n", - " group=[\"series_id\"],\n", - " num=[\"x\", \"future_known_feature\", \"static_feature\"],\n", - " cat=[\"category\", \"static_feature_cat\"],\n", - " known=[\"future_known_feature\"],\n", - " unknown=[\"x\", \"category\"],\n", - " static=[\"static_feature\", \"static_feature_cat\"],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "5U5Lr_ZFKX0s" - }, - "outputs": [], - "source": [ - "data_module = EncoderDecoderTimeSeriesDataModule(\n", - " time_series_dataset=dataset,\n", - " max_encoder_length=30,\n", - " max_prediction_length=1,\n", - " batch_size=32,\n", - " categorical_encoders={\n", - " \"category\": NaNLabelEncoder(add_nan=True),\n", - " \"static_feature_cat\": NaNLabelEncoder(add_nan=True),\n", - " },\n", - " scalers={\n", - " \"x\": StandardScaler(),\n", - " \"future_known_feature\": StandardScaler(),\n", - " \"static_feature\": StandardScaler(),\n", - " },\n", - " target_normalizer=TorchNormalizer(),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000, - "referenced_widgets": [ - "7de178cd43ab4104ba2445a057a5f1a4", - "40598800b1234eaeb769f18e3c27865c", - "05a21a6e18ca46e280a8c66d6a73cf81", - "e4a86f54cf33447f8959864533cbae72", - "d5ba9114135f4ec1818790a62beb865a", - "39ab0f799ae64a8c9c0d6041b17c0ba5", - "6b98518240b744fe8a708b37a2dcaabe", - "274c42415f834e679f45d02d7d1c01d7", - "c88fdff0012c4e4e85563adb36287774", - "c2490a4409f34a608f6073ff6b9426eb", - "1a5e8e6d619740c9b0fd752a8f886b0c", - "ceff330da01b4ed39eafa8820cb6a5ca", - "d1d794dcb83746a48628280e7d552a70", - "31e24662df144d7196e01b873ffed137", - "07c7871be3274f45a93f01c6003b35fb", - "86ab236d48d04f1980836770cfe61b0d", - "73b46d96f2d54ea8b132b409b0739588", - "0b8814890f1d4143842acce4df31d93e", - "7960939d00d844f094e1bcdc1acda7f9", - "2dcacb2bf93c4d16bfb8657670499fc5", - "ed187855e406486ca0ea2259f3e2f43c", - "c73f370b888f4893be08d51b08b23a87", - "ea890982ba5a4c3d8e15e6bbd7285f6a", - "02f4487acf82401b972b4593da15de15", - "cca521da65e946abad580e9db9f2ac6b", - "b4c9075a5c1148ee8e422b2fcf86d90b", - "57dd80719a2e444cab8345bf5086a2f7", - "14e2d31e04e9447ba4f0c55000e0abe2", - "b35a713d4d7041109db487d10c55aaed", - "59ca144ddc884bb5a1c038e96cdf0dc0", - "95f9b03446af41fb83876f625aae5d75", - "bd2154cf3b04468b93a3bec23b5e34fa", - "f4713a710f274dd389691d7c12a7e740", - "07571714d67e4b8793ee76b0fe151e67", - "4e7273caa91147019971bf75ffe71e49", - "b751dbe4e93341eaa7d1a7683c277d83", - "11e1c349cb894caa9fd77333f7ababb9", - "8c9ac67ac6af488fb16e64c834634a30", - "704f28911d674088b8dfc240c6e28449", - "67da733b7a254020ad4d8d3877eb4494", - "72858abca77d430a8d009ad72127b331", - "86e82ab383a54ef7aa7c0abec6faf1be", - "60f097a304044e3db6dda60ff381776a", - "334e7a44226f4414a825e2d81e9571c1", - "13246130b24145d680092a5a3929546e", - "e64a93136fbf468b8e503cd202dbc986", - "559df10442ac42fc8033bdf014864334", - "67e6440ba0424db18646d47beff2e37b", - "36707c56fcfb4fadbfff37d44ee52d5b", - "484702bf9d854cb6a964de47d6975aa6", - "ddb67e4d1749424299a2c40b36810809", - "6f1e8c2aa06a48548a7750869d3f056e", - "d7df1875fbbb4c78909d76c5b81ffd95", - "cb60f83efc234cd6a90212125fe841a1", - "dfca7d9f0f7844ea953ce4b659493695", - "14ed4e565d5a4b319b0c38557a935b92", - "48e7f65e9dcc48289f44724da248875d", - "8220ff572b31472e91dc5d7553ca43b9", - "fcd1b8b77a7f4c108ba02de9779f5bb2", - "a53e4183893242ea884db9f25e439d94", - "e54fdc83afdf492a9540bc892bdc262a", - "04ce55b4c81a46b09d55b8a0dc3fae00", - "86eddf636c2b45fc93a69f3c8a260b1c", - "287ce5b9bafb402e9698f20779f52386", - "ef80c0b7306046738cfd4e16f5af1afc", - "c0bb5f38ec9346119441c6cb8fd71c5a", - "a54b4228eaf34dc1b40ee8d40500e069", - "28cb77e3b48a4c669dcaa79609d332c2", - "8acc35455cb84ece9e7b3d84d8870a7a", - "afe9a01608f449479cc491f75e095d55", - "9b5971f9e8d44e8a872b782ef49d306a", - "1538291371bc45b2bcb7a687c6b8f79c", - "d595a038c20e4d349e392ff1795dd418", - "a67fa2461a18428892e57790074fd5b6", - "6733a54e084a4d348a85e574985720b7", - "204594abf12e43d4abf0fa35544fb64a", - "eee2d2ef09294652bd94ff768ffb99f0", - "c8b8814037464155935670b776378ab9", - "0b6b787693af4b548ad59ad7ba6c921f", - "6b7172a2b1fd4ddea3dba3d435bf36fe", - "ed0ffcd137e94edebad5e41956ad7466", - "0344079173a84c448cd6edbac612d970", - "890f17882dc44e76be06dd88e6d07cff", - "1c552ec98dcd4325800ee8c9dddc398a", - "89d54947a3d146e4b089a22a2711a234", - "91e7ef398d1646bcbfb598480d949c74", - "13173089480c49b18c4ba016da69a855", - "6b717d2759ca445da7642a2cf20f22b5" - ] - }, - "id": "Si7bbZIULBZz", - "outputId": "0b2f26b3-e37c-4ab6-c234-90693745a8cd" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", - "INFO:lightning.pytorch.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", - "INFO: GPU available: False, used: False\n", - "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", - "INFO: TPU available: False, using: 0 TPU cores\n", - "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", - "INFO: HPU available: False, using: 0 HPUs\n", - "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n", - "INFO: \n", - " | Name | Type | Params | Mode \n", - "---------------------------------------------------------------------\n", - "0 | loss | MSELoss | 0 | train\n", - "1 | encoder_var_selection | Sequential | 709 | train\n", - "2 | decoder_var_selection | Sequential | 193 | train\n", - "3 | static_context_linear | Linear | 192 | train\n", - "4 | lstm_encoder | LSTM | 51.5 K | train\n", - "5 | lstm_decoder | LSTM | 50.4 K | train\n", - "6 | self_attention | MultiheadAttention | 16.6 K | train\n", - "7 | pre_output | Linear | 4.2 K | train\n", - "8 | output_layer | Linear | 65 | train\n", - "---------------------------------------------------------------------\n", - "123 K Trainable params\n", - "0 Non-trainable params\n", - "123 K Total params\n", - "0.495 Total estimated model params size (MB)\n", - "18 Modules in train mode\n", - "0 Modules in eval mode\n", - "INFO:lightning.pytorch.callbacks.model_summary:\n", - " | Name | Type | Params | Mode \n", - "---------------------------------------------------------------------\n", - "0 | loss | MSELoss | 0 | train\n", - "1 | encoder_var_selection | Sequential | 709 | train\n", - "2 | decoder_var_selection | Sequential | 193 | train\n", - "3 | static_context_linear | Linear | 192 | train\n", - "4 | lstm_encoder | LSTM | 51.5 K | train\n", - "5 | lstm_decoder | LSTM | 50.4 K | train\n", - "6 | self_attention | MultiheadAttention | 16.6 K | train\n", - "7 | pre_output | Linear | 4.2 K | train\n", - "8 | output_layer | Linear | 65 | train\n", - "---------------------------------------------------------------------\n", - "123 K Trainable params\n", - "0 Non-trainable params\n", - "123 K Total params\n", - "0.495 Total estimated model params size (MB)\n", - "18 Modules in train mode\n", - "0 Modules in eval mode\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Training model...\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7de178cd43ab4104ba2445a057a5f1a4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃ Test metric DataLoader 0 ┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ test_MAE 0.45287469029426575 │\n", - "│ test_SMAPE 0.942494809627533 │\n", - "│ test_loss 0.01396977063268423 │\n", - "└───────────────────────────┴───────────────────────────┘\n", - "\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.45287469029426575 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.942494809627533 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.01396977063268423 \u001b[0m\u001b[35m \u001b[0m│\n", - "└───────────────────────────┴───────────────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Prediction shape: torch.Size([32, 1, 1])\n", - "First prediction values: [[0.11045122]]\n", - "First true values: [[-0.0491814]]\n", - "\n", - "TFT model test complete!\n" - ] - } - ], - "source": [ - "model = TFT(\n", - " loss=nn.MSELoss(),\n", - " logging_metrics=[MAE(), SMAPE()],\n", - " optimizer=\"adam\",\n", - " optimizer_params={\"lr\": 1e-3},\n", - " lr_scheduler=\"reduce_lr_on_plateau\",\n", - " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", - " hidden_size=64,\n", - " num_layers=2,\n", - " attention_head_size=4,\n", - " dropout=0.1,\n", - " metadata=data_module.metadata,\n", - ")\n", - "\n", - "print(\"\\nTraining model...\")\n", - "trainer = Trainer(\n", - " max_epochs=5,\n", - " accelerator=\"auto\",\n", - " devices=1,\n", - " enable_progress_bar=True,\n", - " log_every_n_steps=10,\n", - ")\n", - "\n", - "trainer.fit(model, data_module)\n", - "\n", - "print(\"\\nEvaluating model...\")\n", - "test_metrics = trainer.test(model, data_module)\n", - "\n", - "model.eval()\n", - "with torch.no_grad():\n", - " test_batch = next(iter(data_module.test_dataloader()))\n", - " x_test, y_test = test_batch\n", - " y_pred = model(x_test)\n", - "\n", - " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", - " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", - " print(\"First true values:\", y_test[0].cpu().numpy())\n", - "print(\"\\nTFT model test complete!\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zVRwi2MvLGgc" - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/pytorch_forecasting/data/examples.py b/pytorch_forecasting/data/examples.py index 1adab65d6..dc24b349f 100644 --- a/pytorch_forecasting/data/examples.py +++ b/pytorch_forecasting/data/examples.py @@ -109,3 +109,27 @@ def generate_ar_data( ) return data + + +def load_toydata(num_series, seq_length): + data_list = [] + for i in range(num_series): + x = np.arange(seq_length) + y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length) + category = i % 5 + static_value = np.random.rand() + for t in range(seq_length - 1): + data_list.append( + { + "series_id": i, + "time_idx": t, + "x": y[t], + "y": y[t + 1], + "category": category, + "future_known_feature": np.cos(t / 10), + "static_feature": static_value, + "static_feature_cat": i % 3, + } + ) + data_df = pd.DataFrame(data_list) + return data_df From 2c1fd59b8bc2dcc5c1f1a0976c35cef7ddacf26c Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 2 Jun 2025 00:55:09 +0530 Subject: [PATCH 40/43] update notebook --- docs/source/tutorials/ptf_V2_example.ipynb | 2884 -------------------- 1 file changed, 2884 deletions(-) diff --git a/docs/source/tutorials/ptf_V2_example.ipynb b/docs/source/tutorials/ptf_V2_example.ipynb index 9e2c5b61c..5912c962f 100644 --- a/docs/source/tutorials/ptf_V2_example.ipynb +++ b/docs/source/tutorials/ptf_V2_example.ipynb @@ -851,2890 +851,6 @@ "language_info": { "name": "python", "version": "3.12.3" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "01a5c809e4254e43a1ce9678173eb4a9": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "049d94a2e6e24c6d95280581a7c680ef": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "051d93365fe148c18fc0fa90af70ea45": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "0a19ccb1dabc4f689ebbfc2aa8a2855c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_29aa1ffff46e40758d9a65d6412c1a14", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_37719701ed6e4474a84c8f7bedfef5ea", - "tabbable": null, - "tooltip": null, - "value": 9 - } - }, - "1fcc462941bf4f97813d83f30d04a7e2": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "243678f5c0ab4f2c8bce997e67490134": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "24f8555a1ec5453095322162a016c235": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_25d2cdd4420c42ac9ada8fc1c0b4a102", - "IPY_MODEL_b32e630900154feea6b63dcb1e6fc136", - "IPY_MODEL_31738edb3858416baa14371956256152" - ], - "layout": "IPY_MODEL_b4d7dafb73af4633a23d1fbf197eb289", - "tabbable": null, - "tooltip": null - } - }, - "25d2cdd4420c42ac9ada8fc1c0b4a102": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_69f243809bc74cbdaf413772c4184902", - "placeholder": "​", - "style": "IPY_MODEL_2f34c1ad49dc46c1b0958e1f78433f7f", - "tabbable": null, - "tooltip": null, - "value": "Validation DataLoader 0: 100%" - } - }, - "27e9ab9d495543d7aadd4f418fbf6d9d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "2902b8469715416fa64c61d3bd26613c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_4b63bd8fd47641968594ec0bb7919955", - "placeholder": "​", - "style": "IPY_MODEL_7d60450c60184959accb717f31270816", - "tabbable": null, - "tooltip": null, - "value": " 9/9 [00:00<00:00, 11.61it/s]" - } - }, - "29aa1ffff46e40758d9a65d6412c1a14": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2c624ad70fa442d9ba15d77d32c3b4fb": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "2e907b4f2f674731933839fe854b7915": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "2f34c1ad49dc46c1b0958e1f78433f7f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "2f50b7ba59b546a9a0a139a9880c7439": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "31738edb3858416baa14371956256152": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_9dc235acfcac469ab653a149792dcd0d", - "placeholder": "​", - "style": "IPY_MODEL_27e9ab9d495543d7aadd4f418fbf6d9d", - "tabbable": null, - "tooltip": null, - "value": " 9/9 [00:00<00:00, 11.50it/s]" - } - }, - "32dfa18d4a764de4845dd937f3e4f099": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "36eafed5df36478b80e59d32675f03b1": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "37719701ed6e4474a84c8f7bedfef5ea": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "399c8534021d4f808130d746199f887e": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_ac37c8ac19004341958dd04791943ccf", - "IPY_MODEL_7a52d6588e2440bfa3b6bb78f353d6cb", - "IPY_MODEL_76d7913b8b9c440289a0bfd89bf46674" - ], - "layout": "IPY_MODEL_2c624ad70fa442d9ba15d77d32c3b4fb", - "tabbable": null, - "tooltip": null - } - }, - "3da9fc2e0fec4f55a9a04a14d0865f17": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_5786865f15ff40b085a5bdc908fafafb", - "max": 2, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_42d2d7a93a9c4e4caec8e9cfb9078a8f", - "tabbable": null, - "tooltip": null, - "value": 2 - } - }, - "3ffd7aa2c8ea4098882524150409893f": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "41a0c0a294d94ae7911c6275862786e2": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_ca32b0845d63450ea10b2f74e45b85ab", - "placeholder": "​", - "style": "IPY_MODEL_63d9975f089242e1a2e25e9f667615b8", - "tabbable": null, - "tooltip": null, - "value": "Validation DataLoader 0: 100%" - } - }, - "42d2d7a93a9c4e4caec8e9cfb9078a8f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "43bca6c15e324927a42bdeef4aabcfa0": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "45783fa479844c42b468e57f59d04893": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_e1233d07b0f54d609f5e80d7db0aa319", - "IPY_MODEL_3da9fc2e0fec4f55a9a04a14d0865f17", - "IPY_MODEL_e309d524e11d4fdfa9833ee588f623f4" - ], - "layout": "IPY_MODEL_967c9daec0574b9c9347c764b513722c", - "tabbable": null, - "tooltip": null - } - }, - "47d85d85ca9e4f5a8ac72e2c5b1d0a0f": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "4944ace504f946c79b60161890c968a0": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_3ffd7aa2c8ea4098882524150409893f", - "placeholder": "​", - "style": "IPY_MODEL_8197a6b6b0be4b6c828e8a242e275869", - "tabbable": null, - "tooltip": null, - "value": "Testing DataLoader 0: 100%" - } - }, - "495b5207e43840e3a6649516295bdbb3": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "4b63bd8fd47641968594ec0bb7919955": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "513343d79e4a4e0492c2a2e29aa79879": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_bcd6636168424a9797d695e67db39dd8", - "placeholder": "​", - "style": "IPY_MODEL_36eafed5df36478b80e59d32675f03b1", - "tabbable": null, - "tooltip": null, - "value": "Validation DataLoader 0: 100%" - } - }, - "56917958a5a8402c9cecda208ded82f2": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "5786865f15ff40b085a5bdc908fafafb": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "5a8b4b3a2a214889aed4690c3736aa6f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "5b47c482771448e9af0bf88e684e421f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_fc4ec9c1331a40eab9aa68a15a7a89fe", - "placeholder": "​", - "style": "IPY_MODEL_2e907b4f2f674731933839fe854b7915", - "tabbable": null, - "tooltip": null, - "value": " 42/42 [00:06<00:00,  6.45it/s, v_num=0, train_loss_step=0.00881, val_loss=0.0153, val_MAE=0.487, val_SMAPE=1.110, train_loss_epoch=0.0182, train_MAE=0.457, train_SMAPE=0.993]" - } - }, - "63d9975f089242e1a2e25e9f667615b8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "66fb47ed48ad4c5d9d8269eeead488c2": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "69f243809bc74cbdaf413772c4184902": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "6a0d5b88cbef4b0c9eee5669fabe29b1": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "7378401a35094aa0bd416c33931d7c03": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_56917958a5a8402c9cecda208ded82f2", - "max": 42, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_f914f3a966e24f8eb0e6c23ec3ed5b61", - "tabbable": null, - "tooltip": null, - "value": 42 - } - }, - "7475bee2c4d8459fbb1a8c026c87cc53": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_d4ff8288c3254e0d9fa1b00bcf38aa78", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_d13c22944bdf4b6b882a08956d3475f1", - "tabbable": null, - "tooltip": null, - "value": 9 - } - }, - "76d7913b8b9c440289a0bfd89bf46674": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_47d85d85ca9e4f5a8ac72e2c5b1d0a0f", - "placeholder": "​", - "style": "IPY_MODEL_5a8b4b3a2a214889aed4690c3736aa6f", - "tabbable": null, - "tooltip": null, - "value": " 9/9 [00:00<00:00, 11.67it/s]" - } - }, - "7a52d6588e2440bfa3b6bb78f353d6cb": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_ddd73bc058724056ab643b0634f7b5db", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_aef199a692bc4338a63f81b38c94ad90", - "tabbable": null, - "tooltip": null, - "value": 9 - } - }, - "7af2d15ad21c4b54911d9f505ee8a1be": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "7c19c43c559c40ab8801aaf0bd801355": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_41a0c0a294d94ae7911c6275862786e2", - "IPY_MODEL_bb41bd772f5c4a528998e0369b51dffd", - "IPY_MODEL_a0e9638c78b6405ba95e7d2a89cbe800" - ], - "layout": "IPY_MODEL_99820b7ff3c74f6ea2e89d27701786b3", - "tabbable": null, - "tooltip": null - } - }, - "7d60450c60184959accb717f31270816": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "813b34681d614379ba78e28d7605fff5": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_c14874c570ea4310a9edfd148ea2fe73", - "placeholder": "​", - "style": "IPY_MODEL_8f7a1cbbb81b42cc8262e28dc0d1f278", - "tabbable": null, - "tooltip": null, - "value": " 9/9 [00:00<00:00,  9.42it/s]" - } - }, - "8197a6b6b0be4b6c828e8a242e275869": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "82db490ebcbc493e9cdcdeed02675bb1": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": "100%" - } - }, - "834bf5c188ed497db28148bbd0ab8694": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "83c313f300244406a536595f694d794b": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": "100%" - } - }, - "891dd939561b4078a091f3536336c11e": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "8f7a1cbbb81b42cc8262e28dc0d1f278": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "9078b5fa96dc41eaa8db4e8b0748e56f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "941c97f36d2d4c7780fa034234011c22": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_6a0d5b88cbef4b0c9eee5669fabe29b1", - "placeholder": "​", - "style": "IPY_MODEL_66fb47ed48ad4c5d9d8269eeead488c2", - "tabbable": null, - "tooltip": null, - "value": "Validation DataLoader 0: 100%" - } - }, - "967c9daec0574b9c9347c764b513722c": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "99820b7ff3c74f6ea2e89d27701786b3": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "9be2ea5a2a0b4932bdefb4dcc6ed7e3b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_513343d79e4a4e0492c2a2e29aa79879", - "IPY_MODEL_be5cd97e0aee4455b417f5fbea068f96", - "IPY_MODEL_2902b8469715416fa64c61d3bd26613c" - ], - "layout": "IPY_MODEL_051d93365fe148c18fc0fa90af70ea45", - "tabbable": null, - "tooltip": null - } - }, - "9dc235acfcac469ab653a149792dcd0d": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "9e5dde09d59845f6b0f78f9ab86f509c": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "a0e9638c78b6405ba95e7d2a89cbe800": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_2f50b7ba59b546a9a0a139a9880c7439", - "placeholder": "​", - "style": "IPY_MODEL_7af2d15ad21c4b54911d9f505ee8a1be", - "tabbable": null, - "tooltip": null, - "value": " 9/9 [00:01<00:00,  8.20it/s]" - } - }, - "ac37c8ac19004341958dd04791943ccf": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_1fcc462941bf4f97813d83f30d04a7e2", - "placeholder": "​", - "style": "IPY_MODEL_495b5207e43840e3a6649516295bdbb3", - "tabbable": null, - "tooltip": null, - "value": "Validation DataLoader 0: 100%" - } - }, - "aef199a692bc4338a63f81b38c94ad90": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "b32e630900154feea6b63dcb1e6fc136": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_d98887028e6b4b48bd428d81fa6b74f5", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_9078b5fa96dc41eaa8db4e8b0748e56f", - "tabbable": null, - "tooltip": null, - "value": 9 - } - }, - "b4d7dafb73af4633a23d1fbf197eb289": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "bb41bd772f5c4a528998e0369b51dffd": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_243678f5c0ab4f2c8bce997e67490134", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_c2d1809f9e894c3eafec87e34580c5a7", - "tabbable": null, - "tooltip": null, - "value": 9 - } - }, - "bcd6636168424a9797d695e67db39dd8": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "be5cd97e0aee4455b417f5fbea068f96": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_834bf5c188ed497db28148bbd0ab8694", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_891dd939561b4078a091f3536336c11e", - "tabbable": null, - "tooltip": null, - "value": 9 - } - }, - "c14874c570ea4310a9edfd148ea2fe73": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "c29831a4eaa7492f808f0655993f2907": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_f95586183a2142b2b2458eee62135678", - "placeholder": "​", - "style": "IPY_MODEL_049d94a2e6e24c6d95280581a7c680ef", - "tabbable": null, - "tooltip": null, - "value": " 9/9 [00:00<00:00, 11.33it/s]" - } - }, - "c2d1809f9e894c3eafec87e34580c5a7": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "ca32b0845d63450ea10b2f74e45b85ab": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "cef82d8c32104bae84227b9aa8687824": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_4944ace504f946c79b60161890c968a0", - "IPY_MODEL_7475bee2c4d8459fbb1a8c026c87cc53", - "IPY_MODEL_813b34681d614379ba78e28d7605fff5" - ], - "layout": "IPY_MODEL_82db490ebcbc493e9cdcdeed02675bb1", - "tabbable": null, - "tooltip": null - } - }, - "d13c22944bdf4b6b882a08956d3475f1": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "d1edfa752fcd4f2993cbfa18442cd49d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "d2fb0c3c65ee43f3875409f6f603697d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_941c97f36d2d4c7780fa034234011c22", - "IPY_MODEL_0a19ccb1dabc4f689ebbfc2aa8a2855c", - "IPY_MODEL_c29831a4eaa7492f808f0655993f2907" - ], - "layout": "IPY_MODEL_d96e9c61b63f4365aeccd88605cc85d0", - "tabbable": null, - "tooltip": null - } - }, - "d4ff8288c3254e0d9fa1b00bcf38aa78": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d96e9c61b63f4365aeccd88605cc85d0": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "d98887028e6b4b48bd428d81fa6b74f5": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "ddd73bc058724056ab643b0634f7b5db": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "e1233d07b0f54d609f5e80d7db0aa319": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_32dfa18d4a764de4845dd937f3e4f099", - "placeholder": "​", - "style": "IPY_MODEL_d1edfa752fcd4f2993cbfa18442cd49d", - "tabbable": null, - "tooltip": null, - "value": "Sanity Checking DataLoader 0: 100%" - } - }, - "e309d524e11d4fdfa9833ee588f623f4": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_01a5c809e4254e43a1ce9678173eb4a9", - "placeholder": "​", - "style": "IPY_MODEL_43bca6c15e324927a42bdeef4aabcfa0", - "tabbable": null, - "tooltip": null, - "value": " 2/2 [00:00<00:00,  2.57it/s]" - } - }, - "e63282c1bc2e49a89036864eea8d1315": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "ec6637f83f4244d3ba2b2a3635a1ef85": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_9e5dde09d59845f6b0f78f9ab86f509c", - "placeholder": "​", - "style": "IPY_MODEL_e63282c1bc2e49a89036864eea8d1315", - "tabbable": null, - "tooltip": null, - "value": "Epoch 4: 100%" - } - }, - "f8c3866ecdef4360a695e16d5b83bb63": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_ec6637f83f4244d3ba2b2a3635a1ef85", - "IPY_MODEL_7378401a35094aa0bd416c33931d7c03", - "IPY_MODEL_5b47c482771448e9af0bf88e684e421f" - ], - "layout": "IPY_MODEL_83c313f300244406a536595f694d794b", - "tabbable": null, - "tooltip": null - } - }, - "f914f3a966e24f8eb0e6c23ec3ed5b61": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "f95586183a2142b2b2458eee62135678": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "fc4ec9c1331a40eab9aa68a15a7a89fe": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - } - } } }, "nbformat": 4, From c15854c2d11177950abb5cba73397066a1fba5c8 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 4 Jun 2025 00:02:45 +0530 Subject: [PATCH 41/43] add markdown --- docs/source/tutorials/ptf_V2_example.ipynb | 63 ++++++++++++++++++---- 1 file changed, 54 insertions(+), 9 deletions(-) diff --git a/docs/source/tutorials/ptf_V2_example.ipynb b/docs/source/tutorials/ptf_V2_example.ipynb index 5912c962f..062bb8bf7 100644 --- a/docs/source/tutorials/ptf_V2_example.ipynb +++ b/docs/source/tutorials/ptf_V2_example.ipynb @@ -1,5 +1,21 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example Notebook for a basic vignette for `pytorch-forecasting v2` Data pipeline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + ":warning: The \"Data Pipeline\" showcased here is part of an experimental rework of the `pytorch-forecasting` data layer, planned for release in v2.0.0. The API is currently unstable and subject to change without prior notice. This notebook serves as a basic demonstration of the intended workflow and is not recommended for use in production environments. Feedback and suggestions are highly encouraged — please share them in issue 1736.\n", + "
\n" + ] + }, { "cell_type": "code", "execution_count": 16, @@ -37,9 +53,17 @@ "from pytorch_forecasting.data.examples import load_toydata" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load Data\n", + "We generate a synthetic dataset using `load_toydata` that creates a `pandas` DataFrame with just numerical values as for now the pipeline assumes the data to be numerical only" + ] + }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -373,15 +397,23 @@ } ], "source": [ - "num_series = 100\n", - "seq_length = 50\n", + "num_series = 100 # Number of individual time series to generate\n", + "seq_length = 50 # Length of each time series\n", "data_df = load_toydata(num_series, seq_length)\n", "data_df.head()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Create the dataset and datamodule\n", + "We create a `TimeSeries` dataset instance that returns the raw data in terms of tensors, then this \"raw data\" is sent to the `data_module`that will internally handle the dataloaders and preprocessing" + ] + }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -400,13 +432,14 @@ } ], "source": [ + "# create dataset that returns the raw data in terms of tensors\n", "dataset = TimeSeries(\n", " data=data_df,\n", " time=\"time_idx\",\n", " target=\"y\",\n", " group=[\"series_id\"],\n", - " num=[\"x\", \"future_known_feature\", \"static_feature\"],\n", - " cat=[\"category\", \"static_feature_cat\"],\n", + " num=[\"x\", \"future_known_feature\", \"static_feature\"], # numerical features\n", + " cat=[\"category\", \"static_feature_cat\"], # categorical features\n", " known=[\"future_known_feature\"],\n", " unknown=[\"x\", \"category\"],\n", " static=[\"static_feature\", \"static_feature_cat\"],\n", @@ -415,7 +448,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -434,6 +467,7 @@ } ], "source": [ + "# create the data_module that handles the dataloaders and preprocessing\n", "data_module = EncoderDecoderTimeSeriesDataModule(\n", " time_series_dataset=dataset,\n", " max_encoder_length=30,\n", @@ -452,9 +486,16 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Initialise and train the model" + ] + }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -789,6 +830,7 @@ } ], "source": [ + "# Initialise the Model\n", "model = TFT(\n", " loss=nn.MSELoss(),\n", " logging_metrics=[MAE(), SMAPE()],\n", @@ -800,9 +842,11 @@ " num_layers=2,\n", " attention_head_size=4,\n", " dropout=0.1,\n", - " metadata=data_module.metadata,\n", + " metadata=data_module.metadata, # pass the metadata from the datamodule to the model\n", + " # to initialise important params like `encoder_cont` etc\n", ")\n", "\n", + "# Train the model\n", "print(\"\\nTraining model...\")\n", "trainer = Trainer(\n", " max_epochs=5,\n", @@ -814,6 +858,7 @@ "\n", "trainer.fit(model, data_module)\n", "\n", + "# Evaluate the model\n", "print(\"\\nEvaluating model...\")\n", "test_metrics = trainer.test(model, data_module)\n", "\n", From 396d2a4237c4c1ec102d429ea88c281c95815361 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 7 Jun 2025 02:16:12 +0530 Subject: [PATCH 42/43] update notebook --- docs/source/tutorials/ptf_V2_example.ipynb | 1940 +++++++++++--------- 1 file changed, 1097 insertions(+), 843 deletions(-) diff --git a/docs/source/tutorials/ptf_V2_example.ipynb b/docs/source/tutorials/ptf_V2_example.ipynb index 062bb8bf7..0c97d365f 100644 --- a/docs/source/tutorials/ptf_V2_example.ipynb +++ b/docs/source/tutorials/ptf_V2_example.ipynb @@ -1,903 +1,1157 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Example Notebook for a basic vignette for `pytorch-forecasting v2` Data pipeline" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "
\n", - ":warning: The \"Data Pipeline\" showcased here is part of an experimental rework of the `pytorch-forecasting` data layer, planned for release in v2.0.0. The API is currently unstable and subject to change without prior notice. This notebook serves as a basic demonstration of the intended workflow and is not recommended for use in production environments. Feedback and suggestions are highly encouraged — please share them in issue 1736.\n", - "
\n" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "M7PQerTbI_tM" - }, - "outputs": [], - "source": [ - "from lightning.pytorch import Trainer\n", - "import numpy as np\n", - "import pandas as pd\n", - "from sklearn.preprocessing import RobustScaler, StandardScaler\n", - "import torch\n", - "import torch.nn as nn\n", - "from torch.optim import Optimizer\n", - "from torch.utils.data import Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DGTyf3vct-Jk" - }, - "outputs": [], - "source": [ - "from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule\n", - "from pytorch_forecasting.data.encoders import (\n", - " NaNLabelEncoder,\n", - " TorchNormalizer,\n", - ")\n", - "from pytorch_forecasting.data.timeseries import TimeSeries\n", - "from pytorch_forecasting.metrics import MAE, SMAPE\n", - "from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT\n", - "from pytorch_forecasting.data.examples import load_toydata" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load Data\n", - "We generate a synthetic dataset using `load_toydata` that creates a `pandas` DataFrame with just numerical values as for now the pipeline assumes the data to be numerical only" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 206 + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "rzVbXsEBxnF-" + }, + "source": [ + "# Example Notebook for a basic vignette for `pytorch-forecasting v2` Model Training and Inference" + ] }, - "id": "WX-FRdusJSVN", - "outputId": "8a3ff2f5-afac-4aa8-fc9b-197434d23b10" - }, - "outputs": [ { - "data": { - "application/vnd.google.colaboratory.intrinsic+json": { - "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6708702925213534,\n \"min\": -1.3075502393215899,\n \"max\": 1.3507284163939648,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.5301916832922879,\n 0.9038568051261738,\n -1.0667842789852795\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6751171140336862,\n \"min\": -1.3075502393215899,\n \"max\": 1.3507284163939648,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.6349349027782476,\n 0.6498298627180364,\n -0.6540777514915888\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2671504392334601,\n \"min\": 0.021493369564058562,\n \"max\": 0.9758804007580262,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.7394229429308296,\n 0.8404599998401439,\n 0.24249095973810553\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", - "type": "dataframe", - "variable_name": "data_df" + "cell_type": "markdown", + "metadata": { + "id": "yt0uZV7Px-40" }, - "text/html": [ - "\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
000-0.0932870.30890801.0000000.5551660
1010.3089080.27633000.9950040.5551660
2020.2763300.48430200.9800670.5551660
3030.4843020.64893800.9553360.5551660
4040.6489380.94295600.9210610.5551660
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "\n", - " \n", - "\n", - " \n", - "
\n", - "\n", - "\n", - "
\n", - " \n", - "\n", - "\n", - "\n", - " \n", - "
\n", - "\n", - "
\n", - "
\n" - ], - "text/plain": [ - " series_id time_idx ... static_feature static_feature_cat\n", - "0 0 0 ... 0.555166 0\n", - "1 0 1 ... 0.555166 0\n", - "2 0 2 ... 0.555166 0\n", - "3 0 3 ... 0.555166 0\n", - "4 0 4 ... 0.555166 0\n", - "\n", - "[5 rows x 8 columns]" + "source": [ + "
\n", + ":warning: The vignette showcased here is part of an experimental rework of the `pytorch-forecasting` data layer, planned for release in v2.0.0. The API is currently unstable and subject to change without prior notice. This notebook serves as a basic demonstration of the intended workflow and is not recommended for use in production environments. Feedback and suggestions are highly encouraged — please share them in issue 1736.\n", + "
\n" ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "num_series = 100 # Number of individual time series to generate\n", - "seq_length = 50 # Length of each time series\n", - "data_df = load_toydata(num_series, seq_length)\n", - "data_df.head()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Create the dataset and datamodule\n", - "We create a `TimeSeries` dataset instance that returns the raw data in terms of tensors, then this \"raw data\" is sent to the `data_module`that will internally handle the dataloaders and preprocessing" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "AxxPHK6AKSD2", - "outputId": "ff7c3e3d-7cb7-405b-d668-ed65e66c0dde" - }, - "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "/content/pytorch-forecasting/pytorch_forecasting/data/timeseries/_timeseries_v2.py:105: UserWarning: TimeSeries is part of an experimental rework of the pytorch-forecasting data layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. For beta testing, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", - " warn(\n" - ] - } - ], - "source": [ - "# create dataset that returns the raw data in terms of tensors\n", - "dataset = TimeSeries(\n", - " data=data_df,\n", - " time=\"time_idx\",\n", - " target=\"y\",\n", - " group=[\"series_id\"],\n", - " num=[\"x\", \"future_known_feature\", \"static_feature\"], # numerical features\n", - " cat=[\"category\", \"static_feature_cat\"], # categorical features\n", - " known=[\"future_known_feature\"],\n", - " unknown=[\"x\", \"category\"],\n", - " static=[\"static_feature\", \"static_feature_cat\"],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "markdown", + "metadata": { + "id": "6D9ARyp05R0t" + }, + "source": [ + "In this vignette, we demonstrate how to train and evaluate the **Temporal Fusion Transformer (TFT)** using the new `TimeSeries` and `DataModule` API from the v2 pipeline.\n", + "\n", + "\n", + "## Steps\n", + "\n", + "1. **Load Data** \n", + "2. **Create Dataset & DataModule** \n", + "3. **Initialize, Train & Run Inference with the Model**\n", + "\n", + "\n", + "\n", + "### Load Data\n", + "\n", + "We generate a synthetic dataset using `load_toydata` which returns a `pandas` DataFrame with purely numerical values. \n", + "*(Note: The current pipeline assumes all inputs are numerical only.)*\n", + "\n", + "\n", + "\n", + "\n", + "### Create Dataset & DataModule\n", + "\n", + "- `TimeSeries` returns the raw data in terms of tensors .\n", + "- `DataModule` wraps the dataset, handles splits, preprocessing, batching, and exposes `metadata` for the model initialisation.\n", + "\n", + "\n", + "\n", + "### Initialize the Model\n", + "\n", + "We initialize the TFT model using the `metadata` provided by the `DataModule`. This metadata includes all required dimensional info for the encoder, decoder, and static inputs.\n", + "\n", + "\n", + "\n", + "### Train the Model\n", + "\n", + "We use a `Trainer` from PyTorch Lightning to train the model\n", + "\n", + "### Run Inference\n", + "\n", + "After training, we can make predictions using the trained model\n" + ] }, - "id": "5U5Lr_ZFKX0s", - "outputId": "3b51be31-c103-4c4a-8eec-1fc72e15bdff" - }, - "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "/content/pytorch-forecasting/pytorch_forecasting/data/data_module.py:129: UserWarning: TimeSeries is part of an experimental rework of the pytorch-forecasting data layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. For beta testing, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", - " warn(\n" - ] - } - ], - "source": [ - "# create the data_module that handles the dataloaders and preprocessing\n", - "data_module = EncoderDecoderTimeSeriesDataModule(\n", - " time_series_dataset=dataset,\n", - " max_encoder_length=30,\n", - " max_prediction_length=1,\n", - " batch_size=32,\n", - " categorical_encoders={\n", - " \"category\": NaNLabelEncoder(add_nan=True),\n", - " \"static_feature_cat\": NaNLabelEncoder(add_nan=True),\n", - " },\n", - " scalers={\n", - " \"x\": StandardScaler(),\n", - " \"future_known_feature\": StandardScaler(),\n", - " \"static_feature\": StandardScaler(),\n", - " },\n", - " target_normalizer=TorchNormalizer(),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Initialise and train the model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000, - "referenced_widgets": [ - "45783fa479844c42b468e57f59d04893", - "e1233d07b0f54d609f5e80d7db0aa319", - "3da9fc2e0fec4f55a9a04a14d0865f17", - "e309d524e11d4fdfa9833ee588f623f4", - "967c9daec0574b9c9347c764b513722c", - "32dfa18d4a764de4845dd937f3e4f099", - "d1edfa752fcd4f2993cbfa18442cd49d", - "5786865f15ff40b085a5bdc908fafafb", - "42d2d7a93a9c4e4caec8e9cfb9078a8f", - "01a5c809e4254e43a1ce9678173eb4a9", - "43bca6c15e324927a42bdeef4aabcfa0", - "f8c3866ecdef4360a695e16d5b83bb63", - "ec6637f83f4244d3ba2b2a3635a1ef85", - "7378401a35094aa0bd416c33931d7c03", - "5b47c482771448e9af0bf88e684e421f", - "83c313f300244406a536595f694d794b", - "9e5dde09d59845f6b0f78f9ab86f509c", - "e63282c1bc2e49a89036864eea8d1315", - "56917958a5a8402c9cecda208ded82f2", - "f914f3a966e24f8eb0e6c23ec3ed5b61", - "fc4ec9c1331a40eab9aa68a15a7a89fe", - "2e907b4f2f674731933839fe854b7915", - "d2fb0c3c65ee43f3875409f6f603697d", - "941c97f36d2d4c7780fa034234011c22", - "0a19ccb1dabc4f689ebbfc2aa8a2855c", - "c29831a4eaa7492f808f0655993f2907", - "d96e9c61b63f4365aeccd88605cc85d0", - "6a0d5b88cbef4b0c9eee5669fabe29b1", - "66fb47ed48ad4c5d9d8269eeead488c2", - "29aa1ffff46e40758d9a65d6412c1a14", - "37719701ed6e4474a84c8f7bedfef5ea", - "f95586183a2142b2b2458eee62135678", - "049d94a2e6e24c6d95280581a7c680ef", - "24f8555a1ec5453095322162a016c235", - "25d2cdd4420c42ac9ada8fc1c0b4a102", - "b32e630900154feea6b63dcb1e6fc136", - "31738edb3858416baa14371956256152", - "b4d7dafb73af4633a23d1fbf197eb289", - "69f243809bc74cbdaf413772c4184902", - "2f34c1ad49dc46c1b0958e1f78433f7f", - "d98887028e6b4b48bd428d81fa6b74f5", - "9078b5fa96dc41eaa8db4e8b0748e56f", - "9dc235acfcac469ab653a149792dcd0d", - "27e9ab9d495543d7aadd4f418fbf6d9d", - "399c8534021d4f808130d746199f887e", - "ac37c8ac19004341958dd04791943ccf", - "7a52d6588e2440bfa3b6bb78f353d6cb", - "76d7913b8b9c440289a0bfd89bf46674", - "2c624ad70fa442d9ba15d77d32c3b4fb", - "1fcc462941bf4f97813d83f30d04a7e2", - "495b5207e43840e3a6649516295bdbb3", - "ddd73bc058724056ab643b0634f7b5db", - "aef199a692bc4338a63f81b38c94ad90", - "47d85d85ca9e4f5a8ac72e2c5b1d0a0f", - "5a8b4b3a2a214889aed4690c3736aa6f", - "9be2ea5a2a0b4932bdefb4dcc6ed7e3b", - "513343d79e4a4e0492c2a2e29aa79879", - "be5cd97e0aee4455b417f5fbea068f96", - "2902b8469715416fa64c61d3bd26613c", - "051d93365fe148c18fc0fa90af70ea45", - "bcd6636168424a9797d695e67db39dd8", - "36eafed5df36478b80e59d32675f03b1", - "834bf5c188ed497db28148bbd0ab8694", - "891dd939561b4078a091f3536336c11e", - "4b63bd8fd47641968594ec0bb7919955", - "7d60450c60184959accb717f31270816", - "7c19c43c559c40ab8801aaf0bd801355", - "41a0c0a294d94ae7911c6275862786e2", - "bb41bd772f5c4a528998e0369b51dffd", - "a0e9638c78b6405ba95e7d2a89cbe800", - "99820b7ff3c74f6ea2e89d27701786b3", - "ca32b0845d63450ea10b2f74e45b85ab", - "63d9975f089242e1a2e25e9f667615b8", - "243678f5c0ab4f2c8bce997e67490134", - "c2d1809f9e894c3eafec87e34580c5a7", - "2f50b7ba59b546a9a0a139a9880c7439", - "7af2d15ad21c4b54911d9f505ee8a1be", - "cef82d8c32104bae84227b9aa8687824", - "4944ace504f946c79b60161890c968a0", - "7475bee2c4d8459fbb1a8c026c87cc53", - "813b34681d614379ba78e28d7605fff5", - "82db490ebcbc493e9cdcdeed02675bb1", - "3ffd7aa2c8ea4098882524150409893f", - "8197a6b6b0be4b6c828e8a242e275869", - "d4ff8288c3254e0d9fa1b00bcf38aa78", - "d13c22944bdf4b6b882a08956d3475f1", - "c14874c570ea4310a9edfd148ea2fe73", - "8f7a1cbbb81b42cc8262e28dc0d1f278" - ] + "cell_type": "markdown", + "metadata": { + "id": "QyMFNk4MyY_b" + }, + "source": [ + "# 1. Load Data\n", + "We generate a synthetic dataset using `load_toydata` that creates a `pandas` DataFrame with just numerical values as for now **the pipeline assumes the data to be numerical only**." + ] }, - "id": "Si7bbZIULBZz", - "outputId": "4bf7f692-fbb8-4e1e-afe8-6f7a4ebe89ca" - }, - "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "/content/pytorch-forecasting/pytorch_forecasting/models/base/_base_model_v2.py:58: UserWarning: The Model 'TFT' is part of an experimental reworkof the pytorch-forecasting model layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. This class is intended for beta testing and as a basic skeleton, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", - " warn(\n", - "INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", - "INFO:lightning.pytorch.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", - "INFO: GPU available: False, used: False\n", - "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", - "INFO: TPU available: False, using: 0 TPU cores\n", - "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", - "INFO: HPU available: False, using: 0 HPUs\n", - "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n" - ] + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "RkgOT4kiy_RU" + }, + "outputs": [], + "source": [ + "from pytorch_forecasting.data.examples import load_toydata" + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Training model...\n" - ] + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "WX-FRdusJSVN", + "outputId": "e481484c-b0c3-4026-c933-a9dc047617c5" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6718152112102416,\n \"min\": -1.26886060400029,\n \"max\": 1.3107688985416461,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.1974946533075752,\n 0.8954011563960967,\n -0.802070871866599\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6763475728129391,\n \"min\": -1.26886060400029,\n \"max\": 1.3107688985416461,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.4987517118113735,\n 0.6086435548017073,\n -1.0256970040347706\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2888495447542638,\n \"min\": 0.0028672051720661784,\n \"max\": 0.990604862265208,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.5971646697158197,\n 0.12749395651151985,\n 0.32838971618312873\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", + "type": "dataframe", + "variable_name": "data_df" + }, + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
000-0.1379880.01682401.0000000.9275570
1010.0168240.28729100.9950040.9275570
2020.2872910.59958700.9800670.9275570
3030.5995870.77935200.9553360.9275570
4040.7793520.87614800.9210610.9275570
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "text/plain": [ + " series_id time_idx x y category future_known_feature \\\n", + "0 0 0 -0.137988 0.016824 0 1.000000 \n", + "1 0 1 0.016824 0.287291 0 0.995004 \n", + "2 0 2 0.287291 0.599587 0 0.980067 \n", + "3 0 3 0.599587 0.779352 0 0.955336 \n", + "4 0 4 0.779352 0.876148 0 0.921061 \n", + "\n", + " static_feature static_feature_cat \n", + "0 0.927557 0 \n", + "1 0.927557 0 \n", + "2 0.927557 0 \n", + "3 0.927557 0 \n", + "4 0.927557 0 " + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "num_series = 100 # Number of individual time series to generate\n", + "seq_length = 50 # Length of each time series\n", + "data_df = load_toydata(num_series, seq_length)\n", + "data_df.head()" + ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO: \n", - " | Name | Type | Params | Mode \n", - "---------------------------------------------------------------------\n", - "0 | loss | MSELoss | 0 | train\n", - "1 | encoder_var_selection | Sequential | 709 | train\n", - "2 | decoder_var_selection | Sequential | 193 | train\n", - "3 | static_context_linear | Linear | 192 | train\n", - "4 | lstm_encoder | LSTM | 51.5 K | train\n", - "5 | lstm_decoder | LSTM | 50.4 K | train\n", - "6 | self_attention | MultiheadAttention | 16.6 K | train\n", - "7 | pre_output | Linear | 4.2 K | train\n", - "8 | output_layer | Linear | 65 | train\n", - "---------------------------------------------------------------------\n", - "123 K Trainable params\n", - "0 Non-trainable params\n", - "123 K Total params\n", - "0.495 Total estimated model params size (MB)\n", - "18 Modules in train mode\n", - "0 Modules in eval mode\n", - "INFO:lightning.pytorch.callbacks.model_summary:\n", - " | Name | Type | Params | Mode \n", - "---------------------------------------------------------------------\n", - "0 | loss | MSELoss | 0 | train\n", - "1 | encoder_var_selection | Sequential | 709 | train\n", - "2 | decoder_var_selection | Sequential | 193 | train\n", - "3 | static_context_linear | Linear | 192 | train\n", - "4 | lstm_encoder | LSTM | 51.5 K | train\n", - "5 | lstm_decoder | LSTM | 50.4 K | train\n", - "6 | self_attention | MultiheadAttention | 16.6 K | train\n", - "7 | pre_output | Linear | 4.2 K | train\n", - "8 | output_layer | Linear | 65 | train\n", - "---------------------------------------------------------------------\n", - "123 K Trainable params\n", - "0 Non-trainable params\n", - "123 K Total params\n", - "0.495 Total estimated model params size (MB)\n", - "18 Modules in train mode\n", - "0 Modules in eval mode\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "RYQ5CdNUyc2q" + }, + "source": [ + "# 2. Create the dataset and datamodule\n", + "We create a `TimeSeries` dataset instance that returns the raw data in terms of tensors, then this \"raw data\" is sent to the `data_module`that will internally handle the dataloaders and preprocessing" + ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "45783fa479844c42b468e57f59d04893", - "version_major": 2, - "version_minor": 0 + "cell_type": "markdown", + "metadata": { + "id": "ONe8Eo1zzvCH" }, - "text/plain": [ - "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃ Test metric DataLoader 0 ┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ test_MAE 0.5018836855888367 │\n", - "│ test_SMAPE 1.1201428174972534 │\n", - "│ test_loss 0.012998351827263832 │\n", - "└───────────────────────────┴───────────────────────────┘\n", - "\n" + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9qbjnTxnyh4H", + "outputId": "f59bf985-ffaa-4980-c890-39a80dfcc598" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/content/pytorch-forecasting/pytorch_forecasting/models/base/_base_model_v2.py:58: UserWarning: The Model 'TFT' is part of an experimental reworkof the pytorch-forecasting model layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. This class is intended for beta testing and as a basic skeleton, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", + " warn(\n" + ] + } ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5018836855888367 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.1201428174972534 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.012998351827263832 \u001b[0m\u001b[35m \u001b[0m│\n", - "└───────────────────────────┴───────────────────────────┘\n" + "source": [ + "# Initialise the Model\n", + "model = TFT(\n", + " loss=nn.MSELoss(),\n", + " logging_metrics=[MAE(), SMAPE()],\n", + " optimizer=\"adam\",\n", + " optimizer_params={\"lr\": 1e-3},\n", + " lr_scheduler=\"reduce_lr_on_plateau\",\n", + " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", + " hidden_size=64,\n", + " num_layers=2,\n", + " attention_head_size=4,\n", + " dropout=0.1,\n", + " metadata=data_module.metadata, # pass the metadata from the datamodule to the model\n", + " # to initialise important params like `encoder_cont` etc\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "svdoye-d8F-z" + }, + "source": [ + "We use a `Trainer` from PyTorch Lightning to train the model:\n", + "\n", + "```python\n", + "trainer = Trainer(max_epochs=5, ...)\n", + "trainer.fit(model, data_module)\n", + "```\n", + "\n", + "The `Trainer`:\n", + "- Pulls data from `data_module`\n", + "- Handles device placement\n", + "- Logs training progress and metrics\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "RTSmUu9RytS8" + }, + "outputs": [], + "source": [ + "from lightning.pytorch import Trainer" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Prediction shape: torch.Size([32, 1, 1])\n", - "First prediction values: [[-0.12206323]]\n", - "First true values: [[-0.04691907]]\n", - "\n", - "TFT model test complete!\n" - ] + "cell_type": "code", + "execution_count": 21, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 930, + "referenced_widgets": [ + "edc196b34a0b49fa992fd89909a8414f", + "76e801c32da348788b0b92fd2c2849c1", + "cd6d925c2dcf4ae995f62d8f20219fde", + "283f67f9ef6a47f8b3d4a389a18b8ccc", + "8e19d683c674403b87cbb7c162a0ab33", + "525fe2cef0444d558d870c394cf67e81", + "640318d6256a4451957fce6481625f6b", + "0820bd0fd60f4e0180899f18615d9ee3", + "90cb2b96e18948679eda809a9123565c", + "5ffb5772ae4f4764b23643cfa547a615", + "04f140582780402a94e3391bfcbffa91", + "4abfcb0ade1b47e1b015e7efd75cd563", + "483c0919c7c740af89e6c7a7470591f7", + "18849509bcfc41dd89011efb5690e49b", + "098eb27809314e7ca15c5b9bce21b46d", + "53f7fc6124fb4dda8bfb147078212ada", + "8c2821dc533c4a1ba81d68836047dc54", + "cb303a9d13084d47b0714d717f482a80", + "491aa926179b41dba21a008dc6e9a4fc", + "363967a782db4627a3e16a726b4993c6", + "0d28163142aa465784871ec647efa504", + "a6d821b9d9c9453b87740c3e4d86ca77", + "2301684175b0454897af0fc5e534f52f", + "58fe6653d5e944be8fdd072cd3541b8c", + "00fb395e4cb5472faf221affc0244780", + "1cd0c86474a443b8aae6466405dd3807", + "cfcb344ba8cb49a78a8f934657252855", + "4f455332437a465bb300ccf50405358d", + "fde9c4616cd84e629998819c18adfd1f", + "405e6f4ab49b445ba7e942e0d35419eb", + "2ee2fe3ea3c04592851fef7e6dda5115", + "713bcea77e1744179502f6a187239ce1", + "531c73851dcd469791d6ebdc9aa9da52", + "b0ed53cd20bd47b5aac5fecaacb9bfed", + "da87d95793f84907819ab4c506f6936b", + "c8a64619bf574d43a01595e3c8690759", + "8d0c66b5cc984b8db7d151a970897266", + "a723c1055f4f4c0d9c7b2c8c3469fa9d", + "99dcc8dd7e88463a95ddcf86c67ec153", + "94c1a818e78a47ee9bb46f1c8b1dad34", + "903b0891a3014701b117a391d777e012", + "f6053af158304fc1b307a296cdbe7290", + "f6c93805c03545458f1af0fffaf77fb0", + "52bbe8a40baf4a8f989281bcc225a067", + "c9eb21c70ac243d1a21f50f287b4731f", + "100a61bf42b0498ba241ddd9ef7a96f3", + "1cce85eee5bc4700aa966c05e35ea042", + "2dcaeafc8bc84db786f9fa99aeda39db", + "7a9cf1446cc14d51a0a729822eb1cabc", + "1c6f0e1be6f8475485ba6354ccc68597", + "7ef67c160b8740e9932486f857470542", + "282e2d3a169a4e18b46ca85dbfe82371", + "84a96c68feba48b6880de08722de8743", + "c111af2e297641908cb1835d9a4d06ff", + "d94d97f0852b450c8ea56e6ef4e44ac5", + "1dd910df34e44cdc98586e3094dfddd4", + "86c6678f86d944eeba1fe982197cb7d3", + "b609e7acf93944a98f11b450e01222f8", + "0baf540b09454f61873cd6a3088a3c9d", + "7bc8c547554c481481d0653193dcd917", + "2d92b66364844b5491dabee7a0b90686", + "99a8fec0a22f4949a3d16fc7f48aeea6", + "0fb99264aa0644cda962905dfe1d6997", + "f3d688e29f4b4432ac370d6a1e76ec5c", + "d4cd0f3144784991ba9ec0b178a6c2e7", + "83a1e407cf60401b932c6ee6e72d90db", + "99443f56f20a4c6da81ddb107ab84490", + "2da428c9a3b541bc8f667d5715fdcb0d", + "f483a26167424f66b8b3c537537afd77", + "bc01c61ee2a14f3bbea23d04bb7e9dc9", + "b59e64ed8d9d46dd8803bfe25602c391", + "cc1d7b4924784a08999f41e1e3c989ac", + "8c6ad939d53942e48d8d6fb4c30b666d", + "49d44a5db92b4b3e979e88d14072a16c", + "36d0bfd0973c43e4958af452d2c7025c", + "8a90ddfc1fd446a6894c2771195a4419", + "aa94917a2f364a4ea9033000730d48b7" + ] + }, + "id": "aB_ayE_eykXp", + "outputId": "02c49d3e-2124-4b0b-8ca4-b2d886662613" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", + "INFO:lightning.pytorch.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", + "INFO: GPU available: False, used: False\n", + "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", + "INFO: TPU available: False, using: 0 TPU cores\n", + "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO: HPU available: False, using: 0 HPUs\n", + "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n", + "INFO: \n", + " | Name | Type | Params | Mode\n", + "--------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | eval\n", + "1 | encoder_var_selection | Sequential | 709 | eval\n", + "2 | decoder_var_selection | Sequential | 193 | eval\n", + "3 | static_context_linear | Linear | 192 | eval\n", + "4 | lstm_encoder | LSTM | 51.5 K | eval\n", + "5 | lstm_decoder | LSTM | 50.4 K | eval\n", + "6 | self_attention | MultiheadAttention | 16.6 K | eval\n", + "7 | pre_output | Linear | 4.2 K | eval\n", + "8 | output_layer | Linear | 65 | eval\n", + "--------------------------------------------------------------------\n", + "123 K Trainable params\n", + "0 Non-trainable params\n", + "123 K Total params\n", + "0.495 Total estimated model params size (MB)\n", + "0 Modules in train mode\n", + "18 Modules in eval mode\n", + "INFO:lightning.pytorch.callbacks.model_summary:\n", + " | Name | Type | Params | Mode\n", + "--------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | eval\n", + "1 | encoder_var_selection | Sequential | 709 | eval\n", + "2 | decoder_var_selection | Sequential | 193 | eval\n", + "3 | static_context_linear | Linear | 192 | eval\n", + "4 | lstm_encoder | LSTM | 51.5 K | eval\n", + "5 | lstm_decoder | LSTM | 50.4 K | eval\n", + "6 | self_attention | MultiheadAttention | 16.6 K | eval\n", + "7 | pre_output | Linear | 4.2 K | eval\n", + "8 | output_layer | Linear | 65 | eval\n", + "--------------------------------------------------------------------\n", + "123 K Trainable params\n", + "0 Non-trainable params\n", + "123 K Total params\n", + "0.495 Total estimated model params size (MB)\n", + "0 Modules in train mode\n", + "18 Modules in eval mode\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Training model...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "edc196b34a0b49fa992fd89909a8414f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test_MAE 0.48676350712776184 │\n", + "│ test_SMAPE 1.031250238418579 │\n", + "│ test_loss 0.012420947663486004 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.48676350712776184 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.031250238418579 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.012420947663486004 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Prediction shape: torch.Size([32, 1, 1])\n", + "First prediction values: [[-0.11366826]]\n", + "First true values: [[-0.12978955]]\n", + "\n", + "TFT model test complete!\n" + ] + } + ], + "source": [ + "# Evaluate the model\n", + "print(\"\\nEvaluating model...\")\n", + "test_metrics = trainer.test(model, data_module)\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " test_batch = next(iter(data_module.test_dataloader()))\n", + " x_test, y_test = test_batch\n", + " y_pred = model(x_test)\n", + "\n", + " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", + " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", + " print(\"First true values:\", y_test[0].cpu().numpy())\n", + "print(\"\\nTFT model test complete!\")" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" } - ], - "source": [ - "# Initialise the Model\n", - "model = TFT(\n", - " loss=nn.MSELoss(),\n", - " logging_metrics=[MAE(), SMAPE()],\n", - " optimizer=\"adam\",\n", - " optimizer_params={\"lr\": 1e-3},\n", - " lr_scheduler=\"reduce_lr_on_plateau\",\n", - " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", - " hidden_size=64,\n", - " num_layers=2,\n", - " attention_head_size=4,\n", - " dropout=0.1,\n", - " metadata=data_module.metadata, # pass the metadata from the datamodule to the model\n", - " # to initialise important params like `encoder_cont` etc\n", - ")\n", - "\n", - "# Train the model\n", - "print(\"\\nTraining model...\")\n", - "trainer = Trainer(\n", - " max_epochs=5,\n", - " accelerator=\"auto\",\n", - " devices=1,\n", - " enable_progress_bar=True,\n", - " log_every_n_steps=10,\n", - ")\n", - "\n", - "trainer.fit(model, data_module)\n", - "\n", - "# Evaluate the model\n", - "print(\"\\nEvaluating model...\")\n", - "test_metrics = trainer.test(model, data_module)\n", - "\n", - "model.eval()\n", - "with torch.no_grad():\n", - " test_batch = next(iter(data_module.test_dataloader()))\n", - " x_test, y_test = test_batch\n", - " y_pred = model(x_test)\n", - "\n", - " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", - " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", - " print(\"First true values:\", y_test[0].cpu().numpy())\n", - "print(\"\\nTFT model test complete!\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zVRwi2MvLGgc" - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" }, - "language_info": { - "name": "python", - "version": "3.12.3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 + "nbformat": 4, + "nbformat_minor": 0 } From 5e71cf61c26359903a4677b9e2ff649ed578d873 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 7 Jun 2025 16:29:15 +0200 Subject: [PATCH 43/43] Update ptf_V2_example.ipynb --- docs/source/tutorials/ptf_V2_example.ipynb | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/source/tutorials/ptf_V2_example.ipynb b/docs/source/tutorials/ptf_V2_example.ipynb index 0c97d365f..2036a4784 100644 --- a/docs/source/tutorials/ptf_V2_example.ipynb +++ b/docs/source/tutorials/ptf_V2_example.ipynb @@ -6,7 +6,7 @@ "id": "rzVbXsEBxnF-" }, "source": [ - "# Example Notebook for a basic vignette for `pytorch-forecasting v2` Model Training and Inference" + "# `pytorch-forecasting v2` Model Training and Inference - Beta API" ] }, { @@ -16,7 +16,9 @@ }, "source": [ "
\n", - ":warning: The vignette showcased here is part of an experimental rework of the `pytorch-forecasting` data layer, planned for release in v2.0.0. The API is currently unstable and subject to change without prior notice. This notebook serves as a basic demonstration of the intended workflow and is not recommended for use in production environments. Feedback and suggestions are highly encouraged — please share them in issue 1736.\n", + ":warning: The vignette showcased here is part of an experimental rework of the `pytorch-forecasting` data layer, planned for release in v2.0.0. The API is currently unstable and subject to change without prior notice.\n", + "\n", + "Feedback and suggestions are highly encouraged — please share them in issue 1736.\n", "
\n" ] },