Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions .github/copilot-instructions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copilot Instructions

## Build, test, and lint commands

- Install dependencies for local work with `poetry install --with dev`. The repo also keeps a minimal `requirements.txt` for `pip install -r requirements.txt`.
- Build the package with `poetry build`.
- Run the main unit suite with `poetry run pytest tests/test_main.py`.
- Run a single test with `poetry run pytest tests/test_main.py::TestOpenMythosGQA::test_generate_shape`.
- Run tokenizer tests with `poetry run pytest tests/test_tokenizer.py`. These load the default Hugging Face tokenizer (`openai/gpt-oss-20b`), so they are heavier than `tests/test_main.py`.
- Run the RoPE debug script with `poetry run python tests/test_rope_debug.py`.
- Lint with `poetry run ruff check .`.
- Check formatting with `poetry run black --check .`.
- The documented training entrypoints are `poetry run python training/3b_fine_web_edu.py` and `poetry run torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") training/3b_fine_web_edu.py`.
- Benchmark utilities are executable scripts under `tests/`, not pytest files: `poetry run python tests/small_benchmark.py` and `poetry run python tests/bench_vs_transformer.py`.

## High-level architecture

- `open_mythos/config.py` defines `MythosConfig`. `open_mythos/main.py` is the primary implementation for both attention backends, the MoE FFN, the recurrent block, and the top-level `OpenMythos` model.
- The model pipeline is fixed: token embedding -> Prelude (`cfg.prelude_layers` dense `TransformerBlock`s run once) -> Recurrent Block (one shared `TransformerBlock` looped `n_loops` times with LoRA, ACT halting, and LTI-stable input injection) -> Coda (`cfg.coda_layers` dense `TransformerBlock`s run once) -> `RMSNorm` -> tied LM head.
- The key architectural invariant is that the encoded input `e` is frozen after the Prelude and injected on every recurrent iteration. The recurrent update path is `loop_index_embedding -> TransformerBlock(use_moe=True) -> LoRAAdapter -> LTIInjection -> ACTHalting`.
- `cfg.attn_type` switches the whole attention/cache path. `GQAttention` stores full K/V by layer, while `MLAttention` stores compressed `c_kv` plus `k_rope` and reconstructs K/V on demand. `OpenMythos.forward()` also switches RoPE buffers based on `attn_type`.
- `open_mythos/variants.py` contains the named model scales (`mythos_1b` through `mythos_1t`) as `MythosConfig` factories. `open_mythos/tokenizer.py` is a small wrapper around `transformers.AutoTokenizer` with `openai/gpt-oss-20b` as the default tokenizer.
- `training/3b_fine_web_edu.py` is a standalone FSDP training script for the 3B variant. It streams FineWeb-Edu shards, uses `MythosTokenizer`, and keeps checkpointing/training logic out of the library module.
- `docs/open_mythos.md` is the detailed API and architecture reference. `docs/datasets.md` holds the training dataset recommendations and token-budget guidance referenced by the training script/README.
- `open_mythos/moda.py` is a separate experimental MoDA + MoE implementation, not part of the main `OpenMythos` export surface.

## Key conventions

- For tests, smoke runs, and examples that should execute quickly, use the tiny helper configs from `tests/test_main.py` (`gqa_cfg()` / `mla_cfg()`) or similarly small custom configs. `MythosConfig()` defaults to a much larger research model.
- Preserve decode-position semantics when touching inference code: `start_pos` chooses the RoPE slice for incremental decoding, and each cache entry uses a deterministic key (`prelude_{i}`, `recurrent_loop_{t}`, `coda_{i}`).
- Do not short-circuit recurrent execution when a KV cache is active. `RecurrentBlock` only exits early on ACT convergence when `kv_cache is None`; cached decoding relies on every loop depth writing its cache entry on every step.
- Keep the embedding/LM head weight tying intact (`self.head.weight = self.embed.weight`), and keep `router_bias` as a registered buffer rather than a trainable parameter.
- Prelude and Coda are always dense FFN blocks (`use_moe=False`); the recurrent block is the only place that uses the MoE FFN (`use_moe=True`).
- Named variants and default configs are MLA-first. If you switch code or tests to GQA, make sure the config still provides valid MLA fields because shared config helpers are reused across both modes.
- Use the variant helpers that actually exist in `open_mythos/variants.py`: `mythos_1b`, `mythos_3b`, `mythos_10b`, `mythos_50b`, `mythos_100b`, `mythos_500b`, and `mythos_1t`. The README currently shows a stale `mythos_7b()` example.
- The benchmark scripts live under `tests/` and are intended to be run directly as scripts. Their docstrings still reference a `benchmarks/` path that does not exist in this repository.
51 changes: 33 additions & 18 deletions open_mythos/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
from open_mythos.main import (
ACTHalting,
Expert,
GQAttention,
LoRAAdapter,
LTIInjection,
MLAttention,
MoEFFN,
MythosConfig,
OpenMythos,
RecurrentBlock,
RMSNorm,
TransformerBlock,
apply_rope,
loop_index_embedding,
precompute_rope_freqs,
)
from open_mythos.tokenizer import MythosTokenizer
from importlib import import_module

