Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(
random_state: Optional[int] = None,
transformation: Optional[str] = None,
pre_transform_addition: float = 0,
compression: int = 5,
):
"""Initialise the LoadAndTrainQRF plugin."""
self.feature_config = feature_config
Expand All @@ -59,7 +58,6 @@ def __init__(
self.random_state = random_state
self.transformation = transformation
self.pre_transform_addition = pre_transform_addition
self.compression = compression
self.quantile_forest_installed = quantile_forest_package_available()

def _split_cubes_and_parquet_files(
Expand Down Expand Up @@ -355,7 +353,7 @@ def process(
forecast_df = self._add_features_to_df(forecast_df, cube_inputs)
forecast_df, truth_df = self.filter_bad_sites(forecast_df, truth_df)

TrainQuantileRegressionRandomForests(
result = TrainQuantileRegressionRandomForests(
target_name=self.target_cf_name,
feature_config=self.feature_config,
n_estimators=self.n_estimators,
Expand All @@ -364,6 +362,6 @@ def process(
random_state=self.random_state,
transformation=self.transformation,
pre_transform_addition=self.pre_transform_addition,
compression=self.compression,
model_output=model_output,
)(forecast_df, truth_df)
return result
13 changes: 1 addition & 12 deletions improver/calibration/quantile_regression_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from typing import Optional

import joblib
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -206,8 +205,6 @@ def __init__(
random_state: Optional[int] = None,
transformation: Optional[str] = None,
pre_transform_addition: np.float32 = 0,
compression: int = 5,
model_output: Optional[str] = None,
**kwargs,
) -> None:
"""Initialise the plugin.
Expand Down Expand Up @@ -246,10 +243,6 @@ def __init__(
Transformation to be applied to the data before fitting.
pre_transform_addition (float):
Value to be added before transformation.
compression (int):
Compression level for saving the model.
model_output (str):
Full path including model file name that will store the pickled model.
kwargs:
Additional keyword arguments for the quantile regression model.

Expand All @@ -264,8 +257,6 @@ def __init__(
self.transformation = transformation
_check_valid_transformation(self.transformation)
self.pre_transform_addition = pre_transform_addition
self.compression = compression
self.output = model_output
self.kwargs = kwargs
self.expected_coordinate_order = ["forecast_reference_time", "forecast_period"]

Expand Down Expand Up @@ -353,9 +344,7 @@ def process(
target_values = combined_df["ob_value"].values

# Fit the quantile regression model
qrf_model = self.fit_qrf(feature_values, target_values)

joblib.dump(qrf_model, self.output, compress=self.compression)
return self.fit_qrf(feature_values, target_values)


class ApplyQuantileRegressionRandomForests(PostProcessingPlugin):
Expand Down
8 changes: 7 additions & 1 deletion improver/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def with_output(
pass_through_output=False,
compression_level=1,
least_significant_digit: int = None,
output_file_type="netCDF",
**kwargs,
):
"""Add `output` keyword only argument.
Expand Down Expand Up @@ -346,15 +347,20 @@ def with_output(
Returns:
Result of calling `wrapped` or None if `output` is given.
"""
import joblib

from improver.utilities.save import save_netcdf

result = wrapped(*args, **kwargs)

if output and result:
if output and output.endswith(".nc"):
save_netcdf(result, output, compression_level, least_significant_digit)
if pass_through_output:
return ObjectAsStr(result, output)
return
elif output and output.endswith((".pickle", ".pkl")):
joblib.dump(result, output, compress=compression_level)
return
return result


Expand Down
9 changes: 1 addition & 8 deletions improver/cli/train_quantile_regression_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


@cli.clizefy
@cli.with_output
def process(
*file_paths: cli.inputpath,
feature_config: cli.inputjson,
Expand All @@ -24,8 +25,6 @@ def process(
random_state: int = None,
transformation: str = None,
pre_transform_addition: float = 0,
compression: int = 5,
output: str = None,
):
"""Training a model using Quantile Regression Random Forest.

Expand Down Expand Up @@ -92,10 +91,6 @@ def process(
Transformation to be applied to the data before fitting.
pre_transform_addition (float):
Value to be added before transformation.
compression (int):
Compression level for saving the model.
output (str):
Full path including model file name that will store the pickled model.
Returns:
None:
The function creates a pickle file.
Expand All @@ -119,9 +114,7 @@ def process(
random_state=random_state,
transformation=transformation,
pre_transform_addition=pre_transform_addition,
compression=compression,
)(
file_paths,
model_output=output,
)
return result
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_basic(
"5",
"--random-state",
"42",
"--compression",
"--compression-level",
"5",
"--output",
output_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def test_load_and_apply_qrf(
feature_config = {"wind_speed_at_10m": ["mean", "std", "latitude", "longitude"]}

model_output = _run_train_qrf(
tmp_path,
feature_config,
n_estimators,
max_depth,
Expand All @@ -98,6 +97,7 @@ def test_load_and_apply_qrf(
],
realization_data=[2, 6, 10],
truth_data=[4.2, 6.2, 4.1, 5.1],
tmp_path=tmp_path,
)

frt = "20170103T0000Z"
Expand Down Expand Up @@ -182,7 +182,6 @@ def test_unexpected(
quantiles = [0.5]

model_output = _run_train_qrf(
tmp_path,
feature_config,
n_estimators,
max_depth,
Expand All @@ -202,6 +201,7 @@ def test_unexpected(
],
realization_data=[2, 6, 10],
truth_data=[4.2, 6.2, 4.1, 5.1],
tmp_path=tmp_path,
)

frt = "20170103T0000Z"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""Unit tests for the LoadAndTrainQRF plugin."""

import iris
import joblib
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -380,10 +379,6 @@ def test_load_and_train_qrf(
truth_path = truth_creation(tmp_path)
file_paths = [forecast_path, truth_path]

model_output_dir = tmp_path / "train_qrf"
model_output_dir.mkdir(parents=True)
model_output = str(model_output_dir / "qrf_model.pkl")

if include_static:
ancil_path = _create_ancil_file(tmp_path, sorted(list(set(wmo_ids))))
file_paths.append(ancil_path)
Expand All @@ -407,9 +402,7 @@ def test_load_and_train_qrf(
transformation="log",
pre_transform_addition=1,
)
plugin(file_paths, model_output=model_output)

qrf_model = joblib.load(model_output)
qrf_model = plugin(file_paths)

assert qrf_model.n_estimators == n_estimators
assert qrf_model.max_depth == max_depth
Expand Down Expand Up @@ -446,10 +439,6 @@ def test_load_and_train_qrf_no_paths(tmp_path, make_files):
for file_path in file_paths:
(tmp_path / file_path).mkdir(parents=True, exist_ok=True)

model_output_dir = tmp_path / "train_qrf"
model_output_dir.mkdir(parents=True)
model_output = str(model_output_dir / "qrf_model.pkl")

plugin = LoadAndTrainQRF(
experiment="latestblend",
feature_config=feature_config,
Expand All @@ -464,11 +453,9 @@ def test_load_and_train_qrf_no_paths(tmp_path, make_files):
transformation="log",
pre_transform_addition=1,
)
result = plugin(file_paths, model_output=model_output)
result = plugin(file_paths)
# Expecting None since no valid paths are provided
assert result is None
# Check if the model output file is not created
assert not (model_output_dir / "qrf_model.pkl").exists()


@pytest.mark.parametrize(
Expand All @@ -491,10 +478,6 @@ def test_load_and_train_qrf_mismatches(tmp_path, cycletime, forecast_periods):
tmp_path / "partition" / "truth_table/",
]

model_output_dir = tmp_path / "train_qrf"
model_output_dir.mkdir(parents=True)
model_output = str(model_output_dir / "qrf_model.pkl")

plugin = LoadAndTrainQRF(
experiment="latestblend",
feature_config=feature_config,
Expand All @@ -509,11 +492,9 @@ def test_load_and_train_qrf_mismatches(tmp_path, cycletime, forecast_periods):
transformation="log",
pre_transform_addition=1,
)
result = plugin(file_paths, model_output=model_output)
result = plugin(file_paths)
# Expecting None since no valid paths are provided
assert result is None
# Check if the model output file is not created
assert not (model_output_dir / "qrf_model.pkl").exists()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -581,10 +562,6 @@ def test_unexpected(
truth_path = truth_creation(tmp_path)
file_paths = [forecast_path, truth_path]

model_output_dir = tmp_path / "train_qrf"
model_output_dir.mkdir(parents=True)
model_output = str(model_output_dir / "qrf_model.pkl")

# Create an instance of LoadAndTrainQRF with the required parameters
plugin = LoadAndTrainQRF(
experiment="latestblend",
Expand All @@ -603,7 +580,7 @@ def test_unexpected(

if exception == "non_matching_truth":
with pytest.raises(IOError, match="The requested filepath"):
plugin(file_paths, model_output=model_output)
plugin(file_paths)
elif exception == "missing_static_feature":
feature_config = {
"wind_speed_at_10m": ["mean", "std"],
Expand All @@ -622,13 +599,13 @@ def test_unexpected(
plugin.process(file_paths=file_paths)
elif exception == "no_percentile_realization":
with pytest.raises(ValueError, match="The forecast parquet file"):
plugin(file_paths, model_output=model_output)
plugin(file_paths)
elif exception == "alternative_forecast_period":
with pytest.raises(ValueError, match="The forecast_periods argument"):
plugin(file_paths, model_output=model_output)
plugin(file_paths)
elif exception == "no_quantile_forest_package":
plugin.quantile_forest_installed = False
result = plugin(file_paths, model_output=model_output)
result = plugin(file_paths)
assert result is None
else:
raise ValueError(f"Unknown exception type: {exception}")
Loading