diff --git a/experiments/README.md b/experiments/README.md new file mode 100644 index 0000000..26d5100 --- /dev/null +++ b/experiments/README.md @@ -0,0 +1,78 @@ +# Loop-scaling validation experiments + +A minimal, reproducible training + evaluation pipeline that measures how +OpenMythos's validation perplexity changes as you vary the number of +recurrent loops at **inference time**. + +The pipeline trains two comparison models (same ~118M-parameter MLA+MoE +backbone, same ~491M tokens of FineWeb-Edu, same optimizer / schedule) +that differ only in their training-time loop strategy: + +| Run | Training `n_loops` per step | Role | +|---|---|---| +| `looped_8` | fixed at 8 | the default OpenMythos training style | +| `baseline_1` | fixed at 1 | dense-equivalent ablation | +| `looped_random` (optional) | uniformly sampled from `{4, 6, 8, 12, 16}` | tests whether random-loop training gives monotonic depth extrapolation | + +After training, `evaluate.py` sweeps `n_loops ∈ {1, 2, 4, 6, 8, 12, 16}` +at inference on a held-out FineWeb-Edu slice (`--skip_docs 2_000_000` +ensures no train/val overlap) and logs PPL + generation samples. +`plot_results.py` produces three figures: training loss, ρ(A) over +steps, and the inference-time loop-scaling curve. + +## Usage + +Requires a single GPU with ≥ 48 GB VRAM for `batch_size=32` at `n_loops=8` +(H100 80 GB, A100 80 GB, or A40 48 GB). On H100 SXM each 15k-step run +takes ~4 hours; `looped_random` needs smaller batches to fit `n_loops=16` +and takes ~3.5 hours. + +```bash +cd experiments +pip install matplotlib datasets transformers loguru + +# Looped (recommended default) +python train.py --run_name looped_8 --max_loop_iters 8 --max_steps 15000 + +# Baseline for comparison (trains ~3× faster since n_loops=1) +python train.py --run_name baseline_1 --max_loop_iters 1 --max_steps 15000 + +# Optional: random-loop training for depth-extrapolation ablation +python train.py --run_name looped_random \ + --max_loop_iters 16 \ + --loop_sample_mode random_set --loop_choices 4 6 8 12 16 \ + --batch_size 16 --grad_accum_steps 2 --max_steps 15000 + +# Inference-time loop sweep + generation samples +python evaluate.py --ckpt /workspace/runs/looped_8/ckpt_15000.pt \ + --loop_grid 1 2 4 6 8 12 16 +python evaluate.py --ckpt /workspace/runs/baseline_1/ckpt_15000.pt \ + --loop_grid 1 + +python plot_results.py --runs_dir /workspace/runs --out_dir /workspace/runs/figs +``` + +Or run all three phases end-to-end with default settings: + +```bash +bash run_all.sh # drives looped_8 + baseline_1 + evaluate + plot +``` + +## Files + +| File | Purpose | +|---|---| +| `config.py` | `mythos_150m()` MLA+MoE config (actual param count 117.8M) and `TrainConfig` dataclass | +| `data.py` | Streaming FineWeb-Edu loader with `skip_docs` for clean train/val split | +| `train.py` | AdamW + cosine schedule training with per-step `n_loops` logging; supports `--loop_sample_mode {fixed,random_set}` | +| `evaluate.py` | Loads a checkpoint, runs PPL sweep over `--loop_grid`, emits generation samples at trained and 2× loops | +| `plot_results.py` | Parses all `/train.log` + `/eval_ckpt_*.json` under a runs directory and draws three comparison figures | +| `run_all.sh` | Orchestrator: looped_8 → baseline_1 → eval → plot | + +## What the logs contain + +`train.log` is tab-separated with headers +`step tokens n_loops lr loss grad_norm rho_A step_s tok_per_s gpu_mem_gb`. +The `n_loops` column records the value actually used at each training +step (constant in `fixed` mode, varying in `random_set` mode) so you can +post-hoc slice losses by training-loop-depth. diff --git a/experiments/config.py b/experiments/config.py new file mode 100644 index 0000000..6a7353f --- /dev/null +++ b/experiments/config.py @@ -0,0 +1,88 @@ +""" +Experiment config for OpenMythos loop-scaling validation. + +Two model variants with identical param count/compute: +- looped: max_loop_iters=8 (the OpenMythos architecture) +- baseline: max_loop_iters=1 (equivalent to a plain transformer) + +Training data is FineWeb-Edu sample-10BT (streaming); we train on ~1B tokens. +""" + +from dataclasses import dataclass, field, asdict +from open_mythos.main import MythosConfig + + +def mythos_150m(max_loop_iters: int = 8) -> MythosConfig: + """ + ~150M parameter config tuned for a single H100. + + With max_loop_iters=8 and MoE (16 experts, top-2), activated params per + token ≈ 80M; total params ≈ 150M. The looped block is run 8 times so the + effective compute per forward matches a ~8x deeper plain transformer. + """ + return MythosConfig( + vocab_size=50257, + dim=768, + n_heads=12, + n_kv_heads=4, + max_seq_len=1024, + max_loop_iters=max_loop_iters, + prelude_layers=2, + coda_layers=2, + attn_type="mla", + kv_lora_rank=192, + q_lora_rank=384, + qk_rope_head_dim=32, + qk_nope_head_dim=32, + v_head_dim=32, + n_experts=16, + n_shared_experts=1, + n_experts_per_tok=2, + expert_dim=1536, + act_threshold=0.99, + rope_theta=500000.0, + lora_rank=16, + dropout=0.0, + ) + + +@dataclass +class TrainConfig: + # Model + max_loop_iters: int = 8 + run_name: str = "looped_8" + + # Data + dataset_name: str = "HuggingFaceFW/fineweb-edu" + dataset_config: str = "sample-10BT" + tokenizer: str = "gpt2" + seq_len: int = 1024 + + # Training (H100 can fit B=32 directly; adjust to 16 for A40-class GPUs) + batch_size: int = 32 + grad_accum_steps: int = 1 # 32 * 1024 = 32,768 tokens/step + max_steps: int = 15000 # 15k * 32k = ~490M tokens + learning_rate: float = 3e-4 + min_lr: float = 3e-5 + warmup_steps: int = 500 + weight_decay: float = 0.1 + beta1: float = 0.9 + beta2: float = 0.95 + grad_clip: float = 1.0 + + # Logging & checkpointing + log_every: int = 20 + eval_every: int = 2000 + ckpt_every: int = 5000 + output_dir: str = "/workspace/runs" + + # Precision + dtype: str = "bfloat16" + + # Eval + eval_seq_len: int = 1024 + eval_batch_size: int = 8 + eval_num_batches: int = 50 + + def to_dict(self): + return asdict(self) diff --git a/experiments/data.py b/experiments/data.py new file mode 100644 index 0000000..224c514 --- /dev/null +++ b/experiments/data.py @@ -0,0 +1,102 @@ +""" +Streaming FineWeb-Edu dataloader. + +Packs concatenated documents into fixed-length (input, target) pairs of +length seq_len, where target = input shifted by one. Each DataLoader worker +pulls a disjoint shard of the HuggingFace streaming dataset so workers never +overlap. +""" + +from typing import Iterator + +import torch +from torch.utils.data import IterableDataset, DataLoader, get_worker_info +from datasets import load_dataset +from transformers import AutoTokenizer + + +class FineWebEduStream(IterableDataset): + def __init__( + self, + tokenizer, + seq_len: int, + dataset_name: str, + dataset_config: str, + split: str = "train", + skip_docs: int = 0, + ): + self.tokenizer = tokenizer + self.seq_len = seq_len + self.dataset_name = dataset_name + self.dataset_config = dataset_config + self.split = split + self.skip_docs = skip_docs + self.eos_id = tokenizer.eos_token_id or 0 + + def __iter__(self) -> Iterator[tuple]: + worker = get_worker_info() + num_workers = worker.num_workers if worker else 1 + worker_id = worker.id if worker else 0 + + ds = load_dataset( + self.dataset_name, + name=self.dataset_config, + split=self.split, + streaming=True, + ) + if self.skip_docs > 0: + ds = ds.skip(self.skip_docs) + if num_workers > 1: + ds = ds.shard(num_shards=num_workers, index=worker_id) + + buffer: list[int] = [] + need = self.seq_len + 1 + + for doc in ds: + text = doc.get("text", "") + if not text: + continue + ids = self.tokenizer.encode(text, add_special_tokens=False) + ids.append(self.eos_id) + buffer.extend(ids) + + while len(buffer) >= need: + chunk = buffer[:need] + buffer = buffer[need - 1 :] # keep last token as start of next + x = torch.tensor(chunk[:-1], dtype=torch.long) + y = torch.tensor(chunk[1:], dtype=torch.long) + yield x, y + + +def build_loader( + tokenizer, + seq_len: int, + batch_size: int, + dataset_name: str, + dataset_config: str, + num_workers: int = 2, + split: str = "train", + skip_docs: int = 0, +) -> DataLoader: + ds = FineWebEduStream( + tokenizer=tokenizer, + seq_len=seq_len, + dataset_name=dataset_name, + dataset_config=dataset_config, + split=split, + skip_docs=skip_docs, + ) + return DataLoader( + ds, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=True, + persistent_workers=(num_workers > 0), + ) + + +def get_tokenizer(name: str = "gpt2"): + tok = AutoTokenizer.from_pretrained(name) + if tok.pad_token_id is None: + tok.pad_token = tok.eos_token + return tok diff --git a/experiments/evaluate.py b/experiments/evaluate.py new file mode 100644 index 0000000..24116b3 --- /dev/null +++ b/experiments/evaluate.py @@ -0,0 +1,156 @@ +""" +OpenMythos loop-scaling evaluation. + +Given a trained checkpoint, compute held-out perplexity on FineWeb-Edu's +validation stream while varying the number of inference-time recurrent +loops. This is the central test of the "more loops = deeper reasoning" +claim: a vanilla transformer cannot do this, a looped transformer should +show monotonically decreasing PPL that plateaus. + +Also emits generation samples at different loop counts for qualitative +comparison. + +Usage: + python evaluate.py --ckpt /workspace/runs/looped_8/ckpt_30000.pt + python evaluate.py --ckpt /workspace/runs/baseline_1/ckpt_30000.pt \\ + --loop_grid 1 +""" + +import argparse +import json +import math +from pathlib import Path + +import torch +import torch.nn.functional as F + +from open_mythos.main import OpenMythos, MythosConfig + +from data import build_loader, get_tokenizer + + +@torch.no_grad() +def compute_ppl(model, loader_iter, n_loops: int, num_batches: int, vocab_size: int): + model.eval() + total_loss = 0.0 + total_tokens = 0 + for _ in range(num_batches): + x, y = next(loader_iter) + x = x.to("cuda", non_blocking=True) + y = y.to("cuda", non_blocking=True) + logits = model(x, n_loops=n_loops) + loss = F.cross_entropy( + logits.float().view(-1, vocab_size), + y.view(-1), + reduction="sum", + ) + total_loss += loss.item() + total_tokens += y.numel() + return math.exp(total_loss / total_tokens), total_loss / total_tokens + + +@torch.no_grad() +def generate_sample(model, tokenizer, prompt: str, n_loops: int, max_new_tokens: int = 64): + model.eval() + ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda") + out = model.generate(ids, max_new_tokens=max_new_tokens, n_loops=n_loops, temperature=0.8, top_k=50) + return tokenizer.decode(out[0].tolist(), skip_special_tokens=True) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--ckpt", type=str, required=True) + ap.add_argument("--loop_grid", type=int, nargs="+", default=[1, 2, 4, 6, 8, 12, 16]) + ap.add_argument("--num_batches", type=int, default=50) + ap.add_argument("--batch_size", type=int, default=8) + ap.add_argument("--seq_len", type=int, default=1024) + ap.add_argument("--tokenizer", type=str, default="gpt2") + ap.add_argument("--output_json", type=str, default=None) + ap.add_argument("--sample_prompts", type=str, nargs="+", default=[ + "The main function of mitochondria is to", + "In physics, the second law of thermodynamics states that", + "A short guide to writing clear English:", + ]) + args = ap.parse_args() + + ckpt_path = Path(args.ckpt) + print(f"==> loading {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) + + mcfg_dict = ckpt["mcfg"] + mcfg = MythosConfig(**mcfg_dict) + print(f"==> model config: dim={mcfg.dim} n_experts={mcfg.n_experts} " + f"max_loop_iters={mcfg.max_loop_iters} attn_type={mcfg.attn_type}") + + model = OpenMythos(mcfg).to("cuda", torch.bfloat16) + model.load_state_dict(ckpt["model"]) + print(f"==> loaded step={ckpt['step']} tokens={ckpt['tokens_seen']/1e6:.1f}M") + + tok = get_tokenizer(args.tokenizer) + + # Build a validation loader that skips past the training window so + # evaluation never sees documents the model already trained on. + # With 500M tokens trained and ~500 tokens/doc average, train consumed + # ~1M docs; we skip 2M to be safe. + val_loader = build_loader( + tokenizer=tok, + seq_len=args.seq_len, + batch_size=args.batch_size, + dataset_name="HuggingFaceFW/fineweb-edu", + dataset_config="sample-10BT", + num_workers=1, + skip_docs=2_000_000, + ) + + # Cache a fixed evaluation set so each n_loops sees the same batches. + print(f"==> caching {args.num_batches} eval batches to memory...") + cached = [] + it = iter(val_loader) + for i in range(args.num_batches): + cached.append(next(it)) + if (i + 1) % 10 == 0: + print(f" cached {i+1}/{args.num_batches}") + + def cached_iter(): + for batch in cached: + yield batch + + print(f"\n==> loop grid sweep: {args.loop_grid}") + results = [] + for n_loops in args.loop_grid: + ppl, nll = compute_ppl( + model, cached_iter(), n_loops=n_loops, + num_batches=args.num_batches, vocab_size=mcfg.vocab_size, + ) + rho_A = model.recurrent.injection.get_A().max().item() + print(f"n_loops={n_loops:2d} ppl={ppl:7.3f} nll={nll:.4f} rho_A={rho_A:.4f}") + results.append({"n_loops": n_loops, "ppl": ppl, "nll": nll, "rho_A": rho_A}) + + print("\n==> generation samples (n_loops=trained / doubled)") + samples = {} + for n_loops in [mcfg.max_loop_iters, mcfg.max_loop_iters * 2]: + samples[n_loops] = {} + for prompt in args.sample_prompts: + gen = generate_sample(model, tok, prompt, n_loops=n_loops, max_new_tokens=48) + samples[n_loops][prompt] = gen + print(f"\n[n_loops={n_loops}] {prompt!r}\n -> {gen}") + + out_json = args.output_json or str(ckpt_path.parent / f"eval_{ckpt_path.stem}.json") + with open(out_json, "w") as f: + json.dump( + { + "ckpt": str(ckpt_path), + "step": ckpt["step"], + "tokens_seen": ckpt["tokens_seen"], + "trained_with_max_loop_iters": mcfg.max_loop_iters, + "loop_sweep": results, + "samples": samples, + }, + f, + indent=2, + ) + print(f"\n==> wrote {out_json}") + + +if __name__ == "__main__": + main() diff --git a/experiments/plot_results.py b/experiments/plot_results.py new file mode 100644 index 0000000..8e3153e --- /dev/null +++ b/experiments/plot_results.py @@ -0,0 +1,158 @@ +""" +Plot training curves + loop-scaling PPL sweep for OpenMythos experiments. + +Reads: + /workspace/runs/looped_8/train.log (step loss rho_A etc) + /workspace/runs/baseline_1/train.log + /workspace/runs/looped_8/eval_ckpt_*.json + /workspace/runs/baseline_1/eval_ckpt_*.json + +Writes: + figs/training_loss.png + figs/rho_A.png + figs/loop_scaling.png + figs/summary.md +""" + +import argparse +import json +from pathlib import Path + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + + +def load_log(path: Path): + rows = [] + with path.open() as f: + header = f.readline().strip().split("\t") + for line in f: + parts = line.strip().split("\t") + if len(parts) != len(header): + continue + row = {} + for k, v in zip(header, parts): + try: + row[k] = float(v) + except ValueError: + row[k] = v + rows.append(row) + return rows + + +def ema(xs, alpha=0.9): + out = [] + s = None + for x in xs: + s = x if s is None else alpha * s + (1 - alpha) * x + out.append(s) + return out + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--runs_dir", type=str, default="/workspace/runs") + ap.add_argument("--out_dir", type=str, default="/workspace/runs/figs") + args = ap.parse_args() + + runs_dir = Path(args.runs_dir) + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + runs = { + p.name: load_log(p / "train.log") + for p in sorted(runs_dir.glob("*")) + if (p / "train.log").exists() + } + print(f"loaded runs: {list(runs.keys())}") + + # Training loss + plt.figure(figsize=(7, 4.5)) + for name, rows in runs.items(): + steps = [r["step"] for r in rows] + losses = [r["loss"] for r in rows] + plt.plot(steps, losses, alpha=0.3, color="C0" if "looped" in name else "C1") + plt.plot(steps, ema(losses), label=name, color="C0" if "looped" in name else "C1", lw=2) + plt.xlabel("step") + plt.ylabel("train loss") + plt.title("OpenMythos training loss — looped vs baseline") + plt.legend() + plt.grid(alpha=0.3) + plt.tight_layout() + plt.savefig(out_dir / "training_loss.png", dpi=140) + plt.close() + + # rho_A + plt.figure(figsize=(7, 4.5)) + for name, rows in runs.items(): + steps = [r["step"] for r in rows] + rhos = [r["rho_A"] for r in rows] + plt.plot(steps, rhos, label=name) + plt.axhline(1.0, color="r", linestyle="--", alpha=0.5, label="instability bound") + plt.xlabel("step") + plt.ylabel(r"max element of $A_{\mathrm{discrete}}$ (= $\rho(A)$)") + plt.title("LTI injection spectral radius over training") + plt.legend() + plt.grid(alpha=0.3) + plt.tight_layout() + plt.savefig(out_dir / "rho_A.png", dpi=140) + plt.close() + + # Loop scaling sweep + evals = {} + for run_dir in runs_dir.iterdir(): + if not run_dir.is_dir(): + continue + for ev in sorted(run_dir.glob("eval_ckpt_*.json")): + with ev.open() as f: + data = json.load(f) + evals[run_dir.name] = data + + if evals: + plt.figure(figsize=(7, 4.5)) + for name, data in evals.items(): + sw = data["loop_sweep"] + xs = [r["n_loops"] for r in sw] + ys = [r["ppl"] for r in sw] + trained_at = data["trained_with_max_loop_iters"] + plt.plot(xs, ys, marker="o", label=f"{name} (trained loops={trained_at})") + plt.axvline(trained_at, color="gray", linestyle=":", alpha=0.4) + plt.xlabel("n_loops at inference") + plt.ylabel("validation perplexity") + plt.title("Test-time loop scaling — does more compute help?") + plt.legend() + plt.grid(alpha=0.3) + plt.tight_layout() + plt.savefig(out_dir / "loop_scaling.png", dpi=140) + plt.close() + + # Summary + lines = ["# OpenMythos Experiment Summary\n"] + for name, rows in runs.items(): + if not rows: + continue + last = rows[-1] + lines.append(f"## {name}\n") + lines.append(f"- final step: {int(last['step'])}") + lines.append(f"- final loss: {last['loss']:.3f}") + lines.append(f"- final rho(A): {last['rho_A']:.4f}") + lines.append(f"- tokens seen: {last['tokens']/1e6:.1f}M") + lines.append("") + if evals: + lines.append("## Loop-scaling sweep\n") + for name, data in evals.items(): + lines.append(f"### {name}\n") + lines.append("| n_loops | ppl | nll | rho_A |") + lines.append("|---|---|---|---|") + for r in data["loop_sweep"]: + lines.append(f"| {r['n_loops']} | {r['ppl']:.3f} | {r['nll']:.4f} | {r['rho_A']:.4f} |") + lines.append("") + + (out_dir / "summary.md").write_text("\n".join(lines)) + print(f"==> wrote figs to {out_dir}") + + +if __name__ == "__main__": + main() diff --git a/experiments/run_all.sh b/experiments/run_all.sh new file mode 100755 index 0000000..06c77ca --- /dev/null +++ b/experiments/run_all.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +# Orchestrate the full OpenMythos loop-scaling experiment on a single H100. +# +# Runs two training jobs sequentially: +# 1. looped_8 (max_loop_iters=8, the OpenMythos architecture) +# 2. baseline_1 (max_loop_iters=1, equivalent to a plain transformer) +# Then evaluates both with varying inference-time loops and plots everything. +# +# Expected wall-clock on H100 PCIe: ~4h train + ~4h train + ~30m eval = ~8.5h +set -e + +cd /workspace/OpenMythos/experiments + +MAX_STEPS=${MAX_STEPS:-15000} +BATCH_SIZE=${BATCH_SIZE:-32} +GRAD_ACCUM=${GRAD_ACCUM:-1} +RUNS_DIR=/workspace/runs +mkdir -p "$RUNS_DIR" + +echo "============================================================" +echo "TRAIN 1/2: looped_8 (max_loop_iters=8)" +echo "============================================================" +python train.py \ + --run_name looped_8 \ + --max_loop_iters 8 \ + --max_steps $MAX_STEPS \ + --batch_size $BATCH_SIZE \ + --grad_accum_steps $GRAD_ACCUM \ + --output_dir $RUNS_DIR \ + 2>&1 | tee "$RUNS_DIR/looped_8.stdout.log" + +echo "============================================================" +echo "TRAIN 2/2: baseline_1 (max_loop_iters=1, plain transformer)" +echo "============================================================" +python train.py \ + --run_name baseline_1 \ + --max_loop_iters 1 \ + --max_steps $MAX_STEPS \ + --batch_size $BATCH_SIZE \ + --grad_accum_steps $GRAD_ACCUM \ + --output_dir $RUNS_DIR \ + 2>&1 | tee "$RUNS_DIR/baseline_1.stdout.log" + +echo "============================================================" +echo "EVAL: loop-scaling sweep" +echo "============================================================" +LOOPED_CKPT=$(ls -t $RUNS_DIR/looped_8/ckpt_*.pt | head -1) +BASELINE_CKPT=$(ls -t $RUNS_DIR/baseline_1/ckpt_*.pt | head -1) + +python evaluate.py --ckpt "$LOOPED_CKPT" \ + --loop_grid 1 2 4 6 8 12 16 \ + 2>&1 | tee "$RUNS_DIR/looped_8.eval.log" + +python evaluate.py --ckpt "$BASELINE_CKPT" \ + --loop_grid 1 \ + 2>&1 | tee "$RUNS_DIR/baseline_1.eval.log" + +python plot_results.py --runs_dir $RUNS_DIR --out_dir $RUNS_DIR/figs + +echo "============================================================" +echo "DONE. See $RUNS_DIR/figs/ for plots and summary." +echo "============================================================" +ls -la $RUNS_DIR/figs/ diff --git a/experiments/train.py b/experiments/train.py new file mode 100644 index 0000000..26b0745 --- /dev/null +++ b/experiments/train.py @@ -0,0 +1,222 @@ +""" +OpenMythos loop-scaling validation — training script. + +Usage: + python train.py --run_name looped_8 --max_loop_iters 8 --max_steps 30000 + python train.py --run_name baseline_1 --max_loop_iters 1 --max_steps 30000 + +Writes: + /workspace/runs//train.log (plaintext per-step metrics) + /workspace/runs//config.json + /workspace/runs//ckpt_*.pt +""" + +import argparse +import json +import math +import os +import random +import time +from pathlib import Path + +import torch +import torch.nn.functional as F + +from open_mythos.main import OpenMythos + +from config import TrainConfig, mythos_150m +from data import build_loader, get_tokenizer + + +def cosine_lr(step: int, cfg: TrainConfig) -> float: + if step < cfg.warmup_steps: + return cfg.learning_rate * (step + 1) / cfg.warmup_steps + progress = (step - cfg.warmup_steps) / max(1, cfg.max_steps - cfg.warmup_steps) + progress = min(1.0, progress) + coeff = 0.5 * (1.0 + math.cos(math.pi * progress)) + return cfg.min_lr + (cfg.learning_rate - cfg.min_lr) * coeff + + +def count_params(model): + total = sum(p.numel() for p in model.parameters()) + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + return total, trainable + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--run_name", type=str, required=True) + ap.add_argument("--max_loop_iters", type=int, default=8) + ap.add_argument("--max_steps", type=int, default=30000) + ap.add_argument("--batch_size", type=int, default=8) + ap.add_argument("--grad_accum_steps", type=int, default=4) + ap.add_argument("--seq_len", type=int, default=1024) + ap.add_argument("--lr", type=float, default=3e-4) + ap.add_argument("--log_every", type=int, default=20) + ap.add_argument("--ckpt_every", type=int, default=5000) + ap.add_argument("--output_dir", type=str, default="/workspace/runs") + ap.add_argument( + "--loop_sample_mode", + type=str, + default="fixed", + choices=["fixed", "random_set"], + help="fixed: always use --max_loop_iters. random_set: uniformly sample each step from --loop_choices.", + ) + ap.add_argument( + "--loop_choices", + type=int, + nargs="+", + default=[4, 6, 8, 12, 16], + help="Set of n_loops values to uniformly sample from (only used in random_set mode).", + ) + args = ap.parse_args() + + cfg = TrainConfig( + run_name=args.run_name, + max_loop_iters=args.max_loop_iters, + max_steps=args.max_steps, + batch_size=args.batch_size, + grad_accum_steps=args.grad_accum_steps, + seq_len=args.seq_len, + learning_rate=args.lr, + log_every=args.log_every, + ckpt_every=args.ckpt_every, + output_dir=args.output_dir, + ) + + out_dir = Path(cfg.output_dir) / cfg.run_name + out_dir.mkdir(parents=True, exist_ok=True) + with (out_dir / "config.json").open("w") as f: + json.dump(cfg.to_dict(), f, indent=2) + + device = "cuda" + dtype = torch.bfloat16 + + torch.manual_seed(0) + + loop_mode = args.loop_sample_mode + loop_choices = args.loop_choices if loop_mode == "random_set" else [cfg.max_loop_iters] + print(f"==> run_name={cfg.run_name} max_loop_iters={cfg.max_loop_iters} " + f"loop_mode={loop_mode} loop_choices={loop_choices}") + print(f"==> output_dir={out_dir}") + + tok = get_tokenizer(cfg.tokenizer) + vocab_size = tok.vocab_size + + mcfg = mythos_150m(max_loop_iters=cfg.max_loop_iters) + mcfg.vocab_size = vocab_size + mcfg.max_seq_len = cfg.seq_len + model = OpenMythos(mcfg).to(device=device, dtype=dtype) + + total, trainable = count_params(model) + print(f"==> params total={total/1e6:.1f}M trainable={trainable/1e6:.1f}M") + + opt = torch.optim.AdamW( + model.parameters(), + lr=cfg.learning_rate, + betas=(cfg.beta1, cfg.beta2), + weight_decay=cfg.weight_decay, + ) + + loader = build_loader( + tokenizer=tok, + seq_len=cfg.seq_len, + batch_size=cfg.batch_size, + dataset_name=cfg.dataset_name, + dataset_config=cfg.dataset_config, + num_workers=2, + ) + loader_iter = iter(loader) + + log_path = out_dir / "train.log" + log_f = log_path.open("w", buffering=1) + log_f.write("step\ttokens\tn_loops\tlr\tloss\tgrad_norm\trho_A\tstep_s\ttok_per_s\tgpu_mem_gb\n") + + rng = random.Random(0) + + model.train() + tokens_seen = 0 + ema_loss = None + t_start = time.time() + step_start = time.time() + accum_loss = 0.0 + + for step in range(cfg.max_steps): + lr = cosine_lr(step, cfg) + for g in opt.param_groups: + g["lr"] = lr + + if loop_mode == "random_set": + n_loops_step = rng.choice(loop_choices) + else: + n_loops_step = cfg.max_loop_iters + + opt.zero_grad(set_to_none=True) + accum_loss = 0.0 + + for _ in range(cfg.grad_accum_steps): + try: + x, y = next(loader_iter) + except StopIteration: + loader_iter = iter(loader) + x, y = next(loader_iter) + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + + logits = model(x, n_loops=n_loops_step) + loss = F.cross_entropy( + logits.float().view(-1, mcfg.vocab_size), + y.view(-1), + ) + (loss / cfg.grad_accum_steps).backward() + accum_loss += loss.item() / cfg.grad_accum_steps + + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip).item() + opt.step() + + tokens_seen += cfg.batch_size * cfg.seq_len * cfg.grad_accum_steps + ema_loss = accum_loss if ema_loss is None else 0.98 * ema_loss + 0.02 * accum_loss + + if (step + 1) % cfg.log_every == 0 or step == 0: + torch.cuda.synchronize() + dt = time.time() - step_start + tok_per_s = cfg.batch_size * cfg.seq_len * cfg.grad_accum_steps * cfg.log_every / dt if step > 0 else 0 + rho_A = model.recurrent.injection.get_A().max().item() + mem = torch.cuda.max_memory_allocated() / 1e9 + line = ( + f"{step+1}\t{tokens_seen}\t{n_loops_step}\t{lr:.2e}\t{accum_loss:.4f}\t{grad_norm:.3f}\t" + f"{rho_A:.4f}\t{dt/max(1,cfg.log_every):.3f}\t{tok_per_s:.0f}\t{mem:.1f}" + ) + log_f.write(line + "\n") + print( + f"step {step+1}/{cfg.max_steps} " + f"n_loops={n_loops_step} " + f"loss={accum_loss:.3f} ema={ema_loss:.3f} " + f"lr={lr:.2e} gnorm={grad_norm:.2f} " + f"rho={rho_A:.3f} " + f"tok/s={tok_per_s:.0f} mem={mem:.1f}GB" + ) + step_start = time.time() + + if (step + 1) % cfg.ckpt_every == 0 or step + 1 == cfg.max_steps: + ckpt_path = out_dir / f"ckpt_{step+1}.pt" + torch.save( + { + "step": step + 1, + "model": model.state_dict(), + "opt": opt.state_dict(), + "cfg": cfg.to_dict(), + "mcfg": mcfg.__dict__, + "tokens_seen": tokens_seen, + }, + ckpt_path, + ) + print(f"==> saved ckpt {ckpt_path} tokens={tokens_seen/1e6:.1f}M") + + log_f.close() + total_time = time.time() - t_start + print(f"==> done in {total_time/3600:.2f}h tokens={tokens_seen/1e9:.2f}B") + + +if __name__ == "__main__": + main()