diff --git a/fme/ace/step/fcn3.py b/fme/ace/step/fcn3.py index 240f2aa17..439524d7c 100644 --- a/fme/ace/step/fcn3.py +++ b/fme/ace/step/fcn3.py @@ -173,6 +173,7 @@ class FCN3StepConfig(StepConfigABC): default_factory=lambda: AtmosphereCorrectorConfig() ) next_step_forcing_names: list[str] = dataclasses.field(default_factory=list) + prescribed_prognostic_names: list[str] = dataclasses.field(default_factory=list) residual_prediction: bool = False def __post_init__(self): @@ -203,6 +204,12 @@ def __post_init__(self): self.forcing_names + self.atmosphere_input_names + self.surface_input_names ) self.out_names = self.atmosphere_output_names + self.surface_output_names + for name in self.prescribed_prognostic_names: + if name not in self.out_names: + raise ValueError( + f"prescribed_prognostic_name '{name}' must be in out_names: " + f"{self.out_names}" + ) @property def n_ic_timesteps(self) -> int: @@ -265,9 +272,11 @@ def output_names(self) -> list[str]: def next_step_input_names(self) -> list[str]: """Names of variables provided in next_step_input_data.""" input_only_names = set(self.input_names).difference(self.output_names) - if self.ocean is None: - return list(input_only_names) - return list(input_only_names.union(self.ocean.forcing_names)) + result = set(input_only_names) + if self.ocean is not None: + result = result.union(self.ocean.forcing_names) + result = result.union(self.prescribed_prognostic_names) + return list(result) @property def loss_names(self) -> list[str]: @@ -285,6 +294,16 @@ def replace_ocean(self, ocean: OceanConfig | None): def get_ocean(self) -> OceanConfig | None: return self.ocean + def replace_prescribed_prognostic_names(self, names: list[str]) -> None: + """Replace prescribed prognostic names (e.g. when loading from checkpoint).""" + for name in names: + if name not in self.out_names: + raise ValueError( + f"prescribed_prognostic_name '{name}' must be in out_names: " + f"{self.out_names}" + ) + self.prescribed_prognostic_names = names + @classmethod def _remove_deprecated_keys(cls, state: dict[str, Any]) -> dict[str, Any]: state_copy = state.copy() @@ -470,6 +489,7 @@ def network_call(input_norm: TensorDict) -> TensorDict: ocean=self.ocean, residual_prediction=self._config.residual_prediction, prognostic_names=self.prognostic_names, + prescribed_prognostic_names=self._config.prescribed_prognostic_names, ) def get_regularizer_loss(self): diff --git a/fme/ace/stepper/single_module.py b/fme/ace/stepper/single_module.py index eeaee0335..7d5277603 100644 --- a/fme/ace/stepper/single_module.py +++ b/fme/ace/stepper/single_module.py @@ -100,6 +100,8 @@ class SingleModuleStepperConfig: loss: The loss configuration. corrector: The corrector configuration. next_step_forcing_names: Names of forcing variables for the next timestep. + prescribed_prognostic_names: Prognostic variable names to overwrite from + forcing data at each step during inference (e.g. air_temperature_7). loss_normalization: The normalization configuration for the loss. residual_normalization: Optional alternative to configure loss normalization. If provided, it will be used for all *prognostic* variables in loss scaling. @@ -123,6 +125,7 @@ class SingleModuleStepperConfig: default_factory=lambda: AtmosphereCorrectorConfig() ) next_step_forcing_names: list[str] = dataclasses.field(default_factory=list) + prescribed_prognostic_names: list[str] = dataclasses.field(default_factory=list) loss_normalization: NormalizationConfig | None = None residual_normalization: NormalizationConfig | None = None multi_call: MultiCallConfig | None = None @@ -130,6 +133,12 @@ class SingleModuleStepperConfig: residual_prediction: bool = False def __post_init__(self): + for name in self.prescribed_prognostic_names: + if name not in self.out_names: + raise ValueError( + f"prescribed_prognostic_name '{name}' must be in out_names: " + f"{self.out_names}" + ) for name in self.next_step_forcing_names: if name not in self.in_names: raise ValueError( @@ -300,6 +309,7 @@ def _to_single_module_step_config( ocean=self.ocean, corrector=self.corrector, next_step_forcing_names=self.next_step_forcing_names, + prescribed_prognostic_names=self.prescribed_prognostic_names, residual_prediction=self.residual_prediction, ) @@ -713,6 +723,10 @@ def replace_ocean(self, ocean: OceanConfig | None): def get_ocean(self) -> OceanConfig | None: return self.step.get_ocean() + def replace_prescribed_prognostic_names(self, names: list[str]) -> None: + """Replace prescribed prognostic names (e.g. when loading from checkpoint).""" + self.step.replace_prescribed_prognostic_names(names) + def replace_multi_call( self, multi_call: MultiCallConfig | None, state: dict[str, Any] ) -> dict[str, Any]: @@ -969,6 +983,21 @@ def replace_ocean(self, ocean: OceanConfig | None): new_stepper._step_obj.load_state(self._step_obj.get_state()) self._step_obj = new_stepper._step_obj + def replace_prescribed_prognostic_names(self, names: list[str]) -> None: + """ + Replace prescribed prognostic names (e.g. when loading from checkpoint). + + Args: + names: The new list of prescribed prognostic variable names. + """ + self._config.replace_prescribed_prognostic_names(names) + new_stepper: Stepper = self._config.get_stepper( + dataset_info=self._dataset_info, + apply_parameter_init=False, + ) + new_stepper._step_obj.load_state(self._step_obj.get_state()) + self._step_obj = new_stepper._step_obj + def replace_derived_forcings(self, derived_forcings: DerivedForcingsConfig): """ Replace the derived forcings configuration with a new one. @@ -1726,11 +1755,14 @@ class StepperOverrideConfig: serialized stepper. derived_forcings: Derived forcings configuration to override that used in producing a serialized stepper. + prescribed_prognostic_names: List of prognostic variable names to overwrite + from forcing at each step during inference (e.g. ["air_temperature_7"]). """ ocean: Literal["keep"] | OceanConfig | None = "keep" multi_call: Literal["keep"] | MultiCallConfig | None = "keep" derived_forcings: Literal["keep"] | DerivedForcingsConfig = "keep" + prescribed_prognostic_names: Literal["keep"] | list[str] = "keep" def load_stepper_config( @@ -1792,4 +1824,13 @@ def load_stepper( "derived_forcings configuration." ) stepper.replace_derived_forcings(override_config.derived_forcings) + + if override_config.prescribed_prognostic_names != "keep": + logging.info( + "Overriding prescribed_prognostic_names with %s.", + override_config.prescribed_prognostic_names, + ) + stepper.replace_prescribed_prognostic_names( + override_config.prescribed_prognostic_names + ) return stepper diff --git a/fme/ace/stepper/test_single_module.py b/fme/ace/stepper/test_single_module.py index b6e02cfe4..028eab7bb 100644 --- a/fme/ace/stepper/test_single_module.py +++ b/fme/ace/stepper/test_single_module.py @@ -31,6 +31,7 @@ from fme.ace.stepper.single_module import ( AtmosphereCorrectorConfig, EpochNotProvidedError, + SingleModuleStepperConfig, Stepper, StepperConfig, StepperOverrideConfig, @@ -1207,6 +1208,144 @@ def test_predict_with_forcing(n_ensemble): assert new_input_state.time.equals(output.time[:, -1:]) +def test_predict_with_prescribed_prognostic(): + """Prescribed prognostic "a" is overwritten from forcing at each step.""" + stepper = _get_stepper( + ["a", "b"], + ["a"], + module_name="ChannelSum", + prescribed_prognostic_names=["a"], + ) + n_steps = 3 + input_data, forcing_data = get_data_for_predict(n_steps, forcing_names=["a", "b"]) + output, _ = stepper.predict(input_data, forcing_data) + # Output "a" should be the forcing value at each step, not the model prediction. + assert output.data["a"].size(dim=1) == n_steps + # Forcing has shape [batch, n_ic + n_steps, ...]; + # output steps use indices 1..n_steps. + expected_a = forcing_data.data["a"][:, 1 : n_steps + 1] + torch.testing.assert_close(output.data["a"], expected_a) + + +def test_prescribed_prognostic_config_validation_raises(): + """SingleModuleStepperConfig raises when prescribed_prognostic_name is not in + out_names.""" + with pytest.raises(ValueError) as err: + SingleModuleStepperConfig( + builder=ModuleSelector( + type="prebuilt", config={"module": torch.nn.Identity()} + ), + in_names=["a"], + out_names=["a"], + normalization=NormalizationConfig(means={"a": 0.0}, stds={"a": 1.0}), + prescribed_prognostic_names=["b"], + ) + assert "prescribed_prognostic_name" in str(err.value) + assert "out_names" in str(err.value) + + +def test_predict_with_prescribed_prognostic_multiple_variables(): + """Multiple prescribed prognostics are overwritten from forcing.""" + # Use AddOne (2 in -> 2 out) so we can prescribe both "a" and "b". + stepper = _get_stepper( + ["a", "b"], + ["a", "b"], + module_name="AddOne", + prescribed_prognostic_names=["a", "b"], + ) + n_steps = 2 + n_samples = 3 + index = xr.date_range("2000", freq="6h", periods=n_steps + 1, use_cftime=True) + forcing_time = xr.DataArray(np.stack(n_samples * [index]), dims=["sample", "time"]) + input_time = forcing_time.isel(time=[0]) + # Initial condition must include all prognostics (a, b). + input_data = BatchData.new_on_device( + data={ + "a": torch.rand(n_samples, 1, 5, 5).to(DEVICE), + "b": torch.rand(n_samples, 1, 5, 5).to(DEVICE), + }, + time=input_time, + labels=None, + ).get_start(prognostic_names=["a", "b"], n_ic_timesteps=1) + forcing_data = BatchData.new_on_device( + data={ + "a": torch.rand(3, n_steps + 1, 5, 5).to(DEVICE), + "b": torch.rand(3, n_steps + 1, 5, 5).to(DEVICE), + }, + time=forcing_time, + labels=None, + ) + output, _ = stepper.predict(input_data, forcing_data) + expected_a = forcing_data.data["a"][:, 1 : n_steps + 1] + expected_b = forcing_data.data["b"][:, 1 : n_steps + 1] + torch.testing.assert_close(output.data["a"], expected_a) + torch.testing.assert_close(output.data["b"], expected_b) + + +def test_predict_with_prescribed_prognostic_and_ocean(): + """Prescribed overwrite happens after ocean; both can be used together.""" + # Ocean prescribes "a" over mask; we also prescribe "b" from forcing everywhere. + stepper = _get_stepper( + ["a", "mask"], + ["a", "b"], + module_name="AddOne", + ocean_config=OceanConfig("a", "mask"), + prescribed_prognostic_names=["b"], + ) + n_steps = 2 + input_data, forcing_data = get_data_for_predict( + n_steps, forcing_names=["a", "b", "mask"] + ) + # Where mask==1, ocean overwrites "a" with forcing "a"; "b" is always overwritten. + output, _ = stepper.predict(input_data, forcing_data) + expected_b = forcing_data.data["b"][:, 1 : n_steps + 1] + torch.testing.assert_close(output.data["b"], expected_b) + # "a" should be prescribed by ocean where mask==1 + # (same as test_predict_with_ocean logic) + for n in range(n_steps): + previous_a = ( + input_data.as_batch_data().data["a"][:, 0] + if n == 0 + else output.data["a"][:, n - 1] + ) + expected_a_n = torch.where( + torch.round(forcing_data.data["mask"][:, n + 1]).to(int) == 1, + forcing_data.data["a"][:, n + 1], + previous_a + 1, + ) + torch.testing.assert_close(output.data["a"][:, n], expected_a_n) + + +def test_get_forcing_window_data_requirements_includes_prescribed_names(): + """Forcing window data requirements include prescribed_prognostic_names.""" + config = StepperConfig( + step=StepSelector( + type="single_module", + config=dataclasses.asdict( + SingleModuleStepConfig( + builder=ModuleSelector( + type="prebuilt", config={"module": torch.nn.Identity()} + ), + in_names=["a", "b"], + out_names=["a"], + normalization=NetworkAndLossNormalizationConfig( + network=NormalizationConfig( + means={"a": 0.0, "b": 0.0}, + stds={"a": 1.0, "b": 1.0}, + ), + ), + prescribed_prognostic_names=["a"], + ) + ), + ), + loss=StepLossConfig(type="MSE"), + derived_forcings=DerivedForcingsConfig(), + ) + requirements = config.get_forcing_window_data_requirements(n_forward_steps=5) + assert "a" in requirements.names + assert "b" in requirements.names + + def test_predict_with_ocean(): stepper = _get_stepper(["a"], ["a"], ocean_config=OceanConfig("a", "mask")) n_steps = 3 @@ -1510,6 +1649,87 @@ def test_load_stepper_and_load_stepper_config( assert isinstance(stepper.forcing_deriver, ForcingDeriver) +def _get_inner_single_module_config(stepper: Stepper): + """Get the inner SingleModuleStep config from a stepper + (MultiCallStep or single).""" + from fme.core.step.multi_call import MultiCallStep + + if isinstance(stepper._step_obj, MultiCallStep): + return stepper._step_obj._wrapped_step.config + return stepper._step_obj.config + + +def validate_stepper_prescribed_prognostic_names( + stepper: Stepper, expected: list[str] +) -> None: + """Assert the stepper's inner step config has the given + prescribed_prognostic_names.""" + config = _get_inner_single_module_config(stepper) + assert config.prescribed_prognostic_names == expected + + +def test_load_stepper_with_prescribed_prognostic_override( + tmp_path: pathlib.Path, very_fast_only: bool +): + """Loading with StepperOverrideConfig(prescribed_prognostic_names=...) applies the + override.""" + if very_fast_only: + pytest.skip("Skipping non-fast tests") + in_names = ["co2", "var", "a", "b"] + out_names = ["var", "a"] + stepper_path = tmp_path / "stepper" + horizontal = [DimSize("grid_yt", 4), DimSize("grid_xt", 8)] + dim_sizes = DimSizes( + n_time=9, + horizontal=horizontal, + nz_interface=4, + ) + save_plus_one_stepper( + stepper_path, + in_names, + out_names, + normalization_names=set(in_names + out_names), + mean=0.0, + std=1.0, + data_shape=dim_sizes.shape_nd, + ) + + # Load without override: prescribed_prognostic_names should be [] (default). + stepper = load_stepper(stepper_path) + validate_stepper_prescribed_prognostic_names(stepper, []) + + # Load with override: prescribed_prognostic_names should be ["var"]. + stepper_override = StepperOverrideConfig(prescribed_prognostic_names=["var"]) + stepper = load_stepper(stepper_path, stepper_override) + validate_stepper_prescribed_prognostic_names(stepper, ["var"]) + + # Predict with forcing including "var"; output "var" should come from forcing. + n_steps = 2 + n_samples = 3 + index = xr.date_range("2000", freq="6h", periods=n_steps + 1, use_cftime=True) + forcing_time = xr.DataArray(np.stack(n_samples * [index]), dims=["sample", "time"]) + input_time = forcing_time.isel(time=[0]) + input_data = BatchData.new_on_device( + data={ + "var": torch.rand(n_samples, 1, 4, 8).to(DEVICE), + "a": torch.rand(n_samples, 1, 4, 8).to(DEVICE), + }, + time=input_time, + labels=None, + ).get_start(prognostic_names=["var", "a"], n_ic_timesteps=1) + forcing_data = BatchData.new_on_device( + data={ + name: torch.rand(3, n_steps + 1, 4, 8).to(DEVICE) + for name in ["co2", "var", "a", "b"] + }, + time=forcing_time, + labels=None, + ) + output, _ = stepper.predict(input_data, forcing_data) + expected_var = forcing_data.data["var"][:, 1 : n_steps + 1] + torch.testing.assert_close(output.data["var"], expected_var) + + def get_regression_stepper_and_data( crps_training: bool = False, ) -> tuple[Stepper, TrainStepperConfig, BatchData]: diff --git a/fme/core/step/multi_call.py b/fme/core/step/multi_call.py index 422308334..ec3614d1b 100644 --- a/fme/core/step/multi_call.py +++ b/fme/core/step/multi_call.py @@ -194,6 +194,9 @@ def replace_ocean(self, ocean: OceanConfig | None): def get_ocean(self) -> OceanConfig | None: return self.wrapped_step.get_ocean() + def replace_prescribed_prognostic_names(self, names: list[str]) -> None: + self.wrapped_step.replace_prescribed_prognostic_names(names) + def replace_multi_call(self, multi_call: MultiCallConfig | None): self.config = multi_call diff --git a/fme/core/step/single_module.py b/fme/core/step/single_module.py index a7671ffab..3d0c1a6cd 100644 --- a/fme/core/step/single_module.py +++ b/fme/core/step/single_module.py @@ -50,6 +50,8 @@ class SingleModuleStepConfig(StepConfigABC): ocean: The ocean configuration. corrector: The corrector configuration. next_step_forcing_names: Names of forcing variables for the next timestep. + prescribed_prognostic_names: Prognostic variable names to overwrite from + forcing data at each step (e.g. for inference with observed values). residual_prediction: Whether to use residual prediction. """ @@ -63,10 +65,17 @@ class SingleModuleStepConfig(StepConfigABC): default_factory=lambda: AtmosphereCorrectorConfig() ) next_step_forcing_names: list[str] = dataclasses.field(default_factory=list) + prescribed_prognostic_names: list[str] = dataclasses.field(default_factory=list) residual_prediction: bool = False def __post_init__(self): self.crps_training = None # unused, kept for backwards compatibility + for name in self.prescribed_prognostic_names: + if name not in self.out_names: + raise ValueError( + f"prescribed_prognostic_name '{name}' must be in out_names: " + f"{self.out_names}" + ) for name in self.next_step_forcing_names: if name not in self.in_names: raise ValueError( @@ -153,9 +162,11 @@ def output_names(self) -> list[str]: def next_step_input_names(self) -> list[str]: """Names of variables provided in next_step_input_data.""" input_only_names = set(self.input_names).difference(self.output_names) - if self.ocean is None: - return list(input_only_names) - return list(input_only_names.union(self.ocean.forcing_names)) + result = set(input_only_names) + if self.ocean is not None: + result = result.union(self.ocean.forcing_names) + result = result.union(self.prescribed_prognostic_names) + return list(result) @property def loss_names(self) -> list[str]: @@ -173,6 +184,16 @@ def replace_ocean(self, ocean: OceanConfig | None): def get_ocean(self) -> OceanConfig | None: return self.ocean + def replace_prescribed_prognostic_names(self, names: list[str]) -> None: + """Replace prescribed prognostic names (e.g. when loading from checkpoint).""" + for name in names: + if name not in self.out_names: + raise ValueError( + f"prescribed_prognostic_name '{name}' must be in out_names: " + f"{self.out_names}" + ) + self.prescribed_prognostic_names = names + @classmethod def _remove_deprecated_keys(cls, state: dict[str, Any]) -> dict[str, Any]: state_copy = state.copy() @@ -353,6 +374,7 @@ def network_call(input_norm: TensorDict) -> TensorDict: ocean=self.ocean, residual_prediction=self._config.residual_prediction, prognostic_names=self.prognostic_names, + prescribed_prognostic_names=self._config.prescribed_prognostic_names, ) def get_regularizer_loss(self): @@ -394,6 +416,7 @@ def step_with_adjustments( ocean: Ocean | None, residual_prediction: bool, prognostic_names: list[str], + prescribed_prognostic_names: list[str] | None = None, ) -> TensorDict: """ Step the model forward one timestep given input data. @@ -413,10 +436,14 @@ def step_with_adjustments( ocean: The ocean model to use. residual_prediction: Whether to use residual prediction. prognostic_names: Names of prognostic variables. + prescribed_prognostic_names: Prognostic names to overwrite from + next_step_input_data after the ocean step (e.g. for inference). Returns: The denormalized output data at the next time step. """ + if prescribed_prognostic_names is None: + prescribed_prognostic_names = [] input_norm = normalizer.normalize(input) output_norm = network_calls(input_norm) if residual_prediction: @@ -426,4 +453,7 @@ def step_with_adjustments( output = corrector(input, output, next_step_input_data) if ocean is not None: output = ocean(input, output, next_step_input_data) + for name in prescribed_prognostic_names: + if name in next_step_input_data: + output = {**output, name: next_step_input_data[name]} return output diff --git a/fme/core/step/step.py b/fme/core/step/step.py index c011e1a93..62cdf4064 100644 --- a/fme/core/step/step.py +++ b/fme/core/step/step.py @@ -108,6 +108,10 @@ def replace_ocean(self, ocean: OceanConfig | None): def get_ocean(self) -> OceanConfig | None: pass + def replace_prescribed_prognostic_names(self, names: list[str]) -> None: + """Replace prescribed prognostic names (e.g. when loading from checkpoint).""" + pass + @abc.abstractmethod def load(self): """ @@ -211,6 +215,10 @@ def replace_ocean(self, ocean: OceanConfig | None): def get_ocean(self) -> OceanConfig | None: return self._step_config_instance.get_ocean() + def replace_prescribed_prognostic_names(self, names: list[str]) -> None: + self._step_config_instance.replace_prescribed_prognostic_names(names) + self.config = dataclasses.asdict(self._step_config_instance) + def load(self): self._step_config_instance.load() self.config = dataclasses.asdict(self._step_config_instance) diff --git a/fme/core/step/test_step.py b/fme/core/step/test_step.py index a0eb4e3f6..3cb5e53a0 100644 --- a/fme/core/step/test_step.py +++ b/fme/core/step/test_step.py @@ -653,3 +653,72 @@ def test_input_output_names_secondary_decoder_conflict(conflict: str): ), ) assert f"secondary_diagnostic_name is an {conflict} variable:" in str(err.value) + + +def test_prescribed_prognostic_names_must_be_in_out_names(): + """SingleModuleStepConfig raises when prescribed_prognostic_name is not in out_names + .""" + normalization = get_network_and_loss_normalization_config( + names=["a", "b", "c"], + ) + with pytest.raises(ValueError) as err: + SingleModuleStepConfig( + builder=ModuleSelector( + type="SphericalFourierNeuralOperatorNet", + config={"scale_factor": 1, "embed_dim": 4, "num_layers": 2}, + ), + in_names=["a", "b"], + out_names=["a"], + normalization=normalization, + prescribed_prognostic_names=["c"], + ) + assert "prescribed_prognostic_name" in str(err.value) + assert "out_names" in str(err.value) + assert "c" in str(err.value) + + +def test_step_with_prescribed_prognostic_overwrites_output(): + """Step output is overwritten for prescribed_prognostic_names from + next_step_input_data.""" + normalization = get_network_and_loss_normalization_config( + names=["forcing_shared", "forcing_rad", "diagnostic_main", "diagnostic_rad"], + ) + config = StepSelector( + type="single_module", + config=dataclasses.asdict( + SingleModuleStepConfig( + builder=ModuleSelector( + type="SphericalFourierNeuralOperatorNet", + config={ + "scale_factor": 1, + "embed_dim": 4, + "num_layers": 2, + }, + ), + in_names=["forcing_shared", "forcing_rad"], + out_names=["diagnostic_main", "diagnostic_rad"], + normalization=normalization, + prescribed_prognostic_names=["diagnostic_main"], + ), + ), + ) + img_shape = DEFAULT_IMG_SHAPE + n_samples = 2 + step = get_step(config, img_shape) + input_data = get_tensor_dict(step.input_names, img_shape, n_samples) + next_step_input_data = get_tensor_dict( + step.next_step_input_names, img_shape, n_samples + ) + prescribed_value = torch.full( + (n_samples,) + img_shape, 42.0, device=fme.get_device() + ) + next_step_input_data["diagnostic_main"] = prescribed_value + output = step.step( + args=StepArgs( + input=input_data, + next_step_input_data=next_step_input_data, + labels=None, + ), + wrapper=lambda x: x, + ) + torch.testing.assert_close(output["diagnostic_main"], prescribed_value)