From 3bab673c3a39e27223bac9f150df8f673f7809f2 Mon Sep 17 00:00:00 2001 From: SohaibAhmed121 Date: Mon, 13 Jan 2025 11:39:18 +0500 Subject: [PATCH 01/21] Refactored NBeats and added comments for KAN block and NBeats. --- pytorch_forecasting/models/nbeats/__init__.py | 246 +++++++- .../models/nbeats/kan_layer.py | 532 ++++++++++++++++++ .../models/nbeats/sub_modules.py | 224 +++++++- 3 files changed, 987 insertions(+), 15 deletions(-) create mode 100644 pytorch_forecasting/models/nbeats/kan_layer.py diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index 8d00392cc..114e4efac 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -8,7 +8,8 @@ from torch import nn from pytorch_forecasting.data import TimeSeriesDataSet -from pytorch_forecasting.data.encoders import NaNLabelEncoder + +# from pytorch_forecasting.data.encoders import NaNLabelEncoder from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric from pytorch_forecasting.models.base_model import BaseModel from pytorch_forecasting.models.nbeats.sub_modules import NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock @@ -26,6 +27,19 @@ def __init__( expansion_coefficient_lengths: Optional[List[int]] = None, prediction_length: int = 1, context_length: int = 1, + use_kan: bool = False, + num_grids: int = 5, + k: int = 3, + noise_scale: float = 0.5, + scale_base_mu: float = 0.0, + scale_base_sigma: float = 1.0, + scale_sp: float = 1.0, + base_fun: callable = torch.nn.SiLU(), + grid_eps: float = 0.02, + grid_range: List[int] = [-1, 1], + sp_trainable: bool = True, + sb_trainable: bool = True, + sparse_init: bool = False, dropout: float = 0.1, learning_rate: float = 1e-2, log_interval: int = -1, @@ -76,6 +90,24 @@ def __init__( prediction_length: Length of the prediction. Also known as 'horizon'. context_length: Number of time units that condition the predictions. Also known as 'lookback period'. Should be between 1-10 times the prediction length. + num_grids : Parameter for KAN layer. the number of grid intervals = G. Default: 5. + k : Parameter for KAN layer. the order of piecewise polynomial. Default: 3. + noise_scale : Parameter for KAN layer. the scale of noise injected at initialization. Default: 0.1. + scale_base_mu : Parameter for KAN layer. the scale of the residual function b(x) is intialized to be + N(scale_base_mu, scale_base_sigma^2). Deafult: 0.0 + scale_base_sigma : Parameter for KAN layer. the scale of the residual function b(x) is intialized to be + N(scale_base_mu, scale_base_sigma^2). Deafult: 1.0 + scale_sp : Parameter for KAN layer. the scale of the base function spline(x). Deafult: 1.0 + base_fun : Parameter for KAN layer. residual function b(x). Default: torch.nn.SiLU() + grid_eps : Parameter for KAN layer. When grid_eps = 1, the grid is uniform; when grid_eps = 0, + the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the + two extremes. Deafult: 0.02 + grid_range : Parameter for KAN layer. list/np.array of shape (2,). setting the range of grids. + Default: [-1,1]. + sp_trainable : Parameter for KAN layer. If true, scale_sp is trainable. Default: True. + sb_trainable : Parameter for KAN layer. If true, scale_base is trainable. Default: True. + sparse_init : Parameter for KAN layer. if sparse_init = True, sparse initialization is applied. + Default: False. backcast_loss_ratio: weight of backcast in comparison to forecast when calculating the loss. A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and forecast lengths). Defaults to 0.0, i.e. no weight. @@ -103,6 +135,23 @@ def __init__( logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) if loss is None: loss = MASE() + # Bundle KAN parameters into a dictionary + self.kan_params = { + "use_kan": use_kan, + "num_grids": num_grids, + "k": k, + "noise_scale": noise_scale, + "scale_base_mu": scale_base_mu, + "scale_base_sigma": scale_base_sigma, + "scale_sp": scale_sp, + "base_fun": base_fun, + "grid_eps": grid_eps, + "grid_range": grid_range, + "sp_trainable": sp_trainable, + "sb_trainable": sb_trainable, + "sparse_init": sparse_init, + } + self.save_hyperparameters() super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) @@ -118,6 +167,7 @@ def __init__( backcast_length=context_length, forecast_length=prediction_length, dropout=self.hparams.dropout, + kan_params=self.hparams.kan_params, ) elif stack_type == "seasonality": net_block = NBEATSSeasonalBlock( @@ -127,6 +177,7 @@ def __init__( forecast_length=prediction_length, min_period=self.hparams.expansion_coefficient_lengths[stack_id], dropout=self.hparams.dropout, + kan_params=self.hparams.kan_params, ) elif stack_type == "trend": net_block = NBEATSTrendBlock( @@ -136,6 +187,7 @@ def __init__( backcast_length=context_length, forecast_length=prediction_length, dropout=self.hparams.dropout, + kan_params=self.hparams.kan_params, ) else: raise ValueError(f"Unknown stack type {stack_type}") @@ -374,3 +426,195 @@ def plot_interpretation( fig.legend() return fig + + +# from sktime.datasets import load_airline +# import pandas as pd +# from pytorch_forecasting.data import TimeSeriesDataSet +# import lightning.pytorch as pl +# from lightning.pytorch.callbacks import EarlyStopping + +# # Load the airline dataset +# y = load_airline() + +# # Convert to DataFrame and reset index for clarity +# df = y.reset_index() + +# # Add a 'time_idx' column with values same as the index of rows +# df["time_idx"] = df.index + +# # Display the DataFrame +# data = df.drop(columns=["Period"]) +# data["series"] = 0 +# # data["value"] = data["Number of airline passengers"]+20 + + +# # create dataset and dataloaders +# max_encoder_length = 60 +# max_prediction_length = 20 + +# training_cutoff = data["time_idx"].max() - max_prediction_length + +# context_length = max_encoder_length +# prediction_length = max_prediction_length + +# training = TimeSeriesDataSet( +# data[lambda x: x.time_idx <= training_cutoff], +# time_idx="time_idx", +# target="Number of airline passengers", +# categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, +# group_ids=["series"], +# # only unknown variable is "value" - and N-Beats can also not take any additional variables +# time_varying_unknown_reals=["Number of airline passengers"], +# max_encoder_length=context_length, +# max_prediction_length=prediction_length, +# ) +# print("hazrat") +# validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training_cutoff + 1) +# batch_size = 2 +# train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0) +# val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0) + +# pl.seed_everything(42) +# trainer = pl.Trainer(accelerator="auto", gradient_clip_val=0.01) +# net = NBeats.from_dataset( +# training, +# learning_rate=1e-3, +# log_interval=10, +# log_val_interval=1, +# weight_decay=1e-2, +# widths=[32, 512], +# backcast_loss_ratio=1.0, +# ) + +# early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min") +# trainer = pl.Trainer( +# max_epochs=2, +# accelerator="auto", +# enable_model_summary=True, +# gradient_clip_val=0.1, +# callbacks=[early_stop_callback], +# limit_train_batches=150, +# ) + +# trainer.fit( +# net, +# train_dataloaders=train_dataloader, +# val_dataloaders=val_dataloader, +# ) + +# best_model_path = trainer.checkpoint_callback.best_model_path +# best_model = NBeats.load_from_checkpoint(best_model_path) + +# # for x, y in iter(val_dataloader): +# # for y in y: +# # print(y,type(y)) +# # actuals = torch.cat([y for x, y in iter(val_dataloader)]).to("cpu") +# # actuals = [y_tensors[0] for _, y_tensors in iter(val_dataloader)][0] + +# # print(actuals) + +# # predictions = best_model.predict(val_dataloader, trainer_kwargs=dict(accelerator="cpu")) +# # print(predictions) +# # predictions_tensor = torch.cat(predictions) +# # actuals_tensor = torch.cat(actuals) + +# # # Calculate the absolute error and mean +# # error = (actuals_tensor - predictions_tensor).abs().mean() + +# # print(f"Mean absolute error: {error}") +# import matplotlib.pyplot as plt + +# raw_predictions = best_model.predict(val_dataloader, mode="raw", return_x=True) + +# for idx in range(10): # plot 10 examples +# figure = best_model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=idx, add_loss_to_title=True) +# plt.show() + + +import warnings + +warnings.filterwarnings("ignore") +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +import pandas as pd +import torch + +from pytorch_forecasting import TimeSeriesDataSet +from pytorch_forecasting.data import NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data + + +data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100, seed=42) +data["static"] = 2 +data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") +data.head() + +# create dataset and dataloaders +max_encoder_length = 60 +max_prediction_length = 20 + +training_cutoff = data["time_idx"].max() - max_prediction_length + +context_length = max_encoder_length +prediction_length = max_prediction_length + +training = TimeSeriesDataSet( + data[lambda x: x.time_idx <= training_cutoff], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + # only unknown variable is "value" - and N-Beats can also not take any additional variables + time_varying_unknown_reals=["value"], + max_encoder_length=context_length, + max_prediction_length=prediction_length, +) + +validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training_cutoff + 1) +batch_size = 128 +train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0) +val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0) + +pl.seed_everything(42) +trainer = pl.Trainer(accelerator="auto", gradient_clip_val=0.01) +# net = NBeats.from_dataset(training, learning_rate=3e-2, weight_decay=1e-2, widths=[32, 512], backcast_loss_ratio=0.1) +net = NBeats.from_dataset( + training, + learning_rate=1e-3, + log_interval=10, + log_val_interval=1, + weight_decay=1e-2, + widths=[32, 512], + backcast_loss_ratio=1.0, + num_block_layers=[3, 3], +) + +early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min") +trainer = pl.Trainer( + max_epochs=1, + accelerator="auto", + enable_model_summary=True, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + limit_train_batches=150, +) + +trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, +) + +best_model_path = trainer.checkpoint_callback.best_model_path +best_model = NBeats.load_from_checkpoint(best_model_path) + +raw_predictions = best_model.predict(val_dataloader, mode="raw", return_x=True) +print(best_model) +import matplotlib.pyplot as plt + +raw_predictions = best_model.predict(val_dataloader, mode="raw", return_x=True) + +for idx in range(10): # plot 10 examples + figure = best_model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=idx, add_loss_to_title=True) + plt.show() diff --git a/pytorch_forecasting/models/nbeats/kan_layer.py b/pytorch_forecasting/models/nbeats/kan_layer.py new file mode 100644 index 000000000..13aa77802 --- /dev/null +++ b/pytorch_forecasting/models/nbeats/kan_layer.py @@ -0,0 +1,532 @@ +import torch +import torch.nn as nn +import numpy as np + + +def B_batch(x, grid, k=0, extend=True, device="cpu"): + """ + evaludate x on B-spline bases + + Args: + ----- + x : 2D torch.tensor + inputs, shape (number of splines, number of samples) + grid : 2D torch.tensor + grids, shape (number of splines, number of grid points) + k : int + the piecewise polynomial order of splines. + extend : bool + If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True + device : str + devicde + + Returns: + -------- + spline values : 3D torch.tensor + shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order. + + Example + ------- + >>> from kan.spline import B_batch + >>> x = torch.rand(100,2) + >>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11) + >>> B_batch(x, grid, k=3).shape + """ + + x = x.unsqueeze(dim=2) + grid = grid.unsqueeze(dim=0) + + if k == 0: + value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:]) + else: + B_km1 = B_batch(x[:, :, 0], grid=grid[0], k=k - 1) + + value = (x - grid[:, :, : -(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, : -(k + 1)]) * B_km1[:, :, :-1] + ( + grid[:, :, k + 1 :] - x + ) / (grid[:, :, k + 1 :] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:] + + # in case grid is degenerate + value = torch.nan_to_num(value) + return value + + +def coef2curve(x_eval, grid, coef, k, device="cpu"): + """ + converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves + (summing up B_batch results over B-spline basis). + + Args: + ----- + x_eval : 2D torch.tensor + shape (batch, in_dim) + grid : 2D torch.tensor + shape (in_dim, G+2k). G: the number of grid intervals; k: spline order. + coef : 3D torch.tensor + shape (in_dim, out_dim, G+k) + k : int + the piecewise polynomial order of splines. + device : str + devicde + + Returns: + -------- + y_eval : 3D torch.tensor + shape (batch, in_dim, out_dim) + + """ + + b_splines = B_batch(x_eval, grid, k=k) + y_eval = torch.einsum("ijk,jlk->ijl", b_splines, coef.to(b_splines.device)) + + return y_eval + + +def curve2coef(x_eval, y_eval, grid, k): + """ + converting B-spline curves to B-spline coefficients using least squares. + + Args: + ----- + x_eval : 2D torch.tensor + shape (batch, in_dim) + y_eval : 3D torch.tensor + shape (batch, in_dim, out_dim) + grid : 2D torch.tensor + shape (in_dim, grid+2*k) + k : int + spline order + lamb : float + regularized least square lambda + + Returns: + -------- + coef : 3D torch.tensor + shape (in_dim, out_dim, G+k) + """ + # print('haha', x_eval.shape, y_eval.shape, grid.shape) + batch = x_eval.shape[0] + in_dim = x_eval.shape[1] + out_dim = y_eval.shape[2] + n_coef = grid.shape[1] - k - 1 + mat = B_batch(x_eval, grid, k) + mat = mat.permute(1, 0, 2)[:, None, :, :].expand(in_dim, out_dim, batch, n_coef) + y_eval = y_eval.permute(1, 2, 0).unsqueeze(dim=3) + try: + coef = torch.linalg.lstsq(mat, y_eval).solution[:, :, :, 0] + except Exception as e: + print(f"lstsq failed with error: {e}") + + # manual psuedo-inverse + """lamb=1e-8 + XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat) + Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval) + n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2] + identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device) + A = XtX + lamb * identity + B = Xty + coef = (A.pinverse() @ B)[:,:,:,0]""" + + return coef + + +def extend_grid(grid, k_extend=0): + """ + extend grid + """ + h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) + + for i in range(k_extend): + grid = torch.cat([grid[:, [0]] - h, grid], dim=1) + grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) + + return grid + + +def sparse_mask(in_dim, out_dim): + """ + get sparse mask + """ + in_coord = torch.arange(in_dim) * 1 / in_dim + 1 / (2 * in_dim) + out_coord = torch.arange(out_dim) * 1 / out_dim + 1 / (2 * out_dim) + + dist_mat = torch.abs(out_coord[:, None] - in_coord[None, :]) + in_nearest = torch.argmin(dist_mat, dim=0) + in_connection = torch.stack([torch.arange(in_dim), in_nearest]).permute(1, 0) + out_nearest = torch.argmin(dist_mat, dim=1) + out_connection = torch.stack([out_nearest, torch.arange(out_dim)]).permute(1, 0) + all_connection = torch.cat([in_connection, out_connection], dim=0) + mask = torch.zeros(in_dim, out_dim) + mask[all_connection[:, 0], all_connection[:, 1]] = 1.0 + + return mask + + +class KANLayer(nn.Module): + """ + KANLayer class + + + Attributes: + ----------- + in_dim: int + input dimension + out_dim: int + output dimension + num: int + the number of grid intervals + k: int + the piecewise polynomial order of splines + noise_scale: float + spline scale at initialization + coef: 2D torch.tensor + coefficients of B-spline bases + scale_base_mu: float + magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_mu + scale_base_sigma: float + magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_sigma + scale_sp: float + mangitude of the spline function spline(x) + base_fun: fun + residual function b(x) + mask: 1D torch.float + mask of spline functions. setting some element of the mask to zero means setting the + corresponding activation to zero function. + grid_eps: float in [0,1] + a hyperparameter used in update_grid_from_samples. When grid_eps = 1, the grid is uniform; + when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 + interpolates between the two extremes. + the id of activation functions that are locked + """ + + def __init__( + self, + in_dim=3, + out_dim=2, + num=5, + k=3, + noise_scale=0.5, + scale_base_mu=0.0, + scale_base_sigma=1.0, + scale_sp=1.0, + base_fun=torch.nn.SiLU(), + grid_eps=0.02, + grid_range=[-1, 1], + sp_trainable=True, + sb_trainable=True, + device="cpu", + sparse_init=False, + ): + """' + initialize a KANLayer + + Args: + ----- + in_dim : int + input dimension. Default: 2. + out_dim : int + output dimension. Default: 3. + num : int + the number of grid intervals = G. Default: 5. + k : int + the order of piecewise polynomial. Default: 3. + noise_scale : float + the scale of noise injected at initialization. Default: 0.1. + scale_base_mu : float + the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma : float + the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). + scale_sp : float + the scale of the base function spline(x). + base_fun : function + residual function b(x). Default: torch.nn.SiLU() + grid_eps : float + When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using + percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. + grid_range : list/np.array of shape (2,) + setting the range of grids. Default: [-1,1]. + sp_trainable : bool + If true, scale_sp is trainable + sb_trainable : bool + If true, scale_base is trainable + sparse_init : bool + if sparse_init = True, sparse initialization is applied. + + Returns: + -------- + self + + Example + ------- + >>> from kan.KANLayer import * + >>> model = KANLayer(in_dim=3, out_dim=5) + >>> (model.in_dim, model.out_dim) + """ + super(KANLayer, self).__init__() + # size + self.out_dim = out_dim + self.in_dim = in_dim + self.num = num + self.k = k + + grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[None, :].expand(self.in_dim, num + 1) + grid = extend_grid(grid, k_extend=k) + self.grid = torch.nn.Parameter(grid).requires_grad_(False) + noises = (torch.rand(self.num + 1, self.in_dim, self.out_dim) - 1 / 2) * noise_scale / num + + self.coef = torch.nn.Parameter(curve2coef(self.grid[:, k:-k].permute(1, 0), noises, self.grid, k)) + + if sparse_init: + self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_(False) + else: + self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_(False) + + self.scale_base = torch.nn.Parameter( + scale_base_mu * 1 / np.sqrt(in_dim) + + scale_base_sigma * (torch.rand(in_dim, out_dim) * 2 - 1) * 1 / np.sqrt(in_dim) + ).requires_grad_(sb_trainable) + self.scale_sp = torch.nn.Parameter( + torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask + ).requires_grad_( + sp_trainable + ) # make scale trainable + self.base_fun = base_fun + + self.grid_eps = grid_eps + + def forward(self, x): + """ + KANLayer forward given input x + + Args: + ----- + x : 2D torch.float + inputs, shape (number of samples, input dimension) + + Returns: + -------- + y : 2D torch.float + outputs, shape (number of samples, output dimension) + preacts : 3D torch.float + fan out x into activations, shape (number of sampels, output dimension, input dimension) + postacts : 3D torch.float + the outputs of activation functions with preacts as inputs + postspline : 3D torch.float + the outputs of spline functions with preacts as inputs + + Example + ------- + >>> from kan.KANLayer import * + >>> model = KANLayer(in_dim=3, out_dim=5) + >>> x = torch.normal(0,1,size=(100,3)) + >>> y, preacts, postacts, postspline = model(x) + >>> y.shape, preacts.shape, postacts.shape, postspline.shape + """ + + base = self.base_fun(x) # (batch, in_dim) + y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k) + y = self.scale_base[None, :, :] * base[:, :, None] + self.scale_sp[None, :, :] * y + y = self.mask[None, :, :] * y + y = torch.sum(y, dim=1) + return y + + def update_grid_from_samples(self, x, mode="sample"): + """ + update grid from samples + + Args: + ----- + x : 2D torch.float + inputs, shape (number of samples, input dimension) + + Returns: + -------- + None + + Example + ------- + >>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) + >>> print(model.grid.data) + >>> x = torch.linspace(-3,3,steps=100)[:,None] + >>> model.update_grid_from_samples(x) + >>> print(model.grid.data) + """ + + batch = x.shape[0] + # x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size) + # .permute(1, 0) + x_pos = torch.sort(x, dim=0)[0] + y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) + num_interval = self.grid.shape[1] - 1 - 2 * self.k + + def get_grid(num_interval): + ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] + grid_adaptive = x_pos[ids, :].permute(1, 0) + margin = 0.00 + h = (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) / num_interval + grid_uniform = ( + grid_adaptive[:, [0]] + - margin + + h + * torch.arange( + num_interval + 1, + )[ + None, : + ].to(x.device) + ) + grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive + return grid + + grid = get_grid(num_interval) + + if mode == "grid": + sample_grid = get_grid(2 * num_interval) + x_pos = sample_grid.permute(1, 0) + y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) + + self.grid.data = extend_grid(grid, k_extend=self.k) + # print('x_pos 2', x_pos.shape) + # print('y_eval 2', y_eval.shape) + self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) + + def initialize_grid_from_parent(self, parent, x, mode="sample"): + """ + update grid from a parent KANLayer & samples + + Args: + ----- + parent : KANLayer + a parent KANLayer (whose grid is usually coarser than the current model) + x : 2D torch.float + inputs, shape (number of samples, input dimension) + + Returns: + -------- + None + + Example + ------- + >>> batch = 100 + >>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) + >>> print(parent_model.grid.data) + >>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3) + >>> x = torch.normal(0,1,size=(batch, 1)) + >>> model.initialize_grid_from_parent(parent_model, x) + >>> print(model.grid.data) + """ + # shrink grid + x_pos = torch.sort(x, dim=0)[0] + y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k) + num_interval = self.grid.shape[1] - 1 - 2 * self.k + + """ + # based on samples + def get_grid(num_interval): + ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] + grid_adaptive = x_pos[ids, :].permute(1,0) + h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval + grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device) + grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive + return grid""" + + # based on interpolating parent grid + def get_grid(num_interval): + x_pos = parent.grid[:, parent.k : -parent.k] + # print('x_pos', x_pos) + sp2 = KANLayer( + in_dim=1, out_dim=self.in_dim, k=1, num=x_pos.shape[1] - 1, scale_base_mu=0.0, scale_base_sigma=0.0 + ).to(x.device) + + # print('sp2_grid', sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim)) + # print('sp2_coef_shape', sp2.coef.shape) + sp2_coef = curve2coef( + sp2.grid[:, sp2.k : -sp2.k].permute(1, 0).expand(-1, self.in_dim), + x_pos.permute(1, 0).unsqueeze(dim=2), + sp2.grid[:, :], + k=1, + ).permute(1, 0, 2) + sp2.coef.data = sp2_coef + percentile = torch.linspace(-1, 1, self.num + 1).to(self.device) + grid = sp2(percentile.unsqueeze(dim=1))[0].permute(1, 0) + return grid + + grid = get_grid(num_interval) + + if mode == "grid": + sample_grid = get_grid(2 * num_interval) + x_pos = sample_grid.permute(1, 0) + y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k) + + grid = extend_grid(grid, k_extend=self.k) + self.grid.data = grid + self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) + + def get_subset(self, in_id, out_id): + """ + get a smaller KANLayer from a larger KANLayer (used for pruning) + + Args: + ----- + in_id : list + id of selected input neurons + out_id : list + id of selected output neurons + + Returns: + -------- + spb : KANLayer + + Example + ------- + >>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3) + >>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3]) + >>> kanlayer_small.in_dim, kanlayer_small.out_dim + (2, 3) + """ + spb = KANLayer(len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun) + spb.grid.data = self.grid[in_id] + spb.coef.data = self.coef[in_id][:, out_id] + spb.scale_base.data = self.scale_base[in_id][:, out_id] + spb.scale_sp.data = self.scale_sp[in_id][:, out_id] + spb.mask.data = self.mask[in_id][:, out_id] + + spb.in_dim = len(in_id) + spb.out_dim = len(out_id) + return spb + + def swap(self, i1, i2, mode="in"): + """ + swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out') + + Args: + ----- + i1 : int + i2 : int + mode : str + mode = 'in' or 'out' + + Returns: + -------- + None + + Example + ------- + >>> from kan.KANLayer import * + >>> model = KANLayer(in_dim=2, out_dim=2, num=5, k=3) + >>> print(model.coef) + >>> model.swap(0,1,mode='in') + >>> print(model.coef) + """ + with torch.no_grad(): + + def swap_(data, i1, i2, mode="in"): + if mode == "in": + data[i1], data[i2] = data[i2].clone(), data[i1].clone() + elif mode == "out": + data[:, i1], data[:, i2] = data[:, i2].clone(), data[:, i1].clone() + + if mode == "in": + swap_(self.grid.data, i1, i2, mode="in") + swap_(self.coef.data, i1, i2, mode=mode) + swap_(self.scale_base.data, i1, i2, mode=mode) + swap_(self.scale_sp.data, i1, i2, mode=mode) + swap_(self.mask.data, i1, i2, mode=mode) diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index e300d452f..f78ba1ac4 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from kan_layer import KANLayer def linear(input_size, output_size, bias=True, dropout: int = None): @@ -33,6 +34,65 @@ def linspace(backcast_length: int, forecast_length: int, centered: bool = False) return b_ls, f_ls +# class NBEATSBlock(nn.Module): +# def __init__( +# self, +# units, +# thetas_dim, +# num_block_layers=4, +# backcast_length=10, +# forecast_length=5, +# share_thetas=False, +# num_grid_intervals=5, +# k_order=3, +# dropout=0.1, +# ): +# super().__init__() +# self.units = units +# self.thetas_dim = thetas_dim +# self.backcast_length = backcast_length +# self.forecast_length = forecast_length +# self.share_thetas = share_thetas +# # First KANLayer +# layers = [ +# KANLayer( +# in_dim=backcast_length, +# out_dim=units, +# num=num_grid_intervals, +# k=k_order, +# device="cpu", +# ) +# ] +# # Additional KANLayers for deeper structure +# for _ in range(num_block_layers - 1): +# layers.extend( +# [ +# KANLayer( +# in_dim=units, +# out_dim=units, +# num=num_grid_intervals, +# k=k_order, +# device="cpu", +# ) +# ] +# ) +# self.fc = nn.Sequential(*layers) +# # print(self.fc) +# # Theta layers +# if share_thetas: +# self.theta_f_fc = self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) +# else: +# self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) +# self.theta_f_fc = nn.Linear(units, thetas_dim, bias=False) + +# def forward(self, x): +# # x = x.unsqueeze(0) +# # print(x.shape,"here") +# y = self.fc(x) +# # print("bhen") +# return y + + class NBEATSBlock(nn.Module): def __init__( self, @@ -43,6 +103,7 @@ def __init__( forecast_length=5, share_thetas=False, dropout=0.1, + kan_params={}, ): super().__init__() self.units = units @@ -50,6 +111,52 @@ def __init__( self.backcast_length = backcast_length self.forecast_length = forecast_length self.share_thetas = share_thetas + self.kan_params = kan_params + + if self.kan_params["use_kan"]: + layers = [ + KANLayer( + in_dim=backcast_length, + out_dim=units, + num=self.kan_params["num_grids"], + k=self.kan_params["k"], + noise_scale=self.kan_params["noise_scale"], + scale_base_mu=self.kan_params["scale_base_mu"], + scale_base_sigma=self.kan_params["scale_base_sigma"], + scale_sp=self.kan_params["scale_sp"], + base_fun=self.kan_params["base_fun"], + grid_eps=self.kan_params["grid_eps"], + grid_range=self.kan_params["grid_range"], + sp_trainable=self.kan_params["sp_trainable"], + sb_trainable=self.kan_params["sb_trainable"], + sparse_init=self.kan_params["sparse_init"], + ) + ] + # Additional KANLayers for deeper structure + for _ in range(num_block_layers - 1): + layers.extend( + [ + KANLayer( + in_dim=units, + out_dim=units, + num=self.kan_params["num_grids"], + k=self.kan_params["k"], + noise_scale=self.kan_params["noise_scale"], + scale_base_mu=self.kan_params["scale_base_mu"], + scale_base_sigma=self.kan_params["scale_base_sigma"], + scale_sp=self.kan_params["scale_sp"], + base_fun=self.kan_params["base_fun"], + grid_eps=self.kan_params["grid_eps"], + grid_range=self.kan_params["grid_range"], + sp_trainable=self.kan_params["sp_trainable"], + sb_trainable=self.kan_params["sb_trainable"], + sparse_init=self.kan_params["sparse_init"], + device="cpu", # Assuming you are using the "cpu" device + ) + ] + ) + + self.fc = nn.Sequential(*layers) fc_stack = [ nn.Linear(backcast_length, units), @@ -80,7 +187,41 @@ def __init__( nb_harmonics=None, min_period=1, dropout=0.1, + kan_params={}, ): + """ + Initialize NBeatsSeasonalBlock + + Args: + units: The number of units in the mlp/kan layers. Default: 256. + thetas_dim: The dimension of the parameterized output for the block. If None, it is inferred. Default: None. + num_block_layers: Number of fully connected mlp/kan layers. Default: 4. + backcast_length: The length of the backcast. Defines how many time units from the past are used to + predict the future. Default: 10. + forecast_length: The length of the forecast, i.e., the number of time steps ahead to predict. Default: 5. + nb_harmonics: The number of harmonics in the seasonal function (relevant for seasonal models). + Default: None (no seasonality). + min_period: The minimum period used for seasonal patterns. Default: 1. + dropout: The dropout rate applied to the fully connected mlp layers to prevent overfitting. Default: 0.1. + kan_params (dict): Parameters specific to the KAN layer (used for modeling using KAN). + Default: empty dictionary. + Contains: + num_grids (int): The number of grid intervals for KAN. + k (int): The order of the piecewise polynomial for KAN. + noise_scale (float): The scale of noise injected at initialization. + scale_base_mu (float): The scale of the residual function initialized to + N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma (float): The scale of the residual function initialized to + N(scale_base_mu, scale_base_sigma^2). + scale_sp (float): The scale of the base function spline(x) in KAN. + base_fun (function): The residual function used by KAN (e.g., torch.nn.SiLU()). + grid_eps (float): Determines the partitioning of the grid. If 1, the grid is uniform; if 0, + grid is partitioned by percentiles. + grid_range (list or np.array): The range of the grid, given as a list of two values. + sp_trainable (bool): If True, the scale_sp is trainable. + sb_trainable (bool): If True, the scale_base is trainable. + sparse_init (bool): If True, applies sparse initialization. + """ if nb_harmonics: thetas_dim = nb_harmonics else: @@ -95,6 +236,7 @@ def __init__( forecast_length=forecast_length, share_thetas=True, dropout=dropout, + kan_params=kan_params, ) backcast_linspace, forecast_linspace = linspace(backcast_length, forecast_length, centered=False) @@ -117,6 +259,7 @@ def __init__( self.register_buffer("S_forecast", torch.cat([s1_f, s2_f])) def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes the backcast and forecast outputs for the given input tensor.""" x = super().forward(x) amplitudes_backward = self.theta_b_fc(x) backcast = amplitudes_backward.mm(self.S_backcast) @@ -126,19 +269,46 @@ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: return backcast, forecast def get_frequencies(self, n): + """ + Generates frequency values based on the backcast and forecast lengths. + """ return np.linspace(0, (self.backcast_length + self.forecast_length) / self.min_period, n) class NBEATSTrendBlock(NBEATSBlock): def __init__( - self, - units, - thetas_dim, - num_block_layers=4, - backcast_length=10, - forecast_length=5, - dropout=0.1, + self, units, thetas_dim, num_block_layers=4, backcast_length=10, forecast_length=5, dropout=0.1, kan_params={} ): + """ + Initialize NBeatsSeasonalBlock + + Args: + units: The number of units in the mlp/kan layers. Default: 256. + thetas_dim: The dimension of the parameterized output for the block. If None, it is inferred. Default: None. + num_block_layers: Number of fully connected mlp/kan layers. Default: 4. + backcast_length: The length of the backcast. Defines how many time units from the past are used to + predict the future. Default: 10. + forecast_length: The length of the forecast, i.e., the number of time steps ahead to predict. Default: 5. + dropout: The dropout rate applied to the fully connected mlp layers to prevent overfitting. Default: 0.1. + kan_params (dict): Parameters specific to the KAN layer (used for modeling using KAN). + Default: empty dictionary. + Contains: + num_grids (int): The number of grid intervals for KAN. + k (int): The order of the piecewise polynomial for KAN. + noise_scale (float): The scale of noise injected at initialization. + scale_base_mu (float): The scale of the residual function initialized to + N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma (float): The scale of the residual function initialized to + N(scale_base_mu, scale_base_sigma^2). + scale_sp (float): The scale of the base function spline(x) in KAN. + base_fun (function): The residual function used by KAN (e.g., torch.nn.SiLU()). + grid_eps (float): Determines the partitioning of the grid. If 1, the grid is uniform; if 0, + grid is partitioned by percentiles. + grid_range (list or np.array): The range of the grid, given as a list of two values. + sp_trainable (bool): If True, the scale_sp is trainable. + sb_trainable (bool): If True, the scale_base is trainable. + sparse_init (bool): If True, applies sparse initialization. + """ super().__init__( units=units, thetas_dim=thetas_dim, @@ -147,6 +317,7 @@ def __init__( forecast_length=forecast_length, share_thetas=True, dropout=dropout, + kan_params=kan_params, ) backcast_linspace, forecast_linspace = linspace(backcast_length, forecast_length, centered=True) @@ -167,14 +338,38 @@ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: class NBEATSGenericBlock(NBEATSBlock): def __init__( - self, - units, - thetas_dim, - num_block_layers=4, - backcast_length=10, - forecast_length=5, - dropout=0.1, + self, units, thetas_dim, num_block_layers=4, backcast_length=10, forecast_length=5, dropout=0.1, kan_params={} ): + """ + Initialize NBeatsSeasonalBlock + + Args: + units: The number of units in the mlp/kan layers. Default: 256. + thetas_dim: The dimension of the parameterized output for the block. If None, it is inferred. Default: None. + num_block_layers: Number of fully connected mlp/kan layers. Default: 4. + backcast_length: The length of the backcast. Defines how many time units from the past are used to + predict the future. Default: 10. + forecast_length: The length of the forecast, i.e., the number of time steps ahead to predict. Default: 5. + dropout: The dropout rate applied to the fully connected mlp layers to prevent overfitting. Default: 0.1. + kan_params (dict): Parameters specific to the KAN layer (used for modeling using KAN). + Default: empty dictionary. + Contains: + num_grids (int): The number of grid intervals for KAN. + k (int): The order of the piecewise polynomial for KAN. + noise_scale (float): The scale of noise injected at initialization. + scale_base_mu (float): The scale of the residual function initialized to + N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma (float): The scale of the residual function initialized to + N(scale_base_mu, scale_base_sigma^2). + scale_sp (float): The scale of the base function spline(x) in KAN. + base_fun (function): The residual function used by KAN (e.g., torch.nn.SiLU()). + grid_eps (float): Determines the partitioning of the grid. If 1, the grid is uniform; if 0, + grid is partitioned by percentiles. + grid_range (list or np.array): The range of the grid, given as a list of two values. + sp_trainable (bool): If True, the scale_sp is trainable. + sb_trainable (bool): If True, the scale_base is trainable. + sparse_init (bool): If True, applies sparse initialization. + """ super().__init__( units=units, thetas_dim=thetas_dim, @@ -182,6 +377,7 @@ def __init__( backcast_length=backcast_length, forecast_length=forecast_length, dropout=dropout, + kan_params=kan_params, ) self.backcast_fc = nn.Linear(thetas_dim, backcast_length) From 41d74039d9f1f1db5556ad05eb9cd128b0970a12 Mon Sep 17 00:00:00 2001 From: SohaibAhmed121 Date: Mon, 13 Jan 2025 15:17:16 +0500 Subject: [PATCH 02/21] End to end integrated Kolmogorov Arnold Networks in NBeats. Also refactored NBeats. --- pytorch_forecasting/models/nbeats/_nbeats.py | 8 +- .../models/nbeats/kan_layer.py | 146 ++++++++--------- .../models/nbeats/sub_modules.py | 152 ++++++++++++------ 3 files changed, 180 insertions(+), 126 deletions(-) diff --git a/pytorch_forecasting/models/nbeats/_nbeats.py b/pytorch_forecasting/models/nbeats/_nbeats.py index d390bc3b1..872e244b1 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats.py +++ b/pytorch_forecasting/models/nbeats/_nbeats.py @@ -175,7 +175,7 @@ def __init__( "sparse_init": sparse_init, } - self.save_hyperparameters() + self.save_hyperparameters(ignore=["loss", "logging_metrics"]) super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) # setup stacks @@ -190,7 +190,7 @@ def __init__( backcast_length=context_length, forecast_length=prediction_length, dropout=self.hparams.dropout, - kan_params=self.hparams.kan_params, + kan_params=self.kan_params, ) elif stack_type == "seasonality": net_block = NBEATSSeasonalBlock( @@ -200,7 +200,7 @@ def __init__( forecast_length=prediction_length, min_period=self.hparams.expansion_coefficient_lengths[stack_id], dropout=self.hparams.dropout, - kan_params=self.hparams.kan_params, + kan_params=self.kan_params, ) elif stack_type == "trend": net_block = NBEATSTrendBlock( @@ -210,7 +210,7 @@ def __init__( backcast_length=context_length, forecast_length=prediction_length, dropout=self.hparams.dropout, - kan_params=self.hparams.kan_params, + kan_params=self.kan_params, ) else: raise ValueError(f"Unknown stack type {stack_type}") diff --git a/pytorch_forecasting/models/nbeats/kan_layer.py b/pytorch_forecasting/models/nbeats/kan_layer.py index 13aa77802..7b5d991df 100644 --- a/pytorch_forecasting/models/nbeats/kan_layer.py +++ b/pytorch_forecasting/models/nbeats/kan_layer.py @@ -1,9 +1,9 @@ +import numpy as np import torch import torch.nn as nn -import numpy as np -def B_batch(x, grid, k=0, extend=True, device="cpu"): +def B_batch(x, grid, k=0, extend=True): """ evaludate x on B-spline bases @@ -16,14 +16,14 @@ def B_batch(x, grid, k=0, extend=True, device="cpu"): k : int the piecewise polynomial order of splines. extend : bool - If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True - device : str - devicde + If True, k points are extended on both ends. If False, no extension + (zero boundary condition). Default: True Returns: -------- spline values : 3D torch.tensor - shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order. + shape (batch, in_dim, G+k). G: the number of grid intervals, + k: spline order. Example ------- @@ -41,16 +41,20 @@ def B_batch(x, grid, k=0, extend=True, device="cpu"): else: B_km1 = B_batch(x[:, :, 0], grid=grid[0], k=k - 1) - value = (x - grid[:, :, : -(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, : -(k + 1)]) * B_km1[:, :, :-1] + ( - grid[:, :, k + 1 :] - x - ) / (grid[:, :, k + 1 :] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:] + value = (x - grid[:, :, : -(k + 1)]) / ( + grid[:, :, k:-1] - grid[:, :, : -(k + 1)] + ) * B_km1[:, :, :-1] + (grid[:, :, k + 1 :] - x) / ( + grid[:, :, k + 1 :] - grid[:, :, 1:(-k)] + ) * B_km1[ + :, :, 1: + ] # in case grid is degenerate value = torch.nan_to_num(value) return value -def coef2curve(x_eval, grid, coef, k, device="cpu"): +def coef2curve(x_eval, grid, coef, k): """ converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis). @@ -65,8 +69,6 @@ def coef2curve(x_eval, grid, coef, k, device="cpu"): shape (in_dim, out_dim, G+k) k : int the piecewise polynomial order of splines. - device : str - devicde Returns: -------- @@ -76,7 +78,7 @@ def coef2curve(x_eval, grid, coef, k, device="cpu"): """ b_splines = B_batch(x_eval, grid, k=k) - y_eval = torch.einsum("ijk,jlk->ijl", b_splines, coef.to(b_splines.device)) + y_eval = torch.einsum("ijk,jlk->ijl", b_splines, coef.to(b_splines)) return y_eval @@ -121,7 +123,7 @@ def curve2coef(x_eval, y_eval, grid, k): XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat) Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval) n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2] - identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device) + identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n) A = XtX + lamb * identity B = Xty coef = (A.pinverse() @ B)[:,:,:,0]""" @@ -164,38 +166,6 @@ def sparse_mask(in_dim, out_dim): class KANLayer(nn.Module): """ KANLayer class - - - Attributes: - ----------- - in_dim: int - input dimension - out_dim: int - output dimension - num: int - the number of grid intervals - k: int - the piecewise polynomial order of splines - noise_scale: float - spline scale at initialization - coef: 2D torch.tensor - coefficients of B-spline bases - scale_base_mu: float - magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_mu - scale_base_sigma: float - magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_sigma - scale_sp: float - mangitude of the spline function spline(x) - base_fun: fun - residual function b(x) - mask: 1D torch.float - mask of spline functions. setting some element of the mask to zero means setting the - corresponding activation to zero function. - grid_eps: float in [0,1] - a hyperparameter used in update_grid_from_samples. When grid_eps = 1, the grid is uniform; - when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 - interpolates between the two extremes. - the id of activation functions that are locked """ def __init__( @@ -213,7 +183,6 @@ def __init__( grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, - device="cpu", sparse_init=False, ): """' @@ -232,16 +201,19 @@ def __init__( noise_scale : float the scale of noise injected at initialization. Default: 0.1. scale_base_mu : float - the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). + the scale of the residual function b(x) is intialized to be + N(scale_base_mu, scale_base_sigma^2). scale_base_sigma : float - the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). + the scale of the residual function b(x) is intialized to be + N(scale_base_mu, scale_base_sigma^2). scale_sp : float the scale of the base function spline(x). base_fun : function residual function b(x). Default: torch.nn.SiLU() grid_eps : float - When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using - percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. + When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is + partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates + between the two extremes. grid_range : list/np.array of shape (2,) setting the range of grids. Default: [-1,1]. sp_trainable : bool @@ -268,21 +240,36 @@ def __init__( self.num = num self.k = k - grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[None, :].expand(self.in_dim, num + 1) + grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[ + None, : + ].expand(self.in_dim, num + 1) grid = extend_grid(grid, k_extend=k) self.grid = torch.nn.Parameter(grid).requires_grad_(False) - noises = (torch.rand(self.num + 1, self.in_dim, self.out_dim) - 1 / 2) * noise_scale / num + noises = ( + (torch.rand(self.num + 1, self.in_dim, self.out_dim) - 1 / 2) + * noise_scale + / num + ) - self.coef = torch.nn.Parameter(curve2coef(self.grid[:, k:-k].permute(1, 0), noises, self.grid, k)) + self.coef = torch.nn.Parameter( + curve2coef(self.grid[:, k:-k].permute(1, 0), noises, self.grid, k) + ) if sparse_init: - self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_(False) + self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_( + False + ) else: - self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_(False) + self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_( + False + ) self.scale_base = torch.nn.Parameter( scale_base_mu * 1 / np.sqrt(in_dim) - + scale_base_sigma * (torch.rand(in_dim, out_dim) * 2 - 1) * 1 / np.sqrt(in_dim) + + scale_base_sigma + * (torch.rand(in_dim, out_dim) * 2 - 1) + * 1 + / np.sqrt(in_dim) ).requires_grad_(sb_trainable) self.scale_sp = torch.nn.Parameter( torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask @@ -307,7 +294,8 @@ def forward(self, x): y : 2D torch.float outputs, shape (number of samples, output dimension) preacts : 3D torch.float - fan out x into activations, shape (number of sampels, output dimension, input dimension) + fan out x into activations, shape (number of sampels, + output dimension, input dimension) postacts : 3D torch.float the outputs of activation functions with preacts as inputs postspline : 3D torch.float @@ -324,7 +312,10 @@ def forward(self, x): base = self.base_fun(x) # (batch, in_dim) y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k) - y = self.scale_base[None, :, :] * base[:, :, None] + self.scale_sp[None, :, :] * y + y = ( + self.scale_base[None, :, :] * base[:, :, None] + + self.scale_sp[None, :, :] * y + ) y = self.mask[None, :, :] * y y = torch.sum(y, dim=1) return y @@ -352,8 +343,8 @@ def update_grid_from_samples(self, x, mode="sample"): """ batch = x.shape[0] - # x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size) - # .permute(1, 0) + # x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, )) + # .reshape(batch, self.size).permute(1, 0) x_pos = torch.sort(x, dim=0)[0] y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) num_interval = self.grid.shape[1] - 1 - 2 * self.k @@ -362,16 +353,16 @@ def get_grid(num_interval): ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] grid_adaptive = x_pos[ids, :].permute(1, 0) margin = 0.00 - h = (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) / num_interval + h = ( + grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin + ) / num_interval grid_uniform = ( grid_adaptive[:, [0]] - margin + h * torch.arange( num_interval + 1, - )[ - None, : - ].to(x.device) + )[None, :] ) grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive return grid @@ -384,8 +375,6 @@ def get_grid(num_interval): y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) self.grid.data = extend_grid(grid, k_extend=self.k) - # print('x_pos 2', x_pos.shape) - # print('y_eval 2', y_eval.shape) self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) def initialize_grid_from_parent(self, parent, x, mode="sample"): @@ -424,7 +413,8 @@ def get_grid(num_interval): ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] grid_adaptive = x_pos[ids, :].permute(1,0) h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval - grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device) + grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,) + [None, :] grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive return grid""" @@ -433,11 +423,14 @@ def get_grid(num_interval): x_pos = parent.grid[:, parent.k : -parent.k] # print('x_pos', x_pos) sp2 = KANLayer( - in_dim=1, out_dim=self.in_dim, k=1, num=x_pos.shape[1] - 1, scale_base_mu=0.0, scale_base_sigma=0.0 - ).to(x.device) + in_dim=1, + out_dim=self.in_dim, + k=1, + num=x_pos.shape[1] - 1, + scale_base_mu=0.0, + scale_base_sigma=0.0, + ) - # print('sp2_grid', sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim)) - # print('sp2_coef_shape', sp2.coef.shape) sp2_coef = curve2coef( sp2.grid[:, sp2.k : -sp2.k].permute(1, 0).expand(-1, self.in_dim), x_pos.permute(1, 0).unsqueeze(dim=2), @@ -445,7 +438,7 @@ def get_grid(num_interval): k=1, ).permute(1, 0, 2) sp2.coef.data = sp2_coef - percentile = torch.linspace(-1, 1, self.num + 1).to(self.device) + percentile = torch.linspace(-1, 1, self.num + 1) grid = sp2(percentile.unsqueeze(dim=1))[0].permute(1, 0) return grid @@ -482,7 +475,9 @@ def get_subset(self, in_id, out_id): >>> kanlayer_small.in_dim, kanlayer_small.out_dim (2, 3) """ - spb = KANLayer(len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun) + spb = KANLayer( + len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun + ) spb.grid.data = self.grid[in_id] spb.coef.data = self.coef[in_id][:, out_id] spb.scale_base.data = self.scale_base[in_id][:, out_id] @@ -495,7 +490,8 @@ def get_subset(self, in_id, out_id): def swap(self, i1, i2, mode="in"): """ - swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out') + swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output + (if mode == 'out') Args: ----- diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index 3e6193d49..b80382c64 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -12,6 +12,9 @@ def linear(input_size, output_size, bias=True, dropout: int = None): + """ + Initialize linear layers for MLP block layers. + """ lin = nn.Linear(input_size, output_size, bias=bias) if dropout is not None: return nn.Sequential(nn.Dropout(dropout), lin) @@ -22,6 +25,9 @@ def linear(input_size, output_size, bias=True, dropout: int = None): def linspace( backcast_length: int, forecast_length: int, centered: bool = False ) -> Tuple[np.ndarray, np.ndarray]: + """ + Generate linear spaced values for backcast and forecast. + """ if centered: norm = max(backcast_length, forecast_length) start = -backcast_length @@ -46,16 +52,48 @@ def __init__( num_block_layers=4, backcast_length=10, forecast_length=5, - share_thetas=False, dropout=0.1, kan_params={}, ): + """ + Initialize NBeatsSeasonalBlock + + Args: + units: The number of units in the mlp/kan layers. + thetas_dim: The dimension of the parameterized output for the block. + num_block_layers: Number of fully connected mlp/kan layers. Default: 4. + backcast_length: The length of the backcast. Defines how many time units + from the past are used to predict the future. Default: 10. + forecast_length: The length of the forecast, i.e., the number of time steps + ahead to predict. Default: 5. + dropout: The dropout rate applied to the fully connected mlp layers to + prevent overfitting. Default: 0.1. + kan_params (dict): Parameters specific to the KAN layer + (used for modeling using KAN). Default: empty dictionary. + Contains: + num_grids (int): The number of grid intervals for KAN. + k (int): The order of the piecewise polynomial for KAN. + noise_scale (float): The scale of noise injected at initialization. + scale_base_mu (float): The scale of the residual function + initialized to N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma (float): The scale of the residual function + initialized to N(scale_base_mu, scale_base_sigma^2). + scale_sp (float): The scale of the base function spline(x) in KAN. + base_fun (function): The residual function used by + KAN (e.g., torch.nn.SiLU()). + grid_eps (float): Determines the partitioning of the grid. If 1, + the grid is uniform; if 0, grid is partitioned by percentiles. + grid_range (list or np.array): The range of the grid, given as + a list of two values. + sp_trainable (bool): If True, the scale_sp is trainable. + sb_trainable (bool): If True, the scale_base is trainable. + sparse_init (bool): If True, applies sparse initialization. + """ super().__init__() self.units = units self.thetas_dim = thetas_dim self.backcast_length = backcast_length self.forecast_length = forecast_length - self.share_thetas = share_thetas self.kan_params = kan_params if self.kan_params["use_kan"]: @@ -77,47 +115,63 @@ def __init__( sparse_init=self.kan_params["sparse_init"], ) ] - # Additional KANLayers for deeper structure + + # Add additional layers for deeper structure for _ in range(num_block_layers - 1): - layers.extend( - [ - KANLayer( - in_dim=units, - out_dim=units, - num=self.kan_params["num_grids"], - k=self.kan_params["k"], - noise_scale=self.kan_params["noise_scale"], - scale_base_mu=self.kan_params["scale_base_mu"], - scale_base_sigma=self.kan_params["scale_base_sigma"], - scale_sp=self.kan_params["scale_sp"], - base_fun=self.kan_params["base_fun"], - grid_eps=self.kan_params["grid_eps"], - grid_range=self.kan_params["grid_range"], - sp_trainable=self.kan_params["sp_trainable"], - sb_trainable=self.kan_params["sb_trainable"], - sparse_init=self.kan_params["sparse_init"], - device="cpu", # Assuming you are using the "cpu" device - ) - ] + layers.append( + KANLayer( + in_dim=units, + out_dim=units, + num=self.kan_params["num_grids"], + k=self.kan_params["k"], + noise_scale=self.kan_params["noise_scale"], + scale_base_mu=self.kan_params["scale_base_mu"], + scale_base_sigma=self.kan_params["scale_base_sigma"], + scale_sp=self.kan_params["scale_sp"], + base_fun=self.kan_params["base_fun"], + grid_eps=self.kan_params["grid_eps"], + grid_range=self.kan_params["grid_range"], + sp_trainable=self.kan_params["sp_trainable"], + sb_trainable=self.kan_params["sb_trainable"], + sparse_init=self.kan_params["sparse_init"], + ) ) - self.fc = nn.Sequential(*layers) + # Define the fully connected layers + self.fc = nn.Sequential(*layers) + + # Define the theta layers + self.theta_f_fc = self.theta_b_fc = KANLayer( + in_dim=units, + out_dim=thetas_dim, + num=self.kan_params["num_grids"], + k=self.kan_params["k"], + noise_scale=self.kan_params["noise_scale"], + scale_base_mu=self.kan_params["scale_base_mu"], + scale_base_sigma=self.kan_params["scale_base_sigma"], + scale_sp=self.kan_params["scale_sp"], + base_fun=self.kan_params["base_fun"], + grid_eps=self.kan_params["grid_eps"], + grid_range=self.kan_params["grid_range"], + sp_trainable=self.kan_params["sp_trainable"], + sb_trainable=self.kan_params["sb_trainable"], + sparse_init=self.kan_params["sparse_init"], + ) - fc_stack = [ - nn.Linear(backcast_length, units), - nn.ReLU(), - ] - for _ in range(num_block_layers - 1): - fc_stack.extend([linear(units, units, dropout=dropout), nn.ReLU()]) - self.fc = nn.Sequential(*fc_stack) - - if share_thetas: - self.theta_f_fc = self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) else: - self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) - self.theta_f_fc = nn.Linear(units, thetas_dim, bias=False) + fc_stack = [ + nn.Linear(backcast_length, units), + nn.ReLU(), + ] + for _ in range(num_block_layers - 1): + fc_stack.extend([linear(units, units, dropout=dropout), nn.ReLU()]) + self.fc = nn.Sequential(*fc_stack) + self.theta_f_fc = self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) def forward(self, x): + """ + Pass through the fully connected mlp/kan layers and returns the output. + """ return self.fc(x) @@ -138,9 +192,9 @@ def __init__( Initialize NBeatsSeasonalBlock Args: - units: The number of units in the mlp/kan layers. Default: 256. + units: The number of units in the mlp/kan layers. thetas_dim: The dimension of the parameterized output for the block. - If None, it is inferred. Default: None. + If None, it is inferred. num_block_layers: Number of fully connected mlp/kan layers. Default: 4. backcast_length: The length of the backcast. Defines how many time units from the past are used to predict the future. Default: 10. @@ -184,7 +238,6 @@ def __init__( num_block_layers=num_block_layers, backcast_length=backcast_length, forecast_length=forecast_length, - share_thetas=True, dropout=dropout, kan_params=kan_params, ) @@ -219,7 +272,9 @@ def __init__( self.register_buffer("S_forecast", torch.cat([s1_f, s2_f])) def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: - """Computes the backcast and forecast outputs for the given input tensor.""" + """ + Computes the backcast and forecast outputs for the given input tensor. + """ x = super().forward(x) amplitudes_backward = self.theta_b_fc(x) backcast = amplitudes_backward.mm(self.S_backcast) @@ -252,9 +307,9 @@ def __init__( Initialize NBeatsSeasonalBlock Args: - units: The number of units in the mlp/kan layers. Default: 256. + units: The number of units in the mlp/kan layers. thetas_dim: The dimension of the parameterized output for the block. - If None, it is inferred. Default: None. + If None, it is inferred. num_block_layers: Number of fully connected mlp/kan layers. Default: 4. backcast_length: The length of the backcast. Defines how many time units from the past are used to predict the future. Default: 10. @@ -289,7 +344,6 @@ def __init__( num_block_layers=num_block_layers, backcast_length=backcast_length, forecast_length=forecast_length, - share_thetas=True, dropout=dropout, kan_params=kan_params, ) @@ -313,6 +367,9 @@ def __init__( self.register_buffer("T_forecast", coefficients * norm) def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes the backcast and forecast outputs for the given input tensor. + """ x = super().forward(x) backcast = self.theta_b_fc(x).mm(self.T_backcast) forecast = self.theta_f_fc(x).mm(self.T_forecast) @@ -334,9 +391,9 @@ def __init__( Initialize NBeatsSeasonalBlock Args: - units: The number of units in the mlp/kan layers. Default: 256. + units: The number of units in the mlp/kan layers. thetas_dim: The dimension of the parameterized output for the block. - If None, it is inferred. Default: None. + If None, it is inferred. num_block_layers: Number of fully connected mlp/kan layers. Default: 4. backcast_length: The length of the backcast. Defines how many time units from the past are used to predict the future. Default: 10. @@ -379,9 +436,10 @@ def __init__( self.forecast_fc = nn.Linear(thetas_dim, forecast_length) def forward(self, x): + """ + Computes the backcast and forecast outputs for the given input tensor. + """ x = super().forward(x) - theta_b = F.relu(self.theta_b_fc(x)) theta_f = F.relu(self.theta_f_fc(x)) - return self.backcast_fc(theta_b), self.forecast_fc(theta_f) From 594102d1d13409d5fb00c7686480b05a3cf77f00 Mon Sep 17 00:00:00 2001 From: SohaibAhmed121 Date: Mon, 13 Jan 2025 15:43:36 +0500 Subject: [PATCH 03/21] Resolved import error. --- pytorch_forecasting/models/nbeats/sub_modules.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index b80382c64..c236507a1 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -4,12 +4,13 @@ from typing import Tuple -from kan_layer import KANLayer import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from pytorch_forecasting.models.nbeats.kan_layer import KANLayer + def linear(input_size, output_size, bias=True, dropout: int = None): """ From c8ccfaf6a51e57af6cf810f5f822bc0752d9c1cd Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Thu, 23 Jan 2025 04:46:56 -0800 Subject: [PATCH 04/21] Refactored NBEATS and added support for grid updation during training while using KAN blocks in NBEATS. --- pytorch_forecasting/models/nbeats/_nbeats.py | 106 ++++++----- .../models/nbeats/grid_callback.py | 38 ++++ .../models/nbeats/kan_layer.py | 167 +----------------- .../models/nbeats/sub_modules.py | 89 +++++----- 4 files changed, 150 insertions(+), 250 deletions(-) create mode 100644 pytorch_forecasting/models/nbeats/grid_callback.py diff --git a/pytorch_forecasting/models/nbeats/_nbeats.py b/pytorch_forecasting/models/nbeats/_nbeats.py index beb66314e..77fffd332 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats.py +++ b/pytorch_forecasting/models/nbeats/_nbeats.py @@ -30,19 +30,6 @@ def __init__( expansion_coefficient_lengths: Optional[List[int]] = None, prediction_length: int = 1, context_length: int = 1, - use_kan: bool = False, - num_grids: int = 5, - k: int = 3, - noise_scale: float = 0.5, - scale_base_mu: float = 0.0, - scale_base_sigma: float = 1.0, - scale_sp: float = 1.0, - base_fun: callable = torch.nn.SiLU(), - grid_eps: float = 0.02, - grid_range: List[int] = [-1, 1], - sp_trainable: bool = True, - sb_trainable: bool = True, - sparse_init: bool = False, dropout: float = 0.1, learning_rate: float = 1e-2, log_interval: int = -1, @@ -53,6 +40,19 @@ def __init__( reduce_on_plateau_patience: int = 1000, backcast_loss_ratio: float = 0.0, logging_metrics: nn.ModuleList = None, + use_kan: bool = False, + num: int = 5, + k: int = 3, + noise_scale: float = 0.5, + scale_base_mu: float = 0.0, + scale_base_sigma: float = 1.0, + scale_sp: float = 1.0, + base_fun: callable = None, + grid_eps: float = 0.02, + grid_range: List[int] = None, + sp_trainable: bool = True, + sb_trainable: bool = True, + sparse_init: bool = False, **kwargs, ): """ @@ -101,45 +101,49 @@ def __init__( context_length: Number of time units that condition the predictions. Also known as 'lookback period'. Should be between 1-10 times the prediction length. - num_grids : Parameter for KAN layer. the number of grid intervals = G. - Default: 5. + backcast_loss_ratio: weight of backcast in comparison to forecast when + calculating the loss. A weight of 1.0 means that forecast and + backcast loss is weighted the same (regardless of backcast and forecast + lengths). Defaults to 0.0, i.e. no weight. + loss: loss to optimize. Defaults to MASE(). + log_gradient_flow: if to log gradient flow, this takes time and should be + only done to diagnose training failures. + reduce_on_plateau_patience (int): patience after which learning rate is + reduced by a factor of 10 + logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that + are logged during training. Defaults to + nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, + kan layers are used in nbeats block else mlp layers are used. Default: + false. + num : Parameter for KAN layer. the number of grid intervals = G. + Default: 5. used when use_kan is True. k : Parameter for KAN layer. the order of piecewise polynomial. Default: 3. + used when use_kan is True. noise_scale : Parameter for KAN layer. the scale of noise injected at - initialization. Default: 0.1. + initialization. Default: 0.1. used when use_kan is True. scale_base_mu : Parameter for KAN layer. the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). - Deafult: 0.0 + Deafult: 0.0. used when use_kan is True. scale_base_sigma : Parameter for KAN layer. the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). - Deafult: 1.0 + Deafult: 1.0. used when use_kan is True. scale_sp : Parameter for KAN layer. the scale of the base function - spline(x). Deafult: 1.0 + spline(x). Deafult: 1.0. used when use_kan is True. base_fun : Parameter for KAN layer. residual function b(x). - Default: torch.nn.SiLU() + Default: None. used when use_kan is True. grid_eps : Parameter for KAN layer. When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. - 0 < grid_eps < 1 interpolates between the two extremes. Deafult: 0.02 + 0 < grid_eps < 1 interpolates between the two extremes. Deafult: 0.02. + used when use_kan is True. grid_range : Parameter for KAN layer. list/np.array of shape (2,). setting - the range of grids. - Default: [-1,1]. + the range of grids. Default: None. used when use_kan is True. sp_trainable : Parameter for KAN layer. If true, scale_sp is trainable. - Default: True. + Default: True. used when use_kan is True. sb_trainable : Parameter for KAN layer. If true, scale_base is trainable. - Default: True. + Default: True. used when use_kan is True. sparse_init : Parameter for KAN layer. if sparse_init = True, sparse - initialization is applied. Default: False. - backcast_loss_ratio: weight of backcast in comparison to forecast when - calculating the loss. A weight of 1.0 means that forecast and - backcast loss is weighted the same (regardless of backcast and forecast - lengths). Defaults to 0.0, i.e. no weight. - loss: loss to optimize. Defaults to MASE(). - log_gradient_flow: if to log gradient flow, this takes time and should be - only done to diagnose training failures. - reduce_on_plateau_patience (int): patience after which learning rate is - reduced by a factor of 10 - logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that - are logged during training. Defaults to - nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + initialization is applied. Default: False. used when use_kan is True. **kwargs: additional arguments to :py:class:`~BaseModel`. """ # noqa: E501 if expansion_coefficient_lengths is None: @@ -154,14 +158,17 @@ def __init__( num_blocks = [3, 3] if stack_types is None: stack_types = ["trend", "seasonality"] + if base_fun is None: + base_fun = torch.nn.SiLU() + if grid_range is None: + grid_range = [-1, 1] if logging_metrics is None: logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) if loss is None: loss = MASE() # Bundle KAN parameters into a dictionary self.kan_params = { - "use_kan": use_kan, - "num_grids": num_grids, + "num": num, "k": k, "noise_scale": noise_scale, "scale_base_mu": scale_base_mu, @@ -174,6 +181,7 @@ def __init__( "sb_trainable": sb_trainable, "sparse_init": sparse_init, } + self.use_kan = use_kan self.save_hyperparameters(ignore=["loss", "logging_metrics"]) super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) @@ -191,6 +199,7 @@ def __init__( forecast_length=prediction_length, dropout=self.hparams.dropout, kan_params=self.kan_params, + use_kan=use_kan, ) elif stack_type == "seasonality": net_block = NBEATSSeasonalBlock( @@ -201,6 +210,7 @@ def __init__( min_period=self.hparams.expansion_coefficient_lengths[stack_id], dropout=self.hparams.dropout, kan_params=self.kan_params, + use_kan=use_kan, ) elif stack_type == "trend": net_block = NBEATSTrendBlock( @@ -211,6 +221,7 @@ def __init__( forecast_length=prediction_length, dropout=self.hparams.dropout, kan_params=self.kan_params, + use_kan=use_kan, ) else: raise ValueError(f"Unknown stack type {stack_type}") @@ -291,6 +302,21 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: ), ) + def update_kan_grid(self): + """ + Updates grid of KAN layers when using KAN layers in NBEATSBlock. + """ + if self.use_kan: + for block in self.net_blocks: + # updation logic taken from + # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 + for i, layer in enumerate(block.fc): + # update basis KAN layers' grid + layer.update_grid_from_samples(block.outputs[i]) + # update theta backward and theta forward KAN layers' grid + block.theta_b_fc.update_grid_from_samples(block.outputs[i + 1]) + block.theta_f_fc.update_grid_from_samples(block.outputs[i + 1]) + @classmethod def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): """ diff --git a/pytorch_forecasting/models/nbeats/grid_callback.py b/pytorch_forecasting/models/nbeats/grid_callback.py new file mode 100644 index 000000000..3c36b1ef4 --- /dev/null +++ b/pytorch_forecasting/models/nbeats/grid_callback.py @@ -0,0 +1,38 @@ +from lightning.pytorch.callbacks import Callback + + +class GridUpdateCallback(Callback): + """ + Custom callback to update the grid of the model during training at regular + intervals. + + Attributes: + update_interval (int): The frequency at which the grid is updated. + """ + + def __init__(self, update_interval): + """ + Initializes the callback with the given update interval. + + Args: + update_interval (int): The frequency at which the grid is updated. + """ + self.update_interval = update_interval + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + """ + Hook that is called at the end of each training batch. + Updates the grid of KAN layers if the current step is a multiple of the update + interval. + + Args: + trainer (Trainer): The PyTorch Lightning Trainer object. + pl_module (LightningModule): The model being trained (LightningModule). + outputs (Any): Outputs from the model for the current batch. + batch (Any): The current batch of data. + batch_idx (int): Index of the current batch. + """ + # Check if the current step is a multiple of the update interval + if (trainer.global_step + 1) % self.update_interval == 0: + # Call the model's update_kan_grid method + pl_module.update_kan_grid() diff --git a/pytorch_forecasting/models/nbeats/kan_layer.py b/pytorch_forecasting/models/nbeats/kan_layer.py index 7b5d991df..d9357bb46 100644 --- a/pytorch_forecasting/models/nbeats/kan_layer.py +++ b/pytorch_forecasting/models/nbeats/kan_layer.py @@ -1,3 +1,6 @@ +# The following implementation of KANLayer is inspired by the pykan library. +# Reference: https://github.com/KindXiaoming/pykan/blob/master/kan/KANLayer.py + import numpy as np import torch import torch.nn as nn @@ -105,7 +108,6 @@ def curve2coef(x_eval, y_eval, grid, k): coef : 3D torch.tensor shape (in_dim, out_dim, G+k) """ - # print('haha', x_eval.shape, y_eval.shape, grid.shape) batch = x_eval.shape[0] in_dim = x_eval.shape[1] out_dim = y_eval.shape[2] @@ -118,16 +120,6 @@ def curve2coef(x_eval, y_eval, grid, k): except Exception as e: print(f"lstsq failed with error: {e}") - # manual psuedo-inverse - """lamb=1e-8 - XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat) - Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval) - n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2] - identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n) - A = XtX + lamb * identity - B = Xty - coef = (A.pinverse() @ B)[:,:,:,0]""" - return coef @@ -343,8 +335,6 @@ def update_grid_from_samples(self, x, mode="sample"): """ batch = x.shape[0] - # x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, )) - # .reshape(batch, self.size).permute(1, 0) x_pos = torch.sort(x, dim=0)[0] y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) num_interval = self.grid.shape[1] - 1 - 2 * self.k @@ -368,7 +358,6 @@ def get_grid(num_interval): return grid grid = get_grid(num_interval) - if mode == "grid": sample_grid = get_grid(2 * num_interval) x_pos = sample_grid.permute(1, 0) @@ -376,153 +365,3 @@ def get_grid(num_interval): self.grid.data = extend_grid(grid, k_extend=self.k) self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) - - def initialize_grid_from_parent(self, parent, x, mode="sample"): - """ - update grid from a parent KANLayer & samples - - Args: - ----- - parent : KANLayer - a parent KANLayer (whose grid is usually coarser than the current model) - x : 2D torch.float - inputs, shape (number of samples, input dimension) - - Returns: - -------- - None - - Example - ------- - >>> batch = 100 - >>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) - >>> print(parent_model.grid.data) - >>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3) - >>> x = torch.normal(0,1,size=(batch, 1)) - >>> model.initialize_grid_from_parent(parent_model, x) - >>> print(model.grid.data) - """ - # shrink grid - x_pos = torch.sort(x, dim=0)[0] - y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k) - num_interval = self.grid.shape[1] - 1 - 2 * self.k - - """ - # based on samples - def get_grid(num_interval): - ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] - grid_adaptive = x_pos[ids, :].permute(1,0) - h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval - grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,) - [None, :] - grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive - return grid""" - - # based on interpolating parent grid - def get_grid(num_interval): - x_pos = parent.grid[:, parent.k : -parent.k] - # print('x_pos', x_pos) - sp2 = KANLayer( - in_dim=1, - out_dim=self.in_dim, - k=1, - num=x_pos.shape[1] - 1, - scale_base_mu=0.0, - scale_base_sigma=0.0, - ) - - sp2_coef = curve2coef( - sp2.grid[:, sp2.k : -sp2.k].permute(1, 0).expand(-1, self.in_dim), - x_pos.permute(1, 0).unsqueeze(dim=2), - sp2.grid[:, :], - k=1, - ).permute(1, 0, 2) - sp2.coef.data = sp2_coef - percentile = torch.linspace(-1, 1, self.num + 1) - grid = sp2(percentile.unsqueeze(dim=1))[0].permute(1, 0) - return grid - - grid = get_grid(num_interval) - - if mode == "grid": - sample_grid = get_grid(2 * num_interval) - x_pos = sample_grid.permute(1, 0) - y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k) - - grid = extend_grid(grid, k_extend=self.k) - self.grid.data = grid - self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) - - def get_subset(self, in_id, out_id): - """ - get a smaller KANLayer from a larger KANLayer (used for pruning) - - Args: - ----- - in_id : list - id of selected input neurons - out_id : list - id of selected output neurons - - Returns: - -------- - spb : KANLayer - - Example - ------- - >>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3) - >>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3]) - >>> kanlayer_small.in_dim, kanlayer_small.out_dim - (2, 3) - """ - spb = KANLayer( - len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun - ) - spb.grid.data = self.grid[in_id] - spb.coef.data = self.coef[in_id][:, out_id] - spb.scale_base.data = self.scale_base[in_id][:, out_id] - spb.scale_sp.data = self.scale_sp[in_id][:, out_id] - spb.mask.data = self.mask[in_id][:, out_id] - - spb.in_dim = len(in_id) - spb.out_dim = len(out_id) - return spb - - def swap(self, i1, i2, mode="in"): - """ - swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output - (if mode == 'out') - - Args: - ----- - i1 : int - i2 : int - mode : str - mode = 'in' or 'out' - - Returns: - -------- - None - - Example - ------- - >>> from kan.KANLayer import * - >>> model = KANLayer(in_dim=2, out_dim=2, num=5, k=3) - >>> print(model.coef) - >>> model.swap(0,1,mode='in') - >>> print(model.coef) - """ - with torch.no_grad(): - - def swap_(data, i1, i2, mode="in"): - if mode == "in": - data[i1], data[i2] = data[i2].clone(), data[i1].clone() - elif mode == "out": - data[:, i1], data[:, i2] = data[:, i2].clone(), data[:, i1].clone() - - if mode == "in": - swap_(self.grid.data, i1, i2, mode="in") - swap_(self.coef.data, i1, i2, mode=mode) - swap_(self.scale_base.data, i1, i2, mode=mode) - swap_(self.scale_sp.data, i1, i2, mode=mode) - swap_(self.mask.data, i1, i2, mode=mode) diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index c236507a1..cb0b32962 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -54,7 +54,8 @@ def __init__( backcast_length=10, forecast_length=5, dropout=0.1, - kan_params={}, + kan_params=None, + use_kan=False, ): """ Initialize NBeatsSeasonalBlock @@ -70,7 +71,7 @@ def __init__( dropout: The dropout rate applied to the fully connected mlp layers to prevent overfitting. Default: 0.1. kan_params (dict): Parameters specific to the KAN layer - (used for modeling using KAN). Default: empty dictionary. + (used for modeling using KAN). Default: None. Contains: num_grids (int): The number of grid intervals for KAN. k (int): The order of the piecewise polynomial for KAN. @@ -89,6 +90,9 @@ def __init__( sp_trainable (bool): If True, the scale_sp is trainable. sb_trainable (bool): If True, the scale_base is trainable. sparse_init (bool): If True, applies sparse initialization. + use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, + kan layers are used in nbeats block else mlp layers are used. Default: + false. """ super().__init__() self.units = units @@ -96,24 +100,14 @@ def __init__( self.backcast_length = backcast_length self.forecast_length = forecast_length self.kan_params = kan_params + self.use_kan = use_kan - if self.kan_params["use_kan"]: + if self.use_kan: layers = [ KANLayer( in_dim=backcast_length, out_dim=units, - num=self.kan_params["num_grids"], - k=self.kan_params["k"], - noise_scale=self.kan_params["noise_scale"], - scale_base_mu=self.kan_params["scale_base_mu"], - scale_base_sigma=self.kan_params["scale_base_sigma"], - scale_sp=self.kan_params["scale_sp"], - base_fun=self.kan_params["base_fun"], - grid_eps=self.kan_params["grid_eps"], - grid_range=self.kan_params["grid_range"], - sp_trainable=self.kan_params["sp_trainable"], - sb_trainable=self.kan_params["sb_trainable"], - sparse_init=self.kan_params["sparse_init"], + **self.kan_params, ) ] @@ -123,18 +117,7 @@ def __init__( KANLayer( in_dim=units, out_dim=units, - num=self.kan_params["num_grids"], - k=self.kan_params["k"], - noise_scale=self.kan_params["noise_scale"], - scale_base_mu=self.kan_params["scale_base_mu"], - scale_base_sigma=self.kan_params["scale_base_sigma"], - scale_sp=self.kan_params["scale_sp"], - base_fun=self.kan_params["base_fun"], - grid_eps=self.kan_params["grid_eps"], - grid_range=self.kan_params["grid_range"], - sp_trainable=self.kan_params["sp_trainable"], - sb_trainable=self.kan_params["sb_trainable"], - sparse_init=self.kan_params["sparse_init"], + **self.kan_params, ) ) @@ -145,18 +128,7 @@ def __init__( self.theta_f_fc = self.theta_b_fc = KANLayer( in_dim=units, out_dim=thetas_dim, - num=self.kan_params["num_grids"], - k=self.kan_params["k"], - noise_scale=self.kan_params["noise_scale"], - scale_base_mu=self.kan_params["scale_base_mu"], - scale_base_sigma=self.kan_params["scale_base_sigma"], - scale_sp=self.kan_params["scale_sp"], - base_fun=self.kan_params["base_fun"], - grid_eps=self.kan_params["grid_eps"], - grid_range=self.kan_params["grid_range"], - sp_trainable=self.kan_params["sp_trainable"], - sb_trainable=self.kan_params["sb_trainable"], - sparse_init=self.kan_params["sparse_init"], + **self.kan_params, ) else: @@ -173,7 +145,17 @@ def forward(self, x): """ Pass through the fully connected mlp/kan layers and returns the output. """ - return self.fc(x) + # outputs logic taken from + # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 + self.outputs = [] + self.outputs.append(x.clone().detach()) + for layer in self.fc: + x = layer(x) # Pass data through the current layer + # store outputs for updating grids of theta_fc when using KAN + self.outputs.append(x.clone().detach()) + # for updating grids of theta_b_fc and theta_f_fc when using KAN + self.outputs.append(x.clone().detach()) + return x # Return final output class NBEATSSeasonalBlock(NBEATSBlock): @@ -187,7 +169,8 @@ def __init__( nb_harmonics=None, min_period=1, dropout=0.1, - kan_params={}, + kan_params=None, + use_kan=False, ): """ Initialize NBeatsSeasonalBlock @@ -207,7 +190,7 @@ def __init__( dropout: The dropout rate applied to the fully connected mlp layers to prevent overfitting. Default: 0.1. kan_params (dict): Parameters specific to the KAN layer - (used for modeling using KAN). Default: empty dictionary. + (used for modeling using KAN). Default: None. Contains: num_grids (int): The number of grid intervals for KAN. k (int): The order of the piecewise polynomial for KAN. @@ -226,6 +209,9 @@ def __init__( sp_trainable (bool): If True, the scale_sp is trainable. sb_trainable (bool): If True, the scale_base is trainable. sparse_init (bool): If True, applies sparse initialization. + use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, + kan layers are used in nbeats block else mlp layers are used. Default: + false. """ if nb_harmonics: thetas_dim = nb_harmonics @@ -241,6 +227,7 @@ def __init__( forecast_length=forecast_length, dropout=dropout, kan_params=kan_params, + use_kan=use_kan, ) backcast_linspace, forecast_linspace = linspace( @@ -302,7 +289,8 @@ def __init__( backcast_length=10, forecast_length=5, dropout=0.1, - kan_params={}, + kan_params=None, + use_kan=False, ): """ Initialize NBeatsSeasonalBlock @@ -319,7 +307,7 @@ def __init__( dropout: The dropout rate applied to the fully connected mlp layers to prevent overfitting. Default: 0.1. kan_params (dict): Parameters specific to the KAN layer - (used for modeling using KAN). Default: empty dictionary. + (used for modeling using KAN). Default: None. Contains: num_grids (int): The number of grid intervals for KAN. k (int): The order of the piecewise polynomial for KAN. @@ -338,6 +326,9 @@ def __init__( sp_trainable (bool): If True, the scale_sp is trainable. sb_trainable (bool): If True, the scale_base is trainable. sparse_init (bool): If True, applies sparse initialization. + use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, + kan layers are used in nbeats block else mlp layers are used. Default: + false. """ super().__init__( units=units, @@ -347,6 +338,7 @@ def __init__( forecast_length=forecast_length, dropout=dropout, kan_params=kan_params, + use_kan=use_kan, ) backcast_linspace, forecast_linspace = linspace( @@ -386,7 +378,8 @@ def __init__( backcast_length=10, forecast_length=5, dropout=0.1, - kan_params={}, + kan_params=None, + use_kan=False, ): """ Initialize NBeatsSeasonalBlock @@ -403,7 +396,7 @@ def __init__( dropout: The dropout rate applied to the fully connected mlp layers to prevent overfitting. Default: 0.1. kan_params (dict): Parameters specific to the KAN layer - (used for modeling using KAN). Default: empty dictionary. + (used for modeling using KAN). Default: None. Contains: num_grids (int): The number of grid intervals for KAN. k (int): The order of the piecewise polynomial for KAN. @@ -422,6 +415,9 @@ def __init__( sp_trainable (bool): If True, the scale_sp is trainable. sb_trainable (bool): If True, the scale_base is trainable. sparse_init (bool): If True, applies sparse initialization. + use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, + kan layers are used in nbeats block else mlp layers are used. Default: + false. """ super().__init__( units=units, @@ -431,6 +427,7 @@ def __init__( forecast_length=forecast_length, dropout=dropout, kan_params=kan_params, + use_kan=use_kan, ) self.backcast_fc = nn.Linear(thetas_dim, backcast_length) From 348da97e3ed7f6b93fb011a9903c3965c4457600 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Thu, 23 Jan 2025 05:49:46 -0800 Subject: [PATCH 05/21] Refactored comments. --- pytorch_forecasting/models/nbeats/sub_modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index cb0b32962..7ddf17a20 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -151,9 +151,9 @@ def forward(self, x): self.outputs.append(x.clone().detach()) for layer in self.fc: x = layer(x) # Pass data through the current layer - # store outputs for updating grids of theta_fc when using KAN + # storing outputs for updating grids of self.fc when using KAN self.outputs.append(x.clone().detach()) - # for updating grids of theta_b_fc and theta_f_fc when using KAN + # storing for updating grids of theta_b_fc and theta_f_fc when using KAN self.outputs.append(x.clone().detach()) return x # Return final output From 1ab0da0dd14577e3d51d3f7756e9215e8bd59057 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Sat, 1 Feb 2025 11:56:57 -0800 Subject: [PATCH 06/21] Added example to use grid_update_callback and added correct device to tensors during training. --- examples/nbeats_with_kan.py | 106 ++++++++++++++++++ .../models/nbeats/grid_callback.py | 4 + .../models/nbeats/kan_layer.py | 5 +- 3 files changed, 111 insertions(+), 4 deletions(-) create mode 100644 examples/nbeats_with_kan.py diff --git a/examples/nbeats_with_kan.py b/examples/nbeats_with_kan.py new file mode 100644 index 000000000..925e8dcf0 --- /dev/null +++ b/examples/nbeats_with_kan.py @@ -0,0 +1,106 @@ +import sys + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +import pandas as pd + +from pytorch_forecasting import NBeats, TimeSeriesDataSet +from pytorch_forecasting.data import NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data +from pytorch_forecasting.models.nbeats.grid_callback import GridUpdateCallback + +sys.path.append("..") + + +print("load data") +data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100) +data["static"] = 2 +data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") +validation = data.series.sample(20) + + +max_encoder_length = 150 +max_prediction_length = 20 + +training_cutoff = data["time_idx"].max() - max_prediction_length + +context_length = max_encoder_length +prediction_length = max_prediction_length + +training = TimeSeriesDataSet( + data[lambda x: x.time_idx < training_cutoff], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + min_encoder_length=context_length, + max_encoder_length=context_length, + max_prediction_length=prediction_length, + min_prediction_length=prediction_length, + time_varying_unknown_reals=["value"], + randomize_length=None, + add_relative_time_idx=False, + add_target_scales=False, +) + +validation = TimeSeriesDataSet.from_dataset( + training, data, min_prediction_idx=training_cutoff +) +batch_size = 128 +train_dataloader = training.to_dataloader( + train=True, batch_size=batch_size, num_workers=0 +) +val_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 +) + + +early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min" +) +# updates KAN layers' grid after every 3 steps during training +grid_update_callback = GridUpdateCallback(update_interval=3) + +trainer = pl.Trainer( + max_epochs=1, + accelerator="auto", + gradient_clip_val=0.1, + callbacks=[early_stop_callback, grid_update_callback], + limit_train_batches=15, + # limit_val_batches=1, + # fast_dev_run=True, + # logger=logger, + # profiler=True, +) + + +net = NBeats.from_dataset( + training, + learning_rate=3e-2, + log_interval=10, + log_val_interval=1, + log_gradient_flow=False, + weight_decay=1e-2, + use_kan=True, +) +print(f"Number of parameters in network: {net.size() / 1e3:.1f}k") + +# # find optimal learning rate +# # remove logging and artificial epoch size +# net.hparams.log_interval = -1 +# net.hparams.log_val_interval = -1 +# trainer.limit_train_batches = 1.0 +# # run learning rate finder +# res = Tuner(trainer).lr_find( +# net, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5, max_lr=1e2 # noqa: E501 +# ) +# print(f"suggested learning rate: {res.suggestion()}") +# fig = res.plot(show=True, suggest=True) +# fig.show() +# net.hparams.learning_rate = res.suggestion() + +trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, +) diff --git a/pytorch_forecasting/models/nbeats/grid_callback.py b/pytorch_forecasting/models/nbeats/grid_callback.py index 3c36b1ef4..d311cfe84 100644 --- a/pytorch_forecasting/models/nbeats/grid_callback.py +++ b/pytorch_forecasting/models/nbeats/grid_callback.py @@ -6,6 +6,10 @@ class GridUpdateCallback(Callback): Custom callback to update the grid of the model during training at regular intervals. + Example: + See the full example in: + `examples/nbeats_with_kan.py` + Attributes: update_interval (int): The frequency at which the grid is updated. """ diff --git a/pytorch_forecasting/models/nbeats/kan_layer.py b/pytorch_forecasting/models/nbeats/kan_layer.py index d9357bb46..1f7a18a1c 100644 --- a/pytorch_forecasting/models/nbeats/kan_layer.py +++ b/pytorch_forecasting/models/nbeats/kan_layer.py @@ -349,10 +349,7 @@ def get_grid(num_interval): grid_uniform = ( grid_adaptive[:, [0]] - margin - + h - * torch.arange( - num_interval + 1, - )[None, :] + + h * torch.arange(num_interval + 1, device=h.device)[None, :] ) grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive return grid From 05350c2d193b949ba12d206f181e605e4935b618 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Thu, 20 Feb 2025 04:53:14 -0800 Subject: [PATCH 07/21] Refactored code for NBEATSKAN and introduced it as separate model/entity using adapter for common functionality. --- docs/source/models.rst | 1 + examples/nbeats_with_kan.py | 5 +- pytorch_forecasting/__init__.py | 2 + pytorch_forecasting/models/__init__.py | 3 +- pytorch_forecasting/models/nbeats/__init__.py | 9 +- pytorch_forecasting/models/nbeats/_nbeats.py | 406 +----------------- .../models/nbeats/_nbeatskan.py | 235 ++++++++++ .../models/nbeats/nbeats_adapter.py | 322 ++++++++++++++ .../models/nbeats/sub_modules.py | 23 +- 9 files changed, 599 insertions(+), 407 deletions(-) create mode 100644 pytorch_forecasting/models/nbeats/_nbeatskan.py create mode 100644 pytorch_forecasting/models/nbeats/nbeats_adapter.py diff --git a/docs/source/models.rst b/docs/source/models.rst index f8ac486af..71569c724 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -27,6 +27,7 @@ and you should take into account. Here is an overview over the pros and cons of :py:class:`~pytorch_forecasting.models.rnn.RecurrentNetwork`, "x", "x", "x", "", "", "", "", "x", "", 2 :py:class:`~pytorch_forecasting.models.mlp.DecoderMLP`, "x", "x", "x", "x", "", "x", "", "x", "x", 1 :py:class:`~pytorch_forecasting.models.nbeats.NBeats`, "", "", "x", "", "", "", "", "", "", 1 + :py:class:`~pytorch_forecasting.models.nbeats.NBeatsKAN`, "", "", "x", "", "", "", "", "", "", 1 :py:class:`~pytorch_forecasting.models.nhits.NHiTS`, "x", "x", "x", "", "", "", "", "", "", 1 :py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "x [#deepvar]_ ", "x", "", 3 :py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4 diff --git a/examples/nbeats_with_kan.py b/examples/nbeats_with_kan.py index 925e8dcf0..952a2acce 100644 --- a/examples/nbeats_with_kan.py +++ b/examples/nbeats_with_kan.py @@ -4,7 +4,7 @@ from lightning.pytorch.callbacks import EarlyStopping import pandas as pd -from pytorch_forecasting import NBeats, TimeSeriesDataSet +from pytorch_forecasting import NBeatsKAN, TimeSeriesDataSet from pytorch_forecasting.data import NaNLabelEncoder from pytorch_forecasting.data.examples import generate_ar_data from pytorch_forecasting.models.nbeats.grid_callback import GridUpdateCallback @@ -74,14 +74,13 @@ ) -net = NBeats.from_dataset( +net = NBeatsKAN.from_dataset( training, learning_rate=3e-2, log_interval=10, log_val_interval=1, log_gradient_flow=False, weight_decay=1e-2, - use_kan=True, ) print(f"Number of parameters in network: {net.size() / 1e3:.1f}k") diff --git a/pytorch_forecasting/__init__.py b/pytorch_forecasting/__init__.py index eabfe481f..e1b150d51 100644 --- a/pytorch_forecasting/__init__.py +++ b/pytorch_forecasting/__init__.py @@ -43,6 +43,7 @@ DeepAR, MultiEmbedding, NBeats, + NBeatsKAN, NHiTS, RecurrentNetwork, TemporalFusionTransformer, @@ -71,6 +72,7 @@ "MultiNormalizer", "TemporalFusionTransformer", "NBeats", + "NBeatsKAN", "NHiTS", "Baseline", "DeepAR", diff --git a/pytorch_forecasting/models/__init__.py b/pytorch_forecasting/models/__init__.py index d4173f620..9b92ef30c 100644 --- a/pytorch_forecasting/models/__init__.py +++ b/pytorch_forecasting/models/__init__.py @@ -11,7 +11,7 @@ from pytorch_forecasting.models.baseline import Baseline from pytorch_forecasting.models.deepar import DeepAR from pytorch_forecasting.models.mlp import DecoderMLP -from pytorch_forecasting.models.nbeats import NBeats +from pytorch_forecasting.models.nbeats import NBeats, NBeatsKAN from pytorch_forecasting.models.nhits import NHiTS from pytorch_forecasting.models.nn import GRU, LSTM, MultiEmbedding, get_rnn from pytorch_forecasting.models.rnn import RecurrentNetwork @@ -21,6 +21,7 @@ __all__ = [ "NBeats", + "NBeatsKAN", "NHiTS", "TemporalFusionTransformer", "RecurrentNetwork", diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index b3264272d..87c1fe7fb 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -1,10 +1,17 @@ """N-Beats model for timeseries forecasting without covariates.""" from pytorch_forecasting.models.nbeats._nbeats import NBeats +from pytorch_forecasting.models.nbeats._nbeatskan import NBeatsKAN from pytorch_forecasting.models.nbeats.sub_modules import ( NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock, ) -__all__ = ["NBeats", "NBEATSGenericBlock", "NBEATSSeasonalBlock", "NBEATSTrendBlock"] +__all__ = [ + "NBeats", + "NBeatsKAN", + "NBEATSGenericBlock", + "NBEATSSeasonalBlock", + "NBEATSTrendBlock", +] diff --git a/pytorch_forecasting/models/nbeats/_nbeats.py b/pytorch_forecasting/models/nbeats/_nbeats.py index 77fffd332..f85067e22 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats.py +++ b/pytorch_forecasting/models/nbeats/_nbeats.py @@ -2,24 +2,20 @@ N-Beats model for timeseries forecasting without covariates. """ -from typing import Dict, List, Optional +from typing import List, Optional -import torch from torch import nn -from pytorch_forecasting.data import TimeSeriesDataSet -from pytorch_forecasting.data.encoders import NaNLabelEncoder from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric -from pytorch_forecasting.models.base_model import BaseModel +from pytorch_forecasting.models.nbeats.nbeats_adapter import NBeatsAdapter from pytorch_forecasting.models.nbeats.sub_modules import ( NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock, ) -from pytorch_forecasting.utils._dependencies import _check_matplotlib -class NBeats(BaseModel): +class NBeats(NBeatsAdapter): def __init__( self, stack_types: Optional[List[str]] = None, @@ -40,19 +36,6 @@ def __init__( reduce_on_plateau_patience: int = 1000, backcast_loss_ratio: float = 0.0, logging_metrics: nn.ModuleList = None, - use_kan: bool = False, - num: int = 5, - k: int = 3, - noise_scale: float = 0.5, - scale_base_mu: float = 0.0, - scale_base_sigma: float = 1.0, - scale_sp: float = 1.0, - base_fun: callable = None, - grid_eps: float = 0.02, - grid_range: List[int] = None, - sp_trainable: bool = True, - sb_trainable: bool = True, - sparse_init: bool = False, **kwargs, ): """ @@ -70,23 +53,23 @@ def __init__( Args: stack_types: One of the following values: “generic”, “seasonality" or - “trend". A list of strings of length 1 or ‘num_stacks’. Default and + “trend". A list of strings of length 1 or 'num_stacks'. Default and recommended value for generic mode: [“generic”] Recommended value for interpretable mode: [“trend”,”seasonality”]. num_blocks: The number of blocks per stack. A list of ints of length 1 or - ‘num_stacks’. Default and recommended value for generic mode: [1] + 'num_stacks'. Default and recommended value for generic mode: [1] Recommended value for interpretable mode: [3] num_block_layers: Number of fully connected layers with ReLu activation per block. - A list of ints of length 1 or ‘num_stacks’. Default and recommended + A list of ints of length 1 or 'num_stacks'. Default and recommended value for generic mode: [4] Recommended value for interpretable mode: [4]. width: Widths of the fully connected layers with ReLu activation in the - blocks. A list of ints of length 1 or ‘num_stacks’. Default and + blocks. A list of ints of length 1 or 'num_stacks'. Default and recommended value for generic mode: [512]. Recommended value for interpretable mode: [256, 2048] sharing: Whether the weights are shared with the other blocks per stack. - A list of ints of length 1 or ‘num_stacks’. Default and recommended + A list of ints of length 1 or 'num_stacks'. Default and recommended value for generic mode: [False]. Recommended value for interpretable mode: [True]. expansion_coefficient_length: If the type is “G” (generic), then the length @@ -95,7 +78,7 @@ def __init__( polynomial. If the type is “S” (seasonal) then this is the minimum period allowed, e.g. 2 for changes every timestep. A list of ints of length 1 or - ‘num_stacks’. Default value for generic mode: [32] Recommended value for + 'num_stacks'. Default value for generic mode: [32] Recommended value for interpretable mode: [3] prediction_length: Length of the prediction. Also known as 'horizon'. context_length: Number of time units that condition the predictions. @@ -113,39 +96,9 @@ def __init__( logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training. Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) - use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, - kan layers are used in nbeats block else mlp layers are used. Default: - false. - num : Parameter for KAN layer. the number of grid intervals = G. - Default: 5. used when use_kan is True. - k : Parameter for KAN layer. the order of piecewise polynomial. Default: 3. - used when use_kan is True. - noise_scale : Parameter for KAN layer. the scale of noise injected at - initialization. Default: 0.1. used when use_kan is True. - scale_base_mu : Parameter for KAN layer. the scale of the residual - function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). - Deafult: 0.0. used when use_kan is True. - scale_base_sigma : Parameter for KAN layer. the scale of the residual - function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). - Deafult: 1.0. used when use_kan is True. - scale_sp : Parameter for KAN layer. the scale of the base function - spline(x). Deafult: 1.0. used when use_kan is True. - base_fun : Parameter for KAN layer. residual function b(x). - Default: None. used when use_kan is True. - grid_eps : Parameter for KAN layer. When grid_eps = 1, the grid is uniform; - when grid_eps = 0, the grid is partitioned using percentiles of samples. - 0 < grid_eps < 1 interpolates between the two extremes. Deafult: 0.02. - used when use_kan is True. - grid_range : Parameter for KAN layer. list/np.array of shape (2,). setting - the range of grids. Default: None. used when use_kan is True. - sp_trainable : Parameter for KAN layer. If true, scale_sp is trainable. - Default: True. used when use_kan is True. - sb_trainable : Parameter for KAN layer. If true, scale_base is trainable. - Default: True. used when use_kan is True. - sparse_init : Parameter for KAN layer. if sparse_init = True, sparse - initialization is applied. Default: False. used when use_kan is True. **kwargs: additional arguments to :py:class:`~BaseModel`. """ # noqa: E501 + if expansion_coefficient_lengths is None: expansion_coefficient_lengths = [3, 7] if sharing is None: @@ -158,34 +111,13 @@ def __init__( num_blocks = [3, 3] if stack_types is None: stack_types = ["trend", "seasonality"] - if base_fun is None: - base_fun = torch.nn.SiLU() - if grid_range is None: - grid_range = [-1, 1] if logging_metrics is None: logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) if loss is None: loss = MASE() - # Bundle KAN parameters into a dictionary - self.kan_params = { - "num": num, - "k": k, - "noise_scale": noise_scale, - "scale_base_mu": scale_base_mu, - "scale_base_sigma": scale_base_sigma, - "scale_sp": scale_sp, - "base_fun": base_fun, - "grid_eps": grid_eps, - "grid_range": grid_range, - "sp_trainable": sp_trainable, - "sb_trainable": sb_trainable, - "sparse_init": sparse_init, - } - self.use_kan = use_kan self.save_hyperparameters(ignore=["loss", "logging_metrics"]) super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) - # setup stacks self.net_blocks = nn.ModuleList() for stack_id, stack_type in enumerate(stack_types): @@ -197,9 +129,7 @@ def __init__( num_block_layers=self.hparams.num_block_layers[stack_id], backcast_length=context_length, forecast_length=prediction_length, - dropout=self.hparams.dropout, - kan_params=self.kan_params, - use_kan=use_kan, + dropout=dropout, ) elif stack_type == "seasonality": net_block = NBEATSSeasonalBlock( @@ -207,10 +137,8 @@ def __init__( num_block_layers=self.hparams.num_block_layers[stack_id], backcast_length=context_length, forecast_length=prediction_length, - min_period=self.hparams.expansion_coefficient_lengths[stack_id], - dropout=self.hparams.dropout, - kan_params=self.kan_params, - use_kan=use_kan, + min_period=expansion_coefficient_lengths[stack_id], + dropout=dropout, ) elif stack_type == "trend": net_block = NBEATSTrendBlock( @@ -219,315 +147,9 @@ def __init__( num_block_layers=self.hparams.num_block_layers[stack_id], backcast_length=context_length, forecast_length=prediction_length, - dropout=self.hparams.dropout, - kan_params=self.kan_params, - use_kan=use_kan, + dropout=dropout, ) else: raise ValueError(f"Unknown stack type {stack_type}") self.net_blocks.append(net_block) - - def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """ - Pass forward of network. - - Args: - x (Dict[str, torch.Tensor]): input from dataloader generated from - :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. - - Returns: - Dict[str, torch.Tensor]: output of model - """ - target = x["encoder_cont"][..., 0] - - timesteps = self.hparams.context_length + self.hparams.prediction_length - generic_forecast = [ - torch.zeros( - (target.size(0), timesteps), dtype=torch.float32, device=self.device - ) - ] - trend_forecast = [ - torch.zeros( - (target.size(0), timesteps), dtype=torch.float32, device=self.device - ) - ] - seasonal_forecast = [ - torch.zeros( - (target.size(0), timesteps), dtype=torch.float32, device=self.device - ) - ] - forecast = torch.zeros( - (target.size(0), self.hparams.prediction_length), - dtype=torch.float32, - device=self.device, - ) - - backcast = target # initialize backcast - for i, block in enumerate(self.net_blocks): - # evaluate block - backcast_block, forecast_block = block(backcast) - - # add for interpretation - full = torch.cat([backcast_block.detach(), forecast_block.detach()], dim=1) - if isinstance(block, NBEATSTrendBlock): - trend_forecast.append(full) - elif isinstance(block, NBEATSSeasonalBlock): - seasonal_forecast.append(full) - else: - generic_forecast.append(full) - - # update backcast and forecast - backcast = ( - backcast - backcast_block - ) # do not use backcast -= backcast_block as this signifies an inline operation # noqa : E501 - forecast = forecast + forecast_block - - return self.to_network_output( - prediction=self.transform_output(forecast, target_scale=x["target_scale"]), - backcast=self.transform_output( - prediction=target - backcast, target_scale=x["target_scale"] - ), - trend=self.transform_output( - torch.stack(trend_forecast, dim=0).sum(0), - target_scale=x["target_scale"], - ), - seasonality=self.transform_output( - torch.stack(seasonal_forecast, dim=0).sum(0), - target_scale=x["target_scale"], - ), - generic=self.transform_output( - torch.stack(generic_forecast, dim=0).sum(0), - target_scale=x["target_scale"], - ), - ) - - def update_kan_grid(self): - """ - Updates grid of KAN layers when using KAN layers in NBEATSBlock. - """ - if self.use_kan: - for block in self.net_blocks: - # updation logic taken from - # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 - for i, layer in enumerate(block.fc): - # update basis KAN layers' grid - layer.update_grid_from_samples(block.outputs[i]) - # update theta backward and theta forward KAN layers' grid - block.theta_b_fc.update_grid_from_samples(block.outputs[i + 1]) - block.theta_f_fc.update_grid_from_samples(block.outputs[i + 1]) - - @classmethod - def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): - """ - Convenience function to create network from :py:class - `~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. - - Args: - dataset (TimeSeriesDataSet): dataset where sole predictor is the target. - **kwargs: additional arguments to be passed to ``__init__`` method. - - Returns: - NBeats - """ # noqa: E501 - new_kwargs = { - "prediction_length": dataset.max_prediction_length, - "context_length": dataset.max_encoder_length, - } - new_kwargs.update(kwargs) - - # validate arguments - assert isinstance( - dataset.target, str - ), "only one target is allowed (passed as string to dataset)" - assert not isinstance( - dataset.target_normalizer, NaNLabelEncoder - ), "only regression tasks are supported - target must not be categorical" - assert dataset.min_encoder_length == dataset.max_encoder_length, ( - "only fixed encoder length is allowed," - " but min_encoder_length != max_encoder_length" - ) - - assert dataset.max_prediction_length == dataset.min_prediction_length, ( - "only fixed prediction length is allowed," - " but max_prediction_length != min_prediction_length" - ) - - assert ( - dataset.randomize_length is None - ), "length has to be fixed, but randomize_length is not None" - assert ( - not dataset.add_relative_time_idx - ), "add_relative_time_idx has to be False" - - assert ( - len(dataset.flat_categoricals) == 0 - and len(dataset.reals) == 1 - and len(dataset._time_varying_unknown_reals) == 1 - and dataset._time_varying_unknown_reals[0] == dataset.target - ), ( - "The only variable as input should be the" - " target which is part of time_varying_unknown_reals" - ) - - # initialize class - return super().from_dataset(dataset, **new_kwargs) - - def step(self, x, y, batch_idx) -> Dict[str, torch.Tensor]: - """ - Take training / validation step. - """ - log, out = super().step(x, y, batch_idx=batch_idx) - - if ( - self.hparams.backcast_loss_ratio > 0 and not self.predicting - ): # add loss from backcast - backcast = out["backcast"] - backcast_weight = ( - self.hparams.backcast_loss_ratio - * self.hparams.prediction_length - / self.hparams.context_length - ) - backcast_weight = backcast_weight / (backcast_weight + 1) # normalize - forecast_weight = 1 - backcast_weight - if isinstance(self.loss, MASE): - backcast_loss = ( - self.loss(backcast, x["encoder_target"], x["decoder_target"]) - * backcast_weight - ) - else: - backcast_loss = ( - self.loss(backcast, x["encoder_target"]) * backcast_weight - ) - label = ["val", "train"][self.training] - self.log( - f"{label}_backcast_loss", - backcast_loss, - on_epoch=True, - on_step=self.training, - batch_size=len(x["decoder_target"]), - ) - self.log( - f"{label}_forecast_loss", - log["loss"], - on_epoch=True, - on_step=self.training, - batch_size=len(x["decoder_target"]), - ) - log["loss"] = log["loss"] * forecast_weight + backcast_loss - - self.log_interpretation(x, out, batch_idx=batch_idx) - return log, out - - def log_interpretation(self, x, out, batch_idx): - """ - Log interpretation of network predictions in tensorboard. - """ - mpl_available = _check_matplotlib("log_interpretation", raise_error=False) - - # Don't log figures if matplotlib or add_figure is not available - if not mpl_available or not self._logger_supports("add_figure"): - return None - - label = ["val", "train"][self.training] - if self.log_interval > 0 and batch_idx % self.log_interval == 0: - fig = self.plot_interpretation(x, out, idx=0) - name = f"{label.capitalize()} interpretation of item 0 in " - if self.training: - name += f"step {self.global_step}" - else: - name += f"batch {batch_idx}" - self.logger.experiment.add_figure(name, fig, global_step=self.global_step) - - def plot_interpretation( - self, - x: Dict[str, torch.Tensor], - output: Dict[str, torch.Tensor], - idx: int, - ax=None, - plot_seasonality_and_generic_on_secondary_axis: bool = False, - ): - """ - Plot interpretation. - - Plot two pannels: prediction and backcast vs actuals and - decomposition of prediction into trend, seasonality and generic forecast. - - Args: - x (Dict[str, torch.Tensor]): network input - output (Dict[str, torch.Tensor]): network output - idx (int): index of sample for which to plot the interpretation. - ax (List[matplotlib axes], optional): list of two matplotlib axes onto which - to plot the interpretation. Defaults to None. - plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot - seasonality and generic forecast on secondary axis in second panel. - Defaults to False. - - Returns: - plt.Figure: matplotlib figure - """ # noqa: E501 - _check_matplotlib("plot_interpretation") - - import matplotlib.pyplot as plt - - if ax is None: - fig, ax = plt.subplots(2, 1, figsize=(6, 8)) - else: - fig = ax[0].get_figure() - - time = torch.arange( - -self.hparams.context_length, self.hparams.prediction_length - ) - - # plot target vs prediction - ax[0].plot( - time, - torch.cat([x["encoder_target"][idx], x["decoder_target"][idx]]) - .detach() - .cpu(), - label="target", - ) - ax[0].plot( - time, - torch.cat( - [ - output["backcast"][idx].detach(), - output["prediction"][idx].detach(), - ], - dim=0, - ).cpu(), - label="prediction", - ) - ax[0].set_xlabel("Time") - - # plot blocks - prop_cycle = iter(plt.rcParams["axes.prop_cycle"]) - next(prop_cycle) # prediction - next(prop_cycle) # observations - if plot_seasonality_and_generic_on_secondary_axis: - ax2 = ax[1].twinx() - ax2.set_ylabel("Seasonality / Generic") - else: - ax2 = ax[1] - for title in ["trend", "seasonality", "generic"]: - if title not in self.hparams.stack_types: - continue - if title == "trend": - ax[1].plot( - time, - output[title][idx].detach().cpu(), - label=title.capitalize(), - c=next(prop_cycle)["color"], - ) - else: - ax2.plot( - time, - output[title][idx].detach().cpu(), - label=title.capitalize(), - c=next(prop_cycle)["color"], - ) - ax[1].set_xlabel("Time") - ax[1].set_ylabel("Decomposition") - - fig.legend() - return fig diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan.py b/pytorch_forecasting/models/nbeats/_nbeatskan.py new file mode 100644 index 000000000..ea9591645 --- /dev/null +++ b/pytorch_forecasting/models/nbeats/_nbeatskan.py @@ -0,0 +1,235 @@ +""" +N-Beats model with KAN blocks for timeseries forecasting without covariates. +""" + +from typing import List, Optional + +import torch +from torch import nn + +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric +from pytorch_forecasting.models.nbeats.nbeats_adapter import NBeatsAdapter +from pytorch_forecasting.models.nbeats.sub_modules import ( + NBEATSGenericBlock, + NBEATSSeasonalBlock, + NBEATSTrendBlock, +) + + +class NBeatsKAN(NBeatsAdapter): + def __init__( + self, + stack_types: Optional[List[str]] = None, + num_blocks: Optional[List[int]] = None, + num_block_layers: Optional[List[int]] = None, + widths: Optional[List[int]] = None, + sharing: Optional[List[bool]] = None, + expansion_coefficient_lengths: Optional[List[int]] = None, + prediction_length: int = 1, + context_length: int = 1, + dropout: float = 0.1, + learning_rate: float = 1e-2, + log_interval: int = -1, + log_gradient_flow: bool = False, + log_val_interval: int = None, + weight_decay: float = 1e-3, + loss: MultiHorizonMetric = None, + reduce_on_plateau_patience: int = 1000, + backcast_loss_ratio: float = 0.0, + logging_metrics: nn.ModuleList = None, + num: int = 5, + k: int = 3, + noise_scale: float = 0.5, + scale_base_mu: float = 0.0, + scale_base_sigma: float = 1.0, + scale_sp: float = 1.0, + base_fun: callable = None, + grid_eps: float = 0.02, + grid_range: List[int] = None, + sp_trainable: bool = True, + sb_trainable: bool = True, + sparse_init: bool = False, + **kwargs, + ): + """ + Initialize NBeats Model - use its :py:meth:`~from_dataset` method if possible. + + Based on the article + `N-BEATS: Neural basis expansion analysis for interpretable time series + forecasting `_. The network has (if + used as ensemble) outperformed all other methods including ensembles of + traditional statical methods in the M4 competition. The M4 competition is + arguably the most important benchmark for univariate time series forecasting. + + The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently + shown to consistently outperform N-BEATS. + + Args: + stack_types: One of the following values: “generic”, “seasonality" or + “trend". A list of strings of length 1 or 'num_stacks'. Default and + recommended value for generic mode: [“generic”] Recommended value for + interpretable mode: [“trend”,”seasonality”]. + num_blocks: The number of blocks per stack. A list of ints of length 1 or + 'num_stacks'. Default and recommended value for generic mode: [1] + Recommended value for interpretable mode: [3] + num_block_layers: Number of fully connected layers with ReLu activation per + block. + A list of ints of length 1 or 'num_stacks'. Default and recommended + value for generic mode: [4] Recommended value for interpretable mode: + [4]. + width: Widths of the fully connected layers with ReLu activation in the + blocks. A list of ints of length 1 or 'num_stacks'. Default and + recommended value for generic mode: [512]. Recommended value for + interpretable mode: [256, 2048] + sharing: Whether the weights are shared with the other blocks per stack. + A list of ints of length 1 or 'num_stacks'. Default and recommended + value for generic mode: [False]. Recommended value for interpretable + mode: [True]. + expansion_coefficient_length: If the type is “G” (generic), then the length + of the expansion coefficient. + If type is “T” (trend), then it corresponds to the degree of the + polynomial. + If the type is “S” (seasonal) then this is the minimum period allowed, + e.g. 2 for changes every timestep. A list of ints of length 1 or + 'num_stacks'. Default value for generic mode: [32] Recommended value for + interpretable mode: [3] + prediction_length: Length of the prediction. Also known as 'horizon'. + context_length: Number of time units that condition the predictions. + Also known as 'lookback period'. + Should be between 1-10 times the prediction length. + backcast_loss_ratio: weight of backcast in comparison to forecast when + calculating the loss. A weight of 1.0 means that forecast and + backcast loss is weighted the same (regardless of backcast and forecast + lengths). Defaults to 0.0, i.e. no weight. + loss: loss to optimize. Defaults to MASE(). + log_gradient_flow: if to log gradient flow, this takes time and should be + only done to diagnose training failures. + reduce_on_plateau_patience (int): patience after which learning rate is + reduced by a factor of 10 + logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that + are logged during training. Defaults to + nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + num : Parameter for KAN layer. the number of grid intervals = G. + Default: 5. + k : Parameter for KAN layer. the order of piecewise polynomial. Default: 3. + noise_scale : Parameter for KAN layer. the scale of noise injected at + initialization. Default: 0.1. + scale_base_mu : Parameter for KAN layer. the scale of the residual + function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). + Deafult: 0.0. + scale_base_sigma : Parameter for KAN layer. the scale of the residual + function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). + Deafult: 1.0. + scale_sp : Parameter for KAN layer. the scale of the base function + spline(x). Deafult: 1.0. + base_fun : Parameter for KAN layer. residual function b(x). + Default: None. + grid_eps : Parameter for KAN layer. When grid_eps = 1, the grid is uniform; + when grid_eps = 0, the grid is partitioned using percentiles of samples. + 0 < grid_eps < 1 interpolates between the two extremes. Deafult: 0.02. + grid_range : Parameter for KAN layer. list/np.array of shape (2,). setting + the range of grids. Default: None. + sp_trainable : Parameter for KAN layer. If true, scale_sp is trainable. + Default: True. + sb_trainable : Parameter for KAN layer. If true, scale_base is trainable. + Default: True. + sparse_init : Parameter for KAN layer. if sparse_init = True, sparse + initialization is applied. Default: False. + **kwargs: additional arguments to :py:class:`~BaseModel`. + """ # noqa: E501 + + if base_fun is None: + base_fun = torch.nn.SiLU() + if grid_range is None: + grid_range = [-1, 1] + if expansion_coefficient_lengths is None: + expansion_coefficient_lengths = [3, 7] + if sharing is None: + sharing = [True, True] + if widths is None: + widths = [32, 512] + if num_block_layers is None: + num_block_layers = [3, 3] + if num_blocks is None: + num_blocks = [3, 3] + if stack_types is None: + stack_types = ["trend", "seasonality"] + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + if loss is None: + loss = MASE() + + self.save_hyperparameters(ignore=["loss", "logging_metrics"]) + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + + # Bundle KAN parameters into a dictionary + kan_params = { + "num": num, + "k": k, + "noise_scale": noise_scale, + "scale_base_mu": scale_base_mu, + "scale_base_sigma": scale_base_sigma, + "scale_sp": scale_sp, + "base_fun": base_fun, + "grid_eps": grid_eps, + "grid_range": grid_range, + "sp_trainable": sp_trainable, + "sb_trainable": sb_trainable, + "sparse_init": sparse_init, + } + self.kan_params = kan_params + # setup stacks + self.net_blocks = nn.ModuleList() + for stack_id, stack_type in enumerate(stack_types): + for _ in range(num_blocks[stack_id]): + if stack_type == "generic": + net_block = NBEATSGenericBlock( + units=self.hparams.widths[stack_id], + thetas_dim=self.hparams.expansion_coefficient_lengths[stack_id], + num_block_layers=self.hparams.num_block_layers[stack_id], + backcast_length=context_length, + forecast_length=prediction_length, + dropout=dropout, + kan_params=self.kan_params, + use_kan=True, + ) + elif stack_type == "seasonality": + net_block = NBEATSSeasonalBlock( + units=self.hparams.widths[stack_id], + num_block_layers=self.hparams.num_block_layers[stack_id], + backcast_length=context_length, + forecast_length=prediction_length, + min_period=expansion_coefficient_lengths[stack_id], + dropout=dropout, + kan_params=self.kan_params, + use_kan=True, + ) + elif stack_type == "trend": + net_block = NBEATSTrendBlock( + units=self.hparams.widths[stack_id], + thetas_dim=self.hparams.expansion_coefficient_lengths[stack_id], + num_block_layers=self.hparams.num_block_layers[stack_id], + backcast_length=context_length, + forecast_length=prediction_length, + dropout=dropout, + kan_params=self.kan_params, + use_kan=True, + ) + else: + raise ValueError(f"Unknown stack type {stack_type}") + + self.net_blocks.append(net_block) + + def update_kan_grid(self): + """ + Updates grid of KAN layers when using KAN layers in NBEATSBlock. + """ + for block in self.net_blocks: + # updation logic taken from + # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 + for i, layer in enumerate(block.fc): + # update basis KAN layers' grid + layer.update_grid_from_samples(block.outputs[i]) + # update theta backward and theta forward KAN layers' grid + block.theta_b_fc.update_grid_from_samples(block.outputs[i + 1]) + block.theta_f_fc.update_grid_from_samples(block.outputs[i + 1]) diff --git a/pytorch_forecasting/models/nbeats/nbeats_adapter.py b/pytorch_forecasting/models/nbeats/nbeats_adapter.py new file mode 100644 index 000000000..d08d4c5ca --- /dev/null +++ b/pytorch_forecasting/models/nbeats/nbeats_adapter.py @@ -0,0 +1,322 @@ +""" +N-Beats model adapter for timeseries forecasting without covariates. +""" + +from typing import Dict, List, Optional + +import torch +from torch import nn + +from pytorch_forecasting.data import TimeSeriesDataSet +from pytorch_forecasting.data.encoders import NaNLabelEncoder +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric +from pytorch_forecasting.models.base_model import BaseModel +from pytorch_forecasting.models.nbeats.sub_modules import ( + NBEATSGenericBlock, + NBEATSSeasonalBlock, + NBEATSTrendBlock, +) +from pytorch_forecasting.utils._dependencies import _check_matplotlib + + +class NBeatsAdapter(BaseModel): + def __init__( + self, + **kwargs, + ): + """ + Initialize NBeats Adapter. + + Args: + **kwargs: additional arguments to :py:class:`~BaseModel`. + """ # noqa: E501 + super().__init__(**kwargs) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Pass forward of network. + + Args: + x (Dict[str, torch.Tensor]): input from dataloader generated from + :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + + Returns: + Dict[str, torch.Tensor]: output of model + """ + target = x["encoder_cont"][..., 0] + + timesteps = self.hparams.context_length + self.hparams.prediction_length + generic_forecast = [ + torch.zeros( + (target.size(0), timesteps), dtype=torch.float32, device=self.device + ) + ] + trend_forecast = [ + torch.zeros( + (target.size(0), timesteps), dtype=torch.float32, device=self.device + ) + ] + seasonal_forecast = [ + torch.zeros( + (target.size(0), timesteps), dtype=torch.float32, device=self.device + ) + ] + forecast = torch.zeros( + (target.size(0), self.hparams.prediction_length), + dtype=torch.float32, + device=self.device, + ) + + backcast = target # initialize backcast + for i, block in enumerate(self.net_blocks): + # evaluate block + backcast_block, forecast_block = block(backcast) + + # add for interpretation + full = torch.cat([backcast_block.detach(), forecast_block.detach()], dim=1) + if isinstance(block, NBEATSTrendBlock): + trend_forecast.append(full) + elif isinstance(block, NBEATSSeasonalBlock): + seasonal_forecast.append(full) + else: + generic_forecast.append(full) + + # update backcast and forecast + backcast = ( + backcast - backcast_block + ) # do not use backcast -= backcast_block as this signifies an inline operation # noqa : E501 + forecast = forecast + forecast_block + + return self.to_network_output( + prediction=self.transform_output(forecast, target_scale=x["target_scale"]), + backcast=self.transform_output( + prediction=target - backcast, target_scale=x["target_scale"] + ), + trend=self.transform_output( + torch.stack(trend_forecast, dim=0).sum(0), + target_scale=x["target_scale"], + ), + seasonality=self.transform_output( + torch.stack(seasonal_forecast, dim=0).sum(0), + target_scale=x["target_scale"], + ), + generic=self.transform_output( + torch.stack(generic_forecast, dim=0).sum(0), + target_scale=x["target_scale"], + ), + ) + + @classmethod + def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): + """ + Convenience function to create network from :py:class + `~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + + Args: + dataset (TimeSeriesDataSet): dataset where sole predictor is the target. + **kwargs: additional arguments to be passed to ``__init__`` method. + + Returns: + NBeats + """ # noqa: E501 + new_kwargs = { + "prediction_length": dataset.max_prediction_length, + "context_length": dataset.max_encoder_length, + } + new_kwargs.update(kwargs) + + # validate arguments + assert isinstance( + dataset.target, str + ), "only one target is allowed (passed as string to dataset)" + assert not isinstance( + dataset.target_normalizer, NaNLabelEncoder + ), "only regression tasks are supported - target must not be categorical" + assert dataset.min_encoder_length == dataset.max_encoder_length, ( + "only fixed encoder length is allowed," + " but min_encoder_length != max_encoder_length" + ) + + assert dataset.max_prediction_length == dataset.min_prediction_length, ( + "only fixed prediction length is allowed," + " but max_prediction_length != min_prediction_length" + ) + + assert ( + dataset.randomize_length is None + ), "length has to be fixed, but randomize_length is not None" + assert ( + not dataset.add_relative_time_idx + ), "add_relative_time_idx has to be False" + + assert ( + len(dataset.flat_categoricals) == 0 + and len(dataset.reals) == 1 + and len(dataset._time_varying_unknown_reals) == 1 + and dataset._time_varying_unknown_reals[0] == dataset.target + ), ( + "The only variable as input should be the" + " target which is part of time_varying_unknown_reals" + ) + + # initialize class + return super().from_dataset(dataset, **new_kwargs) + + def step(self, x, y, batch_idx) -> Dict[str, torch.Tensor]: + """ + Take training / validation step. + """ + log, out = super().step(x, y, batch_idx=batch_idx) + + if ( + self.hparams.backcast_loss_ratio > 0 and not self.predicting + ): # add loss from backcast + backcast = out["backcast"] + backcast_weight = ( + self.hparams.backcast_loss_ratio + * self.hparams.prediction_length + / self.hparams.context_length + ) + backcast_weight = backcast_weight / (backcast_weight + 1) # normalize + forecast_weight = 1 - backcast_weight + if isinstance(self.loss, MASE): + backcast_loss = ( + self.loss(backcast, x["encoder_target"], x["decoder_target"]) + * backcast_weight + ) + else: + backcast_loss = ( + self.loss(backcast, x["encoder_target"]) * backcast_weight + ) + label = ["val", "train"][self.training] + self.log( + f"{label}_backcast_loss", + backcast_loss, + on_epoch=True, + on_step=self.training, + batch_size=len(x["decoder_target"]), + ) + self.log( + f"{label}_forecast_loss", + log["loss"], + on_epoch=True, + on_step=self.training, + batch_size=len(x["decoder_target"]), + ) + log["loss"] = log["loss"] * forecast_weight + backcast_loss + + self.log_interpretation(x, out, batch_idx=batch_idx) + return log, out + + def log_interpretation(self, x, out, batch_idx): + """ + Log interpretation of network predictions in tensorboard. + """ + mpl_available = _check_matplotlib("log_interpretation", raise_error=False) + + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): + return None + + label = ["val", "train"][self.training] + if self.log_interval > 0 and batch_idx % self.log_interval == 0: + fig = self.plot_interpretation(x, out, idx=0) + name = f"{label.capitalize()} interpretation of item 0 in " + if self.training: + name += f"step {self.global_step}" + else: + name += f"batch {batch_idx}" + self.logger.experiment.add_figure(name, fig, global_step=self.global_step) + + def plot_interpretation( + self, + x: Dict[str, torch.Tensor], + output: Dict[str, torch.Tensor], + idx: int, + ax=None, + plot_seasonality_and_generic_on_secondary_axis: bool = False, + ): + """ + Plot interpretation. + + Plot two pannels: prediction and backcast vs actuals and + decomposition of prediction into trend, seasonality and generic forecast. + + Args: + x (Dict[str, torch.Tensor]): network input + output (Dict[str, torch.Tensor]): network output + idx (int): index of sample for which to plot the interpretation. + ax (List[matplotlib axes], optional): list of two matplotlib axes onto which + to plot the interpretation. Defaults to None. + plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot + seasonality and generic forecast on secondary axis in second panel. + Defaults to False. + + Returns: + plt.Figure: matplotlib figure + """ # noqa: E501 + _check_matplotlib("plot_interpretation") + + import matplotlib.pyplot as plt + + if ax is None: + fig, ax = plt.subplots(2, 1, figsize=(6, 8)) + else: + fig = ax[0].get_figure() + + time = torch.arange( + -self.hparams.context_length, self.hparams.prediction_length + ) + + # plot target vs prediction + ax[0].plot( + time, + torch.cat([x["encoder_target"][idx], x["decoder_target"][idx]]) + .detach() + .cpu(), + label="target", + ) + ax[0].plot( + time, + torch.cat( + [ + output["backcast"][idx].detach(), + output["prediction"][idx].detach(), + ], + dim=0, + ).cpu(), + label="prediction", + ) + ax[0].set_xlabel("Time") + + # plot blocks + prop_cycle = iter(plt.rcParams["axes.prop_cycle"]) + next(prop_cycle) # prediction + next(prop_cycle) # observations + if plot_seasonality_and_generic_on_secondary_axis: + ax2 = ax[1].twinx() + ax2.set_ylabel("Seasonality / Generic") + else: + ax2 = ax[1] + for title in ["trend", "seasonality", "generic"]: + if title not in self.hparams.stack_types: + continue + if title == "trend": + ax[1].plot( + time, + output[title][idx].detach().cpu(), + label=title.capitalize(), + c=next(prop_cycle)["color"], + ) + else: + ax2.plot( + time, + output[title][idx].detach().cpu(), + label=title.capitalize(), + c=next(prop_cycle)["color"], + ) + ax[1].set_xlabel("Time") + ax[1].set_ylabel("Decomposition") + + fig.legend() + return fig diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index 7ddf17a20..492017e5b 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -145,17 +145,20 @@ def forward(self, x): """ Pass through the fully connected mlp/kan layers and returns the output. """ - # outputs logic taken from - # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 - self.outputs = [] - self.outputs.append(x.clone().detach()) - for layer in self.fc: - x = layer(x) # Pass data through the current layer - # storing outputs for updating grids of self.fc when using KAN + if self.use_kan: + # save outputs to be used in updating grid in kan layers during training + # outputs logic taken from + # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 + self.outputs = [] + self.outputs.append(x.clone().detach()) + for layer in self.fc: + x = layer(x) # Pass data through the current layer + # storing outputs for updating grids of self.fc when using KAN + self.outputs.append(x.clone().detach()) + # storing for updating grids of theta_b_fc and theta_f_fc when using KAN self.outputs.append(x.clone().detach()) - # storing for updating grids of theta_b_fc and theta_f_fc when using KAN - self.outputs.append(x.clone().detach()) - return x # Return final output + return x # Return final output + return self.fc(x) class NBEATSSeasonalBlock(NBEATSBlock): From 7070f8b429854ef040b82b5cf4659f489738dcaa Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Sat, 22 Feb 2025 23:13:35 -0800 Subject: [PATCH 08/21] Made modules private. --- examples/nbeats_with_kan.py | 2 +- pytorch_forecasting/models/nbeats/__init__.py | 4 ++++ .../models/nbeats/{grid_callback.py => _grid_callback.py} | 0 .../models/nbeats/{kan_layer.py => _kan_layer.py} | 0 pytorch_forecasting/models/nbeats/_nbeats.py | 2 +- .../models/nbeats/{nbeats_adapter.py => _nbeats_adapter.py} | 0 pytorch_forecasting/models/nbeats/_nbeatskan.py | 2 +- pytorch_forecasting/models/nbeats/sub_modules.py | 2 +- 8 files changed, 8 insertions(+), 4 deletions(-) rename pytorch_forecasting/models/nbeats/{grid_callback.py => _grid_callback.py} (100%) rename pytorch_forecasting/models/nbeats/{kan_layer.py => _kan_layer.py} (100%) rename pytorch_forecasting/models/nbeats/{nbeats_adapter.py => _nbeats_adapter.py} (100%) diff --git a/examples/nbeats_with_kan.py b/examples/nbeats_with_kan.py index 952a2acce..6a018ce5d 100644 --- a/examples/nbeats_with_kan.py +++ b/examples/nbeats_with_kan.py @@ -7,7 +7,7 @@ from pytorch_forecasting import NBeatsKAN, TimeSeriesDataSet from pytorch_forecasting.data import NaNLabelEncoder from pytorch_forecasting.data.examples import generate_ar_data -from pytorch_forecasting.models.nbeats.grid_callback import GridUpdateCallback +from pytorch_forecasting.models.nbeats import GridUpdateCallback sys.path.append("..") diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index 87c1fe7fb..b588093af 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -1,6 +1,8 @@ """N-Beats model for timeseries forecasting without covariates.""" +from pytorch_forecasting.models.nbeats._grid_callback import GridUpdateCallback from pytorch_forecasting.models.nbeats._nbeats import NBeats +from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter from pytorch_forecasting.models.nbeats._nbeatskan import NBeatsKAN from pytorch_forecasting.models.nbeats.sub_modules import ( NBEATSGenericBlock, @@ -14,4 +16,6 @@ "NBEATSGenericBlock", "NBEATSSeasonalBlock", "NBEATSTrendBlock", + "NBeatsAdapter", + "GridUpdateCallback", ] diff --git a/pytorch_forecasting/models/nbeats/grid_callback.py b/pytorch_forecasting/models/nbeats/_grid_callback.py similarity index 100% rename from pytorch_forecasting/models/nbeats/grid_callback.py rename to pytorch_forecasting/models/nbeats/_grid_callback.py diff --git a/pytorch_forecasting/models/nbeats/kan_layer.py b/pytorch_forecasting/models/nbeats/_kan_layer.py similarity index 100% rename from pytorch_forecasting/models/nbeats/kan_layer.py rename to pytorch_forecasting/models/nbeats/_kan_layer.py diff --git a/pytorch_forecasting/models/nbeats/_nbeats.py b/pytorch_forecasting/models/nbeats/_nbeats.py index f85067e22..3326bb5a9 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats.py +++ b/pytorch_forecasting/models/nbeats/_nbeats.py @@ -7,7 +7,7 @@ from torch import nn from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric -from pytorch_forecasting.models.nbeats.nbeats_adapter import NBeatsAdapter +from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter from pytorch_forecasting.models.nbeats.sub_modules import ( NBEATSGenericBlock, NBEATSSeasonalBlock, diff --git a/pytorch_forecasting/models/nbeats/nbeats_adapter.py b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py similarity index 100% rename from pytorch_forecasting/models/nbeats/nbeats_adapter.py rename to pytorch_forecasting/models/nbeats/_nbeats_adapter.py diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan.py b/pytorch_forecasting/models/nbeats/_nbeatskan.py index ea9591645..9df6b3d2e 100644 --- a/pytorch_forecasting/models/nbeats/_nbeatskan.py +++ b/pytorch_forecasting/models/nbeats/_nbeatskan.py @@ -8,7 +8,7 @@ from torch import nn from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric -from pytorch_forecasting.models.nbeats.nbeats_adapter import NBeatsAdapter +from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter from pytorch_forecasting.models.nbeats.sub_modules import ( NBEATSGenericBlock, NBEATSSeasonalBlock, diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index 492017e5b..e1ea1288f 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -9,7 +9,7 @@ import torch.nn as nn import torch.nn.functional as F -from pytorch_forecasting.models.nbeats.kan_layer import KANLayer +from pytorch_forecasting.models.nbeats._kan_layer import KANLayer def linear(input_size, output_size, bias=True, dropout: int = None): From 14ca66f9a8c823fa3bb1e5d71d4c49ea0ba53205 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Sat, 5 Jul 2025 02:12:32 -0700 Subject: [PATCH 09/21] Address deprecated typing classes --- pytorch_forecasting/models/nbeats/_kan_layer.py | 10 +++------- .../models/nbeats/_nbeats_adapter.py | 10 +++++----- pytorch_forecasting/models/nbeats/_nbeatskan.py | 16 ++++++++-------- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/pytorch_forecasting/models/nbeats/_kan_layer.py b/pytorch_forecasting/models/nbeats/_kan_layer.py index 1f7a18a1c..0d1703e44 100644 --- a/pytorch_forecasting/models/nbeats/_kan_layer.py +++ b/pytorch_forecasting/models/nbeats/_kan_layer.py @@ -48,9 +48,7 @@ def B_batch(x, grid, k=0, extend=True): grid[:, :, k:-1] - grid[:, :, : -(k + 1)] ) * B_km1[:, :, :-1] + (grid[:, :, k + 1 :] - x) / ( grid[:, :, k + 1 :] - grid[:, :, 1:(-k)] - ) * B_km1[ - :, :, 1: - ] + ) * B_km1[:, :, 1:] # in case grid is degenerate value = torch.nan_to_num(value) @@ -225,7 +223,7 @@ def __init__( >>> model = KANLayer(in_dim=3, out_dim=5) >>> (model.in_dim, model.out_dim) """ - super(KANLayer, self).__init__() + super().__init__() # size self.out_dim = out_dim self.in_dim = in_dim @@ -265,9 +263,7 @@ def __init__( ).requires_grad_(sb_trainable) self.scale_sp = torch.nn.Parameter( torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask - ).requires_grad_( - sp_trainable - ) # make scale trainable + ).requires_grad_(sp_trainable) # make scale trainable self.base_fun = base_fun self.grid_eps = grid_eps diff --git a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py index d08d4c5ca..1d99b4a6b 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py +++ b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py @@ -2,7 +2,7 @@ N-Beats model adapter for timeseries forecasting without covariates. """ -from typing import Dict, List, Optional +from typing import Optional import torch from torch import nn @@ -32,7 +32,7 @@ def __init__( """ # noqa: E501 super().__init__(**kwargs) - def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Pass forward of network. @@ -162,7 +162,7 @@ def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): # initialize class return super().from_dataset(dataset, **new_kwargs) - def step(self, x, y, batch_idx) -> Dict[str, torch.Tensor]: + def step(self, x, y, batch_idx) -> dict[str, torch.Tensor]: """ Take training / validation step. """ @@ -230,8 +230,8 @@ def log_interpretation(self, x, out, batch_idx): def plot_interpretation( self, - x: Dict[str, torch.Tensor], - output: Dict[str, torch.Tensor], + x: dict[str, torch.Tensor], + output: dict[str, torch.Tensor], idx: int, ax=None, plot_seasonality_and_generic_on_secondary_axis: bool = False, diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan.py b/pytorch_forecasting/models/nbeats/_nbeatskan.py index 9df6b3d2e..855755cec 100644 --- a/pytorch_forecasting/models/nbeats/_nbeatskan.py +++ b/pytorch_forecasting/models/nbeats/_nbeatskan.py @@ -2,7 +2,7 @@ N-Beats model with KAN blocks for timeseries forecasting without covariates. """ -from typing import List, Optional +from typing import Optional import torch from torch import nn @@ -19,12 +19,12 @@ class NBeatsKAN(NBeatsAdapter): def __init__( self, - stack_types: Optional[List[str]] = None, - num_blocks: Optional[List[int]] = None, - num_block_layers: Optional[List[int]] = None, - widths: Optional[List[int]] = None, - sharing: Optional[List[bool]] = None, - expansion_coefficient_lengths: Optional[List[int]] = None, + stack_types: Optional[list[str]] = None, + num_blocks: Optional[list[int]] = None, + num_block_layers: Optional[list[int]] = None, + widths: Optional[list[int]] = None, + sharing: Optional[list[bool]] = None, + expansion_coefficient_lengths: Optional[list[int]] = None, prediction_length: int = 1, context_length: int = 1, dropout: float = 0.1, @@ -45,7 +45,7 @@ def __init__( scale_sp: float = 1.0, base_fun: callable = None, grid_eps: float = 0.02, - grid_range: List[int] = None, + grid_range: list[int] = None, sp_trainable: bool = True, sb_trainable: bool = True, sparse_init: bool = False, From 89a9a4f8b3157e88ebd8c3265432cc8696c7ea12 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Sat, 5 Jul 2025 05:25:02 -0700 Subject: [PATCH 10/21] Refactor code with proper docstrings and cleaner structure --- .../models/nbeats/_kan_layer.py | 118 ++++++++++++------ .../models/nbeats/_nbeats_adapter.py | 6 +- .../models/nbeats/sub_modules.py | 55 +++++--- 3 files changed, 119 insertions(+), 60 deletions(-) diff --git a/pytorch_forecasting/models/nbeats/_kan_layer.py b/pytorch_forecasting/models/nbeats/_kan_layer.py index 0d1703e44..a33f43e67 100644 --- a/pytorch_forecasting/models/nbeats/_kan_layer.py +++ b/pytorch_forecasting/models/nbeats/_kan_layer.py @@ -6,9 +6,9 @@ import torch.nn as nn -def B_batch(x, grid, k=0, extend=True): +def b_batch(x, grid, k=0): """ - evaludate x on B-spline bases + evaluate x on B-spline bases Args: ----- @@ -30,10 +30,16 @@ def B_batch(x, grid, k=0, extend=True): Example ------- + Install the `pykan` package first: + >>> pip install pykan + Then use: + >>> from kan.spline import B_batch - >>> x = torch.rand(100,2) - >>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11) + >>> import torch + >>> x = torch.rand(100, 2) + >>> grid = torch.linspace(-1, 1, steps=11)[None, :].expand(2, 11) >>> B_batch(x, grid, k=3).shape + """ x = x.unsqueeze(dim=2) @@ -42,7 +48,7 @@ def B_batch(x, grid, k=0, extend=True): if k == 0: value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:]) else: - B_km1 = B_batch(x[:, :, 0], grid=grid[0], k=k - 1) + B_km1 = b_batch(x[:, :, 0], grid=grid[0], k=k - 1) value = (x - grid[:, :, : -(k + 1)]) / ( grid[:, :, k:-1] - grid[:, :, : -(k + 1)] @@ -58,7 +64,7 @@ def B_batch(x, grid, k=0, extend=True): def coef2curve(x_eval, grid, coef, k): """ converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves - (summing up B_batch results over B-spline basis). + (summing up b_batch results over B-spline basis). Args: ----- @@ -78,7 +84,7 @@ def coef2curve(x_eval, grid, coef, k): """ - b_splines = B_batch(x_eval, grid, k=k) + b_splines = b_batch(x_eval, grid, k=k) y_eval = torch.einsum("ijk,jlk->ijl", b_splines, coef.to(b_splines)) return y_eval @@ -110,7 +116,7 @@ def curve2coef(x_eval, y_eval, grid, k): in_dim = x_eval.shape[1] out_dim = y_eval.shape[2] n_coef = grid.shape[1] - k - 1 - mat = B_batch(x_eval, grid, k) + mat = b_batch(x_eval, grid, k) mat = mat.permute(1, 0, 2)[:, None, :, :].expand(in_dim, out_dim, batch, n_coef) y_eval = y_eval.permute(1, 2, 0).unsqueeze(dim=3) try: @@ -123,7 +129,19 @@ def curve2coef(x_eval, y_eval, grid, k): def extend_grid(grid, k_extend=0): """ - extend grid + Extend a grid tensor by padding both ends with equal spacing. + + Args: + ----- + grid : torch.Tensor + Grid of shape (in_dim, grid_points). + k_extend : int + Number of points to extend on both ends. + + Returns: + -------- + grid : torch.Tensor + Extended grid of shape (in_dim, grid_points + 2 * k_extend). """ h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) @@ -136,7 +154,19 @@ def extend_grid(grid, k_extend=0): def sparse_mask(in_dim, out_dim): """ - get sparse mask + Generate a sparse connection mask between input and output units. + + Args: + ----- + in_dim : int + Number of input units. + out_dim : int + Number of output units. + + Returns: + -------- + mask : torch.Tensor + Sparse binary mask of shape (in_dim, out_dim). """ in_coord = torch.arange(in_dim) * 1 / in_dim + 1 / (2 * in_dim) out_coord = torch.arange(out_dim) * 1 / out_dim + 1 / (2 * out_dim) @@ -168,15 +198,15 @@ def __init__( scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, - base_fun=torch.nn.SiLU(), + base_fun=None, grid_eps=0.02, - grid_range=[-1, 1], + grid_range=None, sp_trainable=True, sb_trainable=True, sparse_init=False, ): """' - initialize a KANLayer + Initialize a KANLayer Args: ----- @@ -199,13 +229,13 @@ def __init__( scale_sp : float the scale of the base function spline(x). base_fun : function - residual function b(x). Default: torch.nn.SiLU() + residual function b(x). Default: None grid_eps : float When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. grid_range : list/np.array of shape (2,) - setting the range of grids. Default: [-1,1]. + setting the range of grids. Default: None. sp_trainable : bool If true, scale_sp is trainable sb_trainable : bool @@ -219,11 +249,21 @@ def __init__( Example ------- + Install the `pykan` package first: + >>> pip install pykan + Then use: + >>> from kan.KANLayer import * >>> model = KANLayer(in_dim=3, out_dim=5) >>> (model.in_dim, model.out_dim) """ super().__init__() + + # Handle mutable parameters + if grid_range is None: + grid_range = [-1, 1] + if base_fun is None: + base_fun = torch.nn.SiLU() # size self.out_dim = out_dim self.in_dim = in_dim @@ -274,23 +314,23 @@ def forward(self, x): Args: ----- - x : 2D torch.float - inputs, shape (number of samples, input dimension) + x : torch.Tensor + Input tensor of shape (batch_size, in_dim), where: + - batch_size is the number of input samples. + - in_dim is the input feature dimension. Returns: -------- - y : 2D torch.float - outputs, shape (number of samples, output dimension) - preacts : 3D torch.float - fan out x into activations, shape (number of sampels, - output dimension, input dimension) - postacts : 3D torch.float - the outputs of activation functions with preacts as inputs - postspline : 3D torch.float - the outputs of spline functions with preacts as inputs + y : torch.Tensor + Output tensor, the result of applying spline and residual + transformations followed by weighted summation. Example ------- + Install the `pykan` package first: + >>> pip install pykan + Then use: + >>> from kan.KANLayer import * >>> model = KANLayer(in_dim=3, out_dim=5) >>> x = torch.normal(0,1,size=(100,3)) @@ -308,7 +348,7 @@ def forward(self, x): y = torch.sum(y, dim=1) return y - def update_grid_from_samples(self, x, mode="sample"): + def update_grid_from_samples(self, x): """ update grid from samples @@ -336,25 +376,29 @@ def update_grid_from_samples(self, x, mode="sample"): num_interval = self.grid.shape[1] - 1 - 2 * self.k def get_grid(num_interval): + """ + Generate adaptive or uniform grid points from sorted input samples. + + Args: + ----- + num_interval : int + Number of intervals between grid points. + + Returns: + -------- + grid : torch.Tensor + New grid of shape (in_dim, num_interval + 1). + """ ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] grid_adaptive = x_pos[ids, :].permute(1, 0) - margin = 0.00 - h = ( - grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin - ) / num_interval + h = (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]]) / num_interval grid_uniform = ( grid_adaptive[:, [0]] - - margin + h * torch.arange(num_interval + 1, device=h.device)[None, :] ) grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive return grid grid = get_grid(num_interval) - if mode == "grid": - sample_grid = get_grid(2 * num_interval) - x_pos = sample_grid.permute(1, 0) - y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) - self.grid.data = extend_grid(grid, k_extend=self.k) self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) diff --git a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py index 1d99b4a6b..a6ed1218c 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py +++ b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py @@ -5,14 +5,12 @@ from typing import Optional import torch -from torch import nn from pytorch_forecasting.data import TimeSeriesDataSet from pytorch_forecasting.data.encoders import NaNLabelEncoder -from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric +from pytorch_forecasting.metrics import MASE from pytorch_forecasting.models.base_model import BaseModel from pytorch_forecasting.models.nbeats.sub_modules import ( - NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock, ) @@ -239,7 +237,7 @@ def plot_interpretation( """ Plot interpretation. - Plot two pannels: prediction and backcast vs actuals and + Plot two panels: prediction and backcast vs actuals and decomposition of prediction into trend, seasonality and generic forecast. Args: diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index f08099fa8..e14d847c1 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -68,31 +68,48 @@ def __init__( ahead to predict. Default: 5. dropout: The dropout rate applied to the fully connected mlp layers to prevent overfitting. Default: 0.1. - kan_params (dict): Parameters specific to the KAN layer - (used for modeling using KAN). Default: None. + kan_params (dict): Configuration dictionary for the KAN layer. Only + required if `use_kan=True`. If `kan_params` is not provided and + `use_kan=True`, default values will be used. Default: None. Contains: - num_grids (int): The number of grid intervals for KAN. - k (int): The order of the piecewise polynomial for KAN. - noise_scale (float): The scale of noise injected at initialization. - scale_base_mu (float): The scale of the residual function - initialized to N(scale_base_mu, scale_base_sigma^2). - scale_base_sigma (float): The scale of the residual function - initialized to N(scale_base_mu, scale_base_sigma^2). - scale_sp (float): The scale of the base function spline(x) in KAN. - base_fun (function): The residual function used by - KAN (e.g., torch.nn.SiLU()). - grid_eps (float): Determines the partitioning of the grid. If 1, - the grid is uniform; if 0, grid is partitioned by percentiles. - grid_range (list or np.array): The range of the grid, given as - a list of two values. - sp_trainable (bool): If True, the scale_sp is trainable. - sb_trainable (bool): If True, the scale_base is trainable. - sparse_init (bool): If True, applies sparse initialization. + - num (int): Number of grid intervals. Default: 5. + - k (int): Order of the piecewise polynomial. Default: 3. + - noise_scale (float): Initialization noise scale. Default: 0.5. + - scale_base_mu (float): Mean for residual function init. + Default: 0.0. + - scale_base_sigma (float): Std for residual function init. + Default: 1.0. + - scale_sp (float): Scale for spline function. Default: 1.0. + - base_fun (nn.Module): Base function. Default: torch.nn.SiLU(). + - grid_eps (float): 0 → quantile grid, 1 → uniform. Default: 0.02. + - grid_range (list): Range of the spline grid. Default: [-1, 1]. + - sp_trainable (bool): Whether scale_sp is trainable. Default: True. + - sb_trainable (bool): Whether scale_base is trainable. + Default: True. + - sparse_init (bool): Apply sparse init to KAN. Default: False. use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, kan layers are used in nbeats block else mlp layers are used. Default: false. """ super().__init__() + + if use_kan and kan_params is None: + # Define default parameters for KAN if not provided + kan_params = dict( + num=5, + k=3, + noise_scale=0.5, + scale_base_mu=0.0, + scale_base_sigma=1.0, + scale_sp=1.0, + base_fun=torch.nn.SiLU(), + grid_eps=0.02, + grid_range=[-1, 1], + sp_trainable=True, + sb_trainable=True, + sparse_init=False, + ) + self.units = units self.thetas_dim = thetas_dim self.backcast_length = backcast_length From 0c43448b21cd534b6ef482f2ffd83212edafb1a9 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Sat, 5 Jul 2025 06:08:38 -0700 Subject: [PATCH 11/21] Refactor examples in docstring --- pytorch_forecasting/models/nbeats/_kan_layer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/models/nbeats/_kan_layer.py b/pytorch_forecasting/models/nbeats/_kan_layer.py index a33f43e67..290bae753 100644 --- a/pytorch_forecasting/models/nbeats/_kan_layer.py +++ b/pytorch_forecasting/models/nbeats/_kan_layer.py @@ -30,6 +30,9 @@ def b_batch(x, grid, k=0): Example ------- + The following is an example from the original `pykan` library, adapted here + for illustration within the PyTorch Forecasting integration. + Install the `pykan` package first: >>> pip install pykan Then use: @@ -249,6 +252,9 @@ def __init__( Example ------- + The following is an example from the original `pykan` library, adapted here + for illustration within the PyTorch Forecasting integration. + Install the `pykan` package first: >>> pip install pykan Then use: @@ -327,6 +333,9 @@ def forward(self, x): Example ------- + The following is an example from the original `pykan` library, adapted here + for illustration within the PyTorch Forecasting integration. + Install the `pykan` package first: >>> pip install pykan Then use: @@ -334,8 +343,8 @@ def forward(self, x): >>> from kan.KANLayer import * >>> model = KANLayer(in_dim=3, out_dim=5) >>> x = torch.normal(0,1,size=(100,3)) - >>> y, preacts, postacts, postspline = model(x) - >>> y.shape, preacts.shape, postacts.shape, postspline.shape + >>> y, _, _, _ = model(x) + >>> y.shape """ base = self.base_fun(x) # (batch, in_dim) From 2da4d13e48e13f0ceebe84c5a2eb18dcafa86e6a Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Sun, 6 Jul 2025 00:50:30 -0700 Subject: [PATCH 12/21] Include NBEATSKAN package container --- pytorch_forecasting/models/nbeats/__init__.py | 2 + .../models/nbeats/_nbeatskan_pkg.py | 39 +++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index 112f36c67..21537fadd 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -5,6 +5,7 @@ from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter from pytorch_forecasting.models.nbeats._nbeats_pkg import NBeats_pkg from pytorch_forecasting.models.nbeats._nbeatskan import NBeatsKAN +from pytorch_forecasting.models.nbeats._nbeatskan_pkg import NBeatsKAN_pkg from pytorch_forecasting.models.nbeats.sub_modules import ( NBEATSGenericBlock, NBEATSSeasonalBlock, @@ -16,6 +17,7 @@ "NBeatsKAN", "NBEATSGenericBlock", "NBeats_pkg", + "NBeatsKAN_pkg", "NBEATSSeasonalBlock", "NBEATSTrendBlock", "NBeatsAdapter", diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py new file mode 100644 index 000000000..3002c6b3b --- /dev/null +++ b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py @@ -0,0 +1,39 @@ +"""NBeatsKAN package container.""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecaster + + +class NBeatsKAN_pkg(_BasePtForecaster): + """NBeatsKAN package container.""" + + _tags = { + "info:name": "NBeatsKAN", + "info:compute": 1, + "authors": ["Sohaib-Ahmed21"], + "capability:exogenous": False, + "capability:multivariate": False, + "capability:pred_int": False, + "capability:flexible_history_length": False, + "capability:cold_start": False, + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + from pytorch_forecasting.models import NBeatsKAN + + return NBeatsKAN + + @classmethod + def get_test_train_params(cls): + """Return testing parameter settings for the trainer.""" + return [{"backcast_loss_ratio": 1.0}] + + @classmethod + def _get_test_dataloaders_from(cls, params): + """Get dataloaders from parameters.""" + from pytorch_forecasting.tests._data_scenarios import ( + dataloaders_fixed_window_without_covariates, + ) + + return dataloaders_fixed_window_without_covariates() From eb9c79d4cdea56529e5e32c2d5ec3be09598079a Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Mon, 7 Jul 2025 11:36:00 -0700 Subject: [PATCH 13/21] Refactor and enhance docstrings to follow NumPy style, include KAN references, and extend NBEATSKAN test cases. --- .../models/nbeats/_grid_callback.py | 42 +-- .../models/nbeats/_kan_layer.py | 298 ++++++++-------- pytorch_forecasting/models/nbeats/_nbeats.py | 133 ++++---- .../models/nbeats/_nbeats_adapter.py | 73 ++-- .../models/nbeats/_nbeatskan.py | 222 +++++++----- .../models/nbeats/_nbeatskan_pkg.py | 20 +- .../models/nbeats/sub_modules.py | 323 +++++++++--------- 7 files changed, 607 insertions(+), 504 deletions(-) diff --git a/pytorch_forecasting/models/nbeats/_grid_callback.py b/pytorch_forecasting/models/nbeats/_grid_callback.py index d311cfe84..dabba0fb2 100644 --- a/pytorch_forecasting/models/nbeats/_grid_callback.py +++ b/pytorch_forecasting/models/nbeats/_grid_callback.py @@ -6,35 +6,39 @@ class GridUpdateCallback(Callback): Custom callback to update the grid of the model during training at regular intervals. - Example: - See the full example in: - `examples/nbeats_with_kan.py` - - Attributes: - update_interval (int): The frequency at which the grid is updated. + Parameters + ---------- + update_interval : int + The frequency at which the grid is updated. + + Examples + -------- + See the full example in: + `examples/nbeats_with_kan.py` """ def __init__(self, update_interval): - """ - Initializes the callback with the given update interval. - - Args: - update_interval (int): The frequency at which the grid is updated. - """ self.update_interval = update_interval def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): """ - Hook that is called at the end of each training batch. + Hook called at the end of each training batch. + Updates the grid of KAN layers if the current step is a multiple of the update interval. - Args: - trainer (Trainer): The PyTorch Lightning Trainer object. - pl_module (LightningModule): The model being trained (LightningModule). - outputs (Any): Outputs from the model for the current batch. - batch (Any): The current batch of data. - batch_idx (int): Index of the current batch. + Parameters + ---------- + trainer : Trainer + The PyTorch Lightning Trainer object. + pl_module : LightningModule + The model being trained (LightningModule). + outputs : Any + Outputs from the model for the current batch. + batch : Any + The current batch of data. + batch_idx : int + Index of the current batch. """ # Check if the current step is a multiple of the update interval if (trainer.global_step + 1) % self.update_interval == 0: diff --git a/pytorch_forecasting/models/nbeats/_kan_layer.py b/pytorch_forecasting/models/nbeats/_kan_layer.py index 290bae753..a54edf960 100644 --- a/pytorch_forecasting/models/nbeats/_kan_layer.py +++ b/pytorch_forecasting/models/nbeats/_kan_layer.py @@ -8,33 +8,33 @@ def b_batch(x, grid, k=0): """ - evaluate x on B-spline bases - - Args: - ----- - x : 2D torch.tensor - inputs, shape (number of splines, number of samples) - grid : 2D torch.tensor - grids, shape (number of splines, number of grid points) - k : int - the piecewise polynomial order of splines. - extend : bool - If True, k points are extended on both ends. If False, no extension - (zero boundary condition). Default: True - - Returns: - -------- - spline values : 3D torch.tensor - shape (batch, in_dim, G+k). G: the number of grid intervals, - k: spline order. - - Example + Evaluate x on B-spline bases + + Parameters + ---------- + x : torch.Tensor + 2D tensor of inputs, shape (number of splines, number of samples). + grid : torch.Tensor + 2D tensor of grids, shape (number of splines, number of grid points). + k : int + The piecewise polynomial order of splines. + extend : bool + If True, k points are extended on both ends. If False, no extension + (zero boundary condition). Default: True. + + Returns ------- + spline values : torch.Tensor + 3D tensor of shape (batch, in_dim, G+k), where G is the number of + grid intervals and k is the spline order. + + Examples + -------- The following is an example from the original `pykan` library, adapted here for illustration within the PyTorch Forecasting integration. Install the `pykan` package first: - >>> pip install pykan + pip install pykan Then use: >>> from kan.spline import B_batch @@ -42,7 +42,6 @@ def b_batch(x, grid, k=0): >>> x = torch.rand(100, 2) >>> grid = torch.linspace(-1, 1, steps=11)[None, :].expand(2, 11) >>> B_batch(x, grid, k=3).shape - """ x = x.unsqueeze(dim=2) @@ -66,25 +65,25 @@ def b_batch(x, grid, k=0): def coef2curve(x_eval, grid, coef, k): """ - converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves + Converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up b_batch results over B-spline basis). - Args: - ----- - x_eval : 2D torch.tensor - shape (batch, in_dim) - grid : 2D torch.tensor - shape (in_dim, G+2k). G: the number of grid intervals; k: spline order. - coef : 3D torch.tensor - shape (in_dim, out_dim, G+k) - k : int - the piecewise polynomial order of splines. - - Returns: - -------- - y_eval : 3D torch.tensor - shape (batch, in_dim, out_dim) - + Parameters + ---------- + x_eval : torch.Tensor + 2D tensor of shape (batch, in_dim). + grid : torch.Tensor + 2D tensor of shape (in_dim, G+2k). G: the number of grid intervals; + k: spline order. + coef : torch.Tensor + 3D tensor of shape (in_dim, out_dim, G+k). + k : int + The piecewise polynomial order of splines. + + Returns + ------- + y_eval : torch.Tensor + 3D tensor of shape (batch, in_dim, out_dim). """ b_splines = b_batch(x_eval, grid, k=k) @@ -95,25 +94,25 @@ def coef2curve(x_eval, grid, coef, k): def curve2coef(x_eval, y_eval, grid, k): """ - converting B-spline curves to B-spline coefficients using least squares. - - Args: - ----- - x_eval : 2D torch.tensor - shape (batch, in_dim) - y_eval : 3D torch.tensor - shape (batch, in_dim, out_dim) - grid : 2D torch.tensor - shape (in_dim, grid+2*k) - k : int - spline order - lamb : float - regularized least square lambda - - Returns: - -------- - coef : 3D torch.tensor - shape (in_dim, out_dim, G+k) + Estimate spline coefficients via batched least squares. + + Parameters + ---------- + x_eval : torch.Tensor + 2D tensor of shape (batch, in_dim). + y_eval : torch.Tensor + 3D tensor of shape (batch, in_dim, out_dim). + grid : torch.Tensor + 2D tensor of shape (in_dim, grid + 2 * k). + k : int + Spline order. + lamb : float + Regularized least square lambda. + + Returns + ------- + coef : torch.Tensor + 3D tensor of shape (in_dim, out_dim, G + k). """ batch = x_eval.shape[0] in_dim = x_eval.shape[1] @@ -134,17 +133,17 @@ def extend_grid(grid, k_extend=0): """ Extend a grid tensor by padding both ends with equal spacing. - Args: - ----- - grid : torch.Tensor - Grid of shape (in_dim, grid_points). - k_extend : int - Number of points to extend on both ends. + Parameters + ---------- + grid : torch.Tensor + Grid of shape (in_dim, grid_points). + k_extend : int + Number of points to extend on both ends. - Returns: - -------- - grid : torch.Tensor - Extended grid of shape (in_dim, grid_points + 2 * k_extend). + Returns + ------- + grid : torch.Tensor + Extended grid of shape (in_dim, grid_points + 2 * k_extend). """ h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) @@ -159,17 +158,17 @@ def sparse_mask(in_dim, out_dim): """ Generate a sparse connection mask between input and output units. - Args: - ----- - in_dim : int - Number of input units. - out_dim : int - Number of output units. + Parameters + ---------- + in_dim : int + Number of input units. + out_dim : int + Number of output units. - Returns: - -------- - mask : torch.Tensor - Sparse binary mask of shape (in_dim, out_dim). + Returns + ------- + mask : torch.Tensor + Sparse binary mask of shape (in_dim, out_dim). """ in_coord = torch.arange(in_dim) * 1 / in_dim + 1 / (2 * in_dim) out_coord = torch.arange(out_dim) * 1 / out_dim + 1 / (2 * out_dim) @@ -188,7 +187,59 @@ def sparse_mask(in_dim, out_dim): class KANLayer(nn.Module): """ - KANLayer class + Initialize a KANLayer + + Parameters + ---------- + in_dim : int + input dimension. Default: 2. + out_dim : int + output dimension. Default: 3. + num : int + the number of grid intervals = G. Default: 5. + k : int + the order of piecewise polynomial. Default: 3. + noise_scale : float + the scale of noise injected at initialization. Default: 0.1. + scale_base_mu : float + the scale of the residual function b(x) is intialized to be + N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma : float + the scale of the residual function b(x) is intialized to be + N(scale_base_mu, scale_base_sigma^2). + scale_sp : float + the scale of the base function spline(x). + base_fun : function + residual function b(x). Default: None + grid_eps : float + When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is + partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates + between the two extremes. + grid_range : list or np.array of shape (2,) + setting the range of grids. Default: None. + sp_trainable : bool + If true, scale_sp is trainable. + sb_trainable : bool + If true, scale_base is trainable. + sparse_init : bool + if sparse_init = True, sparse initialization is applied. + + Returns + ------- + self : reference to self + + Examples + -------- + The following is an example from the original `pykan` library, adapted here + for illustration within the PyTorch Forecasting integration. + + Install the `pykan` package first: + pip install pykan + Then use: + + >>> from kan.KANLayer import * + >>> model = KANLayer(in_dim=3, out_dim=5) + >>> (model.in_dim, model.out_dim) """ def __init__( @@ -208,61 +259,6 @@ def __init__( sb_trainable=True, sparse_init=False, ): - """' - Initialize a KANLayer - - Args: - ----- - in_dim : int - input dimension. Default: 2. - out_dim : int - output dimension. Default: 3. - num : int - the number of grid intervals = G. Default: 5. - k : int - the order of piecewise polynomial. Default: 3. - noise_scale : float - the scale of noise injected at initialization. Default: 0.1. - scale_base_mu : float - the scale of the residual function b(x) is intialized to be - N(scale_base_mu, scale_base_sigma^2). - scale_base_sigma : float - the scale of the residual function b(x) is intialized to be - N(scale_base_mu, scale_base_sigma^2). - scale_sp : float - the scale of the base function spline(x). - base_fun : function - residual function b(x). Default: None - grid_eps : float - When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is - partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates - between the two extremes. - grid_range : list/np.array of shape (2,) - setting the range of grids. Default: None. - sp_trainable : bool - If true, scale_sp is trainable - sb_trainable : bool - If true, scale_base is trainable - sparse_init : bool - if sparse_init = True, sparse initialization is applied. - - Returns: - -------- - self - - Example - ------- - The following is an example from the original `pykan` library, adapted here - for illustration within the PyTorch Forecasting integration. - - Install the `pykan` package first: - >>> pip install pykan - Then use: - - >>> from kan.KANLayer import * - >>> model = KANLayer(in_dim=3, out_dim=5) - >>> (model.in_dim, model.out_dim) - """ super().__init__() # Handle mutable parameters @@ -318,26 +314,26 @@ def forward(self, x): """ KANLayer forward given input x - Args: + Parameters ----- x : torch.Tensor Input tensor of shape (batch_size, in_dim), where: - batch_size is the number of input samples. - in_dim is the input feature dimension. - Returns: + Returns -------- y : torch.Tensor Output tensor, the result of applying spline and residual transformations followed by weighted summation. - Example - ------- + Examples + -------- The following is an example from the original `pykan` library, adapted here for illustration within the PyTorch Forecasting integration. Install the `pykan` package first: - >>> pip install pykan + pip install pykan Then use: >>> from kan.KANLayer import * @@ -359,18 +355,18 @@ def forward(self, x): def update_grid_from_samples(self, x): """ - update grid from samples + Update grid from samples - Args: + Parameters ----- - x : 2D torch.float - inputs, shape (number of samples, input dimension) + x : 2D torch.float + inputs, shape (number of samples, input dimension) Returns: -------- - None + None - Example + Examples ------- >>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) >>> print(model.grid.data) @@ -388,15 +384,15 @@ def get_grid(num_interval): """ Generate adaptive or uniform grid points from sorted input samples. - Args: + Parameters ----- - num_interval : int - Number of intervals between grid points. + num_interval : int + Number of intervals between grid points. Returns: -------- - grid : torch.Tensor - New grid of shape (in_dim, num_interval + 1). + grid : torch.Tensor + New grid of shape (in_dim, num_interval + 1). """ ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] grid_adaptive = x_pos[ids, :].permute(1, 0) diff --git a/pytorch_forecasting/models/nbeats/_nbeats.py b/pytorch_forecasting/models/nbeats/_nbeats.py index 68004efff..02609944d 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats.py +++ b/pytorch_forecasting/models/nbeats/_nbeats.py @@ -16,6 +16,78 @@ class NBeats(NBeatsAdapter): + """ + Initialize NBeats Model - use its :py:meth:`~from_dataset` method if possible. + + Based on the article + `N-BEATS: Neural basis expansion analysis for interpretable time series + forecasting `_. The network has (if + used as ensemble) outperformed all other methods including ensembles of + traditional statical methods in the M4 competition. The M4 competition is + arguably the most important benchmark for univariate time series forecasting. + + The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently + shown to consistently outperform N-BEATS. + + Parameters + ---------- + stack_types : list of str + One of the following values “generic”, “seasonality” or “trend”. + A list of strings of length 1 or `num_stacks`. Default and recommended + value for generic mode is ["generic"]. Recommended value for interpretable + mode is ["trend","seasonality"]. + num_blocks : list of int + The number of blocks per stack. Length 1 or `num_stacks`. Default for + generic mode is [1], interpretable mode is [3]. + num_block_layers : list of int + Number of fully connected layers with ReLU activation per block. Length 1 + or `num_stacks`. Default [4] for both modes. + width : list of int + Widths of fully connected layers with ReLU activation. List length 1 or + `num_stacks`. Default [512] for generic; [256, 2048] for interpretable. + sharing : list of bool + Whether weights are shared across blocks in a stack. List length 1 or + `num_stacks`. Default [False] for generic; [True] for interpretable. + expansion_coefficient_length : list of int + If type is "G", length of expansion coefficient; if "T", degree of + polynomial; if "S", minimum period (e.g., 2 for every timestep). List + length 1 or `num_stacks`. Default [32] for generic; [3] for interpretable. + prediction_length : int + Length of the forecast horizon. + context_length : int + Number of time units conditioning the predictions (lookback period). + Should be between 1-10x `prediction_length`. + dropout : float + Dropout probability applied in the network. Helps prevent overfitting. + Default is 0.1. + learning_rate : float + Learning rate used by the optimizer during training. Default is 1e-2. + log_interval : int + Interval (in steps) at which training logs are recorded. If -1, logging + is disabled. Default is -1. + log_gradient_flow : bool + Whether to log gradient flow during training. Useful for diagnosing + vanishing/exploding gradients. Default is False. + log_val_interval : int + Interval (in steps) at which validation metrics are logged. If None, + uses default logging behavior. Default is None. + weight_decay : float + Weight decay (L2 regularization) coefficient used by the optimizer to + reduce overfitting. Default is 1e-3. + loss + Loss to optimize. Defaults to `MASE()`. + reduce_on_plateau_patience : int + Patience after which learning rate is reduced by factor of 10. + backcast_loss_ratio : float + Weight of backcast loss relative to forecast loss. 1.0 gives equal weight; + default 0.0 means no backcast loss. + logging_metrics : nn.ModuleList of MultiHorizonMetric + List of metrics logged during training. Defaults to + nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]). + **kwargs + Additional arguments forwarded to :py:class:`~BaseModel`. + """ # noqa: E501 + def __init__( self, stack_types: Optional[list[str]] = None, @@ -38,67 +110,6 @@ def __init__( logging_metrics: nn.ModuleList = None, **kwargs, ): - """ - Initialize NBeats Model - use its :py:meth:`~from_dataset` method if possible. - - Based on the article - `N-BEATS: Neural basis expansion analysis for interpretable time series - forecasting `_. The network has (if - used as ensemble) outperformed all other methods including ensembles of - traditional statical methods in the M4 competition. The M4 competition is - arguably the most important benchmark for univariate time series forecasting. - - The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently - shown to consistently outperform N-BEATS. - - Args: - stack_types: One of the following values: “generic”, “seasonality" or - “trend". A list of strings of length 1 or 'num_stacks'. Default and - recommended value for generic mode: [“generic”] Recommended value for - interpretable mode: [“trend”,”seasonality”]. - num_blocks: The number of blocks per stack. A list of ints of length 1 or - 'num_stacks'. Default and recommended value for generic mode: [1] - Recommended value for interpretable mode: [3] - num_block_layers: Number of fully connected layers with ReLu activation per - block. - A list of ints of length 1 or 'num_stacks'. Default and recommended - value for generic mode: [4] Recommended value for interpretable mode: - [4]. - width: Widths of the fully connected layers with ReLu activation in the - blocks. A list of ints of length 1 or 'num_stacks'. Default and - recommended value for generic mode: [512]. Recommended value for - interpretable mode: [256, 2048] - sharing: Whether the weights are shared with the other blocks per stack. - A list of ints of length 1 or 'num_stacks'. Default and recommended - value for generic mode: [False]. Recommended value for interpretable - mode: [True]. - expansion_coefficient_length: If the type is “G” (generic), then the length - of the expansion coefficient. - If type is “T” (trend), then it corresponds to the degree of the - polynomial. - If the type is “S” (seasonal) then this is the minimum period allowed, - e.g. 2 for changes every timestep. A list of ints of length 1 or - 'num_stacks'. Default value for generic mode: [32] Recommended value for - interpretable mode: [3] - prediction_length: Length of the prediction. Also known as 'horizon'. - context_length: Number of time units that condition the predictions. - Also known as 'lookback period'. - Should be between 1-10 times the prediction length. - backcast_loss_ratio: weight of backcast in comparison to forecast when - calculating the loss. A weight of 1.0 means that forecast and - backcast loss is weighted the same (regardless of backcast and forecast - lengths). Defaults to 0.0, i.e. no weight. - loss: loss to optimize. Defaults to MASE(). - log_gradient_flow: if to log gradient flow, this takes time and should be - only done to diagnose training failures. - reduce_on_plateau_patience (int): patience after which learning rate is - reduced by a factor of 10 - logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that - are logged during training. Defaults to - nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) - **kwargs: additional arguments to :py:class:`~BaseModel`. - """ # noqa: E501 - if expansion_coefficient_lengths is None: expansion_coefficient_lengths = [3, 7] if sharing is None: diff --git a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py index a6ed1218c..df5a2a08e 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py +++ b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py @@ -18,28 +18,35 @@ class NBeatsAdapter(BaseModel): + """ + Initialize NBeats Adapter. + + Parameters + ---------- + **kwargs + additional arguments to :py:class:`~BaseModel`. + """ # noqa: E501 + def __init__( self, **kwargs, ): - """ - Initialize NBeats Adapter. - - Args: - **kwargs: additional arguments to :py:class:`~BaseModel`. - """ # noqa: E501 super().__init__(**kwargs) def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Pass forward of network. - Args: - x (Dict[str, torch.Tensor]): input from dataloader generated from - :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + Parameters + ---------- + x : dict of str to torch.Tensor + input from dataloader generated from + :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. - Returns: - Dict[str, torch.Tensor]: output of model + Returns + ------- + dict of str to torch.Tensor + output of model """ target = x["encoder_cont"][..., 0] @@ -110,12 +117,16 @@ def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): Convenience function to create network from :py:class `~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. - Args: - dataset (TimeSeriesDataSet): dataset where sole predictor is the target. - **kwargs: additional arguments to be passed to ``__init__`` method. + Parameters + ---------- + dataset : TimeSeriesDataSet + dataset where sole predictor is the target. + **kwargs + additional arguments to be passed to ``__init__`` method. - Returns: - NBeats + Returns + ------- + NBeats """ # noqa: E501 new_kwargs = { "prediction_length": dataset.max_prediction_length, @@ -240,18 +251,24 @@ def plot_interpretation( Plot two panels: prediction and backcast vs actuals and decomposition of prediction into trend, seasonality and generic forecast. - Args: - x (Dict[str, torch.Tensor]): network input - output (Dict[str, torch.Tensor]): network output - idx (int): index of sample for which to plot the interpretation. - ax (List[matplotlib axes], optional): list of two matplotlib axes onto which - to plot the interpretation. Defaults to None. - plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot - seasonality and generic forecast on secondary axis in second panel. - Defaults to False. - - Returns: - plt.Figure: matplotlib figure + Parameters + ---------- + x : dict of str to torch.Tensor + network input + output : dict of str to torch.Tensor + network output + idx : int + index of sample for which to plot the interpretation. + ax : list of matplotlib.axes + list of two matplotlib axes onto which to plot the interpretation. Defaults to None. + plot_seasonality_and_generic_on_secondary_axis : bool + if to plot seasonality and generic forecast on secondary axis in second panel. + Defaults to False. + + Returns + ------- + matplotlib.figure.Figure + matplotlib figure """ # noqa: E501 _check_matplotlib("plot_interpretation") diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan.py b/pytorch_forecasting/models/nbeats/_nbeatskan.py index 855755cec..7cee36d10 100644 --- a/pytorch_forecasting/models/nbeats/_nbeatskan.py +++ b/pytorch_forecasting/models/nbeats/_nbeatskan.py @@ -17,6 +17,136 @@ class NBeatsKAN(NBeatsAdapter): + """ + Initialize NBeats Model - use its :py:meth:`~from_dataset` method if possible. + + Based on the article + `N-BEATS: Neural basis expansion analysis for interpretable time series + forecasting `_. The network has (if + used as ensemble) outperformed all other methods including ensembles of + traditional statical methods in the M4 competition. The M4 competition is + arguably the most important benchmark for univariate time series forecasting. + + The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently + shown to consistently outperform N-BEATS. + + Parameters + ---------- + stack_types : list of str + One of the following values: “generic”, “seasonality" or + “trend". A list of strings of length 1 or 'num_stacks'. Default and + recommended value for generic mode: [“generic”] Recommended value for + interpretable mode: [“trend”,”seasonality”]. + num_blocks : list of int + The number of blocks per stack. A list of ints of length 1 or + 'num_stacks'. Default and recommended value for generic mode: [1] + Recommended value for interpretable mode: [3] + num_block_layers : list of int + Number of fully connected layers with ReLu activation per block. + A list of ints of length 1 or 'num_stacks'. Default and recommended + value for generic mode: [4] Recommended value for interpretable mode: + [4]. + widths : list of int + Widths of the fully connected layers with ReLu activation in the + blocks. A list of ints of length 1 or 'num_stacks'. Default and + recommended value for generic mode: [512]. Recommended value for + interpretable mode: [256, 2048] + sharing : list of bool + Whether the weights are shared with the other blocks per stack. + A list of ints of length 1 or 'num_stacks'. Default and recommended + value for generic mode: [False]. Recommended value for interpretable + mode: [True]. + expansion_coefficient_lengths : list of int + If the type is “G” (generic), then the length of the expansion coefficient. + If type is “T” (trend), then it corresponds to the degree of the + polynomial. + If the type is “S” (seasonal) then this is the minimum period allowed, + e.g. 2 for changes every timestep. A list of ints of length 1 or + 'num_stacks'. Default value for generic mode: [32] Recommended value for + interpretable mode: [3] + prediction_length : int + Length of the prediction. Also known as 'horizon'. + context_length : int + Number of time units that condition the predictions. + Also known as 'lookback period'. + Should be between 1-10 times the prediction length. + backcast_loss_ratio : float + Weight of backcast in comparison to forecast when calculating the loss. + A weight of 1.0 means that forecast and backcast loss is weighted the same + (regardless of backcast and forecast lengths). Defaults to 0.0, i.e. no weight. + loss : MultiHorizonMetric + Loss to optimize. Defaults to MASE(). + log_gradient_flow : bool + If to log gradient flow, this takes time and should be only done to diagnose + training failures. + reduce_on_plateau_patience : int + Patience after which learning rate is reduced by a factor of 10 + logging_metrics : nn.ModuleList of MultiHorizonMetric + List of metrics that are logged during training. Defaults to + nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + num : int + Parameter for KAN layer. the number of grid intervals = G. + Default: 5. + k : int + Parameter for KAN layer. the order of piecewise polynomial. Default: 3. + noise_scale : float + Parameter for KAN layer. the scale of noise injected at initialization. + Default: 0.1. + scale_base_mu : float + Parameter for KAN layer. the scale of the residual function b(x) is intialized + to be N(scale_base_mu, scale_base_sigma^2). Deafult: 0.0. + scale_base_sigma : float + Parameter for KAN layer. the scale of the residual function b(x) is intialized + to be N(scale_base_mu, scale_base_sigma^2). Deafult: 1.0. + scale_sp : float + Parameter for KAN layer. the scale of the base function spline(x). Deafult: 1.0. + base_fun : callable + Parameter for KAN layer. residual function b(x). Default: None. + grid_eps : float + Parameter for KAN layer. When grid_eps = 1, the grid is uniform; + when grid_eps = 0, the grid is partitioned using percentiles of samples. + 0 < grid_eps < 1 interpolates between the two extremes. Deafult: 0.02. + grid_range : list of int + Parameter for KAN layer. list/np.array of shape (2,). setting the range of grids. + Default: None. + sp_trainable : bool + Parameter for KAN layer. If true, scale_sp is trainable. Default: True. + sb_trainable : bool + Parameter for KAN layer. If true, scale_base is trainable. Default: True. + sparse_init : bool + Parameter for KAN layer. if sparse_init = True, sparse initialization is applied. + Default: False. + **kwargs + Additional arguments to :py:class:`~BaseModel`. + + Examples + -------- + See the full example in: + `examples/nbeats_with_kan.py` + + Notes + -------- + The KAN blocks are based on the Kolmogorov-Arnold representation theorem and replace fixed MLP edge weights + with learnable univariate spline functions. This allows KAN-augmented N-BEATS to better capture complex patterns, + improve interpretability, and achieve parameter efficiency. Additionally, when applied in a doubly-residual + adversarial framework, the model excels at zero-shot time-series forecasting across markets. + + Key differences from original N-BEATS: + - MLP layers are replaced by KAN layers with spline-based edge functions. + - Each weight is a trainable function, not a scalar. + - Enables visualization of learned functions and better domain adaptation. + - Yields improved accuracy and interpretability with fewer parameters. + + References + ---------- + .. [1] Z. Liu et al. (2024), “KAN: Kolmogorov-Arnold Networks” + propose replacing MLP weights with spline-based learnable edge functions, enabling improved accuracy, + interpretability, and scaling behavior compared to standard MLPs. + .. [2] A. Bhattacharya & N. Haq (2024), “Zero Shot Time Series Forecasting Using Kolmogorov Arnold Networks” + incorporate KAN layers into a doubly-residual N-BEATS architecture with adversarial domain adaptation, + achieving strong zero-shot cross-market electricity price forecasting performance. + """ # noqa: E501 + def __init__( self, stack_types: Optional[list[str]] = None, @@ -51,93 +181,6 @@ def __init__( sparse_init: bool = False, **kwargs, ): - """ - Initialize NBeats Model - use its :py:meth:`~from_dataset` method if possible. - - Based on the article - `N-BEATS: Neural basis expansion analysis for interpretable time series - forecasting `_. The network has (if - used as ensemble) outperformed all other methods including ensembles of - traditional statical methods in the M4 competition. The M4 competition is - arguably the most important benchmark for univariate time series forecasting. - - The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently - shown to consistently outperform N-BEATS. - - Args: - stack_types: One of the following values: “generic”, “seasonality" or - “trend". A list of strings of length 1 or 'num_stacks'. Default and - recommended value for generic mode: [“generic”] Recommended value for - interpretable mode: [“trend”,”seasonality”]. - num_blocks: The number of blocks per stack. A list of ints of length 1 or - 'num_stacks'. Default and recommended value for generic mode: [1] - Recommended value for interpretable mode: [3] - num_block_layers: Number of fully connected layers with ReLu activation per - block. - A list of ints of length 1 or 'num_stacks'. Default and recommended - value for generic mode: [4] Recommended value for interpretable mode: - [4]. - width: Widths of the fully connected layers with ReLu activation in the - blocks. A list of ints of length 1 or 'num_stacks'. Default and - recommended value for generic mode: [512]. Recommended value for - interpretable mode: [256, 2048] - sharing: Whether the weights are shared with the other blocks per stack. - A list of ints of length 1 or 'num_stacks'. Default and recommended - value for generic mode: [False]. Recommended value for interpretable - mode: [True]. - expansion_coefficient_length: If the type is “G” (generic), then the length - of the expansion coefficient. - If type is “T” (trend), then it corresponds to the degree of the - polynomial. - If the type is “S” (seasonal) then this is the minimum period allowed, - e.g. 2 for changes every timestep. A list of ints of length 1 or - 'num_stacks'. Default value for generic mode: [32] Recommended value for - interpretable mode: [3] - prediction_length: Length of the prediction. Also known as 'horizon'. - context_length: Number of time units that condition the predictions. - Also known as 'lookback period'. - Should be between 1-10 times the prediction length. - backcast_loss_ratio: weight of backcast in comparison to forecast when - calculating the loss. A weight of 1.0 means that forecast and - backcast loss is weighted the same (regardless of backcast and forecast - lengths). Defaults to 0.0, i.e. no weight. - loss: loss to optimize. Defaults to MASE(). - log_gradient_flow: if to log gradient flow, this takes time and should be - only done to diagnose training failures. - reduce_on_plateau_patience (int): patience after which learning rate is - reduced by a factor of 10 - logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that - are logged during training. Defaults to - nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) - num : Parameter for KAN layer. the number of grid intervals = G. - Default: 5. - k : Parameter for KAN layer. the order of piecewise polynomial. Default: 3. - noise_scale : Parameter for KAN layer. the scale of noise injected at - initialization. Default: 0.1. - scale_base_mu : Parameter for KAN layer. the scale of the residual - function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). - Deafult: 0.0. - scale_base_sigma : Parameter for KAN layer. the scale of the residual - function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). - Deafult: 1.0. - scale_sp : Parameter for KAN layer. the scale of the base function - spline(x). Deafult: 1.0. - base_fun : Parameter for KAN layer. residual function b(x). - Default: None. - grid_eps : Parameter for KAN layer. When grid_eps = 1, the grid is uniform; - when grid_eps = 0, the grid is partitioned using percentiles of samples. - 0 < grid_eps < 1 interpolates between the two extremes. Deafult: 0.02. - grid_range : Parameter for KAN layer. list/np.array of shape (2,). setting - the range of grids. Default: None. - sp_trainable : Parameter for KAN layer. If true, scale_sp is trainable. - Default: True. - sb_trainable : Parameter for KAN layer. If true, scale_base is trainable. - Default: True. - sparse_init : Parameter for KAN layer. if sparse_init = True, sparse - initialization is applied. Default: False. - **kwargs: additional arguments to :py:class:`~BaseModel`. - """ # noqa: E501 - if base_fun is None: base_fun = torch.nn.SiLU() if grid_range is None: @@ -223,6 +266,11 @@ def __init__( def update_kan_grid(self): """ Updates grid of KAN layers when using KAN layers in NBEATSBlock. + + Examples + -------- + See the full example in: + `examples/nbeats_with_kan.py` """ for block in self.net_blocks: # updation logic taken from diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py index 3002c6b3b..180b4a406 100644 --- a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py +++ b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py @@ -27,7 +27,25 @@ def get_model_cls(cls): @classmethod def get_test_train_params(cls): """Return testing parameter settings for the trainer.""" - return [{"backcast_loss_ratio": 1.0}] + return [ + {"backcast_loss_ratio": 0.0}, # pure forecast loss + {"backcast_loss_ratio": 1.0}, # equal forecast/backcast + { + "stack_types": ["generic"], + "expansion_coefficient_lengths": [16], + }, + { + "num_blocks": [1, 2], + "num_block_layers": [2, 3], + }, # varying block structure + { + "num": 7, + "k": 4, + "sparse_init": True, + "grid_range": [-0.5, 0.5], + "sp_trainable": False, + }, # complex KAN config + ] @classmethod def _get_test_dataloaders_from(cls, params): diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index e14d847c1..aec948949 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -44,6 +44,54 @@ def linspace( class NBEATSBlock(nn.Module): + """ + Initialize an N-BEATS block using either MLP or KAN layers. + + Parameters + ---------- + units : int + Number of units in each layer. + thetas_dim : int + Output dimension of the theta layers. + num_block_layers : int + Number of hidden layers in the block. Default is 4. + backcast_length : int + Length of the input (past) sequence. Default is 10. + forecast_length : int + Length of the output (future) sequence. Default is 5. + dropout : float + Dropout rate for regularization. Default is 0.1. + kan_params : dict + Dictionary of parameters for KAN layers. Only required if `use_kan=True`. + Default values will be used if not provided. Includes: + - num : int, default=5 + Number of grid intervals. + - k : int, default=3 + Order of piecewise polynomial. + - noise_scale : float, default=0.5 + Initialization noise scale. + - scale_base_mu : float, default=0.0 + Mean for residual function initialization. + - scale_base_sigma : float, default=1.0 + Std deviation for residual function initialization. + - scale_sp : float, default=1.0 + Scale for the spline function. + - base_fun : nn.Module, default=torch.nn.SiLU() + Base function module. + - grid_eps : float, default=0.02 + Determines grid spacing (0 for quantile, 1 for uniform). + - grid_range : list of float, default=[-1, 1] + Range of the spline grid. + - sp_trainable : bool, default=True + Whether scale_sp is trainable. + - sb_trainable : bool, default=True + Whether scale_base is trainable. + - sparse_init : bool, default=False + Whether to apply sparse initialization. + use_kan : bool + If True, uses KAN layers instead of MLP. Default is False. + """ + def __init__( self, units, @@ -55,42 +103,6 @@ def __init__( kan_params=None, use_kan=False, ): - """ - Initialize NBeatsSeasonalBlock - - Args: - units: The number of units in the mlp/kan layers. - thetas_dim: The dimension of the parameterized output for the block. - num_block_layers: Number of fully connected mlp/kan layers. Default: 4. - backcast_length: The length of the backcast. Defines how many time units - from the past are used to predict the future. Default: 10. - forecast_length: The length of the forecast, i.e., the number of time steps - ahead to predict. Default: 5. - dropout: The dropout rate applied to the fully connected mlp layers to - prevent overfitting. Default: 0.1. - kan_params (dict): Configuration dictionary for the KAN layer. Only - required if `use_kan=True`. If `kan_params` is not provided and - `use_kan=True`, default values will be used. Default: None. - Contains: - - num (int): Number of grid intervals. Default: 5. - - k (int): Order of the piecewise polynomial. Default: 3. - - noise_scale (float): Initialization noise scale. Default: 0.5. - - scale_base_mu (float): Mean for residual function init. - Default: 0.0. - - scale_base_sigma (float): Std for residual function init. - Default: 1.0. - - scale_sp (float): Scale for spline function. Default: 1.0. - - base_fun (nn.Module): Base function. Default: torch.nn.SiLU(). - - grid_eps (float): 0 → quantile grid, 1 → uniform. Default: 0.02. - - grid_range (list): Range of the spline grid. Default: [-1, 1]. - - sp_trainable (bool): Whether scale_sp is trainable. Default: True. - - sb_trainable (bool): Whether scale_base is trainable. - Default: True. - - sparse_init (bool): Apply sparse init to KAN. Default: False. - use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, - kan layers are used in nbeats block else mlp layers are used. Default: - false. - """ super().__init__() if use_kan and kan_params is None: @@ -158,7 +170,17 @@ def __init__( def forward(self, x): """ - Pass through the fully connected mlp/kan layers and returns the output. + Forward pass through the block using either MLP or KAN layers. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor after processing through the block. """ if self.use_kan: # save outputs to be used in updating grid in kan layers during training @@ -177,6 +199,33 @@ def forward(self, x): class NBEATSSeasonalBlock(NBEATSBlock): + """ + Initialize a Seasonal N-BEATS block with Fourier-based seasonality modeling. + + Parameters + ---------- + units : int + Number of units in each hidden layer. + thetas_dim : int + Output dimension of theta layers. Inferred from harmonics if not provided. + num_block_layers : int + Number of layers in the block. Default is 4. + backcast_length : int + Length of the input (past) sequence. Default is 10. + forecast_length : int + Length of the output (future) sequence. Default is 5. + nb_harmonics : int + Number of harmonics for Fourier features. Default is None. + min_period : int + Minimum period for seasonality. Default is 1. + dropout : float + Dropout rate. Default is 0.1. + kan_params : dict + Dictionary of KAN layer parameters. See NBEATSBlock for details. + use_kan : bool + If True, uses KAN instead of MLP. Default is False. + """ + def __init__( self, units, @@ -190,47 +239,6 @@ def __init__( kan_params=None, use_kan=False, ): - """ - Initialize NBeatsSeasonalBlock - - Args: - units: The number of units in the mlp/kan layers. - thetas_dim: The dimension of the parameterized output for the block. - If None, it is inferred. - num_block_layers: Number of fully connected mlp/kan layers. Default: 4. - backcast_length: The length of the backcast. Defines how many time units - from the past are used to predict the future. Default: 10. - forecast_length: The length of the forecast, i.e., the number of time steps - ahead to predict. Default: 5. - nb_harmonics: The number of harmonics in the seasonal function (relevant for - seasonal models). Default: None (no seasonality). - min_period: The minimum period used for seasonal patterns. Default: 1. - dropout: The dropout rate applied to the fully connected mlp layers to - prevent overfitting. Default: 0.1. - kan_params (dict): Parameters specific to the KAN layer - (used for modeling using KAN). Default: None. - Contains: - num_grids (int): The number of grid intervals for KAN. - k (int): The order of the piecewise polynomial for KAN. - noise_scale (float): The scale of noise injected at initialization. - scale_base_mu (float): The scale of the residual function - initialized to N(scale_base_mu, scale_base_sigma^2). - scale_base_sigma (float): The scale of the residual function - initialized to N(scale_base_mu, scale_base_sigma^2). - scale_sp (float): The scale of the base function spline(x) in KAN. - base_fun (function): The residual function used by - KAN (e.g., torch.nn.SiLU()). - grid_eps (float): Determines the partitioning of the grid. If 1, - the grid is uniform; if 0, grid is partitioned by percentiles. - grid_range (list or np.array): The range of the grid, given as - a list of two values. - sp_trainable (bool): If True, the scale_sp is trainable. - sb_trainable (bool): If True, the scale_base is trainable. - sparse_init (bool): If True, applies sparse initialization. - use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, - kan layers are used in nbeats block else mlp layers are used. Default: - false. - """ if nb_harmonics: thetas_dim = nb_harmonics else: @@ -279,7 +287,17 @@ def __init__( def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]: """ - Computes the backcast and forecast outputs for the given input tensor. + Compute seasonal backcast and forecast outputs using input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, backcast_length). + + Returns + ------- + tuple of torch.Tensor + Tuple (backcast, forecast), each of shape (batch_size, time_steps). """ x = super().forward(x) amplitudes_backward = self.theta_b_fc(x) @@ -299,6 +317,29 @@ def get_frequencies(self, n): class NBEATSTrendBlock(NBEATSBlock): + """ + Initialize a Trend N-BEATS block using polynomial basis functions. + + Parameters + ---------- + units : int + Number of units in each hidden layer. + thetas_dim : int + Output dimension of theta layers (number of polynomial terms). + num_block_layers : int + Number of hidden layers. Default is 4. + backcast_length : int + Length of input sequence. Default is 10. + forecast_length : int + Length of output sequence. Default is 5. + dropout : float + Dropout rate. Default is 0.1. + kan_params : dict + KAN layer parameters. See NBEATSBlock for details. + use_kan : bool + If True, uses KAN instead of MLP. Default is False. + """ + def __init__( self, units, @@ -310,44 +351,6 @@ def __init__( kan_params=None, use_kan=False, ): - """ - Initialize NBeatsSeasonalBlock - - Args: - units: The number of units in the mlp/kan layers. - thetas_dim: The dimension of the parameterized output for the block. - If None, it is inferred. - num_block_layers: Number of fully connected mlp/kan layers. Default: 4. - backcast_length: The length of the backcast. Defines how many time units - from the past are used to predict the future. Default: 10. - forecast_length: The length of the forecast, i.e., the number of time steps - ahead to predict. Default: 5. - dropout: The dropout rate applied to the fully connected mlp layers to - prevent overfitting. Default: 0.1. - kan_params (dict): Parameters specific to the KAN layer - (used for modeling using KAN). Default: None. - Contains: - num_grids (int): The number of grid intervals for KAN. - k (int): The order of the piecewise polynomial for KAN. - noise_scale (float): The scale of noise injected at initialization. - scale_base_mu (float): The scale of the residual function - initialized to N(scale_base_mu, scale_base_sigma^2). - scale_base_sigma (float): The scale of the residual function - initialized to N(scale_base_mu, scale_base_sigma^2). - scale_sp (float): The scale of the base function spline(x) in KAN. - base_fun (function): The residual function used by - KAN (e.g., torch.nn.SiLU()). - grid_eps (float): Determines the partitioning of the grid. If 1, - the grid is uniform; if 0, grid is partitioned by percentiles. - grid_range (list or np.array): The range of the grid, given as - a list of two values. - sp_trainable (bool): If True, the scale_sp is trainable. - sb_trainable (bool): If True, the scale_base is trainable. - sparse_init (bool): If True, applies sparse initialization. - use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, - kan layers are used in nbeats block else mlp layers are used. Default: - false. - """ super().__init__( units=units, thetas_dim=thetas_dim, @@ -379,8 +382,19 @@ def __init__( def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]: """ - Computes the backcast and forecast outputs for the given input tensor. + Compute backcast and forecast outputs using input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, backcast_length). + + Returns + ------- + tuple of torch.Tensor + Tuple (backcast, forecast). """ + x = super().forward(x) backcast = self.theta_b_fc(x).mm(self.T_backcast) forecast = self.theta_f_fc(x).mm(self.T_forecast) @@ -388,6 +402,29 @@ def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]: class NBEATSGenericBlock(NBEATSBlock): + """ + Initialize a Generic N-BEATS block using linear mapping of theta outputs. + + Parameters + ---------- + units : int + Number of units in each hidden layer. + thetas_dim : int + Dimension of the theta parameter. + num_block_layers : int + Number of hidden layers. Default is 4. + backcast_length : int + Length of past input. Default is 10. + forecast_length : int + Length of future prediction. Default is 5. + dropout : float + Dropout rate. Default is 0.1. + kan_params : dict + KAN layer parameters. See NBEATSBlock for details. + use_kan : bool + If True, uses KAN instead of MLP. Default is False. + """ + def __init__( self, units, @@ -399,44 +436,6 @@ def __init__( kan_params=None, use_kan=False, ): - """ - Initialize NBeatsSeasonalBlock - - Args: - units: The number of units in the mlp/kan layers. - thetas_dim: The dimension of the parameterized output for the block. - If None, it is inferred. - num_block_layers: Number of fully connected mlp/kan layers. Default: 4. - backcast_length: The length of the backcast. Defines how many time units - from the past are used to predict the future. Default: 10. - forecast_length: The length of the forecast, i.e., the number of time steps - ahead to predict. Default: 5. - dropout: The dropout rate applied to the fully connected mlp layers to - prevent overfitting. Default: 0.1. - kan_params (dict): Parameters specific to the KAN layer - (used for modeling using KAN). Default: None. - Contains: - num_grids (int): The number of grid intervals for KAN. - k (int): The order of the piecewise polynomial for KAN. - noise_scale (float): The scale of noise injected at initialization. - scale_base_mu (float): The scale of the residual function - initialized to N(scale_base_mu, scale_base_sigma^2). - scale_base_sigma (float): The scale of the residual function - initialized to N(scale_base_mu, scale_base_sigma^2). - scale_sp (float): The scale of the base function spline(x) in KAN. - base_fun (function): The residual function used by - KAN (e.g., torch.nn.SiLU()). - grid_eps (float): Determines the partitioning of the grid. If 1, - the grid is uniform; if 0, grid is partitioned by percentiles. - grid_range (list or np.array): The range of the grid, given as - a list of two values. - sp_trainable (bool): If True, the scale_sp is trainable. - sb_trainable (bool): If True, the scale_base is trainable. - sparse_init (bool): If True, applies sparse initialization. - use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, - kan layers are used in nbeats block else mlp layers are used. Default: - false. - """ super().__init__( units=units, thetas_dim=thetas_dim, @@ -453,7 +452,17 @@ def __init__( def forward(self, x): """ - Computes the backcast and forecast outputs for the given input tensor. + Compute backcast and forecast using using input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, backcast_length). + + Returns + ------- + tuple of torch.Tensor + Tuple (backcast, forecast). """ x = super().forward(x) theta_b = F.relu(self.theta_b_fc(x)) From ec4844e6a08a3b5284dbbd1fc5e589ec11e04c06 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Sun, 17 Aug 2025 04:02:03 -0700 Subject: [PATCH 14/21] Restructure KAN and NBeats layers to include them in pytorch_forecasting/layers for better maintainability. --- pytorch_forecasting/layers/_kan/__init__.py | 7 + .../nbeats => layers/_kan}/_kan_layer.py | 184 +---------------- pytorch_forecasting/layers/_kan/_utils.py | 185 ++++++++++++++++++ .../layers/_nbeats/__init__.py | 17 ++ .../_nbeats/_blocks.py} | 36 +--- pytorch_forecasting/layers/_nbeats/_utils.py | 39 ++++ pytorch_forecasting/models/nbeats/__init__.py | 8 - pytorch_forecasting/models/nbeats/_nbeats.py | 6 +- .../models/nbeats/_nbeats_adapter.py | 6 +- .../models/nbeats/_nbeatskan.py | 6 +- 10 files changed, 265 insertions(+), 229 deletions(-) create mode 100644 pytorch_forecasting/layers/_kan/__init__.py rename pytorch_forecasting/{models/nbeats => layers/_kan}/_kan_layer.py (58%) create mode 100644 pytorch_forecasting/layers/_kan/_utils.py create mode 100644 pytorch_forecasting/layers/_nbeats/__init__.py rename pytorch_forecasting/{models/nbeats/sub_modules.py => layers/_nbeats/_blocks.py} (92%) create mode 100644 pytorch_forecasting/layers/_nbeats/_utils.py diff --git a/pytorch_forecasting/layers/_kan/__init__.py b/pytorch_forecasting/layers/_kan/__init__.py new file mode 100644 index 000000000..55e296e87 --- /dev/null +++ b/pytorch_forecasting/layers/_kan/__init__.py @@ -0,0 +1,7 @@ +""" +KAN (Kolmogorov Arnold Network) layer implementation. +""" + +from pytorch_forecasting.layers._kan._kan_layer import KANLayer + +__all__ = ["KANLayer"] diff --git a/pytorch_forecasting/models/nbeats/_kan_layer.py b/pytorch_forecasting/layers/_kan/_kan_layer.py similarity index 58% rename from pytorch_forecasting/models/nbeats/_kan_layer.py rename to pytorch_forecasting/layers/_kan/_kan_layer.py index a54edf960..217d92e92 100644 --- a/pytorch_forecasting/models/nbeats/_kan_layer.py +++ b/pytorch_forecasting/layers/_kan/_kan_layer.py @@ -5,184 +5,12 @@ import torch import torch.nn as nn - -def b_batch(x, grid, k=0): - """ - Evaluate x on B-spline bases - - Parameters - ---------- - x : torch.Tensor - 2D tensor of inputs, shape (number of splines, number of samples). - grid : torch.Tensor - 2D tensor of grids, shape (number of splines, number of grid points). - k : int - The piecewise polynomial order of splines. - extend : bool - If True, k points are extended on both ends. If False, no extension - (zero boundary condition). Default: True. - - Returns - ------- - spline values : torch.Tensor - 3D tensor of shape (batch, in_dim, G+k), where G is the number of - grid intervals and k is the spline order. - - Examples - -------- - The following is an example from the original `pykan` library, adapted here - for illustration within the PyTorch Forecasting integration. - - Install the `pykan` package first: - pip install pykan - Then use: - - >>> from kan.spline import B_batch - >>> import torch - >>> x = torch.rand(100, 2) - >>> grid = torch.linspace(-1, 1, steps=11)[None, :].expand(2, 11) - >>> B_batch(x, grid, k=3).shape - """ - - x = x.unsqueeze(dim=2) - grid = grid.unsqueeze(dim=0) - - if k == 0: - value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:]) - else: - B_km1 = b_batch(x[:, :, 0], grid=grid[0], k=k - 1) - - value = (x - grid[:, :, : -(k + 1)]) / ( - grid[:, :, k:-1] - grid[:, :, : -(k + 1)] - ) * B_km1[:, :, :-1] + (grid[:, :, k + 1 :] - x) / ( - grid[:, :, k + 1 :] - grid[:, :, 1:(-k)] - ) * B_km1[:, :, 1:] - - # in case grid is degenerate - value = torch.nan_to_num(value) - return value - - -def coef2curve(x_eval, grid, coef, k): - """ - Converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves - (summing up b_batch results over B-spline basis). - - Parameters - ---------- - x_eval : torch.Tensor - 2D tensor of shape (batch, in_dim). - grid : torch.Tensor - 2D tensor of shape (in_dim, G+2k). G: the number of grid intervals; - k: spline order. - coef : torch.Tensor - 3D tensor of shape (in_dim, out_dim, G+k). - k : int - The piecewise polynomial order of splines. - - Returns - ------- - y_eval : torch.Tensor - 3D tensor of shape (batch, in_dim, out_dim). - """ - - b_splines = b_batch(x_eval, grid, k=k) - y_eval = torch.einsum("ijk,jlk->ijl", b_splines, coef.to(b_splines)) - - return y_eval - - -def curve2coef(x_eval, y_eval, grid, k): - """ - Estimate spline coefficients via batched least squares. - - Parameters - ---------- - x_eval : torch.Tensor - 2D tensor of shape (batch, in_dim). - y_eval : torch.Tensor - 3D tensor of shape (batch, in_dim, out_dim). - grid : torch.Tensor - 2D tensor of shape (in_dim, grid + 2 * k). - k : int - Spline order. - lamb : float - Regularized least square lambda. - - Returns - ------- - coef : torch.Tensor - 3D tensor of shape (in_dim, out_dim, G + k). - """ - batch = x_eval.shape[0] - in_dim = x_eval.shape[1] - out_dim = y_eval.shape[2] - n_coef = grid.shape[1] - k - 1 - mat = b_batch(x_eval, grid, k) - mat = mat.permute(1, 0, 2)[:, None, :, :].expand(in_dim, out_dim, batch, n_coef) - y_eval = y_eval.permute(1, 2, 0).unsqueeze(dim=3) - try: - coef = torch.linalg.lstsq(mat, y_eval).solution[:, :, :, 0] - except Exception as e: - print(f"lstsq failed with error: {e}") - - return coef - - -def extend_grid(grid, k_extend=0): - """ - Extend a grid tensor by padding both ends with equal spacing. - - Parameters - ---------- - grid : torch.Tensor - Grid of shape (in_dim, grid_points). - k_extend : int - Number of points to extend on both ends. - - Returns - ------- - grid : torch.Tensor - Extended grid of shape (in_dim, grid_points + 2 * k_extend). - """ - h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) - - for i in range(k_extend): - grid = torch.cat([grid[:, [0]] - h, grid], dim=1) - grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) - - return grid - - -def sparse_mask(in_dim, out_dim): - """ - Generate a sparse connection mask between input and output units. - - Parameters - ---------- - in_dim : int - Number of input units. - out_dim : int - Number of output units. - - Returns - ------- - mask : torch.Tensor - Sparse binary mask of shape (in_dim, out_dim). - """ - in_coord = torch.arange(in_dim) * 1 / in_dim + 1 / (2 * in_dim) - out_coord = torch.arange(out_dim) * 1 / out_dim + 1 / (2 * out_dim) - - dist_mat = torch.abs(out_coord[:, None] - in_coord[None, :]) - in_nearest = torch.argmin(dist_mat, dim=0) - in_connection = torch.stack([torch.arange(in_dim), in_nearest]).permute(1, 0) - out_nearest = torch.argmin(dist_mat, dim=1) - out_connection = torch.stack([out_nearest, torch.arange(out_dim)]).permute(1, 0) - all_connection = torch.cat([in_connection, out_connection], dim=0) - mask = torch.zeros(in_dim, out_dim) - mask[all_connection[:, 0], all_connection[:, 1]] = 1.0 - - return mask +from pytorch_forecasting.layers._kan._utils import ( + coef2curve, + curve2coef, + extend_grid, + sparse_mask, +) class KANLayer(nn.Module): diff --git a/pytorch_forecasting/layers/_kan/_utils.py b/pytorch_forecasting/layers/_kan/_utils.py new file mode 100644 index 000000000..705a3f6a5 --- /dev/null +++ b/pytorch_forecasting/layers/_kan/_utils.py @@ -0,0 +1,185 @@ +""" +Utility functions for KAN (Kolmogorov Arnold Network) Layer. +Contains B-spline computations, curve transformations, and grid manipulation functions. +""" + +import torch + + +def b_batch(x, grid, k=0): + """ + Evaluate x on B-spline bases + + Parameters + ---------- + x : torch.Tensor + 2D tensor of inputs, shape (number of splines, number of samples). + grid : torch.Tensor + 2D tensor of grids, shape (number of splines, number of grid points). + k : int + The piecewise polynomial order of splines. + extend : bool + If True, k points are extended on both ends. If False, no extension + (zero boundary condition). Default: True. + + Returns + ------- + spline values : torch.Tensor + 3D tensor of shape (batch, in_dim, G+k), where G is the number of + grid intervals and k is the spline order. + + Examples + -------- + The following is an example from the original `pykan` library, adapted here + for illustration within the PyTorch Forecasting integration. + + Install the `pykan` package first: + pip install pykan + Then use: + + >>> from kan.spline import B_batch + >>> import torch + >>> x = torch.rand(100, 2) + >>> grid = torch.linspace(-1, 1, steps=11)[None, :].expand(2, 11) + >>> B_batch(x, grid, k=3).shape + """ + + x = x.unsqueeze(dim=2) + grid = grid.unsqueeze(dim=0) + + if k == 0: + value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:]) + else: + B_km1 = b_batch(x[:, :, 0], grid=grid[0], k=k - 1) + + value = (x - grid[:, :, : -(k + 1)]) / ( + grid[:, :, k:-1] - grid[:, :, : -(k + 1)] + ) * B_km1[:, :, :-1] + (grid[:, :, k + 1 :] - x) / ( + grid[:, :, k + 1 :] - grid[:, :, 1:(-k)] + ) * B_km1[:, :, 1:] + + # in case grid is degenerate + value = torch.nan_to_num(value) + return value + + +def coef2curve(x_eval, grid, coef, k): + """ + Converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves + (summing up b_batch results over B-spline basis). + + Parameters + ---------- + x_eval : torch.Tensor + 2D tensor of shape (batch, in_dim). + grid : torch.Tensor + 2D tensor of shape (in_dim, G+2k). G: the number of grid intervals; + k: spline order. + coef : torch.Tensor + 3D tensor of shape (in_dim, out_dim, G+k). + k : int + The piecewise polynomial order of splines. + + Returns + ------- + y_eval : torch.Tensor + 3D tensor of shape (batch, in_dim, out_dim). + """ + + b_splines = b_batch(x_eval, grid, k=k) + y_eval = torch.einsum("ijk,jlk->ijl", b_splines, coef.to(b_splines)) + + return y_eval + + +def curve2coef(x_eval, y_eval, grid, k): + """ + Estimate spline coefficients via batched least squares. + + Parameters + ---------- + x_eval : torch.Tensor + 2D tensor of shape (batch, in_dim). + y_eval : torch.Tensor + 3D tensor of shape (batch, in_dim, out_dim). + grid : torch.Tensor + 2D tensor of shape (in_dim, grid + 2 * k). + k : int + Spline order. + lamb : float + Regularized least square lambda. + + Returns + ------- + coef : torch.Tensor + 3D tensor of shape (in_dim, out_dim, G + k). + """ + batch = x_eval.shape[0] + in_dim = x_eval.shape[1] + out_dim = y_eval.shape[2] + n_coef = grid.shape[1] - k - 1 + mat = b_batch(x_eval, grid, k) + mat = mat.permute(1, 0, 2)[:, None, :, :].expand(in_dim, out_dim, batch, n_coef) + y_eval = y_eval.permute(1, 2, 0).unsqueeze(dim=3) + try: + coef = torch.linalg.lstsq(mat, y_eval).solution[:, :, :, 0] + except Exception as e: + print(f"lstsq failed with error: {e}") + + return coef + + +def extend_grid(grid, k_extend=0): + """ + Extend a grid tensor by padding both ends with equal spacing. + + Parameters + ---------- + grid : torch.Tensor + Grid of shape (in_dim, grid_points). + k_extend : int + Number of points to extend on both ends. + + Returns + ------- + grid : torch.Tensor + Extended grid of shape (in_dim, grid_points + 2 * k_extend). + """ + h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) + + for i in range(k_extend): + grid = torch.cat([grid[:, [0]] - h, grid], dim=1) + grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) + + return grid + + +def sparse_mask(in_dim, out_dim): + """ + Generate a sparse connection mask between input and output units. + + Parameters + ---------- + in_dim : int + Number of input units. + out_dim : int + Number of output units. + + Returns + ------- + mask : torch.Tensor + Sparse binary mask of shape (in_dim, out_dim). + """ + in_coord = torch.arange(in_dim) * 1 / in_dim + 1 / (2 * in_dim) + out_coord = torch.arange(out_dim) * 1 / out_dim + 1 / (2 * out_dim) + + dist_mat = torch.abs(out_coord[:, None] - in_coord[None, :]) + in_nearest = torch.argmin(dist_mat, dim=0) + in_connection = torch.stack([torch.arange(in_dim), in_nearest]).permute(1, 0) + out_nearest = torch.argmin(dist_mat, dim=1) + out_connection = torch.stack([out_nearest, torch.arange(out_dim)]).permute(1, 0) + all_connection = torch.cat([in_connection, out_connection], dim=0) + mask = torch.zeros(in_dim, out_dim) + mask[all_connection[:, 0], all_connection[:, 1]] = 1.0 + + return mask diff --git a/pytorch_forecasting/layers/_nbeats/__init__.py b/pytorch_forecasting/layers/_nbeats/__init__.py new file mode 100644 index 000000000..daf47de2d --- /dev/null +++ b/pytorch_forecasting/layers/_nbeats/__init__.py @@ -0,0 +1,17 @@ +""" +Implementation of N-BEATS model blocks and utilities. +""" + +from pytorch_forecasting.layers._nbeats._blocks import ( + NBEATSBlock, + NBEATSGenericBlock, + NBEATSSeasonalBlock, + NBEATSTrendBlock, +) + +__all__ = [ + "NBEATSBlock", + "NBEATSGenericBlock", + "NBEATSSeasonalBlock", + "NBEATSTrendBlock", +] diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/layers/_nbeats/_blocks.py similarity index 92% rename from pytorch_forecasting/models/nbeats/sub_modules.py rename to pytorch_forecasting/layers/_nbeats/_blocks.py index aec948949..80e7657da 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/layers/_nbeats/_blocks.py @@ -7,40 +7,8 @@ import torch.nn as nn import torch.nn.functional as F -from pytorch_forecasting.models.nbeats._kan_layer import KANLayer - - -def linear(input_size, output_size, bias=True, dropout: int = None): - """ - Initialize linear layers for MLP block layers. - """ - lin = nn.Linear(input_size, output_size, bias=bias) - if dropout is not None: - return nn.Sequential(nn.Dropout(dropout), lin) - else: - return lin - - -def linspace( - backcast_length: int, forecast_length: int, centered: bool = False -) -> tuple[np.ndarray, np.ndarray]: - """ - Generate linear spaced values for backcast and forecast. - """ - if centered: - norm = max(backcast_length, forecast_length) - start = -backcast_length - stop = forecast_length - 1 - else: - norm = backcast_length + forecast_length - start = 0 - stop = backcast_length + forecast_length - 1 - lin_space = np.linspace( - start / norm, stop / norm, backcast_length + forecast_length, dtype=np.float32 - ) - b_ls = lin_space[:backcast_length] - f_ls = lin_space[backcast_length:] - return b_ls, f_ls +from pytorch_forecasting.layers._kan._kan_layer import KANLayer +from pytorch_forecasting.layers._nbeats._utils import linear, linspace class NBEATSBlock(nn.Module): diff --git a/pytorch_forecasting/layers/_nbeats/_utils.py b/pytorch_forecasting/layers/_nbeats/_utils.py new file mode 100644 index 000000000..0b884d4e1 --- /dev/null +++ b/pytorch_forecasting/layers/_nbeats/_utils.py @@ -0,0 +1,39 @@ +""" +Utility functions for N-BEATS model implementation. +""" + +import numpy as np +import torch.nn as nn + + +def linear(input_size, output_size, bias=True, dropout: int = None): + """ + Initialize linear layers for MLP block layers. + """ + lin = nn.Linear(input_size, output_size, bias=bias) + if dropout is not None: + return nn.Sequential(nn.Dropout(dropout), lin) + else: + return lin + + +def linspace( + backcast_length: int, forecast_length: int, centered: bool = False +) -> tuple[np.ndarray, np.ndarray]: + """ + Generate linear spaced values for backcast and forecast. + """ + if centered: + norm = max(backcast_length, forecast_length) + start = -backcast_length + stop = forecast_length - 1 + else: + norm = backcast_length + forecast_length + start = 0 + stop = backcast_length + forecast_length - 1 + lin_space = np.linspace( + start / norm, stop / norm, backcast_length + forecast_length, dtype=np.float32 + ) + b_ls = lin_space[:backcast_length] + f_ls = lin_space[backcast_length:] + return b_ls, f_ls diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index 21537fadd..b88046780 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -6,20 +6,12 @@ from pytorch_forecasting.models.nbeats._nbeats_pkg import NBeats_pkg from pytorch_forecasting.models.nbeats._nbeatskan import NBeatsKAN from pytorch_forecasting.models.nbeats._nbeatskan_pkg import NBeatsKAN_pkg -from pytorch_forecasting.models.nbeats.sub_modules import ( - NBEATSGenericBlock, - NBEATSSeasonalBlock, - NBEATSTrendBlock, -) __all__ = [ "NBeats", "NBeatsKAN", - "NBEATSGenericBlock", "NBeats_pkg", "NBeatsKAN_pkg", - "NBEATSSeasonalBlock", - "NBEATSTrendBlock", "NBeatsAdapter", "GridUpdateCallback", ] diff --git a/pytorch_forecasting/models/nbeats/_nbeats.py b/pytorch_forecasting/models/nbeats/_nbeats.py index cdf7e8920..5a160e2cc 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats.py +++ b/pytorch_forecasting/models/nbeats/_nbeats.py @@ -6,13 +6,13 @@ from torch import nn -from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric -from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter -from pytorch_forecasting.models.nbeats.sub_modules import ( +from pytorch_forecasting.layers._nbeats._blocks import ( NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock, ) +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric +from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter class NBeats(NBeatsAdapter): diff --git a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py index df5a2a08e..9ec8d6324 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py +++ b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py @@ -8,12 +8,12 @@ from pytorch_forecasting.data import TimeSeriesDataSet from pytorch_forecasting.data.encoders import NaNLabelEncoder -from pytorch_forecasting.metrics import MASE -from pytorch_forecasting.models.base_model import BaseModel -from pytorch_forecasting.models.nbeats.sub_modules import ( +from pytorch_forecasting.layers._nbeats._blocks import ( NBEATSSeasonalBlock, NBEATSTrendBlock, ) +from pytorch_forecasting.metrics import MASE +from pytorch_forecasting.models.base_model import BaseModel from pytorch_forecasting.utils._dependencies import _check_matplotlib diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan.py b/pytorch_forecasting/models/nbeats/_nbeatskan.py index 7cee36d10..5fffd474b 100644 --- a/pytorch_forecasting/models/nbeats/_nbeatskan.py +++ b/pytorch_forecasting/models/nbeats/_nbeatskan.py @@ -7,13 +7,13 @@ import torch from torch import nn -from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric -from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter -from pytorch_forecasting.models.nbeats.sub_modules import ( +from pytorch_forecasting.layers._nbeats._blocks import ( NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock, ) +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric +from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter class NBeatsKAN(NBeatsAdapter): From 698f2427de515a6317a221f91027edd2c5c96be4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 26 Aug 2025 16:26:14 +0200 Subject: [PATCH 15/21] rename get_cls --- pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py index 180b4a406..62315c9c9 100644 --- a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py +++ b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py @@ -18,7 +18,7 @@ class NBeatsKAN_pkg(_BasePtForecaster): } @classmethod - def get_model_cls(cls): + def get_cls(cls): """Get model class.""" from pytorch_forecasting.models import NBeatsKAN From 1570e021432deed879259ee3d5b8b41d0269d281 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 26 Aug 2025 17:51:38 +0200 Subject: [PATCH 16/21] add _pkg pointer --- pytorch_forecasting/models/nbeats/_nbeatskan.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan.py b/pytorch_forecasting/models/nbeats/_nbeatskan.py index 5fffd474b..555483734 100644 --- a/pytorch_forecasting/models/nbeats/_nbeatskan.py +++ b/pytorch_forecasting/models/nbeats/_nbeatskan.py @@ -147,6 +147,13 @@ class NBeatsKAN(NBeatsAdapter): achieving strong zero-shot cross-market electricity price forecasting performance. """ # noqa: E501 + @classmethod + def _pkg(cls): + """Package for the model.""" + from pytorch_forecasting.models.nbeats._nbeatskan_pkg import NBeatsKAN_pkg + + return NBeatsKAN_pkg + def __init__( self, stack_types: Optional[list[str]] = None, From a6929b1d4c504a3dcfe498629b147c7f7b112d85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 27 Aug 2025 06:29:48 +0200 Subject: [PATCH 17/21] Update _nbeatskan_pkg.py --- pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py index 62315c9c9..c10190028 100644 --- a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py +++ b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py @@ -25,8 +25,17 @@ def get_cls(cls): return NBeatsKAN @classmethod - def get_test_train_params(cls): - """Return testing parameter settings for the trainer.""" + def get_base_test_params(cls): + """Return testing parameter settings for the trainer. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ return [ {"backcast_loss_ratio": 0.0}, # pure forecast loss {"backcast_loss_ratio": 1.0}, # equal forecast/backcast From d61b2b5b218bb43f9cbdbc566944f13041359e83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 27 Aug 2025 06:59:24 +0200 Subject: [PATCH 18/21] Update _nbeatskan_pkg.py --- pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py index c10190028..dd1d35058 100644 --- a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py +++ b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py @@ -9,6 +9,8 @@ class NBeatsKAN_pkg(_BasePtForecaster): _tags = { "info:name": "NBeatsKAN", "info:compute": 1, + "info:pred_type": ["point"], + "info:y_type": ["numeric"], "authors": ["Sohaib-Ahmed21"], "capability:exogenous": False, "capability:multivariate": False, From 92213aabf1d337a84acc884b3d2525423c726b39 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Wed, 27 Aug 2025 12:01:50 -0700 Subject: [PATCH 19/21] Solve failing TweedieLoss test with NBeatsKAN --- .../models/nbeats/_nbeatskan_pkg.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py index dd1d35058..2cda8c996 100644 --- a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py +++ b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py @@ -60,9 +60,24 @@ def get_base_test_params(cls): @classmethod def _get_test_dataloaders_from(cls, params): - """Get dataloaders from parameters.""" + loss = params.get("loss", None) + data_loader_kwargs = params.get("data_loader_kwargs", {}) + from pytorch_forecasting.metrics import TweedieLoss from pytorch_forecasting.tests._data_scenarios import ( + data_with_covariates, dataloaders_fixed_window_without_covariates, + make_dataloaders, ) + if isinstance(loss, TweedieLoss): + dwc = data_with_covariates() + dl_default_kwargs = dict( + target="target", + time_varying_unknown_reals=["target"], + add_relative_time_idx=False, + ) + dl_default_kwargs.update(data_loader_kwargs) + dataloaders_with_covariates = make_dataloaders(dwc, **dl_default_kwargs) + return dataloaders_with_covariates + return dataloaders_fixed_window_without_covariates() From 8212b2da8944e67904d90dccd5e3403fb1042cf4 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Wed, 27 Aug 2025 20:41:31 -0700 Subject: [PATCH 20/21] Adjust docstring example of b_batch function --- pytorch_forecasting/layers/_kan/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/layers/_kan/_utils.py b/pytorch_forecasting/layers/_kan/_utils.py index 705a3f6a5..fbbcb65d4 100644 --- a/pytorch_forecasting/layers/_kan/_utils.py +++ b/pytorch_forecasting/layers/_kan/_utils.py @@ -37,11 +37,12 @@ def b_batch(x, grid, k=0): pip install pykan Then use: - >>> from kan.spline import B_batch + >>> from pytorch_forecasting.layers._kan._utils import b_batch >>> import torch >>> x = torch.rand(100, 2) >>> grid = torch.linspace(-1, 1, steps=11)[None, :].expand(2, 11) - >>> B_batch(x, grid, k=3).shape + >>> b_batch(x, grid, k=3).shape + torch.Size([100, 2, 7]) """ x = x.unsqueeze(dim=2) From 086132289535219e456a0448616202a8ac0ee11b Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Mon, 1 Sep 2025 12:17:50 -0700 Subject: [PATCH 21/21] Add compatibility imports for NBEATS' blocks --- pytorch_forecasting/models/nbeats/__init__.py | 15 ++++++++++++++- pytorch_forecasting/models/nbeats/sub_modules.py | 12 ++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 pytorch_forecasting/models/nbeats/sub_modules.py diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index b88046780..5377ff67f 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -1,5 +1,15 @@ -"""N-Beats model for timeseries forecasting without covariates.""" +""" +N-Beats model for timeseries forecasting without covariates. +# TODO v2: remove compatibility imports, kept to avoid breaking existing code. +""" + +# Import blocks from new location for backward compatibility +from pytorch_forecasting.layers._nbeats._blocks import ( + NBEATSGenericBlock, + NBEATSSeasonalBlock, + NBEATSTrendBlock, +) from pytorch_forecasting.models.nbeats._grid_callback import GridUpdateCallback from pytorch_forecasting.models.nbeats._nbeats import NBeats from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter @@ -12,6 +22,9 @@ "NBeatsKAN", "NBeats_pkg", "NBeatsKAN_pkg", + "NBEATSGenericBlock", + "NBEATSSeasonalBlock", + "NBEATSTrendBlock", "NBeatsAdapter", "GridUpdateCallback", ] diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py new file mode 100644 index 000000000..c7ec58972 --- /dev/null +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -0,0 +1,12 @@ +""" +Backward-compatibility shim for N-BEATS blocks. +Real implementations live in `pytorch_forecasting.layers._nbeats._blocks`. + +# TODO v2: remove this file. +""" + +from pytorch_forecasting.layers._nbeats._blocks import ( + NBEATSGenericBlock, + NBEATSSeasonalBlock, + NBEATSTrendBlock, +)