from open_mythos.config import MythosConfig
from open_mythos.tokenizer import MythosTokenizer, get_vocab_size, load_tokenizer
from open_mythos.variants import (
mythos_1b,
mythos_1t,
Expand All @@ -26,6 +12,24 @@
mythos_500b,
)

_MAIN_EXPORTS = {
"ACTHalting",
"Expert",
"GQAttention",
"LoRAAdapter",
"LTIInjection",
"MLAttention",
"MoEFFN",
"OpenMythos",
"RecurrentBlock",
"RMSNorm",
"TransformerBlock",
"apply_rope",
"loop_index_embedding",
"precompute_rope_freqs",
}
_MAIN_MODULE = None

__all__ = [
"MythosConfig",
"RMSNorm",
Expand Down Expand Up @@ -53,3 +57,14 @@
"get_vocab_size",
"MythosTokenizer",
]


def __getattr__(name: str):
if name in _MAIN_EXPORTS:
global _MAIN_MODULE
if _MAIN_MODULE is None:
_MAIN_MODULE = import_module("open_mythos.main")
value = getattr(_MAIN_MODULE, name)
globals()[name] = value
return value
raise AttributeError(f"module 'open_mythos' has no attribute {name!r}")
166 changes: 166 additions & 0 deletions open_mythos/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from dataclasses import asdict, dataclass


@dataclass
class MythosConfig:
"""
Hyperparameter configuration for OpenMythos.

Core:
vocab_size -- token vocabulary size
dim -- model hidden dimension
n_heads -- number of query attention heads
n_kv_heads -- number of key/value heads (GQA; ignored by MLA)
max_seq_len -- maximum sequence length for RoPE precomputation
max_loop_iters -- default recurrent loop depth T at inference
prelude_layers -- number of standard transformer layers before the loop
coda_layers -- number of standard transformer layers after the loop

Attention (attn_type selects between the two):
attn_type -- "gqa" for Grouped Query Attention, "mla" for Multi-Latent Attention
kv_lora_rank -- [MLA] compressed KV latent dimension stored in the cache
q_lora_rank -- [MLA] compressed Q latent dimension
qk_rope_head_dim-- [MLA] per-head dims that receive RoPE
qk_nope_head_dim-- [MLA] per-head dims without positional encoding
v_head_dim -- [MLA] per-head value dimension

MoE FFN (used inside the recurrent block):
n_experts -- total number of routed expert FFNs
n_shared_experts-- number of always-active shared experts
n_experts_per_tok-- top-K experts selected per token by the router
expert_dim -- hidden dimension inside each fine-grained expert

Other:
act_threshold -- ACT halting threshold (cumulative probability to stop looping)
rope_theta -- RoPE base frequency
lora_rank -- rank of the per-loop depth-wise LoRA adapter
"""

