Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions github_issue.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
## [Feature] Expert-Choice Routing for MoEFeedForward

The current `MoEFeedForward` supports token-choice routing where each
token selects its top-k experts. I'd like to add expert-choice routing
(Zhou et al., 2022 -- https://arxiv.org/abs/2202.09368) where each expert
selects its top-C tokens.

### Motivation

Expert-choice routing provides natural load balancing -- each expert
processes exactly C tokens by construction, eliminating the need for
auxiliary losses. The Google paper showed it outperforms token-choice on
language modeling and downstream tasks.

### Proposed changes

- Add `routing_mode` field (`'token_choice'` | `'expert_choice'`) to
`MoEFeedForward`, `TransformerBlock`, and `BaseExperimentConfig`
- Add `_apply_expert_choice_moe()` following the `_apply_dense_moe`
dispatch pattern (einsums, no shard_map)
- Add test configs `lm_moe_test` and `lm_moe_expert_choice_test`
- Add equivalence and gradient tests to `MoETest`

### Design notes

- Backward compatible -- default remains `'token_choice'`
- No new dependencies or files
- Initial implementation uses dense dispatch (not sparse/GMM) for
simplicity; can be extended to sparse dispatch in a follow-up

Would this be a welcome addition?
23 changes: 23 additions & 0 deletions pr_description.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
## Add expert-choice routing mode for MoEFeedForward

Adds expert-choice routing (Zhou et al., 2022) where each expert selectsits top-C tokens, providing natural load balancing without auxiliary losses.

### Changes

- `model_lib.py`: `routing_mode` field, `_apply_expert_choice_moe()`,restructured `apply()` with early routing-mode branch
- `config_lib.py`: `routing_mode` in `BaseExperimentConfig`,`lm_moe_test` and `lm_moe_expert_choice_test` configs
- `model_lib_test.py`: `simple_expert_choice_moe()` reference impl,forward/gradient equivalence tests

### How to test

```bash
# Expert-choice MoE local testpython -m simply.main --experiment_config lm_moe_expert_choice_test --experiment_dir /tmp/moe_ec_test --alsologtostderr# All MoE unit tests (including existing + new)pytest simply/model_lib_test.py::MoETest -v
```

### Design decisions

- Follows `_apply_dense_moe` dispatch pattern (einsums, not sparse GMM)
- Capacity: C = num_experts_per_token x num_tokens / num_experts
- `lbl_loss` is skipped when `routing_mode='expert_choice'` (load isbalanced by construction, so the auxiliary loss is unnecessary)
- `routing_mode` is validated; unknown values raise `ValueError`
- Token-choice path is unchanged: all 7 existing `simple_moe()`equivalence tests pass with the same tolerances as before
34 changes: 34 additions & 0 deletions simply/config_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ class BaseExperimentConfig(ExperimentConfig):
tile_model_dim: int = 1024
tile_expand_dim: int = 1024
gmm_impl: str = 'ragged_dot'
routing_mode: str = 'token_choice'
global_total_num_pages: int = 0
local_total_num_pages: int = 0
page_size: int = 0
Expand Down Expand Up @@ -1982,6 +1983,39 @@ def lm_rl_test():
)


@ExperimentConfigRegistry.register
def lm_moe_test():
"""Tiny token-choice MoE config for local testing."""
config = lm_test()
return dataclasses.replace(
config,
sharding_config=moe_sharding(),
use_moe=True,
num_experts=4,
num_experts_per_token=2,
expert_capacity_factor=None,
lbl_loss_weight=0.01,
ffn_use_bias=False,
)


@ExperimentConfigRegistry.register
def lm_moe_expert_choice_test():
"""Tiny expert-choice MoE config for local testing."""
config = lm_test()
return dataclasses.replace(
config,
sharding_config=moe_sharding(),
use_moe=True,
num_experts=4,
num_experts_per_token=1,
expert_capacity_factor=None,
routing_mode='expert_choice',
lbl_loss_weight=0.0,
ffn_use_bias=False,
)


def get_default_mesh_shape(
config: BaseExperimentConfig, mode: str = 'train',
dcn_mesh_shape=None) -> Mapping[str, int]:
Expand Down
215 changes: 179 additions & 36 deletions simply/model_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ class MoEFeedForward(FeedForward):
tile_model_dim: int = 128
tile_expand_dim: int = 128
gmm_impl: str = 'ragged_dot'
routing_mode: str = 'token_choice'

