Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,40 @@ If Mythos uses this technique, each loop is not a repetition — it is a distinc

---

## Depth-extrapolation recipe (experimental)

There is an empirical trade-off between the "each loop is a distinct computational phase" design above and the **inference-time depth-extrapolation** property (more loops at inference → lower perplexity). A 13-run ablation study on a 117.8M-param OpenMythos trained on ~491M tokens of FineWeb-Edu ([details](https://github.com/kyegomez/OpenMythos/issues/28)) found that:

- Training at a **fixed** `n_loops=k` produces a sharp V-shaped PPL curve at inference: a deep minimum at `n_loops=k`, with PPL rising sharply on either side. **More compute at inference does not help** outside the trained depth.
- Training with **stochastic depth sampling** (uniform over `n_loops ∈ {4, 6, 8, 12, 16}`) produces a **flat plateau** instead — any inference depth in the training range gives essentially the same PPL. Robust to depth choice, but still no monotonic scaling.
- The primary mechanism binding the model to its trained depth is the **ACT halting** module. Disabling the ACT output aggregation (returning the final-iteration `h` instead of the ACT-weighted sum) is the only intervention out of five tested (loop-index embedding, per-loop LoRA, LTI trainability, MoE router freezing, LTI-carry breaking) that qualitatively changes the loop-scaling curve.

To enable monotonic inference-time depth extrapolation on OpenMythos, train with the following knobs (see `open_mythos/main.py`):

```python
cfg = mythos_1b()
cfg.disable_act = True # return last-iteration h, no ACT-weighted sum
# And during training, sample n_loops per step:
# n_loops = random.choice([4, 6, 8, 12, 16])
# logits = model(input_ids, n_loops=n_loops)
```

This loses the adaptive per-token-compute behaviour that ACT provides, in exchange for a PPL curve that decreases monotonically with inference loops and saturates (matching the Saunshi et al. 2025 / Parcae depth-extrapolation shape). In the reference run (`disable_act_random` in Issue #28), PPL at `n_loops=1` is 131 and PPL at `n_loops=12` is 59, compared with PPL 1217 → 65 for the same model trained with default ACT + stochastic depth.

Additional ablation flags available on `MythosConfig`:

| Flag | Default | Effect when set to `False` / `True` |
|---|---|---|
| `loop_index_embedding` | `True` | Skip the sinusoidal loop-index injection into `h` at each step |
| `use_per_loop_lora` | `True` | Skip the per-loop `LoRAAdapter.scale` application |
| `disable_act` | `False` | Return `h` from the last loop iteration instead of the ACT-weighted sum (★ primary depth-extrapolation knob) |
| `freeze_moe_router` | `False` | Set `requires_grad=False` on `MoEFFN.router.weight` at init (must be set externally on the model after construction) |
| `break_recurrence` | `False` | Replace `h = A·h_t + B·e + trans_out` with `h = trans_out` (kill LTI state carry) |

Default values preserve the original architecture exactly, so this is a zero-impact addition for existing users.

---

## The Overthinking Problem

More loops is not always better. Beyond a certain depth, excessive recurrence **degrades predictions** — the hidden state drifts past the solution and into noise. This is the "overthinking" failure mode.
Expand Down
27 changes: 24 additions & 3 deletions open_mythos/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ class MythosConfig:
max_output_tokens: int = 4096
# Dropout (set 0.0 to disable; 0.1 is standard for pretraining)
dropout: float = 0.0
# --- Ablation flags (default values preserve the original architecture) ---
# These expose five mechanistic knobs that control how each recurrent step
# differs from every other. Disabling them individually or in combination
# allows practitioners to trade inference-time depth-extrapolation
# (monotonic PPL scaling with n_loops at inference) against the default
# peak-quality-at-trained-depth behaviour. See the "Depth extrapolation"
# section in README.md for the recipe that restores monotonic scaling.
loop_index_embedding: bool = True
use_per_loop_lora: bool = True
disable_act: bool = False
freeze_moe_router: bool = False
break_recurrence: bool = False


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -855,12 +867,19 @@ def forward(
h_out = torch.zeros_like(h)

for t in range(n_loops):
h_loop = loop_index_embedding(h, t, self.loop_dim)
if self.cfg.loop_index_embedding:
h_loop = loop_index_embedding(h, t, self.loop_dim)
else:
h_loop = h
combined = self.norm(h_loop + e)
cache_key = f"recurrent_loop_{t}"
trans_out = self.block(combined, freqs_cis, mask, kv_cache, cache_key)
trans_out = trans_out + self.lora(trans_out, t)
h = self.injection(h, e, trans_out)
if self.cfg.use_per_loop_lora:
trans_out = trans_out + self.lora(trans_out, t)
if self.cfg.break_recurrence:
h = trans_out
else:
h = self.injection(h, e, trans_out)

p = self.act(h) # (B, T)
still_running = ~halted
Expand Down Expand Up @@ -888,6 +907,8 @@ def forward(
if halted.all() and kv_cache is None:
break

if self.cfg.disable_act:
return h
return h_out


Expand Down