Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ datasets/
todo.md
gpu_task.py
cmd.sh
*logs

# file
*.npz
Expand All @@ -19,6 +20,7 @@ cmd.sh
*.pyc
*.txt
*.core
*.ckpt

*.py[cod]
*$py.class
Expand Down
65 changes: 65 additions & 0 deletions basicts/configs/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
fit:
model:
class_path: basicts.model.BasicTimeSeriesForecastingModule
init_args:
lr: 2e-3
weight_decay: 1e-4
history_len: 12
horizon_len: 12
forward_features: [0, 1, 2]
target_features: [0]
metrics:
- MAE
- MAPE
- RMSE
scaler:
class_path: basicts.scaler.ZScoreScaler
init_args:
dataset_name: ${fit.data.init_args.dataset_name}
train_ratio: ${fit.data.init_args.train_val_test_ratio[0]}
norm_each_channel: False
rescale: True

data:
class_path: basicts.data.TimeSeriesForecastingModule
init_args:
dataset_class: basicts.data.TimeSeriesForecastingDataset
dataset_name: PEMS08
batch_size: 32
train_val_test_ratio: [0.6, 0.2, 0.2]
input_len: ${fit.model.init_args.history_len}
output_len: ${fit.model.init_args.horizon_len}
overlap: False

trainer:
max_epochs: 300
devices: auto
log_every_n_steps: 10
callbacks:
# Use rich for better progress bar and model summary (pip install rich)
# - class_path: lightning.pytorch.callbacks.RichProgressBar
# - class_path: lightning.pytorch.callbacks.RichModelSummary
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val/loss
patience: 20
mode: min
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val/loss
mode: min
save_top_k: 1
filename: best
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: examples/lightning_logs
name: ${fit.data.init_args.dataset_name}

test:
data: ${fit.data}
model: ${fit.model}
trainer: ${fit.trainer}
# Specify the checkpoint path to test, it is recommended to speicify in the command line, e.g., --ckpt_path=examples/lightning_logs/best.ckpt
# ckpt_path: examples/lightning_logs/best.ckpt

3 changes: 2 additions & 1 deletion basicts/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_dataset import BaseDataset
from .simple_tsf_dataset import TimeSeriesForecastingDataset
from .tsf_datamodule import TimeSeriesForecastingModule

__all__ = ['BaseDataset', 'TimeSeriesForecastingDataset']
__all__ = ['BaseDataset', 'TimeSeriesForecastingDataset', 'TimeSeriesForecastingModule']
105 changes: 105 additions & 0 deletions basicts/data/tsf_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import List

from basicts.utils import get_regular_settings
import lightning.pytorch as pl
from importlib import import_module
from torch.utils.data import DataLoader


class TimeSeriesForecastingModule(pl.LightningDataModule):
def __init__(
self,
dataset_class: str,
dataset_name: str,
train_val_test_ratio: List[float],
input_len: int,
output_len: int,
overlap: bool = False,
batch_size: int = 32,
num_workers: int = 0,
pin_memory: bool = False,
shuffle: bool = True,
prefetch: bool = False,
):
super().__init__()
dataset_class_packeg, dataset_class_name = dataset_class.rsplit(".", 1)
self.dataset_class = getattr(
import_module(dataset_class_packeg), dataset_class_name
)
if prefetch:
# todo: implement DataLoaderX
# self.dataloader_class = DataLoaderX
raise NotImplementedError("DataLoaderX is not implemented yet.")
else:
self.dataloader_class = DataLoader
self.dataset_name = dataset_name
self.train_val_test_ratio = train_val_test_ratio
self.input_len = input_len
self.output_len = output_len
self.overlap = overlap
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
self.shuffle = shuffle
self.prefetch = prefetch
self.regular_settings = get_regular_settings(dataset_name)
self.save_hyperparameters()
# self.train_set = TimeSeriesForecastingDataset()

@property
def dataset_params(self):
return {
"dataset_name": self.dataset_name,
"train_val_test_ratio": self.train_val_test_ratio,
"input_len": self.input_len,
"output_len": self.output_len,
"overlap": self.overlap,
}

def train_dataloader(self):
"""Build train dataset and dataloader.

Returns:
train data loader (DataLoader)
"""
dataset = self.dataset_class(**self.dataset_params, mode="train")
loader = self.dataloader_class(
dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
)
return loader

def val_dataloader(self):
"""Build validation dataset and dataloader.

Returns:
validation data loader (DataLoader)
"""
dataset = self.dataset_class(**self.dataset_params, mode="valid")
loader = self.dataloader_class(
dataset,
batch_size=self.batch_size,
# shuffle=self.shuffle, # No need to shuffle validation data
num_workers=self.num_workers,
pin_memory=self.pin_memory,
)
return loader

def test_dataloader(self):
"""Build test dataset and dataloader.

Returns:
test data loader (DataLoader)
"""
dataset = self.dataset_class(**self.dataset_params, mode="test")
loader = self.dataloader_class(
dataset,
batch_size=self.batch_size,
# shuffle=self.shuffle, # No need to shuffle test data
num_workers=self.num_workers,
pin_memory=self.pin_memory,
)
return loader
Loading