diff --git a/pytorch_forecasting/models/nbeats/_nbeats_pkg.py b/pytorch_forecasting/models/nbeats/_nbeats_pkg.py index daeab1c4e..acf300439 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats_pkg.py +++ b/pytorch_forecasting/models/nbeats/_nbeats_pkg.py @@ -17,6 +17,7 @@ class NBeats_pkg(_BasePtForecaster): "capability:pred_int": False, "capability:flexible_history_length": False, "capability:cold_start": False, + "tests:skip_by_name": "test_integration", } @classmethod diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py index 2cda8c996..1ccad7b72 100644 --- a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py +++ b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py @@ -17,6 +17,7 @@ class NBeatsKAN_pkg(_BasePtForecaster): "capability:pred_int": False, "capability:flexible_history_length": False, "capability:cold_start": False, + "tests:skip_by_name": "test_integration", } @classmethod diff --git a/pytorch_forecasting/models/timexer/_timexer_pkg.py b/pytorch_forecasting/models/timexer/_timexer_pkg.py index d91e81d6c..fb7b78430 100644 --- a/pytorch_forecasting/models/timexer/_timexer_pkg.py +++ b/pytorch_forecasting/models/timexer/_timexer_pkg.py @@ -17,6 +17,7 @@ class TimeXer_pkg(_BasePtForecaster): "capability:pred_int": True, "capability:flexible_history_length": True, "capability:cold_start": False, + "tests:skip_by_name": "test_integration", } @classmethod diff --git a/pytorch_forecasting/models/xlstm/_xlstm_pkg.py b/pytorch_forecasting/models/xlstm/_xlstm_pkg.py index 1a10fe660..94934046b 100644 --- a/pytorch_forecasting/models/xlstm/_xlstm_pkg.py +++ b/pytorch_forecasting/models/xlstm/_xlstm_pkg.py @@ -17,6 +17,7 @@ class xLSTMTime_pkg(_BasePtForecaster): "capability:pred_int": False, "capability:flexible_history_length": True, "capability:cold_start": False, + "tests:skip_by_name": "test_integration", } @classmethod diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index d8eb7d81e..d514f87e0 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -334,27 +334,22 @@ def _integration( output = raw_predictions.output.prediction n_dims = len(output.shape) - assert n_dims in [2, 3], ( - f"Prediction output must be 2D or 3D, but got {n_dims}D tensor " + assert n_dims == 3, ( + f"Prediction output must be 3D, but got {n_dims}D tensor " f"with shape {output.shape}" ) - if n_dims == 2: - batch_size, prediction_length = output.shape - assert batch_size > 0, f"Batch size must be positive, got {batch_size}" - assert ( - prediction_length > 0 - ), f"Prediction length must be positive, got {prediction_length}" - - elif n_dims == 3: - batch_size, prediction_length, n_features = output.shape - assert batch_size > 0, f"Batch size must be positive, got {batch_size}" - assert ( - prediction_length > 0 - ), f"Prediction length must be positive, got {prediction_length}" - assert ( - n_features > 0 - ), f"Number of features must be positive, got {n_features}" + batch_size, prediction_length, n_features = output.shape + assert batch_size > 0, f"Batch size must be positive, got {batch_size}" + assert ( + prediction_length > 0 + ), f"Prediction length must be positive, got {prediction_length}" + assert ( + # todo: compare n_features with expected 3rd dimension of the corresponding + # loss function on which model is trained and + # predictions generated in this test. + n_features > 0 # this should be n_features == net.loss.expected_dim + ), f"Number of features must be positive, got {n_features}" finally: shutil.rmtree(tmp_path, ignore_errors=True)