-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Enh: Add lightweight OpenMythos runtime metadata #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
anphonic
wants to merge
7
commits into
kyegomez:main
Choose a base branch
from
anphonic:openmythos-runtime-metadata
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
463856c
Add lightweight OpenMythos runtime metadata
anphonic 95fcf0a
Make tokenizer imports lazy
anphonic bda5ca9
Validate attention type and update docs
anphonic b5388b9
Add MythosConfig sanity checks
anphonic b20e8ff
Tighten MythosConfig validation
anphonic 936d24d
Tighten config bounds
anphonic 2bd4d05
Cache lazy OpenMythos exports
anphonic File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| # Copilot Instructions | ||
|
|
||
| ## Build, test, and lint commands | ||
|
|
||
| - Install dependencies for local work with `poetry install --with dev`. The repo also keeps a minimal `requirements.txt` for `pip install -r requirements.txt`. | ||
| - Build the package with `poetry build`. | ||
| - Run the main unit suite with `poetry run pytest tests/test_main.py`. | ||
| - Run a single test with `poetry run pytest tests/test_main.py::TestOpenMythosGQA::test_generate_shape`. | ||
| - Run tokenizer tests with `poetry run pytest tests/test_tokenizer.py`. These load the default Hugging Face tokenizer (`openai/gpt-oss-20b`), so they are heavier than `tests/test_main.py`. | ||
| - Run the RoPE debug script with `poetry run python tests/test_rope_debug.py`. | ||
| - Lint with `poetry run ruff check .`. | ||
| - Check formatting with `poetry run black --check .`. | ||
| - The documented training entrypoints are `poetry run python training/3b_fine_web_edu.py` and `poetry run torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") training/3b_fine_web_edu.py`. | ||
| - Benchmark utilities are executable scripts under `tests/`, not pytest files: `poetry run python tests/small_benchmark.py` and `poetry run python tests/bench_vs_transformer.py`. | ||
|
|
||
| ## High-level architecture | ||
|
|
||
| - `open_mythos/config.py` defines `MythosConfig`. `open_mythos/main.py` is the primary implementation for both attention backends, the MoE FFN, the recurrent block, and the top-level `OpenMythos` model. | ||
| - The model pipeline is fixed: token embedding -> Prelude (`cfg.prelude_layers` dense `TransformerBlock`s run once) -> Recurrent Block (one shared `TransformerBlock` looped `n_loops` times with LoRA, ACT halting, and LTI-stable input injection) -> Coda (`cfg.coda_layers` dense `TransformerBlock`s run once) -> `RMSNorm` -> tied LM head. | ||
| - The key architectural invariant is that the encoded input `e` is frozen after the Prelude and injected on every recurrent iteration. The recurrent update path is `loop_index_embedding -> TransformerBlock(use_moe=True) -> LoRAAdapter -> LTIInjection -> ACTHalting`. | ||
| - `cfg.attn_type` switches the whole attention/cache path. `GQAttention` stores full K/V by layer, while `MLAttention` stores compressed `c_kv` plus `k_rope` and reconstructs K/V on demand. `OpenMythos.forward()` also switches RoPE buffers based on `attn_type`. | ||
| - `open_mythos/variants.py` contains the named model scales (`mythos_1b` through `mythos_1t`) as `MythosConfig` factories. `open_mythos/tokenizer.py` is a small wrapper around `transformers.AutoTokenizer` with `openai/gpt-oss-20b` as the default tokenizer. | ||
| - `training/3b_fine_web_edu.py` is a standalone FSDP training script for the 3B variant. It streams FineWeb-Edu shards, uses `MythosTokenizer`, and keeps checkpointing/training logic out of the library module. | ||
| - `docs/open_mythos.md` is the detailed API and architecture reference. `docs/datasets.md` holds the training dataset recommendations and token-budget guidance referenced by the training script/README. | ||
| - `open_mythos/moda.py` is a separate experimental MoDA + MoE implementation, not part of the main `OpenMythos` export surface. | ||
|
|
||
| ## Key conventions | ||
|
|
||
| - For tests, smoke runs, and examples that should execute quickly, use the tiny helper configs from `tests/test_main.py` (`gqa_cfg()` / `mla_cfg()`) or similarly small custom configs. `MythosConfig()` defaults to a much larger research model. | ||
| - Preserve decode-position semantics when touching inference code: `start_pos` chooses the RoPE slice for incremental decoding, and each cache entry uses a deterministic key (`prelude_{i}`, `recurrent_loop_{t}`, `coda_{i}`). | ||
| - Do not short-circuit recurrent execution when a KV cache is active. `RecurrentBlock` only exits early on ACT convergence when `kv_cache is None`; cached decoding relies on every loop depth writing its cache entry on every step. | ||
| - Keep the embedding/LM head weight tying intact (`self.head.weight = self.embed.weight`), and keep `router_bias` as a registered buffer rather than a trainable parameter. | ||
| - Prelude and Coda are always dense FFN blocks (`use_moe=False`); the recurrent block is the only place that uses the MoE FFN (`use_moe=True`). | ||
| - Named variants and default configs are MLA-first. If you switch code or tests to GQA, make sure the config still provides valid MLA fields because shared config helpers are reused across both modes. | ||
| - Use the variant helpers that actually exist in `open_mythos/variants.py`: `mythos_1b`, `mythos_3b`, `mythos_10b`, `mythos_50b`, `mythos_100b`, `mythos_500b`, and `mythos_1t`. The README currently shows a stale `mythos_7b()` example. | ||
| - The benchmark scripts live under `tests/` and are intended to be run directly as scripts. Their docstrings still reference a `benchmarks/` path that does not exist in this repository. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| from dataclasses import asdict, dataclass | ||
|
|
||
|
|
||
| @dataclass | ||
| class MythosConfig: | ||
| """ | ||
| Hyperparameter configuration for OpenMythos. | ||
|
|
||
| Core: | ||
| vocab_size -- token vocabulary size | ||
| dim -- model hidden dimension | ||
| n_heads -- number of query attention heads | ||
| n_kv_heads -- number of key/value heads (GQA; ignored by MLA) | ||
| max_seq_len -- maximum sequence length for RoPE precomputation | ||
| max_loop_iters -- default recurrent loop depth T at inference | ||
| prelude_layers -- number of standard transformer layers before the loop | ||
| coda_layers -- number of standard transformer layers after the loop | ||
|
|
||
| Attention (attn_type selects between the two): | ||
| attn_type -- "gqa" for Grouped Query Attention, "mla" for Multi-Latent Attention | ||
| kv_lora_rank -- [MLA] compressed KV latent dimension stored in the cache | ||
| q_lora_rank -- [MLA] compressed Q latent dimension | ||
| qk_rope_head_dim-- [MLA] per-head dims that receive RoPE | ||
| qk_nope_head_dim-- [MLA] per-head dims without positional encoding | ||
| v_head_dim -- [MLA] per-head value dimension | ||
|
|
||
| MoE FFN (used inside the recurrent block): | ||
| n_experts -- total number of routed expert FFNs | ||
| n_shared_experts-- number of always-active shared experts | ||
| n_experts_per_tok-- top-K experts selected per token by the router | ||
| expert_dim -- hidden dimension inside each fine-grained expert | ||
|
|
||
| Other: | ||
| act_threshold -- ACT halting threshold (cumulative probability to stop looping) | ||
| rope_theta -- RoPE base frequency | ||
| lora_rank -- rank of the per-loop depth-wise LoRA adapter | ||
| """ | ||
|
|
||
| vocab_size: int = 32000 | ||
| dim: int = 2048 | ||
| n_heads: int = 16 | ||
| n_kv_heads: int = 4 # GQA: fewer KV heads than Q heads | ||
| max_seq_len: int = 4096 | ||
| max_loop_iters: int = 16 # T — recurrent depth at inference | ||
| prelude_layers: int = 2 | ||
| coda_layers: int = 2 | ||
| # Attention type: "gqa" | "mla" | ||
| attn_type: str = "mla" | ||
| # MLA params (only used when attn_type="mla") | ||
| kv_lora_rank: int = 512 # compressed KV latent cached instead of full K/V | ||
| q_lora_rank: int = 1536 # compressed Q latent dim | ||
| qk_rope_head_dim: int = 64 # per-head dims that receive RoPE | ||
| qk_nope_head_dim: int = 128 # per-head dims without RoPE | ||
| v_head_dim: int = 128 # per-head value dim | ||
| # MoE | ||
| n_experts: int = 64 | ||
| n_shared_experts: int = 2 | ||
| n_experts_per_tok: int = 4 # top-K routed | ||
| expert_dim: int = 512 # fine-grained: dim // (n_experts // n_experts_per_tok) | ||
| # ACT halting | ||
| act_threshold: float = 0.99 | ||
| # RoPE | ||
| rope_theta: float = 500000.0 | ||
| # LoRA depth adaptation | ||
| lora_rank: int = 16 | ||
| # Maximum tokens to generate per forward pass | ||
| max_output_tokens: int = 4096 | ||
| # Dropout (set 0.0 to disable; 0.1 is standard for pretraining) | ||
| dropout: float = 0.0 | ||
|
|
||
| def __post_init__(self) -> None: | ||
| if self.attn_type not in {"gqa", "mla"}: | ||
| raise ValueError( | ||
| f"Unsupported attn_type {self.attn_type!r}; expected 'gqa' or 'mla'" | ||
| ) | ||
| if self.vocab_size <= 0: | ||
| raise ValueError("vocab_size must be positive") | ||
| if self.dim <= 0: | ||
| raise ValueError("dim must be positive") | ||
| if self.n_heads <= 0: | ||
| raise ValueError("n_heads must be positive") | ||
| if self.n_kv_heads <= 0: | ||
| raise ValueError("n_kv_heads must be positive") | ||
| if self.max_seq_len <= 0: | ||
| raise ValueError("max_seq_len must be positive") | ||
| if self.max_loop_iters <= 0: | ||
| raise ValueError("max_loop_iters must be positive") | ||
| if self.prelude_layers < 0: | ||
| raise ValueError("prelude_layers must be non-negative") | ||
| if self.coda_layers < 0: | ||
| raise ValueError("coda_layers must be non-negative") | ||
| if self.kv_lora_rank <= 0: | ||
| raise ValueError("kv_lora_rank must be positive") | ||
| if self.q_lora_rank <= 0: | ||
| raise ValueError("q_lora_rank must be positive") | ||
| if self.qk_rope_head_dim <= 0: | ||
| raise ValueError("qk_rope_head_dim must be positive") | ||
| if self.qk_nope_head_dim <= 0: | ||
| raise ValueError("qk_nope_head_dim must be positive") | ||
| if self.v_head_dim <= 0: | ||
| raise ValueError("v_head_dim must be positive") | ||
| if self.n_experts <= 0: | ||
| raise ValueError("n_experts must be positive") | ||
| if self.n_shared_experts < 0: | ||
| raise ValueError("n_shared_experts must be non-negative") | ||
| if self.n_experts_per_tok <= 0: | ||
| raise ValueError("n_experts_per_tok must be positive") | ||
|
anphonic marked this conversation as resolved.
|
||
| if self.n_experts_per_tok > self.n_experts: | ||
| raise ValueError( | ||
| "n_experts_per_tok must be less than or equal to n_experts" | ||
| ) | ||
| if self.expert_dim <= 0: | ||
| raise ValueError("expert_dim must be positive") | ||
| if not 0.0 < self.act_threshold <= 1.0: | ||
| raise ValueError("act_threshold must be in the interval (0, 1]") | ||
| if self.rope_theta <= 0: | ||
| raise ValueError("rope_theta must be positive") | ||
| if self.lora_rank <= 0: | ||
| raise ValueError("lora_rank must be positive") | ||
| if self.max_output_tokens <= 0: | ||
| raise ValueError("max_output_tokens must be positive") | ||
| if not 0.0 <= self.dropout <= 1.0: | ||
| raise ValueError("dropout must be in the interval [0, 1]") | ||
| if self.dim % self.n_heads != 0: | ||
| raise ValueError("dim must be divisible by n_heads") | ||
| if (self.dim // self.n_heads) % 2 != 0: | ||
| raise ValueError("dim // n_heads must be even for RoPE") | ||
| if self.attn_type == "gqa": | ||
| if self.n_kv_heads > self.n_heads: | ||
| raise ValueError("n_kv_heads must be less than or equal to n_heads") | ||
| if self.n_heads % self.n_kv_heads != 0: | ||
| raise ValueError("n_heads must be divisible by n_kv_heads") | ||
| if self.attn_type == "mla" and self.qk_rope_head_dim % 2 != 0: | ||
| raise ValueError("qk_rope_head_dim must be even for MLA RoPE") | ||
|
|
||
| def to_dict(self) -> dict[str, object]: | ||
| """Return a plain-Python config dictionary for serialization.""" | ||
| return asdict(self) | ||
|
|
||
| def runtime_profile(self) -> dict[str, object]: | ||
| """ | ||
| Describe the stable runtime-facing capabilities of this config. | ||
|
|
||
| GatesOfMythos can use this to validate routing decisions without | ||
| inspecting model internals or loading a heavyweight instance first. | ||
| """ | ||
| return { | ||
| "model_name": "OpenMythos", | ||
| "attn_type": self.attn_type, | ||
| "supports_kv_cache": True, | ||
| "supports_incremental_decode": True, | ||
| "uses_moe": True, | ||
| "uses_act_halting": True, | ||
| "uses_lti_injection": True, | ||
| "max_context_tokens": self.max_seq_len, | ||
| "max_loop_iters": self.max_loop_iters, | ||
| "max_output_tokens": self.max_output_tokens, | ||
| "cache_layout": { | ||
| "prelude": [f"prelude_{i}" for i in range(self.prelude_layers)], | ||
| "recurrent": "recurrent_loop_{t}", | ||
| "coda": [f"coda_{i}" for i in range(self.coda_layers)], | ||
| }, | ||
| "attention_backend": ( | ||
| "multi_latent" if self.attn_type == "mla" else "grouped_query" | ||
| ), | ||
|
anphonic marked this conversation as resolved.
|
||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.