Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions fme/ace/step/fcn3.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class FCN3StepConfig(StepConfigABC):
default_factory=lambda: AtmosphereCorrectorConfig()
)
next_step_forcing_names: list[str] = dataclasses.field(default_factory=list)
prescribed_prognostic_names: list[str] = dataclasses.field(default_factory=list)
residual_prediction: bool = False

def __post_init__(self):
Expand Down Expand Up @@ -203,6 +204,12 @@ def __post_init__(self):
self.forcing_names + self.atmosphere_input_names + self.surface_input_names
)
self.out_names = self.atmosphere_output_names + self.surface_output_names
for name in self.prescribed_prognostic_names:
if name not in self.out_names:
raise ValueError(
f"prescribed_prognostic_name '{name}' must be in out_names: "
f"{self.out_names}"
)

@property
def n_ic_timesteps(self) -> int:
Expand Down Expand Up @@ -265,9 +272,11 @@ def output_names(self) -> list[str]:
def next_step_input_names(self) -> list[str]:
"""Names of variables provided in next_step_input_data."""
input_only_names = set(self.input_names).difference(self.output_names)
if self.ocean is None:
return list(input_only_names)
return list(input_only_names.union(self.ocean.forcing_names))
result = set(input_only_names)
if self.ocean is not None:
result = result.union(self.ocean.forcing_names)
result = result.union(self.prescribed_prognostic_names)
return list(result)

@property
def loss_names(self) -> list[str]:
Expand All @@ -285,6 +294,16 @@ def replace_ocean(self, ocean: OceanConfig | None):
def get_ocean(self) -> OceanConfig | None:
return self.ocean

def replace_prescribed_prognostic_names(self, names: list[str]) -> None:
"""Replace prescribed prognostic names (e.g. when loading from checkpoint)."""
for name in names:
if name not in self.out_names:
raise ValueError(
f"prescribed_prognostic_name '{name}' must be in out_names: "
f"{self.out_names}"
)
self.prescribed_prognostic_names = names

@classmethod
def _remove_deprecated_keys(cls, state: dict[str, Any]) -> dict[str, Any]:
state_copy = state.copy()
Expand Down Expand Up @@ -470,6 +489,7 @@ def network_call(input_norm: TensorDict) -> TensorDict:
ocean=self.ocean,
residual_prediction=self._config.residual_prediction,
prognostic_names=self.prognostic_names,
prescribed_prognostic_names=self._config.prescribed_prognostic_names,
)

def get_regularizer_loss(self):
Expand Down
41 changes: 41 additions & 0 deletions fme/ace/stepper/single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class SingleModuleStepperConfig:
loss: The loss configuration.
corrector: The corrector configuration.
next_step_forcing_names: Names of forcing variables for the next timestep.
prescribed_prognostic_names: Prognostic variable names to overwrite from
forcing data at each step during inference (e.g. air_temperature_7).
loss_normalization: The normalization configuration for the loss.
residual_normalization: Optional alternative to configure loss normalization.
If provided, it will be used for all *prognostic* variables in loss scaling.
Expand All @@ -123,13 +125,20 @@ class SingleModuleStepperConfig:
default_factory=lambda: AtmosphereCorrectorConfig()
)
next_step_forcing_names: list[str] = dataclasses.field(default_factory=list)
prescribed_prognostic_names: list[str] = dataclasses.field(default_factory=list)
loss_normalization: NormalizationConfig | None = None
residual_normalization: NormalizationConfig | None = None
multi_call: MultiCallConfig | None = None
include_multi_call_in_loss: bool = False
residual_prediction: bool = False

def __post_init__(self):
for name in self.prescribed_prognostic_names:
if name not in self.out_names:
raise ValueError(
f"prescribed_prognostic_name '{name}' must be in out_names: "
f"{self.out_names}"
)
for name in self.next_step_forcing_names:
if name not in self.in_names:
raise ValueError(
Expand Down Expand Up @@ -300,6 +309,7 @@ def _to_single_module_step_config(
ocean=self.ocean,
corrector=self.corrector,
next_step_forcing_names=self.next_step_forcing_names,
prescribed_prognostic_names=self.prescribed_prognostic_names,
residual_prediction=self.residual_prediction,
)

Expand Down Expand Up @@ -713,6 +723,10 @@ def replace_ocean(self, ocean: OceanConfig | None):
def get_ocean(self) -> OceanConfig | None:
return self.step.get_ocean()

def replace_prescribed_prognostic_names(self, names: list[str]) -> None:
"""Replace prescribed prognostic names (e.g. when loading from checkpoint)."""
self.step.replace_prescribed_prognostic_names(names)

def replace_multi_call(
self, multi_call: MultiCallConfig | None, state: dict[str, Any]
) -> dict[str, Any]:
Expand Down Expand Up @@ -969,6 +983,21 @@ def replace_ocean(self, ocean: OceanConfig | None):
new_stepper._step_obj.load_state(self._step_obj.get_state())
self._step_obj = new_stepper._step_obj

def replace_prescribed_prognostic_names(self, names: list[str]) -> None:
"""
Replace prescribed prognostic names (e.g. when loading from checkpoint).

Args:
names: The new list of prescribed prognostic variable names.
"""
self._config.replace_prescribed_prognostic_names(names)
new_stepper: Stepper = self._config.get_stepper(
dataset_info=self._dataset_info,
apply_parameter_init=False,
)
new_stepper._step_obj.load_state(self._step_obj.get_state())
self._step_obj = new_stepper._step_obj

def replace_derived_forcings(self, derived_forcings: DerivedForcingsConfig):
"""
Replace the derived forcings configuration with a new one.
Expand Down Expand Up @@ -1726,11 +1755,14 @@ class StepperOverrideConfig:
serialized stepper.
derived_forcings: Derived forcings configuration to override that used in
producing a serialized stepper.
prescribed_prognostic_names: List of prognostic variable names to overwrite
from forcing at each step during inference (e.g. ["air_temperature_7"]).
"""

ocean: Literal["keep"] | OceanConfig | None = "keep"
multi_call: Literal["keep"] | MultiCallConfig | None = "keep"
derived_forcings: Literal["keep"] | DerivedForcingsConfig = "keep"
prescribed_prognostic_names: Literal["keep"] | list[str] = "keep"


def load_stepper_config(
Expand Down Expand Up @@ -1792,4 +1824,13 @@ def load_stepper(
"derived_forcings configuration."
)
stepper.replace_derived_forcings(override_config.derived_forcings)

if override_config.prescribed_prognostic_names != "keep":
logging.info(
"Overriding prescribed_prognostic_names with %s.",
override_config.prescribed_prognostic_names,
)
stepper.replace_prescribed_prognostic_names(
override_config.prescribed_prognostic_names
)
return stepper
Loading