Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions open_mythos/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,8 @@ def __init__(self, cfg: MythosConfig):
self.router = nn.Linear(cfg.dim, cfg.n_experts, bias=False)
# load-balancing bias adjusted externally during training; not a gradient param
self.register_buffer("router_bias", torch.zeros(cfg.n_experts))
self.register_buffer("last_expert_load", torch.zeros(cfg.n_experts))
self.register_buffer("last_load_balance_loss", torch.tensor(0.0))

self.routed_experts = nn.ModuleList(
[Expert(cfg.dim, cfg.expert_dim) for _ in range(cfg.n_experts)]
Expand Down Expand Up @@ -478,16 +480,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
topk_scores, topk_idx = scores.topk(self.topk, dim=-1)
topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # renorm

# routed expert dispatch (token-level scatter)
# Expose lightweight routing diagnostics for training-time monitoring.
with torch.no_grad():
counts = torch.bincount(topk_idx.flatten(), minlength=self.n_experts).float()
self.last_expert_load.copy_(counts / counts.sum().clamp(min=1.0))
self.last_load_balance_loss.copy_(
(self.last_expert_load * scores.mean(dim=0)).sum() * self.n_experts
)

# routed expert dispatch (single pass over experts)
out = torch.zeros_like(flat)
for i in range(self.topk):
expert_ids = topk_idx[:, i]
token_scores = topk_scores[:, i].unsqueeze(-1)
for eid in range(self.n_experts):
mask = expert_ids == eid
if not mask.any():
continue
out[mask] += token_scores[mask] * self.routed_experts[eid](flat[mask])
for eid, expert in enumerate(self.routed_experts):
token_idx, k_slot = torch.where(topk_idx == eid)
if token_idx.numel() == 0:
continue
expert_out = expert(flat[token_idx])
weight = topk_scores[token_idx, k_slot].unsqueeze(-1)
out[token_idx] += weight * expert_out

# shared experts always fire for every token
for shared in self.shared_experts:
Expand Down
10 changes: 10 additions & 0 deletions test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,16 @@ def test_shared_experts_always_fire(self):
out = self.moe(x)
assert out.abs().sum() > 0

def test_routing_diagnostics_are_updated(self):
x = torch.randn(B, T, self.cfg.dim)
_ = self.moe(x)
assert torch.isfinite(self.moe.last_load_balance_loss)
assert torch.isclose(
self.moe.last_expert_load.sum(),
torch.tensor(1.0, device=self.moe.last_expert_load.device),
atol=1e-6,
)


# ---------------------------------------------------------------------------
# loop_index_embedding
Expand Down