-
Notifications
You must be signed in to change notification settings - Fork 33
Separate regression tests for frozen and latest checkpoints #830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a63196f
d4f0014
8449cc2
f6a2e01
b399e71
7cfa7db
5b7caed
5e6736e
e0b87ae
19af595
5c6f633
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates | ||
| from fme.core.dataset_info import DatasetInfo | ||
| from fme.core.labels import LabelEncoding | ||
| from fme.core.rand import set_seed | ||
| from fme.core.registry.module import Module | ||
|
|
||
| from .module import CONDITIONAL_BUILDERS, ModuleConfig, ModuleSelector | ||
|
|
@@ -79,8 +80,27 @@ def test_module_selector_raises_with_bad_config(): | |
| ModuleSelector(type="mock", config={"non_existent_key": 1}) | ||
|
|
||
|
|
||
| def get_noise_conditioned_sfno_module_selector() -> ModuleSelector: | ||
| return ModuleSelector( | ||
| def get_noise_conditioned_sfno_module() -> Module: | ||
| img_shape = (9, 18) | ||
| n_in_channels = 5 | ||
| n_out_channels = 6 | ||
| all_labels = {"a", "b"} | ||
| timestep = datetime.timedelta(hours=6) | ||
| device = fme.get_device() | ||
| horizontal_coordinate = LatLonCoordinates( | ||
| lat=torch.zeros(img_shape[0], device=device), | ||
| lon=torch.zeros(img_shape[1], device=device), | ||
| ) | ||
| vertical_coordinate = HybridSigmaPressureCoordinate( | ||
| ak=torch.arange(7, device=device), bk=torch.arange(7, device=device) | ||
| ) | ||
| dataset_info = DatasetInfo( | ||
| horizontal_coordinates=horizontal_coordinate, | ||
| vertical_coordinate=vertical_coordinate, | ||
| timestep=timestep, | ||
| all_labels=all_labels, | ||
| ) | ||
| selector = ModuleSelector( | ||
| type="NoiseConditionedSFNO", | ||
| config={ | ||
| "embed_dim": 8, | ||
|
|
@@ -94,6 +114,22 @@ def get_noise_conditioned_sfno_module_selector() -> ModuleSelector: | |
| "spectral_transform": "sht", | ||
| }, | ||
| ) | ||
| module = selector.build( | ||
| n_in_channels=n_in_channels, | ||
| n_out_channels=n_out_channels, | ||
| dataset_info=dataset_info, | ||
| ) | ||
| return module | ||
|
|
||
|
|
||
| def load_state(selector_name: str) -> dict[str, torch.Tensor]: | ||
| state_dict_path = DATA_DIR / f"{selector_name}_state_dict.pt" | ||
| if not state_dict_path.exists(): | ||
| raise RuntimeError( | ||
| f"State dict for {selector_name} not found at {state_dict_path}. " | ||
| "Please make sure the checkpoint exists and is committed to the repo." | ||
| ) | ||
| return torch.load(state_dict_path) | ||
|
|
||
|
|
||
| def load_or_cache_state(selector_name: str, module: Module) -> dict[str, torch.Tensor]: | ||
|
|
@@ -110,40 +146,48 @@ def load_or_cache_state(selector_name: str, module: Module) -> dict[str, torch.T | |
| ) | ||
|
|
||
|
|
||
| SELECTORS = { | ||
| "NoiseConditionedSFNO": get_noise_conditioned_sfno_module_selector(), | ||
| FROZEN_BUILDERS = { | ||
| "dbc2925_ncsfno": get_noise_conditioned_sfno_module, | ||
| } | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "selector_name", | ||
| SELECTORS.keys(), | ||
| FROZEN_BUILDERS.keys(), | ||
| ) | ||
| def test_module_backwards_compatibility(selector_name: str): | ||
| torch.manual_seed(0) | ||
| img_shape = (9, 18) | ||
| n_in_channels = 5 | ||
| n_out_channels = 6 | ||
| all_labels = {"a", "b"} | ||
| timestep = datetime.timedelta(hours=6) | ||
| device = fme.get_device() | ||
| horizontal_coordinate = LatLonCoordinates( | ||
| lat=torch.zeros(img_shape[0], device=device), | ||
| lon=torch.zeros(img_shape[1], device=device), | ||
| ) | ||
| vertical_coordinate = HybridSigmaPressureCoordinate( | ||
| ak=torch.arange(7, device=device), bk=torch.arange(7, device=device) | ||
| ) | ||
| dataset_info = DatasetInfo( | ||
| horizontal_coordinates=horizontal_coordinate, | ||
| vertical_coordinate=vertical_coordinate, | ||
| timestep=timestep, | ||
| all_labels=all_labels, | ||
| ) | ||
| module = SELECTORS[selector_name].build( | ||
| n_in_channels=n_in_channels, | ||
| n_out_channels=n_out_channels, | ||
| dataset_info=dataset_info, | ||
| ) | ||
| def test_frozen_module_backwards_compatibility(selector_name: str): | ||
| """ | ||
| Backwards compatibility for frozen releases from specific commits. | ||
| """ | ||
| set_seed(0) | ||
| module = FROZEN_BUILDERS[selector_name]() | ||
| loaded_state_dict = load_state(selector_name) | ||
| module.load_state(loaded_state_dict) | ||
|
Comment on lines
+158
to
+165
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm still a bit concerned that if we place no limits on new keys then we could inadvertently introduce changes in behavior that lead to regressions in inference skill with We could also save and reload the module config dict together with the artifact to verify that the config builds the same architecture. Of course there are a million other ways to change the module code that could lead to inference regressions, but I don't see why we should allow arbitrary new parameters that weren't present when the checkpoint was saved if there is a way to avoid it.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That type of thing is supposed to be covered by the "produces the same result" test(s). If you're concerned about new parameters not affecting an initial prediction but affecting later ones after the first gradient update, I should add a second stage to those tests that does a second step when testing for identicality. I see though now what you're saying, I've not been understanding it. In practice, what I have here won't catch any of the cases I care about updating the regression tests for, because we always add them in a way that sets the weights to None (which doesn't get registered in the state dict). Really what we need is to remember to update/write a new test when we add features that define new weights. What I actually want to do is test that the config has no new keys, and force the user to build a new latest checkpoint when new config keys are added, saving the asdict'd config with the checkpoint. I'll see about adding that, and also adding what you suggested about making sure the config builds the same architecture. |
||
|
|
||
|
|
||
| LATEST_BUILDERS = { | ||
| "NoiseConditionedSFNO": get_noise_conditioned_sfno_module, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was already a "latest" version. |
||
| } | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "selector_name", | ||
| LATEST_BUILDERS.keys(), | ||
| ) | ||
| def test_latest_module_backwards_compatibility(selector_name: str): | ||
| """ | ||
| Backwards compatibility for the latest module implementations. | ||
|
|
||
| Should be kept up-to-date with the latest code changes. | ||
| """ | ||
| set_seed(0) | ||
| module = LATEST_BUILDERS[selector_name]() | ||
| loaded_state_dict = load_or_cache_state(selector_name, module) | ||
| new_keys = set(module.get_state().keys()).difference(loaded_state_dict.keys()) | ||
| module.load_state(loaded_state_dict) | ||
| assert not new_keys, ( | ||
| f"New keys {new_keys} were added to the state dict of {selector_name}, " | ||
| "which need to be added to the checkpoint to maintain comaptibility. " | ||
| "Please delete and regenerate the checkpoint to include these new keys, " | ||
| "and commit the updated checkpoint to the repo." | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch