From 753cc50ee1db4e5b0346fe221152953b390d5c35 Mon Sep 17 00:00:00 2001 From: Gavin Evans Date: Mon, 1 Sep 2025 13:35:10 +0100 Subject: [PATCH] Alter saving of QRF model to use clize, rather than saving within the plugin. --- ...train_quantile_regression_random_forest.py | 6 +-- .../quantile_regression_random_forest.py | 13 +----- improver/cli/__init__.py | 8 +++- ...train_quantile_regression_random_forest.py | 9 +--- ...train_quantile_regression_random_forest.py | 2 +- ...apply_quantile_regression_random_forest.py | 4 +- ...train_quantile_regression_random_forest.py | 37 ++++------------ .../test_quantile_regression_random_forest.py | 37 +++++----------- improver_tests/cli/test_init.py | 42 +++++++++++++++---- 9 files changed, 66 insertions(+), 92 deletions(-) diff --git a/improver/calibration/load_and_train_quantile_regression_random_forest.py b/improver/calibration/load_and_train_quantile_regression_random_forest.py index f3c1aa12f1..421ca110cb 100644 --- a/improver/calibration/load_and_train_quantile_regression_random_forest.py +++ b/improver/calibration/load_and_train_quantile_regression_random_forest.py @@ -43,7 +43,6 @@ def __init__( random_state: Optional[int] = None, transformation: Optional[str] = None, pre_transform_addition: float = 0, - compression: int = 5, ): """Initialise the LoadAndTrainQRF plugin.""" self.feature_config = feature_config @@ -59,7 +58,6 @@ def __init__( self.random_state = random_state self.transformation = transformation self.pre_transform_addition = pre_transform_addition - self.compression = compression self.quantile_forest_installed = quantile_forest_package_available() def _split_cubes_and_parquet_files( @@ -355,7 +353,7 @@ def process( forecast_df = self._add_features_to_df(forecast_df, cube_inputs) forecast_df, truth_df = self.filter_bad_sites(forecast_df, truth_df) - TrainQuantileRegressionRandomForests( + result = TrainQuantileRegressionRandomForests( target_name=self.target_cf_name, feature_config=self.feature_config, n_estimators=self.n_estimators, @@ -364,6 +362,6 @@ def process( random_state=self.random_state, transformation=self.transformation, pre_transform_addition=self.pre_transform_addition, - compression=self.compression, model_output=model_output, )(forecast_df, truth_df) + return result diff --git a/improver/calibration/quantile_regression_random_forest.py b/improver/calibration/quantile_regression_random_forest.py index fdaae25116..48c10c041a 100644 --- a/improver/calibration/quantile_regression_random_forest.py +++ b/improver/calibration/quantile_regression_random_forest.py @@ -6,7 +6,6 @@ from typing import Optional -import joblib import numpy as np import pandas as pd @@ -206,8 +205,6 @@ def __init__( random_state: Optional[int] = None, transformation: Optional[str] = None, pre_transform_addition: np.float32 = 0, - compression: int = 5, - model_output: Optional[str] = None, **kwargs, ) -> None: """Initialise the plugin. @@ -246,10 +243,6 @@ def __init__( Transformation to be applied to the data before fitting. pre_transform_addition (float): Value to be added before transformation. - compression (int): - Compression level for saving the model. - model_output (str): - Full path including model file name that will store the pickled model. kwargs: Additional keyword arguments for the quantile regression model. @@ -264,8 +257,6 @@ def __init__( self.transformation = transformation _check_valid_transformation(self.transformation) self.pre_transform_addition = pre_transform_addition - self.compression = compression - self.output = model_output self.kwargs = kwargs self.expected_coordinate_order = ["forecast_reference_time", "forecast_period"] @@ -353,9 +344,7 @@ def process( target_values = combined_df["ob_value"].values # Fit the quantile regression model - qrf_model = self.fit_qrf(feature_values, target_values) - - joblib.dump(qrf_model, self.output, compress=self.compression) + return self.fit_qrf(feature_values, target_values) class ApplyQuantileRegressionRandomForests(PostProcessingPlugin): diff --git a/improver/cli/__init__.py b/improver/cli/__init__.py index 79542ad4cf..7a6be5a5c6 100644 --- a/improver/cli/__init__.py +++ b/improver/cli/__init__.py @@ -312,6 +312,7 @@ def with_output( pass_through_output=False, compression_level=1, least_significant_digit: int = None, + output_file_type="netCDF", **kwargs, ): """Add `output` keyword only argument. @@ -346,15 +347,20 @@ def with_output( Returns: Result of calling `wrapped` or None if `output` is given. """ + import joblib + from improver.utilities.save import save_netcdf result = wrapped(*args, **kwargs) - if output and result: + if output and output.endswith(".nc"): save_netcdf(result, output, compression_level, least_significant_digit) if pass_through_output: return ObjectAsStr(result, output) return + elif output and output.endswith((".pickle", ".pkl")): + joblib.dump(result, output, compress=compression_level) + return return result diff --git a/improver/cli/train_quantile_regression_random_forest.py b/improver/cli/train_quantile_regression_random_forest.py index 2c10a44421..eb87faf018 100644 --- a/improver/cli/train_quantile_regression_random_forest.py +++ b/improver/cli/train_quantile_regression_random_forest.py @@ -9,6 +9,7 @@ @cli.clizefy +@cli.with_output def process( *file_paths: cli.inputpath, feature_config: cli.inputjson, @@ -24,8 +25,6 @@ def process( random_state: int = None, transformation: str = None, pre_transform_addition: float = 0, - compression: int = 5, - output: str = None, ): """Training a model using Quantile Regression Random Forest. @@ -92,10 +91,6 @@ def process( Transformation to be applied to the data before fitting. pre_transform_addition (float): Value to be added before transformation. - compression (int): - Compression level for saving the model. - output (str): - Full path including model file name that will store the pickled model. Returns: None: The function creates a pickle file. @@ -119,9 +114,7 @@ def process( random_state=random_state, transformation=transformation, pre_transform_addition=pre_transform_addition, - compression=compression, )( file_paths, - model_output=output, ) return result diff --git a/improver_tests/acceptance/test_train_quantile_regression_random_forest.py b/improver_tests/acceptance/test_train_quantile_regression_random_forest.py index 64feccd97e..75073355c2 100644 --- a/improver_tests/acceptance/test_train_quantile_regression_random_forest.py +++ b/improver_tests/acceptance/test_train_quantile_regression_random_forest.py @@ -59,7 +59,7 @@ def test_basic( "5", "--random-state", "42", - "--compression", + "--compression-level", "5", "--output", output_path, diff --git a/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_apply_quantile_regression_random_forest.py b/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_apply_quantile_regression_random_forest.py index 78a2645e76..9fbeaced5f 100644 --- a/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_apply_quantile_regression_random_forest.py +++ b/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_apply_quantile_regression_random_forest.py @@ -78,7 +78,6 @@ def test_load_and_apply_qrf( feature_config = {"wind_speed_at_10m": ["mean", "std", "latitude", "longitude"]} model_output = _run_train_qrf( - tmp_path, feature_config, n_estimators, max_depth, @@ -98,6 +97,7 @@ def test_load_and_apply_qrf( ], realization_data=[2, 6, 10], truth_data=[4.2, 6.2, 4.1, 5.1], + tmp_path=tmp_path, ) frt = "20170103T0000Z" @@ -182,7 +182,6 @@ def test_unexpected( quantiles = [0.5] model_output = _run_train_qrf( - tmp_path, feature_config, n_estimators, max_depth, @@ -202,6 +201,7 @@ def test_unexpected( ], realization_data=[2, 6, 10], truth_data=[4.2, 6.2, 4.1, 5.1], + tmp_path=tmp_path, ) frt = "20170103T0000Z" diff --git a/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_train_quantile_regression_random_forest.py b/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_train_quantile_regression_random_forest.py index 7f3d2e609b..396a0f6bbe 100644 --- a/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_train_quantile_regression_random_forest.py +++ b/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_train_quantile_regression_random_forest.py @@ -5,7 +5,6 @@ """Unit tests for the LoadAndTrainQRF plugin.""" import iris -import joblib import numpy as np import pandas as pd import pytest @@ -380,10 +379,6 @@ def test_load_and_train_qrf( truth_path = truth_creation(tmp_path) file_paths = [forecast_path, truth_path] - model_output_dir = tmp_path / "train_qrf" - model_output_dir.mkdir(parents=True) - model_output = str(model_output_dir / "qrf_model.pkl") - if include_static: ancil_path = _create_ancil_file(tmp_path, sorted(list(set(wmo_ids)))) file_paths.append(ancil_path) @@ -407,9 +402,7 @@ def test_load_and_train_qrf( transformation="log", pre_transform_addition=1, ) - plugin(file_paths, model_output=model_output) - - qrf_model = joblib.load(model_output) + qrf_model = plugin(file_paths) assert qrf_model.n_estimators == n_estimators assert qrf_model.max_depth == max_depth @@ -446,10 +439,6 @@ def test_load_and_train_qrf_no_paths(tmp_path, make_files): for file_path in file_paths: (tmp_path / file_path).mkdir(parents=True, exist_ok=True) - model_output_dir = tmp_path / "train_qrf" - model_output_dir.mkdir(parents=True) - model_output = str(model_output_dir / "qrf_model.pkl") - plugin = LoadAndTrainQRF( experiment="latestblend", feature_config=feature_config, @@ -464,11 +453,9 @@ def test_load_and_train_qrf_no_paths(tmp_path, make_files): transformation="log", pre_transform_addition=1, ) - result = plugin(file_paths, model_output=model_output) + result = plugin(file_paths) # Expecting None since no valid paths are provided assert result is None - # Check if the model output file is not created - assert not (model_output_dir / "qrf_model.pkl").exists() @pytest.mark.parametrize( @@ -491,10 +478,6 @@ def test_load_and_train_qrf_mismatches(tmp_path, cycletime, forecast_periods): tmp_path / "partition" / "truth_table/", ] - model_output_dir = tmp_path / "train_qrf" - model_output_dir.mkdir(parents=True) - model_output = str(model_output_dir / "qrf_model.pkl") - plugin = LoadAndTrainQRF( experiment="latestblend", feature_config=feature_config, @@ -509,11 +492,9 @@ def test_load_and_train_qrf_mismatches(tmp_path, cycletime, forecast_periods): transformation="log", pre_transform_addition=1, ) - result = plugin(file_paths, model_output=model_output) + result = plugin(file_paths) # Expecting None since no valid paths are provided assert result is None - # Check if the model output file is not created - assert not (model_output_dir / "qrf_model.pkl").exists() @pytest.mark.parametrize( @@ -581,10 +562,6 @@ def test_unexpected( truth_path = truth_creation(tmp_path) file_paths = [forecast_path, truth_path] - model_output_dir = tmp_path / "train_qrf" - model_output_dir.mkdir(parents=True) - model_output = str(model_output_dir / "qrf_model.pkl") - # Create an instance of LoadAndTrainQRF with the required parameters plugin = LoadAndTrainQRF( experiment="latestblend", @@ -603,7 +580,7 @@ def test_unexpected( if exception == "non_matching_truth": with pytest.raises(IOError, match="The requested filepath"): - plugin(file_paths, model_output=model_output) + plugin(file_paths) elif exception == "missing_static_feature": feature_config = { "wind_speed_at_10m": ["mean", "std"], @@ -622,13 +599,13 @@ def test_unexpected( plugin.process(file_paths=file_paths) elif exception == "no_percentile_realization": with pytest.raises(ValueError, match="The forecast parquet file"): - plugin(file_paths, model_output=model_output) + plugin(file_paths) elif exception == "alternative_forecast_period": with pytest.raises(ValueError, match="The forecast_periods argument"): - plugin(file_paths, model_output=model_output) + plugin(file_paths) elif exception == "no_quantile_forest_package": plugin.quantile_forest_installed = False - result = plugin(file_paths, model_output=model_output) + result = plugin(file_paths) assert result is None else: raise ValueError(f"Unknown exception type: {exception}") diff --git a/improver_tests/calibration/quantile_regression_random_forests_calibration/test_quantile_regression_random_forest.py b/improver_tests/calibration/quantile_regression_random_forests_calibration/test_quantile_regression_random_forest.py index c2ad1e47d7..f48d89cc3b 100644 --- a/improver_tests/calibration/quantile_regression_random_forests_calibration/test_quantile_regression_random_forest.py +++ b/improver_tests/calibration/quantile_regression_random_forests_calibration/test_quantile_regression_random_forest.py @@ -138,7 +138,6 @@ def _create_ancil_file(return_cube=False): def _run_train_qrf( - tmp_path, feature_config, n_estimators, max_depth, @@ -162,6 +161,7 @@ def _run_train_qrf( ], realization_data=[2, 6, 10], truth_data=[4.2, 3.8, 5.8, 6, 7, 7.3, 9.1, 9.5], + tmp_path=None, ): realization_data = np.array(realization_data, dtype=np.float32) forecast_dfs = [] @@ -196,10 +196,6 @@ def _run_train_qrf( forecast_df["wind_speed_at_10m"] + 10, dtype=np.float32 ) - model_output_dir = tmp_path / "train_qrf" - model_output_dir.mkdir(parents=True) - model_output = str(model_output_dir / "qrf_model.pkl") - plugin = TrainQuantileRegressionRandomForests( target_name="wind_speed_at_10m", feature_config=feature_config, @@ -208,12 +204,14 @@ def _run_train_qrf( random_state=random_state, transformation=transformation, pre_transform_addition=pre_transform_addition, - compression=compression, - model_output=model_output, **extra_kwargs, ) - plugin.process(forecast_df, truth_df) - return model_output + result = plugin.process(forecast_df, truth_df) + if tmp_path is not None: + model_output = tmp_path / "qrf_model.pickle" + joblib.dump(result, model_output, compress=compression) + return model_output + return result def test_quantile_forest_package_available(): @@ -532,7 +530,6 @@ def test_check_valid_transformation(transformation): ], ) def test_train_qrf_single_lead_times( - tmp_path, n_estimators, max_depth, random_state, @@ -548,8 +545,7 @@ def test_train_qrf_single_lead_times( feature_config = {"wind_speed_at_10m": ["mean", "std", "latitude", "longitude"]} - model_output = _run_train_qrf( - tmp_path, + qrf_model = _run_train_qrf( feature_config, n_estimators, max_depth, @@ -570,7 +566,6 @@ def test_train_qrf_single_lead_times( realization_data=[2, 6, 10], truth_data=[4.2, 4.1, 4.2, 4.1], ) - qrf_model = joblib.load(model_output) assert qrf_model.n_estimators == n_estimators assert qrf_model.max_depth == max_depth @@ -600,7 +595,6 @@ def test_train_qrf_single_lead_times( ], ) def test_train_qrf_multiple_lead_times( - tmp_path, n_estimators, max_depth, random_state, @@ -616,8 +610,7 @@ def test_train_qrf_multiple_lead_times( feature_config = {"wind_speed_at_10m": ["mean", "std", "latitude", "longitude"]} - model_output = _run_train_qrf( - tmp_path, + qrf_model = _run_train_qrf( feature_config, n_estimators, max_depth, @@ -628,7 +621,6 @@ def test_train_qrf_multiple_lead_times( extra_kwargs, include_static, ) - qrf_model = joblib.load(model_output) assert qrf_model.n_estimators == n_estimators assert qrf_model.max_depth == max_depth @@ -664,7 +656,6 @@ def test_train_qrf_multiple_lead_times( ], ) def test_alternative_feature_configs( - tmp_path, feature_config, data, include_static, @@ -683,7 +674,6 @@ def test_alternative_feature_configs( if "pressure_at_mean_sea_level" in feature_config: with pytest.raises(ValueError, match=expected): _run_train_qrf( - tmp_path, feature_config, n_estimators, max_depth, @@ -696,8 +686,7 @@ def test_alternative_feature_configs( ) return - model_output = _run_train_qrf( - tmp_path, + qrf_model = _run_train_qrf( feature_config, n_estimators, max_depth, @@ -708,7 +697,6 @@ def test_alternative_feature_configs( extra_kwargs, include_static, ) - qrf_model = joblib.load(model_output) assert qrf_model.n_estimators == n_estimators assert qrf_model.max_depth == max_depth @@ -757,7 +745,6 @@ def test_alternative_feature_configs( ], ) def test_apply_qrf( - tmp_path, quantiles, transformation, pre_transform_addition, @@ -772,8 +759,7 @@ def test_apply_qrf( compression = 5 extra_kwargs = {} - model_output = _run_train_qrf( - tmp_path, + qrf_model = _run_train_qrf( feature_config, n_estimators, max_depth, @@ -784,7 +770,6 @@ def test_apply_qrf( extra_kwargs, include_static, ) - qrf_model = joblib.load(model_output) frt = "20170103T0000Z" vt = "20170103T1200Z" diff --git a/improver_tests/cli/test_init.py b/improver_tests/cli/test_init.py index a0d56ecec6..b8a8ea5bda 100644 --- a/improver_tests/cli/test_init.py +++ b/improver_tests/cli/test_init.py @@ -191,8 +191,8 @@ def test_with_output(self, m): """Tests that save_netcdf is called with object and string, default compression_level=1 and default least_significant_digit=None""" # pylint disable is needed as it can't see the wrappers output kwarg. - result = wrapped_with_output.cli("argv[0]", "2", "--output=foo") - m.assert_called_with(4, "foo", 1, None) + result = wrapped_with_output.cli("argv[0]", "2", "--output=foo.nc") + m.assert_called_with(4, "foo.nc", 1, None) self.assertEqual(result, None) @patch("improver.utilities.save.save_netcdf") @@ -200,9 +200,9 @@ def test_with_output_compression_level(self, m): """Tests save_netcdf, compression-level=9 and default least-significant-digit=None""" # pylint disable is needed as it can't see the wrappers output kwarg. result = wrapped_with_output.cli( - "argv[0]", "2", "--output=foo", "--compression-level=9" + "argv[0]", "2", "--output=foo.nc", "--compression-level=9" ) - m.assert_called_with(4, "foo", 9, None) + m.assert_called_with(4, "foo.nc", 9, None) self.assertEqual(result, None) @patch("improver.utilities.save.save_netcdf") @@ -210,9 +210,9 @@ def test_with_output_no_compression(self, m): """Tests save_netcdf, compression-level=0 and default least-significant-digit=None""" # pylint disable is needed as it can't see the wrappers output kwarg. result = wrapped_with_output.cli( - "argv[0]", "2", "--output=foo", "--compression-level=0" + "argv[0]", "2", "--output=foo.nc", "--compression-level=0" ) - m.assert_called_with(4, "foo", 0, None) + m.assert_called_with(4, "foo.nc", 0, None) self.assertEqual(result, None) @patch("improver.utilities.save.save_netcdf") @@ -222,11 +222,37 @@ def test_with_output_with_least_significant_figure(self, m): result = wrapped_with_output.cli( "argv[0]", "2", - "--output=foo", + "--output=foo.nc", "--compression-level=0", "--least-significant-digit=2", ) - m.assert_called_with(4, "foo", 0, 2) + m.assert_called_with(4, "foo.nc", 0, 2) + self.assertEqual(result, None) + + @patch("joblib.dump") + def test_with_output_pickle(self, m): + """Tests that joblib.dump is called with object and string, default + compression_level=1 and default least_significant_digit=None""" + # pylint disable is needed as it can't see the wrappers output kwarg. + result = wrapped_with_output.cli( + "argv[0]", + "2", + "--output=foo.pickle", + ) + m.assert_called_with(4, "foo.pickle", compress=1) + self.assertEqual(result, None) + + @patch("joblib.dump") + def test_with_output_pkl(self, m): + """Tests that joblib.dump is called with object and string, default + compression_level=1 and default least_significant_digit=None""" + # pylint disable is needed as it can't see the wrappers output kwarg. + result = wrapped_with_output.cli( + "argv[0]", + "2", + "--output=foo.pkl", + ) + m.assert_called_with(4, "foo.pkl", compress=1) self.assertEqual(result, None)