diff --git a/fme/core/models/conditional_sfno/s2convolutions.py b/fme/core/models/conditional_sfno/s2convolutions.py index b138dd442..e2789d515 100644 --- a/fme/core/models/conditional_sfno/s2convolutions.py +++ b/fme/core/models/conditional_sfno/s2convolutions.py @@ -223,8 +223,39 @@ def __init__( if bias: self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) + self.in_channels = in_channels self.out_channels = out_channels + # rewrite old checkpoints on load + self.register_load_state_dict_pre_hook(self._add_singleton_group_dim) + + @staticmethod + def _add_singleton_group_dim( + module: "SpectralConvS2", + state_dict: dict[str, torch.Tensor], + prefix: str, + local_metadata: dict, + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], + ) -> None: + key = prefix + "weight" + if key not in state_dict: + return + + weight = state_dict[key] + + ungrouped_shape = ( + module.in_channels, + module.out_channels, + module.modes_lat_local, + 2, + ) + + if weight.shape == ungrouped_shape: + state_dict[key] = weight.view(1, *ungrouped_shape) + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover dtype = x.dtype residual = x diff --git a/fme/core/registry/test_module_registry.py b/fme/core/registry/test_module_registry.py index 8015648b8..d2dce58e9 100644 --- a/fme/core/registry/test_module_registry.py +++ b/fme/core/registry/test_module_registry.py @@ -1,16 +1,22 @@ import dataclasses +import datetime +import pathlib from collections.abc import Iterable import dacite import pytest import torch +import fme +from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates from fme.core.dataset_info import DatasetInfo from fme.core.labels import LabelEncoding from fme.core.registry.module import Module from .module import CONDITIONAL_BUILDERS, ModuleConfig, ModuleSelector +DATA_DIR = pathlib.Path(__file__).parent / "testdata" + class MockModule(torch.nn.Module): def __init__(self, param_shapes: Iterable[tuple[int, ...]]): @@ -71,3 +77,73 @@ def test_build_conditional(): def test_module_selector_raises_with_bad_config(): with pytest.raises(dacite.UnexpectedDataError): ModuleSelector(type="mock", config={"non_existent_key": 1}) + + +def get_noise_conditioned_sfno_module_selector() -> ModuleSelector: + return ModuleSelector( + type="NoiseConditionedSFNO", + config={ + "embed_dim": 8, + "noise_embed_dim": 4, + "noise_type": "isotropic", + "filter_type": "linear", + "use_mlp": True, + "num_layers": 4, + "operator_type": "dhconv", + "affine_norms": True, + "spectral_transform": "sht", + }, + ) + + +def load_or_cache_state(selector_name: str, module: Module) -> dict[str, torch.Tensor]: + state_dict_path = DATA_DIR / f"{selector_name}_state_dict.pt" + if state_dict_path.exists(): + return torch.load(state_dict_path) + else: + state_dict = module.get_state() + torch.save(state_dict, state_dict_path) + raise RuntimeError( + f"State dict for {selector_name} not found. " + f"Created a new one at {state_dict_path}. " + "Please commit it to the repo and run the test again." + ) + + +SELECTORS = { + "NoiseConditionedSFNO": get_noise_conditioned_sfno_module_selector(), +} + + +@pytest.mark.parametrize( + "selector_name", + SELECTORS.keys(), +) +def test_module_backwards_compatibility(selector_name: str): + torch.manual_seed(0) + img_shape = (9, 18) + n_in_channels = 5 + n_out_channels = 6 + all_labels = {"a", "b"} + timestep = datetime.timedelta(hours=6) + device = fme.get_device() + horizontal_coordinate = LatLonCoordinates( + lat=torch.zeros(img_shape[0], device=device), + lon=torch.zeros(img_shape[1], device=device), + ) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7, device=device), bk=torch.arange(7, device=device) + ) + dataset_info = DatasetInfo( + horizontal_coordinates=horizontal_coordinate, + vertical_coordinate=vertical_coordinate, + timestep=timestep, + all_labels=all_labels, + ) + module = SELECTORS[selector_name].build( + n_in_channels=n_in_channels, + n_out_channels=n_out_channels, + dataset_info=dataset_info, + ) + loaded_state_dict = load_or_cache_state(selector_name, module) + module.load_state(loaded_state_dict) diff --git a/fme/core/registry/testdata/NoiseConditionedSFNO_state_dict.pt b/fme/core/registry/testdata/NoiseConditionedSFNO_state_dict.pt new file mode 100644 index 000000000..1ededd76f Binary files /dev/null and b/fme/core/registry/testdata/NoiseConditionedSFNO_state_dict.pt differ