From 7c1115ebed50e55a8d7f1fca823f49af949760d9 Mon Sep 17 00:00:00 2001 From: triunex Date: Sat, 25 Apr 2026 23:26:15 +0530 Subject: [PATCH] feat: add eturn_latents to extract continuous reasoning trajectory --- open_mythos/main.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/open_mythos/main.py b/open_mythos/main.py index 65b0fa8..595a38e 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Tuple, List, Union import torch import torch.nn as nn @@ -830,7 +830,8 @@ def forward( mask: Optional[torch.Tensor] = None, n_loops: Optional[int] = None, kv_cache: Optional[dict] = None, - ) -> torch.Tensor: + return_latents: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ Run the recurrent loop for up to n_loops iterations with ACT early exit. @@ -843,9 +844,11 @@ def forward( Can be increased at inference for deeper reasoning (depth extrapolation). kv_cache -- cache dict passed through to the inner TransformerBlock; each loop iteration uses a separate cache key + return_latents -- if True, also return the list of hidden states across iterations Returns: ACT-weighted sum of hidden states across iterations, shape (B, T, dim) + If return_latents is True, returns (h_out, latents) """ n_loops = n_loops or self.cfg.max_loop_iters B, T, D = h.shape @@ -853,6 +856,7 @@ def forward( halted = torch.zeros(B, T, device=h.device, dtype=torch.bool) cumulative_p = torch.zeros(B, T, device=h.device) h_out = torch.zeros_like(h) + latents = [] for t in range(n_loops): h_loop = loop_index_embedding(h, t, self.loop_dim) @@ -861,6 +865,9 @@ def forward( trans_out = self.block(combined, freqs_cis, mask, kv_cache, cache_key) trans_out = trans_out + self.lora(trans_out, t) h = self.injection(h, e, trans_out) + + if return_latents: + latents.append(h.detach().clone()) p = self.act(h) # (B, T) still_running = ~halted @@ -888,6 +895,8 @@ def forward( if halted.all() and kv_cache is None: break + if return_latents: + return h_out, latents return h_out @@ -995,7 +1004,8 @@ def forward( n_loops: Optional[int] = None, kv_cache: Optional[dict] = None, start_pos: int = 0, - ) -> torch.Tensor: + return_latents: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """ Forward pass through Prelude → Recurrent Block → Coda. @@ -1009,9 +1019,11 @@ def forward( sequence; used to select the correct RoPE frequencies during incremental decoding (0 for prefill, prompt_len for each subsequent decode step) + return_latents -- if True, also return the reasoning trajectory from the recurrent block Returns: Logits of shape (B, T, vocab_size) + If return_latents is True, returns (logits, latents) """ T = input_ids.shape[1] device = input_ids.device @@ -1026,12 +1038,18 @@ def forward( x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"prelude_{i}") e = x # encoded input frozen for injection every loop - x = self.recurrent(x, e, freqs_cis, mask, n_loops, kv_cache) + if return_latents: + x, latents = self.recurrent(x, e, freqs_cis, mask, n_loops, kv_cache, return_latents=True) + else: + x = self.recurrent(x, e, freqs_cis, mask, n_loops, kv_cache) for i, layer in enumerate(self.coda): x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"coda_{i}") - return self.head(self.norm(x)) + logits = self.head(self.norm(x)) + if return_latents: + return logits, latents + return logits @torch.no_grad() def generate(