Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fme/ace/test_ocean_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _setup(
"thetao_1",
"sst",
"sea_ice_fraction",
"hfds",
"hfds_total_area",
]

# Add masks and idepths for data generation
Expand Down Expand Up @@ -441,7 +441,7 @@ def test_train_and_inference(tmp_path, very_fast_only: bool):
assert best_checkpoint_path.exists()

ds_prediction = xr.open_dataset(prediction_output_path, decode_timedelta=False)
for name in ["sst", "thetao_0", "thetao_1", "ocean_heat_content"]:
for name in ["sst", "thetao_0", "thetao_1", "ocean_heat_content", "hfds"]:
assert name in ds_prediction
# outputs should have some non-null values
assert not np.isnan(ds_prediction[name].values).all()
Expand Down
19 changes: 14 additions & 5 deletions fme/core/corrector/ocean.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,22 @@ def _force_conserve_ocean_heat_content(
name="ocean_heat_content",
)
try:
# First priority: pre-weighted heat flux in gen_data
net_energy_flux_into_ocean = (
gen.net_downward_surface_heat_flux + forcing.geothermal_heat_flux
) * forcing.sea_surface_fraction
gen.net_downward_surface_heat_flux_total_area
+ forcing.geothermal_heat_flux * forcing.sea_surface_fraction
)
except KeyError:
net_energy_flux_into_ocean = (
input.net_downward_surface_heat_flux + forcing.geothermal_heat_flux
) * forcing.sea_surface_fraction
try:
# Second priority: standard heat flux in gen_data
net_energy_flux_into_ocean = (
gen.net_downward_surface_heat_flux + forcing.geothermal_heat_flux
) * forcing.sea_surface_fraction
except KeyError:
# Third priority: standard heat flux in input_data
net_energy_flux_into_ocean = (
input.net_downward_surface_heat_flux + forcing.geothermal_heat_flux
) * forcing.sea_surface_fraction
energy_flux_global_mean = area_weighted_mean(
net_energy_flux_into_ocean,
keepdim=True,
Expand Down
25 changes: 19 additions & 6 deletions fme/core/corrector/test_ocean.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,14 @@ def test_sea_ice_thickness_correction():


@pytest.mark.parametrize(
"hfds_in_gen_data",
"hfds_type",
[
pytest.param(False, id="hfds_in_input"),
pytest.param(True, id="hfds_in_gen"),
pytest.param("input", id="hfds_in_input"),
pytest.param("gen", id="hfds_in_gen"),
pytest.param("total_area", id="hfds_total_area_in_gen"),
],
)
def test_ocean_heat_content_correction(hfds_in_gen_data):
def test_ocean_heat_content_correction(hfds_type):
config = OceanCorrectorConfig(
ocean_heat_content_correction=OceanHeatContentBudgetConfig(
method="scaled_temperature",
Expand All @@ -196,6 +197,8 @@ def test_ocean_heat_content_correction(hfds_in_gen_data):
idepth = torch.tensor([2.5, 10, 20])
depth_coordinate = DepthCoordinate(idepth, mask)

sea_surface_fraction = mask[:, :, :, 0]

input_data_dict = {
"thetao_0": torch.ones(nsamples, nlat, nlon),
"thetao_1": torch.ones(nsamples, nlat, nlon),
Expand All @@ -206,13 +209,23 @@ def test_ocean_heat_content_correction(hfds_in_gen_data):
"thetao_1": torch.ones(nsamples, nlat, nlon) * 2,
"sst": torch.ones(nsamples, nlat, nlon) * 2 + 273.15,
}
if hfds_in_gen_data:
if hfds_type == "gen":
gen_data_dict["hfds"] = torch.ones(nsamples, nlat, nlon)
elif hfds_type == "total_area":
# hfds_total_area is already weighted by sea_surface_fraction also
# include hfds with a different value to verify hfds_total_area takes
# priority
gen_data_dict["hfds"] = (
torch.ones(nsamples, nlat, nlon) * 100
) # should be ignored
gen_data_dict["hfds_total_area"] = (
torch.ones(nsamples, nlat, nlon) * sea_surface_fraction
)
else:
input_data_dict["hfds"] = torch.ones(nsamples, nlat, nlon)
forcing_data_dict = {
"hfgeou": torch.ones(nsamples, nlat, nlon),
"sea_surface_fraction": mask[:, :, :, 0],
"sea_surface_fraction": sea_surface_fraction,
}
input_data = OceanData(input_data_dict, depth_coordinate)
gen_data = OceanData(gen_data_dict, depth_coordinate)
Expand Down
17 changes: 16 additions & 1 deletion fme/core/ocean_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"ocean_sea_ice_fraction": ["ocean_sea_ice_fraction"],
"land_fraction": ["land_fraction"],
"net_downward_surface_heat_flux": ["hfds"],
"net_downward_surface_heat_flux_total_area": ["hfds_total_area"],
"geothermal_heat_flux": ["hfgeou"],
"sea_surface_fraction": ["sea_surface_fraction"],
}
Expand Down Expand Up @@ -159,7 +160,21 @@ def sea_surface_fraction(self) -> torch.Tensor:
@property
def net_downward_surface_heat_flux(self) -> torch.Tensor:
"""Net heat flux downward across the ocean surface (below the sea-ice)."""
return self._get("net_downward_surface_heat_flux")
try:
return self._get("net_downward_surface_heat_flux")
except KeyError:
# derive from the sea-surface-fraction-weighted version
return (
self.net_downward_surface_heat_flux_total_area
/ self.sea_surface_fraction
)

@property
def net_downward_surface_heat_flux_total_area(self) -> torch.Tensor:
"""Net heat flux downward across the ocean surface (below the sea-ice),
normalized by total grid cell area.
"""
return self._get("net_downward_surface_heat_flux_total_area")

@property
def geothermal_heat_flux(self) -> torch.Tensor:
Expand Down
9 changes: 9 additions & 0 deletions fme/core/ocean_derived_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,12 @@ def sea_ice_fraction(
) -> torch.Tensor:
"""Compute the sea ice fraction."""
return data.sea_ice_fraction


@register(VariableMetadata("W/m**2", "Surface ocean heat flux"), exists_ok=True)
def hfds(
data: OceanData,
timestep: datetime.timedelta,
) -> torch.Tensor:
"""Compute the net downward surface heat flux."""
return data.net_downward_surface_heat_flux