vocab_size: int = 32000
dim: int = 2048
n_heads: int = 16
n_kv_heads: int = 4 # GQA: fewer KV heads than Q heads
max_seq_len: int = 4096
max_loop_iters: int = 16 # T — recurrent depth at inference
prelude_layers: int = 2
coda_layers: int = 2
# Attention type: "gqa" | "mla"
attn_type: str = "mla"
# MLA params (only used when attn_type="mla")
kv_lora_rank: int = 512 # compressed KV latent cached instead of full K/V
q_lora_rank: int = 1536 # compressed Q latent dim
qk_rope_head_dim: int = 64 # per-head dims that receive RoPE
qk_nope_head_dim: int = 128 # per-head dims without RoPE
v_head_dim: int = 128 # per-head value dim
# MoE
n_experts: int = 64
n_shared_experts: int = 2
n_experts_per_tok: int = 4 # top-K routed
expert_dim: int = 512 # fine-grained: dim // (n_experts // n_experts_per_tok)
# ACT halting
act_threshold: float = 0.99
# RoPE
rope_theta: float = 500000.0
# LoRA depth adaptation
lora_rank: int = 16
# Maximum tokens to generate per forward pass
max_output_tokens: int = 4096
# Dropout (set 0.0 to disable; 0.1 is standard for pretraining)
dropout: float = 0.0

