diff --git a/improver/calibration/load_and_train_quantile_regression_random_forest.py b/improver/calibration/load_and_train_quantile_regression_random_forest.py index db8845b948..bc342af45f 100644 --- a/improver/calibration/load_and_train_quantile_regression_random_forest.py +++ b/improver/calibration/load_and_train_quantile_regression_random_forest.py @@ -450,6 +450,11 @@ def filter_bad_sites( Tuple containing: - DataFrame containing the forecast data with bad sites removed. - DataFrame containing the truth data with bad sites removed. + + Raises: + ValueError: If the truth DataFrame is empty after removing NaNs. + ValueError: If there are no matching sites and times between the + forecast and truth DataFrames after removing NaNs. """ truth_df.dropna(subset=["ob_value"] + [*self.unique_site_id_keys], inplace=True) @@ -466,6 +471,13 @@ def filter_bad_sites( forecast_df = forecast_df[forecast_index.isin(truth_index)] truth_df = truth_df[truth_index.isin(forecast_index)] + if truth_df.empty: + msg = ( + "Empty truth DataFrame after finding the intersection of sites " + "and times between the truth DataFrame and the forecast DataFrame." + ) + raise ValueError(msg) + return forecast_df, truth_df def process( diff --git a/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_train_quantile_regression_random_forest.py b/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_train_quantile_regression_random_forest.py index 46d81e910a..4bbe63c19a 100644 --- a/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_train_quantile_regression_random_forest.py +++ b/improver_tests/calibration/quantile_regression_random_forests_calibration/test_load_and_train_quantile_regression_random_forest.py @@ -733,7 +733,7 @@ def test_unexpected_loading( @pytest.mark.parametrize("include_noncube_static", [True, False]) @pytest.mark.parametrize("remove_target", [True, False]) @pytest.mark.parametrize("include_nans", [True, False]) -@pytest.mark.parametrize("include_latlon_nans", [True, False]) +@pytest.mark.parametrize("include_site_id_nans", [True, False]) @pytest.mark.parametrize("add_kwargs", [True, False]) @pytest.mark.parametrize( "site_id", ["wmo_id", "station_id", ["wmo_id"], ["latitude", "longitude"]] @@ -771,7 +771,7 @@ def test_prepare_and_train_qrf( include_noncube_static, remove_target, include_nans, - include_latlon_nans, + include_site_id_nans, add_kwargs, site_id, forecast_creation, @@ -823,12 +823,13 @@ def test_prepare_and_train_qrf( # Insert a NaN will result in this row being dropped. truth_df.loc[0, "ob_value"] = pd.NA - if include_latlon_nans: - # As latitude is not a feature, this NaN should be ignored. - if len(truth_df) == 1: - truth_df.loc[0, "latitude"] = pd.NA - else: - truth_df.loc[1, "latitude"] = pd.NA + if include_site_id_nans: + for key in site_id: + # As latitude is not a feature, this NaN should be ignored. + if len(truth_df) == 1: + truth_df.loc[0, key] = pd.NA + else: + truth_df.loc[1, key] = pd.NA if add_kwargs: kwargs = {"min_samples_leaf": 2} @@ -836,6 +837,10 @@ def test_prepare_and_train_qrf( if feature_config == {}: pytest.skip("No features to train on") + plugin_inputs = {"forecast_df": forecast_df, "truth_df": truth_df} + if include_static: + plugin_inputs["cube_inputs"] = iris.cube.CubeList([ancil_cube]) + # Create an instance of PrepareAndTrainQRF with the required parameters plugin = PrepareAndTrainQRF( feature_config=feature_config, @@ -848,18 +853,16 @@ def test_prepare_and_train_qrf( unique_site_id_keys=site_id, **(kwargs if add_kwargs else {}), ) - if truth_df["ob_value"].isna().all() or truth_df["latitude"].isna().all(): + + truth_subset_df = truth_df.dropna(subset=["ob_value"] + site_id) + merged_df = pd.merge( + forecast_df, truth_subset_df, on=[*site_id, "time"], how="inner" + ) + if merged_df.empty: with pytest.raises(ValueError, match="Empty truth DataFrame"): - plugin(forecast_df, truth_df) + plugin(**plugin_inputs) return - elif include_static: - qrf_model, transformation, pre_transform_addition = plugin( - forecast_df, truth_df, iris.cube.CubeList([ancil_cube]) - ) - else: - qrf_model, transformation, pre_transform_addition = plugin( - forecast_df, truth_df - ) + qrf_model, transformation, pre_transform_addition = plugin(**plugin_inputs) assert qrf_model.n_estimators == n_estimators assert qrf_model.max_depth == max_depth