-
Notifications
You must be signed in to change notification settings - Fork 33
Fix backwards compatibility with pre-group checkpoints #828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
jpdunc23
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick fix!
Could you add a test by creating an old SFNO checkpoint from a commit before c2d9151 and ensuring it loads via the usual inference pathway without errors (but without running inference)?
| self.out_channels = out_channels | ||
|
|
||
| # rewrite old checkpoints on load | ||
| self.register_load_state_dict_pre_hook(self._pre_load_hook) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
|
Noting that I generated a checkpoint (still running locally, not committed) using dbc2925 which should be just before the change. Test is currently failing. |
|
I verified the test runs using a checkpoint from before that commit (the test initially failed). The checkpoint was too large, so I regenerated it, and the code pushed includes a checkpoint from this present code. |
jpdunc23
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice approach using register_load_state_dict_pre_hook!
Left a comment with some concerns about the new regression test.
| assert not new_keys, ( | ||
| f"New keys added to state dict: {new_keys}. " | ||
| "Please delete and re-generate the regression target." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's important that we retain this version of NoiseConditionedSFNO_state_dict.pt to avoid regressions in _add_singleton_group_dim, especially as we've had two model releases with the older format.
Maybe you could call it NoiseConditionedSFNO_state_dict_dbc2925.pt and parameterize this test including the commit SHA?
More generally, I don't think we should encourage regeneration of the regression target when there are new keys in the state dict, since presumably those new things could lead to big changes in inference behavior. We could maybe handle new keys here on a case-by-case basis, e.g. where we know that given an existing config the new key is just a no-op / identity op, but even that seems risky since we couldn't guarantee that future changes cause that key to modify inference behavior. Ideally, future changes just wouldn't add the new keys to the module unless a new config parameter is turned on, but maybe you have something specific in mind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than pick the commit hash before the break, we should retain checkpoints from the specific ACE releases that used NoiseConditionedSFNO.
I think it's important to have a "latest" checkpoint that will break if any backwards incompatibilities are introduced. If the checkpoint we're testing doesn't have a key, then a PR removing that key won't cause the test to fail, but would cause newer checkpoints having that key to no longer load. It is important that this message only happen after we've already tested the current model is backwards-compatible, so that we don't update the checkpoint in a commit that also breaks backwards compatibility.
For this PR I'll just remove this check, we can add these two things in another PR.
jpdunc23
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Punting on handling for legacy checkpoints for now so we can get this merged. Approved.
A previous PR #788 introduced a change to the weight shape to include a group dimension. This prevents loading existing checkpoints that don't have this dimension. This PR adds a hook to correctly handle these previous checkpoints.
Resolves #826