-
Notifications
You must be signed in to change notification settings - Fork 39
[DP] Functional DP for GPT-OSS #1137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
wenxindongwork
wants to merge
8
commits into
main
Choose a base branch
from
torch-dp-pr
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+89
−28
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
3fe16df
squash
wenxindongwork 3eada3d
wip
wenxindongwork 8970b80
only submit model dp
wenxindongwork 270e511
wip
wenxindongwork a9d5154
wip
wenxindongwork b10487a
formatting
wenxindongwork a670581
wip
wenxindongwork 9087e9c
formatting
wenxindongwork File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
personally, I prefer not using shard_map if possible as it is really prone to numeric error when not properly used. When using
check_rep=False, unlike other ops, there isn't any safety feature that guarantees that all the numerics of a tensor adheres to a proper SPMD / sharding annotation.I prefer using it only when it's really necessary (like using kernel).
Please modify this code not use shard map and you can refer to this PR where I replaced existing using of shard_map to a regular jax function: #590
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For more context, when using shard map with check_rep=False, when using something like 'out_spec=P(None)', it only annotates the tensor as having that sharding but shard map does not introduce any collective to ensure it.
Meaning, it is possible that the output tenor numeric is not replicated along multiple devices and all devices have different numeric because shard_map does not provide any guarantees - which makes it really painful to debug when there's a numeric issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to not use a shard_map but did not figure out a way that doesn't involve a for loop.
The complexity here is that
jnp.repeat(bias, group_size, 0)expects bias and group_size to share the same size on dimension 0, but group_size[0]=DP*num_expertswhereasbias[0] = num_experts.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if that's the case, can you do something like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that if complier optimization works correctly, the first 2 code (jnp.repeat & sharding constraint) will be a no-op. because the data is already present in each dp rank & we are just telling it to treat them differently starting from now.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Kyuyeun for this suggestio, For correctness, I have to use
jnp.tileinstead ofjnp.repeat. however, I am noticing performance drop (7575.50 vs 7781.92) if I do this instead of shard_map. Maybe due tojnp.tilenot being a no-op?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm in-theory it should be no-op.
because bias is already replicated along TPUs in dp axis, and combining tile/repeat with sharding constraint just tells TPU to treat them like a separate non-replicated tensor.
I'll do some test locally and get back to you asap.