diff --git a/README.md b/README.md index af2d774..ff1b697 100644 --- a/README.md +++ b/README.md @@ -322,6 +322,40 @@ If Mythos uses this technique, each loop is not a repetition — it is a distinc --- +## Depth-extrapolation recipe (experimental) + +There is an empirical trade-off between the "each loop is a distinct computational phase" design above and the **inference-time depth-extrapolation** property (more loops at inference → lower perplexity). A 13-run ablation study on a 117.8M-param OpenMythos trained on ~491M tokens of FineWeb-Edu ([details](https://github.com/kyegomez/OpenMythos/issues/28)) found that: + +- Training at a **fixed** `n_loops=k` produces a sharp V-shaped PPL curve at inference: a deep minimum at `n_loops=k`, with PPL rising sharply on either side. **More compute at inference does not help** outside the trained depth. +- Training with **stochastic depth sampling** (uniform over `n_loops ∈ {4, 6, 8, 12, 16}`) produces a **flat plateau** instead — any inference depth in the training range gives essentially the same PPL. Robust to depth choice, but still no monotonic scaling. +- The primary mechanism binding the model to its trained depth is the **ACT halting** module. Disabling the ACT output aggregation (returning the final-iteration `h` instead of the ACT-weighted sum) is the only intervention out of five tested (loop-index embedding, per-loop LoRA, LTI trainability, MoE router freezing, LTI-carry breaking) that qualitatively changes the loop-scaling curve. + +To enable monotonic inference-time depth extrapolation on OpenMythos, train with the following knobs (see `open_mythos/main.py`): + +```python +cfg = mythos_1b() +cfg.disable_act = True # return last-iteration h, no ACT-weighted sum +# And during training, sample n_loops per step: +# n_loops = random.choice([4, 6, 8, 12, 16]) +# logits = model(input_ids, n_loops=n_loops) +``` + +This loses the adaptive per-token-compute behaviour that ACT provides, in exchange for a PPL curve that decreases monotonically with inference loops and saturates (matching the Saunshi et al. 2025 / Parcae depth-extrapolation shape). In the reference run (`disable_act_random` in Issue #28), PPL at `n_loops=1` is 131 and PPL at `n_loops=12` is 59, compared with PPL 1217 → 65 for the same model trained with default ACT + stochastic depth. + +Additional ablation flags available on `MythosConfig`: + +| Flag | Default | Effect when set to `False` / `True` | +|---|---|---| +| `loop_index_embedding` | `True` | Skip the sinusoidal loop-index injection into `h` at each step | +| `use_per_loop_lora` | `True` | Skip the per-loop `LoRAAdapter.scale` application | +| `disable_act` | `False` | Return `h` from the last loop iteration instead of the ACT-weighted sum (★ primary depth-extrapolation knob) | +| `freeze_moe_router` | `False` | Set `requires_grad=False` on `MoEFFN.router.weight` at init (must be set externally on the model after construction) | +| `break_recurrence` | `False` | Replace `h = A·h_t + B·e + trans_out` with `h = trans_out` (kill LTI state carry) | + +Default values preserve the original architecture exactly, so this is a zero-impact addition for existing users. + +--- + ## The Overthinking Problem More loops is not always better. Beyond a certain depth, excessive recurrence **degrades predictions** — the hidden state drifts past the solution and into noise. This is the "overthinking" failure mode. diff --git a/open_mythos/main.py b/open_mythos/main.py index 65b0fa8..bf4fb44 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -79,6 +79,18 @@ class MythosConfig: max_output_tokens: int = 4096 # Dropout (set 0.0 to disable; 0.1 is standard for pretraining) dropout: float = 0.0 + # --- Ablation flags (default values preserve the original architecture) --- + # These expose five mechanistic knobs that control how each recurrent step + # differs from every other. Disabling them individually or in combination + # allows practitioners to trade inference-time depth-extrapolation + # (monotonic PPL scaling with n_loops at inference) against the default + # peak-quality-at-trained-depth behaviour. See the "Depth extrapolation" + # section in README.md for the recipe that restores monotonic scaling. + loop_index_embedding: bool = True + use_per_loop_lora: bool = True + disable_act: bool = False + freeze_moe_router: bool = False + break_recurrence: bool = False # --------------------------------------------------------------------------- @@ -855,12 +867,19 @@ def forward( h_out = torch.zeros_like(h) for t in range(n_loops): - h_loop = loop_index_embedding(h, t, self.loop_dim) + if self.cfg.loop_index_embedding: + h_loop = loop_index_embedding(h, t, self.loop_dim) + else: + h_loop = h combined = self.norm(h_loop + e) cache_key = f"recurrent_loop_{t}" trans_out = self.block(combined, freqs_cis, mask, kv_cache, cache_key) - trans_out = trans_out + self.lora(trans_out, t) - h = self.injection(h, e, trans_out) + if self.cfg.use_per_loop_lora: + trans_out = trans_out + self.lora(trans_out, t) + if self.cfg.break_recurrence: + h = trans_out + else: + h = self.injection(h, e, trans_out) p = self.act(h) # (B, T) still_running = ~halted @@ -888,6 +907,8 @@ def forward( if halted.all() and kv_cache is None: break + if self.cfg.disable_act: + return h return h_out