diff --git a/.gitignore b/.gitignore index 96cc9655..f16c0180 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ datasets/ todo.md gpu_task.py cmd.sh +*logs # file *.npz @@ -19,6 +20,7 @@ cmd.sh *.pyc *.txt *.core +*.ckpt *.py[cod] *$py.class diff --git a/basicts/configs/default.yaml b/basicts/configs/default.yaml new file mode 100644 index 00000000..e8a41892 --- /dev/null +++ b/basicts/configs/default.yaml @@ -0,0 +1,65 @@ +fit: + model: + class_path: basicts.model.BasicTimeSeriesForecastingModule + init_args: + lr: 2e-3 + weight_decay: 1e-4 + history_len: 12 + horizon_len: 12 + forward_features: [0, 1, 2] + target_features: [0] + metrics: + - MAE + - MAPE + - RMSE + scaler: + class_path: basicts.scaler.ZScoreScaler + init_args: + dataset_name: ${fit.data.init_args.dataset_name} + train_ratio: ${fit.data.init_args.train_val_test_ratio[0]} + norm_each_channel: False + rescale: True + + data: + class_path: basicts.data.TimeSeriesForecastingModule + init_args: + dataset_class: basicts.data.TimeSeriesForecastingDataset + dataset_name: PEMS08 + batch_size: 32 + train_val_test_ratio: [0.6, 0.2, 0.2] + input_len: ${fit.model.init_args.history_len} + output_len: ${fit.model.init_args.horizon_len} + overlap: False + + trainer: + max_epochs: 300 + devices: auto + log_every_n_steps: 10 + callbacks: + # Use rich for better progress bar and model summary (pip install rich) + # - class_path: lightning.pytorch.callbacks.RichProgressBar + # - class_path: lightning.pytorch.callbacks.RichModelSummary + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/loss + patience: 20 + mode: min + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/loss + mode: min + save_top_k: 1 + filename: best + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: examples/lightning_logs + name: ${fit.data.init_args.dataset_name} + +test: + data: ${fit.data} + model: ${fit.model} + trainer: ${fit.trainer} + # Specify the checkpoint path to test, it is recommended to speicify in the command line, e.g., --ckpt_path=examples/lightning_logs/best.ckpt + # ckpt_path: examples/lightning_logs/best.ckpt + \ No newline at end of file diff --git a/basicts/data/__init__.py b/basicts/data/__init__.py index 60a24036..a1751307 100644 --- a/basicts/data/__init__.py +++ b/basicts/data/__init__.py @@ -1,4 +1,5 @@ from .base_dataset import BaseDataset from .simple_tsf_dataset import TimeSeriesForecastingDataset +from .tsf_datamodule import TimeSeriesForecastingModule -__all__ = ['BaseDataset', 'TimeSeriesForecastingDataset'] +__all__ = ['BaseDataset', 'TimeSeriesForecastingDataset', 'TimeSeriesForecastingModule'] diff --git a/basicts/data/tsf_datamodule.py b/basicts/data/tsf_datamodule.py new file mode 100644 index 00000000..0c466a7a --- /dev/null +++ b/basicts/data/tsf_datamodule.py @@ -0,0 +1,105 @@ +from typing import List + +from basicts.utils import get_regular_settings +import lightning.pytorch as pl +from importlib import import_module +from torch.utils.data import DataLoader + + +class TimeSeriesForecastingModule(pl.LightningDataModule): + def __init__( + self, + dataset_class: str, + dataset_name: str, + train_val_test_ratio: List[float], + input_len: int, + output_len: int, + overlap: bool = False, + batch_size: int = 32, + num_workers: int = 0, + pin_memory: bool = False, + shuffle: bool = True, + prefetch: bool = False, + ): + super().__init__() + dataset_class_packeg, dataset_class_name = dataset_class.rsplit(".", 1) + self.dataset_class = getattr( + import_module(dataset_class_packeg), dataset_class_name + ) + if prefetch: + # todo: implement DataLoaderX + # self.dataloader_class = DataLoaderX + raise NotImplementedError("DataLoaderX is not implemented yet.") + else: + self.dataloader_class = DataLoader + self.dataset_name = dataset_name + self.train_val_test_ratio = train_val_test_ratio + self.input_len = input_len + self.output_len = output_len + self.overlap = overlap + self.batch_size = batch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + self.shuffle = shuffle + self.prefetch = prefetch + self.regular_settings = get_regular_settings(dataset_name) + self.save_hyperparameters() + # self.train_set = TimeSeriesForecastingDataset() + + @property + def dataset_params(self): + return { + "dataset_name": self.dataset_name, + "train_val_test_ratio": self.train_val_test_ratio, + "input_len": self.input_len, + "output_len": self.output_len, + "overlap": self.overlap, + } + + def train_dataloader(self): + """Build train dataset and dataloader. + + Returns: + train data loader (DataLoader) + """ + dataset = self.dataset_class(**self.dataset_params, mode="train") + loader = self.dataloader_class( + dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + ) + return loader + + def val_dataloader(self): + """Build validation dataset and dataloader. + + Returns: + validation data loader (DataLoader) + """ + dataset = self.dataset_class(**self.dataset_params, mode="valid") + loader = self.dataloader_class( + dataset, + batch_size=self.batch_size, + # shuffle=self.shuffle, # No need to shuffle validation data + num_workers=self.num_workers, + pin_memory=self.pin_memory, + ) + return loader + + def test_dataloader(self): + """Build test dataset and dataloader. + + Returns: + test data loader (DataLoader) + """ + dataset = self.dataset_class(**self.dataset_params, mode="test") + loader = self.dataloader_class( + dataset, + batch_size=self.batch_size, + # shuffle=self.shuffle, # No need to shuffle test data + num_workers=self.num_workers, + pin_memory=self.pin_memory, + ) + return loader diff --git a/basicts/model.py b/basicts/model.py new file mode 100644 index 00000000..70324fcf --- /dev/null +++ b/basicts/model.py @@ -0,0 +1,326 @@ +import functools +import inspect +import lightning.pytorch as pl +from typing import Any, Callable, Dict, Optional, List + +import numpy as np +import torch + +from basicts.metrics import ALL_METRICS, masked_mae +from basicts.utils import load_dataset_desc + + +class BasicTimeSeriesForecastingModule(pl.LightningModule): + def __init__( + self, + lr: float, + weight_decay: float, + history_len: int, + horizon_len: int, + metrics: Optional[List[str]] = None, + forward_features: Optional[List[int]] = None, + target_features: Optional[List[int]] = None, + target_time_series: Optional[List[int]] = None, + scaler: Any = None, + null_val: Any = None, + dataset_name: str = None, + evaluation_horizons: Optional[List[int]] = None, + ): + super().__init__() + self.lr = lr + self.weight_decay = weight_decay + self.history_len = history_len + self.horizon_len = horizon_len + self.forward_features = forward_features + self.target_features = target_features + self.target_time_series = target_time_series + self.scaler = scaler + self.dataset_name = dataset_name + self.save_hyperparameters() + if evaluation_horizons is None: + evaluation_horizons = [3, 6, 12] + self.evaluation_horizons = evaluation_horizons + assert len(self.evaluation_horizons) == 0 or min(self.evaluation_horizons) >= 1, 'The horizon should start counting from 1.' + + self.dataset_desc = load_dataset_desc(dataset_name) if self.dataset_name else None + if null_val is None and self.dataset_desc is not None: + null_val = self.dataset_desc['regular_settings'].get('NULL_VAL', np.nan) + self.null_val = null_val + self.metric_func_dict = self.init_metrics(metrics) + if "loss" not in self.metric_func_dict: + if hasattr(self, "loss_func"): + self.metric_func_dict["loss"] = self.loss_func + else: + # self.logger.info('No loss function is provided. Using masked_mae as default.') + self.metric_func_dict["loss"] = masked_mae + + def init_metrics(self, metrics: Optional[List[str]]) -> Dict[str, Callable]: + if metrics is None: + return ALL_METRICS + return {name: ALL_METRICS[name] for name in metrics} + + def basicts_forward(self, data: Dict, **kwargs) -> Dict: + """ + The forward function of original runner. + + Performs the forward pass for training, validation, and testing. + + Args: + data (Dict): A dictionary containing 'target' (future data) and 'inputs' (history data) (normalized by self.scaler). + + Returns: + Dict: A dictionary containing the keys: + - 'inputs': Selected input features. + - 'prediction': Model predictions. + - 'target': Selected target features. + + Raises: + AssertionError: If the shape of the model output does not match [B, L, N]. + """ + + data = self.preprocessing(data) + + # Preprocess input data + future_data, history_data = data["target"], data["inputs"] + # history_data = self.to_running_device(history_data) # Shape: [B, L, N, C] + # future_data = self.to_running_device(future_data) # Shape: [B, L, N, C] + batch_size, length, num_nodes, _ = future_data.shape + + # Select input features + history_data = self.select_input_features(history_data) + future_data_4_dec = self.select_input_features(future_data) + + train = self.trainer.training + # epoch = self.trainer.current_epoch + # batch_seen = self.trainer.global_step + + if not train: + # For non-training phases, use only temporal features + future_data_4_dec[..., 0] = torch.empty_like(future_data_4_dec[..., 0]) + + # Forward pass through the model + model_return = self( + history_data=history_data, + future_data=future_data_4_dec, + # batch_seen=batch_seen, + # epoch=epoch, + # train=train, + ) + + # Parse model return + if isinstance(model_return, torch.Tensor): + model_return = {"prediction": model_return} + if "inputs" not in model_return: + model_return["inputs"] = self.select_target_features(history_data) + if "target" not in model_return: + model_return["target"] = self.select_target_features(future_data) + + # Ensure the output shape is correct + assert list(model_return["prediction"].shape)[:3] == [ + batch_size, + length, + num_nodes, + ], "The shape of the output is incorrect. Ensure it matches [B, L, N, C]." + + model_return = self.postprocessing(model_return) + + return model_return + + def metric_forward(self, metric_func, args: Dict) -> torch.Tensor: + """Compute metrics using the given metric function. + + Args: + metric_func (function or functools.partial): Metric function. + args (Dict): Arguments for metrics computation. + + Returns: + torch.Tensor: Computed metric value. + """ + + covariate_names = inspect.signature(metric_func).parameters.keys() + args = {k: v for k, v in args.items() if k in covariate_names} + + if isinstance(metric_func, functools.partial): + if 'null_val' not in metric_func.keywords and 'null_val' in covariate_names: # null_val is required but not provided + args['null_val'] = self.null_val + metric_item = metric_func(**args) + elif callable(metric_func): + if 'null_val' in covariate_names: # null_val is required + args['null_val'] = self.null_val + metric_item = metric_func(**args) + else: + raise TypeError(f'Unknown metric type: {type(metric_func)}') + return metric_item + + + def training_step(self, batch, batch_idx): + forward_return = self.basicts_forward(batch) + metrics = {} + for metric_name, metric_func in self.metric_func_dict.items(): + metric_item = self.metric_forward(metric_func, forward_return) + metrics[f"train/{metric_name}"] = metric_item + self.log(f"train/{metric_name}", metric_item, on_step=True, prog_bar=metric_name == "loss") + + # self.log_dict(metrics, on_step=True) + return metrics["train/loss"] + + def validation_step(self, batch, batch_idx): + forward_return = self.basicts_forward(batch) + metrics = {} + for metric_name, metric_func in self.metric_func_dict.items(): + metric_item = self.metric_forward(metric_func, forward_return) + metrics[f"val/{metric_name}"] = metric_item + self.log(f"val/{metric_name}", metric_item, on_step=False, on_epoch=True, prog_bar=metric_name == "loss") + # self.log_dict(metrics, on_step=False, on_epoch=True) + return forward_return, metrics + + def test_step(self, batch, batch_idx): + forward_return = self.basicts_forward(batch) + + # returns_all = {'prediction': prediction, 'target': target, 'inputs': inputs} + metrics_results = self.compute_evaluation_metrics(forward_return) + + metrics = {} + for metric_name, metric_value in metrics_results.items(): + metrics[f"test/{metric_name}"] = metric_value + self.log(f"test/{metric_name}", metric_value, on_step=False, on_epoch=True, prog_bar=metric_name == "loss") + # self.log_dict(metrics, on_step=False, on_epoch=True) + return forward_return, metrics_results + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr, weight_decay=self.weight_decay + ) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.trainer.max_epochs) + return [optimizer], [scheduler] + + def preprocessing(self, input_data: Dict, scale_keys=["target", "inputs"]) -> Dict: + """Preprocess data. + + Args: + input_data (Dict): Dictionary containing data to be processed. + scale_keys (Optional[Union[List[str], str]], optional): 'all' means scale all, None means ingore all, it also can be specificated by a list of str. Defaults to None. + + Returns: + Dict: Processed data. + """ + if scale_keys is None: + scale_keys = [] + elif scale_keys == "all": + scale_keys = input_data.keys() + + if self.scaler is not None: + for k in scale_keys: + input_data[k] = self.scaler.transform(input_data[k]) + # TODO: add more preprocessing steps as needed. + return input_data + + def postprocessing( + self, input_data: Dict, scale_keys=["target", "inputs", "prediction"] + ) -> Dict: + """Postprocess data. + + Args: + input_data (Dict): Dictionary containing data to be processed. + scale_keys (Optional[Union[List[str], str]], optional): 'all' means scale all, None means ingore all, it also can be specificated by a list of str. Defaults to None. + + Returns: + Dict: Processed data. + """ + + # rescale data + if self.scaler is not None and self.scaler.rescale: + if scale_keys is None: + scale_keys = [] + elif scale_keys == "all": + scale_keys = input_data.keys() + for k in scale_keys: + input_data[k] = self.scaler.inverse_transform(input_data[k]) + + # subset forecasting + if self.target_time_series is not None: + input_data["target"] = input_data["target"][ + :, :, self.target_time_series, : + ] + input_data["prediction"] = input_data["prediction"][ + :, :, self.target_time_series, : + ] + + # TODO: add more postprocessing steps as needed. + return input_data + + def select_input_features(self, data: torch.Tensor) -> torch.Tensor: + """ + Selects input features based on the forward features specified in the configuration. + + Args: + data (torch.Tensor): Input history data with shape [B, L, N, C1]. + + Returns: + torch.Tensor: Data with selected features with shape [B, L, N, C2]. + """ + + if self.forward_features is not None: + data = data[:, :, :, self.forward_features] + return data + + def select_target_features(self, data: torch.Tensor) -> torch.Tensor: + """ + Selects target features based on the target features specified in the configuration. + + Args: + data (torch.Tensor): Model prediction data with shape [B, L, N, C1]. + + Returns: + torch.Tensor: Data with selected target features and shape [B, L, N, C2]. + """ + + data = data[:, :, :, self.target_features] + return data + + def select_target_time_series(self, data: torch.Tensor) -> torch.Tensor: + """ + Select target time series based on the target time series specified in the configuration. + + Args: + data (torch.Tensor): Model prediction data with shape [B, L, N1, C]. + + Returns: + torch.Tensor: Data with selected target time series and shape [B, L, N2, C]. + """ + + data = data[:, :, self.target_time_series, :] + return data + + def compute_evaluation_metrics(self, returns_all: Dict): + """Compute metrics for evaluating model performance during the test process. + + Args: + returns_all (Dict): Must contain keys: inputs, prediction, target. + """ + + metrics_results = {} + for i in self.evaluation_horizons: + pred = returns_all['prediction'][:, i - 1, :, :] + real = returns_all['target'][:, i - 1, :, :] + + # metrics_results[f'horizon_{i + 1}'] = {} + # metric_repr = '' + for metric_name, metric_func in self.metric_func_dict.items(): + if metric_name.lower() == 'mase': + continue # MASE needs to be calculated after all horizons + if metric_name.lower() == 'loss': + continue + metric_item = self.metric_forward(metric_func, {'prediction': pred, 'target': real}) + # metric_repr += f', Test {metric_name}: {metric_item.item():.4f}' + metrics_results[f'{metric_name}/{i}'] = metric_item.item() + # self.logger.info(f'Evaluate best model on test data for horizon {i + 1}{metric_repr}') + + # metrics_results['overall'] = {} + for metric_name, metric_func in self.metric_func_dict.items(): + if metric_name.lower() == 'loss': + continue + metric_item = self.metric_forward(metric_func, returns_all) + metrics_results[f'{metric_name}/overall'] = metric_item.item() + + return metrics_results \ No newline at end of file diff --git a/examples/arch.py b/examples/arch.py index 623fd4c0..7766a3a7 100644 --- a/examples/arch.py +++ b/examples/arch.py @@ -1,9 +1,13 @@ # pylint: disable=unused-argument +from typing import Any, List, Optional +import numpy as np import torch from torch import nn +from basicts.model import BasicTimeSeriesForecastingModule -class MultiLayerPerceptron(nn.Module): + +class MultiLayerPerceptron(BasicTimeSeriesForecastingModule): """ A simple Multi-Layer Perceptron (MLP) model with two fully connected layers. @@ -16,7 +20,20 @@ class MultiLayerPerceptron(nn.Module): act (nn.ReLU): The ReLU activation function applied between the two layers. """ - def __init__(self, history_seq_len: int, prediction_seq_len: int, hidden_dim: int): + def __init__( + self, + lr: float, + weight_decay: float, + history_len: int, + horizon_len: int, + hidden_dim: int, + metrics: Optional[List[str]] = None, + forward_features: Optional[List[int]] = None, + target_features: Optional[List[int]] = None, + target_time_series: Optional[List[int]] = None, + scaler: Any = None, + null_val: Any = np.nan, + ): """ Initialize the MultiLayerPerceptron model. @@ -24,28 +41,41 @@ def __init__(self, history_seq_len: int, prediction_seq_len: int, hidden_dim: in history_seq_len (int): The length of the input history sequence. prediction_seq_len (int): The length of the output prediction sequence. hidden_dim (int): The number of units in the hidden layer. + """ - super().__init__() - self.fc1 = nn.Linear(history_seq_len, hidden_dim) - self.fc2 = nn.Linear(hidden_dim, prediction_seq_len) + super().__init__( + lr=lr, + weight_decay=weight_decay, + history_len=history_len, + horizon_len=horizon_len, + metrics=metrics, + forward_features=forward_features, + target_features=target_features, + target_time_series=target_time_series, + scaler=scaler, + null_val=null_val, + ) + self.fc1 = nn.Linear(history_len, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, horizon_len) self.act = nn.ReLU() - def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_seen: int, epoch: int, train: bool) -> torch.Tensor: + def forward( + self, + history_data: torch.Tensor, + future_data: torch.Tensor, + ) -> torch.Tensor: """ Perform a forward pass through the network. Args: history_data (torch.Tensor): A tensor containing historical data, typically of shape `[B, L, N, C]`. future_data (torch.Tensor): A tensor containing future data, typically of shape `[B, L, N, C]`. - batch_seen (int): The number of batches seen so far during training. - epoch (int): The current epoch number. - train (bool): Flag indicating whether the model is in training mode. Returns: torch.Tensor: The output prediction tensor, typically of shape `[B, L, N, C]`. """ - history_data = history_data[..., 0].transpose(1, 2) # [B, L, N, C] -> [B, N, L] + history_data = history_data[..., 0].transpose(1, 2) # [B, L, N, C] -> [B, N, L] # [B, N, L] --h=act(fc1(x))--> [B, N, D] --fc2(h)--> [B, N, L] -> [B, L, N] prediction = self.fc2(self.act(self.fc1(history_data))).transpose(1, 2) diff --git a/examples/lightning_config.yaml b/examples/lightning_config.yaml new file mode 100644 index 00000000..e27f4999 --- /dev/null +++ b/examples/lightning_config.yaml @@ -0,0 +1,67 @@ +fit: + model: + class_path: examples.arch.MultiLayerPerceptron + init_args: + lr: 2e-3 + weight_decay: 1e-4 + history_len: 12 + horizon_len: 12 + hidden_dim: 64 + forward_features: [0, 1, 2] + target_features: [0] + metrics: + - MAE + - MAPE + - RMSE + scaler: + class_path: basicts.scaler.ZScoreScaler + init_args: + dataset_name: ${fit.data.init_args.dataset_name} + train_ratio: ${fit.data.init_args.train_val_test_ratio[0]} + norm_each_channel: False + rescale: True + data: + class_path: basicts.data.TimeSeriesForecastingModule + init_args: + dataset_class: basicts.data.TimeSeriesForecastingDataset + dataset_name: PEMS08 + train_val_test_ratio: [0.6, 0.2, 0.2] + # input_len: 12 + # output_len: 12 + # To enable variable interpolation, first install omegaconf (pip install omegaconf) + input_len: ${fit.model.init_args.history_len} + output_len: ${fit.model.init_args.horizon_len} + overlap: False + trainer: + max_epochs: 10 + devices: auto + log_every_n_steps: 10 + callbacks: + # Use rich for better progress bar and model summary (pip install rich) + - class_path: lightning.pytorch.callbacks.RichProgressBar + - class_path: lightning.pytorch.callbacks.RichModelSummary + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/loss + patience: 10 + mode: min + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/loss + mode: min + save_top_k: 1 + dirpath: ${fit.trainer.logger.init_args.save_dir}/${fit.trainer.logger.init_args.name} + filename: best + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: examples/lightning_logs + name: ${fit.data.init_args.dataset_name} + +test: + data: ${fit.data} + model: ${fit.model} + trainer: ${fit.trainer} + # Specify the checkpoint path to test, it is recommended to speicify in the command line, e.g., --ckpt_path=examples/lightning_logs/best.ckpt + # ckpt_path: examples/lightning_logs/best.ckpt + \ No newline at end of file diff --git a/experiments/run.py b/experiments/run.py new file mode 100644 index 00000000..4dcfd231 --- /dev/null +++ b/experiments/run.py @@ -0,0 +1,50 @@ +# Run a baseline model in BasicTS framework. +# pylint: disable=wrong-import-position +import sys +from pathlib import Path + +FILE_PATH = Path(__file__).resolve() +PROJECT_DIR = FILE_PATH.parents[1] # PROJECT_DIR: BasicTS +BASICTS_DIR = PROJECT_DIR / "basicts" # BASICTS_DIR: BasicTS/basicts +sys.path.append(PROJECT_DIR.as_posix()) +sys.path.append(BASICTS_DIR.as_posix()) + +from lightning.pytorch.cli import LightningCLI + +# from basicts.data.tsf_datamodule import TimeSeriesForecastingModule +# from basicts.model import BasicTimeSeriesForecastingModule + +class BasicTSCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + super().add_arguments_to_parser(parser) + + parser.link_arguments("data.init_args.dataset_name", "model.init_args.dataset_name") + + +def run(): + cli = BasicTSCLI( + run=True, + trainer_defaults={}, + + parser_kwargs={"parser_mode": "omegaconf", + "default_config_files": [(BASICTS_DIR/"configs"/"default.yaml").as_posix()], + }, + save_config_kwargs={"overwrite": True, "save_to_log_dir": True}, + ) + logger = cli.trainer.logger + + # Log hyperparameters + trainer_hparam_names = ['max_epochs', 'min_epochs', 'precision', 'overfit_batches', 'gradient_clip_val', 'gradient_clip_algorithm', 'accelerator', 'strategy', 'limit_train_batches', 'limit_val_batches', 'limit_test_batches'] + trainer_hparams = {k: cli.config_dump['trainer'][k] for k in trainer_hparam_names} + logger.log_hyperparams(cli.datamodule.hparams) + logger.log_hyperparams(cli.model.hparams) + logger.log_hyperparams(trainer_hparams) + + + if cli.subcommand in ("fit", "validate") and not cli.trainer.fast_dev_run: + # 被动执行了 fit 或者 validate,追加一个 test + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") + + +if __name__ == "__main__": + run() diff --git a/requirements.txt b/requirements.txt index 8b9a1a03..17e754dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,8 @@ scikit-learn tables sympy openpyxl -setuptools==59.5.0 -numpy==1.24.4 +rich +lightning[pytorch-extra]==2.5.0.post0 +# omegaconf +# setuptools==59.5.0 +# numpy==1.24.4