def setup(self):
if self.ffn_use_bias:
Expand Down Expand Up @@ -688,45 +689,68 @@ def apply(
# router_logits: [batch_size, seq_len, num_experts]
router_logits = self.router.apply(params['router'], inputs)
router_logits = router_logits.astype(jnp.float32)
# router_probs: [batch_size, seq_len, num_experts]
# Both routing modes use the same softmax router probs,
# they just consume them differently: token choice does
# topk per token, expert choice does topk per expert.
router_probs = jax.nn.softmax(router_logits, axis=-1)
if self.num_experts_per_token == 1:
# Apply `softmax => topk` when k == 1 to avoid zero gradient
# on the router logits.
# selected_router_probs, selected_indices:
# [batch_size, seq_len, num_experts_per_token]
selected_router_probs, selected_indices = jax.lax.top_k(
router_probs, k=self.num_experts_per_token
)
else:
# Perform `topk => softmax` to get a normalized probability distribution.
# selected_router_logits, selected_indices:
# [batch_size, seq_len, num_experts_per_token]
selected_router_logits, selected_indices = jax.lax.top_k(
router_logits, k=self.num_experts_per_token
)
selected_router_probs = jax.nn.softmax(selected_router_logits, axis=-1)
selected_router_probs = jnp.asarray(
selected_router_probs, self.activation_dtype
)
router_probs = jnp.asarray(router_probs, self.activation_dtype)
if self.expert_capacity_factor is None:
outputs, ffn_extra_output = self._apply_sparse_moe(
params,
inputs,
selected_indices=selected_indices,
selected_weights=selected_router_probs,
inputs_mask=inputs_mask,
)

if self.routing_mode not in ('token_choice', 'expert_choice'):
raise ValueError(
f'Unknown routing_mode: {self.routing_mode!r}')

if self.routing_mode == 'expert_choice':
outputs, ffn_extra_output = (
self._apply_expert_choice_moe(
params,
inputs,
router_probs=router_probs,
inputs_mask=inputs_mask,
))
else:
outputs, ffn_extra_output = self._apply_dense_moe(
params,
inputs,
selected_indices=selected_indices,
selected_weights=selected_router_probs,
inputs_mask=inputs_mask,
if self.num_experts_per_token == 1:
# Apply `softmax => topk` when k == 1 to avoid zero
# gradient on the router logits.
# selected_router_probs, selected_indices:
# [batch_size, seq_len, num_experts_per_token]
selected_router_probs, selected_indices = (
jax.lax.top_k(
router_probs,
k=self.num_experts_per_token,
))
else:
# Perform `topk => softmax` to get a normalized
# probability distribution.
# selected_router_logits, selected_indices:
# [batch_size, seq_len, num_experts_per_token]
selected_router_logits, selected_indices = (
jax.lax.top_k(
router_logits,
k=self.num_experts_per_token,
))
selected_router_probs = jax.nn.softmax(
selected_router_logits, axis=-1)
selected_router_probs = jnp.asarray(
selected_router_probs, self.activation_dtype
)
if self.expert_capacity_factor is None:
outputs, ffn_extra_output = self._apply_sparse_moe(
params,
inputs,
selected_indices=selected_indices,
selected_weights=selected_router_probs,
inputs_mask=inputs_mask,
)
else:
outputs, ffn_extra_output = self._apply_dense_moe(
params,
inputs,
selected_indices=selected_indices,
selected_weights=selected_router_probs,
inputs_mask=inputs_mask,
)
load = ffn_extra_output['load']
router_probs = jnp.asarray(
router_probs, self.activation_dtype)
extra_output.update(ffn_extra_output)
router_entropy = - jnp.sum(router_probs * jnp.where(
router_probs > 0, jnp.log(router_probs), 0.0), axis=-1)
Expand All @@ -738,7 +762,10 @@ def apply(
),
'gini': jnp.sum(load ** 2) * self.num_experts - 1,
})
if self.lbl_loss_weight > 0:
# Expert choice guarantees uniform load across experts, so
# the load balancing loss would just add a constant term.
if (self.lbl_loss_weight > 0
and self.routing_mode != 'expert_choice'):
if inputs_mask is None:
inputs_mask = jnp.ones(shape=x.shape[:2], dtype=self.activation_dtype)
else:
Expand All @@ -763,6 +790,119 @@ def apply(
)
return outputs, extra_output

def _apply_expert_choice_moe(
self,
params: PyTree,
inputs: Array,
router_probs: Array,
inputs_mask: Array | None = None,
) -> tuple[Array, PyTree]:
"""Expert choice routing (Zhou et al., 2022).

Instead of each token picking its top k experts, each expert
picks its top C tokens from the full sequence. This flips
the selection axis: the router prob matrix is transposed
before topk. Every expert ends up with the same number of
tokens, so load is perfectly balanced without needing
lbl_loss.

Uses dense dispatch (one_hot + einsum), same approach as
_apply_dense_moe. Could be extended to sparse dispatch later.
"""
extra_outputs = {}
batch_size, seq_len, _ = inputs.shape
num_tokens = batch_size * seq_len