def __post_init__(self) -> None:
if self.attn_type not in {"gqa", "mla"}:
raise ValueError(
f"Unsupported attn_type {self.attn_type!r}; expected 'gqa' or 'mla'"
)
if self.vocab_size <= 0:
raise ValueError("vocab_size must be positive")
if self.dim <= 0:
raise ValueError("dim must be positive")
if self.n_heads <= 0:
raise ValueError("n_heads must be positive")
if self.n_kv_heads <= 0:
raise ValueError("n_kv_heads must be positive")
if self.max_seq_len <= 0:
Comment thread
anphonic marked this conversation as resolved.
raise ValueError("max_seq_len must be positive")
if self.max_loop_iters <= 0:
raise ValueError("max_loop_iters must be positive")
if self.prelude_layers < 0:
raise ValueError("prelude_layers must be non-negative")
if self.coda_layers < 0:
raise ValueError("coda_layers must be non-negative")
if self.kv_lora_rank <= 0:
raise ValueError("kv_lora_rank must be positive")
if self.q_lora_rank <= 0:
raise ValueError("q_lora_rank must be positive")
if self.qk_rope_head_dim <= 0:
raise ValueError("qk_rope_head_dim must be positive")
if self.qk_nope_head_dim <= 0:
raise ValueError("qk_nope_head_dim must be positive")
if self.v_head_dim <= 0:
raise ValueError("v_head_dim must be positive")
if self.n_experts <= 0:
raise ValueError("n_experts must be positive")
if self.n_shared_experts < 0:
raise ValueError("n_shared_experts must be non-negative")
if self.n_experts_per_tok <= 0:
raise ValueError("n_experts_per_tok must be positive")
Comment thread
anphonic marked this conversation as resolved.
if self.n_experts_per_tok > self.n_experts:
raise ValueError(
"n_experts_per_tok must be less than or equal to n_experts"
)
if self.expert_dim <= 0:
raise ValueError("expert_dim must be positive")
if not 0.0 < self.act_threshold <= 1.0:
raise ValueError("act_threshold must be in the interval (0, 1]")
if self.rope_theta <= 0:
raise ValueError("rope_theta must be positive")
if self.lora_rank <= 0:
raise ValueError("lora_rank must be positive")
if self.max_output_tokens <= 0:
raise ValueError("max_output_tokens must be positive")
if not 0.0 <= self.dropout <= 1.0:
raise ValueError("dropout must be in the interval [0, 1]")
if self.dim % self.n_heads != 0:
raise ValueError("dim must be divisible by n_heads")
if (self.dim // self.n_heads) % 2 != 0:
raise ValueError("dim // n_heads must be even for RoPE")
if self.attn_type == "gqa":
if self.n_kv_heads > self.n_heads:
raise ValueError("n_kv_heads must be less than or equal to n_heads")
if self.n_heads % self.n_kv_heads != 0:
raise ValueError("n_heads must be divisible by n_kv_heads")
if self.attn_type == "mla" and self.qk_rope_head_dim % 2 != 0:
raise ValueError("qk_rope_head_dim must be even for MLA RoPE")

def to_dict(self) -> dict[str, object]:
"""Return a plain-Python config dictionary for serialization."""
return asdict(self)

def runtime_profile(self) -> dict[str, object]:
"""
Describe the stable runtime-facing capabilities of this config.

GatesOfMythos can use this to validate routing decisions without
inspecting model internals or loading a heavyweight instance first.
"""
return {
"model_name": "OpenMythos",
"attn_type": self.attn_type,
"supports_kv_cache": True,
"supports_incremental_decode": True,
"uses_moe": True,
"uses_act_halting": True,
"uses_lti_injection": True,
"max_context_tokens": self.max_seq_len,
"max_loop_iters": self.max_loop_iters,
"max_output_tokens": self.max_output_tokens,
"cache_layout": {
"prelude": [f"prelude_{i}" for i in range(self.prelude_layers)],
"recurrent": "recurrent_loop_{t}",
"coda": [f"coda_{i}" for i in range(self.coda_layers)],
},
"attention_backend": (
"multi_latent" if self.attn_type == "mla" else "grouped_query"
),
Comment thread
anphonic marked this conversation as resolved.
}
75 changes: 6 additions & 69 deletions open_mythos/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from open_mythos.config import MythosConfig

try:
from flash_attn import flash_attn_func

Expand All @@ -13,74 +14,6 @@
_HAS_FLASH_ATTN = False


@dataclass
class MythosConfig:
"""
Hyperparameter configuration for OpenMythos.

Core:
vocab_size -- token vocabulary size
dim -- model hidden dimension
n_heads -- number of query attention heads
n_kv_heads -- number of key/value heads (GQA; ignored by MLA)
max_seq_len -- maximum sequence length for RoPE precomputation
max_loop_iters -- default recurrent loop depth T at inference
prelude_layers -- number of standard transformer layers before the loop
coda_layers -- number of standard transformer layers after the loop

Attention (attn_type selects between the two):
attn_type -- "gqa" for Grouped Query Attention, "mla" for Multi-Latent Attention
kv_lora_rank -- [MLA] compressed KV latent dimension stored in the cache
q_lora_rank -- [MLA] compressed Q latent dimension
qk_rope_head_dim-- [MLA] per-head dims that receive RoPE
qk_nope_head_dim-- [MLA] per-head dims without positional encoding
v_head_dim -- [MLA] per-head value dimension

MoE FFN (used inside the recurrent block):
n_experts -- total number of routed expert FFNs
n_shared_experts-- number of always-active shared experts
n_experts_per_tok-- top-K experts selected per token by the router
expert_dim -- hidden dimension inside each fine-grained expert

Other:
act_threshold -- ACT halting threshold (cumulative probability to stop looping)
rope_theta -- RoPE base frequency
lora_rank -- rank of the per-loop depth-wise LoRA adapter
"""

vocab_size: int = 32000
dim: int = 2048
n_heads: int = 16
n_kv_heads: int = 4 # GQA: fewer KV heads than Q heads
max_seq_len: int = 4096
max_loop_iters: int = 16 # T — recurrent depth at inference
prelude_layers: int = 2
coda_layers: int = 2
# Attention type: "gqa" | "mla"
attn_type: str = "mla"
# MLA params (only used when attn_type="mla")
kv_lora_rank: int = 512 # compressed KV latent cached instead of full K/V
q_lora_rank: int = 1536 # compressed Q latent dim
qk_rope_head_dim: int = 64 # per-head dims that receive RoPE
qk_nope_head_dim: int = 128 # per-head dims without RoPE
v_head_dim: int = 128 # per-head value dim
# MoE
n_experts: int = 64
n_shared_experts: int = 2
n_experts_per_tok: int = 4 # top-K routed
expert_dim: int = 512 # fine-grained: dim // (n_experts // n_experts_per_tok)
# ACT halting
act_threshold: float = 0.99
# RoPE
rope_theta: float = 500000.0
# LoRA depth adaptation
lora_rank: int = 16
# Maximum tokens to generate per forward pass
max_output_tokens: int = 4096
# Dropout (set 0.0 to disable; 0.1 is standard for pretraining)
dropout: float = 0.0


# ---------------------------------------------------------------------------
# RMSNorm
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -965,6 +898,10 @@ def _init_weights(self) -> None:
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, std=0.02)

def describe(self) -> dict[str, object]:
"""Return the stable runtime profile for orchestration and routing."""
return self.cfg.runtime_profile()

@staticmethod
def _causal_mask(
seq_len: int, device: torch.device, dtype: torch.dtype
Expand Down
Loading