From 87f88341697022b3e86572b95a587c8a5c31420a Mon Sep 17 00:00:00 2001 From: xuzhemin <757583912@qq.com> Date: Mon, 20 Apr 2026 17:12:15 +0800 Subject: [PATCH] optimize moe dispatch and expose routing diagnostics Replace the nested top-k/expert routing loops with a single expert pass dispatch path and publish per-step expert load plus a lightweight balance signal for monitoring. Made-with: Cursor --- open_mythos/main.py | 27 ++++++++++++++++++--------- test_main.py | 10 ++++++++++ 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/open_mythos/main.py b/open_mythos/main.py index 238eeed..abcfab1 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -450,6 +450,8 @@ def __init__(self, cfg: MythosConfig): self.router = nn.Linear(cfg.dim, cfg.n_experts, bias=False) # load-balancing bias adjusted externally during training; not a gradient param self.register_buffer("router_bias", torch.zeros(cfg.n_experts)) + self.register_buffer("last_expert_load", torch.zeros(cfg.n_experts)) + self.register_buffer("last_load_balance_loss", torch.tensor(0.0)) self.routed_experts = nn.ModuleList( [Expert(cfg.dim, cfg.expert_dim) for _ in range(cfg.n_experts)] @@ -478,16 +480,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: topk_scores, topk_idx = scores.topk(self.topk, dim=-1) topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # renorm - # routed expert dispatch (token-level scatter) + # Expose lightweight routing diagnostics for training-time monitoring. + with torch.no_grad(): + counts = torch.bincount(topk_idx.flatten(), minlength=self.n_experts).float() + self.last_expert_load.copy_(counts / counts.sum().clamp(min=1.0)) + self.last_load_balance_loss.copy_( + (self.last_expert_load * scores.mean(dim=0)).sum() * self.n_experts + ) + + # routed expert dispatch (single pass over experts) out = torch.zeros_like(flat) - for i in range(self.topk): - expert_ids = topk_idx[:, i] - token_scores = topk_scores[:, i].unsqueeze(-1) - for eid in range(self.n_experts): - mask = expert_ids == eid - if not mask.any(): - continue - out[mask] += token_scores[mask] * self.routed_experts[eid](flat[mask]) + for eid, expert in enumerate(self.routed_experts): + token_idx, k_slot = torch.where(topk_idx == eid) + if token_idx.numel() == 0: + continue + expert_out = expert(flat[token_idx]) + weight = topk_scores[token_idx, k_slot].unsqueeze(-1) + out[token_idx] += weight * expert_out # shared experts always fire for every token for shared in self.shared_experts: diff --git a/test_main.py b/test_main.py index eac767f..a3d0aa4 100644 --- a/test_main.py +++ b/test_main.py @@ -374,6 +374,16 @@ def test_shared_experts_always_fire(self): out = self.moe(x) assert out.abs().sum() > 0 + def test_routing_diagnostics_are_updated(self): + x = torch.randn(B, T, self.cfg.dim) + _ = self.moe(x) + assert torch.isfinite(self.moe.last_load_balance_loss) + assert torch.isclose( + self.moe.last_expert_load.sum(), + torch.tensor(1.0, device=self.moe.last_expert_load.device), + atol=1e-6, + ) + # --------------------------------------------------------------------------- # loop_index_embedding