From 9b4f17d225473b31df972c69ed1095abbd334aca Mon Sep 17 00:00:00 2001 From: WuZhen <498721344@qq.com> Date: Tue, 18 Mar 2025 15:10:55 +0800 Subject: [PATCH 1/6] add an example on lightning backend --- basicts/data/__init__.py | 3 +- basicts/data/tsf_datamodule.py | 107 +++++++++++++ basicts/model.py | 277 +++++++++++++++++++++++++++++++++ examples/arch.py | 52 +++++-- examples/lightning_config.yaml | 67 ++++++++ requirements.txt | 9 +- run.py | 43 +++++ 7 files changed, 544 insertions(+), 14 deletions(-) create mode 100644 basicts/data/tsf_datamodule.py create mode 100644 basicts/model.py create mode 100644 examples/lightning_config.yaml create mode 100644 run.py diff --git a/basicts/data/__init__.py b/basicts/data/__init__.py index 60a24036..a1751307 100644 --- a/basicts/data/__init__.py +++ b/basicts/data/__init__.py @@ -1,4 +1,5 @@ from .base_dataset import BaseDataset from .simple_tsf_dataset import TimeSeriesForecastingDataset +from .tsf_datamodule import TimeSeriesForecastingModule -__all__ = ['BaseDataset', 'TimeSeriesForecastingDataset'] +__all__ = ['BaseDataset', 'TimeSeriesForecastingDataset', 'TimeSeriesForecastingModule'] diff --git a/basicts/data/tsf_datamodule.py b/basicts/data/tsf_datamodule.py new file mode 100644 index 00000000..d5f45d0e --- /dev/null +++ b/basicts/data/tsf_datamodule.py @@ -0,0 +1,107 @@ +from typing import List + +from basicts.utils import get_regular_settings +from .base_dataset import BaseDataset +import lightning.pytorch as pl +from importlib import import_module +from torch.utils.data import DataLoader + + +class TimeSeriesForecastingModule(pl.LightningDataModule): + def __init__( + self, + dataset_class: str, + dataset_name: str, + train_val_test_ratio: List[float], + input_len: int, + output_len: int, + overlap: bool = False, + batch_size: int = 32, + num_workers: int = 0, + pin_memory: bool = False, + shuffle: bool = False, + prefetch: bool = False, + ): + super().__init__() + dataset_class_packeg, dataset_class_name = dataset_class.rsplit(".", 1) + self.dataset_class = getattr( + import_module(dataset_class_packeg), dataset_class_name + ) + if prefetch: + # todo: implement DataLoaderX + # self.dataloader_class = DataLoaderX + raise NotImplementedError("DataLoaderX is not implemented yet.") + else: + self.dataloader_class = DataLoader + self.dataset_name = dataset_name + self.train_val_test_ratio = train_val_test_ratio + self.input_len = input_len + self.output_len = output_len + self.overlap = overlap + self.batch_size = batch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + self.shuffle = shuffle + self.prefetch = prefetch + self.regular_settings = get_regular_settings(dataset_name) + + + # self.train_set = TimeSeriesForecastingDataset() + + @property + def dataset_params(self): + return { + "dataset_name": self.dataset_name, + "train_val_test_ratio": self.train_val_test_ratio, + "input_len": self.input_len, + "output_len": self.output_len, + "overlap": self.overlap, + } + + def train_dataloader(self): + """Build train dataset and dataloader. + + Returns: + train data loader (DataLoader) + """ + dataset = self.dataset_class(**self.dataset_params, mode="train") + loader = self.dataloader_class( + dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + ) + return loader + + def val_dataloader(self): + """Build validation dataset and dataloader. + + Returns: + validation data loader (DataLoader) + """ + dataset = self.dataset_class(**self.dataset_params, mode="valid") + loader = self.dataloader_class( + dataset, + batch_size=self.batch_size, + # shuffle=self.shuffle, # No need to shuffle validation data + num_workers=self.num_workers, + pin_memory=self.pin_memory, + ) + return loader + + def test_dataloader(self): + """Build test dataset and dataloader. + + Returns: + test data loader (DataLoader) + """ + dataset = self.dataset_class(**self.dataset_params, mode="test") + loader = self.dataloader_class( + dataset, + batch_size=self.batch_size, + # shuffle=self.shuffle, # No need to shuffle test data + num_workers=self.num_workers, + pin_memory=self.pin_memory, + ) + return loader diff --git a/basicts/model.py b/basicts/model.py new file mode 100644 index 00000000..3f363a46 --- /dev/null +++ b/basicts/model.py @@ -0,0 +1,277 @@ +from functools import wraps +import functools +import inspect +import sched +import lightning.pytorch as pl +from typing import Any, Callable, Dict, Optional, Union, List + +import numpy as np +import torch + +from basicts.metrics import ALL_METRICS, masked_mae +from basicts.scaler import BaseScaler + + +class BasicTimeSeriesForecastingModule(pl.LightningModule): + def __init__( + self, + lr: float, + weight_decay: float, + history_len: int, + horizon_len: int, + metrics: Optional[List[str]] = None, + forward_features: Optional[List[int]] = None, + target_features: Optional[List[int]] = None, + target_time_series: Optional[List[int]] = None, + scaler: Any = None, + null_val: Any = np.nan, + ): + super().__init__() + self.lr = lr + self.weight_decay = weight_decay + self.history_len = history_len + self.horizon_len = horizon_len + self.forward_features = forward_features + self.target_features = target_features + self.target_time_series = target_time_series + self.scaler = scaler + self.null_val = null_val + + self.metric_func_dict = self.init_metrics(metrics) + if "loss" not in self.metric_func_dict: + if hasattr(self, "loss_func"): + self.metric_func_dict["loss"] = self.loss_func + else: + # self.logger.info('No loss function is provided. Using masked_mae as default.') + self.metric_func_dict["loss"] = masked_mae + + def init_metrics(self, metrics: Optional[List[str]]) -> Dict[str, Callable]: + if metrics is None: + return ALL_METRICS + return {name: ALL_METRICS[name] for name in metrics} + + def basicts_forward(self, data: Dict, **kwargs) -> Dict: + """ + The forward function of original runner. + + Performs the forward pass for training, validation, and testing. + + Args: + data (Dict): A dictionary containing 'target' (future data) and 'inputs' (history data) (normalized by self.scaler). + + Returns: + Dict: A dictionary containing the keys: + - 'inputs': Selected input features. + - 'prediction': Model predictions. + - 'target': Selected target features. + + Raises: + AssertionError: If the shape of the model output does not match [B, L, N]. + """ + + data = self.preprocessing(data) + + # Preprocess input data + future_data, history_data = data["target"], data["inputs"] + # history_data = self.to_running_device(history_data) # Shape: [B, L, N, C] + # future_data = self.to_running_device(future_data) # Shape: [B, L, N, C] + batch_size, length, num_nodes, _ = future_data.shape + + # Select input features + history_data = self.select_input_features(history_data) + future_data_4_dec = self.select_input_features(future_data) + + train = self.trainer.training + # epoch = self.trainer.current_epoch + # batch_seen = self.trainer.global_step + + if not train: + # For non-training phases, use only temporal features + future_data_4_dec[..., 0] = torch.empty_like(future_data_4_dec[..., 0]) + + # Forward pass through the model + model_return = self( + history_data=history_data, + future_data=future_data_4_dec, + # batch_seen=batch_seen, + # epoch=epoch, + # train=train, + ) + + # Parse model return + if isinstance(model_return, torch.Tensor): + model_return = {"prediction": model_return} + if "inputs" not in model_return: + model_return["inputs"] = self.select_target_features(history_data) + if "target" not in model_return: + model_return["target"] = self.select_target_features(future_data) + + # Ensure the output shape is correct + assert list(model_return["prediction"].shape)[:3] == [ + batch_size, + length, + num_nodes, + ], "The shape of the output is incorrect. Ensure it matches [B, L, N, C]." + + model_return = self.postprocessing(model_return) + + return model_return + + def metric_forward(self, metric_func, args: Dict) -> torch.Tensor: + """Compute metrics using the given metric function. + + Args: + metric_func (function or functools.partial): Metric function. + args (Dict): Arguments for metrics computation. + + Returns: + torch.Tensor: Computed metric value. + """ + + covariate_names = inspect.signature(metric_func).parameters.keys() + args = {k: v for k, v in args.items() if k in covariate_names} + + if isinstance(metric_func, functools.partial): + if 'null_val' not in metric_func.keywords and 'null_val' in covariate_names: # null_val is required but not provided + args['null_val'] = self.null_val + metric_item = metric_func(**args) + elif callable(metric_func): + if 'null_val' in covariate_names: # null_val is required + args['null_val'] = self.null_val + metric_item = metric_func(**args) + else: + raise TypeError(f'Unknown metric type: {type(metric_func)}') + return metric_item + + + def training_step(self, batch, batch_idx): + forward_return = self.basicts_forward(batch) + metrics = {} + for metric_name, metric_func in self.metric_func_dict.items(): + metric_item = self.metric_forward(metric_func, forward_return) + metrics[f"train/{metric_name}"] = metric_item + self.log_dict(metrics, on_step=True) + return metrics["train/loss"] + + def validation_step(self, batch, batch_idx): + forward_return = self.basicts_forward(batch) + metrics = {} + for metric_name, metric_func in self.metric_func_dict.items(): + metric_item = self.metric_forward(metric_func, forward_return) + metrics[f"val/{metric_name}"] = metric_item + self.log_dict(metrics, on_step=False, on_epoch=True) + return metrics["val/loss"] + + def test_step(self, batch, batch_idx): + forward_return = self.basicts_forward(batch) + metrics = {} + for metric_name, metric_func in self.metric_func_dict.items(): + metric_item = self.metric_forward(metric_func, forward_return) + metrics[f"test/{metric_name}"] = metric_item + self.log_dict(metrics, on_step=False, on_epoch=True) + return metrics["test/loss"] + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr, weight_decay=self.weight_decay + ) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9) + return [optimizer], [scheduler] + + def preprocessing(self, input_data: Dict, scale_keys=["target", "inputs"]) -> Dict: + """Preprocess data. + + Args: + input_data (Dict): Dictionary containing data to be processed. + scale_keys (Optional[Union[List[str], str]], optional): 'all' means scale all, None means ingore all, it also can be specificated by a list of str. Defaults to None. + + Returns: + Dict: Processed data. + """ + if scale_keys is None: + scale_keys = [] + elif scale_keys == "all": + scale_keys = input_data.keys() + + if self.scaler is not None: + for k in scale_keys: + input_data[k] = self.scaler.transform(input_data[k]) + # TODO: add more preprocessing steps as needed. + return input_data + + def postprocessing( + self, input_data: Dict, scale_keys=["target", "inputs", "prediction"] + ) -> Dict: + """Postprocess data. + + Args: + input_data (Dict): Dictionary containing data to be processed. + scale_keys (Optional[Union[List[str], str]], optional): 'all' means scale all, None means ingore all, it also can be specificated by a list of str. Defaults to None. + + Returns: + Dict: Processed data. + """ + + # rescale data + if self.scaler is not None and self.scaler.rescale: + if scale_keys is None: + scale_keys = [] + elif scale_keys == "all": + scale_keys = input_data.keys() + for k in scale_keys: + input_data[k] = self.scaler.inverse_transform(input_data[k]) + + # subset forecasting + if self.target_time_series is not None: + input_data["target"] = input_data["target"][ + :, :, self.target_time_series, : + ] + input_data["prediction"] = input_data["prediction"][ + :, :, self.target_time_series, : + ] + + # TODO: add more postprocessing steps as needed. + return input_data + + def select_input_features(self, data: torch.Tensor) -> torch.Tensor: + """ + Selects input features based on the forward features specified in the configuration. + + Args: + data (torch.Tensor): Input history data with shape [B, L, N, C1]. + + Returns: + torch.Tensor: Data with selected features with shape [B, L, N, C2]. + """ + + if self.forward_features is not None: + data = data[:, :, :, self.forward_features] + return data + + def select_target_features(self, data: torch.Tensor) -> torch.Tensor: + """ + Selects target features based on the target features specified in the configuration. + + Args: + data (torch.Tensor): Model prediction data with shape [B, L, N, C1]. + + Returns: + torch.Tensor: Data with selected target features and shape [B, L, N, C2]. + """ + + data = data[:, :, :, self.target_features] + return data + + def select_target_time_series(self, data: torch.Tensor) -> torch.Tensor: + """ + Select target time series based on the target time series specified in the configuration. + + Args: + data (torch.Tensor): Model prediction data with shape [B, L, N1, C]. + + Returns: + torch.Tensor: Data with selected target time series and shape [B, L, N2, C]. + """ + + data = data[:, :, self.target_time_series, :] + return data diff --git a/examples/arch.py b/examples/arch.py index 623fd4c0..04eca178 100644 --- a/examples/arch.py +++ b/examples/arch.py @@ -1,9 +1,15 @@ # pylint: disable=unused-argument +from typing import Any, List, Optional +import numpy as np import torch from torch import nn +import lightning.pytorch as pl +from basicts.model import BasicTimeSeriesForecastingModule +from basicts.scaler import BaseScaler -class MultiLayerPerceptron(nn.Module): + +class MultiLayerPerceptron(BasicTimeSeriesForecastingModule): """ A simple Multi-Layer Perceptron (MLP) model with two fully connected layers. @@ -16,7 +22,20 @@ class MultiLayerPerceptron(nn.Module): act (nn.ReLU): The ReLU activation function applied between the two layers. """ - def __init__(self, history_seq_len: int, prediction_seq_len: int, hidden_dim: int): + def __init__( + self, + lr: float, + weight_decay: float, + history_len: int, + horizon_len: int, + hidden_dim: int, + metrics: Optional[List[str]] = None, + forward_features: Optional[List[int]] = None, + target_features: Optional[List[int]] = None, + target_time_series: Optional[List[int]] = None, + scaler: Any = None, + null_val: Any = np.nan, + ): """ Initialize the MultiLayerPerceptron model. @@ -24,28 +43,41 @@ def __init__(self, history_seq_len: int, prediction_seq_len: int, hidden_dim: in history_seq_len (int): The length of the input history sequence. prediction_seq_len (int): The length of the output prediction sequence. hidden_dim (int): The number of units in the hidden layer. + """ - super().__init__() - self.fc1 = nn.Linear(history_seq_len, hidden_dim) - self.fc2 = nn.Linear(hidden_dim, prediction_seq_len) + super().__init__( + lr=lr, + weight_decay=weight_decay, + history_len=history_len, + horizon_len=horizon_len, + metrics=metrics, + forward_features=forward_features, + target_features=target_features, + target_time_series=target_time_series, + scaler=scaler, + null_val=null_val, + ) + self.fc1 = nn.Linear(history_len, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, horizon_len) self.act = nn.ReLU() - def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_seen: int, epoch: int, train: bool) -> torch.Tensor: + def forward( + self, + history_data: torch.Tensor, + future_data: torch.Tensor, + ) -> torch.Tensor: """ Perform a forward pass through the network. Args: history_data (torch.Tensor): A tensor containing historical data, typically of shape `[B, L, N, C]`. future_data (torch.Tensor): A tensor containing future data, typically of shape `[B, L, N, C]`. - batch_seen (int): The number of batches seen so far during training. - epoch (int): The current epoch number. - train (bool): Flag indicating whether the model is in training mode. Returns: torch.Tensor: The output prediction tensor, typically of shape `[B, L, N, C]`. """ - history_data = history_data[..., 0].transpose(1, 2) # [B, L, N, C] -> [B, N, L] + history_data = history_data[..., 0].transpose(1, 2) # [B, L, N, C] -> [B, N, L] # [B, N, L] --h=act(fc1(x))--> [B, N, D] --fc2(h)--> [B, N, L] -> [B, L, N] prediction = self.fc2(self.act(self.fc1(history_data))).transpose(1, 2) diff --git a/examples/lightning_config.yaml b/examples/lightning_config.yaml new file mode 100644 index 00000000..e27f4999 --- /dev/null +++ b/examples/lightning_config.yaml @@ -0,0 +1,67 @@ +fit: + model: + class_path: examples.arch.MultiLayerPerceptron + init_args: + lr: 2e-3 + weight_decay: 1e-4 + history_len: 12 + horizon_len: 12 + hidden_dim: 64 + forward_features: [0, 1, 2] + target_features: [0] + metrics: + - MAE + - MAPE + - RMSE + scaler: + class_path: basicts.scaler.ZScoreScaler + init_args: + dataset_name: ${fit.data.init_args.dataset_name} + train_ratio: ${fit.data.init_args.train_val_test_ratio[0]} + norm_each_channel: False + rescale: True + data: + class_path: basicts.data.TimeSeriesForecastingModule + init_args: + dataset_class: basicts.data.TimeSeriesForecastingDataset + dataset_name: PEMS08 + train_val_test_ratio: [0.6, 0.2, 0.2] + # input_len: 12 + # output_len: 12 + # To enable variable interpolation, first install omegaconf (pip install omegaconf) + input_len: ${fit.model.init_args.history_len} + output_len: ${fit.model.init_args.horizon_len} + overlap: False + trainer: + max_epochs: 10 + devices: auto + log_every_n_steps: 10 + callbacks: + # Use rich for better progress bar and model summary (pip install rich) + - class_path: lightning.pytorch.callbacks.RichProgressBar + - class_path: lightning.pytorch.callbacks.RichModelSummary + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/loss + patience: 10 + mode: min + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/loss + mode: min + save_top_k: 1 + dirpath: ${fit.trainer.logger.init_args.save_dir}/${fit.trainer.logger.init_args.name} + filename: best + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: examples/lightning_logs + name: ${fit.data.init_args.dataset_name} + +test: + data: ${fit.data} + model: ${fit.model} + trainer: ${fit.trainer} + # Specify the checkpoint path to test, it is recommended to speicify in the command line, e.g., --ckpt_path=examples/lightning_logs/best.ckpt + # ckpt_path: examples/lightning_logs/best.ckpt + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8b9a1a03..c34552d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -easy-torch +# easy-torch easydict packaging setproctitle @@ -7,5 +7,8 @@ scikit-learn tables sympy openpyxl -setuptools==59.5.0 -numpy==1.24.4 +rich +lightning[pytorch-extra] +# omegaconf +# setuptools==59.5.0 +# numpy==1.24.4 diff --git a/run.py b/run.py new file mode 100644 index 00000000..6005ed2b --- /dev/null +++ b/run.py @@ -0,0 +1,43 @@ +# Run a baseline model in BasicTS framework. +# pylint: disable=wrong-import-position +import os +import sys +from pathlib import Path + +FILE_PATH = Path(__file__).resolve() +PROJECT_DIR = FILE_PATH.parent # PROJECT_DIR +BASICTS_DIR = PROJECT_DIR / "basicts" # BASICTS_DIR +sys.path.append(PROJECT_DIR.as_posix()) +sys.path.append(BASICTS_DIR.as_posix()) +# os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from lightning.pytorch.cli import LightningCLI + +from basicts.data.tsf_datamodule import TimeSeriesForecastingModule +from basicts.model import BasicTimeSeriesForecastingModule +# from .baselines + + +class BasicTSCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + super().add_arguments_to_parser(parser) + + # parser.link_arguments("model.init_args.null_val", "data.regular_settings[NULL_VAL]") + # parser.link_arguments("model.init_args.history_len", "data.init_args.input_len") + # parser.link_arguments("data.init_args.prediction_len", "data.init_args.output_len") + + +def run(): + cli = BasicTSCLI( + run=True, + trainer_defaults={}, + parser_kwargs={"parser_mode": "omegaconf"}, # pip install omegaconf + save_config_kwargs={"overwrite": True, "save_to_log_dir": True}, + ) + if cli.subcommand in ("fit", "validate") and not cli.trainer.fast_dev_run: + # 被动执行了 fit 或者 validate,追加一个 test + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") + + +if __name__ == "__main__": + run() From 82dce2a7ae3534bd2c5696b5bb1a4a4a55060f61 Mon Sep 17 00:00:00 2001 From: WuZhen <498721344@qq.com> Date: Tue, 18 Mar 2025 15:49:16 +0800 Subject: [PATCH 2/6] add an example on lightning backend --- .gitignore | 2 ++ examples/lightning_config.yaml | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 96cc9655..b69c63f1 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ todo.md gpu_task.py cmd.sh + # file *.npz *.npy @@ -19,6 +20,7 @@ cmd.sh *.pyc *.txt *.core +*.ckpt *.py[cod] *$py.class diff --git a/examples/lightning_config.yaml b/examples/lightning_config.yaml index e27f4999..a12a7651 100644 --- a/examples/lightning_config.yaml +++ b/examples/lightning_config.yaml @@ -38,8 +38,8 @@ fit: log_every_n_steps: 10 callbacks: # Use rich for better progress bar and model summary (pip install rich) - - class_path: lightning.pytorch.callbacks.RichProgressBar - - class_path: lightning.pytorch.callbacks.RichModelSummary + # - class_path: lightning.pytorch.callbacks.RichProgressBar + # - class_path: lightning.pytorch.callbacks.RichModelSummary - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: monitor: val/loss From 207a92237985a87c4e0fa9e935ac7983985ca242 Mon Sep 17 00:00:00 2001 From: WuZhen <498721344@qq.com> Date: Tue, 18 Mar 2025 19:41:32 +0800 Subject: [PATCH 3/6] remove unused import --- basicts/data/tsf_datamodule.py | 2 -- basicts/model.py | 5 +---- examples/arch.py | 2 -- 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/basicts/data/tsf_datamodule.py b/basicts/data/tsf_datamodule.py index d5f45d0e..ccace7d2 100644 --- a/basicts/data/tsf_datamodule.py +++ b/basicts/data/tsf_datamodule.py @@ -1,7 +1,6 @@ from typing import List from basicts.utils import get_regular_settings -from .base_dataset import BaseDataset import lightning.pytorch as pl from importlib import import_module from torch.utils.data import DataLoader @@ -45,7 +44,6 @@ def __init__( self.prefetch = prefetch self.regular_settings = get_regular_settings(dataset_name) - # self.train_set = TimeSeriesForecastingDataset() @property diff --git a/basicts/model.py b/basicts/model.py index 3f363a46..849ca306 100644 --- a/basicts/model.py +++ b/basicts/model.py @@ -1,15 +1,12 @@ -from functools import wraps import functools import inspect -import sched import lightning.pytorch as pl -from typing import Any, Callable, Dict, Optional, Union, List +from typing import Any, Callable, Dict, Optional, List import numpy as np import torch from basicts.metrics import ALL_METRICS, masked_mae -from basicts.scaler import BaseScaler class BasicTimeSeriesForecastingModule(pl.LightningModule): diff --git a/examples/arch.py b/examples/arch.py index 04eca178..7766a3a7 100644 --- a/examples/arch.py +++ b/examples/arch.py @@ -3,10 +3,8 @@ import numpy as np import torch from torch import nn -import lightning.pytorch as pl from basicts.model import BasicTimeSeriesForecastingModule -from basicts.scaler import BaseScaler class MultiLayerPerceptron(BasicTimeSeriesForecastingModule): From 071b601fac20f60d8770bfb77bb081c6343827f8 Mon Sep 17 00:00:00 2001 From: WuZhen <498721344@qq.com> Date: Wed, 19 Mar 2025 12:17:20 +0800 Subject: [PATCH 4/6] feat: add default config, add compute_evaluation_metrics function in base model, move run.py to experiments folder. --- .gitignore | 2 +- basicts/configs/default.yaml | 65 ++++++++++++++++++++++++++++++++++++ basicts/model.py | 56 ++++++++++++++++++++++++++++--- run.py => experiments/run.py | 21 +++++------- 4 files changed, 126 insertions(+), 18 deletions(-) create mode 100644 basicts/configs/default.yaml rename run.py => experiments/run.py (54%) diff --git a/.gitignore b/.gitignore index b69c63f1..f16c0180 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ datasets/ todo.md gpu_task.py cmd.sh - +*logs # file *.npz diff --git a/basicts/configs/default.yaml b/basicts/configs/default.yaml new file mode 100644 index 00000000..e8a41892 --- /dev/null +++ b/basicts/configs/default.yaml @@ -0,0 +1,65 @@ +fit: + model: + class_path: basicts.model.BasicTimeSeriesForecastingModule + init_args: + lr: 2e-3 + weight_decay: 1e-4 + history_len: 12 + horizon_len: 12 + forward_features: [0, 1, 2] + target_features: [0] + metrics: + - MAE + - MAPE + - RMSE + scaler: + class_path: basicts.scaler.ZScoreScaler + init_args: + dataset_name: ${fit.data.init_args.dataset_name} + train_ratio: ${fit.data.init_args.train_val_test_ratio[0]} + norm_each_channel: False + rescale: True + + data: + class_path: basicts.data.TimeSeriesForecastingModule + init_args: + dataset_class: basicts.data.TimeSeriesForecastingDataset + dataset_name: PEMS08 + batch_size: 32 + train_val_test_ratio: [0.6, 0.2, 0.2] + input_len: ${fit.model.init_args.history_len} + output_len: ${fit.model.init_args.horizon_len} + overlap: False + + trainer: + max_epochs: 300 + devices: auto + log_every_n_steps: 10 + callbacks: + # Use rich for better progress bar and model summary (pip install rich) + # - class_path: lightning.pytorch.callbacks.RichProgressBar + # - class_path: lightning.pytorch.callbacks.RichModelSummary + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/loss + patience: 20 + mode: min + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/loss + mode: min + save_top_k: 1 + filename: best + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: examples/lightning_logs + name: ${fit.data.init_args.dataset_name} + +test: + data: ${fit.data} + model: ${fit.model} + trainer: ${fit.trainer} + # Specify the checkpoint path to test, it is recommended to speicify in the command line, e.g., --ckpt_path=examples/lightning_logs/best.ckpt + # ckpt_path: examples/lightning_logs/best.ckpt + \ No newline at end of file diff --git a/basicts/model.py b/basicts/model.py index 849ca306..04cdc1bf 100644 --- a/basicts/model.py +++ b/basicts/model.py @@ -7,6 +7,7 @@ import torch from basicts.metrics import ALL_METRICS, masked_mae +from basicts.utils import load_dataset_desc class BasicTimeSeriesForecastingModule(pl.LightningModule): @@ -22,6 +23,8 @@ def __init__( target_time_series: Optional[List[int]] = None, scaler: Any = None, null_val: Any = np.nan, + dataset_name: str = None, + evaluation_horizons: Optional[List[int]] = None, ): super().__init__() self.lr = lr @@ -33,7 +36,14 @@ def __init__( self.target_time_series = target_time_series self.scaler = scaler self.null_val = null_val + self.dataset_name = dataset_name + if evaluation_horizons is None: + evaluation_horizons = [3, 6, 12] + self.evaluation_horizons = evaluation_horizons + assert len(self.evaluation_horizons) == 0 or min(self.evaluation_horizons) >= 1, 'The horizon should start counting from 1.' + + self.dataset_desc = load_dataset_desc(dataset_name) if self.dataset_name else None self.metric_func_dict = self.init_metrics(metrics) if "loss" not in self.metric_func_dict: if hasattr(self, "loss_func"): @@ -161,13 +171,16 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): forward_return = self.basicts_forward(batch) + + # returns_all = {'prediction': prediction, 'target': target, 'inputs': inputs} + metrics_results = self.compute_evaluation_metrics(forward_return) + metrics = {} - for metric_name, metric_func in self.metric_func_dict.items(): - metric_item = self.metric_forward(metric_func, forward_return) - metrics[f"test/{metric_name}"] = metric_item + for metric_name, metric_value in metrics_results.items(): + metrics[f"test/{metric_name}"] = metric_value self.log_dict(metrics, on_step=False, on_epoch=True) - return metrics["test/loss"] - + return forward_return, metrics_results + def configure_optimizers(self): optimizer = torch.optim.Adam( self.parameters(), lr=self.lr, weight_decay=self.weight_decay @@ -272,3 +285,36 @@ def select_target_time_series(self, data: torch.Tensor) -> torch.Tensor: data = data[:, :, self.target_time_series, :] return data + + def compute_evaluation_metrics(self, returns_all: Dict): + """Compute metrics for evaluating model performance during the test process. + + Args: + returns_all (Dict): Must contain keys: inputs, prediction, target. + """ + + metrics_results = {} + for i in self.evaluation_horizons: + pred = returns_all['prediction'][:, i - 1, :, :] + real = returns_all['target'][:, i - 1, :, :] + + # metrics_results[f'horizon_{i + 1}'] = {} + # metric_repr = '' + for metric_name, metric_func in self.metric_func_dict.items(): + if metric_name.lower() == 'mase': + continue # MASE needs to be calculated after all horizons + if metric_name.lower() == 'loss': + continue + metric_item = self.metric_forward(metric_func, {'prediction': pred, 'target': real}) + # metric_repr += f', Test {metric_name}: {metric_item.item():.4f}' + metrics_results[f'{metric_name}/{i}'] = metric_item.item() + # self.logger.info(f'Evaluate best model on test data for horizon {i + 1}{metric_repr}') + + # metrics_results['overall'] = {} + for metric_name, metric_func in self.metric_func_dict.items(): + if metric_name.lower() == 'loss': + continue + metric_item = self.metric_forward(metric_func, returns_all) + metrics_results[f'{metric_name}/overall'] = metric_item.item() + + return metrics_results \ No newline at end of file diff --git a/run.py b/experiments/run.py similarity index 54% rename from run.py rename to experiments/run.py index 6005ed2b..649d5427 100644 --- a/run.py +++ b/experiments/run.py @@ -1,37 +1,34 @@ # Run a baseline model in BasicTS framework. # pylint: disable=wrong-import-position -import os import sys from pathlib import Path FILE_PATH = Path(__file__).resolve() -PROJECT_DIR = FILE_PATH.parent # PROJECT_DIR -BASICTS_DIR = PROJECT_DIR / "basicts" # BASICTS_DIR +PROJECT_DIR = FILE_PATH.parents[1] # PROJECT_DIR: BasicTS +BASICTS_DIR = PROJECT_DIR / "basicts" # BASICTS_DIR: BasicTS/basicts sys.path.append(PROJECT_DIR.as_posix()) sys.path.append(BASICTS_DIR.as_posix()) -# os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from lightning.pytorch.cli import LightningCLI -from basicts.data.tsf_datamodule import TimeSeriesForecastingModule -from basicts.model import BasicTimeSeriesForecastingModule -# from .baselines - +# from basicts.data.tsf_datamodule import TimeSeriesForecastingModule +# from basicts.model import BasicTimeSeriesForecastingModule class BasicTSCLI(LightningCLI): def add_arguments_to_parser(self, parser): super().add_arguments_to_parser(parser) - # parser.link_arguments("model.init_args.null_val", "data.regular_settings[NULL_VAL]") - # parser.link_arguments("model.init_args.history_len", "data.init_args.input_len") - # parser.link_arguments("data.init_args.prediction_len", "data.init_args.output_len") + parser.link_arguments("data.init_args.dataset_name", "model.init_args.dataset_name") def run(): cli = BasicTSCLI( run=True, trainer_defaults={}, - parser_kwargs={"parser_mode": "omegaconf"}, # pip install omegaconf + + parser_kwargs={"parser_mode": "omegaconf", + "default_config_files": [(BASICTS_DIR/"configs"/"default.yaml").as_posix()], + }, save_config_kwargs={"overwrite": True, "save_to_log_dir": True}, ) if cli.subcommand in ("fit", "validate") and not cli.trainer.fast_dev_run: From 1eac875be8192d2c3dda45f616f879cb2849e7e1 Mon Sep 17 00:00:00 2001 From: WuZhen <498721344@qq.com> Date: Wed, 19 Mar 2025 13:16:03 +0800 Subject: [PATCH 5/6] feat: save and log hyperparams, fetch null_val from dataset description --- basicts/data/tsf_datamodule.py | 2 +- basicts/model.py | 8 +++++--- examples/lightning_config.yaml | 4 ++-- experiments/run.py | 10 ++++++++++ 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/basicts/data/tsf_datamodule.py b/basicts/data/tsf_datamodule.py index ccace7d2..e1c3a9fe 100644 --- a/basicts/data/tsf_datamodule.py +++ b/basicts/data/tsf_datamodule.py @@ -43,7 +43,7 @@ def __init__( self.shuffle = shuffle self.prefetch = prefetch self.regular_settings = get_regular_settings(dataset_name) - + self.save_hyperparameters() # self.train_set = TimeSeriesForecastingDataset() @property diff --git a/basicts/model.py b/basicts/model.py index 04cdc1bf..21bcc9c2 100644 --- a/basicts/model.py +++ b/basicts/model.py @@ -22,7 +22,7 @@ def __init__( target_features: Optional[List[int]] = None, target_time_series: Optional[List[int]] = None, scaler: Any = None, - null_val: Any = np.nan, + null_val: Any = None, dataset_name: str = None, evaluation_horizons: Optional[List[int]] = None, ): @@ -35,15 +35,17 @@ def __init__( self.target_features = target_features self.target_time_series = target_time_series self.scaler = scaler - self.null_val = null_val self.dataset_name = dataset_name - + self.save_hyperparameters() if evaluation_horizons is None: evaluation_horizons = [3, 6, 12] self.evaluation_horizons = evaluation_horizons assert len(self.evaluation_horizons) == 0 or min(self.evaluation_horizons) >= 1, 'The horizon should start counting from 1.' self.dataset_desc = load_dataset_desc(dataset_name) if self.dataset_name else None + if null_val is None and self.dataset_desc is not None: + null_val = self.dataset_desc['regular_settings'].get('NULL_VAL', np.nan) + self.null_val = null_val self.metric_func_dict = self.init_metrics(metrics) if "loss" not in self.metric_func_dict: if hasattr(self, "loss_func"): diff --git a/examples/lightning_config.yaml b/examples/lightning_config.yaml index a12a7651..e27f4999 100644 --- a/examples/lightning_config.yaml +++ b/examples/lightning_config.yaml @@ -38,8 +38,8 @@ fit: log_every_n_steps: 10 callbacks: # Use rich for better progress bar and model summary (pip install rich) - # - class_path: lightning.pytorch.callbacks.RichProgressBar - # - class_path: lightning.pytorch.callbacks.RichModelSummary + - class_path: lightning.pytorch.callbacks.RichProgressBar + - class_path: lightning.pytorch.callbacks.RichModelSummary - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: monitor: val/loss diff --git a/experiments/run.py b/experiments/run.py index 649d5427..4dcfd231 100644 --- a/experiments/run.py +++ b/experiments/run.py @@ -31,6 +31,16 @@ def run(): }, save_config_kwargs={"overwrite": True, "save_to_log_dir": True}, ) + logger = cli.trainer.logger + + # Log hyperparameters + trainer_hparam_names = ['max_epochs', 'min_epochs', 'precision', 'overfit_batches', 'gradient_clip_val', 'gradient_clip_algorithm', 'accelerator', 'strategy', 'limit_train_batches', 'limit_val_batches', 'limit_test_batches'] + trainer_hparams = {k: cli.config_dump['trainer'][k] for k in trainer_hparam_names} + logger.log_hyperparams(cli.datamodule.hparams) + logger.log_hyperparams(cli.model.hparams) + logger.log_hyperparams(trainer_hparams) + + if cli.subcommand in ("fit", "validate") and not cli.trainer.fast_dev_run: # 被动执行了 fit 或者 validate,追加一个 test cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") From 075c464416bb7952eef809f40e382b39d7c02818 Mon Sep 17 00:00:00 2001 From: WuZhen <498721344@qq.com> Date: Thu, 15 May 2025 21:01:35 +0800 Subject: [PATCH 6/6] Update config and metric logging --- basicts/data/tsf_datamodule.py | 2 +- basicts/model.py | 14 +++++++++----- requirements.txt | 4 ++-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/basicts/data/tsf_datamodule.py b/basicts/data/tsf_datamodule.py index e1c3a9fe..0c466a7a 100644 --- a/basicts/data/tsf_datamodule.py +++ b/basicts/data/tsf_datamodule.py @@ -18,7 +18,7 @@ def __init__( batch_size: int = 32, num_workers: int = 0, pin_memory: bool = False, - shuffle: bool = False, + shuffle: bool = True, prefetch: bool = False, ): super().__init__() diff --git a/basicts/model.py b/basicts/model.py index 21bcc9c2..70324fcf 100644 --- a/basicts/model.py +++ b/basicts/model.py @@ -159,7 +159,9 @@ def training_step(self, batch, batch_idx): for metric_name, metric_func in self.metric_func_dict.items(): metric_item = self.metric_forward(metric_func, forward_return) metrics[f"train/{metric_name}"] = metric_item - self.log_dict(metrics, on_step=True) + self.log(f"train/{metric_name}", metric_item, on_step=True, prog_bar=metric_name == "loss") + + # self.log_dict(metrics, on_step=True) return metrics["train/loss"] def validation_step(self, batch, batch_idx): @@ -168,8 +170,9 @@ def validation_step(self, batch, batch_idx): for metric_name, metric_func in self.metric_func_dict.items(): metric_item = self.metric_forward(metric_func, forward_return) metrics[f"val/{metric_name}"] = metric_item - self.log_dict(metrics, on_step=False, on_epoch=True) - return metrics["val/loss"] + self.log(f"val/{metric_name}", metric_item, on_step=False, on_epoch=True, prog_bar=metric_name == "loss") + # self.log_dict(metrics, on_step=False, on_epoch=True) + return forward_return, metrics def test_step(self, batch, batch_idx): forward_return = self.basicts_forward(batch) @@ -180,14 +183,15 @@ def test_step(self, batch, batch_idx): metrics = {} for metric_name, metric_value in metrics_results.items(): metrics[f"test/{metric_name}"] = metric_value - self.log_dict(metrics, on_step=False, on_epoch=True) + self.log(f"test/{metric_name}", metric_value, on_step=False, on_epoch=True, prog_bar=metric_name == "loss") + # self.log_dict(metrics, on_step=False, on_epoch=True) return forward_return, metrics_results def configure_optimizers(self): optimizer = torch.optim.Adam( self.parameters(), lr=self.lr, weight_decay=self.weight_decay ) - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.trainer.max_epochs) return [optimizer], [scheduler] def preprocessing(self, input_data: Dict, scale_keys=["target", "inputs"]) -> Dict: diff --git a/requirements.txt b/requirements.txt index c34552d1..17e754dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -# easy-torch +easy-torch easydict packaging setproctitle @@ -8,7 +8,7 @@ tables sympy openpyxl rich -lightning[pytorch-extra] +lightning[pytorch-extra]==2.5.0.post0 # omegaconf # setuptools==59.5.0 # numpy==1.24.4