Skip to content
1 change: 1 addition & 0 deletions pytorch_forecasting/models/nbeats/_nbeats_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class NBeats_pkg(_BasePtForecaster):
"capability:pred_int": False,
"capability:flexible_history_length": False,
"capability:cold_start": False,
"tests:skip_by_name": "test_integration",
}

@classmethod
Expand Down
1 change: 1 addition & 0 deletions pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class NBeatsKAN_pkg(_BasePtForecaster):
"capability:pred_int": False,
"capability:flexible_history_length": False,
"capability:cold_start": False,
"tests:skip_by_name": "test_integration",
}

@classmethod
Expand Down
1 change: 1 addition & 0 deletions pytorch_forecasting/models/timexer/_timexer_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TimeXer_pkg(_BasePtForecaster):
"capability:pred_int": True,
"capability:flexible_history_length": True,
"capability:cold_start": False,
"tests:skip_by_name": "test_integration",
}

@classmethod
Expand Down
1 change: 1 addition & 0 deletions pytorch_forecasting/models/xlstm/_xlstm_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class xLSTMTime_pkg(_BasePtForecaster):
"capability:pred_int": False,
"capability:flexible_history_length": True,
"capability:cold_start": False,
"tests:skip_by_name": "test_integration",
}

@classmethod
Expand Down
31 changes: 13 additions & 18 deletions pytorch_forecasting/tests/test_all_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,27 +334,22 @@ def _integration(
output = raw_predictions.output.prediction
n_dims = len(output.shape)

assert n_dims in [2, 3], (
f"Prediction output must be 2D or 3D, but got {n_dims}D tensor "
assert n_dims == 3, (
f"Prediction output must be 3D, but got {n_dims}D tensor "
f"with shape {output.shape}"
)

if n_dims == 2:
batch_size, prediction_length = output.shape
assert batch_size > 0, f"Batch size must be positive, got {batch_size}"
assert (
prediction_length > 0
), f"Prediction length must be positive, got {prediction_length}"

elif n_dims == 3:
batch_size, prediction_length, n_features = output.shape
assert batch_size > 0, f"Batch size must be positive, got {batch_size}"
assert (
prediction_length > 0
), f"Prediction length must be positive, got {prediction_length}"
assert (
n_features > 0
), f"Number of features must be positive, got {n_features}"
batch_size, prediction_length, n_features = output.shape
assert batch_size > 0, f"Batch size must be positive, got {batch_size}"
assert (
prediction_length > 0
), f"Prediction length must be positive, got {prediction_length}"
assert (
# todo: compare n_features with expected 3rd dimension of the corresponding
# loss function on which model is trained and
# predictions generated in this test.
n_features > 0 # this should be n_features == net.loss.expected_dim
), f"Number of features must be positive, got {n_features}"
finally:
shutil.rmtree(tmp_path, ignore_errors=True)

Expand Down
Loading