Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions open_mythos/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple, List, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -830,7 +830,8 @@ def forward(
mask: Optional[torch.Tensor] = None,
n_loops: Optional[int] = None,
kv_cache: Optional[dict] = None,
) -> torch.Tensor:
return_latents: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
"""
Run the recurrent loop for up to n_loops iterations with ACT early exit.

Expand All @@ -843,16 +844,19 @@ def forward(
Can be increased at inference for deeper reasoning (depth extrapolation).
kv_cache -- cache dict passed through to the inner TransformerBlock;
each loop iteration uses a separate cache key
return_latents -- if True, also return the list of hidden states across iterations

Returns:
ACT-weighted sum of hidden states across iterations, shape (B, T, dim)
If return_latents is True, returns (h_out, latents)
"""
n_loops = n_loops or self.cfg.max_loop_iters
B, T, D = h.shape

halted = torch.zeros(B, T, device=h.device, dtype=torch.bool)
cumulative_p = torch.zeros(B, T, device=h.device)
h_out = torch.zeros_like(h)
latents = []

for t in range(n_loops):
h_loop = loop_index_embedding(h, t, self.loop_dim)
Expand All @@ -861,6 +865,9 @@ def forward(
trans_out = self.block(combined, freqs_cis, mask, kv_cache, cache_key)
trans_out = trans_out + self.lora(trans_out, t)
h = self.injection(h, e, trans_out)

if return_latents:
latents.append(h.detach().clone())

p = self.act(h) # (B, T)
still_running = ~halted
Expand Down Expand Up @@ -888,6 +895,8 @@ def forward(
if halted.all() and kv_cache is None:
break

if return_latents:
return h_out, latents
return h_out


Expand Down Expand Up @@ -995,7 +1004,8 @@ def forward(
n_loops: Optional[int] = None,
kv_cache: Optional[dict] = None,
start_pos: int = 0,
) -> torch.Tensor:
return_latents: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
"""
Forward pass through Prelude → Recurrent Block → Coda.

Expand All @@ -1009,9 +1019,11 @@ def forward(
sequence; used to select the correct RoPE frequencies
during incremental decoding (0 for prefill, prompt_len
for each subsequent decode step)
return_latents -- if True, also return the reasoning trajectory from the recurrent block

Returns:
Logits of shape (B, T, vocab_size)
If return_latents is True, returns (logits, latents)
"""
T = input_ids.shape[1]
device = input_ids.device
Expand All @@ -1026,12 +1038,18 @@ def forward(
x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"prelude_{i}")

e = x # encoded input frozen for injection every loop
x = self.recurrent(x, e, freqs_cis, mask, n_loops, kv_cache)
if return_latents:
x, latents = self.recurrent(x, e, freqs_cis, mask, n_loops, kv_cache, return_latents=True)
else:
x = self.recurrent(x, e, freqs_cis, mask, n_loops, kv_cache)

for i, layer in enumerate(self.coda):
x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"coda_{i}")

return self.head(self.norm(x))
logits = self.head(self.norm(x))
if return_latents:
return logits, latents
return logits

@torch.no_grad()
def generate(
Expand Down