diff --git a/fme/ace/stepper/single_module.py b/fme/ace/stepper/single_module.py index 3399a9fbf..03dcc6939 100644 --- a/fme/ace/stepper/single_module.py +++ b/fme/ace/stepper/single_module.py @@ -1,3 +1,4 @@ +import abc import contextlib import dataclasses import datetime @@ -801,6 +802,559 @@ def probabilities_from_time_length(value: TimeLength) -> TimeLengthProbabilities return TimeLengthProbabilities.from_constant(value) +class InferenceStepper(abc.ABC): + @property + @abc.abstractmethod + def config(self) -> StepperConfig: ... + + @property + @abc.abstractmethod + def derive_func(self) -> Callable[[TensorMapping, TensorMapping], TensorDict]: ... + + @property + @abc.abstractmethod + def surface_temperature_name(self) -> str | None: ... + + @property + @abc.abstractmethod + def ocean_fraction_name(self) -> str | None: ... + + @abc.abstractmethod + def prescribe_sst( + self, + mask_data: TensorMapping, + gen_data: TensorMapping, + target_data: TensorMapping, + ) -> TensorDict: + """ + Prescribe sea surface temperature onto the generated surface temperature field. + + Args: + mask_data: Source for the prescriber mask field. + gen_data: Contains the generated surface temperature field. + target_data: Contains the target surface temperature that will + be prescribed onto the generated one according to the mask. + """ + ... + + @property + @abc.abstractmethod + def training_dataset_info(self) -> DatasetInfo: ... + + @property + @abc.abstractmethod + def training_variable_metadata(self) -> Mapping[str, VariableMetadata]: ... + + @property + @abc.abstractmethod + def training_history(self) -> TrainingHistory: ... + + @property + @abc.abstractmethod + def effective_loss_scaling(self) -> TensorDict: + """ + Effective loss scalings used to normalize outputs before computing loss. + y_loss_normalized_i = (y_i - y_mean_i) / loss_scaling_i + where loss_scaling_i = loss_normalizer_std_i / weight_i. + """ + ... + + @abc.abstractmethod + def replace_multi_call(self, multi_call: MultiCallConfig | None): + """ + Replace the MultiCall object with a new one. Note this is only + meant to be used at inference time and may result in the loss + function being unusable. + + Args: + multi_call: The new multi_call configuration or None. + """ + ... + + @abc.abstractmethod + def replace_ocean(self, ocean: OceanConfig | None): + """ + Replace the ocean model with a new one. + + Args: + ocean: The new ocean model configuration or None. + """ + ... + + @abc.abstractmethod + def replace_derived_forcings(self, derived_forcings: DerivedForcingsConfig): + """ + Replace the derived forcings configuration with a new one. + + Args: + derived_forcings: The new derived forcings configuration or None. + """ + ... + + @abc.abstractmethod + def get_base_weights(self) -> Weights | None: + """ + Get the base weights of the stepper. + + Returns: + A list of weight dictionaries for each module in the stepper. + """ + ... + + @property + @abc.abstractmethod + def prognostic_names(self) -> list[str]: ... + + @property + @abc.abstractmethod + def out_names(self) -> list[str]: ... + + @property + @abc.abstractmethod + def loss_names(self) -> list[str]: ... + + @property + @abc.abstractmethod + def n_ic_timesteps(self) -> int: ... + + @property + @abc.abstractmethod + def modules(self) -> nn.ModuleList: + """ + Returns: + A list of modules being trained. + """ + ... + + @property + @abc.abstractmethod + def normalizer(self) -> StandardNormalizer: ... + + @abc.abstractmethod + def step( + self, + args: StepArgs, + wrapper: Callable[[nn.Module], nn.Module] = lambda x: x, + ) -> TensorDict: + """ + Step the model forward one timestep given input data. + + Args: + args: The arguments to the step function. + wrapper: Wrapper to apply over each nn.Module before calling. + + Returns: + The denormalized output data at the next time step. + """ + ... + + @abc.abstractmethod + def get_prediction_generator( + self, + initial_condition: PrognosticState, + forcing_data: BatchData, + n_forward_steps: int, + optimizer: OptimizationABC, + ) -> Generator[TensorDict, None, None]: + """ + Predict multiple steps forward given initial condition and forcing data. + + Uses low-level inputs and does not compute derived variables, to separate + concerns from the `predict` method. + + Args: + initial_condition: The initial condition, containing tensors of shape + [n_batch, self.n_ic_timesteps, ]. + forcing_data: The forcing data, containing tensors of shape + [n_batch, n_forward_steps + self.n_ic_timesteps, ]. + n_forward_steps: The number of forward steps to predict, corresponding + to the data shapes of forcing_data. + optimizer: The optimizer to use for updating the module. + + Returns: + Generator yielding the output data at each timestep. + """ + ... + + @abc.abstractmethod + def predict( + self, + initial_condition: PrognosticState, + forcing: BatchData, + compute_derived_variables: bool = False, + compute_derived_forcings: bool = True, + ) -> tuple[BatchData, PrognosticState]: + """ + Predict multiple steps forward given initial condition and reference data. + + Args: + initial_condition: Prognostic state data with tensors of shape + [n_batch, self.n_ic_timesteps, ]. This data is assumed + to contain all prognostic variables and be denormalized. + forcing: Contains tensors of shape + [n_batch, self.n_ic_timesteps + n_forward_steps, n_lat, n_lon]. This + contains the forcing and ocean data for the initial condition and all + subsequent timesteps. + compute_derived_variables: Whether to compute derived variables for the + prediction. + compute_derived_forcings: Whether to compute derived forcing variables for + the prediction. Only used to disable computing the derived forcings + if they have been computed ahead of time. + + Returns: + A batch data containing the prediction and the prediction's final state + which can be used as a new initial condition. + """ + ... + + @abc.abstractmethod + def predict_paired( + self, + initial_condition: PrognosticState, + forcing: BatchData, + compute_derived_variables: bool = False, + ) -> tuple[PairedData, PrognosticState]: + """ + Predict multiple steps forward given initial condition and reference data. + + Args: + initial_condition: Prognostic state data with tensors of shape + [n_batch, self.n_ic_timesteps, ]. This data is assumed + to contain all prognostic variables and be denormalized. + forcing: Contains tensors of shape + [n_batch, self.n_ic_timesteps + n_forward_steps, n_lat, n_lon]. This + contains the forcing and ocean data for the initial condition and all + subsequent timesteps. + compute_derived_variables: Whether to compute derived variables for the + prediction. + + Returns: + A tuple of 1) a paired data object, containing the prediction paired with + all target/forcing data at the same timesteps, and 2) the prediction's + final state, which can be used as a new initial condition. + """ + ... + + @abc.abstractmethod + def get_forward_data( + self, data: BatchData, compute_derived_variables: bool = False + ) -> BatchData: ... + + @abc.abstractmethod + def train_on_batch( + self, + data: BatchData, + optimization: OptimizationABC, + compute_derived_variables: bool = False, + ) -> TrainOutput: + """ + Train the model on a batch of data with one or more forward steps. + + If gradient accumulation is used by the optimization, the computational graph is + detached between steps to reduce memory consumption. This means the model learns + how to deal with inputs on step N but does not try to improve the behavior at + step N by modifying the behavior for step N-1. + + Args: + data: The batch data where each tensor in data.data has shape + [n_sample, n_forward_steps + self.n_ic_timesteps, ]. + optimization: The optimization class to use for updating the module. + Use `NullOptimization` to disable training. + compute_derived_variables: Whether to compute derived variables for the + prediction and target data. + + Returns: + The loss metrics, the generated data, the normalized generated data, + and the normalized batch data. + """ + ... + + @abc.abstractmethod + def update_training_history(self, training_job: TrainingJob) -> None: + """ + Update the stepper's history of training jobs. + + Args: + training_job: The training job to add to the history. + """ + ... + + @abc.abstractmethod + def get_state(self): + """ + Returns: + The state of the stepper. + """ + ... + + @abc.abstractmethod + def load_state(self, state: dict[str, Any]) -> None: + """ + Load the state of the stepper. + + Args: + state: The state to load. + """ + ... + + @classmethod + @abc.abstractmethod + def from_state(cls, state) -> "Stepper": + """ + Load the state of the stepper. + + Args: + state: The state to load. + + Returns: + The stepper. + """ + ... + + @property + @abc.abstractmethod + def TIME_DIM(self) -> int: ... + + @property + @abc.abstractmethod + def CHANNEL_DIM(self) -> int: ... + + @abc.abstractmethod + def set_train(self): + """ + Set the stepper to training mode. + """ + ... + + @abc.abstractmethod + def set_eval(self): + """ + Set the stepper to evaluation mode. + """ + ... + + +class TrainStepper(abc.ABC): + @property + @abc.abstractmethod + def inference(self) -> InferenceStepper: ... + + @property + @abc.abstractmethod + def loss_obj(self) -> StepLoss: ... + + @property + @abc.abstractmethod + def config(self) -> StepperConfig: ... + + @property + @abc.abstractmethod + def derive_func(self) -> Callable[[TensorMapping, TensorMapping], TensorDict]: ... + + @property + @abc.abstractmethod + def surface_temperature_name(self) -> str | None: ... + + @property + @abc.abstractmethod + def ocean_fraction_name(self) -> str | None: ... + + @abc.abstractmethod + def prescribe_sst( + self, + mask_data: TensorMapping, + gen_data: TensorMapping, + target_data: TensorMapping, + ) -> TensorDict: + """ + Prescribe sea surface temperature onto the generated surface temperature field. + + Args: + mask_data: Source for the prescriber mask field. + gen_data: Contains the generated surface temperature field. + target_data: Contains the target surface temperature that will + be prescribed onto the generated one according to the mask. + """ + ... + + @property + @abc.abstractmethod + def training_dataset_info(self) -> DatasetInfo: ... + + @property + @abc.abstractmethod + def training_variable_metadata(self) -> Mapping[str, VariableMetadata]: ... + + @property + @abc.abstractmethod + def training_history(self) -> TrainingHistory: ... + + @property + @abc.abstractmethod + def effective_loss_scaling(self) -> TensorDict: + """ + Effective loss scalings used to normalize outputs before computing loss. + y_loss_normalized_i = (y_i - y_mean_i) / loss_scaling_i + where loss_scaling_i = loss_normalizer_std_i / weight_i. + """ + ... + + @abc.abstractmethod + def replace_multi_call(self, multi_call: MultiCallConfig | None): + """ + Replace the MultiCall object with a new one. Note this is only + meant to be used at inference time and may result in the loss + function being unusable. + + Args: + multi_call: The new multi_call configuration or None. + """ + ... + + @abc.abstractmethod + def replace_ocean(self, ocean: OceanConfig | None): + """ + Replace the ocean model with a new one. + + Args: + ocean: The new ocean model configuration or None. + """ + ... + + @abc.abstractmethod + def replace_derived_forcings(self, derived_forcings: DerivedForcingsConfig): + """ + Replace the derived forcings configuration with a new one. + + Args: + derived_forcings: The new derived forcings configuration or None. + """ + ... + + @abc.abstractmethod + def get_base_weights(self) -> Weights | None: + """ + Get the base weights of the stepper. + + Returns: + A list of weight dictionaries for each module in the stepper. + """ + ... + + @property + @abc.abstractmethod + def prognostic_names(self) -> list[str]: ... + + @property + @abc.abstractmethod + def out_names(self) -> list[str]: ... + + @property + @abc.abstractmethod + def loss_names(self) -> list[str]: ... + + @property + @abc.abstractmethod + def n_ic_timesteps(self) -> int: ... + + @property + @abc.abstractmethod + def modules(self) -> nn.ModuleList: + """ + Returns: + A list of modules being trained. + """ + ... + + @property + @abc.abstractmethod + def normalizer(self) -> StandardNormalizer: ... + + @abc.abstractmethod + def get_forward_data( + self, data: BatchData, compute_derived_variables: bool = False + ) -> BatchData: ... + + @abc.abstractmethod + def train_on_batch( + self, + data: BatchData, + optimization: OptimizationABC, + compute_derived_variables: bool = False, + ) -> TrainOutput: + """ + Train the model on a batch of data with one or more forward steps. + + If gradient accumulation is used by the optimization, the computational graph is + detached between steps to reduce memory consumption. This means the model learns + how to deal with inputs on step N but does not try to improve the behavior at + step N by modifying the behavior for step N-1. + + Args: + data: The batch data where each tensor in data.data has shape + [n_sample, n_forward_steps + self.n_ic_timesteps, ]. + optimization: The optimization class to use for updating the module. + Use `NullOptimization` to disable training. + compute_derived_variables: Whether to compute derived variables for the + prediction and target data. + + Returns: + The loss metrics, the generated data, the normalized generated data, + and the normalized batch data. + """ + ... + + @abc.abstractmethod + def update_training_history(self, training_job: TrainingJob) -> None: + """ + Update the stepper's history of training jobs. + + Args: + training_job: The training job to add to the history. + """ + ... + + @abc.abstractmethod + def get_state(self): + """ + Returns: + The state of the stepper. + """ + ... + + @abc.abstractmethod + def load_state(self, state: dict[str, Any]) -> None: + """ + Load the state of the stepper. + + Args: + state: The state to load. + """ + ... + + @classmethod + @abc.abstractmethod + def from_state(cls, state) -> "Stepper": + """ + Load the state of the stepper. + + Args: + state: The state to load. + + Returns: + The stepper. + """ + ... + + @property + @abc.abstractmethod + def TIME_DIM(self) -> int: ... + + @property + @abc.abstractmethod + def CHANNEL_DIM(self) -> int: ... + + class Stepper( TrainStepperABC[ PrognosticState, @@ -808,7 +1362,9 @@ class Stepper( BatchData, PairedData, TrainOutput, - ] + ], + TrainStepper, + InferenceStepper, ): """ Stepper class for selectable step configurations. @@ -906,6 +1462,10 @@ def get_loss_obj() -> StepLoss: self._dataset_info = dataset_info self._forcing_deriver = config.derived_forcings.build(dataset_info) + @property + def inference(self) -> InferenceStepper: + return self + def _init_for_epoch(self, epoch: int | None): if ( epoch is None diff --git a/fme/coupled/stepper.py b/fme/coupled/stepper.py index 9601f3a94..8fc3367e2 100644 --- a/fme/coupled/stepper.py +++ b/fme/coupled/stepper.py @@ -25,7 +25,7 @@ Weights, WeightsAndHistoryLoader, ) -from fme.ace.stepper.single_module import StepperConfig +from fme.ace.stepper.single_module import InferenceStepper, StepperConfig, TrainStepper from fme.ace.stepper.single_module import ( load_weights_and_history as load_uncoupled_weights_and_history, ) @@ -150,7 +150,7 @@ def build_ocean_data( def _load_stepper_weights_and_history_factory( - stepper: Stepper, + stepper: TrainStepper, ) -> WeightsAndHistoryLoader: def load_stepper_weights_and_history(*_) -> StepperWeightsAndHistory: return_weights: Weights = [] @@ -799,22 +799,14 @@ def __call__( return None -class CoupledStepper( - TrainStepperABC[ - CoupledPrognosticState, - CoupledBatchData, - CoupledBatchData, - CoupledPairedData, - CoupledTrainOutput, - ], -): +class CoupledInferenceStepper: TIME_DIM = 1 def __init__( self, config: CoupledStepperConfig, - ocean: Stepper, - atmosphere: Stepper, + ocean: InferenceStepper, + atmosphere: InferenceStepper, dataset_info: CoupledDatasetInfo, ): """ @@ -833,16 +825,6 @@ def __init__( self._dataset_info = dataset_info self._ocean_mask_provider = dataset_info.ocean_mask_provider - ocean_loss = self._config.get_ocean_loss( - self.ocean.loss_obj, - ocean.TIME_DIM, - ) - atmos_loss = self._config.get_atmosphere_loss( - self.atmosphere.loss_obj, - atmosphere.TIME_DIM, - ) - self._loss = self._config.get_loss(ocean_loss, atmos_loss) - _: PredictFunction[ # for type checking CoupledPrognosticState, CoupledBatchData, @@ -1319,6 +1301,175 @@ def predict( ), ) + def update_training_history(self, training_job: TrainingJob) -> None: + """ + Update the stepper's history of training jobs. + + Args: + training_job: The training job to add to the history. + """ + self.ocean.update_training_history(training_job) + self.atmosphere.update_training_history(training_job) + + @classmethod + def from_state(cls, state) -> "CoupledInferenceStepper": + ocean = Stepper.from_state(state["ocean_state"]) + atmosphere = Stepper.from_state(state["atmosphere_state"]) + config = CoupledStepperConfig.from_state(state["config"]) + if "dataset_info" in state: + dataset_info = CoupledDatasetInfo.from_state(state["dataset_info"]) + else: + # NOTE: this is included for backwards compatibility + dataset_info = CoupledDatasetInfo( + ocean=ocean.training_dataset_info, + atmosphere=atmosphere.training_dataset_info, + ) + return cls( + config=config, + ocean=ocean, + atmosphere=atmosphere, + dataset_info=dataset_info, + ) + + +class CoupledStepper( + TrainStepperABC[ + CoupledPrognosticState, + CoupledBatchData, + CoupledBatchData, + CoupledPairedData, + CoupledTrainOutput, + ], +): + TIME_DIM = 1 + + def __init__( + self, + config: CoupledStepperConfig, + ocean: TrainStepper, + atmosphere: TrainStepper, + dataset_info: CoupledDatasetInfo, + ): + """ + Args: + config: The configuration. + ocean: The ocean stepper. + atmosphere: The atmosphere stepper. + dataset_info: The CoupledDatasetInfo. + """ + if ocean.n_ic_timesteps != 1 or atmosphere.n_ic_timesteps != 1: + raise ValueError("Only n_ic_timesteps = 1 is currently supported.") + + self.inference = CoupledInferenceStepper( + config=config, + ocean=ocean.inference, + atmosphere=atmosphere.inference, + dataset_info=dataset_info, + ) + + self.ocean = ocean + self.atmosphere = atmosphere + self._config = config + self._dataset_info = dataset_info + self._ocean_mask_provider = dataset_info.ocean_mask_provider + + ocean_loss = self._config.get_ocean_loss( + self.ocean.loss_obj, + ocean.TIME_DIM, + ) + atmos_loss = self._config.get_atmosphere_loss( + self.atmosphere.loss_obj, + atmosphere.TIME_DIM, + ) + self._loss = self._config.get_loss(ocean_loss, atmos_loss) + + _: PredictFunction[ # for type checking + CoupledPrognosticState, + CoupledBatchData, + CoupledPairedData, + ] = self.predict_paired + + @property + def modules(self) -> nn.ModuleList: + return nn.ModuleList([*self.atmosphere.modules, *self.ocean.modules]) + + def set_train(self): + self.inference.set_train() + + def set_eval(self): + self.inference.set_eval() + + def get_state(self): + """ + Returns: + The state of the coupled stepper. + """ + return { + "config": self._config.get_state(), + "atmosphere_state": self.atmosphere.get_state(), + "ocean_state": self.ocean.get_state(), + "dataset_info": self._dataset_info.to_state(), + } + + def load_state(self, state: dict[str, Any]): + self.atmosphere.load_state(state["atmosphere_state"]) + self.ocean.load_state(state["ocean_state"]) + + @property + def training_dataset_info(self) -> CoupledDatasetInfo: + return self._dataset_info + + @property + def n_ic_timesteps(self) -> int: + return 1 + + @property + def n_inner_steps(self) -> int: + """Number of atmosphere steps per ocean step.""" + return self._config.n_inner_steps + + def get_prediction_generator( + self, + initial_condition: CoupledPrognosticState, + forcing_data: CoupledBatchData, + optimizer: OptimizationABC, + ) -> Generator[ComponentStepPrediction, None, None]: + return self.inference.get_prediction_generator( + initial_condition, forcing_data, optimizer + ) + + def _process_prediction_generator_list( + self, + output_list: list[ComponentStepPrediction], + forcing_data: CoupledBatchData, + ) -> CoupledBatchData: + atmos_data = process_prediction_generator_list( + [x.data for x in output_list if x.realm == "atmosphere"], + time=forcing_data.atmosphere_data.time[:, self.atmosphere.n_ic_timesteps :], + horizontal_dims=forcing_data.atmosphere_data.horizontal_dims, + labels=forcing_data.atmosphere_data.labels, + ) + ocean_data = process_prediction_generator_list( + [x.data for x in output_list if x.realm == "ocean"], + time=forcing_data.ocean_data.time[:, self.ocean.n_ic_timesteps :], + horizontal_dims=forcing_data.ocean_data.horizontal_dims, + labels=forcing_data.ocean_data.labels, + ) + return CoupledBatchData(ocean_data=ocean_data, atmosphere_data=atmos_data) + + def predict_paired( + self, + initial_condition: CoupledPrognosticState, + forcing: CoupledBatchData, + compute_derived_variables: bool = False, + ) -> tuple[CoupledPairedData, CoupledPrognosticState]: + """ + Predict multiple steps forward given initial condition and reference data. + """ + return self.inference.predict_paired( + initial_condition, forcing, compute_derived_variables + ) + def train_on_batch( self, data: CoupledBatchData, diff --git a/fme/coupled/test_stepper.py b/fme/coupled/test_stepper.py index ef0b611a5..a8e83c4d2 100644 --- a/fme/coupled/test_stepper.py +++ b/fme/coupled/test_stepper.py @@ -1229,7 +1229,7 @@ def test__get_atmosphere_forcings( .clone() .expand(*shape_atmos) ) - new_atmos_forcings = coupler._get_atmosphere_forcings( + new_atmos_forcings = coupler.inference._get_atmosphere_forcings( atmos_forcing_data, forcings_from_ocean ) for name in expected_atmos_forcings: @@ -1271,7 +1271,7 @@ def test__get_ocean_forcings(): "exog": atmos_forcings["exog"].mean(dim=1), "a_diag": atmos_gen["a_diag"].mean(dim=1), } - new_ocean_forcings = coupler._get_ocean_forcings( + new_ocean_forcings = coupler.inference._get_ocean_forcings( ocean_data, atmos_gen, atmos_forcings ) assert new_ocean_forcings.keys() == expected_ocean_forcings.keys()