Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions experiments/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Loop-scaling validation experiments

A minimal, reproducible training + evaluation pipeline that measures how
OpenMythos's validation perplexity changes as you vary the number of
recurrent loops at **inference time**.

The pipeline trains two comparison models (same ~118M-parameter MLA+MoE
backbone, same ~491M tokens of FineWeb-Edu, same optimizer / schedule)
that differ only in their training-time loop strategy:

| Run | Training `n_loops` per step | Role |
|---|---|---|
| `looped_8` | fixed at 8 | the default OpenMythos training style |
| `baseline_1` | fixed at 1 | dense-equivalent ablation |
| `looped_random` (optional) | uniformly sampled from `{4, 6, 8, 12, 16}` | tests whether random-loop training gives monotonic depth extrapolation |

After training, `evaluate.py` sweeps `n_loops ∈ {1, 2, 4, 6, 8, 12, 16}`
at inference on a held-out FineWeb-Edu slice (`--skip_docs 2_000_000`
ensures no train/val overlap) and logs PPL + generation samples.
`plot_results.py` produces three figures: training loss, ρ(A) over
steps, and the inference-time loop-scaling curve.

## Usage

Requires a single GPU with ≥ 48 GB VRAM for `batch_size=32` at `n_loops=8`
(H100 80 GB, A100 80 GB, or A40 48 GB). On H100 SXM each 15k-step run
takes ~4 hours; `looped_random` needs smaller batches to fit `n_loops=16`
and takes ~3.5 hours.

```bash
cd experiments
pip install matplotlib datasets transformers loguru

# Looped (recommended default)
python train.py --run_name looped_8 --max_loop_iters 8 --max_steps 15000

# Baseline for comparison (trains ~3× faster since n_loops=1)
python train.py --run_name baseline_1 --max_loop_iters 1 --max_steps 15000

# Optional: random-loop training for depth-extrapolation ablation
python train.py --run_name looped_random \
--max_loop_iters 16 \
--loop_sample_mode random_set --loop_choices 4 6 8 12 16 \
--batch_size 16 --grad_accum_steps 2 --max_steps 15000

# Inference-time loop sweep + generation samples
python evaluate.py --ckpt /workspace/runs/looped_8/ckpt_15000.pt \
--loop_grid 1 2 4 6 8 12 16
python evaluate.py --ckpt /workspace/runs/baseline_1/ckpt_15000.pt \
--loop_grid 1

python plot_results.py --runs_dir /workspace/runs --out_dir /workspace/runs/figs
```

Or run all three phases end-to-end with default settings:

```bash
bash run_all.sh # drives looped_8 + baseline_1 + evaluate + plot
```

## Files

| File | Purpose |
|---|---|
| `config.py` | `mythos_150m()` MLA+MoE config (actual param count 117.8M) and `TrainConfig` dataclass |
| `data.py` | Streaming FineWeb-Edu loader with `skip_docs` for clean train/val split |
| `train.py` | AdamW + cosine schedule training with per-step `n_loops` logging; supports `--loop_sample_mode {fixed,random_set}` |
| `evaluate.py` | Loads a checkpoint, runs PPL sweep over `--loop_grid`, emits generation samples at trained and 2× loops |
| `plot_results.py` | Parses all `<run>/train.log` + `<run>/eval_ckpt_*.json` under a runs directory and draws three comparison figures |
| `run_all.sh` | Orchestrator: looped_8 → baseline_1 → eval → plot |

## What the logs contain

`train.log` is tab-separated with headers
`step tokens n_loops lr loss grad_norm rho_A step_s tok_per_s gpu_mem_gb`.
The `n_loops` column records the value actually used at each training
step (constant in `fixed` mode, varying in `random_set` mode) so you can
post-hoc slice losses by training-loop-depth.
88 changes: 88 additions & 0 deletions experiments/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
Experiment config for OpenMythos loop-scaling validation.

Two model variants with identical param count/compute:
- looped: max_loop_iters=8 (the OpenMythos architecture)
- baseline: max_loop_iters=1 (equivalent to a plain transformer)

Training data is FineWeb-Edu sample-10BT (streaming); we train on ~1B tokens.
"""

from dataclasses import dataclass, field, asdict
from open_mythos.main import MythosConfig


def mythos_150m(max_loop_iters: int = 8) -> MythosConfig:
"""
~150M parameter config tuned for a single H100.

