Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions fme/core/models/conditional_sfno/s2convolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,39 @@ def __init__(

if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
self.in_channels = in_channels
self.out_channels = out_channels

# rewrite old checkpoints on load
self.register_load_state_dict_pre_hook(self._add_singleton_group_dim)

@staticmethod
def _add_singleton_group_dim(
module: "SpectralConvS2",
state_dict: dict[str, torch.Tensor],
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
key = prefix + "weight"
if key not in state_dict:
return

weight = state_dict[key]

ungrouped_shape = (
module.in_channels,
module.out_channels,
module.modes_lat_local,
2,
)

if weight.shape == ungrouped_shape:
state_dict[key] = weight.view(1, *ungrouped_shape)

def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover
dtype = x.dtype
residual = x
Expand Down
76 changes: 76 additions & 0 deletions fme/core/registry/test_module_registry.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import dataclasses
import datetime
import pathlib
from collections.abc import Iterable

import dacite
import pytest
import torch

import fme
from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates
from fme.core.dataset_info import DatasetInfo
from fme.core.labels import LabelEncoding
from fme.core.registry.module import Module

from .module import CONDITIONAL_BUILDERS, ModuleConfig, ModuleSelector

DATA_DIR = pathlib.Path(__file__).parent / "testdata"


class MockModule(torch.nn.Module):
def __init__(self, param_shapes: Iterable[tuple[int, ...]]):
Expand Down Expand Up @@ -71,3 +77,73 @@ def test_build_conditional():
def test_module_selector_raises_with_bad_config():
with pytest.raises(dacite.UnexpectedDataError):
ModuleSelector(type="mock", config={"non_existent_key": 1})


def get_noise_conditioned_sfno_module_selector() -> ModuleSelector:
return ModuleSelector(
type="NoiseConditionedSFNO",
config={
"embed_dim": 8,
"noise_embed_dim": 4,
"noise_type": "isotropic",
"filter_type": "linear",
"use_mlp": True,
"num_layers": 4,
"operator_type": "dhconv",
"affine_norms": True,
"spectral_transform": "sht",
},
)


def load_or_cache_state(selector_name: str, module: Module) -> dict[str, torch.Tensor]:
state_dict_path = DATA_DIR / f"{selector_name}_state_dict.pt"
if state_dict_path.exists():
return torch.load(state_dict_path)
else:
state_dict = module.get_state()
torch.save(state_dict, state_dict_path)
raise RuntimeError(
f"State dict for {selector_name} not found. "
f"Created a new one at {state_dict_path}. "
"Please commit it to the repo and run the test again."
)


SELECTORS = {
"NoiseConditionedSFNO": get_noise_conditioned_sfno_module_selector(),
}


@pytest.mark.parametrize(
"selector_name",
SELECTORS.keys(),
)
def test_module_backwards_compatibility(selector_name: str):
torch.manual_seed(0)
img_shape = (9, 18)
n_in_channels = 5
n_out_channels = 6
all_labels = {"a", "b"}
timestep = datetime.timedelta(hours=6)
device = fme.get_device()
horizontal_coordinate = LatLonCoordinates(
lat=torch.zeros(img_shape[0], device=device),
lon=torch.zeros(img_shape[1], device=device),
)
vertical_coordinate = HybridSigmaPressureCoordinate(
ak=torch.arange(7, device=device), bk=torch.arange(7, device=device)
)
dataset_info = DatasetInfo(
horizontal_coordinates=horizontal_coordinate,
vertical_coordinate=vertical_coordinate,
timestep=timestep,
all_labels=all_labels,
)
module = SELECTORS[selector_name].build(
n_in_channels=n_in_channels,
n_out_channels=n_out_channels,
dataset_info=dataset_info,
)
loaded_state_dict = load_or_cache_state(selector_name, module)
module.load_state(loaded_state_dict)
Binary file not shown.