-
Notifications
You must be signed in to change notification settings - Fork 62
move weight update validation functions to util #573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@JenniferWang has exported this pull request. If you are a Meta employee, you can view the originating Diff in D87005971. |
Summary: * Fix the weight update test * Extract common logic to a separate util function; see the next diff D87083010 for how to use them in verifying weights do get updated as part of infra verification when debugging a buggy run. Reviewed By: casteryh Differential Revision: D87005971
97f8002 to
4d79c23
Compare
Summary: * Fix the weight update test * Extract common logic to a separate util function; see the next diff D87083010 for how to use them in verifying weights do get updated as part of infra verification when debugging a buggy run. Reviewed By: casteryh Differential Revision: D87005971
4d79c23 to
9f23889
Compare
|
|
||
|
|
||
| @dataclass | ||
| class WeightVerificationResult: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why wouldn't this + the function verify_weights_changed belong in tests? I see no reason users should have access to this as a public API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See D87083010 on how we use it in verifying the infra set up -- for example, users want to write their own RL loop and want to verify that the weight sync is happening as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I consider this a bit of an anti-pattern. The expectation for an API is that it does what it says it does. If we have an API that says it updates weights, the onus is on us to ensure that it actually does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a slightly different opinion -- for example, how to verify that the Generator model is initialized with the correct weight when resuming from checkpointing? Every single component can potentially do the right thing but the user may configured something wrong.
| # TODO: Remove below param | ||
| _test_prev_params = {} | ||
|
|
||
| def __post_init__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What was the reason for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is required to make the logger.info print to stdout?
Summary:
Differential Revision: D87005971