-
Notifications
You must be signed in to change notification settings - Fork 39
[DP] Functional DP for GPT-OSS #1137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
DescriptionStart with a short description of what the PR does and how this is a change from The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,: TestsPlease describe how you tested this change, and include any instructions and/or ChecklistBefore submitting this PR, please make sure:
|
1efb3dc to
b10487a
Compare
kyuyeunk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add torchax dp related unit tests.
| total_repeat_length=m // mesh.shape["data"]) | ||
| return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype) | ||
|
|
||
| gmm_result = shard_map( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
personally, I prefer not using shard_map if possible as it is really prone to numeric error when not properly used. When using check_rep=False, unlike other ops, there isn't any safety feature that guarantees that all the numerics of a tensor adheres to a proper SPMD / sharding annotation.
I prefer using it only when it's really necessary (like using kernel).
Please modify this code not use shard map and you can refer to this PR where I replaced existing using of shard_map to a regular jax function: #590
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For more context, when using shard map with check_rep=False, when using something like 'out_spec=P(None)', it only annotates the tensor as having that sharding but shard map does not introduce any collective to ensure it.
Meaning, it is possible that the output tenor numeric is not replicated along multiple devices and all devices have different numeric because shard_map does not provide any guarantees - which makes it really painful to debug when there's a numeric issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to not use a shard_map but did not figure out a way that doesn't involve a for loop.
The complexity here is that jnp.repeat(bias, group_size, 0) expects bias and group_size to share the same size on dimension 0, but group_size[0]= DP*num_experts whereas bias[0] = num_experts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if that's the case, can you do something like this?
# convert (experts, model_dim) to (experts * dp_size, model_dim)
bias = jnp.repeat(bias, dp_size, 0)
# (optional. may or may not needed) match bias's sharding with group_size's sharding
bias = jax.lax.with_sharding_constraint(bias, P("data", "model"))
# Now the bias.shape[0] and group_size.shape[0] matches
rhs_bias = jnp.repeat(bias, group_size, 0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that if complier optimization works correctly, the first 2 code (jnp.repeat & sharding constraint) will be a no-op. because the data is already present in each dp rank & we are just telling it to treat them differently starting from now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Kyuyeun for this suggestio, For correctness, I have to use jnp.tile instead of jnp.repeat. however, I am noticing performance drop (7575.50 vs 7781.92) if I do this instead of shard_map. Maybe due to jnp.tile not being a no-op?
rhs_bias = jnp.tile(rhs_bias, (mesh.shape["data"], 1))
# adding the sharding constraint does not make a difference
rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm in-theory it should be no-op.
because bias is already replicated along TPUs in dp axis, and combining tile/repeat with sharding constraint just tells TPU to treat them like a separate non-replicated tensor.
I'll do some test locally and get back to you asap.
kyuyeunk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add torchax dp related unit tests.
Also, please address this comment.
| total_repeat_length=m // mesh.shape["data"]) | ||
| return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype) | ||
|
|
||
| gmm_result = shard_map( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if that's the case, can you do something like this?
# convert (experts, model_dim) to (experts * dp_size, model_dim)
bias = jnp.repeat(bias, dp_size, 0)
# (optional. may or may not needed) match bias's sharding with group_size's sharding
bias = jax.lax.with_sharding_constraint(bias, P("data", "model"))
# Now the bias.shape[0] and group_size.shape[0] matches
rhs_bias = jnp.repeat(bias, group_size, 0)
|
added e2e model parallelism test for Llama3.1 1b for torchax. |
Description
Add functional DP support for GPT-OSS Torchax backend.
Verified baseline throughput unchanged (5037.82) , DP=2 throughput is 1.54x (7781.92).
Validated numerical correctness with offline_inference.py
Full details: https://paste.googleplex.com/5240826907197440
Tests
https://buildkite.com/tpu-commons/tpu-inference-ci/builds/5712
https://buildkite.com/tpu-commons/tpu-inference-ci/builds/5749
Checklist
Before submitting this PR, please make sure: