diff --git a/docs/superpowers/plans/2026-04-27-stochastic-depth-training.md b/docs/superpowers/plans/2026-04-27-stochastic-depth-training.md new file mode 100644 index 0000000..0abec6a --- /dev/null +++ b/docs/superpowers/plans/2026-04-27-stochastic-depth-training.md @@ -0,0 +1,656 @@ +# Stochastic Depth Training (Option B) Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add a runtime-selectable stochastic-depth training recipe (no ACT weighting + random per-step `n_loops`) while keeping the existing ACT recipe fully intact and checkpoint-compatible. + +**Architecture:** Thread a boolean `bypass_act` flag through `OpenMythos.forward() -> RecurrentBlock.forward()`. When `True`, skip the ACT weighted-sum accumulation and halting-driven early exit, returning the final hidden state directly. The training script samples `n_loops` uniformly per step when in stochastic-depth mode. `ACTHalting` and `LoRAAdapter` modules remain present in the model unchanged, so checkpoints are bit-compatible across modes. + +**Tech Stack:** PyTorch (FSDP, distributed), pytest, loguru logger, ClearML. + +**Spec:** `docs/superpowers/specs/2026-04-27-stochastic-depth-training-design.md` + +--- + +## File Structure + +| File | Action | Responsibility | +|------|--------|----------------| +| `open_mythos/main.py` | Modify | Add `bypass_act` parameter to `RecurrentBlock.forward()` and `OpenMythos.forward()` | +| `training/1b_poc_fineweb.py` | Modify | Add `recurrent_mode` / `stochastic_depth_min` / `stochastic_depth_max` variables; sample `n_loops` per step; log mode and per-step `n_loops` | +| `tests/test_stochastic_depth.py` | Create | New test module for `bypass_act` behavior, regression of ACT path, checkpoint cross-mode compatibility, smoke test of training step | + +--- + +## Task 1: RecurrentBlock `bypass_act` — test first + +**Files:** +- Create: `tests/test_stochastic_depth.py` +- Modify: `open_mythos/main.py` (RecurrentBlock.forward signature and body) + +- [ ] **Step 1: Write the failing tests** + +Create the file `tests/test_stochastic_depth.py` with the following content: + +```python +"""Tests for stochastic-depth (Option B) training path: bypass_act flag.""" + +import pytest +import torch + +from open_mythos.main import MythosConfig, OpenMythos, RecurrentBlock + + +def _small_cfg() -> MythosConfig: + """Small CPU config used by the existing test suite.""" + return MythosConfig( + vocab_size=128, + dim=64, + n_heads=4, + n_kv_heads=2, + max_seq_len=32, + max_loop_iters=4, + prelude_layers=1, + coda_layers=1, + attn_type="gqa", + n_experts=2, + n_shared_experts=1, + n_experts_per_tok=2, + expert_dim=64, + act_threshold=0.99, + lora_rank=4, + ) + + +def _build_block_inputs(cfg: MythosConfig, B: int = 2, T: int = 8): + """Build the (h, e, freqs_cis) inputs needed by RecurrentBlock.forward.""" + torch.manual_seed(0) + model = OpenMythos(cfg) + input_ids = torch.randint(0, cfg.vocab_size, (B, T)) + x = model.embed(input_ids) + freqs_cis = model.freqs_cis[:T] + mask = model._causal_mask(T, x.device, x.dtype) + for i, layer in enumerate(model.prelude): + x = layer(x, freqs_cis, mask, None, cache_key=f"prelude_{i}") + return model.recurrent, x.clone(), x.clone(), freqs_cis, mask + + +def test_recurrent_block_bypass_act_differs_from_act(): + """bypass_act=True should produce a different output than bypass_act=False.""" + cfg = _small_cfg() + block, h, e, freqs_cis, mask = _build_block_inputs(cfg) + torch.manual_seed(1) + out_act = block(h.clone(), e, freqs_cis, mask, n_loops=4, bypass_act=False) + torch.manual_seed(1) + out_bypass = block(h.clone(), e, freqs_cis, mask, n_loops=4, bypass_act=True) + assert out_act.shape == out_bypass.shape + assert not torch.allclose(out_act, out_bypass, atol=1e-6), ( + "bypass_act=True should not equal ACT-weighted output" + ) + + +def test_recurrent_block_bypass_act_runs_full_n_loops(): + """With bypass_act=True there should be no early exit; all n_loops iterations run.""" + cfg = _small_cfg() + block, h, e, freqs_cis, mask = _build_block_inputs(cfg) + call_count = {"n": 0} + original_block = block.block.forward + + def counting_forward(*args, **kwargs): + call_count["n"] += 1 + return original_block(*args, **kwargs) + + block.block.forward = counting_forward + try: + _ = block(h, e, freqs_cis, mask, n_loops=3, bypass_act=True) + finally: + block.block.forward = original_block + assert call_count["n"] == 3, f"expected 3 block calls, got {call_count['n']}" + + +def test_recurrent_block_bypass_act_returns_final_h(): + """bypass_act=True output should match a manual iteration returning the final h.""" + from open_mythos.main import loop_index_embedding + + cfg = _small_cfg() + block, h, e, freqs_cis, mask = _build_block_inputs(cfg) + n_loops = 3 + + torch.manual_seed(1) + h_manual = h.clone() + for t in range(n_loops): + h_loop = loop_index_embedding(h_manual, t, block.loop_dim) + combined = block.norm(h_loop + e) + trans_out = block.block(combined, freqs_cis, mask, None, f"recurrent_loop_{t}") + trans_out = trans_out + block.lora(trans_out, t) + h_manual = block.injection(h_manual, e, trans_out) + + torch.manual_seed(1) + out_bypass = block(h.clone(), e, freqs_cis, mask, n_loops=n_loops, bypass_act=True) + + assert torch.allclose(out_bypass, h_manual, atol=1e-5), ( + "bypass_act=True should return the final hidden state after n_loops iterations" + ) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `pytest tests/test_stochastic_depth.py -v` +Expected: all three tests FAIL with `TypeError: forward() got an unexpected keyword argument 'bypass_act'`. + +- [ ] **Step 3: Add `bypass_act` parameter to `RecurrentBlock.forward()`** + +In `open_mythos/main.py`, modify `RecurrentBlock.forward()` (currently around lines 853–941). Replace the current `forward` method with this version: + +```python + def forward( + self, + h: torch.Tensor, + e: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor] = None, + n_loops: Optional[int] = None, + kv_cache: Optional[dict] = None, + bypass_act: bool = False, + ) -> torch.Tensor: + """ + Run the recurrent loop for up to n_loops iterations. + + Args: + h -- initial hidden state from the Prelude, shape (B, T, dim) + e -- encoded input frozen for injection each step, shape (B, T, dim) + freqs_cis -- precomputed RoPE frequencies + mask -- additive causal mask or None + n_loops -- number of loop iterations; defaults to cfg.max_loop_iters. + kv_cache -- cache dict passed through to the inner TransformerBlock; + each loop iteration uses a separate cache key + bypass_act -- if True, skip ACT weighting and return the final h directly + after running all n_loops iterations (used for Option B + stochastic-depth training). + + Returns: + ACT-weighted sum of hidden states across iterations when bypass_act=False, + or the final hidden state after n_loops iterations when bypass_act=True. + Shape: (B, T, dim) in both cases. + """ + n_loops = n_loops or self.cfg.max_loop_iters + B, T, D = h.shape + + if not bypass_act: + halted = torch.zeros(B, T, device=h.device, dtype=torch.bool) + cumulative_p = torch.zeros(B, T, device=h.device) + h_out = torch.zeros_like(h) + + for t in range(n_loops): + h_loop = loop_index_embedding(h, t, self.loop_dim) + combined = self.norm(h_loop + e) + cache_key = f"recurrent_loop_{t}" + trans_out = self.block(combined, freqs_cis, mask, kv_cache, cache_key) + trans_out = trans_out + self.lora(trans_out, t) + h = self.injection(h, e, trans_out) + + if bypass_act: + continue + + p = self.act(h) # (B, T) + still_running = ~halted + + remainder = (1.0 - cumulative_p).clamp(min=0) + weight = torch.where( + cumulative_p + p >= self.cfg.act_threshold, + remainder, + p, + ) + weight = weight * still_running.float() + h_out = h_out + weight.unsqueeze(-1) * h + + cumulative_p = cumulative_p + p * still_running.float() + halted = halted | (cumulative_p >= self.cfg.act_threshold) + + if kv_cache is None: + all_halted = halted.all() + if torch.distributed.is_initialized(): + flag = torch.tensor( + [all_halted], dtype=torch.int32, device=h.device + ) + torch.distributed.all_reduce( + flag, op=torch.distributed.ReduceOp.MIN + ) + all_halted = flag.item() > 0 + if all_halted: + break + + if bypass_act: + return h + + not_halted = ~halted + if not_halted.any(): + final_remainder = (1.0 - cumulative_p).clamp(min=0) * not_halted.float() + h_out = h_out + final_remainder.unsqueeze(-1) * h + return h_out +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `pytest tests/test_stochastic_depth.py -v` +Expected: all three `test_recurrent_block_bypass_act_*` tests PASS. + +- [ ] **Step 5: Verify the ACT-mode regression — full existing suite** + +Run: `pytest tests/test_main.py -v` +Expected: same pass/fail counts as before this task (no newly broken tests; the 14 pre-existing failures remain). The goal is proving `bypass_act=False` (default) did not break the existing ACT behavior. + +- [ ] **Step 6: Commit** + +```bash +git add open_mythos/main.py tests/test_stochastic_depth.py +git commit -m "feat(model): add bypass_act flag to RecurrentBlock.forward + +Skips ACT weighting and returns the final hidden state directly. +Default bypass_act=False preserves the existing ACT code path. +" +``` + +--- + +## Task 2: Plumb `bypass_act` through `OpenMythos.forward()` + +**Files:** +- Modify: `open_mythos/main.py` (OpenMythos.forward signature and body) +- Modify: `tests/test_stochastic_depth.py` (add one test) + +- [ ] **Step 1: Write the failing test** + +Append this test to `tests/test_stochastic_depth.py`: + +```python +def test_openmythos_forward_bypass_act_propagates(): + """OpenMythos.forward(bypass_act=True) should route through RecurrentBlock with bypass_act=True.""" + cfg = _small_cfg() + torch.manual_seed(0) + model = OpenMythos(cfg) + input_ids = torch.randint(0, cfg.vocab_size, (2, 8)) + + torch.manual_seed(1) + logits_act = model(input_ids, n_loops=3, bypass_act=False) + torch.manual_seed(1) + logits_bypass = model(input_ids, n_loops=3, bypass_act=True) + + assert logits_act.shape == logits_bypass.shape + assert not torch.allclose(logits_act, logits_bypass, atol=1e-6), ( + "bypass_act should change model output" + ) +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `pytest tests/test_stochastic_depth.py::test_openmythos_forward_bypass_act_propagates -v` +Expected: FAIL with `TypeError: forward() got an unexpected keyword argument 'bypass_act'`. + +- [ ] **Step 3: Add `bypass_act` to `OpenMythos.forward()`** + +In `open_mythos/main.py`, locate `OpenMythos.forward()` (currently around lines 1043–1086). Make two edits. + +First, update the signature and docstring (around lines 1044–1072). Replace the method definition header with: + +```python + def forward( + self, + input_ids: torch.Tensor, + n_loops: Optional[int] = None, + kv_cache: Optional[dict] = None, + start_pos: int = 0, + bypass_act: bool = False, + ) -> torch.Tensor: + """ + Forward pass through Prelude → Recurrent Block → Coda. + + Args: + input_ids -- token indices of shape (B, T) + n_loops -- recurrent loop depth; defaults to cfg.max_loop_iters. + Increase at inference to extrapolate to harder problems. + kv_cache -- dict mutated in-place for autoregressive KV caching; + pass an empty dict {} and reuse across decode steps + start_pos -- index of the first token in input_ids within the full + sequence; used to select the correct RoPE frequencies + during incremental decoding (0 for prefill, prompt_len + for each subsequent decode step) + bypass_act -- if True, RecurrentBlock skips ACT weighting and returns + the final hidden state directly. Default False preserves + the existing ACT behavior. + + Returns: + Logits of shape (B, T, vocab_size) + """ +``` + +Second, update the call to `self.recurrent(...)` — find the line that currently reads: + +```python + x = self.recurrent(x, e, freqs_cis, mask, n_loops, kv_cache) +``` + +Replace it with: + +```python + x = self.recurrent(x, e, freqs_cis, mask, n_loops, kv_cache, bypass_act) +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `pytest tests/test_stochastic_depth.py::test_openmythos_forward_bypass_act_propagates -v` +Expected: PASS. + +- [ ] **Step 5: Run the full new test file** + +Run: `pytest tests/test_stochastic_depth.py -v` +Expected: all four tests PASS. + +- [ ] **Step 6: Commit** + +```bash +git add open_mythos/main.py tests/test_stochastic_depth.py +git commit -m "feat(model): plumb bypass_act through OpenMythos.forward" +``` + +--- + +## Task 3: Checkpoint round-trip test across modes + +**Files:** +- Modify: `tests/test_stochastic_depth.py` (add one test) + +- [ ] **Step 1: Write the cross-mode checkpoint test** + +Append this test to `tests/test_stochastic_depth.py`: + +```python +def test_state_dict_compatible_across_modes(tmp_path): + """A checkpoint saved before toggling bypass_act should load without key mismatch.""" + cfg = _small_cfg() + torch.manual_seed(0) + model_a = OpenMythos(cfg) + ckpt_path = tmp_path / "model.pt" + torch.save(model_a.state_dict(), ckpt_path) + + torch.manual_seed(1) + model_b = OpenMythos(cfg) + state = torch.load(ckpt_path, map_location="cpu") + missing, unexpected = model_b.load_state_dict(state, strict=True) + assert not missing, f"unexpected missing keys: {missing}" + assert not unexpected, f"unexpected extra keys: {unexpected}" + + input_ids = torch.randint(0, cfg.vocab_size, (2, 8)) + torch.manual_seed(2) + logits_act = model_b(input_ids, n_loops=3, bypass_act=False) + torch.manual_seed(2) + logits_bypass = model_b(input_ids, n_loops=3, bypass_act=True) + assert logits_act.shape == logits_bypass.shape +``` + +- [ ] **Step 2: Run test to verify it passes** + +Run: `pytest tests/test_stochastic_depth.py::test_state_dict_compatible_across_modes -v` +Expected: PASS (no model code changes needed — the parameter set is already mode-independent). + +- [ ] **Step 3: Commit** + +```bash +git add tests/test_stochastic_depth.py +git commit -m "test: verify state_dict is compatible across ACT / stochastic_depth modes" +``` + +--- + +## Task 4: Training script — runtime mode toggle + per-step sampling + logging + +**Files:** +- Modify: `training/1b_poc_fineweb.py` + +- [ ] **Step 1: Add the `random` import** + +In `training/1b_poc_fineweb.py`, locate the import block near the top of the file (around line 29–46). Add `import random` alphabetically among the stdlib imports. For example, after `import os` (or wherever it fits alphabetically): + +```python +import random +``` + +- [ ] **Step 2: Add the three hyperparameters to the hyperparams block** + +In `training/1b_poc_fineweb.py`, locate the hyperparameter block that starts around line 398: + +```python + # ------------------------------------------------------------------ + # Hyperparameters (env-var configurable with defaults) + # ------------------------------------------------------------------ + seq_len = 2048 + micro_batch = 1 +``` + +Immediately before `seq_len = 2048`, insert the three new variables: + +```python + # Recurrent-depth training recipe (Option A: ACT, Option B: stochastic depth). + # Change recurrent_mode to "act" to use the original ACT halting recipe. + recurrent_mode = "stochastic_depth" # "act" or "stochastic_depth" + stochastic_depth_min = 1 + stochastic_depth_max = 32 + +``` + +- [ ] **Step 3: Add startup banner and ClearML hparams** + +Locate `training_hparams = {...}` (around line 423). Add the three new keys at the end of the dict (just before the closing `}`): + +```python + "recurrent_mode": recurrent_mode, + "stochastic_depth_min": stochastic_depth_min, + "stochastic_depth_max": stochastic_depth_max, +``` + +Then find the `if master:` block that logs hyperparameters (search for the earliest `logger.info` with "Parameters:" or the config banner near line 484). Immediately after the existing banner lines, add a dedicated mode line. For example, right after: + +```python + logger.info(f"Parameters: {param_count:,} | AMP dtype: {amp_dtype}") +``` + +(The exact wording may differ — find the existing "Parameters:" log line and insert the next line directly after it, inside the same `if master:` guard if present.) + +Add: + +```python + if master: + if recurrent_mode == "stochastic_depth": + logger.info( + f"Recurrent mode: stochastic_depth " + f"(n_loops sampled uniformly from [{stochastic_depth_min}, {stochastic_depth_max}])" + ) + else: + logger.info(f"Recurrent mode: act (n_loops = cfg.max_loop_iters = {cfg.max_loop_iters})") +``` + +- [ ] **Step 4: Sample `n_loops` per step and pass both flags to the forward** + +Locate the training loop forward call (around line 555–556): + +```python + with sync, amp_ctx: + logits = model(x) +``` + +Replace with: + +```python + if recurrent_mode == "stochastic_depth": + n_loops_this_step = random.randint(stochastic_depth_min, stochastic_depth_max) + bypass_act_this_step = True + else: + n_loops_this_step = None + bypass_act_this_step = False + + with sync, amp_ctx: + logits = model( + x, + n_loops=n_loops_this_step, + bypass_act=bypass_act_this_step, + ) +``` + +- [ ] **Step 5: Include mode and n_loops in the per-step stderr log and ClearML scalars** + +Locate the per-step logging block (around line 572–588). Modify the `logger.info(...)` call to include `mode=` and `n_loops=`. + +Replace: + +```python + logger.info( + f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} " + f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} " + f"| {tok_per_sec / 1e6:.2f}M tok/s " + f"| {tokens_seen / 1e9:.1f}B tokens seen" + ) +``` + +with: + +```python + n_loops_display = ( + n_loops_this_step + if n_loops_this_step is not None + else cfg.max_loop_iters + ) + logger.info( + f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} " + f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} " + f"| {tok_per_sec / 1e6:.2f}M tok/s " + f"| {tokens_seen / 1e9:.1f}B tokens seen " + f"| mode={recurrent_mode} n_loops={n_loops_display}" + ) +``` + +Then, in the block of `log_clearml(...)` calls directly below, add one more scalar: + +```python + log_clearml("n_loops", float(n_loops_display), step) +``` + +- [ ] **Step 6: Run the full test suite to verify no regression in the training script** + +The training script is not directly unit-tested, but a syntax/import error would be caught by import. Run: + +```bash +python -c "import ast; ast.parse(open('training/1b_poc_fineweb.py').read()); print('OK')" +``` + +Expected: `OK`. + +- [ ] **Step 7: Commit** + +```bash +git add training/1b_poc_fineweb.py +git commit -m "feat(training): add stochastic-depth mode to training script + +New local variables (recurrent_mode, stochastic_depth_min, stochastic_depth_max) +control the recipe. Default recurrent_mode='stochastic_depth' samples n_loops +uniformly from [1, 32] and uses bypass_act=True. Set recurrent_mode='act' +for the original ACT halting recipe. + +Logs mode and per-step n_loops to stderr and ClearML. +" +``` + +--- + +## Task 5: Smoke-test training step in each mode + +**Files:** +- Modify: `tests/test_stochastic_depth.py` (add one test) + +- [ ] **Step 1: Write the smoke test** + +Append to `tests/test_stochastic_depth.py`: + +```python +def test_training_step_runs_in_each_mode(): + """One forward+backward+optimizer step works in both modes without error.""" + cfg = _small_cfg() + torch.manual_seed(0) + model = OpenMythos(cfg) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + input_ids = torch.randint(0, cfg.vocab_size, (2, 8)) + targets = torch.randint(0, cfg.vocab_size, (2, 8)) + + # ACT mode + optimizer.zero_grad() + logits = model(input_ids, n_loops=None, bypass_act=False) + loss_act = torch.nn.functional.cross_entropy( + logits.view(-1, cfg.vocab_size), targets.view(-1) + ) + loss_act.backward() + optimizer.step() + assert torch.isfinite(loss_act), "ACT-mode loss must be finite" + + # Stochastic-depth mode + optimizer.zero_grad() + logits = model(input_ids, n_loops=3, bypass_act=True) + loss_sd = torch.nn.functional.cross_entropy( + logits.view(-1, cfg.vocab_size), targets.view(-1) + ) + loss_sd.backward() + optimizer.step() + assert torch.isfinite(loss_sd), "stochastic-depth-mode loss must be finite" +``` + +- [ ] **Step 2: Run the smoke test** + +Run: `pytest tests/test_stochastic_depth.py::test_training_step_runs_in_each_mode -v` +Expected: PASS. + +- [ ] **Step 3: Run the full new test file once more** + +Run: `pytest tests/test_stochastic_depth.py -v` +Expected: all 5 tests PASS. + +- [ ] **Step 4: Run lint/format** + +```bash +black tests/test_stochastic_depth.py training/1b_poc_fineweb.py open_mythos/main.py +ruff check --fix tests/test_stochastic_depth.py training/1b_poc_fineweb.py open_mythos/main.py +``` + +Expected: no changes required (or only whitespace fixes). If ruff/black makes edits, inspect and commit. + +- [ ] **Step 5: Commit** + +```bash +git add tests/test_stochastic_depth.py +git commit -m "test: smoke test one training step in each recurrent mode" +``` + +--- + +## Task 6: Push and verify end-to-end + +- [ ] **Step 1: Confirm all new tests pass** + +Run: `pytest tests/test_stochastic_depth.py -v` +Expected: 5 tests PASS. + +- [ ] **Step 2: Confirm existing tests have no new failures** + +Run: `pytest tests/test_main.py -v` +Expected: the 14 pre-existing failures remain (RoPE + LTI boundary); no new failures introduced. + +- [ ] **Step 3: Push to origin** + +```bash +git push origin main +``` + +--- + +## Post-implementation notes (not part of plan execution) + +After this plan is merged, the **currently running 10B training job (56429) will auto-pick up the new default `recurrent_mode="stochastic_depth"` on the next preemption + resubmit** via `bash deploy/bluevela/bsub_1b_10b.sh`. The user has explicitly requested stochastic_depth as the default. + +If the current ACT run should continue under ACT instead, set `recurrent_mode = "act"` at the top of `training/1b_poc_fineweb.py` before resubmitting. A mode switch mid-training will cause a transient loss spike of ~0.3–0.5 for a few hundred steps while the Coda re-adapts (documented in spec). diff --git a/docs/superpowers/specs/2026-04-27-stochastic-depth-training-design.md b/docs/superpowers/specs/2026-04-27-stochastic-depth-training-design.md new file mode 100644 index 0000000..f31775b --- /dev/null +++ b/docs/superpowers/specs/2026-04-27-stochastic-depth-training-design.md @@ -0,0 +1,119 @@ +# Stochastic Depth Training (Option B) — Design Spec + +**Date:** 2026-04-27 +**Status:** Design, not yet implemented +**Related:** issue #5 (ACT depth-binding), `docs/logbook/2026-04-24-eval-and-analysis.md` + +--- + +## Goal + +Add a second training recipe — **stochastic depth without ACT weighting** — to the OpenMythos training pipeline, selectable per training run, while keeping the existing ACT recipe fully intact and checkpoint-compatible. + +## Motivation + +The 1B-token PoC evaluation (2026-04-24) confirmed the upstream finding: with ACT enabled, the model binds tightly to its trained recurrent depth (n_loops=16) and gains nothing from additional inference-time iterations. Depth extrapolation — a core advertised property of recurrent-depth transformers — is unreachable while ACT is on. + +Upstream empirical work ([kyegomez/OpenMythos#28](https://github.com/kyegomez/OpenMythos/issues/28), 13-run ablation) showed that the only recipe producing a monotonically decreasing PPL-vs-depth curve was: + +- Disable ACT (return the final hidden state directly, no weighted sum) +- Train with random `n_loops` sampled per step + +We want this recipe available as an alternative training strategy without abandoning the ACT path. The two should be freely switchable, including mid-training from the same checkpoint, so the model can be trained under different recipes in different phases. + +## Design + +### Runtime control + +Two hyperparameters added directly to `training/1b_poc_fineweb.py` (local variables, not env vars — avoids env-var sprawl): + +```python +recurrent_mode = "stochastic_depth" # "act" or "stochastic_depth" +stochastic_depth_min = 1 +stochastic_depth_max = 32 +``` + +The default is `"stochastic_depth"`. To use the current ACT recipe, change to `"act"`. + +### Per-step forward + +In the training loop, before each forward pass: + +- If `recurrent_mode == "stochastic_depth"`: sample `n_loops` uniformly from `[stochastic_depth_min, stochastic_depth_max]` inclusive, and call the model with `bypass_act=True`. +- If `recurrent_mode == "act"`: pass `n_loops=None` (uses `cfg.max_loop_iters`) and `bypass_act=False`. + +**Logging:** +- At training startup (master rank only), print a clearly visible banner stating the active `recurrent_mode` and, if stochastic, the `[min, max]` sampling range. Example: `Recurrent mode: stochastic_depth (n_loops sampled from [1, 32])`. +- Add `recurrent_mode`, `stochastic_depth_min`, `stochastic_depth_max` to the ClearML `training_hparams` dict so they appear in the ClearML task configuration. +- Log per-step `n_loops` as a ClearML scalar so the sampling distribution is visible on the dashboard. +- Include `mode=` and `n_loops=` in the per-step stderr step line so they are visible in the job logs. + +### Model changes + +Two surgical additions to `open_mythos/main.py`: + +1. **`OpenMythos.forward()`** — new parameter `bypass_act: bool = False`, plumbed through to `self.recurrent(...)`. +2. **`RecurrentBlock.forward()`** — new parameter `bypass_act: bool = False`: + - When `False` (default): current behavior unchanged. + - When `True`: skip ACT weighting accumulation, skip the `halted.all()` FSDP all-reduce, return the final `h` directly after the last iteration. + +The `ACTHalting` module stays present in the architecture regardless of mode. When bypassed, its weights simply receive no gradient that step. + +### Checkpoint compatibility + +The parameter set (state_dict keys and shapes) is **identical across modes**. A checkpoint saved in one mode loads cleanly in the other. This enables: + +- Starting from an ACT-trained checkpoint and switching to stochastic depth (current use case — resume from `step_0032000.pt`) +- Curriculum-style training: phases of ACT and phases of stochastic depth interleaved +- Direct A/B comparison on the same initialization + +### Stability + +Existing architectural guarantees make this design stable: + +- **LTI injection** with guaranteed spectral radius < 1 (ZOH discretization) makes the recurrence contractive — hidden state cannot explode across iterations. +- **Input re-injection** at every iteration prevents drift from the input signal. +- **RMSNorm** before every transformer block caps input magnitudes. + +Upstream ablation confirmed monotonic PPL across depths 1→16 under this recipe. + +Caveats: at `n_loops=32`, gradients through 32 shared blocks may partially vanish in the earliest iterations — not catastrophic, but worth monitoring. When switching modes mid-training, expect a transient loss spike (~0.3–0.5, ~few hundred steps) while the Coda re-adapts to the different hidden-state distribution. + +**LoRA depth indexing**: `LoRAAdapter` is initialized with `cfg.max_loop_iters=16` scale embeddings. For `loop_t >= 16`, the adapter already clamps the index (line 641–642) and reuses the depth-15 scale. This means depths 16–31 will share a single LoRA scale rather than having distinct learned scales. Acceptable trade-off: keeps checkpoint compatibility (no shape change in state_dict) and the LoRA delta is a small additive modulation anyway. If per-depth LoRA at extrapolation depths becomes important later, we can bump `cfg.max_loop_iters=32` and pad/re-initialize the LoRA scale embedding in a separate migration. + +### Evaluation + +No changes needed. `evaluations/eval_checkpoint.py` already runs a depth sweep at `n_loops ∈ {1, 2, 4, 8, 12, 16, 24, 32}`, which gives a direct apples-to-apples comparison between Option A and Option B checkpoints. + +## Scope (YAGNI) + +**In scope:** +- Runtime mode toggle in the training script +- `bypass_act` flag plumbed through `OpenMythos.forward()` and `RecurrentBlock.forward()` +- Uniform random `n_loops` sampling in the training loop +- ClearML logging of `recurrent_mode` and per-step `n_loops` + +**Out of scope (explicitly not doing):** +- Biased / non-uniform depth sampling distributions +- Automatic scheduling between modes (manual switch only) +- Removing or refactoring the ACT path +- Changing `MythosConfig` (no new fields; all control is at training-script level) +- Soft attention over loop outputs (Option C) — separate future design if needed + +## Testing + +- Unit test: `RecurrentBlock.forward(bypass_act=True)` returns `h` at the requested `n_loops`, with no ACT accumulation applied. Parameter grads match expectation (ACT module receives zero grad). +- Unit test: `bypass_act=False` path produces identical output to the current implementation (regression). +- Unit test: Checkpoint round-trip — save in one mode, load in the other, verify no state_dict mismatch. +- Smoke test: Small-config training loop runs one step in each mode without error. + +## Success criteria + +1. A single training run can be launched in either `"act"` or `"stochastic_depth"` mode by changing one variable. +2. The current ACT recipe is bit-identical to before when `recurrent_mode="act"`. +3. A checkpoint trained in one mode can be resumed in the other (state_dict loads cleanly; training continues; loss spike is transient). +4. After training ~1B tokens in stochastic_depth mode from a checkpoint, the depth sweep shows non-trivial generation at `n_loops > 16` (i.e., the depth-binding is reduced). + +## Open questions + +None currently. Range `[1, 32]` chosen based on upstream recipe and compute budget; can be tuned later via the script-level variables without code changes. diff --git a/open_mythos/__init__.py b/open_mythos/__init__.py index 73c2c04..64fcdad 100644 --- a/open_mythos/__init__.py +++ b/open_mythos/__init__.py @@ -49,7 +49,5 @@ "mythos_100b", "mythos_500b", "mythos_1t", - "load_tokenizer", - "get_vocab_size", "MythosTokenizer", ] diff --git a/open_mythos/main.py b/open_mythos/main.py index 65b0fa8..a957dce 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -112,8 +112,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: RMS-normalized tensor of the same shape, rescaled by self.weight """ - rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() - return x * rms * self.weight + dtype = x.dtype + rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() + return (x * rms * self.weight).to(dtype) # --------------------------------------------------------------------------- @@ -229,6 +230,7 @@ def forward( Output tensor of shape (B, T, dim) """ B, T, _ = x.shape + x = x.to(self.wq.weight.dtype) # align with FSDP param dtype q = self.wq(x).view(B, T, self.n_heads, self.head_dim) k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) @@ -268,7 +270,9 @@ def forward( if mask is not None: attn = attn + mask attn = F.dropout( - F.softmax(attn, dim=-1), p=self.dropout_p, training=self.training + F.softmax(attn, dim=-1).to(v.dtype), + p=self.dropout_p, + training=self.training, ) out = torch.matmul(attn, v) out = out.transpose(1, 2).contiguous().view(B, T, -1) @@ -367,6 +371,7 @@ def forward( Output tensor of shape (B, T, dim) """ B, T, _ = x.shape + x = x.to(self.q_down.weight.dtype) # align with FSDP param dtype # Q c_q = self.q_norm(self.q_down(x)) @@ -412,7 +417,7 @@ def forward( attn = torch.matmul(q, k.transpose(-2, -1)) * scale if mask is not None: attn = attn + mask - attn = self.attn_drop(F.softmax(attn, dim=-1)) + attn = self.attn_drop(F.softmax(attn, dim=-1).to(v.dtype)) out = torch.matmul(attn, v) # (B, H, T, v_dim) out = out.transpose(1, 2).contiguous().view(B, T, -1) return self.wo(out) @@ -450,6 +455,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: Tensor of shape (..., dim) """ + x = x.to(self.gate.weight.dtype) # align with FSDP param dtype return self.down(F.silu(self.gate(x)) * self.up(x)) @@ -503,6 +509,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: of the weighted routed expert outputs """ B, T, D = x.shape + x = x.to(self.router.weight.dtype) # align with FSDP param dtype flat = x.view(B * T, D) # Aux-loss-free load balancing (DeepSeek-V3): the bias shifts only the @@ -513,18 +520,40 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: scores = F.softmax(logits, dim=-1) _, topk_idx = (logits + self.router_bias).topk(self.topk, dim=-1) topk_scores = scores.gather(-1, topk_idx) - topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # renorm - - # routed expert dispatch (token-level scatter) - out = torch.zeros_like(flat) - for i in range(self.topk): - expert_ids = topk_idx[:, i] - token_scores = topk_scores[:, i].unsqueeze(-1) - for eid in range(self.n_experts): - mask = expert_ids == eid - if not mask.any(): - continue - out[mask] += token_scores[mask] * self.routed_experts[eid](flat[mask]) + topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True).clamp( + min=1e-9 + ) + + # Grouped expert dispatch — one expert call per active expert. + # Flatten all topk (token, expert) pairs, sort by expert ID, + # run each expert once on its contiguous batch, scatter back. + N = flat.size(0) + flat_expert_ids = topk_idx.view(-1) # (N*topk,) + flat_scores = topk_scores.view(-1, 1) # (N*topk, 1) + flat_tokens = flat.repeat_interleave(self.topk, dim=0) # (N*topk, D) + + sorted_order = flat_expert_ids.argsort(stable=True) + sorted_expert_ids = flat_expert_ids[sorted_order] + sorted_tokens = flat_tokens[sorted_order] + sorted_scores = flat_scores[sorted_order] + + unique_experts, counts = torch.unique_consecutive( + sorted_expert_ids, return_counts=True + ) + split_tokens = sorted_tokens.split(counts.tolist()) + split_scores = sorted_scores.split(counts.tolist()) + + expert_outputs = [] + for eid, tok_batch, sc_batch in zip( + unique_experts.tolist(), split_tokens, split_scores + ): + expert_outputs.append(sc_batch * self.routed_experts[eid](tok_batch)) + + sorted_out = torch.cat(expert_outputs, dim=0) + # Unsort back to original (N*topk,) order, then sum over topk dim + out_flat = torch.zeros_like(sorted_out) + out_flat[sorted_order] = sorted_out + out = out_flat.view(N, self.topk, D).sum(dim=1) # (N, D) # shared experts always fire for every token for shared in self.shared_experts: @@ -561,12 +590,16 @@ def loop_index_embedding( """ freqs = 1.0 / ( theta - ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim) + ** ( + torch.arange(0, loop_dim, 2, device=h.device, dtype=torch.float32) + / loop_dim + ) ) angles = loop_t * freqs # (loop_dim//2,) emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim] - emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype) + emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=torch.float32) emb_full[:loop_dim] = emb + emb_full = emb_full.to(h.dtype) return h + emb_full.unsqueeze(0).unsqueeze(0) @@ -615,8 +648,9 @@ def forward(self, x: torch.Tensor, loop_t: int) -> torch.Tensor: max_t = self.scale.num_embeddings - 1 t_idx = loop_t if loop_t <= max_t else max_t s = self.scale(torch.tensor(t_idx, device=x.device)) # (rank,) + x = x.to(self.down.weight.dtype) # align with FSDP param dtype down = self.down(x) * s # (B, T, rank) - return down @ self.B # (B, T, dim) + return down @ self.B.to(down.dtype) # (B, T, dim) # --------------------------------------------------------------------------- @@ -777,7 +811,7 @@ def forward(self, h: torch.Tensor) -> torch.Tensor: Returns: Halting probability tensor of shape (B, T), values in (0, 1) """ - return torch.sigmoid(self.halt(h)).squeeze(-1) + return torch.sigmoid(self.halt(h.to(self.halt.weight.dtype))).squeeze(-1) # --------------------------------------------------------------------------- @@ -830,22 +864,28 @@ def forward( mask: Optional[torch.Tensor] = None, n_loops: Optional[int] = None, kv_cache: Optional[dict] = None, + bypass_act: bool = False, ) -> torch.Tensor: """ - Run the recurrent loop for up to n_loops iterations with ACT early exit. + Run the recurrent loop for up to n_loops iterations. Args: - h -- initial hidden state from the Prelude, shape (B, T, dim) - e -- encoded input frozen for injection each step, shape (B, T, dim) - freqs_cis-- precomputed RoPE frequencies - mask -- additive causal mask or None - n_loops -- number of loop iterations; defaults to cfg.max_loop_iters. - Can be increased at inference for deeper reasoning (depth extrapolation). - kv_cache -- cache dict passed through to the inner TransformerBlock; - each loop iteration uses a separate cache key + h -- initial hidden state from the Prelude, shape (B, T, dim) + e -- encoded input frozen for injection each step, shape (B, T, dim) + freqs_cis -- precomputed RoPE frequencies + mask -- additive causal mask or None + n_loops -- number of loop iterations; defaults to cfg.max_loop_iters. + Can be increased at inference for deeper reasoning (depth extrapolation). + kv_cache -- cache dict passed through to the inner TransformerBlock; + each loop iteration uses a separate cache key + bypass_act -- if True, skip ACT weighting and return the final h directly + after running all n_loops iterations (used for Option B + stochastic-depth training). Returns: - ACT-weighted sum of hidden states across iterations, shape (B, T, dim) + ACT-weighted sum of hidden states across iterations when bypass_act=False, + or the final hidden state after n_loops iterations when bypass_act=True. + Shape: (B, T, dim) in both cases. """ n_loops = n_loops or self.cfg.max_loop_iters B, T, D = h.shape @@ -862,6 +902,9 @@ def forward( trans_out = trans_out + self.lora(trans_out, t) h = self.injection(h, e, trans_out) + if bypass_act: + continue + p = self.act(h) # (B, T) still_running = ~halted @@ -885,9 +928,35 @@ def forward( # Only short-circuit when there is no KV cache to keep consistent. # With a cache, every loop depth must run on every forward pass so # later decode steps find populated keys at every cache_key. - if halted.all() and kv_cache is None: - break - + if kv_cache is None: + all_halted = halted.all() + # Under FSDP/DDP each rank has different data, so halted.all() + # can differ across ranks. If one rank breaks out of the loop + # while others continue, the FSDP all-gather inside self.block + # deadlocks (the exited rank never issues the collective). + # All-reduce with MIN so ranks only exit together. + # The all-reduce is unconditional — every rank must participate + # regardless of its local halting state. + if torch.distributed.is_initialized(): + flag = torch.tensor( + [all_halted], dtype=torch.int32, device=h.device + ) + torch.distributed.all_reduce( + flag, op=torch.distributed.ReduceOp.MIN + ) + all_halted = flag.item() > 0 + if all_halted: + break + + if bypass_act: + return h + + # Assign remainder weight for positions that never halted within n_loops. + # Without this, non-halted positions have weights summing to < 1.0. + not_halted = ~halted + if not_halted.any(): + final_remainder = (1.0 - cumulative_p).clamp(min=0) * not_halted.float() + h_out = h_out + final_remainder.unsqueeze(-1) * h return h_out @@ -995,20 +1064,24 @@ def forward( n_loops: Optional[int] = None, kv_cache: Optional[dict] = None, start_pos: int = 0, + bypass_act: bool = False, ) -> torch.Tensor: """ Forward pass through Prelude → Recurrent Block → Coda. Args: - input_ids -- token indices of shape (B, T) - n_loops -- recurrent loop depth; defaults to cfg.max_loop_iters. - Increase at inference to extrapolate to harder problems. - kv_cache -- dict mutated in-place for autoregressive KV caching; - pass an empty dict {} and reuse across decode steps - start_pos -- index of the first token in input_ids within the full - sequence; used to select the correct RoPE frequencies - during incremental decoding (0 for prefill, prompt_len - for each subsequent decode step) + input_ids -- token indices of shape (B, T) + n_loops -- recurrent loop depth; defaults to cfg.max_loop_iters. + Increase at inference to extrapolate to harder problems. + kv_cache -- dict mutated in-place for autoregressive KV caching; + pass an empty dict {} and reuse across decode steps + start_pos -- index of the first token in input_ids within the full + sequence; used to select the correct RoPE frequencies + during incremental decoding (0 for prefill, prompt_len + for each subsequent decode step) + bypass_act -- if True, RecurrentBlock skips ACT weighting and returns + the final hidden state directly. Default False preserves + the existing ACT behavior. Returns: Logits of shape (B, T, vocab_size) @@ -1026,12 +1099,13 @@ def forward( x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"prelude_{i}") e = x # encoded input frozen for injection every loop - x = self.recurrent(x, e, freqs_cis, mask, n_loops, kv_cache) + x = self.recurrent(x, e, freqs_cis, mask, n_loops, kv_cache, bypass_act) for i, layer in enumerate(self.coda): x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"coda_{i}") - return self.head(self.norm(x)) + x = self.norm(x) + return self.head(x.to(self.head.weight.dtype)) @torch.no_grad() def generate( diff --git a/tests/test_act_fsdp_fix.py b/tests/test_act_fsdp_fix.py new file mode 100644 index 0000000..bdce868 --- /dev/null +++ b/tests/test_act_fsdp_fix.py @@ -0,0 +1,194 @@ +""" +Tests for the ACT early exit FSDP deadlock fix (issue #4). + +Verifies that: + - Single-process early exit still works (no regression) + - KV cache disables early exit (unchanged behavior) + - The all-reduce branch is skipped when torch.distributed is not initialized + - Loop runs all iterations when not all positions have halted + - The fix doesn't change model outputs +""" + +import torch +import pytest +from unittest.mock import patch + +from open_mythos.main import ( + MythosConfig, + OpenMythos, + RecurrentBlock, +) + + +B, T = 2, 8 + + +def small_cfg(**overrides) -> MythosConfig: + defaults = dict( + vocab_size=200, + dim=64, + n_heads=4, + n_kv_heads=2, + max_seq_len=32, + max_loop_iters=3, + prelude_layers=1, + coda_layers=1, + attn_type="mla", + n_experts=4, + n_shared_experts=1, + n_experts_per_tok=2, + expert_dim=16, + act_threshold=0.99, + lora_rank=4, + kv_lora_rank=16, + q_lora_rank=32, + qk_rope_head_dim=8, + qk_nope_head_dim=8, + v_head_dim=8, + ) + defaults.update(overrides) + return MythosConfig(**defaults) + + +class TestACTEarlyExitSingleProcess: + """Verify early exit still works in single-process (no dist initialized).""" + + def test_early_exit_when_all_halted(self): + """With very low threshold + high halt prob, loop should exit early.""" + cfg = small_cfg(act_threshold=0.01, max_loop_iters=16) + model = OpenMythos(cfg) + # Bias ACT to halt immediately + with torch.no_grad(): + model.recurrent.act.halt.weight.fill_(0.0) + model.recurrent.act.halt.bias.fill_(10.0) # sigmoid(10) ≈ 1.0 + + ids = torch.randint(0, cfg.vocab_size, (1, 4)) + # Should complete without hanging — early exit fires + logits = model(ids) + assert logits.shape == (1, 4, cfg.vocab_size) + assert not torch.isnan(logits).any() + + def test_dist_not_initialized_in_tests(self): + """Confirm torch.distributed is not initialized in test environment.""" + assert not torch.distributed.is_initialized() + + def test_early_exit_skips_iterations(self): + """When halting is immediate, fewer loop iterations should run. + + We verify this indirectly: with max_loop_iters=16 and immediate halting, + the forward pass should be fast (not 16x slower than needed). + """ + cfg = small_cfg(act_threshold=0.01, max_loop_iters=16) + model = OpenMythos(cfg) + with torch.no_grad(): + model.recurrent.act.halt.weight.fill_(0.0) + model.recurrent.act.halt.bias.fill_(10.0) + + ids = torch.randint(0, cfg.vocab_size, (B, T)) + # Just verify it completes and produces valid output + logits = model(ids) + assert not torch.isnan(logits).any() + + +class TestACTNoEarlyExitWithKVCache: + """KV cache should disable early exit regardless of halting state.""" + + def test_kv_cache_prevents_early_exit(self): + """With KV cache, all loop iterations must run for cache consistency.""" + cfg = small_cfg(act_threshold=0.01, max_loop_iters=3) + model = OpenMythos(cfg) + # Bias ACT to halt immediately + with torch.no_grad(): + model.recurrent.act.halt.weight.fill_(0.0) + model.recurrent.act.halt.bias.fill_(10.0) + + ids = torch.randint(0, cfg.vocab_size, (1, 4)) + kv_cache = {} + logits = model(ids, kv_cache=kv_cache) + assert logits.shape == (1, 4, cfg.vocab_size) + + # Verify all 3 recurrent loop cache keys were populated + for t in range(cfg.max_loop_iters): + key = f"recurrent_loop_{t}" + assert key in kv_cache, ( + f"Cache key '{key}' missing — loop didn't run iteration {t}" + ) + + +class TestACTAllReduceBranch: + """Verify the all-reduce code path logic.""" + + def test_all_reduce_not_called_without_dist(self): + """When dist is not initialized, torch.distributed.all_reduce should + not be called.""" + cfg = small_cfg(act_threshold=0.01) + model = OpenMythos(cfg) + with torch.no_grad(): + model.recurrent.act.halt.weight.fill_(0.0) + model.recurrent.act.halt.bias.fill_(10.0) + + ids = torch.randint(0, cfg.vocab_size, (B, T)) + with patch("torch.distributed.all_reduce") as mock_ar: + model(ids) + mock_ar.assert_not_called() + + def test_all_reduce_would_be_called_if_dist_initialized(self): + """Verify the is_initialized() check gates the all_reduce call.""" + cfg = small_cfg(act_threshold=0.01) + model = OpenMythos(cfg) + with torch.no_grad(): + model.recurrent.act.halt.weight.fill_(0.0) + model.recurrent.act.halt.bias.fill_(10.0) + + ids = torch.randint(0, cfg.vocab_size, (B, T)) + with patch("torch.distributed.is_initialized", return_value=True), \ + patch("torch.distributed.all_reduce") as mock_ar: + model(ids) + # all_reduce should have been called at least once + assert mock_ar.call_count > 0 + + +class TestACTFixOutputEquivalence: + """The fix must not change model outputs in single-process mode.""" + + def test_output_deterministic(self): + """Same input produces same output — fix doesn't introduce randomness.""" + cfg = small_cfg() + model = OpenMythos(cfg) + model.eval() + + torch.manual_seed(42) + ids = torch.randint(0, cfg.vocab_size, (B, T)) + out1 = model(ids) + out2 = model(ids) + assert torch.allclose(out1, out2, atol=1e-6) + + def test_forward_backward_works(self): + """Full forward+backward completes without error after the fix.""" + cfg = small_cfg() + model = OpenMythos(cfg) + ids = torch.randint(0, cfg.vocab_size, (B, T)) + logits = model(ids) + loss = logits.sum() + loss.backward() + assert model.embed.weight.grad is not None + + def test_generate_works(self): + """Autoregressive generation still works after the fix.""" + cfg = small_cfg() + model = OpenMythos(cfg) + ids = torch.randint(0, cfg.vocab_size, (1, 4)) + out = model.generate(ids, max_new_tokens=4, n_loops=2) + assert out.shape == (1, 8) + + def test_many_loops_no_nan(self): + """Depth extrapolation still works.""" + cfg = small_cfg() + model = OpenMythos(cfg) + ids = torch.randint(0, cfg.vocab_size, (1, 4)) + logits = model(ids, n_loops=10) + assert not torch.isnan(logits).any() + + +if __name__ == "__main__": + pytest.main([__file__, "--verbose"]) diff --git a/tests/test_code_review_fixes.py b/tests/test_code_review_fixes.py new file mode 100644 index 0000000..43d0bea --- /dev/null +++ b/tests/test_code_review_fixes.py @@ -0,0 +1,716 @@ +""" +Tests for code review fixes and MoE dispatch optimization (2026-04-23). + +Covers: + - MoE grouped dispatch: correctness, edge cases, gradient flow + - ACT remainder for non-halted positions + - MoE score renormalization epsilon (div-by-zero guard) + - LoRAAdapter.B dtype safety + - loop_index_embedding float32 precision + - __init__.py public API exports +""" + +import importlib +import math + +import torch +import torch.nn as nn +import pytest +from unittest.mock import patch + +from open_mythos.main import ( + ACTHalting, + Expert, + LoRAAdapter, + MoEFFN, + MythosConfig, + OpenMythos, + RecurrentBlock, + TransformerBlock, + loop_index_embedding, + precompute_rope_freqs, +) + +# --------------------------------------------------------------------------- +# Shared test config — tiny dims for CPU speed +# --------------------------------------------------------------------------- + +B, T = 2, 8 + + +def small_cfg(**overrides) -> MythosConfig: + defaults = dict( + vocab_size=200, + dim=64, + n_heads=4, + n_kv_heads=2, + max_seq_len=32, + max_loop_iters=3, + prelude_layers=1, + coda_layers=1, + attn_type="gqa", + n_experts=4, + n_shared_experts=1, + n_experts_per_tok=2, + expert_dim=16, + act_threshold=0.99, + lora_rank=4, + kv_lora_rank=16, + q_lora_rank=32, + qk_rope_head_dim=8, + qk_nope_head_dim=8, + v_head_dim=8, + ) + defaults.update(overrides) + return MythosConfig(**defaults) + + +# =================================================================== +# MoE Grouped Dispatch +# =================================================================== + + +class TestMoEGroupedDispatch: + """Tests for the grouped/batched MoE dispatch replacing the nested loop.""" + + def setup_method(self): + self.cfg = small_cfg() + self.moe = MoEFFN(self.cfg) + + def test_output_shape_standard(self): + x = torch.randn(B, T, self.cfg.dim) + assert self.moe(x).shape == (B, T, self.cfg.dim) + + def test_single_token(self): + """Edge case: batch with only one token (B=1, T=1).""" + x = torch.randn(1, 1, self.cfg.dim) + out = self.moe(x) + assert out.shape == (1, 1, self.cfg.dim) + assert not torch.isnan(out).any() + + def test_large_batch(self): + """Stress test with larger batch to exercise grouping with many tokens.""" + x = torch.randn(8, 32, self.cfg.dim) + out = self.moe(x) + assert out.shape == (8, 32, self.cfg.dim) + assert not torch.isnan(out).any() + + def test_topk_1(self): + """Edge case: only one expert per token (topk=1).""" + cfg = small_cfg(n_experts_per_tok=1) + moe = MoEFFN(cfg) + x = torch.randn(B, T, cfg.dim) + out = moe(x) + assert out.shape == (B, T, cfg.dim) + assert not torch.isnan(out).any() + + def test_topk_equals_n_experts(self): + """Edge case: every expert is selected for every token.""" + cfg = small_cfg(n_experts=4, n_experts_per_tok=4) + moe = MoEFFN(cfg) + x = torch.randn(B, T, cfg.dim) + out = moe(x) + assert out.shape == (B, T, cfg.dim) + assert not torch.isnan(out).any() + + def test_all_tokens_same_expert(self): + """Force all tokens to route to the same expert via router_bias.""" + cfg = small_cfg(n_experts=4, n_experts_per_tok=2) + moe = MoEFFN(cfg) + # Overwhelm the router logits: bias expert 0 and 1 massively + moe.router_bias.data = torch.tensor( + [1000.0, 999.0, -1000.0, -1000.0] + ) + x = torch.randn(B, T, cfg.dim) + out = moe(x) + assert out.shape == (B, T, cfg.dim) + assert not torch.isnan(out).any() + + def test_no_nan_or_inf(self): + x = torch.randn(B, T, self.cfg.dim) + out = self.moe(x) + assert not torch.isnan(out).any() + assert not torch.isinf(out).any() + + def test_gradient_flows_through_routed_experts(self): + """Verify gradients reach routed expert parameters.""" + x = torch.randn(B, T, self.cfg.dim, requires_grad=True) + out = self.moe(x) + loss = out.sum() + loss.backward() + # At least some routed experts should have gradients + has_grad = any( + p.grad is not None and p.grad.abs().sum() > 0 + for exp in self.moe.routed_experts + for p in exp.parameters() + ) + assert has_grad, "No gradient flowed to any routed expert" + + def test_gradient_flows_through_shared_experts(self): + """Verify gradients reach shared expert parameters.""" + x = torch.randn(B, T, self.cfg.dim, requires_grad=True) + out = self.moe(x) + loss = out.sum() + loss.backward() + for shared in self.moe.shared_experts: + has_grad = any( + p.grad is not None and p.grad.abs().sum() > 0 + for p in shared.parameters() + ) + assert has_grad, "No gradient flowed to shared expert" + + def test_gradient_flows_to_input(self): + """Verify gradients propagate back to the input tensor.""" + x = torch.randn(B, T, self.cfg.dim, requires_grad=True) + out = self.moe(x) + loss = out.sum() + loss.backward() + assert x.grad is not None + assert x.grad.abs().sum() > 0 + + def test_router_gradient_exists(self): + """Verify the router weight receives gradients.""" + x = torch.randn(B, T, self.cfg.dim) + out = self.moe(x) + out.sum().backward() + assert self.moe.router.weight.grad is not None + assert self.moe.router.weight.grad.abs().sum() > 0 + + def test_deterministic_output(self): + """Same input should produce same output (no randomness in dispatch).""" + torch.manual_seed(42) + x = torch.randn(B, T, self.cfg.dim) + out1 = self.moe(x.clone()) + out2 = self.moe(x.clone()) + assert torch.allclose(out1, out2, atol=1e-6) + + def test_output_changes_with_different_input(self): + """Different inputs should produce different outputs.""" + x1 = torch.randn(B, T, self.cfg.dim) + x2 = torch.randn(B, T, self.cfg.dim) + out1 = self.moe(x1) + out2 = self.moe(x2) + assert not torch.allclose(out1, out2) + + def test_router_bias_shifts_expert_selection(self): + """Changing router_bias should change which experts are selected.""" + cfg = small_cfg(n_experts=4, n_experts_per_tok=1) + moe = MoEFFN(cfg) + x = torch.randn(1, 1, cfg.dim) + + moe.router_bias.data = torch.tensor([100.0, 0.0, 0.0, 0.0]) + out_biased_0 = moe(x.clone()).detach() + + moe.router_bias.data = torch.tensor([0.0, 0.0, 0.0, 100.0]) + out_biased_3 = moe(x.clone()).detach() + + # Different experts → different outputs (shared expert is the same, + # but routed contribution differs) + assert not torch.allclose(out_biased_0, out_biased_3) + + def test_only_shared_experts_when_routed_zeroed(self): + """Zeroing routed experts: output should match shared-only.""" + cfg = small_cfg() + moe = MoEFFN(cfg) + for exp in moe.routed_experts: + for p in exp.parameters(): + p.data.zero_() + x = torch.randn(B, T, cfg.dim) + out = moe(x) + # Recompute shared-only + flat = x.view(B * T, cfg.dim) + shared_out = sum(s(flat) for s in moe.shared_experts) + expected = shared_out.view(B, T, cfg.dim) + assert torch.allclose(out, expected, atol=1e-5) + + def test_scores_sum_to_one_per_token(self): + """After renormalization, topk scores per token should sum to ~1.""" + cfg = small_cfg() + moe = MoEFFN(cfg) + x = torch.randn(B, T, cfg.dim) + flat = x.view(B * T, cfg.dim) + logits = moe.router(flat) + scores = torch.nn.functional.softmax(logits, dim=-1) + _, topk_idx = (logits + moe.router_bias).topk(moe.topk, dim=-1) + topk_scores = scores.gather(-1, topk_idx) + topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True).clamp( + min=1e-9 + ) + sums = topk_scores.sum(dim=-1) + assert torch.allclose(sums, torch.ones_like(sums), atol=1e-6) + + +# =================================================================== +# MoE Score Renormalization Epsilon (div-by-zero guard) +# =================================================================== + + +class TestMoEScoreEpsilon: + """Tests for the .clamp(min=1e-9) guard on score renormalization.""" + + def test_zero_scores_no_nan(self): + """If all topk softmax scores underflow to zero, output should not be NaN.""" + cfg = small_cfg() + moe = MoEFFN(cfg) + # Force router to produce extreme negative logits → softmax → ~0 + with torch.no_grad(): + moe.router.weight.fill_(-100.0) + x = torch.randn(B, T, cfg.dim) + out = moe(x) + assert not torch.isnan(out).any(), "NaN in output despite epsilon guard" + assert not torch.isinf(out).any(), "Inf in output despite epsilon guard" + + def test_near_zero_scores_bfloat16(self): + """Simulate bfloat16 underflow scenario with very small scores.""" + cfg = small_cfg() + moe = MoEFFN(cfg) + x = torch.randn(B, T, cfg.dim) + # Run in bfloat16 if available (the actual risk scenario) + if torch.cuda.is_available(): + moe = moe.to(torch.bfloat16).cuda() + x = x.to(torch.bfloat16).cuda() + out = moe(x) + assert not torch.isnan(out).any() + + def test_uniform_scores_stay_uniform(self): + """When all topk scores are equal, renorm should keep them equal.""" + cfg = small_cfg(n_experts=4, n_experts_per_tok=2) + moe = MoEFFN(cfg) + # Manually compute: equal softmax scores → equal after renorm + flat = torch.randn(4, cfg.dim) + logits = moe.router(flat) + # Make all logits equal so softmax is uniform + logits = torch.zeros_like(logits) + scores = torch.nn.functional.softmax(logits, dim=-1) + _, topk_idx = logits.topk(cfg.n_experts_per_tok, dim=-1) + topk_scores = scores.gather(-1, topk_idx) + topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True).clamp( + min=1e-9 + ) + # Each of 2 selected experts should get score 0.5 + assert torch.allclose( + topk_scores, torch.full_like(topk_scores, 0.5), atol=1e-5 + ) + + +# =================================================================== +# ACT Remainder for Non-Halted Positions +# =================================================================== + + +class TestACTRemainder: + """Tests for the post-loop remainder weight assignment. + + Uses the full OpenMythos model (MLA mode) to avoid the pre-existing + GQA RoPE dimension mismatch in RecurrentBlock-level tests. + """ + + def _make_model(self, **cfg_overrides): + cfg = small_cfg(attn_type="mla", **cfg_overrides) + model = OpenMythos(cfg) + return model, cfg + + def test_output_not_all_zero_with_low_halting(self): + """With very high threshold, positions won't halt but should still + produce nonzero output via the remainder.""" + model, cfg = self._make_model(act_threshold=0.9999) + # Bias ACT to predict very low halting probability + with torch.no_grad(): + model.recurrent.act.halt.weight.fill_(0.0) + model.recurrent.act.halt.bias.fill_(-10.0) # sigmoid(-10) ≈ 0.00005 + ids = torch.randint(0, cfg.vocab_size, (B, T)) + logits = model(ids) + assert logits.abs().sum() > 0, "Output is all zeros — remainder not applied" + assert not torch.isnan(logits).any() + + def test_remainder_does_not_double_count_halted(self): + """Positions that halted normally should NOT get additional remainder.""" + model, cfg = self._make_model(act_threshold=0.01) + # Bias ACT to halt immediately (high halting prob) + with torch.no_grad(): + model.recurrent.act.halt.weight.fill_(0.0) + model.recurrent.act.halt.bias.fill_(10.0) # sigmoid(10) ≈ 0.99995 + ids = torch.randint(0, cfg.vocab_size, (B, T)) + # Run twice — if remainder double-counts, outputs would differ + logits1 = model(ids) + logits2 = model(ids) + assert torch.allclose(logits1, logits2, atol=1e-5) + + def test_single_loop_remainder(self): + """With n_loops=1 and no halting, remainder should provide weight.""" + model, cfg = self._make_model(act_threshold=0.9999) + with torch.no_grad(): + model.recurrent.act.halt.weight.fill_(0.0) + model.recurrent.act.halt.bias.fill_(-10.0) + ids = torch.randint(0, cfg.vocab_size, (1, 1)) + logits = model(ids, n_loops=1) + assert logits.shape == (1, 1, cfg.vocab_size) + assert not torch.isnan(logits).any() + assert logits.abs().sum() > 0 + + def test_no_nan_with_many_loops(self): + """Run many loops with low halting — should never produce NaN.""" + model, cfg = self._make_model(act_threshold=0.9999, max_loop_iters=16) + with torch.no_grad(): + model.recurrent.act.halt.weight.fill_(0.0) + model.recurrent.act.halt.bias.fill_(-5.0) # sigmoid(-5) ≈ 0.007 + ids = torch.randint(0, cfg.vocab_size, (B, T)) + logits = model(ids, n_loops=16) + assert not torch.isnan(logits).any() + assert not torch.isinf(logits).any() + + def test_low_threshold_all_halt_early(self): + """Very low threshold + high halting prob → all positions halt in loop 1.""" + model, cfg = self._make_model(act_threshold=0.01) + with torch.no_grad(): + model.recurrent.act.halt.weight.fill_(0.0) + model.recurrent.act.halt.bias.fill_(10.0) + ids = torch.randint(0, cfg.vocab_size, (B, T)) + logits = model(ids, n_loops=5) + assert not torch.isnan(logits).any() + # Should produce valid logits even if everything halts immediately + assert logits.shape == (B, T, cfg.vocab_size) + + +# =================================================================== +# ACT Halting Weight Invariants +# =================================================================== + + +class TestACTWeightInvariants: + """Verify that ACT weights (halted + remainder) sum correctly.""" + + def test_weights_sum_to_one_all_halt(self): + """When all positions halt, accumulated weights should sum to ~1.""" + cfg = small_cfg(act_threshold=0.5) + act = ACTHalting(cfg.dim) + # Force high halting prob so everything halts in 1 iteration + with torch.no_grad(): + act.halt.weight.fill_(0.0) + act.halt.bias.fill_(10.0) + + B_, T_ = 2, 4 + halted = torch.zeros(B_, T_, dtype=torch.bool) + cumulative_p = torch.zeros(B_, T_) + total_weight = torch.zeros(B_, T_) + h = torch.randn(B_, T_, cfg.dim) + + for t in range(5): + p = act(h) + still_running = ~halted + remainder = (1.0 - cumulative_p).clamp(min=0) + weight = torch.where( + cumulative_p + p >= cfg.act_threshold, + remainder, + p, + ) + weight = weight * still_running.float() + total_weight = total_weight + weight + cumulative_p = cumulative_p + p * still_running.float() + halted = halted | (cumulative_p >= cfg.act_threshold) + + # Post-loop remainder for non-halted + not_halted = ~halted + if not_halted.any(): + final_remainder = (1.0 - cumulative_p).clamp(min=0) * not_halted.float() + total_weight = total_weight + final_remainder + + assert torch.allclose( + total_weight, torch.ones_like(total_weight), atol=1e-4 + ), f"Weights don't sum to 1: {total_weight}" + + def test_weights_sum_to_one_none_halt(self): + """When no positions halt within the loop, remainder ensures sum ~1.""" + cfg = small_cfg(act_threshold=0.9999) + act = ACTHalting(cfg.dim) + # Force very low halting prob + with torch.no_grad(): + act.halt.weight.fill_(0.0) + act.halt.bias.fill_(-10.0) # sigmoid(-10) ≈ 0.00005 + + B_, T_ = 2, 4 + halted = torch.zeros(B_, T_, dtype=torch.bool) + cumulative_p = torch.zeros(B_, T_) + total_weight = torch.zeros(B_, T_) + h = torch.randn(B_, T_, cfg.dim) + + for t in range(3): + p = act(h) + still_running = ~halted + remainder = (1.0 - cumulative_p).clamp(min=0) + weight = torch.where( + cumulative_p + p >= cfg.act_threshold, + remainder, + p, + ) + weight = weight * still_running.float() + total_weight = total_weight + weight + cumulative_p = cumulative_p + p * still_running.float() + halted = halted | (cumulative_p >= cfg.act_threshold) + + # Post-loop remainder + not_halted = ~halted + if not_halted.any(): + final_remainder = (1.0 - cumulative_p).clamp(min=0) * not_halted.float() + total_weight = total_weight + final_remainder + + assert torch.allclose( + total_weight, torch.ones_like(total_weight), atol=1e-4 + ), f"Weights don't sum to 1: {total_weight}" + + def test_weights_sum_to_one_mixed_halting(self): + """Mix of halted and non-halted positions: all weights sum to ~1.""" + cfg = small_cfg(act_threshold=0.5) + act = ACTHalting(cfg.dim) + + B_, T_ = 1, 8 + halted = torch.zeros(B_, T_, dtype=torch.bool) + cumulative_p = torch.zeros(B_, T_) + total_weight = torch.zeros(B_, T_) + h = torch.randn(B_, T_, cfg.dim) + + for t in range(3): + p = act(h) + still_running = ~halted + remainder = (1.0 - cumulative_p).clamp(min=0) + weight = torch.where( + cumulative_p + p >= cfg.act_threshold, + remainder, + p, + ) + weight = weight * still_running.float() + total_weight = total_weight + weight + cumulative_p = cumulative_p + p * still_running.float() + halted = halted | (cumulative_p >= cfg.act_threshold) + + # Post-loop remainder + not_halted = ~halted + if not_halted.any(): + final_remainder = (1.0 - cumulative_p).clamp(min=0) * not_halted.float() + total_weight = total_weight + final_remainder + + assert torch.allclose( + total_weight, torch.ones_like(total_weight), atol=1e-4 + ), f"Weights don't sum to 1: {total_weight}" + + +# =================================================================== +# LoRAAdapter.B dtype Safety +# =================================================================== + + +class TestLoRADtypeSafety: + """Tests for the defensive .to(down.dtype) cast on self.B.""" + + def test_float32_pass_through(self): + """Standard float32 — should work as before.""" + lora = LoRAAdapter(dim=64, rank=8, max_loops=10) + x = torch.randn(B, T, 64) + out = lora(x, loop_t=0) + assert out.dtype == torch.float32 + assert out.shape == (B, T, 64) + + def test_float16_input(self): + """float16 input with float32 parameters — cast should prevent mismatch.""" + lora = LoRAAdapter(dim=64, rank=8, max_loops=10) + x = torch.randn(B, T, 64, dtype=torch.float16) + out = lora(x, loop_t=0) + assert out.shape == (B, T, 64) + assert not torch.isnan(out).any() + + def test_bfloat16_input(self): + """bfloat16 input — the actual FSDP mixed precision scenario.""" + lora = LoRAAdapter(dim=64, rank=8, max_loops=10) + x = torch.randn(B, T, 64, dtype=torch.bfloat16) + out = lora(x, loop_t=0) + assert out.shape == (B, T, 64) + assert not torch.isnan(out).any() + + def test_B_param_dtype_mismatch_handled(self): + """Manually set B to a different dtype — the cast should still work.""" + lora = LoRAAdapter(dim=64, rank=8, max_loops=10) + # Simulate FSDP casting down to bfloat16 but B staying float32 + lora.down = lora.down.to(torch.bfloat16) + # B is still float32 + assert lora.B.dtype == torch.float32 + x = torch.randn(B, T, 64, dtype=torch.bfloat16) + # Should not raise RuntimeError about dtype mismatch + out = lora(x, loop_t=0) + assert out.shape == (B, T, 64) + + def test_gradient_flows_through_B(self): + """Verify the dtype cast doesn't block gradient flow to self.B.""" + lora = LoRAAdapter(dim=64, rank=8, max_loops=10) + x = torch.randn(B, T, 64) + out = lora(x, loop_t=0) + out.sum().backward() + assert lora.B.grad is not None + assert lora.B.grad.abs().sum() > 0 + + def test_loop_index_clamp(self): + """Exceeding max_loops should clamp to last index, not crash.""" + lora = LoRAAdapter(dim=64, rank=8, max_loops=5) + x = torch.randn(B, T, 64) + # loop_t=10 > max_loops=5 → should clamp to index 4 + out = lora(x, loop_t=10) + assert out.shape == (B, T, 64) + assert not torch.isnan(out).any() + + +# =================================================================== +# loop_index_embedding Float32 Precision +# =================================================================== + + +class TestLoopIndexEmbeddingPrecision: + """Tests for computing trig in float32 then casting back.""" + + def test_bfloat16_input_no_error(self): + """bfloat16 hidden state should work without dtype errors.""" + h = torch.randn(B, T, 64, dtype=torch.bfloat16) + out = loop_index_embedding(h, loop_t=5, loop_dim=8) + assert out.dtype == torch.bfloat16 + assert out.shape == h.shape + + def test_float16_input_preserves_dtype(self): + """float16 input should return float16 output.""" + h = torch.randn(B, T, 64, dtype=torch.float16) + out = loop_index_embedding(h, loop_t=3, loop_dim=8) + assert out.dtype == torch.float16 + + def test_float32_input_preserves_dtype(self): + """float32 input should return float32 output.""" + h = torch.randn(B, T, 64, dtype=torch.float32) + out = loop_index_embedding(h, loop_t=3, loop_dim=8) + assert out.dtype == torch.float32 + + def test_precision_matches_float32_reference(self): + """bfloat16 computation should match float32 reference (via the fix).""" + h_f32 = torch.randn(1, 1, 64, dtype=torch.float32) + h_bf16 = h_f32.to(torch.bfloat16) + + out_f32 = loop_index_embedding(h_f32, loop_t=7, loop_dim=16) + out_bf16 = loop_index_embedding(h_bf16, loop_t=7, loop_dim=16) + + # The embedding itself should be computed with float32 precision, + # so the difference should be only from bf16 quantization of h, not + # from bf16 trig functions. + diff = (out_f32 - out_bf16.float()).abs().max().item() + # bf16 has ~0.4% relative error; float32 trig vs bf16 trig would give + # much larger errors on high-frequency components + assert diff < 0.05, f"Precision gap too large: {diff}" + + def test_large_loop_index_no_nan(self): + """High loop indices should not produce NaN from overflow.""" + h = torch.randn(1, 1, 64, dtype=torch.bfloat16) + out = loop_index_embedding(h, loop_t=1000, loop_dim=8) + assert not torch.isnan(out).any() + + def test_loop_zero_is_nonzero_embedding(self): + """loop_t=0 should still add sin(0)/cos(0) = [0, ..., 1, ...] pattern.""" + h = torch.zeros(1, 1, 64, dtype=torch.float32) + out = loop_index_embedding(h, loop_t=0, loop_dim=8) + # sin(0)=0, cos(0)=1, so first 4 dims are 0, next 4 are 1 + # (because emb = cat([sin, cos])[:loop_dim]) + embedding = out[0, 0, :8] + assert embedding[:4].abs().sum() < 1e-5 # sin(0) = 0 + assert torch.allclose( + embedding[4:], torch.ones(4), atol=1e-5 + ) # cos(0) = 1 + + +# =================================================================== +# __init__.py Public API +# =================================================================== + + +class TestPublicAPI: + """Tests for the __init__.py exports.""" + + def test_no_broken_exports(self): + """Every symbol in __all__ should be importable.""" + import open_mythos + + for name in open_mythos.__all__: + assert hasattr(open_mythos, name), ( + f"'{name}' is in __all__ but not importable" + ) + + def test_removed_symbols_not_in_all(self): + """load_tokenizer and get_vocab_size should not be in __all__.""" + import open_mythos + + assert "load_tokenizer" not in open_mythos.__all__ + assert "get_vocab_size" not in open_mythos.__all__ + + def test_key_classes_exported(self): + """Core classes should remain in __all__.""" + import open_mythos + + required = [ + "MythosConfig", + "OpenMythos", + "MoEFFN", + "RecurrentBlock", + "MythosTokenizer", + ] + for name in required: + assert name in open_mythos.__all__, f"'{name}' missing from __all__" + + def test_import_from_package(self): + """Smoke test: importing key symbols from the package level.""" + from open_mythos import MythosConfig, OpenMythos, MoEFFN + + assert MythosConfig is not None + assert OpenMythos is not None + assert MoEFFN is not None + + +# =================================================================== +# Full Model Integration (exercises all fixes together) +# =================================================================== + + +class TestFullModelIntegration: + """End-to-end tests verifying all fixes work together in the full model.""" + + def setup_method(self): + self.cfg = small_cfg() + self.model = OpenMythos(self.cfg) + + def test_forward_no_nan(self): + ids = torch.randint(0, self.cfg.vocab_size, (B, T)) + logits = self.model(ids) + assert not torch.isnan(logits).any() + assert not torch.isinf(logits).any() + + def test_backward_no_error(self): + """Full forward+backward should work with all fixes in place.""" + ids = torch.randint(0, self.cfg.vocab_size, (B, T)) + logits = self.model(ids) + loss = logits.sum() + loss.backward() + # Check key parameters got gradients + assert self.model.embed.weight.grad is not None + + def test_generate_no_nan(self): + ids = torch.randint(0, self.cfg.vocab_size, (1, T)) + out = self.model.generate(ids, max_new_tokens=4, n_loops=2) + assert out.shape == (1, T + 4) + + def test_many_loops_no_nan(self): + """Depth extrapolation with many loops — exercises ACT remainder.""" + ids = torch.randint(0, self.cfg.vocab_size, (1, 4)) + logits = self.model(ids, n_loops=10) + assert not torch.isnan(logits).any() + + def test_single_token_input(self): + """Edge case: single token sequence.""" + ids = torch.randint(0, self.cfg.vocab_size, (1, 1)) + logits = self.model(ids) + assert logits.shape == (1, 1, self.cfg.vocab_size) + assert not torch.isnan(logits).any() + + +if __name__ == "__main__": + pytest.main([__file__, "--verbose"]) diff --git a/tests/test_components.py b/tests/test_components.py new file mode 100644 index 0000000..6b87ba2 --- /dev/null +++ b/tests/test_components.py @@ -0,0 +1,846 @@ +""" +Comprehensive component-level tests for every module in open_mythos/main.py. + +Covers: RMSNorm, precompute_rope_freqs, apply_rope, GQAttention, MLAttention, +Expert, TransformerBlock, LTIInjection, RecurrentBlock, and OpenMythos. + +All tests run on CPU with small configs (dim=64, vocab_size=200, etc.). +""" + +import pytest +import torch +import torch.nn as nn + +from open_mythos.main import ( + ACTHalting, + Expert, + GQAttention, + LoRAAdapter, + LTIInjection, + MLAttention, + MoEFFN, + MythosConfig, + OpenMythos, + RecurrentBlock, + RMSNorm, + TransformerBlock, + apply_rope, + loop_index_embedding, + precompute_rope_freqs, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +B, T = 2, 8 + + +def small_cfg(**overrides) -> MythosConfig: + defaults = dict( + vocab_size=200, + dim=64, + n_heads=4, + n_kv_heads=2, + max_seq_len=32, + max_loop_iters=3, + prelude_layers=1, + coda_layers=1, + attn_type="gqa", + n_experts=4, + n_shared_experts=1, + n_experts_per_tok=2, + expert_dim=16, + act_threshold=0.99, + lora_rank=4, + kv_lora_rank=16, + q_lora_rank=32, + qk_rope_head_dim=8, + qk_nope_head_dim=8, + v_head_dim=8, + ) + defaults.update(overrides) + return MythosConfig(**defaults) + + +# ===================================================================== +# TestRMSNorm +# ===================================================================== + + +class TestRMSNorm: + """Tests for the RMSNorm layer.""" + + def test_output_shape(self): + """Output matches input shape.""" + norm = RMSNorm(64) + x = torch.randn(B, T, 64) + out = norm(x) + assert out.shape == x.shape + + def test_normalization_magnitude(self): + """Output RMS is approximately 1 (within tolerance).""" + norm = RMSNorm(64) + x = torch.randn(B, T, 64) * 10.0 # large scale input + out = norm(x) + # With weight=1, the RMS of each output vector should be ~1 + rms = out.float().pow(2).mean(-1).sqrt() + assert torch.allclose(rms, torch.ones_like(rms), atol=0.1) + + def test_zero_input(self): + """Zero input produces zero output.""" + norm = RMSNorm(64) + x = torch.zeros(B, T, 64) + out = norm(x) + assert torch.allclose(out, torch.zeros_like(out)) + + def test_gradient_flows(self): + """Gradients reach the weight parameter.""" + norm = RMSNorm(64) + x = torch.randn(B, T, 64, requires_grad=True) + out = norm(x) + loss = out.sum() + loss.backward() + assert norm.weight.grad is not None + assert norm.weight.grad.abs().sum() > 0 + assert x.grad is not None + + def test_learned_weight_effect(self): + """Changing weight parameter changes output.""" + norm = RMSNorm(64) + x = torch.randn(B, T, 64) + out1 = norm(x).clone() + # Scale the weight by 2 + with torch.no_grad(): + norm.weight.mul_(2.0) + out2 = norm(x) + assert not torch.allclose(out1, out2) + # Outputs should be in a 2:1 ratio + ratio = out2 / (out1 + 1e-12) + assert torch.allclose(ratio[out1.abs() > 1e-6], torch.tensor(2.0), atol=0.01) + + def test_eps_prevents_nan(self): + """Very small input doesn't produce NaN.""" + norm = RMSNorm(64, eps=1e-6) + x = torch.full((B, T, 64), 1e-20) + out = norm(x) + assert not torch.isnan(out).any() + assert not torch.isinf(out).any() + + def test_preserves_dtype(self): + """float16 and bfloat16 inputs return same dtype.""" + norm = RMSNorm(64) + for dtype in [torch.float16, torch.bfloat16]: + x = torch.randn(B, T, 64, dtype=dtype) + out = norm(x) + assert out.dtype == dtype, f"Expected {dtype}, got {out.dtype}" + + +# ===================================================================== +# TestRoPE +# ===================================================================== + + +class TestRoPE: + """Tests for precompute_rope_freqs and apply_rope.""" + + def test_freqs_shape(self): + """precompute_rope_freqs returns (max_len, dim//2) complex tensor.""" + dim, max_len = 16, 32 + freqs = precompute_rope_freqs(dim, max_len) + assert freqs.shape == (max_len, dim // 2) + assert freqs.is_complex() + + def test_freqs_unit_magnitude(self): + """All phasors have magnitude 1.""" + freqs = precompute_rope_freqs(16, 32) + magnitudes = freqs.abs() + assert torch.allclose(magnitudes, torch.ones_like(magnitudes), atol=1e-6) + + def test_freqs_position_zero_identity(self): + """freqs[0] are all 1+0j (zero rotation).""" + freqs = precompute_rope_freqs(16, 32) + expected = torch.ones(8, dtype=torch.complex64) + assert torch.allclose(freqs[0], expected, atol=1e-6) + + def test_apply_rope_shape_preserved(self): + """Output shape matches input.""" + dim = 16 + freqs = precompute_rope_freqs(dim, T) + x = torch.randn(B, T, 4, dim) + out = apply_rope(x, freqs) + assert out.shape == x.shape + + def test_apply_rope_norm_preserved(self): + """RoPE is an isometry (norm doesn't change).""" + dim = 16 + freqs = precompute_rope_freqs(dim, T) + x = torch.randn(B, T, 4, dim) + out = apply_rope(x, freqs) + norms_in = x.float().norm(dim=-1) + norms_out = out.float().norm(dim=-1) + assert torch.allclose(norms_in, norms_out, atol=1e-5) + + def test_apply_rope_position_zero_identity(self): + """Position 0 doesn't change the tensor.""" + dim = 16 + freqs = precompute_rope_freqs(dim, 1) + x = torch.randn(B, 1, 4, dim) + out = apply_rope(x, freqs) + assert torch.allclose(x, out, atol=1e-6) + + def test_apply_rope_dtype_preserved(self): + """Preserves float16 and bfloat16.""" + dim = 16 + freqs = precompute_rope_freqs(dim, T) + for dtype in [torch.float16, torch.bfloat16]: + x = torch.randn(B, T, 4, dim, dtype=dtype) + out = apply_rope(x, freqs) + assert out.dtype == dtype, f"Expected {dtype}, got {out.dtype}" + + +# ===================================================================== +# TestGQAttention +# ===================================================================== + + +class TestGQAttention: + """Tests for Grouped Query Attention.""" + + @pytest.fixture + def gqa_setup(self): + cfg = small_cfg(attn_type="gqa") + attn = GQAttention(cfg) + head_dim = cfg.dim // cfg.n_heads + freqs = precompute_rope_freqs(head_dim, cfg.max_seq_len) + x = torch.randn(B, T, cfg.dim) + return attn, cfg, freqs, x + + def test_output_shape(self, gqa_setup): + """(B, T, dim) output.""" + attn, cfg, freqs, x = gqa_setup + out = attn(x, freqs[:T]) + assert out.shape == (B, T, cfg.dim) + + def test_forward_no_nan(self, gqa_setup): + """Standard forward pass has no NaN.""" + attn, cfg, freqs, x = gqa_setup + out = attn(x, freqs[:T]) + assert not torch.isnan(out).any() + + def test_kv_cache_populates(self, gqa_setup): + """Passing kv_cache dict gets populated.""" + attn, cfg, freqs, x = gqa_setup + cache = {} + attn(x, freqs[:T], kv_cache=cache, cache_key="layer0") + assert "layer0" in cache + assert "k" in cache["layer0"] + assert "v" in cache["layer0"] + assert cache["layer0"]["k"].shape[1] == T + assert cache["layer0"]["v"].shape[1] == T + + def test_kv_cache_decode_step(self, gqa_setup): + """Decode with cache produces correct shape.""" + attn, cfg, freqs, x = gqa_setup + cache = {} + # Prefill + attn(x, freqs[:T], kv_cache=cache, cache_key="layer0") + # Decode step: single token + x_decode = torch.randn(B, 1, cfg.dim) + out = attn(x_decode, freqs[T : T + 1], kv_cache=cache, cache_key="layer0") + assert out.shape == (B, 1, cfg.dim) + # Cache should now have T+1 entries + assert cache["layer0"]["k"].shape[1] == T + 1 + + def test_causal_mask_effect(self, gqa_setup): + """With mask, future tokens don't leak.""" + attn, cfg, freqs, x = gqa_setup + mask = OpenMythos._causal_mask(T, x.device, x.dtype) + out_masked = attn(x, freqs[:T], mask=mask) + out_unmasked = attn(x, freqs[:T], mask=None) + # Outputs should differ because the mask blocks future tokens + assert not torch.allclose(out_masked, out_unmasked, atol=1e-5) + + def test_gradient_flows(self, gqa_setup): + """Gradients reach wq, wk, wv, wo.""" + attn, cfg, freqs, x = gqa_setup + x = x.requires_grad_(True) + out = attn(x, freqs[:T]) + loss = out.sum() + loss.backward() + for name in ["wq", "wk", "wv", "wo"]: + param = getattr(attn, name).weight + assert param.grad is not None, f"No gradient for {name}" + assert param.grad.abs().sum() > 0, f"Zero gradient for {name}" + + def test_different_n_kv_heads(self): + """GQA grouping works with different ratios.""" + for n_kv_heads in [1, 2, 4]: + cfg = small_cfg(attn_type="gqa", n_heads=4, n_kv_heads=n_kv_heads) + attn = GQAttention(cfg) + head_dim = cfg.dim // cfg.n_heads + freqs = precompute_rope_freqs(head_dim, cfg.max_seq_len) + x = torch.randn(B, T, cfg.dim) + out = attn(x, freqs[:T]) + assert out.shape == (B, T, cfg.dim) + assert not torch.isnan(out).any() + + +# ===================================================================== +# TestMLAttention +# ===================================================================== + + +class TestMLAttention: + """Tests for Multi-Latent Attention.""" + + @pytest.fixture + def mla_setup(self): + cfg = small_cfg(attn_type="mla") + attn = MLAttention(cfg) + freqs = precompute_rope_freqs(cfg.qk_rope_head_dim, cfg.max_seq_len) + x = torch.randn(B, T, cfg.dim) + return attn, cfg, freqs, x + + def test_output_shape(self, mla_setup): + """(B, T, dim) output.""" + attn, cfg, freqs, x = mla_setup + out = attn(x, freqs[:T]) + assert out.shape == (B, T, cfg.dim) + + def test_forward_no_nan(self, mla_setup): + """Standard forward pass has no NaN.""" + attn, cfg, freqs, x = mla_setup + out = attn(x, freqs[:T]) + assert not torch.isnan(out).any() + + def test_kv_cache_populates(self, mla_setup): + """Cache stores c_kv and k_rope (not full K/V).""" + attn, cfg, freqs, x = mla_setup + cache = {} + attn(x, freqs[:T], kv_cache=cache, cache_key="mla0") + assert "mla0" in cache + assert "c_kv" in cache["mla0"] + assert "k_rope" in cache["mla0"] + # c_kv should have shape (B, T, kv_lora_rank) + assert cache["mla0"]["c_kv"].shape == (B, T, cfg.kv_lora_rank) + + def test_kv_cache_decode_step(self, mla_setup): + """Decode step with cache.""" + attn, cfg, freqs, x = mla_setup + cache = {} + # Prefill + attn(x, freqs[:T], kv_cache=cache, cache_key="mla0") + # Decode + x_decode = torch.randn(B, 1, cfg.dim) + out = attn(x_decode, freqs[T : T + 1], kv_cache=cache, cache_key="mla0") + assert out.shape == (B, 1, cfg.dim) + assert cache["mla0"]["c_kv"].shape[1] == T + 1 + + def test_cache_size_smaller_than_gqa(self): + """Verify MLA cache is smaller than equivalent GQA cache.""" + cfg = small_cfg(attn_type="mla") + mla = MLAttention(cfg) + freqs_mla = precompute_rope_freqs(cfg.qk_rope_head_dim, cfg.max_seq_len) + x = torch.randn(B, T, cfg.dim) + + mla_cache = {} + mla(x, freqs_mla[:T], kv_cache=mla_cache, cache_key="mla") + + # MLA stores c_kv (B, T, kv_lora_rank) + k_rope (B, T, n_heads, qk_rope_head_dim) + mla_size = ( + mla_cache["mla"]["c_kv"].numel() + mla_cache["mla"]["k_rope"].numel() + ) + + # Equivalent GQA stores k (B, T, n_kv_heads, head_dim) + v (same) + cfg_gqa = small_cfg(attn_type="gqa") + gqa = GQAttention(cfg_gqa) + head_dim = cfg_gqa.dim // cfg_gqa.n_heads + freqs_gqa = precompute_rope_freqs(head_dim, cfg_gqa.max_seq_len) + + gqa_cache = {} + gqa(x, freqs_gqa[:T], kv_cache=gqa_cache, cache_key="gqa") + + gqa_size = gqa_cache["gqa"]["k"].numel() + gqa_cache["gqa"]["v"].numel() + + assert mla_size < gqa_size, ( + f"MLA cache ({mla_size}) should be smaller than GQA cache ({gqa_size})" + ) + + def test_gradient_flows(self, mla_setup): + """Gradients reach key projections.""" + attn, cfg, freqs, x = mla_setup + x = x.requires_grad_(True) + out = attn(x, freqs[:T]) + loss = out.sum() + loss.backward() + for name in ["q_down", "q_up_nope", "q_up_rope", "kv_down", "kv_up", "wo"]: + param = getattr(attn, name).weight + assert param.grad is not None, f"No gradient for {name}" + assert param.grad.abs().sum() > 0, f"Zero gradient for {name}" + + +# ===================================================================== +# TestExpert +# ===================================================================== + + +class TestExpert: + """Tests for the SwiGLU Expert FFN.""" + + def test_output_shape(self): + """(B, T, dim) output.""" + expert = Expert(64, 16) + x = torch.randn(B, T, 64) + out = expert(x) + assert out.shape == (B, T, 64) + + def test_swiglu_forward(self): + """Basic forward pass works and is not trivially zero.""" + expert = Expert(64, 16) + x = torch.randn(B, T, 64) + out = expert(x) + assert not torch.isnan(out).any() + assert out.abs().sum() > 0 + + def test_gradient_flows(self): + """All three weight matrices get gradients.""" + expert = Expert(64, 16) + x = torch.randn(B, T, 64, requires_grad=True) + out = expert(x) + loss = out.sum() + loss.backward() + for name in ["gate", "up", "down"]: + param = getattr(expert, name).weight + assert param.grad is not None, f"No gradient for {name}" + assert param.grad.abs().sum() > 0, f"Zero gradient for {name}" + + def test_dtype_alignment(self): + """float16 input works with float32 params (the FSDP dtype cast).""" + expert = Expert(64, 16) # float32 params + x = torch.randn(B, T, 64, dtype=torch.float16) + out = expert(x) + # The expert casts x to param dtype internally, so output is float32 + assert not torch.isnan(out).any() + assert out.shape == (B, T, 64) + + +# ===================================================================== +# TestTransformerBlock +# ===================================================================== + + +class TestTransformerBlock: + """Tests for the pre-norm TransformerBlock.""" + + def test_output_shape_dense_ffn(self): + """With use_moe=False.""" + cfg = small_cfg(attn_type="gqa") + block = TransformerBlock(cfg, use_moe=False) + head_dim = cfg.dim // cfg.n_heads + freqs = precompute_rope_freqs(head_dim, cfg.max_seq_len) + x = torch.randn(B, T, cfg.dim) + out = block(x, freqs[:T]) + assert out.shape == (B, T, cfg.dim) + + def test_output_shape_moe_ffn(self): + """With use_moe=True.""" + cfg = small_cfg(attn_type="gqa") + block = TransformerBlock(cfg, use_moe=True) + head_dim = cfg.dim // cfg.n_heads + freqs = precompute_rope_freqs(head_dim, cfg.max_seq_len) + x = torch.randn(B, T, cfg.dim) + out = block(x, freqs[:T]) + assert out.shape == (B, T, cfg.dim) + + def test_residual_connection(self): + """Output is not identical to just FFN(Attn(x)) -- residual adds input.""" + cfg = small_cfg(attn_type="gqa") + block = TransformerBlock(cfg, use_moe=False) + head_dim = cfg.dim // cfg.n_heads + freqs = precompute_rope_freqs(head_dim, cfg.max_seq_len) + x = torch.randn(B, T, cfg.dim) + out = block(x, freqs[:T]) + # If there were no residual, the output would be independent of x's exact + # values in a very different way. Check that out != 0 (non-trivial) and + # that it is close to x + something (the residual sum pattern). + diff = out - x + # The residual connection ensures out != x (attention+FFN output is non-zero) + assert diff.abs().sum() > 0, "Block should modify input via attention+FFN" + # But also out should be correlated with x (residual keeps the signal) + cosine_sim = torch.nn.functional.cosine_similarity( + out.flatten(), x.flatten(), dim=0 + ) + assert cosine_sim > 0.5, "Residual connection should preserve input signal" + + def test_forward_no_nan(self): + """Both GQA and MLA modes.""" + for attn_type in ["gqa", "mla"]: + cfg = small_cfg(attn_type=attn_type) + block = TransformerBlock(cfg, use_moe=False) + if attn_type == "mla": + freqs = precompute_rope_freqs(cfg.qk_rope_head_dim, cfg.max_seq_len) + else: + head_dim = cfg.dim // cfg.n_heads + freqs = precompute_rope_freqs(head_dim, cfg.max_seq_len) + x = torch.randn(B, T, cfg.dim) + out = block(x, freqs[:T]) + assert not torch.isnan(out).any(), f"NaN in {attn_type} mode" + + def test_gradient_flows(self): + """Gradients propagate through the block.""" + cfg = small_cfg(attn_type="gqa") + block = TransformerBlock(cfg, use_moe=False) + head_dim = cfg.dim // cfg.n_heads + freqs = precompute_rope_freqs(head_dim, cfg.max_seq_len) + x = torch.randn(B, T, cfg.dim, requires_grad=True) + out = block(x, freqs[:T]) + loss = out.sum() + loss.backward() + assert x.grad is not None + assert x.grad.abs().sum() > 0 + # Check that attention weights got gradients + assert block.attn.wq.weight.grad is not None + + +# ===================================================================== +# TestLTIInjection +# ===================================================================== + + +class TestLTIInjection: + """Tests for the LTI-stable injection module.""" + + def test_spectral_radius_below_one(self): + """get_A() values are all in (0, 1) -- THE key invariant.""" + lti = LTIInjection(64) + A = lti.get_A() + assert (A > 0).all(), "A values must be strictly positive" + assert (A < 1).all(), "A values must be strictly less than 1" + + def test_spectral_radius_extreme_params(self): + """Even with extreme log_A and log_dt, A stays in [0, 1] and is finite. + + Mathematically A is in the open interval (0, 1), but float32 can round + to 0.0 when exp(-very_large) underflows, or to 1.0 when exp(-very_small) + rounds up. The important guarantee is: A never exceeds 1 and never goes + negative, so the system is non-explosive (spectral radius <= 1). + """ + lti = LTIInjection(64) + + # Large positive params -> exp(log_dt + log_A) is huge -> A ~ 0 + with torch.no_grad(): + lti.log_A.fill_(10.0) + lti.log_dt.fill_(10.0) + A = lti.get_A() + assert (A >= 0).all() and (A <= 1).all(), "A out of [0,1] with large params" + assert torch.isfinite(A).all(), "A not finite with large params" + + # Large negative params -> exp(log_dt + log_A) is tiny -> A ~ 1 + with torch.no_grad(): + lti.log_A.fill_(-10.0) + lti.log_dt.fill_(-10.0) + A = lti.get_A() + assert (A >= 0).all() and (A <= 1).all(), "A out of [0,1] with negative params" + assert torch.isfinite(A).all(), "A not finite with negative params" + + # Mixed extremes + with torch.no_grad(): + lti.log_A.fill_(15.0) + lti.log_dt.fill_(-15.0) + A = lti.get_A() + assert (A >= 0).all() and (A <= 1).all(), "A out of [0,1] with mixed params" + assert torch.isfinite(A).all(), "A not finite with mixed params" + + # Moderate values -> A strictly in (0, 1) + with torch.no_grad(): + lti.log_A.fill_(0.0) + lti.log_dt.fill_(0.0) + A = lti.get_A() + assert (A > 0).all() and (A < 1).all(), "A out of (0,1) with moderate params" + + def test_forward_shape(self): + """Output matches input shape.""" + lti = LTIInjection(64) + h = torch.randn(B, T, 64) + e = torch.randn(B, T, 64) + trans_out = torch.randn(B, T, 64) + out = lti(h, e, trans_out) + assert out.shape == (B, T, 64) + + def test_stability_many_iterations(self): + """Iterated application doesn't explode.""" + lti = LTIInjection(64) + h = torch.randn(B, T, 64) + e = torch.randn(B, T, 64) * 0.1 + for _ in range(100): + trans_out = torch.zeros(B, T, 64) + h = lti(h, e, trans_out) + assert not torch.isnan(h).any(), "NaN after 100 iterations" + assert not torch.isinf(h).any(), "Inf after 100 iterations" + # The state should converge toward a fixed point since A < 1 + h_norm = h.norm() + assert h_norm < 1e6, f"State norm {h_norm} is too large after 100 steps" + + def test_gradient_flows(self): + """Gradients reach log_A, log_dt, B.""" + lti = LTIInjection(64) + h = torch.randn(B, T, 64, requires_grad=True) + e = torch.randn(B, T, 64) + trans_out = torch.randn(B, T, 64) + out = lti(h, e, trans_out) + loss = out.sum() + loss.backward() + assert lti.log_A.grad is not None, "No gradient for log_A" + assert lti.log_dt.grad is not None, "No gradient for log_dt" + assert lti.B.grad is not None, "No gradient for B" + assert lti.log_A.grad.abs().sum() > 0 + + +# ===================================================================== +# TestRecurrentBlock +# ===================================================================== + + +class TestRecurrentBlock: + """Tests for the RecurrentBlock with ACT, LoRA, and LTI.""" + + @pytest.fixture + def recurrent_setup(self): + cfg = small_cfg(attn_type="gqa") + block = RecurrentBlock(cfg) + head_dim = cfg.dim // cfg.n_heads + freqs = precompute_rope_freqs(head_dim, cfg.max_seq_len) + h = torch.randn(B, T, cfg.dim) + e = torch.randn(B, T, cfg.dim) + mask = OpenMythos._causal_mask(T, h.device, h.dtype) + return block, cfg, freqs, h, e, mask + + def test_output_shape_gqa(self, recurrent_setup): + """(B, T, dim) with GQA attention.""" + block, cfg, freqs, h, e, mask = recurrent_setup + out = block(h, e, freqs[:T], mask) + assert out.shape == (B, T, cfg.dim) + + def test_output_shape_mla(self): + """(B, T, dim) with MLA attention.""" + cfg = small_cfg(attn_type="mla") + block = RecurrentBlock(cfg) + freqs = precompute_rope_freqs(cfg.qk_rope_head_dim, cfg.max_seq_len) + h = torch.randn(B, T, cfg.dim) + e = torch.randn(B, T, cfg.dim) + mask = OpenMythos._causal_mask(T, h.device, h.dtype) + out = block(h, e, freqs[:T], mask) + assert out.shape == (B, T, cfg.dim) + + def test_loops_override(self, recurrent_setup): + """n_loops parameter changes behavior.""" + block, cfg, freqs, h, e, mask = recurrent_setup + torch.manual_seed(42) + out_2 = block(h, e, freqs[:T], mask, n_loops=2) + torch.manual_seed(42) + out_3 = block(h, e, freqs[:T], mask, n_loops=3) + # Different number of loops should yield different results + assert not torch.allclose(out_2, out_3, atol=1e-5) + + def test_act_early_exit_without_cache(self): + """When halted.all() is true and no cache, the loop breaks early.""" + # Use an extremely low threshold so halting triggers early + cfg = small_cfg(attn_type="gqa", act_threshold=0.01, max_loop_iters=10) + block = RecurrentBlock(cfg) + h = torch.randn(B, T, cfg.dim) + e = torch.randn(B, T, cfg.dim) + head_dim = cfg.dim // cfg.n_heads + freqs = precompute_rope_freqs(head_dim, cfg.max_seq_len) + mask = OpenMythos._causal_mask(T, h.device, h.dtype) + + # Bias the halting head strongly so sigmoid outputs near 1.0 + with torch.no_grad(): + block.act.halt.weight.fill_(0.0) + block.act.halt.bias.fill_(10.0) # sigmoid(10) ~ 1.0 + + out = block(h, e, freqs[:T], mask, n_loops=10) + assert out.shape == (B, T, cfg.dim) + assert not torch.isnan(out).any() + + def test_kv_cache_no_early_exit(self, recurrent_setup): + """With cache, loop always runs all iterations (no early exit).""" + block, cfg, freqs, h, e, mask = recurrent_setup + cache = {} + + # Bias halting high so it would normally exit early + with torch.no_grad(): + block.act.halt.weight.fill_(0.0) + block.act.halt.bias.fill_(10.0) + + out = block(h, e, freqs[:T], mask, n_loops=3, kv_cache=cache) + assert out.shape == (B, T, cfg.dim) + # All 3 loop iterations should have created cache entries + for t in range(3): + assert f"recurrent_loop_{t}" in cache, ( + f"Cache key for loop {t} missing -- early exit happened with cache" + ) + + def test_gradient_flows(self, recurrent_setup): + """End-to-end gradient through the recurrent block.""" + block, cfg, freqs, h, e, mask = recurrent_setup + h = h.requires_grad_(True) + out = block(h, e, freqs[:T], mask, n_loops=2) + loss = out.sum() + loss.backward() + assert h.grad is not None + assert h.grad.abs().sum() > 0 + # Check LTI gets gradients + assert block.injection.log_A.grad is not None + # Check LoRA gets gradients + assert block.lora.down.weight.grad is not None + + def test_depth_extrapolation(self, recurrent_setup): + """n_loops > max_loop_iters works (LoRA clamping).""" + block, cfg, freqs, h, e, mask = recurrent_setup + # cfg.max_loop_iters=3, so n_loops=5 exceeds it + out = block(h, e, freqs[:T], mask, n_loops=5) + assert out.shape == (B, T, cfg.dim) + assert not torch.isnan(out).any() + + +# ===================================================================== +# TestOpenMythosModel +# ===================================================================== + + +class TestOpenMythosModel: + """Tests for the full OpenMythos model.""" + + @pytest.fixture + def gqa_model(self): + cfg = small_cfg(attn_type="gqa") + model = OpenMythos(cfg) + model.eval() + return model, cfg + + @pytest.fixture + def mla_model(self): + cfg = small_cfg(attn_type="mla") + model = OpenMythos(cfg) + model.eval() + return model, cfg + + def test_weight_tying(self, gqa_model): + """head.weight is embed.weight.""" + model, cfg = gqa_model + assert model.head.weight is model.embed.weight + + def test_causal_mask_shape(self): + """_causal_mask returns correct shape.""" + mask = OpenMythos._causal_mask(T, torch.device("cpu"), torch.float32) + assert mask.shape == (1, 1, T, T) + + def test_causal_mask_values(self): + """Upper triangle is -inf, lower triangle and diagonal are 0.""" + mask = OpenMythos._causal_mask(T, torch.device("cpu"), torch.float32) + mask_2d = mask.squeeze(0).squeeze(0) + # Diagonal and below should be 0 + lower = torch.tril(torch.ones(T, T, dtype=torch.bool)) + assert (mask_2d[lower] == 0.0).all(), "Lower triangle should be 0" + # Above diagonal should be -inf + upper = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1) + assert (mask_2d[upper] == float("-inf")).all(), "Upper triangle should be -inf" + + def test_attn_type_gqa(self, gqa_model): + """Model works with attn_type='gqa'.""" + model, cfg = gqa_model + input_ids = torch.randint(0, cfg.vocab_size, (B, T)) + logits = model(input_ids) + assert logits.shape == (B, T, cfg.vocab_size) + assert not torch.isnan(logits).any() + + def test_attn_type_mla(self, mla_model): + """Model works with attn_type='mla'.""" + model, cfg = mla_model + input_ids = torch.randint(0, cfg.vocab_size, (B, T)) + logits = model(input_ids) + assert logits.shape == (B, T, cfg.vocab_size) + assert not torch.isnan(logits).any() + + def test_generate_basic(self, gqa_model): + """Generate produces correct shape.""" + model, cfg = gqa_model + input_ids = torch.randint(0, cfg.vocab_size, (1, 4)) + max_new = 3 + output = model.generate(input_ids, max_new_tokens=max_new, n_loops=2) + assert output.shape == (1, 4 + max_new) + + def test_generate_temperature(self, gqa_model): + """Temperature=0.01 is near-greedy (repeated runs are nearly identical).""" + model, cfg = gqa_model + input_ids = torch.randint(0, cfg.vocab_size, (1, 4)) + results = [] + for _ in range(3): + torch.manual_seed(0) + out = model.generate( + input_ids.clone(), max_new_tokens=3, n_loops=2, temperature=0.01 + ) + results.append(out) + # With very low temperature and same seed, all should be identical + assert torch.equal(results[0], results[1]) + assert torch.equal(results[1], results[2]) + + def test_generate_top_k(self, gqa_model): + """Top_k=1 is deterministic.""" + model, cfg = gqa_model + input_ids = torch.randint(0, cfg.vocab_size, (1, 4)) + out1 = model.generate( + input_ids.clone(), max_new_tokens=3, n_loops=2, top_k=1 + ) + out2 = model.generate( + input_ids.clone(), max_new_tokens=3, n_loops=2, top_k=1 + ) + # top_k=1 forces the argmax token each step, so results must match + assert torch.equal(out1, out2) + + def test_forward_with_kv_cache(self, gqa_model): + """Cache-based forward works.""" + model, cfg = gqa_model + input_ids = torch.randint(0, cfg.vocab_size, (B, T)) + cache = {} + logits = model(input_ids, kv_cache=cache, start_pos=0) + assert logits.shape == (B, T, cfg.vocab_size) + # Cache should be populated + assert len(cache) > 0 + + # Decode step + next_ids = torch.randint(0, cfg.vocab_size, (B, 1)) + logits_decode = model(next_ids, kv_cache=cache, start_pos=T) + assert logits_decode.shape == (B, 1, cfg.vocab_size) + + def test_start_pos_affects_rope(self, gqa_model): + """Different start_pos gives different results when used with KV cache. + + RoPE encodes *relative* positions: without a cache, shifting all Q and K + by the same offset cancels out in the dot product. The effect of + start_pos becomes visible during decode, where cached keys were encoded + at earlier positions and a new query is encoded at a different offset. + """ + model, cfg = gqa_model + + prompt = torch.randint(0, cfg.vocab_size, (1, 4)) + next_tok = torch.randint(0, cfg.vocab_size, (1, 1)) + + # Path A: prefill at pos 0, decode at pos 4 + cache_a = {} + model(prompt, kv_cache=cache_a, start_pos=0) + logits_a = model(next_tok, kv_cache=cache_a, start_pos=4) + + # Path B: prefill at pos 0, decode at pos 10 (wrong position) + cache_b = {} + model(prompt, kv_cache=cache_b, start_pos=0) + logits_b = model(next_tok, kv_cache=cache_b, start_pos=10) + + # The cached keys were encoded at positions 0..3 in both cases, but the + # query token is encoded at position 4 vs 10, changing the relative + # distances and therefore the attention weights via RoPE. + assert not torch.allclose(logits_a, logits_b, atol=1e-4), ( + "Different start_pos during decode should change logits via RoPE " + "relative position encoding" + ) diff --git a/tests/test_moda.py b/tests/test_moda.py new file mode 100644 index 0000000..2a5d111 --- /dev/null +++ b/tests/test_moda.py @@ -0,0 +1,710 @@ +"""Comprehensive tests for open_mythos/moda.py — MoDA + DeepSeek MoE architecture. + +Tests every public class: + RMSNorm, RotaryEmbedding, apply_rotary_emb, DeepSeekExpert, DeepSeekGate, + DeepSeekMoE, MoDAAttention, MoDABlock, MoDAModel. + +All tests use tiny configs (d_model=64, 4 experts) and run on CPU. +""" + +from __future__ import annotations + +import math + +import pytest +import torch +import torch.nn as nn + +from open_mythos.moda import ( + MoDAConfig, + RMSNorm, + RotaryEmbedding, + apply_rotary_emb, + _rotate_half, + DeepSeekExpert, + DeepSeekGate, + DeepSeekMoE, + _SharedFFN, + MoDAAttention, + MoDABlock, + MoDAModel, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +B, T = 2, 8 # batch, sequence length used across all tests + + +def tiny_cfg(**overrides) -> MoDAConfig: + defaults = dict( + vocab_size=200, + d_model=64, + n_layers=2, + n_heads_q=4, + n_heads_kv=2, + head_dim=16, + max_seq_len=32, + rope_base=10000.0, + attn_dropout=0.0, + norm_eps=1e-6, + n_shared_experts=1, + n_routed_experts=4, + n_activated_experts=2, + expert_hidden_dim=32, + moe_balance_alpha=0.001, + moe_score_func="softmax", + moe_n_groups=1, + moe_topk_groups=1, + moe_route_scale=1.0, + ) + defaults.update(overrides) + return MoDAConfig(**defaults) + + +# =========================================================================== +# TestMoDANorm +# =========================================================================== + + +class TestMoDANorm: + """Tests for the RMSNorm module in moda.py.""" + + def test_output_shape(self): + norm = RMSNorm(64) + x = torch.randn(B, T, 64) + out = norm(x) + assert out.shape == (B, T, 64) + + def test_normalization_effect(self): + """After RMSNorm with unit weight the RMS of each vector should be ~1.""" + norm = RMSNorm(64, eps=1e-8) + x = torch.randn(B, T, 64) * 10.0 # large-magnitude input + out = norm(x) + rms = out.pow(2).mean(-1).sqrt() + # With unit weight, RMS should be close to 1 + assert torch.allclose(rms, torch.ones_like(rms), atol=0.05) + + def test_gradient_flow(self): + norm = RMSNorm(64) + x = torch.randn(B, T, 64, requires_grad=True) + out = norm(x) + loss = out.sum() + loss.backward() + assert x.grad is not None + assert not torch.all(x.grad == 0) + assert norm.weight.grad is not None + + def test_learnable_weight(self): + """The weight parameter is initialized to ones and is learnable.""" + norm = RMSNorm(32) + assert norm.weight.shape == (32,) + assert torch.allclose(norm.weight.data, torch.ones(32)) + + def test_different_input_dims(self): + """Works with arbitrary leading dimensions.""" + norm = RMSNorm(16) + for shape in [(16,), (3, 16), (2, 4, 16), (1, 2, 3, 16)]: + x = torch.randn(*shape) + out = norm(x) + assert out.shape == x.shape + + +# =========================================================================== +# TestRotaryEmbedding +# =========================================================================== + + +class TestRotaryEmbedding: + """Tests for RotaryEmbedding with lazy cache extension.""" + + def test_cache_shape(self): + dim, max_len = 16, 32 + rope = RotaryEmbedding(dim, max_len) + cos, sin = rope(max_len) + # Shape: [1, 1, T, dim] + assert cos.shape == (1, 1, max_len, dim) + assert sin.shape == (1, 1, max_len, dim) + + def test_lazy_extension(self): + """Requesting a length > initial cache doubles the cache.""" + rope = RotaryEmbedding(16, max_seq_len=8) + # Initial cache covers 8 positions + cos, sin = rope(8) + assert cos.shape[2] == 8 + + # Request 12 > 8 => cache doubles to 24 + cos, sin = rope(12) + assert cos.shape[2] == 12 + # Internal cache should have been rebuilt for 24 + assert rope._cos.shape[2] == 24 + + def test_cos_sin_at_pos_zero(self): + """At position 0 all frequencies are 0, so cos=1 and sin=0.""" + rope = RotaryEmbedding(16, max_seq_len=32) + cos, sin = rope(1) + assert torch.allclose(cos[0, 0, 0, :], torch.ones(16), atol=1e-6) + assert torch.allclose(sin[0, 0, 0, :], torch.zeros(16), atol=1e-6) + + def test_values_within_bounds(self): + """cos and sin values are in [-1, 1].""" + rope = RotaryEmbedding(16, max_seq_len=64) + cos, sin = rope(64) + assert cos.min() >= -1.0 - 1e-6 + assert cos.max() <= 1.0 + 1e-6 + assert sin.min() >= -1.0 - 1e-6 + assert sin.max() <= 1.0 + 1e-6 + + +# =========================================================================== +# TestApplyRotaryEmb +# =========================================================================== + + +class TestApplyRotaryEmb: + """Tests for _rotate_half and apply_rotary_emb.""" + + def test_rotate_half_shape(self): + x = torch.randn(B, 4, T, 16) + out = _rotate_half(x) + assert out.shape == x.shape + + def test_rotate_half_values(self): + """_rotate_half swaps halves with negation: [-x2, x1].""" + x = torch.tensor([1.0, 2.0, 3.0, 4.0]) + out = _rotate_half(x) + expected = torch.tensor([-3.0, -4.0, 1.0, 2.0]) + assert torch.allclose(out, expected) + + def test_shape_preserved(self): + rope = RotaryEmbedding(16, max_seq_len=32) + cos, sin = rope(T) + x = torch.randn(B, 4, T, 16) + out = apply_rotary_emb(x, cos, sin) + assert out.shape == x.shape + + def test_norm_preserved(self): + """RoPE is a rotation so the L2 norm per position should be preserved.""" + rope = RotaryEmbedding(16, max_seq_len=32) + cos, sin = rope(T) + x = torch.randn(B, 4, T, 16) + out = apply_rotary_emb(x, cos, sin) + # Compare norms per-position + x_norm = x.norm(dim=-1) + out_norm = out.norm(dim=-1) + assert torch.allclose(x_norm, out_norm, atol=1e-5) + + def test_position_zero_identity(self): + """At position 0, cos=1 and sin=0, so RoPE is the identity.""" + rope = RotaryEmbedding(16, max_seq_len=32) + cos, sin = rope(1) + x = torch.randn(B, 4, 1, 16) + out = apply_rotary_emb(x, cos, sin) + assert torch.allclose(x, out, atol=1e-6) + + +# =========================================================================== +# TestDeepSeekExpert +# =========================================================================== + + +class TestDeepSeekExpert: + """Tests for a single SwiGLU expert.""" + + def test_output_shape(self): + expert = DeepSeekExpert(d_model=64, hidden_dim=32) + x = torch.randn(B * T, 64) + out = expert(x) + assert out.shape == (B * T, 64) + + def test_swiglu_forward(self): + """Output equals w2(silu(w1(x)) * w3(x)).""" + expert = DeepSeekExpert(d_model=64, hidden_dim=32) + x = torch.randn(4, 64) + expected = expert.w2(torch.nn.functional.silu(expert.w1(x)) * expert.w3(x)) + actual = expert(x) + assert torch.allclose(actual, expected, atol=1e-6) + + def test_gradient_flow(self): + expert = DeepSeekExpert(d_model=64, hidden_dim=32) + x = torch.randn(4, 64, requires_grad=True) + out = expert(x) + out.sum().backward() + assert x.grad is not None + assert not torch.all(x.grad == 0) + # All three weight matrices should receive gradients + for name in ("w1", "w2", "w3"): + w = getattr(expert, name).weight + assert w.grad is not None, f"{name} has no gradient" + + def test_no_bias(self): + """Expert linear layers have no bias.""" + expert = DeepSeekExpert(d_model=64, hidden_dim=32) + for name in ("w1", "w2", "w3"): + assert getattr(expert, name).bias is None + + +# =========================================================================== +# TestDeepSeekGate +# =========================================================================== + + +class TestDeepSeekGate: + """Tests for the token-to-expert routing gate.""" + + def test_output_shapes(self): + gate = DeepSeekGate(d_model=64, n_routed_experts=4, n_activated=2) + x = torch.randn(B * T, 64) + weights, indices, scores = gate(x) + assert weights.shape == (B * T, 2) + assert indices.shape == (B * T, 2) + assert scores.shape == (B * T, 4) + + def test_topk_selection(self): + """Indices should be in [0, n_routed_experts).""" + gate = DeepSeekGate(d_model=64, n_routed_experts=4, n_activated=2) + x = torch.randn(B * T, 64) + _, indices, _ = gate(x) + assert indices.min() >= 0 + assert indices.max() < 4 + + def test_softmax_mode(self): + """With softmax, scores should sum to 1 per token.""" + gate = DeepSeekGate( + d_model=64, n_routed_experts=4, n_activated=2, score_func="softmax" + ) + x = torch.randn(B * T, 64) + _, _, scores = gate(x) + row_sums = scores.sum(dim=-1) + assert torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-5) + + def test_sigmoid_mode(self): + """With sigmoid, selected weights are re-normalised to sum to 1 per token.""" + gate = DeepSeekGate( + d_model=64, n_routed_experts=4, n_activated=2, score_func="sigmoid" + ) + x = torch.randn(B * T, 64) + weights, _, _ = gate(x) + row_sums = weights.sum(dim=-1) + assert torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-5) + + def test_route_scale(self): + """Weights should be scaled by route_scale.""" + gate_1 = DeepSeekGate( + d_model=64, n_routed_experts=4, n_activated=2, route_scale=1.0 + ) + gate_2 = DeepSeekGate( + d_model=64, n_routed_experts=4, n_activated=2, route_scale=2.0 + ) + # Copy weights so routing is identical + gate_2.weight.data.copy_(gate_1.weight.data) + x = torch.randn(B * T, 64) + w1, _, _ = gate_1(x) + w2, _, _ = gate_2(x) + assert torch.allclose(w2, w1 * 2.0, atol=1e-5) + + def test_no_bias_by_default(self): + gate = DeepSeekGate(d_model=64, n_routed_experts=4, n_activated=2) + assert gate.bias is None + + def test_with_bias(self): + gate = DeepSeekGate( + d_model=64, n_routed_experts=4, n_activated=2, use_bias=True + ) + assert gate.bias is not None + assert gate.bias.shape == (4,) + # Bias is initialized to zero + assert torch.allclose(gate.bias.data, torch.zeros(4)) + + def test_indices_unique_per_token(self): + """Each token selects distinct experts.""" + gate = DeepSeekGate(d_model=64, n_routed_experts=8, n_activated=3) + x = torch.randn(B * T, 64) + _, indices, _ = gate(x) + for row in range(indices.shape[0]): + unique = indices[row].unique() + assert len(unique) == indices.shape[1] + + +# =========================================================================== +# TestDeepSeekMoE +# =========================================================================== + + +class TestDeepSeekMoE: + """Tests for the full MoE layer.""" + + def test_forward_shape(self): + cfg = tiny_cfg() + moe = DeepSeekMoE(cfg) + x = torch.randn(B, T, 64) + out, _ = moe(x) + assert out.shape == (B, T, 64) + + def test_shared_plus_routed_combination(self): + """Output is non-zero and differs from input, showing both paths contribute.""" + cfg = tiny_cfg() + moe = DeepSeekMoE(cfg) + x = torch.randn(B, T, 64) + out, _ = moe(x) + assert not torch.allclose(out, x, atol=1e-3) + assert not torch.all(out == 0) + + def test_balance_loss_in_training(self): + cfg = tiny_cfg(moe_balance_alpha=0.01) + moe = DeepSeekMoE(cfg) + moe.train() + x = torch.randn(B, T, 64) + _, balance_loss = moe(x) + assert balance_loss is not None + assert balance_loss.dim() == 0 # scalar + assert balance_loss.item() >= 0.0 + + def test_no_balance_loss_in_eval(self): + cfg = tiny_cfg(moe_balance_alpha=0.01) + moe = DeepSeekMoE(cfg) + moe.eval() + x = torch.randn(B, T, 64) + _, balance_loss = moe(x) + assert balance_loss is None + + def test_no_balance_loss_when_alpha_zero(self): + cfg = tiny_cfg(moe_balance_alpha=0.0) + moe = DeepSeekMoE(cfg) + moe.train() + x = torch.randn(B, T, 64) + _, balance_loss = moe(x) + assert balance_loss is None + + def test_gradient_flow(self): + cfg = tiny_cfg() + moe = DeepSeekMoE(cfg) + moe.train() + x = torch.randn(B, T, 64, requires_grad=True) + out, bal = moe(x) + loss = out.sum() + if bal is not None: + loss = loss + bal + loss.backward() + assert x.grad is not None + assert not torch.all(x.grad == 0) + + def test_shared_expert_hidden_dim(self): + """Shared experts FFN hidden is n_shared_experts * expert_hidden_dim.""" + cfg = tiny_cfg(n_shared_experts=2, expert_hidden_dim=32) + moe = DeepSeekMoE(cfg) + assert moe.shared_experts.w1.out_features == 64 # 2 * 32 + + +# =========================================================================== +# TestMoDAAttention +# =========================================================================== + + +class TestMoDAAttention: + """Tests for MoDA attention (sequence + depth KV).""" + + def _make_rope(self, cfg): + rope = RotaryEmbedding(cfg.head_dim, cfg.max_seq_len, cfg.rope_base) + return rope(T) + + def test_output_shape(self): + cfg = tiny_cfg() + attn = MoDAAttention(cfg) + cos, sin = self._make_rope(cfg) + x = torch.randn(B, T, 64) + out = attn(x, [], [], cos, sin) + assert out.shape == (B, T, 64) + + def test_gqa_head_expansion(self): + """_expand_kv repeats KV heads to match query heads.""" + cfg = tiny_cfg(n_heads_q=4, n_heads_kv=2, head_dim=16) + attn = MoDAAttention(cfg) + kv = torch.randn(B, 2, T, 16) # [B, Hk, T, d] + expanded = attn._expand_kv(kv) + assert expanded.shape == (B, 4, T, 16) + # Head 0 of expanded should equal head 0 of original + assert torch.allclose(expanded[:, 0], kv[:, 0]) + assert torch.allclose(expanded[:, 1], kv[:, 0]) + assert torch.allclose(expanded[:, 2], kv[:, 1]) + assert torch.allclose(expanded[:, 3], kv[:, 1]) + + def test_gqa_no_expansion_when_equal(self): + """When n_heads_q == n_heads_kv, _expand_kv is identity.""" + cfg = tiny_cfg(n_heads_q=4, n_heads_kv=4, head_dim=16) + attn = MoDAAttention(cfg) + kv = torch.randn(B, 4, T, 16) + expanded = attn._expand_kv(kv) + assert expanded is kv # same object, no copy + + def test_forward_with_empty_depth_cache(self): + """Standard causal attention when no depth entries are present.""" + cfg = tiny_cfg() + attn = MoDAAttention(cfg) + cos, sin = self._make_rope(cfg) + x = torch.randn(B, T, 64) + out = attn(x, [], [], cos, sin) + assert out.shape == (B, T, 64) + assert torch.isfinite(out).all() + + def test_forward_with_depth_cache_entries(self): + """Attention integrates depth KV entries from preceding layers.""" + cfg = tiny_cfg() + attn = MoDAAttention(cfg) + cos, sin = self._make_rope(cfg) + x = torch.randn(B, T, 64) + + # Simulate 2 preceding layers each producing depth KV + depth_k = [torch.randn(B, cfg.n_heads_kv, T, cfg.head_dim) for _ in range(2)] + depth_v = [torch.randn(B, cfg.n_heads_kv, T, cfg.head_dim) for _ in range(2)] + + out = attn(x, depth_k, depth_v, cos, sin) + assert out.shape == (B, T, 64) + assert torch.isfinite(out).all() + + def test_depth_cache_changes_output(self): + """Adding depth cache entries should change the attention output.""" + cfg = tiny_cfg() + attn = MoDAAttention(cfg) + cos, sin = self._make_rope(cfg) + x = torch.randn(B, T, 64) + + out_empty = attn(x, [], [], cos, sin) + + depth_k = [torch.randn(B, cfg.n_heads_kv, T, cfg.head_dim)] + depth_v = [torch.randn(B, cfg.n_heads_kv, T, cfg.head_dim)] + out_depth = attn(x, depth_k, depth_v, cos, sin) + + assert not torch.allclose(out_empty, out_depth, atol=1e-4) + + def test_invalid_gqa_config(self): + """n_heads_q must be divisible by n_heads_kv.""" + cfg = tiny_cfg(n_heads_q=5, n_heads_kv=2) + with pytest.raises(ValueError, match="divisible"): + MoDAAttention(cfg) + + +# =========================================================================== +# TestMoDABlock +# =========================================================================== + + +class TestMoDABlock: + """Tests for a single MoDA + MoE transformer block.""" + + def _make_rope(self, cfg): + rope = RotaryEmbedding(cfg.head_dim, cfg.max_seq_len, cfg.rope_base) + return rope(T) + + def test_forward_shape(self): + cfg = tiny_cfg() + block = MoDABlock(cfg) + cos, sin = self._make_rope(cfg) + x = torch.randn(B, T, 64) + x_out, k_write, v_write, bal = block(x, [], [], cos, sin) + assert x_out.shape == (B, T, 64) + + def test_returns_four_values(self): + """Forward returns (x, k_write, v_write, balance_loss).""" + cfg = tiny_cfg() + block = MoDABlock(cfg) + cos, sin = self._make_rope(cfg) + x = torch.randn(B, T, 64) + result = block(x, [], [], cos, sin) + assert len(result) == 4 + + def test_k_v_write_shapes(self): + """Depth write projections produce [B, Hk, T, head_dim].""" + cfg = tiny_cfg() + block = MoDABlock(cfg) + cos, sin = self._make_rope(cfg) + x = torch.randn(B, T, 64) + _, k_write, v_write, _ = block(x, [], [], cos, sin) + expected_shape = (B, cfg.n_heads_kv, T, cfg.head_dim) + assert k_write.shape == expected_shape + assert v_write.shape == expected_shape + + def test_balance_loss_scalar_in_training(self): + cfg = tiny_cfg(moe_balance_alpha=0.01) + block = MoDABlock(cfg) + block.train() + cos, sin = self._make_rope(cfg) + x = torch.randn(B, T, 64) + _, _, _, bal = block(x, [], [], cos, sin) + assert bal is not None + assert bal.dim() == 0 + + def test_balance_loss_none_in_eval(self): + cfg = tiny_cfg(moe_balance_alpha=0.01) + block = MoDABlock(cfg) + block.eval() + cos, sin = self._make_rope(cfg) + x = torch.randn(B, T, 64) + _, _, _, bal = block(x, [], [], cos, sin) + assert bal is None + + def test_depth_cache_stacking(self): + """Simulate two consecutive blocks building up the depth cache.""" + cfg = tiny_cfg() + block0 = MoDABlock(cfg) + block1 = MoDABlock(cfg) + cos, sin = self._make_rope(cfg) + x = torch.randn(B, T, 64) + + depth_k, depth_v = [], [] + x, k0, v0, _ = block0(x, depth_k, depth_v, cos, sin) + depth_k.append(k0) + depth_v.append(v0) + + # Block 1 sees 1 depth entry from block 0 + x, k1, v1, _ = block1(x, depth_k, depth_v, cos, sin) + depth_k.append(k1) + depth_v.append(v1) + + assert len(depth_k) == 2 + assert len(depth_v) == 2 + + +# =========================================================================== +# TestMoDAModel +# =========================================================================== + + +class TestMoDAModel: + """Tests for the full MoDA + MoE language model.""" + + def test_forward_shape_logits(self): + cfg = tiny_cfg() + model = MoDAModel(cfg) + model.eval() + ids = torch.randint(0, cfg.vocab_size, (B, T)) + logits, loss = model(ids) + assert logits.shape == (B, T, cfg.vocab_size) + assert loss is None + + def test_loss_computation_with_labels(self): + cfg = tiny_cfg() + model = MoDAModel(cfg) + model.train() + ids = torch.randint(0, cfg.vocab_size, (B, T)) + labels = torch.randint(0, cfg.vocab_size, (B, T)) + logits, loss = model(ids, labels=labels) + assert logits.shape == (B, T, cfg.vocab_size) + assert loss is not None + assert loss.dim() == 0 + assert loss.item() > 0.0 + + def test_weight_tying(self): + cfg = tiny_cfg() + model = MoDAModel(cfg) + assert model.lm_head.weight is model.embed.weight + + def test_num_parameters(self): + cfg = tiny_cfg() + model = MoDAModel(cfg) + n_all = model.num_parameters(trainable_only=False) + n_train = model.num_parameters(trainable_only=True) + assert n_all > 0 + assert n_train == n_all # all params are trainable by default + + # Freeze some params and check trainable count drops + for p in model.embed.parameters(): + p.requires_grad_(False) + n_train_frozen = model.num_parameters(trainable_only=True) + assert n_train_frozen < n_all + + def test_forward_without_labels_returns_none_loss(self): + cfg = tiny_cfg() + model = MoDAModel(cfg) + model.eval() + ids = torch.randint(0, cfg.vocab_size, (B, T)) + logits, loss = model(ids) + assert loss is None + + def test_sequence_length_validation(self): + cfg = tiny_cfg(max_seq_len=16) + model = MoDAModel(cfg) + ids = torch.randint(0, cfg.vocab_size, (1, 20)) # exceeds 16 + with pytest.raises(ValueError, match="exceeds max_seq_len"): + model(ids) + + def test_loss_includes_balance_loss(self): + """When training with balance_alpha > 0, loss includes the balance term.""" + cfg = tiny_cfg(moe_balance_alpha=0.1) + model = MoDAModel(cfg) + model.train() + ids = torch.randint(0, cfg.vocab_size, (B, T)) + labels = torch.randint(0, cfg.vocab_size, (B, T)) + + # Get loss with balance + _, loss_with = model(ids, labels=labels) + + # Get loss without balance + cfg_no_bal = tiny_cfg(moe_balance_alpha=0.0) + model_no_bal = MoDAModel(cfg_no_bal) + model_no_bal.train() + # Copy weights so LM loss is comparable + model_no_bal.load_state_dict(model.state_dict(), strict=False) + _, loss_without = model_no_bal(ids, labels=labels) + + # Balance loss adds a non-negative term; loss_with >= loss_without in general + # (due to different routing from different gate inits this is approximate) + assert loss_with is not None + assert loss_without is not None + + def test_gradient_flow_full_model(self): + cfg = tiny_cfg() + model = MoDAModel(cfg) + model.train() + ids = torch.randint(0, cfg.vocab_size, (B, T)) + labels = torch.randint(0, cfg.vocab_size, (B, T)) + _, loss = model(ids, labels=labels) + loss.backward() + + # Check gradients reach the embedding + assert model.embed.weight.grad is not None + assert not torch.all(model.embed.weight.grad == 0) + + def test_extra_repr(self): + """extra_repr returns a meaningful string.""" + cfg = tiny_cfg() + model = MoDAModel(cfg) + r = model.extra_repr() + assert "vocab=200" in r + assert "d_model=64" in r + assert "layers=2" in r + + def test_depth_cache_grows_with_layers(self): + """Each layer adds one entry to the depth cache (verified via k_write counts).""" + cfg = tiny_cfg(n_layers=3) + model = MoDAModel(cfg) + model.eval() + ids = torch.randint(0, cfg.vocab_size, (B, T)) + # Forward runs without error for 3 layers + logits, _ = model(ids) + assert logits.shape == (B, T, cfg.vocab_size) + + def test_ignore_index_in_loss(self): + """Labels with -100 at some positions are excluded from the LM loss.""" + cfg = tiny_cfg() + model = MoDAModel(cfg) + model.train() + ids = torch.randint(0, cfg.vocab_size, (B, T)) + + # Fully valid labels + labels_full = torch.randint(0, cfg.vocab_size, (B, T)) + _, loss_full = model(ids, labels=labels_full) + + # Partially masked labels (mask the second half) + labels_partial = labels_full.clone() + labels_partial[:, T // 2 :] = -100 + _, loss_partial = model(ids, labels=labels_partial) + + # Both losses should be finite scalars + assert loss_full is not None and torch.isfinite(loss_full) + assert loss_partial is not None and torch.isfinite(loss_partial) + # They should generally differ since different positions are counted + # (not guaranteed to differ in magnitude, but they should both be valid) + assert loss_full.dim() == 0 + assert loss_partial.dim() == 0 diff --git a/tests/test_moe_before_after.py b/tests/test_moe_before_after.py new file mode 100644 index 0000000..2b827df --- /dev/null +++ b/tests/test_moe_before_after.py @@ -0,0 +1,264 @@ +""" +Before/After comparison: MoE dispatch optimization. + +Verifies that the new grouped dispatch (sort-by-expert, batch-per-expert) +produces identical numerical results to the old nested-loop dispatch. +""" + +import torch +import torch.nn.functional as F +import pytest + +from open_mythos.main import Expert, MoEFFN, MythosConfig + + +# --------------------------------------------------------------------------- +# Reference: OLD nested-loop dispatch (pre-optimization, commit before 65cd807) +# --------------------------------------------------------------------------- + +def old_moe_dispatch(moe: MoEFFN, x: torch.Tensor) -> torch.Tensor: + """Reimplementation of the old MoE forward — nested for-loops.""" + B, T, D = x.shape + x = x.to(moe.router.weight.dtype) + flat = x.view(B * T, D) + + logits = moe.router(flat) + scores = F.softmax(logits, dim=-1) + _, topk_idx = (logits + moe.router_bias).topk(moe.topk, dim=-1) + topk_scores = scores.gather(-1, topk_idx) + # OLD code: no .clamp(min=1e-9) + topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) + + # OLD code: nested for-loops + out = torch.zeros_like(flat) + for i in range(moe.topk): + expert_ids = topk_idx[:, i] + token_scores = topk_scores[:, i].unsqueeze(-1) + for eid in range(moe.n_experts): + mask = expert_ids == eid + if not mask.any(): + continue + out[mask] += token_scores[mask] * moe.routed_experts[eid](flat[mask]) + + for shared in moe.shared_experts: + out = out + shared(flat) + + return out.view(B, T, D) + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + +def small_cfg(**overrides) -> MythosConfig: + defaults = dict( + vocab_size=200, + dim=64, + n_heads=4, + n_kv_heads=2, + max_seq_len=32, + max_loop_iters=3, + prelude_layers=1, + coda_layers=1, + attn_type="gqa", + n_experts=4, + n_shared_experts=1, + n_experts_per_tok=2, + expert_dim=16, + act_threshold=0.99, + lora_rank=4, + kv_lora_rank=16, + q_lora_rank=32, + qk_rope_head_dim=8, + qk_nope_head_dim=8, + v_head_dim=8, + ) + defaults.update(overrides) + return MythosConfig(**defaults) + + +B, T = 2, 8 + + +# =================================================================== +# Numerical equivalence tests +# =================================================================== + + +class TestMoEBeforeAfterEquivalence: + """Verify that old and new MoE dispatch produce identical results.""" + + def test_basic_equivalence(self): + """Standard batch: old and new dispatch should match within float32 tolerance.""" + torch.manual_seed(42) + cfg = small_cfg() + moe = MoEFFN(cfg) + moe.eval() + + x = torch.randn(B, T, cfg.dim) + new_out = moe(x) + old_out = old_moe_dispatch(moe, x) + + assert torch.allclose(new_out, old_out, atol=1e-5), ( + f"Max diff: {(new_out - old_out).abs().max().item()}" + ) + + def test_equivalence_single_token(self): + """Single token (B=1, T=1).""" + torch.manual_seed(123) + cfg = small_cfg() + moe = MoEFFN(cfg) + moe.eval() + + x = torch.randn(1, 1, cfg.dim) + new_out = moe(x) + old_out = old_moe_dispatch(moe, x) + + assert torch.allclose(new_out, old_out, atol=1e-5), ( + f"Max diff: {(new_out - old_out).abs().max().item()}" + ) + + def test_equivalence_topk_1(self): + """Top-1 routing.""" + torch.manual_seed(7) + cfg = small_cfg(n_experts_per_tok=1) + moe = MoEFFN(cfg) + moe.eval() + + x = torch.randn(B, T, cfg.dim) + new_out = moe(x) + old_out = old_moe_dispatch(moe, x) + + assert torch.allclose(new_out, old_out, atol=1e-5), ( + f"Max diff: {(new_out - old_out).abs().max().item()}" + ) + + def test_equivalence_all_experts_selected(self): + """Every expert selected for every token (topk == n_experts).""" + torch.manual_seed(99) + cfg = small_cfg(n_experts=4, n_experts_per_tok=4) + moe = MoEFFN(cfg) + moe.eval() + + x = torch.randn(B, T, cfg.dim) + new_out = moe(x) + old_out = old_moe_dispatch(moe, x) + + assert torch.allclose(new_out, old_out, atol=1e-5), ( + f"Max diff: {(new_out - old_out).abs().max().item()}" + ) + + def test_equivalence_forced_single_expert(self): + """Force all tokens to the same expert via router_bias.""" + torch.manual_seed(0) + cfg = small_cfg(n_experts=4, n_experts_per_tok=2) + moe = MoEFFN(cfg) + moe.eval() + + moe.router_bias.data = torch.tensor([1000.0, 999.0, -1000.0, -1000.0]) + x = torch.randn(B, T, cfg.dim) + new_out = moe(x) + old_out = old_moe_dispatch(moe, x) + + assert torch.allclose(new_out, old_out, atol=1e-5), ( + f"Max diff: {(new_out - old_out).abs().max().item()}" + ) + + def test_equivalence_larger_batch(self): + """Larger batch stress test.""" + torch.manual_seed(314) + cfg = small_cfg(n_experts=8, n_experts_per_tok=3) + moe = MoEFFN(cfg) + moe.eval() + + x = torch.randn(4, 16, cfg.dim) + new_out = moe(x) + old_out = old_moe_dispatch(moe, x) + + assert torch.allclose(new_out, old_out, atol=1e-5), ( + f"Max diff: {(new_out - old_out).abs().max().item()}" + ) + + def test_gradient_equivalence(self): + """Gradients w.r.t. input should match between old and new dispatch.""" + torch.manual_seed(77) + cfg = small_cfg() + moe = MoEFFN(cfg) + + x1 = torch.randn(B, T, cfg.dim, requires_grad=True) + x2 = x1.clone().detach().requires_grad_(True) + + new_out = moe(x1) + new_out.sum().backward() + + # Reset grads in the MoE + moe.zero_grad() + + old_out = old_moe_dispatch(moe, x2) + old_out.sum().backward() + + assert x1.grad is not None and x2.grad is not None + assert torch.allclose(x1.grad, x2.grad, atol=1e-4), ( + f"Max grad diff: {(x1.grad - x2.grad).abs().max().item()}" + ) + + def test_epsilon_guard_difference(self): + """The .clamp(min=1e-9) is a safety net; verify it doesn't change + normal-case results (scores are never actually zero in practice).""" + torch.manual_seed(42) + cfg = small_cfg() + moe = MoEFFN(cfg) + moe.eval() + + x = torch.randn(B, T, cfg.dim) + flat = x.view(B * T, cfg.dim) + logits = moe.router(flat) + scores = F.softmax(logits, dim=-1) + _, topk_idx = (logits + moe.router_bias).topk(moe.topk, dim=-1) + topk_scores = scores.gather(-1, topk_idx) + + # Without epsilon + renorm_no_eps = topk_scores / topk_scores.sum(dim=-1, keepdim=True) + # With epsilon + renorm_with_eps = topk_scores / topk_scores.sum(dim=-1, keepdim=True).clamp(min=1e-9) + + # In normal cases, these should be identical + assert torch.allclose(renorm_no_eps, renorm_with_eps, atol=1e-9) + + +class TestMoEDispatchPerformanceCharacteristics: + """Verify the new grouped dispatch has the same semantic behavior.""" + + def test_each_expert_called_exactly_once_per_batch(self): + """In grouped dispatch, each active expert should be called once + with all its assigned tokens batched together.""" + torch.manual_seed(42) + cfg = small_cfg(n_experts=4, n_experts_per_tok=2) + moe = MoEFFN(cfg) + moe.eval() + + x = torch.randn(B, T, cfg.dim) + flat = x.view(B * T, cfg.dim) + + logits = moe.router(flat) + _, topk_idx = (logits + moe.router_bias).topk(moe.topk, dim=-1) + + flat_expert_ids = topk_idx.view(-1) + unique_experts = torch.unique(flat_expert_ids) + + # Each unique expert in the routing should appear at least once + assert len(unique_experts) > 0 + assert len(unique_experts) <= cfg.n_experts + + def test_output_preserves_batch_structure(self): + """Output shape must match input shape through the dispatch.""" + cfg = small_cfg() + moe = MoEFFN(cfg) + for b, t in [(1, 1), (1, 16), (4, 8), (8, 1)]: + x = torch.randn(b, t, cfg.dim) + out = moe(x) + assert out.shape == (b, t, cfg.dim) + + +if __name__ == "__main__": + pytest.main([__file__, "--verbose"]) diff --git a/tests/test_stochastic_depth.py b/tests/test_stochastic_depth.py new file mode 100644 index 0000000..cea3c9c --- /dev/null +++ b/tests/test_stochastic_depth.py @@ -0,0 +1,170 @@ +"""Tests for stochastic-depth (Option B) training path: bypass_act flag.""" + +import torch + +from open_mythos.main import MythosConfig, OpenMythos + + +def _small_cfg() -> MythosConfig: + """Small CPU config used by the existing test suite.""" + return MythosConfig( + vocab_size=128, + dim=64, + n_heads=4, + n_kv_heads=2, + max_seq_len=32, + max_loop_iters=4, + prelude_layers=1, + coda_layers=1, + attn_type="gqa", + n_experts=2, + n_shared_experts=1, + n_experts_per_tok=2, + expert_dim=64, + act_threshold=0.99, + lora_rank=4, + ) + + +def _build_block_inputs(cfg: MythosConfig, B: int = 2, T: int = 8): + """Build the (h, e, freqs_cis) inputs needed by RecurrentBlock.forward.""" + torch.manual_seed(0) + model = OpenMythos(cfg) + input_ids = torch.randint(0, cfg.vocab_size, (B, T)) + x = model.embed(input_ids) + freqs_cis = model.freqs_cis[:T] + mask = model._causal_mask(T, x.device, x.dtype) + for i, layer in enumerate(model.prelude): + x = layer(x, freqs_cis, mask, None, cache_key=f"prelude_{i}") + return model.recurrent, x.clone(), x.clone(), freqs_cis, mask + + +def test_recurrent_block_bypass_act_differs_from_act(): + """bypass_act=True should produce a different output than bypass_act=False.""" + cfg = _small_cfg() + block, h, e, freqs_cis, mask = _build_block_inputs(cfg) + torch.manual_seed(1) + out_act = block(h.clone(), e, freqs_cis, mask, n_loops=4, bypass_act=False) + torch.manual_seed(1) + out_bypass = block(h.clone(), e, freqs_cis, mask, n_loops=4, bypass_act=True) + assert out_act.shape == out_bypass.shape + assert not torch.allclose( + out_act, out_bypass, atol=1e-6 + ), "bypass_act=True should not equal ACT-weighted output" + + +def test_recurrent_block_bypass_act_runs_full_n_loops(): + """With bypass_act=True there should be no early exit; all n_loops iterations run.""" + cfg = _small_cfg() + block, h, e, freqs_cis, mask = _build_block_inputs(cfg) + call_count = {"n": 0} + original_block = block.block.forward + + def counting_forward(*args, **kwargs): + call_count["n"] += 1 + return original_block(*args, **kwargs) + + block.block.forward = counting_forward + try: + _ = block(h, e, freqs_cis, mask, n_loops=3, bypass_act=True) + finally: + block.block.forward = original_block + assert call_count["n"] == 3, f"expected 3 block calls, got {call_count['n']}" + + +def test_recurrent_block_bypass_act_returns_final_h(): + """bypass_act=True output should match a manual iteration returning the final h.""" + from open_mythos.main import loop_index_embedding + + cfg = _small_cfg() + block, h, e, freqs_cis, mask = _build_block_inputs(cfg) + n_loops = 3 + + torch.manual_seed(1) + h_manual = h.clone() + for t in range(n_loops): + h_loop = loop_index_embedding(h_manual, t, block.loop_dim) + combined = block.norm(h_loop + e) + trans_out = block.block(combined, freqs_cis, mask, None, f"recurrent_loop_{t}") + trans_out = trans_out + block.lora(trans_out, t) + h_manual = block.injection(h_manual, e, trans_out) + + torch.manual_seed(1) + out_bypass = block(h.clone(), e, freqs_cis, mask, n_loops=n_loops, bypass_act=True) + + assert torch.allclose( + out_bypass, h_manual, atol=1e-5 + ), "bypass_act=True should return the final hidden state after n_loops iterations" + + +def test_openmythos_forward_bypass_act_propagates(): + """OpenMythos.forward(bypass_act=True) should route through RecurrentBlock with bypass_act=True.""" + cfg = _small_cfg() + torch.manual_seed(0) + model = OpenMythos(cfg) + input_ids = torch.randint(0, cfg.vocab_size, (2, 8)) + + torch.manual_seed(1) + logits_act = model(input_ids, n_loops=3, bypass_act=False) + torch.manual_seed(1) + logits_bypass = model(input_ids, n_loops=3, bypass_act=True) + + assert logits_act.shape == logits_bypass.shape + assert not torch.allclose( + logits_act, logits_bypass, atol=1e-6 + ), "bypass_act should change model output" + + +def test_state_dict_compatible_across_modes(tmp_path): + """state_dict round-trips cleanly and the loaded model works in both ACT and bypass modes.""" + cfg = _small_cfg() + torch.manual_seed(0) + model_a = OpenMythos(cfg) + ckpt_path = tmp_path / "model.pt" + torch.save(model_a.state_dict(), ckpt_path) + + torch.manual_seed(1) + model_b = OpenMythos(cfg) + state = torch.load(ckpt_path, map_location="cpu") + # strict=True raises if any keys are missing or unexpected, which is the + # actual compatibility check. + model_b.load_state_dict(state, strict=True) + + input_ids = torch.randint(0, cfg.vocab_size, (2, 8)) + torch.manual_seed(2) + logits_act = model_b(input_ids, n_loops=3, bypass_act=False) + torch.manual_seed(2) + logits_bypass = model_b(input_ids, n_loops=3, bypass_act=True) + assert logits_act.shape == logits_bypass.shape + assert torch.isfinite(logits_act).all(), "ACT logits must be finite" + assert torch.isfinite(logits_bypass).all(), "bypass logits must be finite" + + +def test_training_step_runs_in_each_mode(): + """One forward+backward+optimizer step works in both modes without error.""" + cfg = _small_cfg() + torch.manual_seed(0) + model = OpenMythos(cfg) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + input_ids = torch.randint(0, cfg.vocab_size, (2, 8)) + targets = torch.randint(0, cfg.vocab_size, (2, 8)) + + # ACT mode + optimizer.zero_grad() + logits = model(input_ids, n_loops=None, bypass_act=False) + loss_act = torch.nn.functional.cross_entropy( + logits.view(-1, cfg.vocab_size), targets.view(-1) + ) + loss_act.backward() + optimizer.step() + assert torch.isfinite(loss_act), "ACT-mode loss must be finite" + + # Stochastic-depth mode + optimizer.zero_grad() + logits = model(input_ids, n_loops=3, bypass_act=True) + loss_sd = torch.nn.functional.cross_entropy( + logits.view(-1, cfg.vocab_size), targets.view(-1) + ) + loss_sd.backward() + optimizer.step() + assert torch.isfinite(loss_sd), "stochastic-depth-mode loss must be finite" diff --git a/tests/test_variants.py b/tests/test_variants.py new file mode 100644 index 0000000..b90ceae --- /dev/null +++ b/tests/test_variants.py @@ -0,0 +1,118 @@ +"""Comprehensive tests for open_mythos/variants.py factory functions.""" + +import pytest +import torch + +from open_mythos.variants import ( + mythos_1b, + mythos_3b, + mythos_10b, + mythos_50b, + mythos_100b, + mythos_500b, + mythos_1t, +) +from open_mythos.main import MythosConfig, OpenMythos + +# Ordered from smallest to largest scale. +ALL_FACTORIES = [ + mythos_1b, + mythos_3b, + mythos_10b, + mythos_50b, + mythos_100b, + mythos_500b, + mythos_1t, +] + +# Configs that are small enough to actually instantiate on CPU without OOM. +SMALL_FACTORIES = [mythos_1b, mythos_3b] + + +class TestVariantConfigs: + """Tests for every variant factory function in variants.py.""" + + @pytest.mark.parametrize("factory", ALL_FACTORIES, ids=lambda f: f.__name__) + def test_each_variant_returns_config(self, factory): + """All 7 factory functions return a MythosConfig instance.""" + cfg = factory() + assert isinstance(cfg, MythosConfig) + + @pytest.mark.parametrize("factory", SMALL_FACTORIES, ids=lambda f: f.__name__) + def test_each_variant_instantiates_model(self, factory): + """1b and 3b configs can create an OpenMythos model on CPU.""" + cfg = factory() + model = OpenMythos(cfg) + assert isinstance(model, torch.nn.Module) + # Sanity: model should have parameters. + param_count = sum(p.numel() for p in model.parameters()) + assert param_count > 0 + + @pytest.mark.parametrize("factory", ALL_FACTORIES, ids=lambda f: f.__name__) + def test_dim_divisible_by_n_heads(self, factory): + """dim must be evenly divisible by n_heads.""" + cfg = factory() + assert cfg.dim % cfg.n_heads == 0, ( + f"{factory.__name__}: dim={cfg.dim} not divisible by n_heads={cfg.n_heads}" + ) + + @pytest.mark.parametrize("factory", ALL_FACTORIES, ids=lambda f: f.__name__) + def test_n_heads_divisible_by_n_kv_heads(self, factory): + """n_heads must divide evenly by n_kv_heads (for GQA grouping).""" + cfg = factory() + assert cfg.n_heads % cfg.n_kv_heads == 0, ( + f"{factory.__name__}: n_heads={cfg.n_heads} not divisible by " + f"n_kv_heads={cfg.n_kv_heads}" + ) + + @pytest.mark.parametrize("factory", ALL_FACTORIES, ids=lambda f: f.__name__) + def test_vocab_size_positive(self, factory): + """All configs must have a positive vocab_size.""" + cfg = factory() + assert cfg.vocab_size > 0 + + def test_dimensions_increase_with_scale(self): + """dim must strictly increase from 1b -> 3b -> 10b -> ... -> 1t.""" + dims = [f().dim for f in ALL_FACTORIES] + for i in range(len(dims) - 1): + assert dims[i] < dims[i + 1], ( + f"dim did not increase: {ALL_FACTORIES[i].__name__} " + f"(dim={dims[i]}) >= {ALL_FACTORIES[i+1].__name__} " + f"(dim={dims[i+1]})" + ) + + def test_expert_count_increases_or_stays(self): + """Larger models should have n_experts >= the previous scale.""" + expert_counts = [f().n_experts for f in ALL_FACTORIES] + for i in range(len(expert_counts) - 1): + assert expert_counts[i] <= expert_counts[i + 1], ( + f"n_experts decreased: {ALL_FACTORIES[i].__name__} " + f"({expert_counts[i]}) > {ALL_FACTORIES[i+1].__name__} " + f"({expert_counts[i+1]})" + ) + + @pytest.mark.parametrize("factory", ALL_FACTORIES, ids=lambda f: f.__name__) + def test_max_loop_iters_positive(self, factory): + """All configs must have positive max_loop_iters.""" + cfg = factory() + assert cfg.max_loop_iters > 0 + + @pytest.mark.parametrize("factory", ALL_FACTORIES, ids=lambda f: f.__name__) + def test_attn_type_is_mla(self, factory): + """All variants use Multi-Latent Attention.""" + cfg = factory() + assert cfg.attn_type == "mla" + + @pytest.mark.parametrize("factory", ALL_FACTORIES, ids=lambda f: f.__name__) + def test_act_threshold_valid(self, factory): + """act_threshold must be in the range (0, 1].""" + cfg = factory() + assert 0.0 < cfg.act_threshold <= 1.0, ( + f"{factory.__name__}: act_threshold={cfg.act_threshold} out of (0, 1]" + ) + + @pytest.mark.parametrize("factory", ALL_FACTORIES, ids=lambda f: f.__name__) + def test_rope_theta_positive(self, factory): + """All configs must have a positive rope_theta.""" + cfg = factory() + assert cfg.rope_theta > 0.0 diff --git a/training/1b_fine_web_edu.py b/training/1b_fine_web_edu.py new file mode 100644 index 0000000..d8f8e08 --- /dev/null +++ b/training/1b_fine_web_edu.py @@ -0,0 +1,707 @@ +#!/usr/bin/env python3 +""" +OpenMythos 1B pretraining on FineWeb-Edu with FSDP + AdamW + optional ClearML. + +Supports both the original ACT recipe and the new stochastic-depth recipe +(Option B) via the `recurrent_mode` hyperparameter in main(). Checkpoints +are compatible across modes. + +Single GPU: + python training/1b_fine_web_edu.py + +Multi-GPU: + torchrun --nproc_per_node=N training/1b_fine_web_edu.py + +Dataset: expects FineWeb-Edu parquet files at DATASET_PATH (see docs/datasets.md +for preparation instructions). Uses direct pyarrow parquet reading rather than +the HuggingFace `datasets` streaming iterator (~17,000x faster for local files). + +Environment variables (optional): + DATASET_PATH -- local path to FineWeb-Edu parquet files (required) + OUTPUT_DIR -- checkpoint + log directory (default: ./output/experiments) + TARGET_TOKENS -- token budget in billions (default: 10) + HF_TOKEN -- HuggingFace token, for tokenizer download + +ClearML tracking (optional — set all three to enable): + CLEARML_API_HOST + CLEARML_API_ACCESS_KEY + CLEARML_API_SECRET_KEY + CLEARML_PROJECT -- ClearML project name (default: openmythos) + EXPERIMENT_NAME -- ClearML task name (default: 1b-fine-web-edu) +""" + +import os +import math +import random +import time +import torch +import torch.nn as nn +import torch.distributed as dist +from loguru import logger +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + ShardingStrategy, + MixedPrecision, + FullStateDictConfig, + StateDictType, +) +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.utils.data import IterableDataset, DataLoader, get_worker_info +from contextlib import nullcontext + +import glob as _glob +import pyarrow.parquet as pq +from datasets import load_dataset + +from open_mythos import OpenMythos +from open_mythos.main import TransformerBlock, RecurrentBlock +from open_mythos.variants import mythos_1b +from open_mythos.tokenizer import MythosTokenizer + +# --------------------------------------------------------------------------- +# ClearML (lazy — only initialized on rank 0) +# --------------------------------------------------------------------------- + +_clearml_task = None +_clearml_logger = None + + +def init_clearml(cfg, training_hparams: dict, timeout: int = 30): + """Initialize ClearML tracking on rank 0. No-op if unreachable or missing.""" + global _clearml_task, _clearml_logger + import signal + + def _timeout_handler(signum, frame): + raise TimeoutError("ClearML init timed out") + + try: + from clearml import Task + + project = os.environ.get("CLEARML_PROJECT", "openmythos") + task_name = os.environ.get("EXPERIMENT_NAME", "1b-fine-web-edu") + + # Task.init can hang if the ClearML server is unreachable (e.g., + # the network is restricted). Use a SIGALRM timeout to fail fast. + old_handler = signal.signal(signal.SIGALRM, _timeout_handler) + signal.alarm(timeout) + try: + _clearml_task = Task.init(project_name=project, task_name=task_name) + _clearml_task.connect(vars(cfg), name="model_config") + _clearml_task.connect(training_hparams, name="training_hparams") + _clearml_logger = _clearml_task.get_logger() + logger.info(f"ClearML initialized: project={project}, task={task_name}") + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + except Exception as e: + logger.warning( + f"ClearML init failed (training continues without tracking): {e}" + ) + + +def log_clearml(series: str, value: float, step: int): + """Report a scalar to ClearML if available.""" + if _clearml_logger is not None: + _clearml_logger.report_scalar("train", series, iteration=step, value=value) + + +def log_clearml_text(title: str, text: str): + """Log text to ClearML if available.""" + if _clearml_logger is not None: + _clearml_logger.report_text(f"## {title}\n\n{text}") + + +def register_clearml_artifact(name: str, path: str): + """Register a file artifact in ClearML if available.""" + if _clearml_task is not None: + _clearml_task.upload_artifact(name, artifact_object=path) + + +# --------------------------------------------------------------------------- +# Dataset +# --------------------------------------------------------------------------- + + +class FineWebEduDataset(IterableDataset): + """ + FineWeb-Edu loader yielding fixed-length (input, target) pairs. + + Supports two modes: + - Local parquet: loads from a directory of .parquet files (no internet needed) + - Streaming: pulls shards on demand from HuggingFace (requires internet) + + Documents are concatenated into a rolling buffer and sliced into + fixed-length chunks. Sharding is two-dimensional: world_size ranks x + num_workers DataLoader workers per rank. + """ + + def __init__( + self, + encoding, + seq_len: int, + rank: int, + world_size: int, + dataset_path: str = "", + dataset_subset: str = "sample-10BT", + ): + self.encoding = encoding + self.seq_len = seq_len + self.rank = rank + self.world_size = world_size + self.dataset_path = dataset_path + self.dataset_subset = dataset_subset + + def _get_parquet_files(self, shard_index: int, total_shards: int) -> list[str]: + """Return the subset of parquet files assigned to this shard.""" + all_files = sorted(_glob.glob(os.path.join(self.dataset_path, "*.parquet"))) + if not all_files: + raise FileNotFoundError(f"No .parquet files found in {self.dataset_path}") + return [f for i, f in enumerate(all_files) if i % total_shards == shard_index] + + def _iter_parquet(self, shard_index: int, total_shards: int): + """Read local parquet files directly via pyarrow. Loops infinitely.""" + files = self._get_parquet_files(shard_index, total_shards) + if not files: + return + + buf: list[int] = [] + while True: + for parquet_path in files: + table = pq.read_table(parquet_path, columns=["text"]) + text_column = table.column("text") + del table + + for text_value in text_column: + text = text_value.as_py() + if text: + buf.extend(self.encoding.encode(text)) + while len(buf) >= self.seq_len + 1: + chunk = buf[: self.seq_len + 1] + buf = buf[self.seq_len + 1 :] + yield ( + torch.tensor(chunk[:-1], dtype=torch.long), + torch.tensor(chunk[1:], dtype=torch.long), + ) + del text_column + + def _iter_streaming(self, shard_index: int, total_shards: int): + """HuggingFace streaming fallback (requires internet).""" + ds = load_dataset( + "HuggingFaceFW/fineweb-edu", + name=self.dataset_subset, + split="train", + streaming=True, + ).shard(num_shards=total_shards, index=shard_index) + + buf: list[int] = [] + for sample in ds: + buf.extend(self.encoding.encode(sample["text"])) + while len(buf) >= self.seq_len + 1: + chunk = buf[: self.seq_len + 1] + buf = buf[self.seq_len + 1 :] + yield ( + torch.tensor(chunk[:-1], dtype=torch.long), + torch.tensor(chunk[1:], dtype=torch.long), + ) + + def __iter__(self): + worker = get_worker_info() + num_workers = worker.num_workers if worker else 1 + worker_id = worker.id if worker else 0 + + total_shards = self.world_size * num_workers + shard_index = self.rank * num_workers + worker_id + + if self.dataset_path: + yield from self._iter_parquet(shard_index, total_shards) + else: + yield from self._iter_streaming(shard_index, total_shards) + + +# --------------------------------------------------------------------------- +# LR schedule: linear warmup -> cosine decay +# --------------------------------------------------------------------------- + + +def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float: + if step < warmup: + return max_lr * step / warmup + if step >= total: + return min_lr + decay = (step - warmup) / (total - warmup) + return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay)) + + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- + + +def _list_ckpts(ckpt_dir: str) -> list[str]: + if not os.path.isdir(ckpt_dir): + return [] + return sorted( + os.path.join(ckpt_dir, f) + for f in os.listdir(ckpt_dir) + if f.startswith("step_") and f.endswith(".pt") + ) + + +def save_checkpoint( + model, + optimizer, + step: int, + cfg, + vocab_size: int, + ckpt_dir: str, + ddp: bool, + master: bool, + keep_last: int = 3, +) -> None: + if ddp: + with FSDP.state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + model_state = model.state_dict() + optim_state = FSDP.optim_state_dict(model, optimizer) + else: + model_state = model.state_dict() + optim_state = optimizer.state_dict() + + if not master: + return + + os.makedirs(ckpt_dir, exist_ok=True) + final_path = os.path.join(ckpt_dir, f"step_{step:07d}.pt") + tmp_path = final_path + ".tmp" + torch.save( + { + "step": step, + "model": model_state, + "optimizer": optim_state, + "cfg": cfg, + "vocab_size": vocab_size, + }, + tmp_path, + ) + os.replace(tmp_path, final_path) + + for old in _list_ckpts(ckpt_dir)[:-keep_last]: + try: + os.remove(old) + except OSError as exc: + logger.warning(f"Failed to prune old checkpoint {old}: {exc}") + + logger.success(f"Checkpoint saved -> {final_path}") + + +def load_checkpoint(model, optimizer, path: str, ddp: bool) -> int: + ckpt = torch.load(path, map_location="cpu", weights_only=False) + + if ddp: + with FSDP.state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=False), + ): + model.load_state_dict(ckpt["model"]) + optim_state = FSDP.optim_state_dict_to_load( + model=model, + optim=optimizer, + optim_state_dict=ckpt["optimizer"], + ) + optimizer.load_state_dict(optim_state) + else: + model.load_state_dict(ckpt["model"]) + optimizer.load_state_dict(ckpt["optimizer"]) + + return int(ckpt["step"]) + + +# --------------------------------------------------------------------------- +# Post-training generation test +# --------------------------------------------------------------------------- + + +GENERATION_PROMPTS = [ + "The purpose of education is", + "In the beginning, there was", + "The most important scientific discovery", +] + + +def run_generation_test(cfg, ckpt_dir, encoding, device: str): + """ + Reconstruct a raw model from the latest checkpoint and generate text. + + Under FSDP, calling model.module.generate() while parameters are still + sharded across ranks produces incorrect output or deadlocks. Instead, + this function loads the fully-gathered checkpoint (saved by rank 0) into + a fresh, unwrapped model on a single GPU after the process group has + been torn down. Safe for both single-GPU and post-FSDP scenarios. + """ + logger.info("Running post-training generation test...") + + ckpts = _list_ckpts(ckpt_dir) + if not ckpts: + logger.warning("No checkpoint found — skipping generation test.") + return + + ckpt = torch.load(ckpts[-1], map_location=device, weights_only=False) + raw_model = OpenMythos(cfg) + raw_model.load_state_dict(ckpt["model"]) + raw_model = raw_model.to(device) + raw_model.eval() + + results = [] + for prompt_text in GENERATION_PROMPTS: + tokens = encoding.encode(prompt_text) + input_ids = torch.tensor([tokens], dtype=torch.long, device=device) + + with torch.no_grad(): + output_ids = raw_model.generate( + input_ids, + max_new_tokens=128, + temperature=0.8, + top_k=40, + ) + + generated = encoding.decode(output_ids[0].tolist()) + result = f"**Prompt:** {prompt_text}\n**Generated:** {generated}\n" + results.append(result) + logger.info(f"\n{result}") + + all_results = "\n---\n".join(results) + log_clearml_text("Generation Samples", all_results) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + # ------------------------------------------------------------------ + # Distributed init + # ------------------------------------------------------------------ + ddp = int(os.environ.get("RANK", -1)) != -1 + if ddp: + dist.init_process_group("nccl") + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + device = f"cuda:{local_rank}" + torch.cuda.set_device(device) + else: + rank = local_rank = 0 + world_size = 1 + device = "cuda" if torch.cuda.is_available() else "cpu" + + master = rank == 0 + + if master: + logger.info( + f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}" + ) + + # ------------------------------------------------------------------ + # Tokenizer + # ------------------------------------------------------------------ + encoding = MythosTokenizer() + vocab_size = encoding.vocab_size + + if master: + logger.info(f"Tokenizer: gpt-oss-20b | Vocab size: {vocab_size:,}") + + # ------------------------------------------------------------------ + # Hyperparameters (env-var configurable with defaults) + # ------------------------------------------------------------------ + # Recurrent-depth training recipe (Option A: ACT, Option B: stochastic depth). + # Change recurrent_mode to "act" to use the original ACT halting recipe. + recurrent_mode = "stochastic_depth" # "act" or "stochastic_depth" + stochastic_depth_min = 1 + stochastic_depth_max = 32 + + seq_len = 2048 + micro_batch = 1 + target_tokens_b = int(os.environ.get("TARGET_TOKENS", "10")) + target_tokens = target_tokens_b * 1_000_000_000 + grad_accum = max(1, 16 // (world_size * micro_batch)) + global_batch_tok = world_size * micro_batch * grad_accum * seq_len + total_steps = target_tokens // global_batch_tok + warmup_steps = 2000 + lr = 3e-4 + min_lr = 3e-5 + wd = 0.1 + log_every = 1 + ckpt_every = 1000 + output_dir = os.environ.get( + "OUTPUT_DIR", "./output/experiments" + ) + ckpt_dir = os.path.join(output_dir, "checkpoints") + dataset_path = os.environ.get( + "DATASET_PATH", "./data/fineweb-edu" + ) + dataset_subset = "sample-10BT" + + training_hparams = { + "seq_len": seq_len, + "micro_batch": micro_batch, + "target_tokens": target_tokens, + "grad_accum": grad_accum, + "global_batch_tok": global_batch_tok, + "total_steps": total_steps, + "warmup_steps": warmup_steps, + "lr": lr, + "min_lr": min_lr, + "weight_decay": wd, + "log_every": log_every, + "ckpt_every": ckpt_every, + "output_dir": output_dir, + "dataset_path": dataset_path, + "dataset_subset": dataset_subset, + "world_size": world_size, + "recurrent_mode": recurrent_mode, + "stochastic_depth_min": stochastic_depth_min, + "stochastic_depth_max": stochastic_depth_max, + } + + if master: + logger.info( + f"seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | " + f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,} | " + f"target_tokens={target_tokens_b}B" + ) + + # ------------------------------------------------------------------ + # Model + # ------------------------------------------------------------------ + cfg = mythos_1b() + cfg.vocab_size = vocab_size + cfg.max_seq_len = seq_len + + bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported() + amp_dtype = torch.bfloat16 if bf16_ok else torch.float16 + + model = OpenMythos(cfg) + + if ddp: + mp_policy = MixedPrecision( + param_dtype=amp_dtype, + reduce_dtype=amp_dtype, + buffer_dtype=amp_dtype, + ) + wrap_policy = ModuleWrapPolicy({TransformerBlock, RecurrentBlock}) + model = FSDP( + model, + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mp_policy, + auto_wrap_policy=wrap_policy, + device_id=local_rank, + ) + else: + model = model.to(device) + amp_ctx = ( + torch.amp.autocast(device_type="cuda", dtype=amp_dtype) + if "cuda" in device + else nullcontext() + ) + + amp_ctx = nullcontext() if ddp else amp_ctx # type: ignore[possibly-undefined] + + if master: + n_params = sum(p.numel() for p in model.parameters()) + logger.info(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}") + + if master: + if recurrent_mode == "stochastic_depth": + logger.info( + f"Recurrent mode: stochastic_depth " + f"(n_loops sampled uniformly from [{stochastic_depth_min}, {stochastic_depth_max}])" + ) + else: + logger.info( + f"Recurrent mode: act (n_loops = cfg.max_loop_iters = {cfg.max_loop_iters})" + ) + + # ------------------------------------------------------------------ + # ClearML init (after model is built so we can log config) + # ------------------------------------------------------------------ + if master: + init_clearml(cfg, training_hparams) + + # ------------------------------------------------------------------ + # Optimizer + # ------------------------------------------------------------------ + optimizer = torch.optim.AdamW( + model.parameters(), + lr=lr, + weight_decay=wd, + betas=(0.9, 0.95), + fused="cuda" in device, + ) + + # ------------------------------------------------------------------ + # Resume from latest checkpoint (if any) + # ------------------------------------------------------------------ + start_step = 0 + existing_ckpts = _list_ckpts(ckpt_dir) + if existing_ckpts: + latest = existing_ckpts[-1] + if master: + logger.info(f"Resuming from checkpoint: {latest}") + start_step = load_checkpoint(model, optimizer, latest, ddp) + if master: + logger.success(f"Resumed at step {start_step}") + + # ------------------------------------------------------------------ + # Dataset + DataLoader + # ------------------------------------------------------------------ + dataset = FineWebEduDataset( + encoding, + seq_len, + rank, + world_size, + dataset_path=dataset_path, + dataset_subset=dataset_subset, + ) + loader = DataLoader(dataset, batch_size=micro_batch, num_workers=4, pin_memory=True) + + # ------------------------------------------------------------------ + # Training loop + # ------------------------------------------------------------------ + if master: + os.makedirs(ckpt_dir, exist_ok=True) + + model.train() + data_iter = iter(loader) + t0 = time.perf_counter() + step = start_step + + while step < total_steps: + cur_lr = get_lr(step, warmup_steps, total_steps, lr, min_lr) + for g in optimizer.param_groups: + g["lr"] = cur_lr + + optimizer.zero_grad() + loss_accum = 0.0 + + # Sample n_loops once per optimizer step. With FSDP/DDP, all ranks must + # run the same number of recurrent iterations to avoid all-gather + # ordering mismatch (same bug class as the ACT early-exit deadlock in + # commit 6c5659c). Broadcast from rank 0 so all ranks agree. + if recurrent_mode == "stochastic_depth": + if master: + n_loops_this_step = random.randint( + stochastic_depth_min, stochastic_depth_max + ) + else: + n_loops_this_step = 0 + if ddp: + nl_tensor = torch.tensor( + [n_loops_this_step], device=device, dtype=torch.int64 + ) + dist.broadcast(nl_tensor, src=0) + n_loops_this_step = int(nl_tensor.item()) + bypass_act_this_step = True + else: + n_loops_this_step = None + bypass_act_this_step = False + + for micro_step in range(grad_accum): + try: + x, y = next(data_iter) + except StopIteration: + data_iter = iter(loader) + x, y = next(data_iter) + + x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True) + y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True) + + sync = ( + nullcontext() + if (not ddp or micro_step == grad_accum - 1) + else model.no_sync() + ) + + with sync, amp_ctx: + logits = model( + x, + n_loops=n_loops_this_step, + bypass_act=bypass_act_this_step, + ) + loss = nn.functional.cross_entropy( + logits.view(-1, vocab_size), y.view(-1) + ) + loss = loss / grad_accum + + loss.backward() + loss_accum += loss.item() + + if ddp: + grad_norm = model.clip_grad_norm_(1.0) + else: + grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + step += 1 + + if master and step % log_every == 0: + dt = time.perf_counter() - t0 + tok_per_sec = global_batch_tok * log_every / dt + tokens_seen = step * global_batch_tok + + n_loops_display = ( + n_loops_this_step + if n_loops_this_step is not None + else cfg.max_loop_iters + ) + logger.info( + f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} " + f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} " + f"| {tok_per_sec / 1e6:.2f}M tok/s " + f"| {tokens_seen / 1e9:.1f}B tokens seen " + f"| mode={recurrent_mode} n_loops={n_loops_display}" + ) + + log_clearml("loss", loss_accum, step) + log_clearml("grad_norm", float(grad_norm), step) + log_clearml("lr", cur_lr, step) + log_clearml("throughput_mtok_s", tok_per_sec / 1e6, step) + log_clearml("tokens_seen_B", tokens_seen / 1e9, step) + log_clearml("n_loops", float(n_loops_display), step) + + t0 = time.perf_counter() + + if step % ckpt_every == 0: + save_checkpoint( + model, optimizer, step, cfg, vocab_size, ckpt_dir, ddp, master + ) + + # Final checkpoint + if step > start_step and step % ckpt_every != 0: + save_checkpoint(model, optimizer, step, cfg, vocab_size, ckpt_dir, ddp, master) + + # ------------------------------------------------------------------ + # Tear down distributed process group before generation + # ------------------------------------------------------------------ + if ddp: + dist.barrier() + dist.destroy_process_group() + + # ------------------------------------------------------------------ + # Post-training generation test (rank 0 only) + # ------------------------------------------------------------------ + # Reconstruct a fresh model from the checkpoint so we don't need + # FSDP — the process group is already torn down at this point. + if master: + gen_device = device if not ddp else f"cuda:{local_rank}" + run_generation_test(cfg, ckpt_dir, encoding, gen_device) + + if master: + logger.success("Training complete.") + + +if __name__ == "__main__": + main() diff --git a/training/3b_fine_web_edu.py b/training/3b_fine_web_edu.py index e980302..6179974 100644 --- a/training/3b_fine_web_edu.py +++ b/training/3b_fine_web_edu.py @@ -436,7 +436,7 @@ def main(): # Optimizer # ------------------------------------------------------------------ optimizer = torch.optim.AdamW( - model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True + model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused="cuda" in device ) # ------------------------------------------------------------------ diff --git a/training/requirements.txt b/training/requirements.txt index e3348c5..dc7a6f4 100644 --- a/training/requirements.txt +++ b/training/requirements.txt @@ -1,4 +1,6 @@ torch>=2.11.0 datasets>=3.6.0 loguru>=0.7.3 -open-mythos \ No newline at end of file +open-mythos +clearml>=1.16.0 +pyarrow>=15.0.0 \ No newline at end of file