diff --git a/fme/ace/stepper/single_module.py b/fme/ace/stepper/single_module.py index eeaee0335..11b579c8b 100644 --- a/fme/ace/stepper/single_module.py +++ b/fme/ace/stepper/single_module.py @@ -823,20 +823,6 @@ def __init__( self._no_optimization = NullOptimization() self._parameter_initializer = parameter_initializer - def get_loss_obj() -> StepLoss: - loss_normalizer = step.get_loss_normalizer() - if config.loss is None: - raise ValueError("Loss is not configured") - return config.loss.build( - dataset_info.gridded_operations, - out_names=config.loss_names, - channel_dim=self.CHANNEL_DIM, - normalizer=loss_normalizer, - ) - - self._get_loss_obj = get_loss_obj - self._loss_obj: StepLoss | None = None - self._parameter_initializer.apply_weights( step.modules, ) @@ -863,11 +849,24 @@ def get_loss_obj() -> StepLoss: self._dataset_info = dataset_info self.forcing_deriver = config.derived_forcings.build(dataset_info) - @property - def loss_obj(self) -> StepLoss: - if self._loss_obj is None: - self._loss_obj = self._get_loss_obj() - return self._loss_obj + def build_loss(self, loss_config: StepLossConfig) -> StepLoss: + """Build a StepLoss from the given config using this stepper's normalizer + and dataset info. + + Args: + loss_config: The loss configuration to build from. + + Returns: + A StepLoss built using this stepper's loss normalizer, gridded + operations, loss variable names, and channel dimension. + """ + loss_normalizer = self._step_obj.get_loss_normalizer() + return loss_config.build( + self._dataset_info.gridded_operations, + out_names=self.loss_names, + channel_dim=self.CHANNEL_DIM, + normalizer=loss_normalizer, + ) @property def config(self) -> StepperConfig: @@ -928,15 +927,6 @@ def _append_training_history_from( if base_training_history is not None: self._training_history.extend(base_training_history) - @property - def effective_loss_scaling(self) -> TensorDict: - """ - Effective loss scalings used to normalize outputs before computing loss. - y_loss_normalized_i = (y_i - y_mean_i) / loss_scaling_i - where loss_scaling_i = loss_normalizer_std_i / weight_i. - """ - return self.loss_obj.effective_loss_scaling - def replace_multi_call(self, multi_call: MultiCallConfig | None): """ Replace the MultiCall object with a new one. Note this is only @@ -1472,7 +1462,7 @@ def __init__( self._prognostic_names = self._stepper.prognostic_names self._derive_func = self._stepper.derive_func - self._loss_obj = self._stepper.loss_obj + self._loss_obj = self._stepper.build_loss(config.loss) def train_on_batch( self, @@ -1660,7 +1650,7 @@ def effective_loss_scaling(self) -> TensorDict: y_loss_normalized_i = (y_i - y_mean_i) / loss_scaling_i where loss_scaling_i = loss_normalizer_std_i / weight_i. """ - return self._stepper.effective_loss_scaling + return self._loss_obj.effective_loss_scaling def _init_for_epoch(self, epoch: int | None): if ( diff --git a/fme/ace/stepper/test_single_module.py b/fme/ace/stepper/test_single_module.py index b6e02cfe4..c21bc17c0 100644 --- a/fme/ace/stepper/test_single_module.py +++ b/fme/ace/stepper/test_single_module.py @@ -1332,11 +1332,12 @@ def test_stepper_from_state_using_resnorm_has_correct_normalizer(): stepper_from_state = Stepper.from_state(orig_stepper.get_state()) for stepper in [orig_stepper, stepper_from_state]: - assert stepper.loss_obj._normalizer.means == { + loss = stepper.build_loss(StepLossConfig()) + assert loss._normalizer.means == { **residual_means, "diagnostic": full_field_means["diagnostic"], } - assert stepper.loss_obj._normalizer.stds == { + assert loss._normalizer.stds == { **residual_stds, "diagnostic": full_field_stds["diagnostic"], } diff --git a/fme/coupled/aggregator.py b/fme/coupled/aggregator.py index 33a80b847..37876a7a1 100644 --- a/fme/coupled/aggregator.py +++ b/fme/coupled/aggregator.py @@ -29,6 +29,7 @@ ) from fme.coupled.dataset_info import CoupledDatasetInfo from fme.coupled.stepper import CoupledTrainOutput +from fme.coupled.typing_ import CoupledTensorMapping class TrainAggregator(AggregatorABC[CoupledTrainOutput]): @@ -65,11 +66,10 @@ class OneStepAggregator(AggregatorABC[CoupledTrainOutput]): def __init__( self, dataset_info: CoupledDatasetInfo, + loss_scaling: CoupledTensorMapping, save_diagnostics: bool = True, output_dir: str | None = None, variable_metadata: Mapping[str, VariableMetadata] | None = None, - ocean_loss_scaling: TensorMapping | None = None, - atmosphere_loss_scaling: TensorMapping | None = None, ocean_channel_mean_names: Sequence[str] | None = None, atmosphere_channel_mean_names: Sequence[str] | None = None, ): @@ -79,13 +79,12 @@ def __init__( save_diagnostics: Whether to save diagnostics to disk. output_dir: Directory to write diagnostics to. variable_metadata: Metadata for each variable. - ocean_loss_scaling: Dictionary of variables and their scaling factors - used in loss computation for the ocean stepper. - atmosphere_loss_scaling: Dictionary of variables and their scaling factors - used in loss computation for the atmosphere stepper. + loss_scaling: Optional coupled mapping of variables and their + scaling factors used in loss computation for the stepper. ocean_channel_mean_names: Names to include in ocean channel-mean metrics. atmosphere_channel_mean_names: Names to include in atmosphere channel-mean metrics. + """ self._dist = Distributed.get_instance() self._loss = torch.tensor(0.0, device=get_device()) @@ -101,7 +100,7 @@ def __init__( if output_dir is not None else None ), - loss_scaling=ocean_loss_scaling, + loss_scaling=loss_scaling.ocean, channel_mean_names=ocean_channel_mean_names, ), "atmosphere": OneStepAggregator_( @@ -112,7 +111,7 @@ def __init__( if output_dir is not None else None ), - loss_scaling=atmosphere_loss_scaling, + loss_scaling=loss_scaling.atmosphere, channel_mean_names=atmosphere_channel_mean_names, ), } diff --git a/fme/coupled/loss.py b/fme/coupled/loss.py index 1311168d6..de6ac841e 100644 --- a/fme/coupled/loss.py +++ b/fme/coupled/loss.py @@ -1,11 +1,11 @@ import abc import dataclasses -from collections.abc import Callable import torch from fme.core.device import get_device -from fme.core.typing_ import TensorMapping +from fme.core.loss import StepLoss +from fme.core.typing_ import TensorDict, TensorMapping class StepPredictionABC(abc.ABC): @@ -24,6 +24,10 @@ class StepLossABC(abc.ABC): """ + @property + @abc.abstractmethod + def effective_loss_scaling(self) -> TensorDict: ... + @abc.abstractmethod def step_is_optimized(self, step: int) -> bool: """Returns True if the step is less than to the number of @@ -57,11 +61,11 @@ class LossContributionsConfig: def build( self, - loss_obj: Callable[[TensorMapping, TensorMapping, int], torch.Tensor], + loss_obj: StepLoss, time_dim: int, ) -> StepLossABC: if self.n_steps == 0 or self.weight == 0.0: - return NullLossContributions() + return NullLossContributions(loss_obj) return LossContributions( n_steps=self.n_steps, weight=self.weight, @@ -76,6 +80,16 @@ class NullLossContributions(StepLossABC): """ + def __init__( + self, + loss_obj: StepLoss, + ): + self._loss = loss_obj + + @property + def effective_loss_scaling(self) -> TensorDict: + return self._loss.effective_loss_scaling + def step_is_optimized(self, step: int) -> bool: return False @@ -90,7 +104,7 @@ def __init__( self, n_steps: float, weight: float, - loss_obj: Callable[[TensorMapping, TensorMapping, int], torch.Tensor], + loss_obj: StepLoss, time_dim: int, ): self._loss = loss_obj @@ -98,6 +112,10 @@ def __init__( self._weight = weight self._time_dim = time_dim + @property + def effective_loss_scaling(self) -> TensorDict: + return self._loss.effective_loss_scaling + def step_is_optimized(self, step: int) -> bool: """Returns True if the step is less than to the number of steps and weight is != 0. The first step number is assumed to be 0. diff --git a/fme/coupled/stepper.py b/fme/coupled/stepper.py index 9601f3a94..7f134d13c 100644 --- a/fme/coupled/stepper.py +++ b/fme/coupled/stepper.py @@ -2,7 +2,7 @@ import datetime import logging import pathlib -from collections.abc import Callable, Generator, Iterable +from collections.abc import Generator, Iterable from typing import Any, Literal import dacite @@ -33,6 +33,7 @@ from fme.core.generics.inference import PredictFunction from fme.core.generics.optimization import OptimizationABC from fme.core.generics.train_stepper import TrainOutputABC, TrainStepperABC +from fme.core.loss import StepLossConfig from fme.core.ocean import OceanConfig from fme.core.ocean_data import OceanData from fme.core.optimization import NullOptimization @@ -51,6 +52,7 @@ CoupledDataRequirements, CoupledPrognosticStateDataRequirements, ) +from fme.coupled.typing_ import CoupledTensorMapping @dataclasses.dataclass @@ -63,14 +65,10 @@ class ComponentConfig: timedelta: An ISO 8601 Duration string specifying the size of this component's stepper step. stepper: The single module stepper configuration for this component. - loss_contributions: The loss contributions configuration for this component. """ timedelta: str stepper: StepperConfig - loss_contributions: LossContributionsConfig = dataclasses.field( - default_factory=lambda: LossContributionsConfig() - ) @dataclasses.dataclass @@ -604,25 +602,6 @@ def get_stepper( dataset_info=dataset_info, ) - def get_ocean_loss( - self, - loss_obj: Callable[[TensorMapping, TensorMapping, int], torch.Tensor], - time_dim: int, - ) -> StepLossABC: - return self.ocean.loss_contributions.build(loss_obj, time_dim) - - def get_atmosphere_loss( - self, - loss_obj: Callable[[TensorMapping, TensorMapping, int], torch.Tensor], - time_dim: int, - ) -> StepLossABC: - return self.atmosphere.loss_contributions.build(loss_obj, time_dim) - - def get_loss( - self, ocean_loss: StepLossABC, atmosphere_loss: StepLossABC - ) -> "CoupledStepperTrainLoss": - return CoupledStepperTrainLoss(ocean_loss, atmosphere_loss) - def get_state(self): return dataclasses.asdict(self) @@ -642,6 +621,9 @@ def remove_deprecated_keys(cls, state: dict[str, Any]) -> dict[str, Any]: state_copy = state.copy() if "sst_mask_name" in state_copy: del state_copy["sst_mask_name"] + for component_key in ["ocean", "atmosphere"]: + if "loss_contributions" in state_copy[component_key]: + del state_copy[component_key]["loss_contributions"] return state_copy @@ -788,6 +770,13 @@ def __init__( "atmosphere": atmosphere_loss, } + @property + def effective_loss_scaling(self) -> CoupledTensorMapping: + return CoupledTensorMapping( + ocean=self._loss_objs["ocean"].effective_loss_scaling, + atmosphere=self._loss_objs["atmosphere"].effective_loss_scaling, + ) + def __call__( self, prediction: ComponentStepPrediction, @@ -799,15 +788,7 @@ def __call__( return None -class CoupledStepper( - TrainStepperABC[ - CoupledPrognosticState, - CoupledBatchData, - CoupledBatchData, - CoupledPairedData, - CoupledTrainOutput, - ], -): +class CoupledStepper: TIME_DIM = 1 def __init__( @@ -833,16 +814,6 @@ def __init__( self._dataset_info = dataset_info self._ocean_mask_provider = dataset_info.ocean_mask_provider - ocean_loss = self._config.get_ocean_loss( - self.ocean.loss_obj, - ocean.TIME_DIM, - ) - atmos_loss = self._config.get_atmosphere_loss( - self.atmosphere.loss_obj, - atmosphere.TIME_DIM, - ) - self._loss = self._config.get_loss(ocean_loss, atmos_loss) - _: PredictFunction[ # for type checking CoupledPrognosticState, CoupledBatchData, @@ -1319,12 +1290,169 @@ def predict( ), ) + def update_training_history(self, training_job: TrainingJob) -> None: + """ + Update the stepper's history of training jobs. + + Args: + training_job: The training job to add to the history. + """ + self.ocean.update_training_history(training_job) + self.atmosphere.update_training_history(training_job) + + @classmethod + def from_state(cls, state) -> "CoupledStepper": + ocean = Stepper.from_state(state["ocean_state"]) + atmosphere = Stepper.from_state(state["atmosphere_state"]) + config = CoupledStepperConfig.from_state(state["config"]) + if "dataset_info" in state: + dataset_info = CoupledDatasetInfo.from_state(state["dataset_info"]) + else: + # NOTE: this is included for backwards compatibility + dataset_info = CoupledDatasetInfo( + ocean=ocean.training_dataset_info, + atmosphere=atmosphere.training_dataset_info, + ) + return cls( + config=config, + ocean=ocean, + atmosphere=atmosphere, + dataset_info=dataset_info, + ) + + +@dataclasses.dataclass +class ComponentTrainingConfig: + loss: StepLossConfig + loss_contributions: LossContributionsConfig = dataclasses.field( + default_factory=lambda: LossContributionsConfig() + ) + + +@dataclasses.dataclass +class CoupledTrainStepperConfig: + """ + Configuration for training-specific aspects of a coupled stepper. + + Parameters: + ocean: The configuration for the ocean component. + atmosphere: The configuration for the atmosphere component. + """ + + ocean: ComponentTrainingConfig + atmosphere: ComponentTrainingConfig + + def get_train_stepper(self, stepper: CoupledStepper) -> "CoupledTrainStepper": + """ + Build a CoupledTrainStepper from this configuration and a CoupledStepper. + + Args: + stepper: The underlying coupled stepper for inference operations. + + Returns: + A CoupledTrainStepper wrapping the given stepper with training + functionality. + """ + ocean_step_loss = stepper.ocean.build_loss(self.ocean.loss) + atmos_step_loss = stepper.atmosphere.build_loss(self.atmosphere.loss) + ocean_loss = self.ocean.loss_contributions.build( + ocean_step_loss, stepper.ocean.TIME_DIM + ) + atmos_loss = self.atmosphere.loss_contributions.build( + atmos_step_loss, stepper.atmosphere.TIME_DIM + ) + loss = CoupledStepperTrainLoss(ocean_loss, atmos_loss) + return CoupledTrainStepper( + stepper=stepper, + loss=loss, + ) + + +class CoupledTrainStepper( + TrainStepperABC[ + CoupledPrognosticState, + CoupledBatchData, + CoupledBatchData, + CoupledPairedData, + CoupledTrainOutput, + ], +): + """ + Wrapper around CoupledStepper that adds training functionality. + + This class composes a CoupledStepper (for inference) with training-specific + loss configuration and implements the train_on_batch method. + """ + + def __init__( + self, + stepper: CoupledStepper, + loss: CoupledStepperTrainLoss, + ): + """ + Args: + stepper: The underlying coupled stepper for inference operations. + loss: The coupled loss object for computing per-component losses. + """ + self._stepper = stepper + self._loss = loss + + @property + def ocean(self) -> Stepper: + return self._stepper.ocean + + @property + def atmosphere(self) -> Stepper: + return self._stepper.atmosphere + + @property + def effective_loss_scaling(self) -> CoupledTensorMapping: + return self._loss.effective_loss_scaling + + @property + def modules(self) -> nn.ModuleList: + return self._stepper.modules + + @property + def n_ic_timesteps(self) -> int: + return self._stepper.n_ic_timesteps + + @property + def n_inner_steps(self) -> int: + """Number of atmosphere steps per ocean step.""" + return self._stepper.n_inner_steps + + def predict_paired( + self, + initial_condition: CoupledPrognosticState, + forcing: CoupledBatchData, + compute_derived_variables: bool = False, + ) -> tuple[CoupledPairedData, CoupledPrognosticState]: + return self._stepper.predict_paired( + initial_condition, forcing, compute_derived_variables + ) + + def set_train(self): + self._stepper.set_train() + + def set_eval(self): + self._stepper.set_eval() + + def get_state(self) -> dict[str, Any]: + return self._stepper.get_state() + + def load_state(self, state: dict[str, Any]): + self._stepper.load_state(state) + + def update_training_history(self, training_job: TrainingJob) -> None: + self._stepper.update_training_history(training_job) + def train_on_batch( self, data: CoupledBatchData, optimization: OptimizationABC, compute_derived_variables: bool = False, - ): + ) -> CoupledTrainOutput: """ Args: data: The coupled batch data, consisting of separate batches for ocean and @@ -1357,7 +1485,7 @@ def train_on_batch( metrics = ComponentStepMetrics() optimization.set_mode(self.modules) with optimization.autocast(): - output_generator = self.get_prediction_generator( + output_generator = self._stepper.get_prediction_generator( input_data, data, optimization, @@ -1390,7 +1518,7 @@ def train_on_batch( loss = optimization.get_accumulated_loss().detach() optimization.step_weights() - gen_data = self._process_prediction_generator_list(output_list, data) + gen_data = self._stepper._process_prediction_generator_list(output_list, data) ocean_stepped = TrainOutput( metrics=metrics.get_ocean_metrics(), gen_data=add_ensemble_dim(dict(gen_data.ocean_data.data)), @@ -1432,36 +1560,6 @@ def train_on_batch( return stepped - def update_training_history(self, training_job: TrainingJob) -> None: - """ - Update the stepper's history of training jobs. - - Args: - training_job: The training job to add to the history. - """ - self.ocean.update_training_history(training_job) - self.atmosphere.update_training_history(training_job) - - @classmethod - def from_state(cls, state) -> "CoupledStepper": - ocean = Stepper.from_state(state["ocean_state"]) - atmosphere = Stepper.from_state(state["atmosphere_state"]) - config = CoupledStepperConfig.from_state(state["config"]) - if "dataset_info" in state: - dataset_info = CoupledDatasetInfo.from_state(state["dataset_info"]) - else: - # NOTE: this is included for backwards compatibility - dataset_info = CoupledDatasetInfo( - ocean=ocean.training_dataset_info, - atmosphere=atmosphere.training_dataset_info, - ) - return cls( - config=config, - ocean=ocean, - atmosphere=atmosphere, - dataset_info=dataset_info, - ) - def load_coupled_stepper(checkpoint_path: str | pathlib.Path) -> CoupledStepper: logging.info(f"Loading trained coupled model checkpoint from {checkpoint_path}") diff --git a/fme/coupled/test_loss.py b/fme/coupled/test_loss.py index 72a1d4cf2..d72abbbec 100644 --- a/fme/coupled/test_loss.py +++ b/fme/coupled/test_loss.py @@ -4,6 +4,7 @@ import pytest import torch +from fme.core.loss import StepLoss from fme.core.typing_ import TensorMapping from .loss import LossContributionsConfig, StepLossABC, StepPredictionABC @@ -60,6 +61,10 @@ def __init__( self._loss_obj = loss_obj self._time_dim = time_dim + @property + def effective_loss_scaling(self): + raise NotImplementedError() + def step_is_optimized(self, step: int) -> bool: return step < 2 @@ -117,8 +122,9 @@ def mae_loss(gen, target, step: int): n_steps=6, weight=1 / 3, ) + mock_step_loss = Mock(spec=StepLoss, side_effect=mae_loss) atmosphere_loss = atmos_loss_config.build( - loss_obj=mae_loss, + loss_obj=mock_step_loss, time_dim=1, ) ocean_loss = _StepLoss(loss_obj=mae_loss) @@ -153,13 +159,12 @@ def test_null_loss_contributions(steps_thru_atmos_7, ocean_config_kwargs): # test LossContributionsConfig with n_steps = 0 atmos_loss_config = LossContributionsConfig() atmosphere_loss = atmos_loss_config.build( - loss_obj=lambda *_, **__: torch.tensor(5.25), + loss_obj=Mock(spec=StepLoss, return_value=torch.tensor(5.25)), time_dim=1, ) ocean_loss_config = LossContributionsConfig(**ocean_config_kwargs) - ocean_loss_obj = Mock(return_value=torch.tensor(42.0)) ocean_loss = ocean_loss_config.build( - loss_obj=ocean_loss_obj, + loss_obj=Mock(spec=StepLoss, return_value=torch.tensor(42.0)), time_dim=1, ) loss_obj = CoupledStepperTrainLoss( diff --git a/fme/coupled/test_stepper.py b/fme/coupled/test_stepper.py index ef0b611a5..d0f6a7947 100644 --- a/fme/coupled/test_stepper.py +++ b/fme/coupled/test_stepper.py @@ -44,10 +44,12 @@ ) from .stepper import ( ComponentConfig, + ComponentTrainingConfig, CoupledOceanFractionConfig, CoupledParameterInitConfig, CoupledStepper, CoupledStepperConfig, + CoupledTrainStepperConfig, ) NZ = 3 # number of vertical interface levels in mock data from get_data @@ -1501,7 +1503,12 @@ def test_train_on_batch_with_derived_variables(): n_forward_times_atmosphere=2, n_samples=3, ) - output = coupler.train_on_batch( + train_stepper_config = CoupledTrainStepperConfig( + ocean=ComponentTrainingConfig(loss=StepLossConfig(type="MSE")), + atmosphere=ComponentTrainingConfig(loss=StepLossConfig(type="MSE")), + ) + train_stepper = train_stepper_config.get_train_stepper(coupler) + output = train_stepper.train_on_batch( data=coupled_data.data, optimization=NullOptimization(), compute_derived_variables=True, @@ -1588,11 +1595,15 @@ def test_reloaded_stepper_gives_same_prediction(): n_forward_times_atmosphere=4, n_samples=1, ) - first_result = stepper.train_on_batch( + train_stepper_config = CoupledTrainStepperConfig( + ocean=ComponentTrainingConfig(loss=StepLossConfig(type="MSE")), + atmosphere=ComponentTrainingConfig(loss=StepLossConfig(type="MSE")), + ) + first_result = train_stepper_config.get_train_stepper(stepper).train_on_batch( data=data.data, optimization=NullOptimization(), ) - second_result = new_stepper.train_on_batch( + second_result = train_stepper_config.get_train_stepper(new_stepper).train_on_batch( data=data.data, optimization=NullOptimization(), ) diff --git a/fme/coupled/test_stepper_integrations.py b/fme/coupled/test_stepper_integrations.py index e0a684fe3..0cd2cf334 100644 --- a/fme/coupled/test_stepper_integrations.py +++ b/fme/coupled/test_stepper_integrations.py @@ -7,10 +7,12 @@ import fme from fme.ace.stepper.parameter_init import ParameterInitializationConfig from fme.core.coordinates import NullVerticalCoordinate +from fme.core.loss import StepLossConfig from fme.core.optimization import OptimizationConfig from fme.core.registry.module import ModuleSelector from .data_loading.data_typing import CoupledVerticalCoordinate +from .stepper import ComponentTrainingConfig, CoupledTrainStepperConfig from .test_stepper import ( CoupledDatasetInfoBuilder, get_stepper_and_batch, @@ -52,6 +54,12 @@ def test_stepper_gradient_accumulation_integration(): ), ) + train_stepper_config = CoupledTrainStepperConfig( + ocean=ComponentTrainingConfig(loss=StepLossConfig(type="MSE")), + atmosphere=ComponentTrainingConfig(loss=StepLossConfig(type="MSE")), + ) + train_stepper = train_stepper_config.get_train_stepper(coupler) + assert len(coupler.atmosphere.modules) == 1 assert len(coupler.ocean.modules) == 1 @@ -86,7 +94,7 @@ def ocean_hook_v1(module, grad_input, grad_output): optim = OptimizationConfig(use_gradient_accumulation=False).build( coupler.modules, 1 ) - _ = coupler.train_on_batch( + _ = train_stepper.train_on_batch( data=coupled_data.data, optimization=optim, ) @@ -118,7 +126,7 @@ def ocean_hook_v2(module, grad_input, grad_output): # with gradient accumulation, atmos steps detached optim = OptimizationConfig(use_gradient_accumulation=True).build(coupler.modules, 1) - _ = coupler.train_on_batch( + _ = train_stepper.train_on_batch( data=coupled_data.data, optimization=optim, ) diff --git a/fme/coupled/test_train.py b/fme/coupled/test_train.py index 7dde90e73..80e47bc8c 100644 --- a/fme/coupled/test_train.py +++ b/fme/coupled/test_train.py @@ -66,6 +66,17 @@ enable_automatic_mixed_precision: false lr: 0.0001 optimizer_type: Adam +train_stepper: + ocean: + loss: + type: MSE + loss_contributions: + weight: {loss_ocean_weight} + atmosphere: + loss: + type: MSE + loss_contributions: + n_steps: {loss_atmos_n_steps} stepper: sst_name: {ocean_sfc_temp_name} ocean_fraction_prediction: @@ -74,11 +85,7 @@ sea_ice_fraction_name_in_atmosphere: {atmos_sea_ice_frac_name} ocean: timedelta: 2D - loss_contributions: - weight: {loss_ocean_weight} stepper: - loss: - type: MSE input_masking: mask_value: 0 fill_value: 0.0 @@ -109,11 +116,7 @@ out_names: {ocean_out_names} atmosphere: timedelta: 1D - loss_contributions: - n_steps: {loss_atmos_n_steps} stepper: - loss: - type: MSE step: type: single_module config: diff --git a/fme/coupled/train/train.py b/fme/coupled/train/train.py index b10d3f9a8..eb60f6eb2 100644 --- a/fme/coupled/train/train.py +++ b/fme/coupled/train/train.py @@ -27,6 +27,7 @@ from fme.coupled.dataset_info import CoupledDatasetInfo from fme.coupled.stepper import CoupledTrainOutput from fme.coupled.train.train_config import TrainBuilders, TrainConfig +from fme.coupled.typing_ import CoupledTensorMapping def build_trainer(builder: TrainBuilders, config: TrainConfig) -> Trainer: @@ -58,8 +59,7 @@ def build_trainer(builder: TrainBuilders, config: TrainConfig) -> Trainer: n_timesteps_atmosphere=n_timesteps_atmosphere, ocean_normalize=stepper.ocean.normalizer.normalize, atmosphere_normalize=stepper.atmosphere.normalizer.normalize, - ocean_loss_scaling=stepper.ocean.effective_loss_scaling, - atmosphere_loss_scaling=stepper.atmosphere.effective_loss_scaling, + loss_scaling=stepper.effective_loss_scaling, save_per_epoch_diagnostics=config.save_per_epoch_diagnostics, output_dir=config.output_dir, ) @@ -89,8 +89,7 @@ def __init__( output_dir: str, ocean_normalize: Callable[[TensorMapping], TensorDict], atmosphere_normalize: Callable[[TensorMapping], TensorDict], - ocean_loss_scaling: TensorMapping | None = None, - atmosphere_loss_scaling: TensorMapping | None = None, + loss_scaling: CoupledTensorMapping, ocean_channel_mean_names: Sequence[str] | None = None, atmosphere_channel_mean_names: Sequence[str] | None = None, save_per_epoch_diagnostics: bool = False, @@ -103,8 +102,7 @@ def __init__( self.output_dir = output_dir self.ocean_normalize = ocean_normalize self.atmosphere_normalize = atmosphere_normalize - self.ocean_loss_scaling = ocean_loss_scaling - self.atmosphere_loss_scaling = atmosphere_loss_scaling + self.loss_scaling = loss_scaling self.ocean_channel_mean_names = ocean_channel_mean_names self.atmosphere_channel_mean_names = atmosphere_channel_mean_names self.save_per_epoch_diagnostics = save_per_epoch_diagnostics @@ -117,8 +115,7 @@ def get_validation_aggregator(self) -> OneStepAggregator: dataset_info=self.dataset_info, save_diagnostics=self.save_per_epoch_diagnostics, output_dir=os.path.join(self.output_dir, "val"), - ocean_loss_scaling=self.ocean_loss_scaling, - atmosphere_loss_scaling=self.atmosphere_loss_scaling, + loss_scaling=self.loss_scaling, ocean_channel_mean_names=self.ocean_channel_mean_names, atmosphere_channel_mean_names=self.atmosphere_channel_mean_names, ) diff --git a/fme/coupled/train/train_config.py b/fme/coupled/train/train_config.py index 95cf49796..11ae93ad5 100644 --- a/fme/coupled/train/train_config.py +++ b/fme/coupled/train/train_config.py @@ -23,7 +23,11 @@ CoupledDataRequirements, CoupledPrognosticStateDataRequirements, ) -from fme.coupled.stepper import CoupledStepper, CoupledStepperConfig +from fme.coupled.stepper import ( + CoupledStepperConfig, + CoupledTrainStepper, + CoupledTrainStepperConfig, +) @dataclasses.dataclass @@ -128,6 +132,7 @@ class TrainConfig: train_loader: CoupledDataLoaderConfig validation_loader: CoupledDataLoaderConfig stepper: CoupledStepperConfig + train_stepper: CoupledTrainStepperConfig optimization: OptimizationConfig logging: LoggingConfig max_epochs: int @@ -242,8 +247,9 @@ def atmosphere_timestep(self) -> datetime.timedelta: def ocean_timestep(self) -> datetime.timedelta: return self.config.stepper.ocean_timestep - def get_stepper(self, dataset_info: CoupledDatasetInfo) -> CoupledStepper: - return self.config.stepper.get_stepper(dataset_info) + def get_stepper(self, dataset_info: CoupledDatasetInfo) -> CoupledTrainStepper: + stepper = self.config.stepper.get_stepper(dataset_info) + return self.config.train_stepper.get_train_stepper(stepper) def get_ema(self, modules) -> EMATracker: return self.config.ema.build(modules) diff --git a/fme/coupled/typing_.py b/fme/coupled/typing_.py new file mode 100644 index 000000000..a2b01eac5 --- /dev/null +++ b/fme/coupled/typing_.py @@ -0,0 +1,9 @@ +import dataclasses + +from fme.core.typing_ import TensorMapping + + +@dataclasses.dataclass +class CoupledTensorMapping: + ocean: TensorMapping + atmosphere: TensorMapping