With max_loop_iters=8 and MoE (16 experts, top-2), activated params per
token ≈ 80M; total params ≈ 150M. The looped block is run 8 times so the
effective compute per forward matches a ~8x deeper plain transformer.
"""
return MythosConfig(
vocab_size=50257,
dim=768,
n_heads=12,
n_kv_heads=4,
max_seq_len=1024,
max_loop_iters=max_loop_iters,
prelude_layers=2,
coda_layers=2,
attn_type="mla",
kv_lora_rank=192,
q_lora_rank=384,
qk_rope_head_dim=32,
qk_nope_head_dim=32,
v_head_dim=32,
n_experts=16,
n_shared_experts=1,
n_experts_per_tok=2,
expert_dim=1536,
act_threshold=0.99,
rope_theta=500000.0,
lora_rank=16,
dropout=0.0,
)


@dataclass
class TrainConfig:
# Model
max_loop_iters: int = 8
run_name: str = "looped_8"

# Data
dataset_name: str = "HuggingFaceFW/fineweb-edu"
dataset_config: str = "sample-10BT"
tokenizer: str = "gpt2"
seq_len: int = 1024

# Training (H100 can fit B=32 directly; adjust to 16 for A40-class GPUs)
batch_size: int = 32
grad_accum_steps: int = 1 # 32 * 1024 = 32,768 tokens/step
max_steps: int = 15000 # 15k * 32k = ~490M tokens
learning_rate: float = 3e-4
min_lr: float = 3e-5
warmup_steps: int = 500
weight_decay: float = 0.1
beta1: float = 0.9
beta2: float = 0.95
grad_clip: float = 1.0

# Logging & checkpointing
log_every: int = 20
eval_every: int = 2000
ckpt_every: int = 5000
output_dir: str = "/workspace/runs"

# Precision
dtype: str = "bfloat16"

# Eval
eval_seq_len: int = 1024
eval_batch_size: int = 8
eval_num_batches: int = 50

def to_dict(self):
return asdict(self)
102 changes: 102 additions & 0 deletions experiments/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
Streaming FineWeb-Edu dataloader.

Packs concatenated documents into fixed-length (input, target) pairs of
length seq_len, where target = input shifted by one. Each DataLoader worker
pulls a disjoint shard of the HuggingFace streaming dataset so workers never
overlap.
"""

from typing import Iterator

import torch
from torch.utils.data import IterableDataset, DataLoader, get_worker_info
from datasets import load_dataset
from transformers import AutoTokenizer


class FineWebEduStream(IterableDataset):
def __init__(
self,
tokenizer,
seq_len: int,
dataset_name: str,
dataset_config: str,
split: str = "train",
skip_docs: int = 0,
):
self.tokenizer = tokenizer
self.seq_len = seq_len
self.dataset_name = dataset_name
self.dataset_config = dataset_config
self.split = split
self.skip_docs = skip_docs
self.eos_id = tokenizer.eos_token_id or 0

def __iter__(self) -> Iterator[tuple]:
worker = get_worker_info()
num_workers = worker.num_workers if worker else 1
worker_id = worker.id if worker else 0

ds = load_dataset(
self.dataset_name,
name=self.dataset_config,
split=self.split,
streaming=True,
)
if self.skip_docs > 0:
ds = ds.skip(self.skip_docs)
if num_workers > 1:
ds = ds.shard(num_shards=num_workers, index=worker_id)

buffer: list[int] = []
need = self.seq_len + 1

for doc in ds:
text = doc.get("text", "")
if not text:
continue
ids = self.tokenizer.encode(text, add_special_tokens=False)
ids.append(self.eos_id)
buffer.extend(ids)

while len(buffer) >= need:
chunk = buffer[:need]
buffer = buffer[need - 1 :] # keep last token as start of next
x = torch.tensor(chunk[:-1], dtype=torch.long)
y = torch.tensor(chunk[1:], dtype=torch.long)
yield x, y


def build_loader(
tokenizer,
seq_len: int,
batch_size: int,
dataset_name: str,
dataset_config: str,
num_workers: int = 2,
split: str = "train",
skip_docs: int = 0,
) -> DataLoader:
ds = FineWebEduStream(
tokenizer=tokenizer,
seq_len=seq_len,
dataset_name=dataset_name,
dataset_config=dataset_config,
split=split,
skip_docs=skip_docs,
)
return DataLoader(
ds,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
persistent_workers=(num_workers > 0),
)


def get_tokenizer(name: str = "gpt2"):
tok = AutoTokenizer.from_pretrained(name)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
return tok
Loading