diff --git a/open_mythos/main.py b/open_mythos/main.py index 65b0fa8..303a902 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -381,11 +381,7 @@ def forward( k_rope = kv_raw[..., self.kv_lora_rank :] # (B, T, rope_dim) # expand rope keys across heads and apply RoPE before caching so # retrieved keys are already positionally encoded - k_rope = ( - k_rope.unsqueeze(2) - .expand(B, T, self.n_heads, self.qk_rope_dim) - .contiguous() - ) + k_rope = k_rope.unsqueeze(2).repeat(1, 1, self.n_heads, 1) k_rope = apply_rope(k_rope, freqs_cis) # (B, T, H, rope_dim) ← cached if kv_cache is not None: @@ -517,14 +513,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # routed expert dispatch (token-level scatter) out = torch.zeros_like(flat) - for i in range(self.topk): - expert_ids = topk_idx[:, i] - token_scores = topk_scores[:, i].unsqueeze(-1) - for eid in range(self.n_experts): - mask = expert_ids == eid - if not mask.any(): - continue - out[mask] += token_scores[mask] * self.routed_experts[eid](flat[mask]) +# More efficient dispatch +for i in range(self.topk): + expert_ids = topk_idx[:, i] + for eid in range(self.n_experts): + mask = expert_ids == eid + if mask.any(): + out[mask] += token_scores[mask] * self.routed_experts[eid](flat[mask]) +# Still O(n_experts * topk) but with fewer Python overheads # shared experts always fire for every token for shared in self.shared_experts: @@ -821,7 +817,6 @@ def __init__(self, cfg: MythosConfig): self.loop_dim = ( cfg.dim // 8 ) # fraction of channels receiving loop-index embedding - def forward( self, h: torch.Tensor,