diff --git a/fme/core/registry/module.py b/fme/core/registry/module.py index b1ee41d79..1f28a224f 100644 --- a/fme/core/registry/module.py +++ b/fme/core/registry/module.py @@ -161,7 +161,7 @@ def build( n_in_channels: int, n_out_channels: int, dataset_info: DatasetInfo, - ) -> nn.Module: + ) -> Module: """ Build a nn.Module given information about the input and output channels and the dataset. diff --git a/fme/core/registry/test_module_registry.py b/fme/core/registry/test_module_registry.py index d2dce58e9..c2abc7ff9 100644 --- a/fme/core/registry/test_module_registry.py +++ b/fme/core/registry/test_module_registry.py @@ -11,6 +11,7 @@ from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates from fme.core.dataset_info import DatasetInfo from fme.core.labels import LabelEncoding +from fme.core.rand import set_seed from fme.core.registry.module import Module from .module import CONDITIONAL_BUILDERS, ModuleConfig, ModuleSelector @@ -79,8 +80,27 @@ def test_module_selector_raises_with_bad_config(): ModuleSelector(type="mock", config={"non_existent_key": 1}) -def get_noise_conditioned_sfno_module_selector() -> ModuleSelector: - return ModuleSelector( +def get_noise_conditioned_sfno_module() -> Module: + img_shape = (9, 18) + n_in_channels = 5 + n_out_channels = 6 + all_labels = {"a", "b"} + timestep = datetime.timedelta(hours=6) + device = fme.get_device() + horizontal_coordinate = LatLonCoordinates( + lat=torch.zeros(img_shape[0], device=device), + lon=torch.zeros(img_shape[1], device=device), + ) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7, device=device), bk=torch.arange(7, device=device) + ) + dataset_info = DatasetInfo( + horizontal_coordinates=horizontal_coordinate, + vertical_coordinate=vertical_coordinate, + timestep=timestep, + all_labels=all_labels, + ) + selector = ModuleSelector( type="NoiseConditionedSFNO", config={ "embed_dim": 8, @@ -94,6 +114,22 @@ def get_noise_conditioned_sfno_module_selector() -> ModuleSelector: "spectral_transform": "sht", }, ) + module = selector.build( + n_in_channels=n_in_channels, + n_out_channels=n_out_channels, + dataset_info=dataset_info, + ) + return module + + +def load_state(selector_name: str) -> dict[str, torch.Tensor]: + state_dict_path = DATA_DIR / f"{selector_name}_state_dict.pt" + if not state_dict_path.exists(): + raise RuntimeError( + f"State dict for {selector_name} not found at {state_dict_path}. " + "Please make sure the checkpoint exists and is committed to the repo." + ) + return torch.load(state_dict_path) def load_or_cache_state(selector_name: str, module: Module) -> dict[str, torch.Tensor]: @@ -110,40 +146,48 @@ def load_or_cache_state(selector_name: str, module: Module) -> dict[str, torch.T ) -SELECTORS = { - "NoiseConditionedSFNO": get_noise_conditioned_sfno_module_selector(), +FROZEN_BUILDERS = { + "dbc2925_ncsfno": get_noise_conditioned_sfno_module, } @pytest.mark.parametrize( "selector_name", - SELECTORS.keys(), + FROZEN_BUILDERS.keys(), ) -def test_module_backwards_compatibility(selector_name: str): - torch.manual_seed(0) - img_shape = (9, 18) - n_in_channels = 5 - n_out_channels = 6 - all_labels = {"a", "b"} - timestep = datetime.timedelta(hours=6) - device = fme.get_device() - horizontal_coordinate = LatLonCoordinates( - lat=torch.zeros(img_shape[0], device=device), - lon=torch.zeros(img_shape[1], device=device), - ) - vertical_coordinate = HybridSigmaPressureCoordinate( - ak=torch.arange(7, device=device), bk=torch.arange(7, device=device) - ) - dataset_info = DatasetInfo( - horizontal_coordinates=horizontal_coordinate, - vertical_coordinate=vertical_coordinate, - timestep=timestep, - all_labels=all_labels, - ) - module = SELECTORS[selector_name].build( - n_in_channels=n_in_channels, - n_out_channels=n_out_channels, - dataset_info=dataset_info, - ) +def test_frozen_module_backwards_compatibility(selector_name: str): + """ + Backwards compatibility for frozen releases from specific commits. + """ + set_seed(0) + module = FROZEN_BUILDERS[selector_name]() + loaded_state_dict = load_state(selector_name) + module.load_state(loaded_state_dict) + + +LATEST_BUILDERS = { + "NoiseConditionedSFNO": get_noise_conditioned_sfno_module, +} + + +@pytest.mark.parametrize( + "selector_name", + LATEST_BUILDERS.keys(), +) +def test_latest_module_backwards_compatibility(selector_name: str): + """ + Backwards compatibility for the latest module implementations. + + Should be kept up-to-date with the latest code changes. + """ + set_seed(0) + module = LATEST_BUILDERS[selector_name]() loaded_state_dict = load_or_cache_state(selector_name, module) + new_keys = set(module.get_state().keys()).difference(loaded_state_dict.keys()) module.load_state(loaded_state_dict) + assert not new_keys, ( + f"New keys {new_keys} were added to the state dict of {selector_name}, " + "which need to be added to the checkpoint to maintain comaptibility. " + "Please delete and regenerate the checkpoint to include these new keys, " + "and commit the updated checkpoint to the repo." + ) diff --git a/fme/core/registry/testdata/dbc2925_ncsfno_state_dict.pt b/fme/core/registry/testdata/dbc2925_ncsfno_state_dict.pt new file mode 100644 index 000000000..f05bfa735 Binary files /dev/null and b/fme/core/registry/testdata/dbc2925_ncsfno_state_dict.pt differ