diff --git a/github_issue.md b/github_issue.md new file mode 100644 index 0000000..b144b9b --- /dev/null +++ b/github_issue.md @@ -0,0 +1,31 @@ +## [Feature] Expert-Choice Routing for MoEFeedForward + +The current `MoEFeedForward` supports token-choice routing where each +token selects its top-k experts. I'd like to add expert-choice routing +(Zhou et al., 2022 -- https://arxiv.org/abs/2202.09368) where each expert +selects its top-C tokens. + +### Motivation + +Expert-choice routing provides natural load balancing -- each expert +processes exactly C tokens by construction, eliminating the need for +auxiliary losses. The Google paper showed it outperforms token-choice on +language modeling and downstream tasks. + +### Proposed changes + +- Add `routing_mode` field (`'token_choice'` | `'expert_choice'`) to + `MoEFeedForward`, `TransformerBlock`, and `BaseExperimentConfig` +- Add `_apply_expert_choice_moe()` following the `_apply_dense_moe` + dispatch pattern (einsums, no shard_map) +- Add test configs `lm_moe_test` and `lm_moe_expert_choice_test` +- Add equivalence and gradient tests to `MoETest` + +### Design notes + +- Backward compatible -- default remains `'token_choice'` +- No new dependencies or files +- Initial implementation uses dense dispatch (not sparse/GMM) for + simplicity; can be extended to sparse dispatch in a follow-up + +Would this be a welcome addition? diff --git a/pr_description.md b/pr_description.md new file mode 100644 index 0000000..e9fb5f3 --- /dev/null +++ b/pr_description.md @@ -0,0 +1,23 @@ +## Add expert-choice routing mode for MoEFeedForward + +Adds expert-choice routing (Zhou et al., 2022) where each expert selectsits top-C tokens, providing natural load balancing without auxiliary losses. + +### Changes + +- `model_lib.py`: `routing_mode` field, `_apply_expert_choice_moe()`,restructured `apply()` with early routing-mode branch +- `config_lib.py`: `routing_mode` in `BaseExperimentConfig`,`lm_moe_test` and `lm_moe_expert_choice_test` configs +- `model_lib_test.py`: `simple_expert_choice_moe()` reference impl,forward/gradient equivalence tests + +### How to test + +```bash +# Expert-choice MoE local testpython -m simply.main --experiment_config lm_moe_expert_choice_test --experiment_dir /tmp/moe_ec_test --alsologtostderr# All MoE unit tests (including existing + new)pytest simply/model_lib_test.py::MoETest -v +``` + +### Design decisions + +- Follows `_apply_dense_moe` dispatch pattern (einsums, not sparse GMM) +- Capacity: C = num_experts_per_token x num_tokens / num_experts +- `lbl_loss` is skipped when `routing_mode='expert_choice'` (load isbalanced by construction, so the auxiliary loss is unnecessary) +- `routing_mode` is validated; unknown values raise `ValueError` +- Token-choice path is unchanged: all 7 existing `simple_moe()`equivalence tests pass with the same tolerances as before \ No newline at end of file diff --git a/simply/config_lib.py b/simply/config_lib.py index 15168e7..87a9e3e 100644 --- a/simply/config_lib.py +++ b/simply/config_lib.py @@ -315,6 +315,7 @@ class BaseExperimentConfig(ExperimentConfig): tile_model_dim: int = 1024 tile_expand_dim: int = 1024 gmm_impl: str = 'ragged_dot' + routing_mode: str = 'token_choice' global_total_num_pages: int = 0 local_total_num_pages: int = 0 page_size: int = 0 @@ -1982,6 +1983,39 @@ def lm_rl_test(): ) +@ExperimentConfigRegistry.register +def lm_moe_test(): + """Tiny token-choice MoE config for local testing.""" + config = lm_test() + return dataclasses.replace( + config, + sharding_config=moe_sharding(), + use_moe=True, + num_experts=4, + num_experts_per_token=2, + expert_capacity_factor=None, + lbl_loss_weight=0.01, + ffn_use_bias=False, + ) + + +@ExperimentConfigRegistry.register +def lm_moe_expert_choice_test(): + """Tiny expert-choice MoE config for local testing.""" + config = lm_test() + return dataclasses.replace( + config, + sharding_config=moe_sharding(), + use_moe=True, + num_experts=4, + num_experts_per_token=1, + expert_capacity_factor=None, + routing_mode='expert_choice', + lbl_loss_weight=0.0, + ffn_use_bias=False, + ) + + def get_default_mesh_shape( config: BaseExperimentConfig, mode: str = 'train', dcn_mesh_shape=None) -> Mapping[str, int]: diff --git a/simply/model_lib.py b/simply/model_lib.py index d6695ed..7bf78ea 100644 --- a/simply/model_lib.py +++ b/simply/model_lib.py @@ -613,6 +613,7 @@ class MoEFeedForward(FeedForward): tile_model_dim: int = 128 tile_expand_dim: int = 128 gmm_impl: str = 'ragged_dot' + routing_mode: str = 'token_choice' def setup(self): if self.ffn_use_bias: @@ -688,45 +689,68 @@ def apply( # router_logits: [batch_size, seq_len, num_experts] router_logits = self.router.apply(params['router'], inputs) router_logits = router_logits.astype(jnp.float32) - # router_probs: [batch_size, seq_len, num_experts] + # Both routing modes use the same softmax router probs, + # they just consume them differently: token choice does + # topk per token, expert choice does topk per expert. router_probs = jax.nn.softmax(router_logits, axis=-1) - if self.num_experts_per_token == 1: - # Apply `softmax => topk` when k == 1 to avoid zero gradient - # on the router logits. - # selected_router_probs, selected_indices: - # [batch_size, seq_len, num_experts_per_token] - selected_router_probs, selected_indices = jax.lax.top_k( - router_probs, k=self.num_experts_per_token - ) - else: - # Perform `topk => softmax` to get a normalized probability distribution. - # selected_router_logits, selected_indices: - # [batch_size, seq_len, num_experts_per_token] - selected_router_logits, selected_indices = jax.lax.top_k( - router_logits, k=self.num_experts_per_token - ) - selected_router_probs = jax.nn.softmax(selected_router_logits, axis=-1) - selected_router_probs = jnp.asarray( - selected_router_probs, self.activation_dtype - ) - router_probs = jnp.asarray(router_probs, self.activation_dtype) - if self.expert_capacity_factor is None: - outputs, ffn_extra_output = self._apply_sparse_moe( - params, - inputs, - selected_indices=selected_indices, - selected_weights=selected_router_probs, - inputs_mask=inputs_mask, - ) + + if self.routing_mode not in ('token_choice', 'expert_choice'): + raise ValueError( + f'Unknown routing_mode: {self.routing_mode!r}') + + if self.routing_mode == 'expert_choice': + outputs, ffn_extra_output = ( + self._apply_expert_choice_moe( + params, + inputs, + router_probs=router_probs, + inputs_mask=inputs_mask, + )) else: - outputs, ffn_extra_output = self._apply_dense_moe( - params, - inputs, - selected_indices=selected_indices, - selected_weights=selected_router_probs, - inputs_mask=inputs_mask, + if self.num_experts_per_token == 1: + # Apply `softmax => topk` when k == 1 to avoid zero + # gradient on the router logits. + # selected_router_probs, selected_indices: + # [batch_size, seq_len, num_experts_per_token] + selected_router_probs, selected_indices = ( + jax.lax.top_k( + router_probs, + k=self.num_experts_per_token, + )) + else: + # Perform `topk => softmax` to get a normalized + # probability distribution. + # selected_router_logits, selected_indices: + # [batch_size, seq_len, num_experts_per_token] + selected_router_logits, selected_indices = ( + jax.lax.top_k( + router_logits, + k=self.num_experts_per_token, + )) + selected_router_probs = jax.nn.softmax( + selected_router_logits, axis=-1) + selected_router_probs = jnp.asarray( + selected_router_probs, self.activation_dtype ) + if self.expert_capacity_factor is None: + outputs, ffn_extra_output = self._apply_sparse_moe( + params, + inputs, + selected_indices=selected_indices, + selected_weights=selected_router_probs, + inputs_mask=inputs_mask, + ) + else: + outputs, ffn_extra_output = self._apply_dense_moe( + params, + inputs, + selected_indices=selected_indices, + selected_weights=selected_router_probs, + inputs_mask=inputs_mask, + ) load = ffn_extra_output['load'] + router_probs = jnp.asarray( + router_probs, self.activation_dtype) extra_output.update(ffn_extra_output) router_entropy = - jnp.sum(router_probs * jnp.where( router_probs > 0, jnp.log(router_probs), 0.0), axis=-1) @@ -738,7 +762,10 @@ def apply( ), 'gini': jnp.sum(load ** 2) * self.num_experts - 1, }) - if self.lbl_loss_weight > 0: + # Expert choice guarantees uniform load across experts, so + # the load balancing loss would just add a constant term. + if (self.lbl_loss_weight > 0 + and self.routing_mode != 'expert_choice'): if inputs_mask is None: inputs_mask = jnp.ones(shape=x.shape[:2], dtype=self.activation_dtype) else: @@ -763,6 +790,119 @@ def apply( ) return outputs, extra_output + def _apply_expert_choice_moe( + self, + params: PyTree, + inputs: Array, + router_probs: Array, + inputs_mask: Array | None = None, + ) -> tuple[Array, PyTree]: + """Expert choice routing (Zhou et al., 2022). + + Instead of each token picking its top k experts, each expert + picks its top C tokens from the full sequence. This flips + the selection axis: the router prob matrix is transposed + before topk. Every expert ends up with the same number of + tokens, so load is perfectly balanced without needing + lbl_loss. + + Uses dense dispatch (one_hot + einsum), same approach as + _apply_dense_moe. Could be extended to sparse dispatch later. + """ + extra_outputs = {} + batch_size, seq_len, _ = inputs.shape + num_tokens = batch_size * seq_len + + # In token choice, each of the T tokens picks k experts, + # giving k*T total (token, expert) assignments. To keep the + # same total work, each of the E experts gets a budget of + # C = k*T/E tokens. + expert_capacity = max( + 1, + int(self.num_experts_per_token * num_tokens + / self.num_experts), + ) + + inputs = einops.rearrange(inputs, 'b s d -> (b s) d') + router_probs_flat = einops.rearrange( + router_probs, 'b s e -> (b s) e') + + # Kill padding token probs so experts don't burn capacity + # on them. Without this, pad tokens can outcompete real + # tokens for expert slots. + if inputs_mask is not None: + token_mask = einops.rearrange( + inputs_mask, 'b s -> (b s) 1') + router_probs_flat = router_probs_flat * token_mask + + # Transpose to [E, T] so top_k operates per expert across + # all tokens, instead of per token across experts. + expert_probs = router_probs_flat.T + + # Each expert independently selects its top C tokens. + # selected_probs[e, c] = router prob for expert e's c'th + # pick, selected_indices[e, c] = which token it picked. + selected_probs, selected_indices = jax.lax.top_k( + expert_probs, k=expert_capacity) + selected_probs = jnp.asarray( + selected_probs, self.activation_dtype) + + # Every expert processes exactly C tokens, so load is + # trivially 1/E for each. The downstream metrics code + # still expects a load vector so we provide one. + extra_outputs['load'] = jnp.ones( + self.num_experts, dtype=self.activation_dtype + ) / self.num_experts + + # One hot encode token indices for einsum gather. + # dispatch_onehot[e, c, t] = 1 iff expert e picked token t + # as its c'th slot. + dispatch_onehot = jax.nn.one_hot( + selected_indices, num_tokens, + dtype=self.activation_dtype) + + # Gather selected tokens into expert buffers. + # expert_inputs[e, c, d] = input embedding of the token + # sitting in expert e's c'th slot. + expert_inputs = jnp.einsum( + 'ect,td->ecd', dispatch_onehot, inputs) + expert_inputs = jnp.asarray( + expert_inputs, self.activation_dtype) + + # Expert FFN, same as _apply_dense_moe. + projected_inputs = self.ffn_0.apply( + params['ffn_0'], expert_inputs) + activation_fn = registry.FunctionRegistry.get( + self.ffn_activation) + if self.use_gated_activation_in_ffn: + gate = self.ffn_0_gate.apply( + params['ffn_0_gate'], expert_inputs) + gate = jnp.asarray( + activation_fn(gate), self.activation_dtype) + middle = gate * projected_inputs + else: + middle = jnp.asarray( + activation_fn(projected_inputs), + self.activation_dtype, + ) + expert_outputs = self.ffn_1.apply( + params['ffn_1'], middle) + + # Scatter back: weight each expert output by the router + # prob that caused the selection, then sum across experts. + # If a token was picked by multiple experts it gets a + # weighted combination; if picked by none it stays zero. + dispatch_weights = jnp.einsum( + 'ec,ect->ect', selected_probs, dispatch_onehot) + outputs = jnp.einsum( + 'ecd,ect->td', expert_outputs, dispatch_weights) + + outputs = einops.rearrange( + outputs, '(b s) d -> b s d', + b=batch_size, s=seq_len, + ) + return outputs, extra_outputs + def _apply_sparse_moe( self, params: PyTree, @@ -1562,6 +1702,7 @@ class TransformerBlock(module.SimplyModule): expert_capacity_factor: float | None = 0.0 lbl_loss_weight: float = 0.0 router_z_loss_weight: float = 0.0 + routing_mode: str = 'token_choice' # Mixed precision related. activation_dtype: DTypeLike = 'bfloat16' # Below are for experimental usage. @@ -1688,6 +1829,7 @@ def setup(self) -> None: expert_capacity_factor=self.expert_capacity_factor, router_z_loss_weight=self.router_z_loss_weight, lbl_loss_weight=self.lbl_loss_weight, + routing_mode=self.routing_mode, model_dim=self.model_dim, expand_factor=self.expand_factor, use_gated_activation_in_ffn=self.use_gated_activation_in_ffn, @@ -1887,6 +2029,7 @@ def _create_transformer_block(pattern): num_experts_per_token=config.num_experts_per_token, lbl_loss_weight=config.lbl_loss_weight, router_z_loss_weight=config.router_z_loss_weight, + routing_mode=config.routing_mode, tile_batch_seq=config.tile_batch_seq, tile_model_dim=config.tile_model_dim, tile_expand_dim=config.tile_expand_dim, diff --git a/simply/model_lib_test.py b/simply/model_lib_test.py index ebeee60..b70addb 100644 --- a/simply/model_lib_test.py +++ b/simply/model_lib_test.py @@ -1285,6 +1285,73 @@ def simple_moe( return outputs +def simple_expert_choice_moe( + params, inputs, inputs_mask, + num_experts_per_token, num_experts, + ffn_activation, + use_gated_activation_in_ffn, + activation_dtype): + # Reference expert choice impl for equivalence testing. Same + # math as _apply_expert_choice_moe but with raw param einsums + # instead of going through the module apply() machinery. + params = model_lib.get_raw_arrays(params) + router_w = jnp.asarray( + params['router']['w'], activation_dtype) + router_logits = jnp.einsum( + 'ie,bsi->bse', router_w, inputs) + router_probs = jax.nn.softmax(router_logits, axis=-1) + b, s, e = router_probs.shape + num_tokens = b * s + expert_capacity = max( + 1, num_experts_per_token * num_tokens // num_experts) + # Flip to [E, T] so top_k selects per expert. + probs_flat = router_probs.reshape(num_tokens, e) + if inputs_mask is not None: + probs_flat = ( + probs_flat * inputs_mask.reshape(num_tokens, 1)) + expert_probs = probs_flat.T # [E, T] + selected_probs, selected_indices = jax.lax.top_k( + expert_probs, k=expert_capacity) + selected_probs = jnp.asarray( + selected_probs, activation_dtype) + inputs_flat = inputs.reshape(num_tokens, -1) + # One hot dispatch, gather, FFN, weighted scatter. + dispatch_onehot = jax.nn.one_hot( + selected_indices, num_tokens, dtype=activation_dtype) + expert_inputs = jnp.einsum( + 'ect,td->ecd', dispatch_onehot, inputs_flat) + # Apply expert FFNs with raw weight einsums. + ffn0_w = jnp.asarray( + params['ffn_0']['w'], activation_dtype) + projected = jnp.einsum( + 'eio,ebi->ebo', ffn0_w, expert_inputs) + activation_fn = registry.FunctionRegistry.get( + ffn_activation) + if use_gated_activation_in_ffn: + ffn0_gate_w = jnp.asarray( + params['ffn_0_gate']['w'], activation_dtype) + gate = jnp.einsum( + 'eio,ebi->ebo', ffn0_gate_w, expert_inputs) + middle = ( + jnp.asarray(activation_fn(gate), activation_dtype) + * projected) + else: + middle = jnp.asarray( + activation_fn(projected), activation_dtype) + ffn1_w = jnp.asarray( + params['ffn_1']['w'], activation_dtype) + expert_outputs = jnp.einsum( + 'eio,ebi->ebo', ffn1_w, middle) + dispatch_weights = jnp.einsum( + 'ec,ect->ect', selected_probs, dispatch_onehot) + outputs = jnp.einsum( + 'ecd,ect->td', expert_outputs, dispatch_weights) + outputs = outputs.reshape(b, s, -1) + if inputs_mask is not None: + outputs = outputs * inputs_mask[..., None] + return outputs + + class MoETest(parameterized.TestCase): @parameterized.named_parameters( @@ -1398,6 +1465,147 @@ def loss2(params, inputs, inputs_mask): lambda x, y: np.testing.assert_allclose(x, y, rtol=1e-2, atol=1e-2), grad1, grad2) + @parameterized.named_parameters( + dict( + testcase_name='_expert_choice_no_gate', + use_gated_activation_in_ffn=False, + num_experts=4, + num_experts_per_token=1, + ), + dict( + testcase_name='_expert_choice_gated', + use_gated_activation_in_ffn=True, + num_experts=4, + num_experts_per_token=2, + ), + dict( + testcase_name='_expert_choice_single_expert', + use_gated_activation_in_ffn=True, + num_experts=1, + num_experts_per_token=1, + ), + ) + def test_expert_choice_moe_equivalence( + self, use_gated_activation_in_ffn, num_experts, + num_experts_per_token, + activation_dtype='bfloat16', + ): + sharding_config = config_lib.moe_sharding() + sharding_lib.set_default_mesh_shape( + mesh_shape=(1, 1, 1, 1), + axis_names=sharding_config.mesh_axis_names) + batch_size, seq_len, model_dim, expand_factor = ( + 2, 4, 4, 2) + segment_ids = jnp.array( + [[1, 2, 3, 0], [1, 0, 0, 1]]) + key = jax.random.PRNGKey(0) + input_key, prng_key = jax.random.split(key) + inputs = jax.random.normal( + input_key, + shape=(batch_size, seq_len, model_dim), + dtype=activation_dtype, + ) + inputs_mask = segment_ids != 0 + + moe_ffn = model_lib.MoEFeedForward( + model_dim=model_dim, + expand_factor=expand_factor, + sharding_config=sharding_config, + num_experts=num_experts, + num_experts_per_token=num_experts_per_token, + expert_capacity_factor=None, + ffn_use_bias=False, + use_gated_activation_in_ffn=( + use_gated_activation_in_ffn), + activation_dtype=activation_dtype, + routing_mode='expert_choice', + ) + + params = moe_ffn.init(prng_key) + moe_output, _ = moe_ffn.apply( + params, inputs, inputs_mask=inputs_mask) + simple_ec_fn = functools.partial( + simple_expert_choice_moe, + num_experts_per_token=num_experts_per_token, + num_experts=num_experts, + ffn_activation=moe_ffn.ffn_activation, + use_gated_activation_in_ffn=( + use_gated_activation_in_ffn), + activation_dtype=activation_dtype, + ) + simple_ec_output = simple_ec_fn( + params, inputs, inputs_mask=inputs_mask) + self.assertEqual( + moe_output.shape, simple_ec_output.shape) + self.assertEqual( + moe_output.dtype, simple_ec_output.dtype) + np.testing.assert_allclose( + moe_output, simple_ec_output, + rtol=1e-2, atol=1e-2) + + # Also check that gradients match, not just forward outputs. + def loss1(params, inputs, inputs_mask): + out, _ = moe_ffn.apply( + params, inputs, inputs_mask=inputs_mask) + return jnp.sum(out) / (batch_size * seq_len) + + def loss2(params, inputs, inputs_mask): + out = simple_ec_fn( + params, inputs, inputs_mask=inputs_mask) + return jnp.sum(out) / (batch_size * seq_len) + + grad1 = jax.grad(loss1)( + params, inputs, inputs_mask) + grad2 = jax.grad(loss2)( + params, inputs, inputs_mask) + jax.tree.map( + lambda x, y: np.testing.assert_allclose( + x, y, rtol=1e-2, atol=1e-2), + grad1, grad2) + + def test_token_choice_routing_mode_regression(self): + """Explicit routing_mode='token_choice' matches default.""" + sharding_config = config_lib.moe_sharding() + sharding_lib.set_default_mesh_shape( + mesh_shape=(1, 1, 1, 1), + axis_names=sharding_config.mesh_axis_names) + batch_size, seq_len, model_dim, expand_factor = ( + 2, 4, 4, 2) + segment_ids = jnp.array( + [[1, 2, 3, 0], [1, 0, 0, 1]]) + key = jax.random.PRNGKey(0) + input_key, prng_key = jax.random.split(key) + inputs = jax.random.normal( + input_key, + shape=(batch_size, seq_len, model_dim), + dtype='bfloat16', + ) + inputs_mask = segment_ids != 0 + + common_kwargs = dict( + model_dim=model_dim, + expand_factor=expand_factor, + sharding_config=sharding_config, + num_experts=4, + num_experts_per_token=2, + expert_capacity_factor=None, + ffn_use_bias=False, + use_gated_activation_in_ffn=True, + activation_dtype='bfloat16', + ) + moe_default = model_lib.MoEFeedForward( + **common_kwargs) + moe_explicit = model_lib.MoEFeedForward( + routing_mode='token_choice', **common_kwargs) + + params = moe_default.init(prng_key) + out_default, _ = moe_default.apply( + params, inputs, inputs_mask=inputs_mask) + out_explicit, _ = moe_explicit.apply( + params, inputs, inputs_mask=inputs_mask) + np.testing.assert_array_equal( + out_default, out_explicit) + if __name__ == '__main__': absltest.main()