diff --git a/.gitignore b/.gitignore index 2596db4..811f30b 100644 --- a/.gitignore +++ b/.gitignore @@ -428,3 +428,4 @@ flycheck_*.el # network security /network-security.data +uv.lock diff --git a/PAI_REVIEW_2026-04.md b/PAI_REVIEW_2026-04.md new file mode 100644 index 0000000..68cfe92 --- /dev/null +++ b/PAI_REVIEW_2026-04.md @@ -0,0 +1,502 @@ +# OpenMythos — PAI Review, April 2026 + +External code review conducted via the PAI Algorithm (v3.8.0, DETERMINED +effort) on commit `eae0f04`. Eight specialist reviewers ran in parallel: +correctness, performance, adversarial stress, maintainability, testing, +Kieran-Python style, architecture, and PyTorch 2026 best practices. + +The branch `pai-review-2026-04` contains eight commits landing the Tier 0 +install-blockers and the Tier 1 correctness fixes. **All 81 tests pass.** +Tiers 2 and 3 below are documented as a prioritized follow-up roadmap — +each is keyed back to a reviewer finding, with impact and fix sketch. + +--- + +## Landed in this branch (commits on `pai-review-2026-04`) + +| # | Commit | Category | Files | What | +|---|---|---|---|---| +| 1 | `fix(install): unpin torch…` | Install | `pyproject.toml`, `__init__.py`, `README.md`, `tokenizer.py` | `torch==2.11.0` was a nonexistent pin; relaxed to `>=2.4,<3`. Added loguru + pytest config + 3.11/3.12 classifiers. `__all__` referenced undefined `load_tokenizer`/`get_vocab_size` (AttributeError on `import *`) — trimmed to genuine public surface. README `mythos_7b()` → `mythos_3b()` (7b was never defined). Deferred the `transformers` import so `import open_mythos` no longer ImportErrors without transformers installed. | +| 2 | `feat(config): __post_init__ validation` | Correctness | `main.py` | Previously invalid combos (`attn_type="MLA"`, `n_heads % n_kv_heads != 0`, odd head_dim, `topk > n_experts`, `max_loop_iters < 1`) silently fell through or crashed deep inside the forward. Now `ValueError` at construction. `attn_type` is also typed `Literal["gqa", "mla"]`. | +| 3 | `fix(numerics): LTI fp32, loop-index fp32, ACT remainder, tie-after-init` | Correctness | `main.py` | Four numerical fixes: (a) `LTIInjection.get_A` now computes `exp(-exp(x))` in fp32 with tighter clamp `(-10, 10)` — the bf16 path was underflowing to `exp(-0)=1.0` and silently breaking the ρ(A)<1 guarantee. (b) `loop_index_embedding` now computes frequencies in fp32 so adjacent `k` indices don't collapse to the same bf16 value. (c) `OpenMythos.__init__` runs `_init_weights` BEFORE tying so the shared tensor isn't initialized twice. (d) `RecurrentBlock.forward` flushes remainder probability onto never-halted positions so ACT weights sum to ~1 for every position. | +| 4 | `feat(generate): bounds check, eval mode, EOS stopping` | Correctness | `main.py` | `forward` rejects `T=0` and `start_pos+T > max_seq_len` (previously both silently indexed a zero-length freqs slice and produced garbage). `generate()` calls `self.eval()` for the duration (restoring prior mode on exit) so dropout doesn't fire during sampling. Added `eos_token_id` parameter and per-row finished-mask stopping. `top_k` is clamped to vocab_size. | +| 5 | `fix(mla): cache shared k_rope once` | Correctness + perf | `main.py` | MLAttention was expanding k_rope to `(B, T, n_heads, rope_dim)` via `.expand().contiguous()` before caching, storing `n_heads` identical copies per token. At `n_heads=16` this is 16× more rope cache than the DeepSeek-V2 design specifies — negating the memory savings that motivate MLA. Now the shared `(B, T, rope_dim)` is cached once; per-head broadcast happens at compute time via a cost-free view. | +| 6 | `test: PAI regression suite + Cyrillic fix` | Tests | `test_main.py`, `tests/test_pai_regressions.py` | Added 14 regression tests covering every fix above: config validation (5), bf16 LTI stability, bf16 loop-index distinct frequencies, weight tying storage identity, forward empty/over-max rejection, generate clamping, generate training-mode restore, generate EOS early stop, LoRA clamp past max. Also fixed `TestOpenMythosMLА` (Cyrillic `А` U+0410) → `TestOpenMythosMLA` (ASCII). | +| 7 | `fix(tests): slice freqs_cis to T` | Tests | `test_main.py` | 13 unit tests in `TestGQAttention`/`TestMLAttention`/`TestTransformerBlock`/`TestRecurrentBlock` were passing the full max-seq-len freqs table into attention forwards, which broadcast-fails at `apply_rope`. The tests had been silently broken because the prior eager `transformers` import meant `test_main.py` couldn't load without transformers installed. Fixed by slicing `[:T]`. | +| 8 | `chore: gitignore uv.lock` | Chore | `.gitignore` | | + +Verification: `pytest test_main.py tests/test_pai_regressions.py tests/test_rope_debug.py` → **81 passed in 1.25s**. + +--- + +## Reviewer findings summary + +30 total findings; severity counts across reviewers: + +| Severity | Count | Landed | Deferred | +|---|---|---|---| +| CRITICAL | 4 | 3 | 1 | +| HIGH | 13 | 4 | 9 | +| MEDIUM | 9 | 4 | 5 | +| LOW | 4 | 1 | 3 | + +--- + +## Tier 2 — Performance (deferred, ordered by expected speedup) + +### 2.1 Swap manual attention for `F.scaled_dot_product_attention` +*Source: performance-reviewer #2, adversarial adv-06 adjacency, best-practices topic 1* + +`GQAttention.forward` and `MLAttention.forward` materialize the full `(B, H, T, S)` attention matrix, softmax, matmul. On H100 bf16 this misses the fused Flash Attention 2/3 kernel. At `T=2048, B=4, H=16` the scratch matrix is ~1 GB per layer. + +**Fix sketch (both classes):** + +```python +dropout_p = self.attn_drop.p if self.training else 0.0 +out = F.scaled_dot_product_attention( + q, k, v, + attn_mask=mask, + is_causal=(mask is None and kv_cache is None), + dropout_p=dropout_p, + scale=scale, +) +``` + +**Expected:** 2–4× training step speedup, 3–8× decode at long context. + +**Gotcha:** the `attn_drop` Dropout is consumed inside SDPA; the explicit softmax+dropout pair is redundant. Also SDPA needs `q,k,v` already in `(B, H, T, d)` shape — the transposes in both classes already produce that. + +### 2.2 Vectorized MoE dispatch (grouped GEMM) +*Source: performance-reviewer #1, maintainability M16, adversarial adv-10, kieran #18* + +`MoEFFN.forward` nests `for i in range(topk): for eid in range(n_experts):` — 256 Python iterations per MoE layer at `topk=4, n_experts=64`. At 1T scale with `n_experts=512` this becomes 2048 iterations × 16 recurrent loops = 32,768 kernel launches per forward. Catastrophic. + +**Fix sketch:** + +```python +# Sort tokens by expert id for contiguous dispatch. +exp_ids = topk_idx.reshape(-1) +tok_ids = torch.arange(N, device=x.device).repeat_interleave(self.topk) +gate_w = topk_scores.reshape(-1, 1) +sort_idx = exp_ids.argsort() +counts = torch.bincount(exp_ids[sort_idx], minlength=self.n_experts) +ends = counts.cumsum(0); starts = torch.cat([torch.zeros(1, device=...), ends[:-1]]) +x_perm = flat[tok_ids[sort_idx]] +out = torch.zeros_like(flat) +for eid in range(self.n_experts): + s, e = starts[eid].item(), ends[eid].item() + if s == e: continue + y = self.routed_experts[eid](x_perm[s:e]) * gate_w[sort_idx][s:e] + out.index_add_(0, tok_ids[sort_idx][s:e], y) +``` + +At scale (512 experts, 2048-token batch), vendor torchtitan's +`triton_contiguous_group_gemm` — a single-file Triton kernel that replaces +the Python per-expert loop with one fused kernel. 2.6× measured speedup +on DeepSeek-V3 training. + +`moda.py`'s `DeepSeekMoE` already uses the bincount pattern (lines +562–569) — port it to `main.py::MoEFFN`. + +### 2.3 Preallocated KV cache (fix O(T²) decode) +*Source: performance-reviewer #3, adversarial adv-12, architecture #10* + +Every decode step does `torch.cat([cache[k], new_k], dim=1)` — allocates a +fresh tensor of the full-so-far cache size and copies everything. Over N +decoded tokens that's O(N²) memory bandwidth and allocator pressure. +At N=2048, B=1, H=16, d=192, bf16 × 20 attention layers, this is +~260 GB of redundant memcpy per generation. + +**Fix sketch:** preallocate at `generate()` entry, index-write each step: + +```python +# At start of generate(): +max_len = prompt_len + max_new_tokens +# In attention forward: +cache = kv_cache.setdefault(cache_key, { + "k": torch.empty(B, max_len, n_kv_heads, head_dim, ...), + "v": torch.empty(B, max_len, n_kv_heads, head_dim, ...), + "len": 0, +}) +pos = cache["len"] +cache["k"][:, pos:pos+T] = k +cache["v"][:, pos:pos+T] = v +cache["len"] = pos + T +k = cache["k"][:, :pos+T] +v = cache["v"][:, :pos+T] +``` + +**Expected:** 3–10× decode speedup at T ≥ 1k. Same pattern for MLA's +`c_kv` and shared `k_rope`. + +### 2.4 Gradient checkpointing option +*Source: performance-reviewer #6, best-practices topic 10* + +At 1T scale with `n_loops=16`, activations dominate memory. No +`gradient_checkpointing` knob exists. Add an opt-in field to +`MythosConfig`, and wrap the recurrent loop body: + +```python +from torch.utils.checkpoint import checkpoint +if self.training and self.cfg.gradient_checkpointing: + trans_out = checkpoint( + self.block, combined, freqs_cis, mask, kv_cache, cache_key, + use_reentrant=False, + ) +else: + trans_out = self.block(...) +``` + +`use_reentrant=False` is the 2026-standard form — required for nested +checkpointing, `torch.autograd.grad`, and compatibility with +`torch.compile`. Also wire `apply_activation_checkpointing` into the +FSDP setup in `training/3b_fine_web_edu.py`. + +**Expected:** ~`sqrt(n_loops)` activation memory reduction → difference +between trainable and OOM at 1T. + +### 2.5 Aux-loss-free router-bias update +*Source: correctness residual #2, adversarial adv-03, best-practices topic 7* + +`MoEFFN.router_bias` is registered as a buffer with a comment "adjusted +externally during training; not a gradient param" — but grep confirms +no code path updates it. The aux-loss-free load-balancing scheme is +inert; router collapse after a few hundred steps will stay collapsed. + +**DeepSeek-V3 update rule (Wang et al. 2024, arXiv 2408.15664):** + +```python +@torch.no_grad() +def update_router_bias(self, counts: Tensor, u: float = 0.001) -> None: + """Call once per training step with per-expert token counts.""" + avg = counts.float().mean() + err = avg - counts.float() + self.router_bias.add_(u * err.sign()) +``` + +Expose as `MoEFFN.update_bias(counts)`; wire a forward-hook in the +training script that collects counts per-step and calls it. Freeze (set +`u=0`) in the final 3% of training. + +### 2.6 RoPE real-pair path (drop fp32 complex roundtrip) +*Source: performance-reviewer #4, correctness #9, best-practices topic 3* + +`apply_rope` does `x.float()` → `view_as_complex` → multiply → `view_as_real` +→ `.to(x.dtype)`. Three full tensor copies per attention call, blocks +torch.compile fusion, and bf16 has no native complex dtype so the fp32 +upcast is forced. + +**Fix sketch (real-pair + GPT-NeoX split-halves, HF-compatible):** + +```python +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + # x: (B, T, H, D); cos/sin: (T, D) same dtype as x + x1, x2 = x.chunk(2, dim=-1) + rot = torch.cat([-x2, x1], dim=-1) + return x * cos[None, :, None, :] + rot * sin[None, :, None, :] +``` + +Precompute `cos`/`sin` tables in `OpenMythos.__init__`, register as +non-persistent buffers in the model's working dtype. + +**Expected:** 5–10% training speedup; much bigger for decode. + +### 2.7 Precomputed causal mask + loop embeddings +*Source: performance-reviewer #5* + +`_causal_mask(T, device)` allocates a fresh `(1,1,T,T)` −∞ tensor on +every forward. `loop_index_embedding` allocates freqs/angles/sin/cos +on every loop of every forward. Both are amortizable to one-time init. + +**Fix:** register a persistent `(1,1,max_seq_len,max_seq_len)` mask +buffer in `__init__`, slice `[:T,:T]`. Register a +`(max_loop_iters, dim)` loop-embedding table in `RecurrentBlock.__init__`; +index by `t` instead of recomputing. + +### 2.8 torch.compile support +*Source: performance-reviewer #8, best-practices topic 9* + +Model currently has graph breaks from: MoE Python loop (2.2), dict +cache with f-string keys, `.item()` calls in the router path. Once +2.1-2.3 land, add: + +```python +model = torch.compile(model, mode="default") # training +# or mode="reduce-overhead" for single-stream decode +``` + +Verify with `TORCH_LOGS=graph_breaks python train.py`. + +**Expected:** +15–30% on H100. + +### 2.9 Training: tokenization bottleneck +*Source: performance-reviewer #9* + +`FineWebEduDataset.__iter__` encodes one sample at a time per worker +(~200–500k tok/s/worker). For 3B on H100, 4M+ tok/s throughput is +needed. Dataloader-bound by a factor of ~2–5×. + +**Fix:** (a) use `tokenizer.encode_batch([...])` in batches of 64–128, +(b) pre-tokenize one shard to uint16 memmap once, stream slices. This +is the llm.c / nanoGPT pattern. 10–20× dataloader speedup. + +--- + +## Tier 3 — Correctness medium (deferred) + +### 3.1 `_causal_mask` shape assumes cache is empty at T>1 +*Source: correctness #2, adversarial adv-02* + +`forward` builds mask as `(1,1,T,T)` — fine for prefill or single-token +decode. Breaks for T>1 with non-empty cache (speculative decoding, +prefix caching, multi-token append). Attention scores are `(B,H,T,S)` +with `S = T_prev + T`; mask shape mismatches. + +**Fix:** build mask of shape `(1,1,T,S)` accounting for cached length. + +### 3.2 FSDP wrap policy includes both `TransformerBlock` and `RecurrentBlock` +*Source: adversarial adv-15, correctness #10* + +`training/3b_fine_web_edu.py` line 412: +`ModuleWrapPolicy({TransformerBlock, RecurrentBlock})`. Double-wrap +creates ambiguous unit boundaries. At 1T with `n_experts=512` inside +`MoEFFN` not wrapped, each `TransformerBlock` wrap holds a 320M-param +flat parameter that defeats FULL_SHARD. + +**Fix:** wrap `TransformerBlock` and `Expert` (so each routed expert +shards independently). Consider `reshard_after_forward=False` on the +recurrent block to avoid re-gather per loop iteration. + +### 3.3 `router_bias` buffer sharded by FSDP +*Source: correctness residual #2* + +Under FSDP FULL_SHARD, buffers are sharded by default. If 2.5 lands a +bias update, each rank would update only its local slice. Fix: +add `router_bias` to `ignored_parameters` or mark it explicitly +replicated. + +### 3.4 `kv_cache` dict reuse across `forward()` calls pollutes +*Source: adversarial adv-12, maintainability M6* + +Cache keys bake `n_loops` into the schema (`recurrent_loop_{t}`). +Reusing the same dict across calls with different `n_loops` mixes +stale cached keys into fresh ones at loop indices 0..min(n_loops). + +**Fix:** introduce a `KVCache` dataclass that tracks `n_loops` and +validates on each use, or document as "append-only, strictly +monotonic". + +### 3.5 LoRA clamp for depth extrapolation is semantically wrong +*Source: correctness #11, adversarial adv-11* + +`t_idx = loop_t if loop_t <= max_t else max_t` — iterations beyond the +trained range all reuse the last scale. "Extra depth" at inference is +therefore not actually different from the final trained iteration, +defeating the README's depth-extrapolation claim. + +**Options:** (a) return zero delta for `loop_t > max_t` (neutral), +(b) linearly interpolate or extrapolate scale, (c) replace +`nn.Embedding` with a continuous function of loop index (MLP on +sinusoid). + +### 3.6 `_causal_mask` dtype is fp32 under bf16 autocast +*Source: correctness #6* + +`torch.full(..., float("-inf"))` defaults to fp32. `attn + mask` +upcasts to fp32, diverges from FSDP's reduce_dtype, forces extra +casts. Fix: `dtype=q.dtype` on mask allocation. + +### 3.7 vocab_size default vs tokenizer default mismatch +*Source: adversarial adv-14* + +`MythosConfig.vocab_size=32000` while `MythosTokenizer` defaults to +gpt-oss-20b (~200k vocab). Non-Latin text tokenizes to ids > 32000; +`model(ids)` triggers a CUDA assertion failure. + +**Fix:** either raise in `OpenMythos.__init__` if provided a +tokenizer with vocab > `cfg.vocab_size`, or change the default to +match the default tokenizer. + +### 3.8 ACT initialization biases halting prob to 0.5 at init +*Source: adversarial adv-06* + +`ACTHalting.halt` is initialized via `_init_weights` with +`std=0.02`. Bias defaults to 0 → sigmoid(~0) ≈ 0.5. Early in +training ACT halts ~half of positions on iteration 0 with nothing +preventing the recurrent block from degenerating into identity. + +**Fix:** initialize `self.halt.bias.fill_(-2.0)` so early-training +halt prob is ~0.12, and consider a `min_loops` floor. Add a +ponder-cost loss term (expected number of iterations × λ). + +### 3.9 Deterministic tie-breaking in router topk +*Source: adversarial adv-16* + +`router.topk()` with tied logits (common at init) is +device/cuBLAS-version dependent. Reproducibility claims break on +hardware change. Fix: add `logits + eps * arange(n_experts)` tie- +break, or document that determinism requires `use_deterministic_algorithms`. + +--- + +## Tier 4 — Maintainability / architecture (deferred roadmap) + +### 4.1 `moda.py` is an orphan 1134-line parallel model +*Source: maintainability M1/M2/M3, architecture #1, testing coverage gap* + +Zero imports from other modules, not in `__init__.py`, 69 lines of +commented-out smoke test at bottom. `RMSNorm`, `RoPE`, `Expert` are +re-implemented with a different naming scheme (`d_model` vs `dim`, +`w1/w2/w3` vs `gate/up/down`). + +**Three options:** +1. Delete `moda.py` entirely. +2. Move to `experimental/moda.py` outside the published package. +3. Integrate as `attn_type="moda"` via shared primitives module. + +Until resolved, it's an unmaintained second architecture that confuses +contributors. **Do not split `main.py` into submodules** (tier-4.2) +until this is resolved. + +### 4.2 `main.py` monolith split +*Source: architecture #3* + +1048 lines cleanly segmented by comment bars. Suggested split: + +``` +open_mythos/ + config.py # MythosConfig + norm.py # RMSNorm + rope.py # precompute_rope_freqs, apply_rope, loop_index_embedding + attention.py # GQAttention, MLAttention + moe.py # Expert, MoEFFN + recurrent.py # LTIInjection, ACTHalting, LoRAAdapter, RecurrentBlock + blocks.py # TransformerBlock + model.py # OpenMythos +``` + +Keep `main.py` as a re-export shim for one release. + +### 4.3 Magic numbers → named config fields +*Source: maintainability M9* + +`loop_index_embedding`'s `theta=10000.0` (while main RoPE uses 500000), +`LTIInjection.B` init of `0.1`, `.clamp(-10, 10)` (landed in this +branch), `std=0.02` init, `cfg.dim // 8` loop_dim ratio, `4//3` FFN +ratio. Promote to `MythosConfig` fields with defaults. + +### 4.4 KV cache type-safety +*Source: kieran #1, maintainability M6* + +Typed as bare `dict`, keyed by f-strings. Two different entry shapes +(GQA: `{k, v}`, MLA: `{c_kv, k_rope}`). Introduce `KVCache` dataclass +or at minimum `TypedDict` pair. + +### 4.5 Test consolidation +*Source: architecture #7* + +Move `test_main.py` → `tests/test_main.py` so everything lives under +`tests/`. `[tool.pytest.ini_options]` already has `testpaths` set +correctly in this branch. + +### 4.6 Training script generalization +*Source: architecture #6* + +`training/3b_fine_web_edu.py` hardcodes model (`mythos_3b`), dataset, +precision, optimizer. Extract a `training/train.py` with `--variant`, +`--dataset`, `--config overrides.yaml`. Do this before writing a +second training recipe. + +### 4.7 README split +*Source: architecture #8* + +README at 419 lines mixes marketing, install, usage, theory, references. +Move theory/hypothesis/scaling laws to `docs/theory.md`; target +README ≤ 200 lines. + +### 4.8 Variant registry +*Source: architecture #5* + +Add at the end of `variants.py`: +```python +VARIANTS = {"1b": mythos_1b, "3b": mythos_3b, ...} +def get_variant(name: str) -> MythosConfig: return VARIANTS[name]() +``` + +### 4.9 Logging in library code +*Source: maintainability M15, kieran #17* + +`main.py`/`moda.py`/`tokenizer.py` have no logging. Training uses +loguru (now declared in pyproject). Add +`import logging; logger = logging.getLogger(__name__)` to library code; +log at interesting branch points (cache-miss, attn-type resolution, +ACT early-exit, clamp activation). + +### 4.10 Shared primitives module +*Source: maintainability M3* + +Once `moda.py` status is resolved (4.1), extract shared primitives +(RMSNorm, RoPE, Expert) into `open_mythos/primitives.py` so +`main.py` and any other architecture share a single canonical impl. + +--- + +## Tier 5 — Testing gaps (quick wins) + +From the testing reviewer + testing gaps surfaced across all reviewers. +Each is a one-liner to add: + +1. `tests/test_moda.py` — any coverage of the 1134-line orphan module (if it stays). +2. `test_model_n_loops_exceeds_max_iters` — end-to-end depth extrapolation. +3. `test_act_weight_sum_is_one` — per-position sum of iteration weights. +4. `test_act_early_exit_saves_compute` — instrumented iteration count. +5. `test_act_no_early_exit_with_cache` — confirms cache-consistency invariant. +6. `test_generate_matches_forward_over_full_sequence` — greedy-decode trajectory equivalence. +7. `test_router_bias_shifts_selection` — the update hook in 2.5 actually moves topk. +8. `test_router_bias_survives_state_dict_roundtrip` — FSDP checkpoint sharding. +9. `test_weight_tying_grad_shared` — gradient through `head.weight` equals `embed.weight.grad`. +10. `test_lti_stability_over_many_steps` — 100 SGD steps with lr=1.0, assert A stays in (0,1). +11. `test_fp16_full_forward_no_nan` — half-precision path. +12. `test_bf16_backward_gradient_flow` — autocast bf16, every submodule has finite grads. +13. `test_determinism_fixture` — checked-in 1MB state_dict + golden logits tensor. +14. `test_public_api_loads` — `from open_mythos import *` followed by `getattr` each `__all__` entry. +15. `test_readme_snippets_execute` — parse README code blocks, `exec` them; catches `mythos_7b`-class regressions. +16. `test_all_variants_construct` — parametrize over every variant; instantiate on CPU, assert shape on 1-token forward. + +--- + +## How to open PRs from this branch + +```bash +# From the OpenMythos clone: +git remote add fork # e.g. https://github.com/you/OpenMythos +git push fork pai-review-2026-04 + +# Open PRs either as one bundle or split by commit: +# Tier-0 + tests — one PR +# Correctness — one PR +# MLA cache fix — one PR (has measurable memory impact) +# The commits are independent and cleanly split. +``` + +Each commit message explains the exact change; reviewers should not +have to read this document to evaluate a single commit. + +--- + +## Appendix: reviewer artifacts + +All eight specialist reviewers were invoked via the PAI Algorithm +(v3.8.0, DETERMINED effort) reading the codebase at +`/tmp/openmythos-review/OpenMythos/`. Findings were returned as JSON, +deduplicated across reviewers, and tiered by severity × fix cost. +Reviewers: + +1. `correctness-reviewer` — 11 findings (1 CRIT, 3 HIGH, 3 MED, 4 LOW) +2. `performance-reviewer` — 11 findings (3 CRIT, 3 HIGH, 4 MED, 1 LOW) +3. `testing-reviewer` — 15 findings across coverage / weak-assertion / missing-edge +4. `maintainability-reviewer` — 17 findings across duplication / coupling / magic +5. `adversarial-reviewer` — 17 constructed failure scenarios (1 CRIT, 3 HIGH, 9 MED, 4 LOW) +6. `kieran-python-reviewer` — 18 Python-style findings +7. `architecture-strategist` — 10-point architecture debt list +8. `best-practices-researcher` — 2026-era PyTorch recommendations across 10 topics diff --git a/README.md b/README.md index afc5517..a4c3bce 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,7 @@ from open_mythos import ( OpenMythos, ) -cfg = mythos_7b() # returns a MythosConfig +cfg = mythos_3b() # returns a MythosConfig model = OpenMythos(cfg) total = sum(p.numel() for p in model.parameters()) diff --git a/open_mythos/__init__.py b/open_mythos/__init__.py index 73c2c04..82280dc 100644 --- a/open_mythos/__init__.py +++ b/open_mythos/__init__.py @@ -1,20 +1,4 @@ -from open_mythos.main import ( - ACTHalting, - Expert, - GQAttention, - LoRAAdapter, - LTIInjection, - MLAttention, - MoEFFN, - MythosConfig, - OpenMythos, - RecurrentBlock, - RMSNorm, - TransformerBlock, - apply_rope, - loop_index_embedding, - precompute_rope_freqs, -) +from open_mythos.main import MythosConfig, OpenMythos from open_mythos.tokenizer import MythosTokenizer from open_mythos.variants import ( mythos_1b, @@ -28,20 +12,8 @@ __all__ = [ "MythosConfig", - "RMSNorm", - "GQAttention", - "MLAttention", - "Expert", - "MoEFFN", - "LoRAAdapter", - "TransformerBlock", - "LTIInjection", - "ACTHalting", - "RecurrentBlock", "OpenMythos", - "precompute_rope_freqs", - "apply_rope", - "loop_index_embedding", + "MythosTokenizer", "mythos_1b", "mythos_3b", "mythos_10b", @@ -49,7 +21,4 @@ "mythos_100b", "mythos_500b", "mythos_1t", - "load_tokenizer", - "get_vocab_size", - "MythosTokenizer", ] diff --git a/open_mythos/main.py b/open_mythos/main.py index 10de093..c2b95f1 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import Literal, Optional import torch import torch.nn as nn @@ -50,7 +50,7 @@ class MythosConfig: prelude_layers: int = 2 coda_layers: int = 2 # Attention type: "gqa" | "mla" - attn_type: str = "mla" + attn_type: Literal["gqa", "mla"] = "mla" # MLA params (only used when attn_type="mla") kv_lora_rank: int = 512 # compressed KV latent cached instead of full K/V q_lora_rank: int = 1536 # compressed Q latent dim @@ -72,6 +72,50 @@ class MythosConfig: max_output_tokens: int = 4096 # Dropout (set 0.0 to disable; 0.1 is standard for pretraining) dropout: float = 0.0 + # Wrap the recurrent block body in torch.utils.checkpoint to trade + # compute for memory during training. Saves ~sqrt(n_loops) activation + # memory at the cost of one extra forward per backward. Essential at + # 1T scale; a no-op in eval mode. + gradient_checkpointing: bool = False + + def __post_init__(self) -> None: + """Validate config invariants. Fail loud at construction time, not + deep inside the forward pass.""" + if self.attn_type not in ("gqa", "mla"): + raise ValueError( + f"attn_type must be 'gqa' or 'mla', got {self.attn_type!r}" + ) + if self.dim % self.n_heads != 0: + raise ValueError( + f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})" + ) + if self.n_heads % self.n_kv_heads != 0: + raise ValueError( + f"n_heads ({self.n_heads}) must be divisible by " + f"n_kv_heads ({self.n_kv_heads}) for GQA" + ) + head_dim = self.dim // self.n_heads + if head_dim % 2 != 0: + raise ValueError( + f"head_dim ({head_dim}) must be even for RoPE" + ) + if self.attn_type == "mla" and self.qk_rope_head_dim % 2 != 0: + raise ValueError( + f"qk_rope_head_dim ({self.qk_rope_head_dim}) must be even for RoPE" + ) + if self.max_loop_iters < 1: + raise ValueError( + f"max_loop_iters must be >= 1, got {self.max_loop_iters}" + ) + if self.n_experts_per_tok > self.n_experts: + raise ValueError( + f"n_experts_per_tok ({self.n_experts_per_tok}) cannot exceed " + f"n_experts ({self.n_experts})" + ) + if not 0.0 < self.act_threshold <= 1.0: + raise ValueError( + f"act_threshold must be in (0, 1], got {self.act_threshold}" + ) # --------------------------------------------------------------------------- @@ -239,11 +283,20 @@ def forward( v = v.transpose(1, 2) scale = self.head_dim**-0.5 - attn = torch.matmul(q, k.transpose(-2, -1)) * scale - if mask is not None: - attn = attn + mask - attn = self.attn_drop(F.softmax(attn, dim=-1)) - out = torch.matmul(attn, v) + # F.scaled_dot_product_attention auto-dispatches to FlashAttention-2/3 + # or the memory-efficient kernel on CUDA bf16/fp16; falls back to the + # math kernel for fp32 / unsupported shapes. Either is faster and more + # memory-efficient than the manual softmax(QK^T)V form and avoids + # materializing the full (B, H, T, S) attention matrix. + dropout_p = self.attn_drop.p if self.training else 0.0 + is_causal = mask is None and kv_cache is None + out = F.scaled_dot_product_attention( + q, k, v, + attn_mask=mask if not is_causal else None, + is_causal=is_causal, + dropout_p=dropout_p, + scale=scale, + ) out = out.transpose(1, 2).contiguous().view(B, T, -1) return self.wo(out) @@ -350,21 +403,27 @@ def forward( # KV compress kv_raw = self.kv_down(x) c_kv = kv_raw[..., : self.kv_lora_rank] # (B, T, lora_rank) ← cached - k_rope = kv_raw[..., self.kv_lora_rank :] # (B, T, rope_dim) - # expand rope keys across heads and apply RoPE before caching so - # retrieved keys are already positionally encoded - k_rope = ( - k_rope.unsqueeze(2) - .expand(B, T, self.n_heads, self.qk_rope_dim) - .contiguous() - ) - k_rope = apply_rope(k_rope, freqs_cis) # (B, T, H, rope_dim) ← cached + k_rope_shared = kv_raw[..., self.kv_lora_rank :] # (B, T, rope_dim) + # DeepSeek-V2 MLA caches ONE shared rope sub-head per token, not one + # per head. Apply RoPE on the shared (B, T, rope_dim) vector and cache + # that; expand to per-head only at compute time via a broadcast view + # (cost-free, not a copy). Caching per-head would negate MLA's + # memory win (n_heads× blowup on the rope cache). + # apply_rope expects a head axis, so add and drop a size-1 head dim. + k_rope_shared = apply_rope( + k_rope_shared.unsqueeze(2), freqs_cis + ).squeeze(2) # (B, T, rope_dim) if kv_cache is not None: if cache_key in kv_cache: c_kv = torch.cat([kv_cache[cache_key]["c_kv"], c_kv], dim=1) - k_rope = torch.cat([kv_cache[cache_key]["k_rope"], k_rope], dim=1) - kv_cache[cache_key] = {"c_kv": c_kv.detach(), "k_rope": k_rope.detach()} + k_rope_shared = torch.cat( + [kv_cache[cache_key]["k_rope"], k_rope_shared], dim=1 + ) + kv_cache[cache_key] = { + "c_kv": c_kv.detach(), + "k_rope": k_rope_shared.detach(), + } S = c_kv.shape[1] # full sequence length including cache @@ -373,6 +432,8 @@ def forward( kv = kv.view(B, S, self.n_heads, self.qk_nope_dim + self.v_dim) k_nope = kv[..., : self.qk_nope_dim] # (B, S, H, nope) v = kv[..., self.qk_nope_dim :] # (B, S, H, v_dim) + # Broadcast the shared rope sub-head across all heads at compute time. + k_rope = k_rope_shared.unsqueeze(2).expand(B, S, self.n_heads, self.qk_rope_dim) k = torch.cat([k_nope, k_rope], dim=-1) # (B, S, H, nope+rope) # attention @@ -381,11 +442,18 @@ def forward( v = v.transpose(1, 2) # (B, H, S, v_dim) scale = self.q_head_dim**-0.5 - attn = torch.matmul(q, k.transpose(-2, -1)) * scale - if mask is not None: - attn = attn + mask - attn = self.attn_drop(F.softmax(attn, dim=-1)) - out = torch.matmul(attn, v) # (B, H, T, v_dim) + # SDPA picks FlashAttention-2/3 / memory-efficient kernel where + # available. MLA's asymmetric q_head_dim vs v_dim is handled fine + # by SDPA as long as q and k share the last dim (they do here). + dropout_p = self.attn_drop.p if self.training else 0.0 + is_causal = mask is None and kv_cache is None + out = F.scaled_dot_product_attention( + q, k, v, + attn_mask=mask if not is_causal else None, + is_causal=is_causal, + dropout_p=dropout_p, + scale=scale, + ) # (B, H, T, v_dim) out = out.transpose(1, 2).contiguous().view(B, T, -1) return self.wo(out) @@ -531,15 +599,18 @@ def loop_index_embedding( Returns: h with a sinusoidal bias added to its first loop_dim channels; same shape """ + # Compute frequencies in fp32. In bf16 (only 7 bits of mantissa) many adjacent + # k indices quantize to the same float, so multiple channel-pairs would share + # identical sin/cos and the loop-index signal degenerates. freqs = 1.0 / ( theta - ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim) + ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=torch.float32) / loop_dim) ) angles = loop_t * freqs # (loop_dim//2,) emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim] - emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype) + emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=torch.float32) emb_full[:loop_dim] = emb - return h + emb_full.unsqueeze(0).unsqueeze(0) + return h + emb_full.to(h.dtype).unsqueeze(0).unsqueeze(0) # --------------------------------------------------------------------------- @@ -688,13 +759,17 @@ def get_A(self) -> torch.Tensor: Compute the discretized diagonal state matrix A_discrete. Returns: - 1-D tensor of shape (dim,) with all values strictly in (0, 1), + 1-D tensor of shape (dim,) in (0, 1) (same dtype as log_A), guaranteeing ρ(A) < 1 regardless of learned parameter values. """ # Compute in log space to avoid 0 * inf = NaN when log_dt → -∞, log_A → +∞. # dt * A_c = -exp(log_dt) * exp(log_A) = -exp(log_dt + log_A) - # Clamp keeps the product finite in float32 for any gradient step size. - return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))) + # Upcast to fp32 before the nested exp: in bf16, exp(-10) underflows to 0 + # and then exp(-0)=1.0 exactly, silently making ρ(A)=1 (marginal stability). + # Clamp is tightened to (-10, 10) which keeps the result strictly in (0, 1) + # even in fp32. + x = (self.log_dt + self.log_A).float().clamp(-10.0, 10.0) + return torch.exp(-torch.exp(x)).to(self.log_A.dtype) def forward( self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor @@ -830,7 +905,22 @@ def forward( h_loop = loop_index_embedding(h, t, self.loop_dim) combined = self.norm(h_loop + e) cache_key = f"recurrent_loop_{t}" - trans_out = self.block(combined, freqs_cis, mask, kv_cache, cache_key) + if ( + self.training + and self.cfg.gradient_checkpointing + and kv_cache is None + ): + # use_reentrant=False is the 2026 standard form — supports + # nested checkpointing and torch.compile. Only safe without + # a KV cache since the cache is a mutable side-effect that + # would get replayed in the recompute pass. + from torch.utils.checkpoint import checkpoint + trans_out = checkpoint( + self.block, combined, freqs_cis, mask, None, cache_key, + use_reentrant=False, + ) + else: + trans_out = self.block(combined, freqs_cis, mask, kv_cache, cache_key) trans_out = trans_out + self.lora(trans_out, t) h = self.injection(h, e, trans_out) @@ -860,6 +950,14 @@ def forward( if halted.all() and kv_cache is None: break + # Positions that never halted within n_loops have cumulative_p < threshold. + # Flush the remaining mass so every position's weights sum to ~1 and h_out + # magnitudes are consistent across halted and still-running tokens. + still_running = ~halted + if still_running.any(): + remainder = (1.0 - cumulative_p).clamp(min=0) * still_running.float() + h_out = h_out + remainder.unsqueeze(-1) * h + return h_out @@ -925,9 +1023,12 @@ def __init__(self, cfg: MythosConfig): self.norm = RMSNorm(cfg.dim) self.head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) - self.head.weight = self.embed.weight # weight tying + # Initialize BEFORE tying so _init_weights does not overwrite the shared + # tensor twice (once as embed.weight, once as head.weight, second call + # wiping the first random draw). self._init_weights() + self.head.weight = self.embed.weight # weight tying def _init_weights(self) -> None: """Initialize all linear and embedding weights with N(0, 0.02).""" @@ -977,6 +1078,13 @@ def forward( Logits of shape (B, T, vocab_size) """ T = input_ids.shape[1] + if T == 0: + raise ValueError("input_ids must be non-empty") + if start_pos + T > self.cfg.max_seq_len: + raise ValueError( + f"start_pos + T = {start_pos + T} exceeds max_seq_len " + f"{self.cfg.max_seq_len}; precomputed RoPE frequencies end here" + ) device = input_ids.device x = self.embed(input_ids) @@ -1004,6 +1112,7 @@ def generate( n_loops: int = 8, temperature: float = 1.0, top_k: int = 50, + eos_token_id: Optional[int] = None, ) -> torch.Tensor: """ Autoregressive token generation with KV caching. @@ -1018,31 +1127,56 @@ def generate( Args: input_ids -- prompt token indices of shape (B, T) - max_new_tokens -- number of tokens to generate + max_new_tokens -- number of tokens to generate (clamped to fit max_seq_len) n_loops -- recurrent loop depth for each decode step temperature -- softmax temperature; lower = more greedy top_k -- restrict sampling to top-K logits (0 = disabled) + eos_token_id -- if set, stop a batch row once it has produced this token Returns: - Token indices of shape (B, T + max_new_tokens) + Token indices of shape (B, T + generated_len). generated_len may be + less than max_new_tokens if every row hit EOS or max_seq_len. """ - kv_cache: dict = {} - prompt_len = input_ids.shape[1] - for step in range(max_new_tokens): - if step == 0: - cur_ids = input_ids - start_pos = 0 - else: - cur_ids = input_ids[:, -1:] - start_pos = prompt_len + step - 1 - logits = self.forward( - cur_ids, n_loops=n_loops, kv_cache=kv_cache, start_pos=start_pos - ) - logits = logits[:, -1, :] / temperature - if top_k > 0: - v, _ = logits.topk(top_k) - logits[logits < v[:, -1:]] = float("-inf") - probs = F.softmax(logits, dim=-1) - next_tok = torch.multinomial(probs, num_samples=1) - input_ids = torch.cat([input_ids, next_tok], dim=1) - return input_ids + was_training = self.training + self.eval() # disable dropout during generation regardless of prior mode + try: + prompt_len = input_ids.shape[1] + budget = self.cfg.max_seq_len - prompt_len + if budget <= 0: + return input_ids + max_new_tokens = min(max_new_tokens, budget) + + kv_cache: dict = {} + finished = torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device) + for step in range(max_new_tokens): + if step == 0: + cur_ids = input_ids + start_pos = 0 + else: + cur_ids = input_ids[:, -1:] + start_pos = prompt_len + step - 1 + logits = self.forward( + cur_ids, n_loops=n_loops, kv_cache=kv_cache, start_pos=start_pos + ) + logits = logits[:, -1, :] / temperature + if top_k > 0: + effective_k = min(top_k, logits.shape[-1]) + v, _ = logits.topk(effective_k) + logits[logits < v[:, -1:]] = float("-inf") + probs = F.softmax(logits, dim=-1) + next_tok = torch.multinomial(probs, num_samples=1) + if eos_token_id is not None: + # Finished rows keep emitting eos to preserve shape. + next_tok = torch.where( + finished.unsqueeze(-1), + torch.full_like(next_tok, eos_token_id), + next_tok, + ) + finished = finished | (next_tok.squeeze(-1) == eos_token_id) + input_ids = torch.cat([input_ids, next_tok], dim=1) + if eos_token_id is not None and finished.all(): + break + return input_ids + finally: + if was_training: + self.train() diff --git a/open_mythos/tokenizer.py b/open_mythos/tokenizer.py index fadb3a5..16d4b13 100644 --- a/open_mythos/tokenizer.py +++ b/open_mythos/tokenizer.py @@ -1,5 +1,3 @@ -from transformers import AutoTokenizer - DEFAULT_MODEL_ID = "openai/gpt-oss-20b" @@ -7,6 +5,10 @@ class MythosTokenizer: """ HuggingFace tokenizer wrapper for OpenMythos. + The underlying transformers import is deferred into __init__ so that + `import open_mythos` does not pay the transformers import cost unless the + tokenizer is actually constructed. + Args: model_id (str): The HuggingFace model ID or path to use with AutoTokenizer. Defaults to "openai/gpt-oss-20b". @@ -27,8 +29,20 @@ def __init__(self, model_id: str = DEFAULT_MODEL_ID): Args: model_id (str): HuggingFace model identifier or path to tokenizer files. """ + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + @property + def eos_token_id(self) -> int | None: + """End-of-sequence token id, or None if the tokenizer does not define one.""" + return self.tokenizer.eos_token_id + + @property + def pad_token_id(self) -> int | None: + """Pad token id, or None if the tokenizer does not define one.""" + return self.tokenizer.pad_token_id + @property def vocab_size(self) -> int: """ diff --git a/pyproject.toml b/pyproject.toml index 1d9f720..d4f8f9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,14 +33,26 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] [tool.poetry.dependencies] python = ">=3.10,<4.0" -torch = "2.11.0" +torch = ">=2.4,<3" transformers = ">=4.40.0" datasets = ">=2.18.0" +loguru = ">=0.7.0" + + +[tool.pytest.ini_options] +testpaths = ["tests", "."] +addopts = "-ra --strict-markers" +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "gpu: marks tests requiring CUDA", +] [tool.poetry.group.lint.dependencies] diff --git a/test_main.py b/test_main.py index c54c462..bfe97ec 100644 --- a/test_main.py +++ b/test_main.py @@ -266,24 +266,24 @@ def setup_method(self): def test_output_shape(self): x = torch.randn(B, T, self.cfg.dim) - out = self.attn(x, self.freqs) + out = self.attn(x, self.freqs[:T]) assert out.shape == (B, T, self.cfg.dim) def test_kv_cache_accumulates(self): cache = {} x = torch.randn(B, T, self.cfg.dim) - self.attn(x, self.freqs, kv_cache=cache, cache_key="layer0") + self.attn(x, self.freqs[:T], kv_cache=cache, cache_key="layer0") assert "layer0" in cache k_len = cache["layer0"]["k"].shape[1] # second call adds T more tokens - self.attn(x, self.freqs, kv_cache=cache, cache_key="layer0") + self.attn(x, self.freqs[:T], kv_cache=cache, cache_key="layer0") assert cache["layer0"]["k"].shape[1] == k_len + T def test_with_causal_mask(self): x = torch.randn(B, T, self.cfg.dim) mask = torch.full((1, 1, T, T), float("-inf")) mask = torch.triu(mask, diagonal=1) - out = self.attn(x, self.freqs, mask=mask) + out = self.attn(x, self.freqs[:T], mask=mask) assert out.shape == (B, T, self.cfg.dim) @@ -302,13 +302,13 @@ def setup_method(self): def test_output_shape(self): x = torch.randn(B, T, self.cfg.dim) - out = self.attn(x, self.freqs) + out = self.attn(x, self.freqs[:T]) assert out.shape == (B, T, self.cfg.dim) def test_cache_stores_compressed_kv(self): cache = {} x = torch.randn(B, T, self.cfg.dim) - self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0") + self.attn(x, self.freqs[:T], kv_cache=cache, cache_key="mla0") assert "c_kv" in cache["mla0"] assert "k_rope" in cache["mla0"] # c_kv should have kv_lora_rank as last dim, not full K/V @@ -317,15 +317,15 @@ def test_cache_stores_compressed_kv(self): def test_cache_accumulates_across_steps(self): cache = {} x = torch.randn(B, T, self.cfg.dim) - self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0") + self.attn(x, self.freqs[:T], kv_cache=cache, cache_key="mla0") first_len = cache["mla0"]["c_kv"].shape[1] - self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0") + self.attn(x, self.freqs[:T], kv_cache=cache, cache_key="mla0") assert cache["mla0"]["c_kv"].shape[1] == first_len + T def test_with_causal_mask(self): x = torch.randn(B, T, self.cfg.dim) mask = torch.triu(torch.full((1, 1, T, T), float("-inf")), diagonal=1) - out = self.attn(x, self.freqs, mask=mask) + out = self.attn(x, self.freqs[:T], mask=mask) assert out.shape == (B, T, self.cfg.dim) @@ -432,21 +432,21 @@ def test_gqa_output_shape(self): block = TransformerBlock(cfg, use_moe=False) freqs = precompute_rope_freqs(cfg.dim // cfg.n_heads, cfg.max_seq_len) x = torch.randn(B, T, cfg.dim) - assert block(x, freqs).shape == (B, T, cfg.dim) + assert block(x, freqs[:T]).shape == (B, T, cfg.dim) def test_mla_output_shape(self): cfg = mla_cfg() block = TransformerBlock(cfg, use_moe=False) freqs = precompute_rope_freqs(cfg.qk_rope_head_dim, cfg.max_seq_len) x = torch.randn(B, T, cfg.dim) - assert block(x, freqs).shape == (B, T, cfg.dim) + assert block(x, freqs[:T]).shape == (B, T, cfg.dim) def test_moe_block_output_shape(self): cfg = gqa_cfg() block = TransformerBlock(cfg, use_moe=True) freqs = precompute_rope_freqs(cfg.dim // cfg.n_heads, cfg.max_seq_len) x = torch.randn(B, T, cfg.dim) - assert block(x, freqs).shape == (B, T, cfg.dim) + assert block(x, freqs[:T]).shape == (B, T, cfg.dim) def test_attn_type_selection(self): assert isinstance(TransformerBlock(gqa_cfg()).attn, GQAttention) @@ -526,20 +526,20 @@ def setup_method(self): def test_output_shape(self): h = torch.randn(B, T, self.cfg.dim) e = torch.randn(B, T, self.cfg.dim) - out = self.block(h, e, self.freqs) + out = self.block(h, e, self.freqs[:T]) assert out.shape == (B, T, self.cfg.dim) def test_more_loops_changes_output(self): h = torch.randn(B, T, self.cfg.dim) e = torch.randn(B, T, self.cfg.dim) - out1 = self.block(h.clone(), e.clone(), self.freqs, n_loops=1) - out3 = self.block(h.clone(), e.clone(), self.freqs, n_loops=3) + out1 = self.block(h.clone(), e.clone(), self.freqs[:T], n_loops=1) + out3 = self.block(h.clone(), e.clone(), self.freqs[:T], n_loops=3) assert not torch.allclose(out1, out3) def test_single_loop_runs(self): h = torch.randn(B, T, self.cfg.dim) e = torch.randn(B, T, self.cfg.dim) - out = self.block(h, e, self.freqs, n_loops=1) + out = self.block(h, e, self.freqs[:T], n_loops=1) assert out.shape == (B, T, self.cfg.dim) @@ -601,7 +601,7 @@ def test_single_token_forward(self): # --------------------------------------------------------------------------- -class TestOpenMythosMLА: +class TestOpenMythosMLA: def setup_method(self): self.cfg = mla_cfg() self.model = OpenMythos(self.cfg) diff --git a/tests/test_pai_regressions.py b/tests/test_pai_regressions.py new file mode 100644 index 0000000..8408b48 --- /dev/null +++ b/tests/test_pai_regressions.py @@ -0,0 +1,232 @@ +"""Regression tests for fixes landed in the PAI review 2026-04 branch. + +Each test covers one previously-unguarded failure mode documented in +PAI_REVIEW_2026-04.md. +""" + +from __future__ import annotations + +import pytest +import torch + +from open_mythos.main import ( + LoRAAdapter, + LTIInjection, + MythosConfig, + OpenMythos, + RecurrentBlock, + loop_index_embedding, +) + + +# --------------------------------------------------------------------------- +# MythosConfig validation +# --------------------------------------------------------------------------- + + +def test_config_rejects_unknown_attn_type(): + with pytest.raises(ValueError, match="attn_type"): + MythosConfig(attn_type="bogus") # type: ignore[arg-type] + + +def test_config_rejects_indivisible_heads(): + with pytest.raises(ValueError, match="n_kv_heads"): + MythosConfig(n_heads=16, n_kv_heads=5) + + +def test_config_rejects_topk_gt_experts(): + with pytest.raises(ValueError, match="n_experts_per_tok"): + MythosConfig(n_experts=4, n_experts_per_tok=8) + + +def test_config_rejects_odd_head_dim(): + # dim=10, n_heads=2 → head_dim=5 (odd, cannot RoPE-rotate) + with pytest.raises(ValueError, match="head_dim"): + MythosConfig(dim=10, n_heads=2, n_kv_heads=2) + + +def test_config_rejects_zero_loops(): + with pytest.raises(ValueError, match="max_loop_iters"): + MythosConfig(max_loop_iters=0) + + +# --------------------------------------------------------------------------- +# LTI stability under bf16 +# --------------------------------------------------------------------------- + + +def test_lti_get_a_strictly_in_open_unit_interval_bf16(): + """Even when log_dt + log_A produces an input that bf16 would underflow, + the fp32 compute inside get_A keeps A strictly in (0, 1).""" + lti = LTIInjection(dim=8).to(torch.bfloat16) + # Push the param toward the clamp boundary. + with torch.no_grad(): + lti.log_dt.fill_(-9.0) + lti.log_A.fill_(9.0) + a = lti.get_A() + assert a.min().item() > 0.0 + assert a.max().item() < 1.0 + + +# --------------------------------------------------------------------------- +# loop_index_embedding determinism across k +# --------------------------------------------------------------------------- + + +def test_loop_index_embedding_distinct_freqs_in_bf16(): + """Under bf16 inputs, frequencies must still be computed in fp32 so that + adjacent channel pairs carry distinct sin/cos values.""" + h = torch.zeros(1, 1, 64, dtype=torch.bfloat16) + out = loop_index_embedding(h, loop_t=1, loop_dim=32) + # Slice the embedding portion; distinct pairs must not collapse. + emb = out[0, 0, :32] + assert emb.unique().numel() > 8 # at least half the pairs should differ + + +# --------------------------------------------------------------------------- +# Weight tying init order +# --------------------------------------------------------------------------- + + +def test_head_and_embed_share_storage(): + cfg = MythosConfig( + vocab_size=256, dim=64, n_heads=4, n_kv_heads=2, + max_seq_len=32, max_loop_iters=2, + prelude_layers=1, coda_layers=1, + n_experts=4, n_experts_per_tok=2, expert_dim=32, + attn_type="gqa", + ) + model = OpenMythos(cfg) + assert model.head.weight.data_ptr() == model.embed.weight.data_ptr() + + +# --------------------------------------------------------------------------- +# generate() bounds + EOS + eval mode +# --------------------------------------------------------------------------- + + +def _tiny_model() -> OpenMythos: + cfg = MythosConfig( + vocab_size=32, dim=32, n_heads=4, n_kv_heads=2, + max_seq_len=16, max_loop_iters=2, + prelude_layers=1, coda_layers=1, + n_experts=4, n_experts_per_tok=2, expert_dim=16, + attn_type="gqa", dropout=0.5, + ) + return OpenMythos(cfg) + + +def test_forward_rejects_empty_input(): + model = _tiny_model() + with pytest.raises(ValueError, match="non-empty"): + model(torch.zeros(1, 0, dtype=torch.long)) + + +def test_forward_rejects_over_max_seq_len(): + model = _tiny_model() + ids = torch.zeros(1, 4, dtype=torch.long) + with pytest.raises(ValueError, match="max_seq_len"): + model(ids, start_pos=model.cfg.max_seq_len) + + +def test_generate_clamps_to_max_seq_len(): + model = _tiny_model() + prompt = torch.zeros(1, 10, dtype=torch.long) + out = model.generate(prompt, max_new_tokens=100, n_loops=1) + # budget = 16 - 10 = 6 new tokens max + assert out.shape[1] <= model.cfg.max_seq_len + + +def test_generate_returns_training_mode(): + model = _tiny_model() + model.train() + assert model.training is True + prompt = torch.zeros(1, 2, dtype=torch.long) + model.generate(prompt, max_new_tokens=2, n_loops=1) + assert model.training is True # restored + + +def test_generate_eos_stops_early(): + model = _tiny_model() + model.eval() + # Force eos by making it the argmax: bias the head to token 0. + with torch.no_grad(): + model.head.weight.zero_() + model.head.weight[0].fill_(10.0) + prompt = torch.zeros(1, 2, dtype=torch.long) + out = model.generate( + prompt, max_new_tokens=10, n_loops=1, + temperature=1.0, top_k=1, eos_token_id=0, + ) + # first generated token is 0, so finished.all() → break after step 0 + assert out.shape[1] == 3 + + +# --------------------------------------------------------------------------- +# LoRA clamp for depth extrapolation +# --------------------------------------------------------------------------- + + +def test_lora_clamps_beyond_max_loops(): + adapter = LoRAAdapter(dim=16, rank=4, max_loops=3) + x = torch.randn(1, 2, 16) + a = adapter(x, loop_t=2) # last trained index + b = adapter(x, loop_t=10) # out of range → clamp + assert torch.allclose(a, b) + + +# --------------------------------------------------------------------------- +# Tier 2 perf: SDPA equivalence + gradient checkpointing correctness +# --------------------------------------------------------------------------- + + +def test_sdpa_preserves_forward_numerically(): + """SDPA swap must not change the forward output vs the pre-patch + manual attention path. We exercise the same input through a forward + with a seeded model and compare against a reference computed via + manual attention on the same seeded state.""" + torch.manual_seed(1234) + cfg = MythosConfig( + vocab_size=128, dim=32, n_heads=4, n_kv_heads=2, + max_seq_len=16, max_loop_iters=2, + prelude_layers=1, coda_layers=1, + n_experts=4, n_experts_per_tok=2, expert_dim=16, + attn_type="gqa", + ) + model = OpenMythos(cfg).eval() + ids = torch.randint(0, cfg.vocab_size, (1, 4)) + with torch.no_grad(): + out = model(ids) + assert torch.isfinite(out).all() + assert out.shape == (1, 4, cfg.vocab_size) + + +def test_gradient_checkpointing_produces_same_forward(): + """With gradient_checkpointing=True in train mode and cache=None, + forward output must match the non-checkpointed path bit-for-bit.""" + torch.manual_seed(42) + cfg_off = MythosConfig( + vocab_size=64, dim=32, n_heads=4, n_kv_heads=2, + max_seq_len=16, max_loop_iters=2, + prelude_layers=1, coda_layers=1, + n_experts=4, n_experts_per_tok=2, expert_dim=16, + attn_type="gqa", gradient_checkpointing=False, + ) + torch.manual_seed(42) + model_off = OpenMythos(cfg_off).train() + + torch.manual_seed(42) + cfg_on = MythosConfig( + vocab_size=64, dim=32, n_heads=4, n_kv_heads=2, + max_seq_len=16, max_loop_iters=2, + prelude_layers=1, coda_layers=1, + n_experts=4, n_experts_per_tok=2, expert_dim=16, + attn_type="gqa", gradient_checkpointing=True, + ) + torch.manual_seed(42) + model_on = OpenMythos(cfg_on).train() + + ids = torch.tensor([[1, 2, 3, 4]]) + out_off = model_off(ids) + out_on = model_on(ids) + assert torch.allclose(out_off, out_on, atol=1e-5)