# In token choice, each of the T tokens picks k experts,
# giving k*T total (token, expert) assignments. To keep the
# same total work, each of the E experts gets a budget of
# C = k*T/E tokens.
expert_capacity = max(
1,
int(self.num_experts_per_token * num_tokens
/ self.num_experts),
)

inputs = einops.rearrange(inputs, 'b s d -> (b s) d')
router_probs_flat = einops.rearrange(
router_probs, 'b s e -> (b s) e')

# Kill padding token probs so experts don't burn capacity
# on them. Without this, pad tokens can outcompete real
# tokens for expert slots.
if inputs_mask is not None:
token_mask = einops.rearrange(
inputs_mask, 'b s -> (b s) 1')
router_probs_flat = router_probs_flat * token_mask

# Transpose to [E, T] so top_k operates per expert across
# all tokens, instead of per token across experts.
expert_probs = router_probs_flat.T

# Each expert independently selects its top C tokens.
# selected_probs[e, c] = router prob for expert e's c'th
# pick, selected_indices[e, c] = which token it picked.
selected_probs, selected_indices = jax.lax.top_k(
expert_probs, k=expert_capacity)
selected_probs = jnp.asarray(
selected_probs, self.activation_dtype)

# Every expert processes exactly C tokens, so load is
# trivially 1/E for each. The downstream metrics code
# still expects a load vector so we provide one.
extra_outputs['load'] = jnp.ones(
self.num_experts, dtype=self.activation_dtype
) / self.num_experts

# One hot encode token indices for einsum gather.
# dispatch_onehot[e, c, t] = 1 iff expert e picked token t
# as its c'th slot.
dispatch_onehot = jax.nn.one_hot(
selected_indices, num_tokens,
dtype=self.activation_dtype)

# Gather selected tokens into expert buffers.
# expert_inputs[e, c, d] = input embedding of the token
# sitting in expert e's c'th slot.
expert_inputs = jnp.einsum(
'ect,td->ecd', dispatch_onehot, inputs)
expert_inputs = jnp.asarray(
expert_inputs, self.activation_dtype)

# Expert FFN, same as _apply_dense_moe.
projected_inputs = self.ffn_0.apply(
params['ffn_0'], expert_inputs)
activation_fn = registry.FunctionRegistry.get(
self.ffn_activation)
if self.use_gated_activation_in_ffn:
gate = self.ffn_0_gate.apply(
params['ffn_0_gate'], expert_inputs)
gate = jnp.asarray(
activation_fn(gate), self.activation_dtype)
middle = gate * projected_inputs
else:
middle = jnp.asarray(
activation_fn(projected_inputs),
self.activation_dtype,
)
expert_outputs = self.ffn_1.apply(
params['ffn_1'], middle)

# Scatter back: weight each expert output by the router
# prob that caused the selection, then sum across experts.
# If a token was picked by multiple experts it gets a
# weighted combination; if picked by none it stays zero.
dispatch_weights = jnp.einsum(
'ec,ect->ect', selected_probs, dispatch_onehot)
outputs = jnp.einsum(
'ecd,ect->td', expert_outputs, dispatch_weights)

outputs = einops.rearrange(
outputs, '(b s) d -> b s d',
b=batch_size, s=seq_len,
)
return outputs, extra_outputs

def _apply_sparse_moe(
self,
params: PyTree,
Expand Down Expand Up @@ -1562,6 +1702,7 @@ class TransformerBlock(module.SimplyModule):
expert_capacity_factor: float | None = 0.0
lbl_loss_weight: float = 0.0
router_z_loss_weight: float = 0.0
routing_mode: str = 'token_choice'
# Mixed precision related.
activation_dtype: DTypeLike = 'bfloat16'
# Below are for experimental usage.
Expand Down Expand Up @@ -1688,6 +1829,7 @@ def setup(self) -> None:
expert_capacity_factor=self.expert_capacity_factor,
router_z_loss_weight=self.router_z_loss_weight,
lbl_loss_weight=self.lbl_loss_weight,
routing_mode=self.routing_mode,
model_dim=self.model_dim,
expand_factor=self.expand_factor,
use_gated_activation_in_ffn=self.use_gated_activation_in_ffn,
Expand Down Expand Up @@ -1887,6 +2029,7 @@ def _create_transformer_block(pattern):
num_experts_per_token=config.num_experts_per_token,
lbl_loss_weight=config.lbl_loss_weight,
router_z_loss_weight=config.router_z_loss_weight,
routing_mode=config.routing_mode,
tile_batch_seq=config.tile_batch_seq,
tile_model_dim=config.tile_model_dim,
tile_expand_dim=config.tile_expand_dim,
Expand Down
Loading