From 6ab317fa6f61f073a4cc6146ea44d7b168967eff Mon Sep 17 00:00:00 2001 From: James Duncan Date: Tue, 3 Feb 2026 13:43:16 -0800 Subject: [PATCH 1/2] Handle pre-weighted hfds --- fme/ace/test_ocean_train.py | 2 +- fme/core/corrector/ocean.py | 19 ++++++++++++++----- fme/core/corrector/test_ocean.py | 25 +++++++++++++++++++------ fme/core/ocean_data.py | 17 ++++++++++++++++- fme/core/ocean_derived_variables.py | 9 +++++++++ 5 files changed, 59 insertions(+), 13 deletions(-) diff --git a/fme/ace/test_ocean_train.py b/fme/ace/test_ocean_train.py index 2ed903323..367be9b6e 100644 --- a/fme/ace/test_ocean_train.py +++ b/fme/ace/test_ocean_train.py @@ -302,7 +302,7 @@ def _setup( "thetao_1", "sst", "sea_ice_fraction", - "hfds", + "hfds_total_area", ] # Add masks and idepths for data generation diff --git a/fme/core/corrector/ocean.py b/fme/core/corrector/ocean.py index b7571b069..b8027ba11 100644 --- a/fme/core/corrector/ocean.py +++ b/fme/core/corrector/ocean.py @@ -187,13 +187,22 @@ def _force_conserve_ocean_heat_content( name="ocean_heat_content", ) try: + # First priority: pre-weighted heat flux in gen_data net_energy_flux_into_ocean = ( - gen.net_downward_surface_heat_flux + forcing.geothermal_heat_flux - ) * forcing.sea_surface_fraction + gen.net_downward_surface_heat_flux_total_area + + forcing.geothermal_heat_flux * forcing.sea_surface_fraction + ) except KeyError: - net_energy_flux_into_ocean = ( - input.net_downward_surface_heat_flux + forcing.geothermal_heat_flux - ) * forcing.sea_surface_fraction + try: + # Second priority: standard heat flux in gen_data + net_energy_flux_into_ocean = ( + gen.net_downward_surface_heat_flux + forcing.geothermal_heat_flux + ) * forcing.sea_surface_fraction + except KeyError: + # Third priority: standard heat flux in input_data + net_energy_flux_into_ocean = ( + input.net_downward_surface_heat_flux + forcing.geothermal_heat_flux + ) * forcing.sea_surface_fraction energy_flux_global_mean = area_weighted_mean( net_energy_flux_into_ocean, keepdim=True, diff --git a/fme/core/corrector/test_ocean.py b/fme/core/corrector/test_ocean.py index fb4c89221..7488c2c5b 100644 --- a/fme/core/corrector/test_ocean.py +++ b/fme/core/corrector/test_ocean.py @@ -166,13 +166,14 @@ def test_sea_ice_thickness_correction(): @pytest.mark.parametrize( - "hfds_in_gen_data", + "hfds_type", [ - pytest.param(False, id="hfds_in_input"), - pytest.param(True, id="hfds_in_gen"), + pytest.param("input", id="hfds_in_input"), + pytest.param("gen", id="hfds_in_gen"), + pytest.param("total_area", id="hfds_total_area_in_gen"), ], ) -def test_ocean_heat_content_correction(hfds_in_gen_data): +def test_ocean_heat_content_correction(hfds_type): config = OceanCorrectorConfig( ocean_heat_content_correction=OceanHeatContentBudgetConfig( method="scaled_temperature", @@ -196,6 +197,8 @@ def test_ocean_heat_content_correction(hfds_in_gen_data): idepth = torch.tensor([2.5, 10, 20]) depth_coordinate = DepthCoordinate(idepth, mask) + sea_surface_fraction = mask[:, :, :, 0] + input_data_dict = { "thetao_0": torch.ones(nsamples, nlat, nlon), "thetao_1": torch.ones(nsamples, nlat, nlon), @@ -206,13 +209,23 @@ def test_ocean_heat_content_correction(hfds_in_gen_data): "thetao_1": torch.ones(nsamples, nlat, nlon) * 2, "sst": torch.ones(nsamples, nlat, nlon) * 2 + 273.15, } - if hfds_in_gen_data: + if hfds_type == "gen": gen_data_dict["hfds"] = torch.ones(nsamples, nlat, nlon) + elif hfds_type == "total_area": + # hfds_total_area is already weighted by sea_surface_fraction also + # include hfds with a different value to verify hfds_total_area takes + # priority + gen_data_dict["hfds"] = ( + torch.ones(nsamples, nlat, nlon) * 100 + ) # should be ignored + gen_data_dict["hfds_total_area"] = ( + torch.ones(nsamples, nlat, nlon) * sea_surface_fraction + ) else: input_data_dict["hfds"] = torch.ones(nsamples, nlat, nlon) forcing_data_dict = { "hfgeou": torch.ones(nsamples, nlat, nlon), - "sea_surface_fraction": mask[:, :, :, 0], + "sea_surface_fraction": sea_surface_fraction, } input_data = OceanData(input_data_dict, depth_coordinate) gen_data = OceanData(gen_data_dict, depth_coordinate) diff --git a/fme/core/ocean_data.py b/fme/core/ocean_data.py index 66e2d237f..1129ab33f 100644 --- a/fme/core/ocean_data.py +++ b/fme/core/ocean_data.py @@ -22,6 +22,7 @@ "ocean_sea_ice_fraction": ["ocean_sea_ice_fraction"], "land_fraction": ["land_fraction"], "net_downward_surface_heat_flux": ["hfds"], + "net_downward_surface_heat_flux_by_total_area": ["hfds_total_area"], "geothermal_heat_flux": ["hfgeou"], "sea_surface_fraction": ["sea_surface_fraction"], } @@ -159,7 +160,21 @@ def sea_surface_fraction(self) -> torch.Tensor: @property def net_downward_surface_heat_flux(self) -> torch.Tensor: """Net heat flux downward across the ocean surface (below the sea-ice).""" - return self._get("net_downward_surface_heat_flux") + try: + return self._get("net_downward_surface_heat_flux") + except KeyError: + # derive from the sea-surface-fraction-weighted version + return ( + self.net_downward_surface_heat_flux_total_area + / self.sea_surface_fraction + ) + + @property + def net_downward_surface_heat_flux_total_area(self) -> torch.Tensor: + """Net heat flux downward across the ocean surface (below the sea-ice), + normalized by total grid cell area. + """ + return self._get("net_downward_surface_heat_flux_by_total_area") @property def geothermal_heat_flux(self) -> torch.Tensor: diff --git a/fme/core/ocean_derived_variables.py b/fme/core/ocean_derived_variables.py index f02581022..22eed4d73 100644 --- a/fme/core/ocean_derived_variables.py +++ b/fme/core/ocean_derived_variables.py @@ -176,3 +176,12 @@ def sea_ice_fraction( ) -> torch.Tensor: """Compute the sea ice fraction.""" return data.sea_ice_fraction + + +@register(VariableMetadata("W/m**2", "Surface ocean heat flux"), exists_ok=True) +def hfds( + data: OceanData, + timestep: datetime.timedelta, +) -> torch.Tensor: + """Compute the sea ice fraction.""" + return data.net_downward_surface_heat_flux From 172abe8e2f3ea9cefc2f1c0d86766d2950ab03be Mon Sep 17 00:00:00 2001 From: James Duncan Date: Thu, 5 Feb 2026 13:35:11 -0800 Subject: [PATCH 2/2] Address review comments --- fme/ace/test_ocean_train.py | 2 +- fme/core/ocean_data.py | 4 ++-- fme/core/ocean_derived_variables.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fme/ace/test_ocean_train.py b/fme/ace/test_ocean_train.py index 367be9b6e..5f56be6de 100644 --- a/fme/ace/test_ocean_train.py +++ b/fme/ace/test_ocean_train.py @@ -441,7 +441,7 @@ def test_train_and_inference(tmp_path, very_fast_only: bool): assert best_checkpoint_path.exists() ds_prediction = xr.open_dataset(prediction_output_path, decode_timedelta=False) - for name in ["sst", "thetao_0", "thetao_1", "ocean_heat_content"]: + for name in ["sst", "thetao_0", "thetao_1", "ocean_heat_content", "hfds"]: assert name in ds_prediction # outputs should have some non-null values assert not np.isnan(ds_prediction[name].values).all() diff --git a/fme/core/ocean_data.py b/fme/core/ocean_data.py index 1129ab33f..6f27a077f 100644 --- a/fme/core/ocean_data.py +++ b/fme/core/ocean_data.py @@ -22,7 +22,7 @@ "ocean_sea_ice_fraction": ["ocean_sea_ice_fraction"], "land_fraction": ["land_fraction"], "net_downward_surface_heat_flux": ["hfds"], - "net_downward_surface_heat_flux_by_total_area": ["hfds_total_area"], + "net_downward_surface_heat_flux_total_area": ["hfds_total_area"], "geothermal_heat_flux": ["hfgeou"], "sea_surface_fraction": ["sea_surface_fraction"], } @@ -174,7 +174,7 @@ def net_downward_surface_heat_flux_total_area(self) -> torch.Tensor: """Net heat flux downward across the ocean surface (below the sea-ice), normalized by total grid cell area. """ - return self._get("net_downward_surface_heat_flux_by_total_area") + return self._get("net_downward_surface_heat_flux_total_area") @property def geothermal_heat_flux(self) -> torch.Tensor: diff --git a/fme/core/ocean_derived_variables.py b/fme/core/ocean_derived_variables.py index 22eed4d73..f243135ca 100644 --- a/fme/core/ocean_derived_variables.py +++ b/fme/core/ocean_derived_variables.py @@ -183,5 +183,5 @@ def hfds( data: OceanData, timestep: datetime.timedelta, ) -> torch.Tensor: - """Compute the sea ice fraction.""" + """Compute the net downward surface heat flux.""" return data.net_downward_surface_heat_flux