Skip to content

为什么使用BasicTS最新的版本复现不了论文中的结果,差距还比较大 #2

@klayc-gzl

Description

@klayc-gzl

比如PEMS08数据集:
使用的PEMS08.py:

-- coding: utf-8 --

"""
Created on Thu Aug 14 22:57:12 2025

@author: HP
"""

import os
import sys
import torch
from easydict import EasyDict
sys.path.append(os.path.abspath(file + '/../../..'))

from basicts.metrics import masked_mae, masked_mape, masked_rmse
from basicts.data import TimeSeriesForecastingDataset
from basicts.runners import SimpleTimeSeriesForecastingRunner
from basicts.scaler import ZScoreScaler
from basicts.utils import get_regular_settings, load_adj

from basicts.models.M3Net import M3NetForForecasting

############################## Hot Parameters ##############################

Dataset & Metrics configuration

DATA_NAME = 'PEMS08' # Dataset name
regular_settings = get_regular_settings(DATA_NAME)
INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence
OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data

Model architecture and parameters

MODEL_ARCH = M3NetForForecasting
MODEL_PARAM = {
"num_nodes": 170,
"num_group": 10,
"input_len": INPUT_LEN,
"input_dim": 1,
"embed_dim": 32,
"output_len": OUTPUT_LEN,
"num_layer": 3,
"if_node": True,
"node_dim": 32,
"if_T_i_D": True,
"if_D_i_W": True,
"temp_dim_tid": 32,
"temp_dim_diw": 32,
"time_of_day_size": 288,
"day_of_week_size": 7
}
NUM_EPOCHS = 200

############################## General Configuration ##############################
CFG = EasyDict()

General settings

CFG.DESCRIPTION = 'An Example Config'
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode)

Runner

CFG.RUNNER = SimpleTimeSeriesForecastingRunner

############################## Dataset Configuration ##############################
CFG.DATASET = EasyDict()

Dataset settings

CFG.DATASET.NAME = DATA_NAME
CFG.DATASET.TYPE = TimeSeriesForecastingDataset
CFG.DATASET.PARAM = EasyDict({
'dataset_name': DATA_NAME,
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO,
'input_len': INPUT_LEN,
'output_len': OUTPUT_LEN,
'use_timestamps': True, # Enable timestamps for time embeddings
# 'mode' is automatically set by the runner
})

############################## Scaler Configuration ##############################
CFG.SCALER = EasyDict()

Scaler settings

CFG.SCALER.TYPE = ZScoreScaler # Scaler class
CFG.SCALER.PARAM = EasyDict({
'dataset_name': DATA_NAME,
'train_ratio': TRAIN_VAL_TEST_RATIO[0],
'norm_each_channel': NORM_EACH_CHANNEL,
'rescale': RESCALE,
})

############################## Model Configuration ##############################
CFG.MODEL = EasyDict()

Model settings

CFG.MODEL.NAME = MODEL_ARCH.name
CFG.MODEL.ARCH = MODEL_ARCH
CFG.MODEL.PARAM = MODEL_PARAM
CFG.MODEL.FORWARD_FEATURES = [0]
CFG.MODEL.TARGET_FEATURES = [0]

############################## Metrics Configuration ##############################

CFG.METRICS = EasyDict()

Metrics settings

CFG.METRICS.FUNCS = EasyDict({
'MAE': masked_mae,
'MAPE': masked_mape,
'RMSE': masked_rmse,
})
CFG.METRICS.TARGET = 'MAE'
CFG.METRICS.NULL_VAL = NULL_VAL

############################## Training Configuration ##############################
CFG.TRAIN = EasyDict()
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
'checkpoints',
MODEL_ARCH.name,
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)])
)
CFG.TRAIN.LOSS = masked_mae

Optimizer settings

CFG.TRAIN.OPTIM = EasyDict()
CFG.TRAIN.OPTIM.TYPE = "Adam"
CFG.TRAIN.OPTIM.PARAM = {
"lr": 0.002,
"weight_decay": 0.0001,
}

Learning rate scheduler settings

CFG.TRAIN.LR_SCHEDULER = EasyDict()
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR"
CFG.TRAIN.LR_SCHEDULER.PARAM = {
"milestones": [25, 50, 75, 100, 125, 150],
"gamma": 0.5
}
CFG.TRAIN.CLIP_GRAD_PARAM = {
'max_norm': 5.0
}

Train data loader settings

CFG.TRAIN.DATA = EasyDict()
CFG.TRAIN.DATA.BATCH_SIZE = 64
CFG.TRAIN.DATA.SHUFFLE = True

############################## Validation Configuration ##############################
CFG.VAL = EasyDict()
CFG.VAL.INTERVAL = 1
CFG.VAL.DATA = EasyDict()
CFG.VAL.DATA.BATCH_SIZE = 64

############################## Test Configuration ##############################
CFG.TEST = EasyDict()
CFG.TEST.INTERVAL = 1
CFG.TEST.DATA = EasyDict()
CFG.TEST.DATA.BATCH_SIZE = 64

############################## Evaluation Configuration ##############################

CFG.EVAL = EasyDict()

Evaluation parameters

CFG.EVAL.HORIZONS = [3, 6, 12] # Prediction horizons for evaluation. Default: []
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True

Image

结果:
{
"overall": {
"MAE": 14.954769656820202,
"MAPE": 0.12561433297944194,
"RMSE": 24.119266296187238
},
"horizon_3": {
"MAE": 13.996050395190213,
"MAPE": 0.12702000695021787,
"RMSE": 22.37340842347487
},
"horizon_6": {
"MAE": 14.935872885383887,
"MAPE": 0.12214047373116645,
"RMSE": 24.13685706515097
},
"horizon_12": {
"MAE": 16.515607816716763,
"MAPE": 0.12552956363138265,
"RMSE": 26.635417352043813
}
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions