Skip to content

Conversation

@wenxindongwork
Copy link
Collaborator

@wenxindongwork wenxindongwork commented Nov 20, 2025

Description

Add functional DP support for GPT-OSS Torchax backend.

Verified baseline throughput unchanged (5037.82) , DP=2 throughput is 1.54x (7781.92).

Validated numerical correctness with offline_inference.py

Full details: https://paste.googleplex.com/5240826907197440

Tests

https://buildkite.com/tpu-commons/tpu-inference-ci/builds/5712
https://buildkite.com/tpu-commons/tpu-inference-ci/builds/5749

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@github-actions
Copy link

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

Copy link
Collaborator

@kyuyeunk kyuyeunk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add torchax dp related unit tests.

total_repeat_length=m // mesh.shape["data"])
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)

gmm_result = shard_map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally, I prefer not using shard_map if possible as it is really prone to numeric error when not properly used. When using check_rep=False, unlike other ops, there isn't any safety feature that guarantees that all the numerics of a tensor adheres to a proper SPMD / sharding annotation.

I prefer using it only when it's really necessary (like using kernel).

Please modify this code not use shard map and you can refer to this PR where I replaced existing using of shard_map to a regular jax function: #590

Copy link
Collaborator

@kyuyeunk kyuyeunk Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more context, when using shard map with check_rep=False, when using something like 'out_spec=P(None)', it only annotates the tensor as having that sharding but shard map does not introduce any collective to ensure it.

Meaning, it is possible that the output tenor numeric is not replicated along multiple devices and all devices have different numeric because shard_map does not provide any guarantees - which makes it really painful to debug when there's a numeric issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to not use a shard_map but did not figure out a way that doesn't involve a for loop.

The complexity here is that jnp.repeat(bias, group_size, 0) expects bias and group_size to share the same size on dimension 0, but group_size[0]= DP*num_experts whereas bias[0] = num_experts.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if that's the case, can you do something like this?

# convert (experts, model_dim) to (experts * dp_size, model_dim)
bias = jnp.repeat(bias, dp_size, 0) 

# (optional. may or may not needed) match bias's sharding with group_size's sharding
bias = jax.lax.with_sharding_constraint(bias, P("data", "model"))

# Now the bias.shape[0] and group_size.shape[0] matches
rhs_bias = jnp.repeat(bias, group_size, 0)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that if complier optimization works correctly, the first 2 code (jnp.repeat & sharding constraint) will be a no-op. because the data is already present in each dp rank & we are just telling it to treat them differently starting from now.

Copy link
Collaborator Author

@wenxindongwork wenxindongwork Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Kyuyeun for this suggestio, For correctness, I have to use jnp.tile instead of jnp.repeat. however, I am noticing performance drop (7575.50 vs 7781.92) if I do this instead of shard_map. Maybe due to jnp.tile not being a no-op?

        rhs_bias = jnp.tile(rhs_bias, (mesh.shape["data"], 1)) 
        # adding the sharding constraint does not make a difference 
        rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
        gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm in-theory it should be no-op.

because bias is already replicated along TPUs in dp axis, and combining tile/repeat with sharding constraint just tells TPU to treat them like a separate non-replicated tensor.

I'll do some test locally and get back to you asap.

Copy link
Collaborator

@kyuyeunk kyuyeunk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add torchax dp related unit tests.

Also, please address this comment.

total_repeat_length=m // mesh.shape["data"])
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)

gmm_result = shard_map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if that's the case, can you do something like this?

# convert (experts, model_dim) to (experts * dp_size, model_dim)
bias = jnp.repeat(bias, dp_size, 0) 

# (optional. may or may not needed) match bias's sharding with group_size's sharding
bias = jax.lax.with_sharding_constraint(bias, P("data", "model"))

# Now the bias.shape[0] and group_size.shape[0] matches
rhs_bias = jnp.repeat(bias, group_size, 0)

@wenxindongwork
Copy link
Collaborator Author

added e2e model parallelism test for Llama3.1 1b for torchax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants