From 88ba01fcd06d8466d720558c1aa9e3d8b9dd169d Mon Sep 17 00:00:00 2001 From: Gavin Evans Date: Tue, 23 Sep 2025 11:06:49 +0100 Subject: [PATCH] Correction to handle providing no parquet files when training, and addition of CLI tests. --- ...train_quantile_regression_random_forest.py | 4 +- ...train_quantile_regression_random_forest.py | 2 + improver_tests/acceptance/SHA256SUMS | 1 + ...apply_quantile_regression_random_forest.py | 28 ++++++++++++- ...train_quantile_regression_random_forest.py | 40 +++++++++++++++++++ ...train_quantile_regression_random_forest.py | 2 +- 6 files changed, 73 insertions(+), 4 deletions(-) diff --git a/improver/calibration/load_and_train_quantile_regression_random_forest.py b/improver/calibration/load_and_train_quantile_regression_random_forest.py index 5cf9b8f5a8..81626d9154 100644 --- a/improver/calibration/load_and_train_quantile_regression_random_forest.py +++ b/improver/calibration/load_and_train_quantile_regression_random_forest.py @@ -270,13 +270,13 @@ def process( # If there are no parquet files, return None. if not parquets: - return None + return None, None, None forecast_table_path, truth_table_path = identify_parquet_type(parquets) # If either the forecast or truth parquet files are missing, return None. if not forecast_table_path or not truth_table_path: - return None + return None, None, None forecast_periods = self._parse_forecast_periods() forecast_df, truth_df = self._read_parquet_files( diff --git a/improver/cli/train_quantile_regression_random_forest.py b/improver/cli/train_quantile_regression_random_forest.py index b69dd1f87e..e018d5f630 100644 --- a/improver/cli/train_quantile_regression_random_forest.py +++ b/improver/cli/train_quantile_regression_random_forest.py @@ -122,6 +122,8 @@ def process( training_length=training_length, unique_site_id_keys=unique_site_id_keys, )(file_paths) + if forecast_df is None or truth_df is None or cube_inputs is None: + return None result = PrepareAndTrainQRF( feature_config=feature_config, target_cf_name=target_cf_name, diff --git a/improver_tests/acceptance/SHA256SUMS b/improver_tests/acceptance/SHA256SUMS index 8fe8c48a99..f59d51ecb7 100644 --- a/improver_tests/acceptance/SHA256SUMS +++ b/improver_tests/acceptance/SHA256SUMS @@ -84,6 +84,7 @@ eeb021922eb3f61c4dcc4c5096140d458d8b91c3f431ee3e54cdd06c08029f12 ./apply-night- 3a6b4b6e2931e4b58d970c4b034247269a96c372e5a4f97a8ba62a895071fa95 ./apply-night-mask/uk_prob/invalid_input.nc ed8ab78a9f55b54bf0a49f191d3eb33daae30e0d05b9051275a70c0e697aac71 ./apply-night-mask/uk_prob/kgo.nc 809a446327626d007ac288b8277520730863bc596e98116bca5c4afb0d531e96 ./apply-night-mask/uk_prob/valid_input.nc +0f74c2d6e3caf30c4f971da2dc757b1ae659062e31194e4da8e5266a1f23d2af ./apply-quantile-regression-random-forest/added_comment_kgo.nc d2f7d8389b33cde359dd253def2aaf31afbc27557f389bf0f744a28b24e145dd ./apply-quantile-regression-random-forest/config.json 8db1b35bde734c16340a6e42454b9ac146ea32f7c012c17c5e64815069af41bf ./apply-quantile-regression-random-forest/input_forecast.nc 0f74c2d6e3caf30c4f971da2dc757b1ae659062e31194e4da8e5266a1f23d2af ./apply-quantile-regression-random-forest/with_transformation_kgo.nc diff --git a/improver_tests/acceptance/test_apply_quantile_regression_random_forest.py b/improver_tests/acceptance/test_apply_quantile_regression_random_forest.py index 6f43606dc1..de1121d8f5 100644 --- a/improver_tests/acceptance/test_apply_quantile_regression_random_forest.py +++ b/improver_tests/acceptance/test_apply_quantile_regression_random_forest.py @@ -18,7 +18,8 @@ ["without_transformation", "with_transformation"], ) def test_basic(tmp_path, transformation): - """Test""" + """Test apply-quantile-regression-random-forest CLI with and without a + transformation applied.""" kgo_dir = acc.kgo_root() / "apply-quantile-regression-random-forest/" kgo_path = kgo_dir / f"{transformation}_kgo.nc" qrf_path = kgo_dir / f"{transformation}_input.pickle" @@ -38,3 +39,28 @@ def test_basic(tmp_path, transformation): run_cli(args) acc.compare(output_path, kgo_path, atol=LOOSE_TOLERANCE) + + +def test_missing_qrf_model(tmp_path): + """Test that if no QRF model is provided, the result matches the input forecast + with the exception of the comment attribute.""" + kgo_dir = acc.kgo_root() / "apply-quantile-regression-random-forest/" + kgo_path = kgo_dir / "added_comment_kgo.nc" + forecast_path = kgo_dir / "input_forecast.nc" + config_path = kgo_dir / "config.json" + output_path = tmp_path / "output.nc" + args = [ + forecast_path, + "--feature-config", + config_path, + "--target-cf-name", + "air_temperature", + "--output", + output_path, + ] + + run_cli(args) + acc.compare( + output_path, forecast_path, atol=LOOSE_TOLERANCE, exclude_attributes=["comment"] + ) + acc.compare(output_path, kgo_path, atol=LOOSE_TOLERANCE) diff --git a/improver_tests/acceptance/test_train_quantile_regression_random_forest.py b/improver_tests/acceptance/test_train_quantile_regression_random_forest.py index 3fbb571ded..97669d4878 100644 --- a/improver_tests/acceptance/test_train_quantile_regression_random_forest.py +++ b/improver_tests/acceptance/test_train_quantile_regression_random_forest.py @@ -74,3 +74,43 @@ def test_basic( run_cli(compulsory_args + named_args) acc.compare(output_path, kgo_path, file_type="pickled_forest") + + +def test_missing_inputs( + tmp_path, +): + """ + Test train-quantile-regression-random-forest CLI with missing parquet inputs. + """ + kgo_dir = acc.kgo_root() / CLI + config_path = kgo_dir / "config.json" + output_path = tmp_path / "output.pickle" + compulsory_args = [] + named_args = [ + "--feature-config", + config_path, + "--parquet-diagnostic-names", + "temperature_at_screen_level", + "--target-cf-name", + "air_temperature", + "--forecast-periods", + "6:18:6", + "--cycletime", + "20250804T0000Z", + "--training-length", + "2", + "--experiment", + "mix-latestblend", + "--n-estimators", + "10", + "--max-depth", + "5", + "--random-state", + "42", + "--compression-level", + "5", + "--output", + output_path, + ] + + assert run_cli(compulsory_args + named_args) is None diff --git a/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_train_quantile_regression_random_forest.py b/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_train_quantile_regression_random_forest.py index b345dee47e..b2b814b33f 100644 --- a/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_train_quantile_regression_random_forest.py +++ b/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_train_quantile_regression_random_forest.py @@ -529,7 +529,7 @@ def test_load_for_qrf_no_paths(tmp_path, make_files): ) result = plugin(file_paths) # Expecting None since no valid paths are provided - assert result is None + assert result == (None, None, None) @pytest.mark.parametrize(