Skip to content
50 changes: 20 additions & 30 deletions fme/ace/stepper/single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,20 +823,6 @@ def __init__(
self._no_optimization = NullOptimization()
self._parameter_initializer = parameter_initializer

def get_loss_obj() -> StepLoss:
loss_normalizer = step.get_loss_normalizer()
if config.loss is None:
raise ValueError("Loss is not configured")
return config.loss.build(
dataset_info.gridded_operations,
out_names=config.loss_names,
channel_dim=self.CHANNEL_DIM,
normalizer=loss_normalizer,
)

self._get_loss_obj = get_loss_obj
self._loss_obj: StepLoss | None = None

self._parameter_initializer.apply_weights(
step.modules,
)
Expand All @@ -863,11 +849,24 @@ def get_loss_obj() -> StepLoss:
self._dataset_info = dataset_info
self.forcing_deriver = config.derived_forcings.build(dataset_info)

@property
def loss_obj(self) -> StepLoss:
if self._loss_obj is None:
self._loss_obj = self._get_loss_obj()
return self._loss_obj
def build_loss(self, loss_config: StepLossConfig) -> StepLoss:
"""Build a StepLoss from the given config using this stepper's normalizer
and dataset info.

Args:
loss_config: The loss configuration to build from.

Returns:
A StepLoss built using this stepper's loss normalizer, gridded
operations, loss variable names, and channel dimension.
"""
loss_normalizer = self._step_obj.get_loss_normalizer()
return loss_config.build(
self._dataset_info.gridded_operations,
out_names=self.loss_names,
channel_dim=self.CHANNEL_DIM,
normalizer=loss_normalizer,
)

@property
def config(self) -> StepperConfig:
Expand Down Expand Up @@ -928,15 +927,6 @@ def _append_training_history_from(
if base_training_history is not None:
self._training_history.extend(base_training_history)

@property
def effective_loss_scaling(self) -> TensorDict:
"""
Effective loss scalings used to normalize outputs before computing loss.
y_loss_normalized_i = (y_i - y_mean_i) / loss_scaling_i
where loss_scaling_i = loss_normalizer_std_i / weight_i.
"""
return self.loss_obj.effective_loss_scaling

def replace_multi_call(self, multi_call: MultiCallConfig | None):
"""
Replace the MultiCall object with a new one. Note this is only
Expand Down Expand Up @@ -1472,7 +1462,7 @@ def __init__(

self._prognostic_names = self._stepper.prognostic_names
self._derive_func = self._stepper.derive_func
self._loss_obj = self._stepper.loss_obj
self._loss_obj = self._stepper.build_loss(config.loss)

def train_on_batch(
self,
Expand Down Expand Up @@ -1660,7 +1650,7 @@ def effective_loss_scaling(self) -> TensorDict:
y_loss_normalized_i = (y_i - y_mean_i) / loss_scaling_i
where loss_scaling_i = loss_normalizer_std_i / weight_i.
"""
return self._stepper.effective_loss_scaling
return self._loss_obj.effective_loss_scaling

def _init_for_epoch(self, epoch: int | None):
if (
Expand Down
5 changes: 3 additions & 2 deletions fme/ace/stepper/test_single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,11 +1332,12 @@ def test_stepper_from_state_using_resnorm_has_correct_normalizer():
stepper_from_state = Stepper.from_state(orig_stepper.get_state())

for stepper in [orig_stepper, stepper_from_state]:
assert stepper.loss_obj._normalizer.means == {
loss = stepper.build_loss(StepLossConfig())
assert loss._normalizer.means == {
**residual_means,
"diagnostic": full_field_means["diagnostic"],
}
assert stepper.loss_obj._normalizer.stds == {
assert loss._normalizer.stds == {
**residual_stds,
"diagnostic": full_field_stds["diagnostic"],
}
Expand Down
15 changes: 7 additions & 8 deletions fme/coupled/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from fme.coupled.dataset_info import CoupledDatasetInfo
from fme.coupled.stepper import CoupledTrainOutput
from fme.coupled.typing_ import CoupledTensorMapping


class TrainAggregator(AggregatorABC[CoupledTrainOutput]):
Expand Down Expand Up @@ -65,11 +66,10 @@ class OneStepAggregator(AggregatorABC[CoupledTrainOutput]):
def __init__(
self,
dataset_info: CoupledDatasetInfo,
loss_scaling: CoupledTensorMapping,
save_diagnostics: bool = True,
output_dir: str | None = None,
variable_metadata: Mapping[str, VariableMetadata] | None = None,
ocean_loss_scaling: TensorMapping | None = None,
atmosphere_loss_scaling: TensorMapping | None = None,
ocean_channel_mean_names: Sequence[str] | None = None,
atmosphere_channel_mean_names: Sequence[str] | None = None,
):
Expand All @@ -79,13 +79,12 @@ def __init__(
save_diagnostics: Whether to save diagnostics to disk.
output_dir: Directory to write diagnostics to.
variable_metadata: Metadata for each variable.
ocean_loss_scaling: Dictionary of variables and their scaling factors
used in loss computation for the ocean stepper.
atmosphere_loss_scaling: Dictionary of variables and their scaling factors
used in loss computation for the atmosphere stepper.
loss_scaling: Optional coupled mapping of variables and their
scaling factors used in loss computation for the stepper.
ocean_channel_mean_names: Names to include in ocean channel-mean metrics.
atmosphere_channel_mean_names: Names to include in atmosphere channel-mean
metrics.

"""
self._dist = Distributed.get_instance()
self._loss = torch.tensor(0.0, device=get_device())
Expand All @@ -101,7 +100,7 @@ def __init__(
if output_dir is not None
else None
),
loss_scaling=ocean_loss_scaling,
loss_scaling=loss_scaling.ocean,
channel_mean_names=ocean_channel_mean_names,
),
"atmosphere": OneStepAggregator_(
Expand All @@ -112,7 +111,7 @@ def __init__(
if output_dir is not None
else None
),
loss_scaling=atmosphere_loss_scaling,
loss_scaling=loss_scaling.atmosphere,
channel_mean_names=atmosphere_channel_mean_names,
),
}
Expand Down
28 changes: 23 additions & 5 deletions fme/coupled/loss.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import abc
import dataclasses
from collections.abc import Callable

import torch

from fme.core.device import get_device
from fme.core.typing_ import TensorMapping
from fme.core.loss import StepLoss
from fme.core.typing_ import TensorDict, TensorMapping


class StepPredictionABC(abc.ABC):
Expand All @@ -24,6 +24,10 @@ class StepLossABC(abc.ABC):

"""

@property
@abc.abstractmethod
def effective_loss_scaling(self) -> TensorDict: ...

@abc.abstractmethod
def step_is_optimized(self, step: int) -> bool:
"""Returns True if the step is less than to the number of
Expand Down Expand Up @@ -57,11 +61,11 @@ class LossContributionsConfig:

def build(
self,
loss_obj: Callable[[TensorMapping, TensorMapping, int], torch.Tensor],
loss_obj: StepLoss,
time_dim: int,
) -> StepLossABC:
if self.n_steps == 0 or self.weight == 0.0:
return NullLossContributions()
return NullLossContributions(loss_obj)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This preserves the existing behavior where we used a component stepper's effective_loss_scaling to compute mse_fractional_components metrics even if the stepper had no loss contribution in coupled training.

return LossContributions(
n_steps=self.n_steps,
weight=self.weight,
Expand All @@ -76,6 +80,16 @@ class NullLossContributions(StepLossABC):

"""

def __init__(
self,
loss_obj: StepLoss,
):
self._loss = loss_obj

@property
def effective_loss_scaling(self) -> TensorDict:
return self._loss.effective_loss_scaling

def step_is_optimized(self, step: int) -> bool:
return False

Expand All @@ -90,14 +104,18 @@ def __init__(
self,
n_steps: float,
weight: float,
loss_obj: Callable[[TensorMapping, TensorMapping, int], torch.Tensor],
loss_obj: StepLoss,
time_dim: int,
):
self._loss = loss_obj
self._n_steps = n_steps
self._weight = weight
self._time_dim = time_dim

@property
def effective_loss_scaling(self) -> TensorDict:
return self._loss.effective_loss_scaling

def step_is_optimized(self, step: int) -> bool:
"""Returns True if the step is less than to the number of steps and
weight is != 0. The first step number is assumed to be 0.
Expand Down
Loading