From e9477e49abe3942f23123d6fe2cbd4698d174ec7 Mon Sep 17 00:00:00 2001 From: Marko Sisovic Date: Wed, 25 Mar 2026 01:43:33 +0000 Subject: [PATCH 1/8] Getting some good results --- experiment_results.md | 164 ++++ train_gpt.py | 1927 +++++++++++++++++++++++++++++------------ 2 files changed, 1555 insertions(+), 536 deletions(-) create mode 100644 experiment_results.md diff --git a/experiment_results.md b/experiment_results.md new file mode 100644 index 000000000..cf48237c3 --- /dev/null +++ b/experiment_results.md @@ -0,0 +1,164 @@ +# Experiment Results + +Baseline reference: PR #549 — val_bpb 1.1194 (3-seed mean, 8xH100, dim=512, 11 layers) + +## Experiment 1: MODEL_DIM=576 (all else PR #549 defaults) + +- **Date**: 2026-03-24 +- **Hardware**: 4xH100 80GB +- **Key changes**: MODEL_DIM=576 (up from 512), MAX_WALLCLOCK_SECONDS=1200 +- **Other params**: ITERATIONS=9000, all other hyperparams at code defaults (matching PR #549) +- **model_params**: 33,968,348 (~34M, over budget) +- **Steps completed**: 5,609 / 9000 (wallclock capped at 1200s) +- **Step avg**: ~214ms +- **Peak memory**: 24,301 MiB + +### Results +| Metric | Value | +|--------|-------| +| Pre-EMA val_bpb | 1.1284 | +| Post-EMA val_bpb | 1.1277 | +| **Final int6 quantized val_bpb** | **1.1359** | +| Submission size (int6+lzma) | 19,525,669 bytes (~19.5 MB) | + +### Notes +- Over param budget (~34M vs typical ~24M), so not submittable as-is. +- Quantization gap is large: 1.1277 -> 1.1359 (+0.0082), likely because bigger model loses more to int6. +- Only got ~5.6k steps due to slower step time at larger dim on 4 GPUs. +- No TTT was run (would need separate eval pass). + +--- + +## Experiment 2: NUM_LAYERS=12 (dim=512, all else PR #549 defaults) + +- **Date**: 2026-03-24 +- **Hardware**: 4xH100 80GB +- **Key changes**: NUM_LAYERS=12 (up from 11), MAX_WALLCLOCK_SECONDS=1200 +- **Other params**: ITERATIONS=9000, MODEL_DIM=512, all other hyperparams at code defaults +- **model_params**: 29,355,620 (~29M, over budget but closer than dim=576) +- **Steps completed**: 6,878 / 9000 (wallclock capped at 1200s) +- **Step avg**: ~174ms +- **Peak memory**: 23,484 MiB + +### Results +| Metric | Value | +|--------|-------| +| Pre-EMA val_bpb | 1.1317 | +| Post-EMA val_bpb | 1.1307 | +| Final int6 quantized val_bpb | 1.1390 | +| **Final int6 sliding window val_bpb** | **1.1153** | +| Submission size (int6+lzma) | 17,275,453 bytes (~17.3 MB) | + +### Notes +- Also over param budget (~29M) but less so than dim=576. +- Quantization gap: 1.1307 -> 1.1390 (+0.0083), similar to exp 1. +- Sliding window eval (stride 64) brings it to 1.1153 — better than PR #549 baseline (1.1194) pre-TTT. +- Got more steps (6,878 vs 5,609) due to faster per-step time at dim=512. +- No TTT was run. + +--- + +## Experiment 3: NUM_LAYERS=12 + TTT (dim=512, all else PR #549 defaults) + +- **Date**: 2026-03-24 +- **Hardware**: 4xH100 80GB +- **Key changes**: NUM_LAYERS=12, TTT_ENABLED=1, MAX_WALLCLOCK_SECONDS=1200 +- **Other params**: ITERATIONS=9000, MODEL_DIM=512, all other hyperparams at code defaults +- **model_params**: 29,355,620 (~29M) +- **Steps completed**: 6,879 / 9000 (wallclock capped at 1200s) +- **Step avg**: ~174ms +- **TTT time**: 620s (1893 chunks, 3 epochs each) + +### Results +| Metric | Value | +|--------|-------| +| Post-EMA val_bpb | 1.1304 | +| Final int6 quantized val_bpb | 1.1387 | +| Final int6 sliding window val_bpb | 1.1151 | +| **Post-TTT sliding window val_bpb** | **1.1126** | + +### Notes +- TTT gave a -0.0025 gain over sliding window (1.1151 -> 1.1126), similar to PR #549's TTT gain. +- **Beats PR #549's 1.1194 by 0.0068 BPB** — but still over param budget (~29M). +- TTT took 620s which would be within a 10-min eval constraint. + +--- + +## Experiment 4: RECUR_LAYER=5 + TTT (depth recurrence, 11 physical → 12 virtual layers) + +- **Date**: 2026-03-25 +- **Hardware**: 4xH100 80GB +- **Key changes**: RECUR_LAYER=5 (layer 5 duplicated), TTT_ENABLED=1, MAX_WALLCLOCK_SECONDS=1200 +- **Other params**: ITERATIONS=9000, MODEL_DIM=512, NUM_LAYERS=11 (physical), all else defaults +- **model_params**: 26,996,324 (~27M) — only ~2.7M over baseline from extra block scalars +- **Steps completed**: 6,884 / 9000 (wallclock capped at 1200s) +- **Step avg**: ~174ms (same as full 12-layer) +- **TTT time**: 622s (untied recurrence before TTT) +- **Submission size**: 15,927,562 bytes (~15.9 MB, well under 16MB budget!) + +### Results +| Metric | Value | +|--------|-------| +| Post-EMA val_bpb | 1.1354 | +| Final int6 quantized val_bpb | 1.1440 | +| Final int6 sliding window val_bpb | 1.1205 | +| **Post-TTT sliding window val_bpb** | **1.1180** | + +### Comparison to full 12-layer (Exp 3) +| Metric | Full 12L (Exp 3) | Recur L5 (Exp 4) | Delta | +|--------|-----------------|-------------------|-------| +| Params | 29.4M | 27.0M | -2.4M | +| Submission size | 17.3 MB | 15.9 MB | -1.4 MB | +| Sliding window val_bpb | 1.1151 | 1.1205 | +0.0054 | +| Post-TTT val_bpb | 1.1126 | 1.1180 | +0.0054 | + +### Notes +- Recurrence adds depth for free in compute, but shared weights cost ~0.005 BPB vs independent layers. +- TTT untying gave a -0.0025 gain (1.1205 -> 1.1180), same magnitude as independent layers. +- Submission size is much smaller (15.9 MB vs 17.3 MB) since banks stay at 11-layer size. +- Still beats PR #549 baseline (1.1194) by 0.0014 BPB, with a smaller model. + +--- + +## Experiment 4b: Tied TTT (same checkpoint as Exp 4, no untying) + +- **Post-TTT val_bpb**: **1.1179** (vs 1.1180 untied — negligible difference) +- Conclusion: untying doesn't help with 3-epoch TTT. Tied is fine. + +--- + +## Experiment 5: RECUR_LAYERS=4,5 + tied TTT (dual recurrence, 11 physical → 13 virtual layers) + +- **Date**: 2026-03-25 +- **Hardware**: 4xH100 80GB +- **Key changes**: RECUR_LAYERS=4,5 (4,5,4,5 pattern), TTT_ENABLED=1, TTT_UNTIE=0 +- **Other params**: ITERATIONS=9000, MODEL_DIM=512, NUM_LAYERS=11 (physical), all else defaults +- **model_params**: 26,998,380 (~27M) +- **Steps completed**: 6,389 / 9000 (wallclock capped at 1200s) +- **Step avg**: ~188ms (up from 174ms with single recurrence — extra virtual layer costs ~14ms/step) +- **TTT time**: 655s (tied) +- **Submission size**: 15,944,748 bytes (~15.9 MB) + +### Results +| Metric | Value | +|--------|-------| +| Post-EMA val_bpb | 1.1337 | +| Final int6 quantized val_bpb | 1.1421 | +| Final int6 sliding window val_bpb | 1.1187 | +| **Post-TTT sliding window val_bpb** | **1.1163** | + +### Full comparison +| | PR #549 | Recur L5 (Exp 4) | Recur L4,5 (Exp 5) | Full 12L (Exp 3) | +|---|---|---|---|---| +| Virtual depth | 11 | 12 | **13** | 12 | +| Params | ~24M | ~27M | ~27M | ~29M | +| Submission size | ~19.5 MB | 15.9 MB | 15.9 MB | 17.3 MB | +| Steps completed | ~7,180 | 6,884 | 6,389 | 6,879 | +| Post-TTT val_bpb | 1.1194 | 1.1179 | **1.1163** | **1.1126** | + +### Notes +- Dual recurrence (1.1163) beats single recurrence (1.1179) by 0.0016 BPB. +- Beats PR #549 by 0.0031 BPB, with ~27M params and ~16 MB submission. +- Gap to full independent 12-layer (1.1126) is 0.0037 — weight sharing costs more with 2 repeated layers. +- Step time increased to ~188ms (from 174ms), resulting in ~500 fewer steps in the wallclock budget. +- The extra virtual depth helps despite fewer training steps. diff --git a/train_gpt.py b/train_gpt.py index 651beb2b8..942f03ad3 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -1,14 +1,8 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - from __future__ import annotations - import copy import glob import io +import lzma import math import os import random @@ -18,7 +12,11 @@ import uuid import zlib from pathlib import Path - +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" import numpy as np import sentencepiece as spm import torch @@ -26,156 +24,251 @@ import torch.nn.functional as F from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - +from flash_attn_interface import flash_attn_func as flash_attn_3_func class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") train_files = os.path.join(data_path, "fineweb_train_*.bin") val_files = os.path.join(data_path, "fineweb_val_*.bin") tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.6)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) + lawa_k = int(os.environ.get("LAWA_K", 10)) + lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) + value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + recur_layer = int(os.environ.get("RECUR_LAYER", -1)) # single layer compat + recur_layers_str = os.environ.get("RECUR_LAYERS", "") # comma-separated, e.g. "4,5" + eval_only = bool(int(os.environ.get("EVAL_ONLY", "0"))) + +# --- Batched Newton-Schulz orthogonalization --- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) + transposed = X.size(-2) > X.size(-1) if transposed: - X = X.T + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A + A = X @ X.mT + B = b * A + c * (A @ A) X = a * X + B @ X - return X.T if transposed else X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X +# --- Parallel Muon optimizer --- class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): super().__init__( params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) @torch.no_grad() def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" loss = None if closure is not None: with torch.enable_grad(): loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 + if not self._built: + self._build() for group in self.param_groups: - params = group["params"] - if not params: - continue lr = group["lr"] momentum = group["momentum"] backend_steps = group["backend_steps"] nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() state = self.state[p] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures return loss - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. +# --- Tokenizer evaluation helpers --- def build_sentencepiece_luts( sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device @@ -193,7 +286,7 @@ def build_sentencepiece_luts( base_bytes_np[token_id] = 1 continue piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): + if piece.startswith("\u2581"): has_leading_space_np[token_id] = True piece = piece[1:] base_bytes_np[token_id] = len(piece.encode("utf-8")) @@ -202,20 +295,15 @@ def build_sentencepiece_luts( torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), ) - - def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: files = [Path(p) for p in sorted(glob.glob(pattern))] if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() usable = ((tokens.numel() - 1) // seq_len) * seq_len if usable <= 0: raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") return tokens[: usable + 1] - - def eval_val( args: Hyperparameters, model: nn.Module, @@ -227,34 +315,32 @@ def eval_val( base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, ) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge + seq_len = eval_seq_len or args.train_seq_len local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: + if local_batch_tokens < seq_len: raise ValueError( "VAL_BATCH_SIZE must provide at least one sequence per rank; " f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len seq_start = (total_seqs * rank) // world_size seq_end = (total_seqs * (rank + 1)) // world_size val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) val_token_count = torch.zeros((), device=device, dtype=torch.float64) val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() with torch.inference_mode(): for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): batch_loss = model(x, y).detach() batch_token_count = float(y.numel()) @@ -265,31 +351,23 @@ def eval_val( token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count bits_per_token = val_loss.item() / math.log(2.0) tokens_per_byte = val_token_count.item() / val_byte_count.item() model.train() return float(val_loss.item()), float(bits_per_token * tokens_per_byte) -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. +# --- Quantization helpers --- CONTROL_TENSOR_NAME_PATTERNS = tuple( pattern for pattern in os.environ.get( "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", ).split(",") if pattern ) @@ -306,10 +384,8 @@ def eval_val( INT8_PER_ROW_SCALE_DTYPE = torch.float16 INT8_CLIP_PERCENTILE = 99.99984 INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - def tensor_nbytes(t: Tensor) -> int: return int(t.numel()) * int(t.element_size()) - def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): return t.float().contiguous() @@ -317,12 +393,9 @@ def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, s passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() return t - def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: t32 = t.float() if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. clip_abs = ( torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() @@ -332,19 +405,11 @@ def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() return q, scale - def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes quantized: dict[str, Tensor] = {} scales: dict[str, Tensor] = {} dtypes: dict[str, str] = {} @@ -355,27 +420,21 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), 0, ) - for name, tensor in state_dict.items(): t = tensor.detach().to("cpu").contiguous() stats["param_count"] += int(t.numel()) stats["num_tensors"] += 1 stats["baseline_tensor_bytes"] += tensor_nbytes(t) - if not t.is_floating_point(): stats["num_nonfloat_tensors"] += 1 passthrough[name] = t stats["int8_payload_bytes"] += tensor_nbytes(t) continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: kept = keep_float_tensor(name, t, passthrough_orig_dtypes) passthrough[name] = kept stats["int8_payload_bytes"] += tensor_nbytes(kept) continue - stats["num_float_tensors"] += 1 q, s = quantize_float_tensor(t) if s.ndim > 0: @@ -384,7 +443,6 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): scales[name] = s dtypes[name] = str(t.dtype).removeprefix("torch.") stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - obj: dict[str, object] = { "__quant_format__": "int8_clean_per_row_v1", "quantized": quantized, @@ -397,7 +455,6 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): if passthrough_orig_dtypes: obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes return obj, stats - def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: out: dict[str, Tensor] = {} qmeta = obj.get("qmeta", {}) @@ -407,13 +464,11 @@ def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: s = obj["scales"][name] if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() else: scale = float(s.item()) out[name] = (q.float() * scale).to(dtype=dtype).contiguous() for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. out_t = t.detach().to("cpu").contiguous() orig_dtype = passthrough_orig_dtypes.get(name) if isinstance(orig_dtype, str): @@ -421,16 +476,12 @@ def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: out[name] = out_t return out - -# ----------------------------- -# DATA LOADING -# ----------------------------- +# --- Data loading --- def load_data_shard(file: Path) -> Tensor: header_bytes = 256 * np.dtype(" Tensor: if tokens_np.size != num_tokens: raise ValueError(f"Short read for {file}") return torch.from_numpy(tokens_np.astype(np.uint16, copy=False)) - - class TokenStream: - # Reads shards sequentially and wraps around forever. The training loop therefore - # has deterministic, simple streaming behavior with no sampling or workers. def __init__(self, pattern: str): self.files = [Path(p) for p in sorted(glob.glob(pattern))] if not self.files: @@ -453,12 +500,10 @@ def __init__(self, pattern: str): self.file_idx = 0 self.tokens = load_data_shard(self.files[0]) self.pos = 0 - def _advance_file(self) -> None: self.file_idx = (self.file_idx + 1) % len(self.files) self.tokens = load_data_shard(self.files[self.file_idx]) self.pos = 0 - def take(self, n: int) -> Tensor: chunks: list[Tensor] = [] remaining = n @@ -472,17 +517,12 @@ def take(self, n: int) -> Tensor: self.pos += k remaining -= k return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): self.rank = rank self.world_size = world_size self.device = device self.stream = TokenStream(pattern) - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: local_tokens = global_tokens // (self.world_size * grad_accum_steps) per_rank_span = local_tokens + 1 @@ -493,44 +533,44 @@ def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> y = local[1:].reshape(-1, seq_len) return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- +# --- Transformer modules --- class RMSNorm(nn.Module): def __init__(self, eps: float | None = None): super().__init__() self.eps = eps - def forward(self, x: Tensor) -> Tensor: return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + _qat_enabled: bool = False def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - + return F.linear(x, w, bias) def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. with torch.no_grad(): for name, param in module.named_parameters(): if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: param.data = param.data.float() - - class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._cos_cached: Tensor | None = None self._sin_cached: Tensor | None = None - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: if ( self._cos_cached is None @@ -538,20 +578,30 @@ def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tup or self._seq_len_cached != seq_len or self._cos_cached.device != device ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] self._seq_len_cached = seq_len return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) half = x.size(-1) // 2 x1, x2 = x[..., :half], x[..., half:] return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - class CausalSelfAttention(nn.Module): def __init__( self, @@ -560,6 +610,8 @@ def __init__( num_kv_heads: int, rope_base: float, qk_gain_init: float, + gated_attention: bool = False, + value_residual: bool = False, ): super().__init__() if dim % num_heads != 0: @@ -571,51 +623,114 @@ def __init__( self.head_dim = dim // num_heads if self.head_dim % 2 != 0: raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True + # No CastedLinear -- weights come from banks self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + # Gated attention and value residual (non-banked small params) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout + gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) + y = y * gate + y = y.reshape(bsz, seqlen, dim) + return F.linear(y, out_w.to(x.dtype)), raw_v + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup def __init__(self, dim: int, mlp_mult: int): super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - + # No CastedLinear -- weights come from banks + def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: + x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5) + return F.linear(x.square(), down_w.to(x.dtype)) class Block(nn.Module): def __init__( @@ -626,24 +741,38 @@ def __init__( mlp_mult: int, rope_base: float, qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + gated_attention: bool = False, + value_residual: bool = False, ): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + gated_attention=gated_attention, value_residual=value_residual) self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, raw_v class GPT(nn.Module): def __init__( @@ -659,18 +788,65 @@ def __init__( logit_softcap: float, rope_base: float, qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + gated_attention: bool = False, + value_residual: bool = False, + recur_layer: int = -1, + recur_layers: list[int] | None = None, ): super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection if logit_softcap <= 0.0: raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") self.tie_embeddings = tie_embeddings self.tied_embed_init_std = tied_embed_init_std self.logit_softcap = logit_softcap + self.value_residual = value_residual + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + # Depth recurrence: duplicate layers virtually for near-zero param overhead + # Normalize: recur_layers is the canonical list; recur_layer is single-layer compat + if recur_layers is None and recur_layer >= 0: + recur_layers = [recur_layer] + self.recur_layers = sorted(recur_layers) if recur_layers else [] + self.num_physical_layers = num_layers + if self.recur_layers: + for rl in self.recur_layers: + assert 0 <= rl < num_layers, f"recur_layer={rl} out of range [0, {num_layers})" + cutoff = max(self.recur_layers) + 1 + self.v2p = list(range(cutoff)) + self.recur_layers + list(range(cutoff, num_layers)) + virtual_num_layers = num_layers + len(self.recur_layers) + else: + virtual_num_layers = num_layers + self.v2p = list(range(num_layers)) + self.virtual_num_layers = virtual_num_layers + self.num_encoder_layers = virtual_num_layers // 2 + self.num_decoder_layers = virtual_num_layers - self.num_encoder_layers self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Parameter banks: contiguous 3D tensors for batched optimizer (physical layer count) + head_dim = model_dim // num_heads + kv_dim = num_kv_heads * head_dim + mlp_dim = int(mlp_mult * model_dim) + self.num_layers = num_layers # physical, used for bank indexing + self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) + # Blocks: one per virtual layer (each has own scalar params) self.blocks = nn.ModuleList( [ Block( @@ -680,65 +856,601 @@ def __init__( mlp_mult, rope_base, qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + gated_attention=gated_attention, + value_residual=value_residual, ) - for i in range(num_layers) + for i in range(virtual_num_layers) ] ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim_ve = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) if self.lm_head is not None: self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, virtual_num_layers - xsa_last_n), virtual_num_layers): + self.blocks[i].attn.use_xsa = True self._init_weights() - def _init_weights(self) -> None: if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + # Init banks: orthogonal, with proj layers scaled down and out/down zero-init + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q + nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up + nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) + # Scale proj layers (out_proj and mlp_down are "proj" layers) + self.qo_bank.data[n + i].mul_(proj_scale) + self.mlp_down_bank.data[i].mul_(proj_scale) + # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + n = self.num_layers # physical layer count for bank offset + v2p = self.v2p x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) x0 = x + v0 = None skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. + ve_cache: dict = {} for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) + pi = v2p[i] + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], + self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v skips.append(x) for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + pi = v2p[bi] if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], + self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) + logits_proj = F.linear(x_flat, self.tok_emb.weight) else: if self.lm_head is None: raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) + logits_proj = self.lm_head(x_flat) logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + n = self.num_layers # physical layer count for bank offset + v2p = self.v2p + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + pi = v2p[i] + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], + self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + pi = v2p[bi] + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], + self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + def untie_recurrence(self): + """Expand banks so duplicated layers get their own independent weights. + Called before TTT so SGD can update each virtual layer independently.""" + if not self.recur_layers: + return + n = self.num_layers # physical count before expansion + # Build list of cloned rows to insert after max(recur_layers) + insert_after = max(self.recur_layers) + clones = sorted(self.recur_layers) # rows to duplicate + + def _expand(bank: Tensor) -> Tensor: + """Insert clones of recur rows after insert_after position.""" + parts = [bank[:insert_after + 1]] + for rl in clones: + parts.append(bank[rl:rl + 1].clone()) + parts.append(bank[insert_after + 1:]) + return torch.cat(parts, dim=0) + + # qo_bank: [2*n, dim, dim] -> first n are Q, last n are O + q_part = _expand(self.qo_bank.data[:n]) + o_part = _expand(self.qo_bank.data[n:]) + self.qo_bank = nn.Parameter(torch.cat([q_part, o_part], dim=0)) + + # kv_bank: [2*n, kv_dim, dim] -> first n are K, last n are V + k_part = _expand(self.kv_bank.data[:n]) + v_part = _expand(self.kv_bank.data[n:]) + self.kv_bank = nn.Parameter(torch.cat([k_part, v_part], dim=0)) + + # mlp banks: [n, ...] + self.mlp_up_bank = nn.Parameter(_expand(self.mlp_up_bank.data)) + self.mlp_down_bank = nn.Parameter(_expand(self.mlp_down_bank.data)) + + # Update to identity mapping + new_n = n + len(clones) + self.num_layers = new_n + self.num_physical_layers = new_n + self.v2p = list(range(new_n)) + self.recur_layers = [] + +# --- Sliding window evaluation --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) -# ----------------------------- -# TRAINING -# ----------------------------- + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") -def main() -> None: - global zeropower_via_newtonschulz5 + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# --- GPTQ-lite int6 quantization --- - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: + """Convert 3D bank tensors into individual 2D tensors with standard names.""" + out: dict[str, Tensor] = {} + n = num_layers + for name, tensor in sd.items(): + if name == "qo_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] + out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] + elif name == "kv_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] + out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] + elif name == "mlp_up_bank": + for i in range(n): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "mlp_down_bank": + for i in range(n): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Convert individual 2D tensors back into 3D bank tensors.""" + out: dict[str, Tensor] = {} + n = num_layers + # Reconstruct banks from individual weight keys + qo_slices = [None] * (2 * n) + kv_slices = [None] * (2 * n) + up_slices = [None] * n + down_slices = [None] * n + consumed = set() + for i in range(n): + qk = f"blocks.{i}.attn.c_q.weight" + if qk in sd: + qo_slices[i] = sd[qk] + consumed.add(qk) + ok = f"blocks.{i}.attn.proj.weight" + if ok in sd: + qo_slices[n + i] = sd[ok] + consumed.add(ok) + kk = f"blocks.{i}.attn.c_k.weight" + if kk in sd: + kv_slices[i] = sd[kk] + consumed.add(kk) + vk = f"blocks.{i}.attn.c_v.weight" + if vk in sd: + kv_slices[n + i] = sd[vk] + consumed.add(vk) + fk = f"blocks.{i}.mlp.fc.weight" + if fk in sd: + up_slices[i] = sd[fk] + consumed.add(fk) + dk = f"blocks.{i}.mlp.proj.weight" + if dk in sd: + down_slices[i] = sd[dk] + consumed.add(dk) + out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) + out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) + out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) + out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +# --- Training --- + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) @@ -757,23 +1469,18 @@ def main() -> None: dist.init_process_group(backend="nccl", device_id=device) dist.barrier() master_process = rank == 0 - - # Fast math knobs torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) enable_flash_sdp(True) enable_mem_efficient_sdp(False) enable_math_sdp(False) - logfile = None if master_process: os.makedirs("logs", exist_ok=True) logfile = f"logs/{args.run_id}.txt" print(logfile) - def log0(msg: str, console: bool = True) -> None: if not master_process: return @@ -782,7 +1489,6 @@ def log0(msg: str, console: bool = True) -> None: if logfile is not None: with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) - log0(code, console=False) log0("=" * 100, console=False) log0(f"Running Python {sys.version}", console=False) @@ -792,16 +1498,10 @@ def log0(msg: str, console: bool = True) -> None: console=False, ) log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) - if not args.tokenizer_path.endswith(".model"): raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) @@ -811,18 +1511,16 @@ def log0(msg: str, console: bool = True) -> None: ) dataset_dir = Path(args.data_path).resolve() actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( sp, args.vocab_size, device ) log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - + CastedLinear._qat_enabled = args.qat_enabled base_model = GPT( vocab_size=args.vocab_size, num_layers=args.num_layers, @@ -835,292 +1533,449 @@ def log0(msg: str, console: bool = True) -> None: logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + gated_attention=args.gated_attention, + value_residual=args.value_residual, + recur_layer=args.recur_layer, + recur_layers=([int(x) for x in args.recur_layers_str.split(",") if x.strip()] or None), ).to(device).bfloat16() + # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward + base_model.qo_bank.data = base_model.qo_bank.data.float() + base_model.kv_bank.data = base_model.kv_bank.data.float() + base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() + base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() for module in base_model.modules(): if isinstance(module, CastedLinear): module.float() restore_low_dim_params_to_fp32(base_model) + # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, + # and non-bank grads are manually all-reduced before Adam steps. compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + model = compiled_model + + if not args.eval_only: + # - 4 parameter banks -> Muon (batched Newton-Schulz) + # - token embedding -> Adam + # - scalars/control tensors -> Adam + # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) + matrix_params = [ + base_model.qo_bank, base_model.kv_bank, + base_model.mlp_up_bank, base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.adam_wd, fused=True, ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + # Non-bank params that need manual all-reduce (replicated across GPUs) + replicated_params = list(optimizer_tok.param_groups[0]["params"]) + for pg in optimizer_tok.param_groups[1:]: + replicated_params.extend(pg["params"]) + replicated_params.extend(scalar_params) + + optimizer_head = None + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + replicated_params.append(base_model.lm_head.weight) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if optimizer_head is not None: + optimizers.append(optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + if base_model.recur_layers: + log0(f"recurrence:layers={base_model.recur_layers} physical_layers={args.num_layers} virtual_layers={base_model.virtual_num_layers}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + # All-reduce all grads for warmup (simple, not optimized) + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + from collections import deque + lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum for opt in optimizers: - opt.step() + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + # === 3-phase overlapped optimizer step === + # Phase 1: Launch async reduce-scatter for banks (biggest first) + optimizer_muon.launch_reduce_scatters() + # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + optimizer_tok.step() + optimizer_scalar.step() + if optimizer_head is not None: + optimizer_head.step() + # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) + optimizer_muon.step() zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + if args.lawa_enabled and step % args.lawa_freq == 0: + lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: + if should_log_train: log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + # Apply weight averaging + if args.lawa_enabled and len(lawa_queue) > 1: + log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") + current_state = base_model.state_dict() + avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} + for snap in lawa_queue: + for name in avg_state: + avg_state[name] += snap[name].float() + for name in avg_state: + avg_state[name] /= len(lawa_queue) + avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) + base_model.load_state_dict(avg_state, strict=True) + else: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + # Unbank 3D tensors into individual 2D tensors for quantization + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) + quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") + log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + gated_attention=args.gated_attention, value_residual=args.value_residual, + recur_layer=args.recur_layer, + recur_layers=([int(x) for x in args.recur_layers_str.split(",") if x.strip()] or None), + ).to(device).bfloat16() + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + with open("final_model.int6.ptz", "rb") as f: quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + # Use eval_model's own state as dequant template + if not args.eval_only: + template_sd, template_unbanked = sd_cpu, unbanked_sd + else: + template_sd = {k: v.detach().cpu() for k, v in eval_model.state_dict().items()} + template_unbanked = _unbank_state_dict(template_sd, args.num_layers) + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], template_unbanked) + # Re-bank the dequantized tensors + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, template_sd) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) torch.cuda.synchronize() t_qeval = time.perf_counter() q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, ) torch.cuda.synchronize() log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + ttt_untie = bool(int(os.environ.get("TTT_UNTIE", "1"))) + if args.ttt_enabled: + if eval_model.recur_layers and ttt_untie: + log0(f"ttt:untying recurrence at layers {eval_model.recur_layers}") + eval_model.untie_recurrence() + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") if distributed: dist.destroy_process_group() - - if __name__ == "__main__": - main() + main() \ No newline at end of file From 51984ca3f2fb4b160db1d15fc645fa142466d5cf Mon Sep 17 00:00:00 2001 From: Marko Sisovic Date: Wed, 25 Mar 2026 04:15:42 +0000 Subject: [PATCH 2/8] Add submission: Depth Recurrence (layers 4,5) + TTT Seed 1337 complete (val_bpb=1.1179). Seeds 42 and 2024 need rerun after GPU restart (stale CUDA contexts blocking clean runs). Co-Authored-By: Claude Opus 4.6 --- .../2026-03-25_RecurLayers/README.md | 88 + .../2026-03-25_RecurLayers/submission.json | 22 + .../2026-03-25_RecurLayers/train_gpt.py | 1981 +++++++++++++++++ .../2026-03-25_RecurLayers/train_seed1337.log | 274 +++ 4 files changed, 2365 insertions(+) create mode 100644 records/track_10min_16mb/2026-03-25_RecurLayers/README.md create mode 100644 records/track_10min_16mb/2026-03-25_RecurLayers/submission.json create mode 100644 records/track_10min_16mb/2026-03-25_RecurLayers/train_gpt.py create mode 100644 records/track_10min_16mb/2026-03-25_RecurLayers/train_seed1337.log diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/README.md b/records/track_10min_16mb/2026-03-25_RecurLayers/README.md new file mode 100644 index 000000000..2dc88633e --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/README.md @@ -0,0 +1,88 @@ +# Depth Recurrence (layers 4,5) + +## Score: mean val_bpb = TODO (3 seeds: TODO, TODO, TODO) + +Trained on 8xH100 SXM in ~600 seconds. ~15.9MB artifact (int6+lzma). + +## Motivation + +We explored both width scaling (MODEL_DIM=576) and depth scaling (adding layers) and found that depth consistently wins over width in this regime. A full independent 12-layer model at dim=512 outperformed a wider 11-layer model at dim=576, despite the wider model having more parameters. However, adding independent layers pushes the model over the 16MB artifact budget. Depth recurrence solves this: by re-executing mid-network layers with independent block scalars, we get the depth benefit without the parameter/size cost. Dual recurrence on layers 4 and 5 gives us 13 virtual layers from 11 physical, staying well under budget at ~15.9MB. + +## Approach + +Depth recurrence applied to layers 4 and 5, creating 13 virtual layers from 11 physical layers while keeping parameter count at ~27M. Combined with test-time training (TTT) for additional evaluation-time adaptation. + +### 1. Dual Depth Recurrence (layers 4,5) +Layers 4 and 5 are each executed twice in sequence (pattern: 0,1,2,3,4,5,4,5,6,7,8,9,10), producing 13 virtual layers from 11 physical layers. Each recurrent pass uses independent learnable block scalars, so the model can modulate how the repeated layers behave on their second pass. This adds depth without increasing model size or artifact bytes — only the small block scalar parameters are added (~2K params). + +### 2. Test-Time Training (TTT) +At evaluation time, the model adapts its weights to the validation data using a short fine-tuning pass. 3 epochs over the validation set with lr=0.002, chunked into 32K-token segments. The top 2 blocks are frozen during TTT to preserve the output head's calibration. Tied TTT (no weight untying) performs equivalently to untied. + +### 3. Inherited Techniques from Baseline +- **Int6 quantization + lzma compression**: Per-row int6 quantization on MLP/attention weights +- **3x MLP expansion**: Hidden dim 1536 (3x model dim) +- **Bigram hash embeddings**: 2048-bucket hash table (dim=128) +- **Value embeddings**: Learned value residuals on layers 9,10 +- **SWA**: Stochastic weight averaging every 50 steps +- **Muon optimizer**: With weight decay 0.04, momentum warmup 0.92->0.99 +- **Orthogonal initialization** + +## Hyperparameters + +| Parameter | Value | +|-----------|-------| +| num_layers | 11 (physical) / 13 (virtual) | +| model_dim | 512 | +| mlp_mult | 3.0 (hidden=1536) | +| recur_layers | 4, 5 | +| train_seq_len | 2048 | +| train_batch_tokens | 786,432 | +| warmdown_iters | 3500 | +| matrix_lr | 0.025 | +| scalar_lr | 0.025 | +| tied_embed_lr | 0.035 | +| muon_momentum | 0.99 (warmup from 0.92 over 1500 steps) | +| muon_weight_decay | 0.04 | +| adam_weight_decay | 0.04 | +| grad_clip_norm | 0.3 | +| eval_stride | 64 | +| swa_every | 50 | +| ttt_lr | 0.002 | +| ttt_epochs | 3 | +| ttt_chunk_tokens | 32768 | +| ttt_freeze_blocks | 2 | + +## Key Metrics + +- **Mean val_bpb: TODO** (std: TODO) +- Training: TODO steps in ~600s +- Model params: ~27M +- Artifact size: ~15.9MB (int6+lzma) + +## Reproducibility + +Three independent training runs with different random seeds: + +| Seed | val_loss | val_bpb | +|------|----------|---------| +| 1337 | TODO | TODO | +| 42 | TODO | TODO | +| 2024 | TODO | TODO | +| **Mean** | **TODO** | **TODO** | +| **Std** | **TODO** | **TODO** | + +## Run Commands + +```bash +# Seed 1337 (default) +ITERATIONS=9000 RECUR_LAYERS=4,5 TTT_ENABLED=1 TTT_UNTIE=0 \ + torchrun --nproc_per_node=8 train_gpt.py + +# Seed 42 +ITERATIONS=9000 RECUR_LAYERS=4,5 TTT_ENABLED=1 TTT_UNTIE=0 SEED=42 \ + torchrun --nproc_per_node=8 train_gpt.py + +# Seed 2024 +ITERATIONS=9000 RECUR_LAYERS=4,5 TTT_ENABLED=1 TTT_UNTIE=0 SEED=2024 \ + torchrun --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/submission.json b/records/track_10min_16mb/2026-03-25_RecurLayers/submission.json new file mode 100644 index 000000000..191ee440b --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/submission.json @@ -0,0 +1,22 @@ +{ + "author": "Marko Sisovic", + "github_id": "msisovic", + "name": "Depth Recurrence (layers 4,5) + TTT", + "blurb": "Dual depth recurrence on layers 4 and 5 (11 physical -> 13 virtual layers) with tied test-time training. Reuses layer weights to add depth without increasing model size, keeping the artifact under 16MB with int6+lzma compression. Combined with TTT, SWA, bigram embeddings, value embeddings, and Muon optimizer with weight decay.", + "date": "2026-03-25T00:00:00Z", + "val_loss": "TODO_MEAN", + "val_bpb": "TODO_MEAN", + "val_loss_std": "TODO", + "val_bpb_std": "TODO", + "seeds": [1337, 42, 2024], + "seed_results": { + "1337": {"val_loss": 1.88749538, "val_bpb": 1.11788404}, + "42": {"val_loss": "TODO_RERUN", "val_bpb": "TODO_RERUN"}, + "2024": {"val_loss": "TODO_RERUN", "val_bpb": "TODO_RERUN"} + }, + "step_stop": 6100, + "wallclock_seconds": 600.188, + "eval_time_seconds": "TODO", + "bytes_total": 15928948, + "bytes_code": 95036 +} diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/train_gpt.py b/records/track_10min_16mb/2026-03-25_RecurLayers/train_gpt.py new file mode 100644 index 000000000..942f03ad3 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/train_gpt.py @@ -0,0 +1,1981 @@ +from __future__ import annotations +import copy +import glob +import io +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as flash_attn_3_func +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) + lawa_k = int(os.environ.get("LAWA_K", 10)) + lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) + value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + recur_layer = int(os.environ.get("RECUR_LAYER", -1)) # single layer compat + recur_layers_str = os.environ.get("RECUR_LAYERS", "") # comma-separated, e.g. "4,5" + eval_only = bool(int(os.environ.get("EVAL_ONLY", "0"))) + +# --- Batched Newton-Schulz orthogonalization --- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +# --- Parallel Muon optimizer --- + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# --- Tokenizer evaluation helpers --- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# --- Quantization helpers --- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# --- Data loading --- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# --- Transformer modules --- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + # No CastedLinear -- weights come from banks + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + # Gated attention and value residual (non-banked small params) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout + gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) + y = y * gate + y = y.reshape(bsz, seqlen, dim) + return F.linear(y, out_w.to(x.dtype)), raw_v + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + # No CastedLinear -- weights come from banks + def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: + x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5) + return F.linear(x.square(), down_w.to(x.dtype)) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + gated_attention=gated_attention, value_residual=value_residual) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, raw_v + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + gated_attention: bool = False, + value_residual: bool = False, + recur_layer: int = -1, + recur_layers: list[int] | None = None, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.value_residual = value_residual + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + # Depth recurrence: duplicate layers virtually for near-zero param overhead + # Normalize: recur_layers is the canonical list; recur_layer is single-layer compat + if recur_layers is None and recur_layer >= 0: + recur_layers = [recur_layer] + self.recur_layers = sorted(recur_layers) if recur_layers else [] + self.num_physical_layers = num_layers + if self.recur_layers: + for rl in self.recur_layers: + assert 0 <= rl < num_layers, f"recur_layer={rl} out of range [0, {num_layers})" + cutoff = max(self.recur_layers) + 1 + self.v2p = list(range(cutoff)) + self.recur_layers + list(range(cutoff, num_layers)) + virtual_num_layers = num_layers + len(self.recur_layers) + else: + virtual_num_layers = num_layers + self.v2p = list(range(num_layers)) + self.virtual_num_layers = virtual_num_layers + self.num_encoder_layers = virtual_num_layers // 2 + self.num_decoder_layers = virtual_num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Parameter banks: contiguous 3D tensors for batched optimizer (physical layer count) + head_dim = model_dim // num_heads + kv_dim = num_kv_heads * head_dim + mlp_dim = int(mlp_mult * model_dim) + self.num_layers = num_layers # physical, used for bank indexing + self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) + # Blocks: one per virtual layer (each has own scalar params) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + gated_attention=gated_attention, + value_residual=value_residual, + ) + for i in range(virtual_num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim_ve = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, virtual_num_layers - xsa_last_n), virtual_num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + # Init banks: orthogonal, with proj layers scaled down and out/down zero-init + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q + nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up + nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) + # Scale proj layers (out_proj and mlp_down are "proj" layers) + self.qo_bank.data[n + i].mul_(proj_scale) + self.mlp_down_bank.data[i].mul_(proj_scale) + # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + n = self.num_layers # physical layer count for bank offset + v2p = self.v2p + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + pi = v2p[i] + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], + self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + pi = v2p[bi] + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], + self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + n = self.num_layers # physical layer count for bank offset + v2p = self.v2p + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + pi = v2p[i] + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], + self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + pi = v2p[bi] + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], + self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def untie_recurrence(self): + """Expand banks so duplicated layers get their own independent weights. + Called before TTT so SGD can update each virtual layer independently.""" + if not self.recur_layers: + return + n = self.num_layers # physical count before expansion + # Build list of cloned rows to insert after max(recur_layers) + insert_after = max(self.recur_layers) + clones = sorted(self.recur_layers) # rows to duplicate + + def _expand(bank: Tensor) -> Tensor: + """Insert clones of recur rows after insert_after position.""" + parts = [bank[:insert_after + 1]] + for rl in clones: + parts.append(bank[rl:rl + 1].clone()) + parts.append(bank[insert_after + 1:]) + return torch.cat(parts, dim=0) + + # qo_bank: [2*n, dim, dim] -> first n are Q, last n are O + q_part = _expand(self.qo_bank.data[:n]) + o_part = _expand(self.qo_bank.data[n:]) + self.qo_bank = nn.Parameter(torch.cat([q_part, o_part], dim=0)) + + # kv_bank: [2*n, kv_dim, dim] -> first n are K, last n are V + k_part = _expand(self.kv_bank.data[:n]) + v_part = _expand(self.kv_bank.data[n:]) + self.kv_bank = nn.Parameter(torch.cat([k_part, v_part], dim=0)) + + # mlp banks: [n, ...] + self.mlp_up_bank = nn.Parameter(_expand(self.mlp_up_bank.data)) + self.mlp_down_bank = nn.Parameter(_expand(self.mlp_down_bank.data)) + + # Update to identity mapping + new_n = n + len(clones) + self.num_layers = new_n + self.num_physical_layers = new_n + self.v2p = list(range(new_n)) + self.recur_layers = [] + +# --- Sliding window evaluation --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# --- GPTQ-lite int6 quantization --- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: + """Convert 3D bank tensors into individual 2D tensors with standard names.""" + out: dict[str, Tensor] = {} + n = num_layers + for name, tensor in sd.items(): + if name == "qo_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] + out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] + elif name == "kv_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] + out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] + elif name == "mlp_up_bank": + for i in range(n): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "mlp_down_bank": + for i in range(n): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Convert individual 2D tensors back into 3D bank tensors.""" + out: dict[str, Tensor] = {} + n = num_layers + # Reconstruct banks from individual weight keys + qo_slices = [None] * (2 * n) + kv_slices = [None] * (2 * n) + up_slices = [None] * n + down_slices = [None] * n + consumed = set() + for i in range(n): + qk = f"blocks.{i}.attn.c_q.weight" + if qk in sd: + qo_slices[i] = sd[qk] + consumed.add(qk) + ok = f"blocks.{i}.attn.proj.weight" + if ok in sd: + qo_slices[n + i] = sd[ok] + consumed.add(ok) + kk = f"blocks.{i}.attn.c_k.weight" + if kk in sd: + kv_slices[i] = sd[kk] + consumed.add(kk) + vk = f"blocks.{i}.attn.c_v.weight" + if vk in sd: + kv_slices[n + i] = sd[vk] + consumed.add(vk) + fk = f"blocks.{i}.mlp.fc.weight" + if fk in sd: + up_slices[i] = sd[fk] + consumed.add(fk) + dk = f"blocks.{i}.mlp.proj.weight" + if dk in sd: + down_slices[i] = sd[dk] + consumed.add(dk) + out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) + out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) + out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) + out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# --- Training --- + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + gated_attention=args.gated_attention, + value_residual=args.value_residual, + recur_layer=args.recur_layer, + recur_layers=([int(x) for x in args.recur_layers_str.split(",") if x.strip()] or None), + ).to(device).bfloat16() + # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward + base_model.qo_bank.data = base_model.qo_bank.data.float() + base_model.kv_bank.data = base_model.kv_bank.data.float() + base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() + base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, + # and non-bank grads are manually all-reduced before Adam steps. + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = compiled_model + + if not args.eval_only: + # - 4 parameter banks -> Muon (batched Newton-Schulz) + # - token embedding -> Adam + # - scalars/control tensors -> Adam + # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) + matrix_params = [ + base_model.qo_bank, base_model.kv_bank, + base_model.mlp_up_bank, base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + # Non-bank params that need manual all-reduce (replicated across GPUs) + replicated_params = list(optimizer_tok.param_groups[0]["params"]) + for pg in optimizer_tok.param_groups[1:]: + replicated_params.extend(pg["params"]) + replicated_params.extend(scalar_params) + + optimizer_head = None + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + replicated_params.append(base_model.lm_head.weight) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if optimizer_head is not None: + optimizers.append(optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + if base_model.recur_layers: + log0(f"recurrence:layers={base_model.recur_layers} physical_layers={args.num_layers} virtual_layers={base_model.virtual_num_layers}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + # All-reduce all grads for warmup (simple, not optimized) + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + from collections import deque + lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + # === 3-phase overlapped optimizer step === + # Phase 1: Launch async reduce-scatter for banks (biggest first) + optimizer_muon.launch_reduce_scatters() + # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + optimizer_tok.step() + optimizer_scalar.step() + if optimizer_head is not None: + optimizer_head.step() + # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) + optimizer_muon.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + if args.lawa_enabled and step % args.lawa_freq == 0: + lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply weight averaging + if args.lawa_enabled and len(lawa_queue) > 1: + log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") + current_state = base_model.state_dict() + avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} + for snap in lawa_queue: + for name in avg_state: + avg_state[name] += snap[name].float() + for name in avg_state: + avg_state[name] /= len(lawa_queue) + avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) + base_model.load_state_dict(avg_state, strict=True) + else: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + # Unbank 3D tensors into individual 2D tensors for quantization + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) + quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") + log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + gated_attention=args.gated_attention, value_residual=args.value_residual, + recur_layer=args.recur_layer, + recur_layers=([int(x) for x in args.recur_layers_str.split(",") if x.strip()] or None), + ).to(device).bfloat16() + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + # Use eval_model's own state as dequant template + if not args.eval_only: + template_sd, template_unbanked = sd_cpu, unbanked_sd + else: + template_sd = {k: v.detach().cpu() for k, v in eval_model.state_dict().items()} + template_unbanked = _unbank_state_dict(template_sd, args.num_layers) + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], template_unbanked) + # Re-bank the dequantized tensors + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, template_sd) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + ttt_untie = bool(int(os.environ.get("TTT_UNTIE", "1"))) + if args.ttt_enabled: + if eval_model.recur_layers and ttt_untie: + log0(f"ttt:untying recurrence at layers {eval_model.recur_layers}") + eval_model.untie_recurrence() + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed1337.log b/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed1337.log new file mode 100644 index 000000000..038211e79 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed1337.log @@ -0,0 +1,274 @@ +W0325 03:04:09.626000 9338 torch/distributed/run.py:803] +W0325 03:04:09.626000 9338 torch/distributed/run.py:803] ***************************************** +W0325 03:04:09.626000 9338 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 03:04:09.626000 9338 torch/distributed/run.py:803] ***************************************** +logs/0b5022bd-4335-4cd0-b1f6-10ed1d9c971d.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26998380 +recurrence:layers=[4, 5] physical_layers=11 virtual_layers=13 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[9, 10, 11, 12] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/9000 val_loss:6.9309 val_bpb:4.1049 train_time:0ms step_avg:0.01ms +step:1/9000 train_loss:6.9317 train_time:145ms step_avg:145.35ms +step:2/9000 train_loss:8.5645 train_time:192ms step_avg:95.83ms +step:3/9000 train_loss:7.6147 train_time:286ms step_avg:95.17ms +step:4/9000 train_loss:7.3037 train_time:381ms step_avg:95.32ms +step:5/9000 train_loss:7.2074 train_time:476ms step_avg:95.21ms +step:6/9000 train_loss:7.1029 train_time:572ms step_avg:95.29ms +step:7/9000 train_loss:6.9765 train_time:667ms step_avg:95.32ms +step:8/9000 train_loss:6.9345 train_time:763ms step_avg:95.32ms +step:9/9000 train_loss:6.5741 train_time:857ms step_avg:95.25ms +step:10/9000 train_loss:6.1479 train_time:952ms step_avg:95.19ms +step:500/9000 train_loss:2.3847 train_time:48843ms step_avg:97.69ms +step:1000/9000 train_loss:2.2560 train_time:98080ms step_avg:98.08ms +step:1500/9000 train_loss:2.1966 train_time:147290ms step_avg:98.19ms +step:2000/9000 train_loss:2.0422 train_time:196499ms step_avg:98.25ms +step:2500/9000 train_loss:2.1505 train_time:245737ms step_avg:98.29ms +step:3000/9000 train_loss:2.1321 train_time:294888ms step_avg:98.30ms +step:3500/9000 train_loss:2.1387 train_time:344021ms step_avg:98.29ms +step:4000/9000 train_loss:1.9297 train_time:393161ms step_avg:98.29ms +step:4000/9000 val_loss:2.0194 val_bpb:1.1960 train_time:393212ms step_avg:98.30ms +step:4500/9000 train_loss:2.0776 train_time:442263ms step_avg:98.28ms +step:5000/9000 train_loss:2.0527 train_time:491349ms step_avg:98.27ms +swa:start step:5450 +step:5500/9000 train_loss:1.9627 train_time:540544ms step_avg:98.28ms +late_qat:enabled step:5579 scale:0.1500 +step:6000/9000 train_loss:1.8879 train_time:590216ms step_avg:98.37ms +step:6100/9000 val_loss:1.9185 val_bpb:1.1362 train_time:600188ms step_avg:98.39ms +stopping_early: wallclock_cap train_time:600188ms step:6100/9000 +peak memory allocated: 25295 MiB reserved: 26088 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9170 val_bpb:1.1353 eval_time:2344ms +Serialized model: 106179454 bytes +Code size: 95036 bytes +Serialized model int6+lzma: 15833912 bytes +Total submission size int6+lzma: 15928948 bytes +final_int6_roundtrip val_loss:1.9321 val_bpb:1.1443 eval_time:21468ms +final_int6_roundtrip_exact val_loss:1.93214110 val_bpb:1.14432279 +final_int6_sliding_window val_loss:1.8925 val_bpb:1.1208 stride:64 eval_time:110318ms +final_int6_sliding_window_exact val_loss:1.89245481 val_bpb:1.12082130 +final_int8_zlib_roundtrip_exact val_loss:1.89245481 val_bpb:1.12082130 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=2 +ttt_sliding:params unfrozen=26994268 frozen=4112 + ttt_chunk [1/1893] bpb=1.149595 time=0.5s + ttt_chunk [11/1893] bpb=1.144860 time=3.0s + ttt_chunk [21/1893] bpb=1.130659 time=5.5s + ttt_chunk [31/1893] bpb=1.128493 time=8.0s + ttt_chunk [41/1893] bpb=1.114966 time=10.6s + ttt_chunk [51/1893] bpb=1.109232 time=13.1s + ttt_chunk [61/1893] bpb=1.115662 time=15.6s + ttt_chunk [71/1893] bpb=1.114420 time=18.1s + ttt_chunk [81/1893] bpb=1.113703 time=20.6s + ttt_chunk [91/1893] bpb=1.114592 time=23.2s + ttt_chunk [101/1893] bpb=1.118197 time=25.7s + ttt_chunk [111/1893] bpb=1.120626 time=28.2s + ttt_chunk [121/1893] bpb=1.113734 time=30.7s + ttt_chunk [131/1893] bpb=1.113868 time=33.2s + ttt_chunk [141/1893] bpb=1.119649 time=35.7s + ttt_chunk [151/1893] bpb=1.121408 time=38.2s + ttt_chunk [161/1893] bpb=1.120830 time=40.8s + ttt_chunk [171/1893] bpb=1.125186 time=43.3s + ttt_chunk [181/1893] bpb=1.127255 time=45.8s + ttt_chunk [191/1893] bpb=1.134413 time=48.3s + ttt_chunk [201/1893] bpb=1.133158 time=50.8s + ttt_chunk [211/1893] bpb=1.130866 time=53.3s + ttt_chunk [221/1893] bpb=1.132428 time=55.9s + ttt_chunk [231/1893] bpb=1.131007 time=58.4s + ttt_chunk [241/1893] bpb=1.131338 time=60.9s + ttt_chunk [251/1893] bpb=1.130756 time=63.4s + ttt_chunk [261/1893] bpb=1.127867 time=66.0s + ttt_chunk [271/1893] bpb=1.126701 time=68.5s + ttt_chunk [281/1893] bpb=1.128083 time=71.0s + ttt_chunk [291/1893] bpb=1.129891 time=73.5s + ttt_chunk [301/1893] bpb=1.130617 time=76.0s + ttt_chunk [311/1893] bpb=1.132800 time=78.5s + ttt_chunk [321/1893] bpb=1.134770 time=81.0s + ttt_chunk [331/1893] bpb=1.134598 time=83.5s + ttt_chunk [341/1893] bpb=1.133566 time=86.0s + ttt_chunk [351/1893] bpb=1.135875 time=88.5s + ttt_chunk [361/1893] bpb=1.136122 time=91.0s + ttt_chunk [371/1893] bpb=1.135433 time=93.5s + ttt_chunk [381/1893] bpb=1.135618 time=96.0s + ttt_chunk [391/1893] bpb=1.135421 time=98.5s + ttt_chunk [401/1893] bpb=1.133384 time=101.0s + ttt_chunk [411/1893] bpb=1.132187 time=103.5s + ttt_chunk [421/1893] bpb=1.131277 time=106.0s + ttt_chunk [431/1893] bpb=1.131156 time=108.6s + ttt_chunk [441/1893] bpb=1.131604 time=111.0s + ttt_chunk [451/1893] bpb=1.131891 time=113.5s + ttt_chunk [461/1893] bpb=1.130814 time=116.1s + ttt_chunk [471/1893] bpb=1.131442 time=118.6s + ttt_chunk [481/1893] bpb=1.131098 time=121.1s + ttt_chunk [491/1893] bpb=1.129967 time=123.6s + ttt_chunk [501/1893] bpb=1.129522 time=126.1s + ttt_chunk [511/1893] bpb=1.128859 time=128.6s + ttt_chunk [521/1893] bpb=1.126487 time=131.1s + ttt_chunk [531/1893] bpb=1.127705 time=133.6s + ttt_chunk [541/1893] bpb=1.128083 time=136.1s + ttt_chunk [551/1893] bpb=1.127075 time=138.6s + ttt_chunk [561/1893] bpb=1.127641 time=141.1s + ttt_chunk [571/1893] bpb=1.126636 time=143.6s + ttt_chunk [581/1893] bpb=1.125882 time=146.1s + ttt_chunk [591/1893] bpb=1.125267 time=148.6s + ttt_chunk [601/1893] bpb=1.125751 time=151.1s + ttt_chunk [611/1893] bpb=1.125689 time=153.6s + ttt_chunk [621/1893] bpb=1.125524 time=156.1s + ttt_chunk [631/1893] bpb=1.126261 time=158.7s + ttt_chunk [641/1893] bpb=1.126015 time=161.2s + ttt_chunk [651/1893] bpb=1.126099 time=163.7s + ttt_chunk [661/1893] bpb=1.125578 time=166.2s + ttt_chunk [671/1893] bpb=1.125936 time=168.7s + ttt_chunk [681/1893] bpb=1.126666 time=171.2s + ttt_chunk [691/1893] bpb=1.127662 time=173.7s + ttt_chunk [701/1893] bpb=1.127101 time=176.2s + ttt_chunk [711/1893] bpb=1.127076 time=178.7s + ttt_chunk [721/1893] bpb=1.126723 time=181.2s + ttt_chunk [731/1893] bpb=1.126778 time=183.8s + ttt_chunk [741/1893] bpb=1.126872 time=186.3s + ttt_chunk [751/1893] bpb=1.126710 time=188.8s + ttt_chunk [761/1893] bpb=1.126628 time=191.4s + ttt_chunk [771/1893] bpb=1.126293 time=193.9s + ttt_chunk [781/1893] bpb=1.127027 time=196.4s + ttt_chunk [791/1893] bpb=1.126598 time=198.9s + ttt_chunk [801/1893] bpb=1.126939 time=201.4s + ttt_chunk [811/1893] bpb=1.126684 time=203.9s + ttt_chunk [821/1893] bpb=1.126459 time=206.4s + ttt_chunk [831/1893] bpb=1.126271 time=208.9s + ttt_chunk [841/1893] bpb=1.125602 time=211.4s + ttt_chunk [851/1893] bpb=1.125312 time=213.9s + ttt_chunk [861/1893] bpb=1.125070 time=216.4s + ttt_chunk [871/1893] bpb=1.125326 time=218.9s + ttt_chunk [881/1893] bpb=1.125508 time=221.5s + ttt_chunk [891/1893] bpb=1.125071 time=224.0s + ttt_chunk [901/1893] bpb=1.124822 time=226.5s + ttt_chunk [911/1893] bpb=1.124940 time=229.0s + ttt_chunk [921/1893] bpb=1.125432 time=231.5s + ttt_chunk [931/1893] bpb=1.125398 time=234.0s + ttt_chunk [941/1893] bpb=1.125069 time=236.5s + ttt_chunk [951/1893] bpb=1.125495 time=239.0s + ttt_chunk [961/1893] bpb=1.125568 time=241.5s + ttt_chunk [971/1893] bpb=1.126425 time=244.1s + ttt_chunk [981/1893] bpb=1.126515 time=246.6s + ttt_chunk [991/1893] bpb=1.126548 time=249.1s + ttt_chunk [1001/1893] bpb=1.126519 time=251.6s + ttt_chunk [1011/1893] bpb=1.126294 time=254.1s + ttt_chunk [1021/1893] bpb=1.126616 time=256.6s + ttt_chunk [1031/1893] bpb=1.127071 time=259.1s + ttt_chunk [1041/1893] bpb=1.126714 time=261.6s + ttt_chunk [1051/1893] bpb=1.126465 time=264.1s + ttt_chunk [1061/1893] bpb=1.126523 time=266.6s + ttt_chunk [1071/1893] bpb=1.127116 time=269.1s + ttt_chunk [1081/1893] bpb=1.127398 time=271.6s + ttt_chunk [1091/1893] bpb=1.128136 time=274.2s + ttt_chunk [1101/1893] bpb=1.128158 time=276.6s + ttt_chunk [1111/1893] bpb=1.128007 time=279.2s + ttt_chunk [1121/1893] bpb=1.127817 time=281.7s + ttt_chunk [1131/1893] bpb=1.127688 time=284.2s + ttt_chunk [1141/1893] bpb=1.127398 time=286.7s + ttt_chunk [1151/1893] bpb=1.127419 time=289.2s + ttt_chunk [1161/1893] bpb=1.127044 time=291.7s + ttt_chunk [1171/1893] bpb=1.127350 time=294.2s + ttt_chunk [1181/1893] bpb=1.126595 time=296.7s + ttt_chunk [1191/1893] bpb=1.126455 time=299.2s + ttt_chunk [1201/1893] bpb=1.126858 time=301.7s + ttt_chunk [1211/1893] bpb=1.126384 time=304.2s + ttt_chunk [1221/1893] bpb=1.126089 time=306.7s + ttt_chunk [1231/1893] bpb=1.125807 time=309.2s + ttt_chunk [1241/1893] bpb=1.125451 time=311.7s + ttt_chunk [1251/1893] bpb=1.124854 time=314.2s + ttt_chunk [1261/1893] bpb=1.124831 time=316.7s + ttt_chunk [1271/1893] bpb=1.124454 time=319.2s + ttt_chunk [1281/1893] bpb=1.124276 time=321.6s + ttt_chunk [1291/1893] bpb=1.124034 time=324.2s + ttt_chunk [1301/1893] bpb=1.123472 time=326.7s + ttt_chunk [1311/1893] bpb=1.123101 time=329.1s + ttt_chunk [1321/1893] bpb=1.122783 time=331.6s + ttt_chunk [1331/1893] bpb=1.122728 time=334.1s + ttt_chunk [1341/1893] bpb=1.122603 time=336.6s + ttt_chunk [1351/1893] bpb=1.122539 time=339.1s + ttt_chunk [1361/1893] bpb=1.122574 time=341.6s + ttt_chunk [1371/1893] bpb=1.122431 time=344.1s + ttt_chunk [1381/1893] bpb=1.122414 time=346.6s + ttt_chunk [1391/1893] bpb=1.122020 time=349.2s + ttt_chunk [1401/1893] bpb=1.121968 time=351.7s + ttt_chunk [1411/1893] bpb=1.122084 time=354.2s + ttt_chunk [1421/1893] bpb=1.122329 time=356.7s + ttt_chunk [1431/1893] bpb=1.122018 time=359.2s + ttt_chunk [1441/1893] bpb=1.122520 time=361.6s + ttt_chunk [1451/1893] bpb=1.122848 time=364.2s + ttt_chunk [1461/1893] bpb=1.122393 time=366.7s + ttt_chunk [1471/1893] bpb=1.123433 time=369.2s + ttt_chunk [1481/1893] bpb=1.123004 time=371.7s + ttt_chunk [1491/1893] bpb=1.122822 time=374.2s + ttt_chunk [1501/1893] bpb=1.122721 time=376.7s + ttt_chunk [1511/1893] bpb=1.122739 time=379.1s + ttt_chunk [1521/1893] bpb=1.122758 time=381.6s + ttt_chunk [1531/1893] bpb=1.122243 time=384.1s + ttt_chunk [1541/1893] bpb=1.122105 time=386.6s + ttt_chunk [1551/1893] bpb=1.122420 time=389.1s + ttt_chunk [1561/1893] bpb=1.122437 time=391.6s + ttt_chunk [1571/1893] bpb=1.122273 time=394.2s + ttt_chunk [1581/1893] bpb=1.122387 time=396.7s + ttt_chunk [1591/1893] bpb=1.122241 time=399.1s + ttt_chunk [1601/1893] bpb=1.122411 time=401.6s + ttt_chunk [1611/1893] bpb=1.122349 time=404.2s + ttt_chunk [1621/1893] bpb=1.121958 time=406.7s + ttt_chunk [1631/1893] bpb=1.122271 time=409.2s + ttt_chunk [1641/1893] bpb=1.122288 time=411.7s + ttt_chunk [1651/1893] bpb=1.122242 time=414.2s + ttt_chunk [1661/1893] bpb=1.122123 time=416.7s + ttt_chunk [1671/1893] bpb=1.122594 time=419.2s + ttt_chunk [1681/1893] bpb=1.122739 time=421.7s + ttt_chunk [1691/1893] bpb=1.122583 time=424.2s + ttt_chunk [1701/1893] bpb=1.122750 time=426.7s + ttt_chunk [1711/1893] bpb=1.122758 time=429.2s + ttt_chunk [1721/1893] bpb=1.122755 time=431.7s + ttt_chunk [1731/1893] bpb=1.122630 time=434.2s + ttt_chunk [1741/1893] bpb=1.122445 time=436.7s + ttt_chunk [1751/1893] bpb=1.122287 time=439.2s + ttt_chunk [1761/1893] bpb=1.122426 time=441.7s + ttt_chunk [1771/1893] bpb=1.122334 time=444.2s + ttt_chunk [1781/1893] bpb=1.122368 time=446.7s + ttt_chunk [1791/1893] bpb=1.121959 time=449.2s + ttt_chunk [1801/1893] bpb=1.121851 time=451.7s + ttt_chunk [1811/1893] bpb=1.121747 time=454.2s + ttt_chunk [1821/1893] bpb=1.121805 time=456.7s + ttt_chunk [1831/1893] bpb=1.121219 time=459.2s + ttt_chunk [1841/1893] bpb=1.121157 time=461.7s + ttt_chunk [1851/1893] bpb=1.120944 time=464.2s + ttt_chunk [1861/1893] bpb=1.120574 time=466.7s + ttt_chunk [1871/1893] bpb=1.120562 time=469.2s + ttt_chunk [1881/1893] bpb=1.120122 time=471.7s + ttt_chunk [1891/1893] bpb=1.119889 time=474.2s + ttt_chunk [1893/1893] bpb=1.119934 time=474.6s +ttt_sliding:done val_loss=1.887495 val_bpb=1.117884 elapsed=474.6s +legal_ttt val_loss:1.8875 val_bpb:1.1179 eval_time:475147ms +legal_ttt_exact val_loss:1.88749538 val_bpb:1.11788404 From 9c2cad6262bacaffbc1dd8ffd17a42fcf75c80b8 Mon Sep 17 00:00:00 2001 From: Marko Sisovic Date: Wed, 25 Mar 2026 05:50:59 +0000 Subject: [PATCH 3/8] Finalize 3-seed results: mean val_bpb=1.1182 (seeds 1337,2025,2024) Co-Authored-By: Claude Opus 4.6 --- .../2026-03-25_RecurLayers/README.md | 20 +- .../2026-03-25_RecurLayers/submission.json | 18 +- .../2026-03-25_RecurLayers/train_seed2024.log | 274 ++++++++++++++++++ .../2026-03-25_RecurLayers/train_seed2025.log | 274 ++++++++++++++++++ 4 files changed, 567 insertions(+), 19 deletions(-) create mode 100644 records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2024.log create mode 100644 records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2025.log diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/README.md b/records/track_10min_16mb/2026-03-25_RecurLayers/README.md index 2dc88633e..53db1c8cc 100644 --- a/records/track_10min_16mb/2026-03-25_RecurLayers/README.md +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/README.md @@ -1,6 +1,6 @@ # Depth Recurrence (layers 4,5) -## Score: mean val_bpb = TODO (3 seeds: TODO, TODO, TODO) +## Score: mean val_bpb = 1.1182 (3 seeds: 1.1179, 1.1191, 1.1176) Trained on 8xH100 SXM in ~600 seconds. ~15.9MB artifact (int6+lzma). @@ -54,8 +54,8 @@ At evaluation time, the model adapts its weights to the validation data using a ## Key Metrics -- **Mean val_bpb: TODO** (std: TODO) -- Training: TODO steps in ~600s +- **Mean val_bpb: 1.11819** (std: 0.00076) +- Training: ~6,100 steps in ~600s - Model params: ~27M - Artifact size: ~15.9MB (int6+lzma) @@ -65,11 +65,11 @@ Three independent training runs with different random seeds: | Seed | val_loss | val_bpb | |------|----------|---------| -| 1337 | TODO | TODO | -| 42 | TODO | TODO | -| 2024 | TODO | TODO | -| **Mean** | **TODO** | **TODO** | -| **Std** | **TODO** | **TODO** | +| 1337 | 1.88749538 | 1.11788404 | +| 2025 | 1.88948575 | 1.11906285 | +| 2024 | 1.88706558 | 1.11762949 | +| **Mean** | **1.88801557** | **1.11819213** | +| **Std** | **0.00129122** | **0.00076473** | ## Run Commands @@ -78,8 +78,8 @@ Three independent training runs with different random seeds: ITERATIONS=9000 RECUR_LAYERS=4,5 TTT_ENABLED=1 TTT_UNTIE=0 \ torchrun --nproc_per_node=8 train_gpt.py -# Seed 42 -ITERATIONS=9000 RECUR_LAYERS=4,5 TTT_ENABLED=1 TTT_UNTIE=0 SEED=42 \ +# Seed 2025 +ITERATIONS=9000 RECUR_LAYERS=4,5 TTT_ENABLED=1 TTT_UNTIE=0 SEED=2025 \ torchrun --nproc_per_node=8 train_gpt.py # Seed 2024 diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/submission.json b/records/track_10min_16mb/2026-03-25_RecurLayers/submission.json index 191ee440b..c33dede75 100644 --- a/records/track_10min_16mb/2026-03-25_RecurLayers/submission.json +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/submission.json @@ -4,19 +4,19 @@ "name": "Depth Recurrence (layers 4,5) + TTT", "blurb": "Dual depth recurrence on layers 4 and 5 (11 physical -> 13 virtual layers) with tied test-time training. Reuses layer weights to add depth without increasing model size, keeping the artifact under 16MB with int6+lzma compression. Combined with TTT, SWA, bigram embeddings, value embeddings, and Muon optimizer with weight decay.", "date": "2026-03-25T00:00:00Z", - "val_loss": "TODO_MEAN", - "val_bpb": "TODO_MEAN", - "val_loss_std": "TODO", - "val_bpb_std": "TODO", - "seeds": [1337, 42, 2024], + "val_loss": 1.88801557, + "val_bpb": 1.11819213, + "val_loss_std": 0.00129122, + "val_bpb_std": 0.00076473, + "seeds": [1337, 2025, 2024], "seed_results": { "1337": {"val_loss": 1.88749538, "val_bpb": 1.11788404}, - "42": {"val_loss": "TODO_RERUN", "val_bpb": "TODO_RERUN"}, - "2024": {"val_loss": "TODO_RERUN", "val_bpb": "TODO_RERUN"} + "2025": {"val_loss": 1.88948575, "val_bpb": 1.11906285}, + "2024": {"val_loss": 1.88706558, "val_bpb": 1.11762949} }, "step_stop": 6100, - "wallclock_seconds": 600.188, - "eval_time_seconds": "TODO", + "wallclock_seconds": 600.0, + "eval_time_seconds": 475, "bytes_total": 15928948, "bytes_code": 95036 } diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2024.log b/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2024.log new file mode 100644 index 000000000..1d3852ead --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2024.log @@ -0,0 +1,274 @@ +W0325 05:28:06.586000 75757 torch/distributed/run.py:803] +W0325 05:28:06.586000 75757 torch/distributed/run.py:803] ***************************************** +W0325 05:28:06.586000 75757 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 05:28:06.586000 75757 torch/distributed/run.py:803] ***************************************** +logs/388e32c2-50e0-4e35-88ac-51219634dd13.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26998380 +recurrence:layers=[4, 5] physical_layers=11 virtual_layers=13 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[9, 10, 11, 12] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:8000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2024 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/8000 val_loss:6.9327 val_bpb:4.1059 train_time:0ms step_avg:0.02ms +step:1/8000 train_loss:6.9341 train_time:144ms step_avg:144.25ms +step:2/8000 train_loss:8.6442 train_time:190ms step_avg:95.07ms +step:3/8000 train_loss:7.6507 train_time:285ms step_avg:95.08ms +step:4/8000 train_loss:7.2426 train_time:380ms step_avg:94.93ms +step:5/8000 train_loss:7.1122 train_time:475ms step_avg:94.98ms +step:6/8000 train_loss:7.0615 train_time:570ms step_avg:94.98ms +step:7/8000 train_loss:6.9790 train_time:665ms step_avg:95.01ms +step:8/8000 train_loss:6.8907 train_time:759ms step_avg:94.92ms +step:9/8000 train_loss:6.5547 train_time:853ms step_avg:94.83ms +step:10/8000 train_loss:6.1912 train_time:949ms step_avg:94.86ms +step:500/8000 train_loss:2.3959 train_time:48743ms step_avg:97.49ms +step:1000/8000 train_loss:2.2556 train_time:97892ms step_avg:97.89ms +step:1500/8000 train_loss:2.2016 train_time:146957ms step_avg:97.97ms +step:2000/8000 train_loss:2.0479 train_time:196043ms step_avg:98.02ms +step:2500/8000 train_loss:2.1513 train_time:245114ms step_avg:98.05ms +step:3000/8000 train_loss:2.1306 train_time:294168ms step_avg:98.06ms +step:3500/8000 train_loss:2.1382 train_time:343183ms step_avg:98.05ms +step:4000/8000 train_loss:1.9288 train_time:392206ms step_avg:98.05ms +step:4000/8000 val_loss:2.0189 val_bpb:1.1957 train_time:392256ms step_avg:98.06ms +step:4500/8000 train_loss:2.0800 train_time:441225ms step_avg:98.05ms +step:5000/8000 train_loss:2.0518 train_time:490224ms step_avg:98.04ms +swa:start step:5450 +step:5500/8000 train_loss:1.9626 train_time:539319ms step_avg:98.06ms +late_qat:enabled step:5593 scale:0.1500 +step:6000/8000 train_loss:1.8847 train_time:588951ms step_avg:98.16ms +step:6111/8000 val_loss:1.9183 val_bpb:1.1361 train_time:600058ms step_avg:98.19ms +stopping_early: wallclock_cap train_time:600058ms step:6111/8000 +peak memory allocated: 25295 MiB reserved: 26058 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9167 val_bpb:1.1352 eval_time:2330ms +Serialized model: 106179454 bytes +Code size: 95036 bytes +Serialized model int6+lzma: 15825172 bytes +Total submission size int6+lzma: 15920208 bytes +final_int6_roundtrip val_loss:1.9308 val_bpb:1.1435 eval_time:6942ms +final_int6_roundtrip_exact val_loss:1.93077809 val_bpb:1.14351554 +final_int6_sliding_window val_loss:1.8911 val_bpb:1.1200 stride:64 eval_time:86022ms +final_int6_sliding_window_exact val_loss:1.89110334 val_bpb:1.12002088 +final_int8_zlib_roundtrip_exact val_loss:1.89110334 val_bpb:1.12002088 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=2 +ttt_sliding:params unfrozen=26994268 frozen=4112 + ttt_chunk [1/1893] bpb=1.152698 time=0.5s + ttt_chunk [11/1893] bpb=1.142797 time=3.0s + ttt_chunk [21/1893] bpb=1.129555 time=5.4s + ttt_chunk [31/1893] bpb=1.127492 time=7.9s + ttt_chunk [41/1893] bpb=1.114271 time=10.4s + ttt_chunk [51/1893] bpb=1.108772 time=12.9s + ttt_chunk [61/1893] bpb=1.115238 time=15.3s + ttt_chunk [71/1893] bpb=1.113682 time=17.8s + ttt_chunk [81/1893] bpb=1.112830 time=20.3s + ttt_chunk [91/1893] bpb=1.113697 time=22.8s + ttt_chunk [101/1893] bpb=1.117243 time=25.2s + ttt_chunk [111/1893] bpb=1.119761 time=27.7s + ttt_chunk [121/1893] bpb=1.112954 time=30.2s + ttt_chunk [131/1893] bpb=1.113122 time=32.7s + ttt_chunk [141/1893] bpb=1.118719 time=35.1s + ttt_chunk [151/1893] bpb=1.120823 time=37.6s + ttt_chunk [161/1893] bpb=1.120211 time=40.1s + ttt_chunk [171/1893] bpb=1.124588 time=42.6s + ttt_chunk [181/1893] bpb=1.126885 time=45.0s + ttt_chunk [191/1893] bpb=1.134034 time=47.5s + ttt_chunk [201/1893] bpb=1.132635 time=50.0s + ttt_chunk [211/1893] bpb=1.130300 time=52.5s + ttt_chunk [221/1893] bpb=1.131792 time=55.0s + ttt_chunk [231/1893] bpb=1.130602 time=57.5s + ttt_chunk [241/1893] bpb=1.131019 time=60.0s + ttt_chunk [251/1893] bpb=1.130536 time=62.4s + ttt_chunk [261/1893] bpb=1.127636 time=64.9s + ttt_chunk [271/1893] bpb=1.126431 time=67.4s + ttt_chunk [281/1893] bpb=1.127757 time=69.8s + ttt_chunk [291/1893] bpb=1.129542 time=72.4s + ttt_chunk [301/1893] bpb=1.130241 time=74.8s + ttt_chunk [311/1893] bpb=1.132259 time=77.3s + ttt_chunk [321/1893] bpb=1.134252 time=79.8s + ttt_chunk [331/1893] bpb=1.134108 time=82.3s + ttt_chunk [341/1893] bpb=1.133064 time=84.8s + ttt_chunk [351/1893] bpb=1.135435 time=87.3s + ttt_chunk [361/1893] bpb=1.135684 time=89.8s + ttt_chunk [371/1893] bpb=1.135018 time=92.2s + ttt_chunk [381/1893] bpb=1.135189 time=94.7s + ttt_chunk [391/1893] bpb=1.135030 time=97.2s + ttt_chunk [401/1893] bpb=1.132961 time=99.7s + ttt_chunk [411/1893] bpb=1.131839 time=102.1s + ttt_chunk [421/1893] bpb=1.130917 time=104.6s + ttt_chunk [431/1893] bpb=1.130798 time=107.1s + ttt_chunk [441/1893] bpb=1.131204 time=109.6s + ttt_chunk [451/1893] bpb=1.131534 time=112.1s + ttt_chunk [461/1893] bpb=1.130458 time=114.5s + ttt_chunk [471/1893] bpb=1.131108 time=117.0s + ttt_chunk [481/1893] bpb=1.130715 time=119.5s + ttt_chunk [491/1893] bpb=1.129638 time=122.0s + ttt_chunk [501/1893] bpb=1.129226 time=124.5s + ttt_chunk [511/1893] bpb=1.128588 time=127.0s + ttt_chunk [521/1893] bpb=1.126247 time=129.4s + ttt_chunk [531/1893] bpb=1.127478 time=131.9s + ttt_chunk [541/1893] bpb=1.127861 time=134.4s + ttt_chunk [551/1893] bpb=1.126835 time=136.8s + ttt_chunk [561/1893] bpb=1.127383 time=139.3s + ttt_chunk [571/1893] bpb=1.126322 time=141.8s + ttt_chunk [581/1893] bpb=1.125534 time=144.3s + ttt_chunk [591/1893] bpb=1.124915 time=146.8s + ttt_chunk [601/1893] bpb=1.125413 time=149.3s + ttt_chunk [611/1893] bpb=1.125370 time=151.8s + ttt_chunk [621/1893] bpb=1.125213 time=154.3s + ttt_chunk [631/1893] bpb=1.125922 time=156.7s + ttt_chunk [641/1893] bpb=1.125684 time=159.2s + ttt_chunk [651/1893] bpb=1.125767 time=161.7s + ttt_chunk [661/1893] bpb=1.125207 time=164.2s + ttt_chunk [671/1893] bpb=1.125569 time=166.6s + ttt_chunk [681/1893] bpb=1.126268 time=169.1s + ttt_chunk [691/1893] bpb=1.127255 time=171.6s + ttt_chunk [701/1893] bpb=1.126684 time=174.1s + ttt_chunk [711/1893] bpb=1.126676 time=176.6s + ttt_chunk [721/1893] bpb=1.126337 time=179.0s + ttt_chunk [731/1893] bpb=1.126351 time=181.5s + ttt_chunk [741/1893] bpb=1.126446 time=184.0s + ttt_chunk [751/1893] bpb=1.126280 time=186.5s + ttt_chunk [761/1893] bpb=1.126209 time=189.0s + ttt_chunk [771/1893] bpb=1.125883 time=191.4s + ttt_chunk [781/1893] bpb=1.126612 time=193.9s + ttt_chunk [791/1893] bpb=1.126185 time=196.4s + ttt_chunk [801/1893] bpb=1.126506 time=198.8s + ttt_chunk [811/1893] bpb=1.126256 time=201.3s + ttt_chunk [821/1893] bpb=1.126017 time=203.8s + ttt_chunk [831/1893] bpb=1.125800 time=206.3s + ttt_chunk [841/1893] bpb=1.125135 time=208.7s + ttt_chunk [851/1893] bpb=1.124881 time=211.2s + ttt_chunk [861/1893] bpb=1.124615 time=213.7s + ttt_chunk [871/1893] bpb=1.124877 time=216.1s + ttt_chunk [881/1893] bpb=1.125051 time=218.7s + ttt_chunk [891/1893] bpb=1.124604 time=221.1s + ttt_chunk [901/1893] bpb=1.124350 time=223.6s + ttt_chunk [911/1893] bpb=1.124485 time=226.1s + ttt_chunk [921/1893] bpb=1.124957 time=228.5s + ttt_chunk [931/1893] bpb=1.124919 time=231.0s + ttt_chunk [941/1893] bpb=1.124607 time=233.5s + ttt_chunk [951/1893] bpb=1.125010 time=235.9s + ttt_chunk [961/1893] bpb=1.125091 time=238.4s + ttt_chunk [971/1893] bpb=1.125951 time=240.9s + ttt_chunk [981/1893] bpb=1.126037 time=243.3s + ttt_chunk [991/1893] bpb=1.126072 time=245.8s + ttt_chunk [1001/1893] bpb=1.126024 time=248.3s + ttt_chunk [1011/1893] bpb=1.125816 time=250.7s + ttt_chunk [1021/1893] bpb=1.126176 time=253.2s + ttt_chunk [1031/1893] bpb=1.126656 time=255.7s + ttt_chunk [1041/1893] bpb=1.126303 time=258.2s + ttt_chunk [1051/1893] bpb=1.126056 time=260.6s + ttt_chunk [1061/1893] bpb=1.126125 time=263.1s + ttt_chunk [1071/1893] bpb=1.126729 time=265.6s + ttt_chunk [1081/1893] bpb=1.127023 time=268.1s + ttt_chunk [1091/1893] bpb=1.127756 time=270.5s + ttt_chunk [1101/1893] bpb=1.127773 time=273.0s + ttt_chunk [1111/1893] bpb=1.127631 time=275.5s + ttt_chunk [1121/1893] bpb=1.127428 time=278.0s + ttt_chunk [1131/1893] bpb=1.127315 time=280.5s + ttt_chunk [1141/1893] bpb=1.127020 time=282.9s + ttt_chunk [1151/1893] bpb=1.127036 time=285.4s + ttt_chunk [1161/1893] bpb=1.126665 time=287.9s + ttt_chunk [1171/1893] bpb=1.127008 time=290.3s + ttt_chunk [1181/1893] bpb=1.126281 time=292.8s + ttt_chunk [1191/1893] bpb=1.126148 time=295.3s + ttt_chunk [1201/1893] bpb=1.126564 time=297.7s + ttt_chunk [1211/1893] bpb=1.126098 time=300.2s + ttt_chunk [1221/1893] bpb=1.125816 time=302.7s + ttt_chunk [1231/1893] bpb=1.125546 time=305.1s + ttt_chunk [1241/1893] bpb=1.125214 time=307.6s + ttt_chunk [1251/1893] bpb=1.124635 time=310.1s + ttt_chunk [1261/1893] bpb=1.124613 time=312.6s + ttt_chunk [1271/1893] bpb=1.124243 time=315.1s + ttt_chunk [1281/1893] bpb=1.124042 time=317.5s + ttt_chunk [1291/1893] bpb=1.123808 time=320.0s + ttt_chunk [1301/1893] bpb=1.123219 time=322.5s + ttt_chunk [1311/1893] bpb=1.122848 time=324.9s + ttt_chunk [1321/1893] bpb=1.122545 time=327.4s + ttt_chunk [1331/1893] bpb=1.122490 time=329.9s + ttt_chunk [1341/1893] bpb=1.122373 time=332.3s + ttt_chunk [1351/1893] bpb=1.122312 time=334.8s + ttt_chunk [1361/1893] bpb=1.122344 time=337.3s + ttt_chunk [1371/1893] bpb=1.122201 time=339.7s + ttt_chunk [1381/1893] bpb=1.122180 time=342.2s + ttt_chunk [1391/1893] bpb=1.121784 time=344.7s + ttt_chunk [1401/1893] bpb=1.121738 time=347.2s + ttt_chunk [1411/1893] bpb=1.121838 time=349.6s + ttt_chunk [1421/1893] bpb=1.122101 time=352.1s + ttt_chunk [1431/1893] bpb=1.121782 time=354.6s + ttt_chunk [1441/1893] bpb=1.122299 time=357.0s + ttt_chunk [1451/1893] bpb=1.122639 time=359.5s + ttt_chunk [1461/1893] bpb=1.122192 time=362.0s + ttt_chunk [1471/1893] bpb=1.123246 time=364.5s + ttt_chunk [1481/1893] bpb=1.122784 time=367.0s + ttt_chunk [1491/1893] bpb=1.122589 time=369.4s + ttt_chunk [1501/1893] bpb=1.122499 time=371.9s + ttt_chunk [1511/1893] bpb=1.122532 time=374.4s + ttt_chunk [1521/1893] bpb=1.122545 time=376.9s + ttt_chunk [1531/1893] bpb=1.122009 time=379.3s + ttt_chunk [1541/1893] bpb=1.121883 time=381.8s + ttt_chunk [1551/1893] bpb=1.122203 time=384.3s + ttt_chunk [1561/1893] bpb=1.122216 time=386.8s + ttt_chunk [1571/1893] bpb=1.122055 time=389.3s + ttt_chunk [1581/1893] bpb=1.122170 time=391.8s + ttt_chunk [1591/1893] bpb=1.122016 time=394.3s + ttt_chunk [1601/1893] bpb=1.122186 time=396.8s + ttt_chunk [1611/1893] bpb=1.122140 time=399.3s + ttt_chunk [1621/1893] bpb=1.121728 time=401.8s + ttt_chunk [1631/1893] bpb=1.122045 time=404.2s + ttt_chunk [1641/1893] bpb=1.122058 time=406.7s + ttt_chunk [1651/1893] bpb=1.122020 time=409.2s + ttt_chunk [1661/1893] bpb=1.121890 time=411.7s + ttt_chunk [1671/1893] bpb=1.122370 time=414.2s + ttt_chunk [1681/1893] bpb=1.122515 time=416.7s + ttt_chunk [1691/1893] bpb=1.122366 time=419.2s + ttt_chunk [1701/1893] bpb=1.122534 time=421.6s + ttt_chunk [1711/1893] bpb=1.122524 time=424.1s + ttt_chunk [1721/1893] bpb=1.122528 time=426.6s + ttt_chunk [1731/1893] bpb=1.122396 time=429.1s + ttt_chunk [1741/1893] bpb=1.122206 time=431.6s + ttt_chunk [1751/1893] bpb=1.122042 time=434.1s + ttt_chunk [1761/1893] bpb=1.122188 time=436.6s + ttt_chunk [1771/1893] bpb=1.122101 time=439.1s + ttt_chunk [1781/1893] bpb=1.122120 time=441.5s + ttt_chunk [1791/1893] bpb=1.121712 time=444.0s + ttt_chunk [1801/1893] bpb=1.121600 time=446.5s + ttt_chunk [1811/1893] bpb=1.121495 time=449.0s + ttt_chunk [1821/1893] bpb=1.121555 time=451.5s + ttt_chunk [1831/1893] bpb=1.120959 time=453.9s + ttt_chunk [1841/1893] bpb=1.120911 time=456.4s + ttt_chunk [1851/1893] bpb=1.120694 time=459.0s + ttt_chunk [1861/1893] bpb=1.120338 time=461.5s + ttt_chunk [1871/1893] bpb=1.120323 time=464.0s + ttt_chunk [1881/1893] bpb=1.119878 time=466.5s + ttt_chunk [1891/1893] bpb=1.119649 time=469.0s + ttt_chunk [1893/1893] bpb=1.119692 time=469.3s +ttt_sliding:done val_loss=1.887066 val_bpb=1.117629 elapsed=469.3s +legal_ttt val_loss:1.8871 val_bpb:1.1176 eval_time:469850ms +legal_ttt_exact val_loss:1.88706558 val_bpb:1.11762949 diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2025.log b/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2025.log new file mode 100644 index 000000000..24de5a5db --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2025.log @@ -0,0 +1,274 @@ +W0325 05:05:06.263000 8081 torch/distributed/run.py:803] +W0325 05:05:06.263000 8081 torch/distributed/run.py:803] ***************************************** +W0325 05:05:06.263000 8081 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 05:05:06.263000 8081 torch/distributed/run.py:803] ***************************************** +logs/c7954938-1ada-4e4c-882a-aaecc4b07794.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26998380 +recurrence:layers=[4, 5] physical_layers=11 virtual_layers=13 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[9, 10, 11, 12] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2025 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/9000 val_loss:6.9277 val_bpb:4.1030 train_time:0ms step_avg:0.01ms +step:1/9000 train_loss:6.9281 train_time:145ms step_avg:145.42ms +step:2/9000 train_loss:8.5172 train_time:193ms step_avg:96.30ms +step:3/9000 train_loss:7.6279 train_time:287ms step_avg:95.63ms +step:4/9000 train_loss:7.3505 train_time:382ms step_avg:95.58ms +step:5/9000 train_loss:7.1826 train_time:477ms step_avg:95.45ms +step:6/9000 train_loss:7.0688 train_time:571ms step_avg:95.23ms +step:7/9000 train_loss:6.9227 train_time:667ms step_avg:95.25ms +step:8/9000 train_loss:6.8939 train_time:762ms step_avg:95.22ms +step:9/9000 train_loss:6.5512 train_time:857ms step_avg:95.22ms +step:10/9000 train_loss:6.1678 train_time:951ms step_avg:95.12ms +step:500/9000 train_loss:2.3824 train_time:48648ms step_avg:97.30ms +step:1000/9000 train_loss:2.2574 train_time:97714ms step_avg:97.71ms +step:1500/9000 train_loss:2.2027 train_time:146827ms step_avg:97.88ms +step:2000/9000 train_loss:2.0445 train_time:196024ms step_avg:98.01ms +step:2500/9000 train_loss:2.1514 train_time:245140ms step_avg:98.06ms +step:3000/9000 train_loss:2.1320 train_time:294203ms step_avg:98.07ms +step:3500/9000 train_loss:2.1407 train_time:343243ms step_avg:98.07ms +step:4000/9000 train_loss:1.9321 train_time:392296ms step_avg:98.07ms +step:4000/9000 val_loss:2.0220 val_bpb:1.1976 train_time:392346ms step_avg:98.09ms +step:4500/9000 train_loss:2.0800 train_time:441368ms step_avg:98.08ms +step:5000/9000 train_loss:2.0584 train_time:490429ms step_avg:98.09ms +swa:start step:5450 +step:5500/9000 train_loss:1.9663 train_time:539572ms step_avg:98.10ms +late_qat:enabled step:5590 scale:0.1500 +step:6000/9000 train_loss:1.8896 train_time:589238ms step_avg:98.21ms +step:6108/9000 val_loss:1.9204 val_bpb:1.1374 train_time:600085ms step_avg:98.25ms +stopping_early: wallclock_cap train_time:600085ms step:6108/9000 +peak memory allocated: 25295 MiB reserved: 26088 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9189 val_bpb:1.1365 eval_time:2339ms +Serialized model: 106179454 bytes +Code size: 95036 bytes +Serialized model int6+lzma: 15839896 bytes +Total submission size int6+lzma: 15934932 bytes +final_int6_roundtrip val_loss:1.9334 val_bpb:1.1450 eval_time:21441ms +final_int6_roundtrip_exact val_loss:1.93335053 val_bpb:1.14503908 +final_int6_sliding_window val_loss:1.8938 val_bpb:1.1216 stride:64 eval_time:110549ms +final_int6_sliding_window_exact val_loss:1.89375476 val_bpb:1.12159121 +final_int8_zlib_roundtrip_exact val_loss:1.89375476 val_bpb:1.12159121 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=2 +ttt_sliding:params unfrozen=26994268 frozen=4112 + ttt_chunk [1/1893] bpb=1.162210 time=0.5s + ttt_chunk [11/1893] bpb=1.145781 time=3.0s + ttt_chunk [21/1893] bpb=1.130322 time=5.5s + ttt_chunk [31/1893] bpb=1.129152 time=8.1s + ttt_chunk [41/1893] bpb=1.116044 time=10.6s + ttt_chunk [51/1893] bpb=1.110693 time=13.1s + ttt_chunk [61/1893] bpb=1.117044 time=15.6s + ttt_chunk [71/1893] bpb=1.115827 time=18.1s + ttt_chunk [81/1893] bpb=1.115163 time=20.6s + ttt_chunk [91/1893] bpb=1.115961 time=23.1s + ttt_chunk [101/1893] bpb=1.119242 time=25.6s + ttt_chunk [111/1893] bpb=1.121548 time=28.1s + ttt_chunk [121/1893] bpb=1.114806 time=30.6s + ttt_chunk [131/1893] bpb=1.114754 time=33.1s + ttt_chunk [141/1893] bpb=1.120362 time=35.6s + ttt_chunk [151/1893] bpb=1.122303 time=38.1s + ttt_chunk [161/1893] bpb=1.121662 time=40.6s + ttt_chunk [171/1893] bpb=1.125906 time=43.1s + ttt_chunk [181/1893] bpb=1.128185 time=45.6s + ttt_chunk [191/1893] bpb=1.135549 time=48.1s + ttt_chunk [201/1893] bpb=1.134305 time=50.6s + ttt_chunk [211/1893] bpb=1.132100 time=53.1s + ttt_chunk [221/1893] bpb=1.133617 time=55.6s + ttt_chunk [231/1893] bpb=1.132254 time=58.1s + ttt_chunk [241/1893] bpb=1.132617 time=60.7s + ttt_chunk [251/1893] bpb=1.132233 time=63.2s + ttt_chunk [261/1893] bpb=1.129325 time=65.7s + ttt_chunk [271/1893] bpb=1.128204 time=68.2s + ttt_chunk [281/1893] bpb=1.129517 time=70.7s + ttt_chunk [291/1893] bpb=1.131399 time=73.2s + ttt_chunk [301/1893] bpb=1.132069 time=75.7s + ttt_chunk [311/1893] bpb=1.134154 time=78.2s + ttt_chunk [321/1893] bpb=1.136082 time=80.7s + ttt_chunk [331/1893] bpb=1.135933 time=83.2s + ttt_chunk [341/1893] bpb=1.134982 time=85.8s + ttt_chunk [351/1893] bpb=1.137301 time=88.3s + ttt_chunk [361/1893] bpb=1.137582 time=90.7s + ttt_chunk [371/1893] bpb=1.136861 time=93.2s + ttt_chunk [381/1893] bpb=1.137043 time=95.7s + ttt_chunk [391/1893] bpb=1.136862 time=98.2s + ttt_chunk [401/1893] bpb=1.134839 time=100.8s + ttt_chunk [411/1893] bpb=1.133707 time=103.2s + ttt_chunk [421/1893] bpb=1.132831 time=105.8s + ttt_chunk [431/1893] bpb=1.132741 time=108.3s + ttt_chunk [441/1893] bpb=1.133128 time=110.8s + ttt_chunk [451/1893] bpb=1.133348 time=113.3s + ttt_chunk [461/1893] bpb=1.132251 time=115.8s + ttt_chunk [471/1893] bpb=1.132836 time=118.2s + ttt_chunk [481/1893] bpb=1.132438 time=120.8s + ttt_chunk [491/1893] bpb=1.131381 time=123.3s + ttt_chunk [501/1893] bpb=1.130900 time=125.8s + ttt_chunk [511/1893] bpb=1.130280 time=128.3s + ttt_chunk [521/1893] bpb=1.127914 time=130.8s + ttt_chunk [531/1893] bpb=1.129097 time=133.3s + ttt_chunk [541/1893] bpb=1.129478 time=135.8s + ttt_chunk [551/1893] bpb=1.128459 time=138.3s + ttt_chunk [561/1893] bpb=1.128997 time=140.8s + ttt_chunk [571/1893] bpb=1.127988 time=143.3s + ttt_chunk [581/1893] bpb=1.127180 time=145.9s + ttt_chunk [591/1893] bpb=1.126527 time=148.4s + ttt_chunk [601/1893] bpb=1.127034 time=150.9s + ttt_chunk [611/1893] bpb=1.126999 time=153.4s + ttt_chunk [621/1893] bpb=1.126844 time=155.9s + ttt_chunk [631/1893] bpb=1.127588 time=158.4s + ttt_chunk [641/1893] bpb=1.127333 time=160.9s + ttt_chunk [651/1893] bpb=1.127450 time=163.5s + ttt_chunk [661/1893] bpb=1.126922 time=165.9s + ttt_chunk [671/1893] bpb=1.127297 time=168.4s + ttt_chunk [681/1893] bpb=1.127980 time=170.9s + ttt_chunk [691/1893] bpb=1.128981 time=173.4s + ttt_chunk [701/1893] bpb=1.128408 time=175.9s + ttt_chunk [711/1893] bpb=1.128368 time=178.4s + ttt_chunk [721/1893] bpb=1.128037 time=180.9s + ttt_chunk [731/1893] bpb=1.128080 time=183.4s + ttt_chunk [741/1893] bpb=1.128164 time=185.9s + ttt_chunk [751/1893] bpb=1.128023 time=188.4s + ttt_chunk [761/1893] bpb=1.127942 time=190.9s + ttt_chunk [771/1893] bpb=1.127617 time=193.5s + ttt_chunk [781/1893] bpb=1.128346 time=196.0s + ttt_chunk [791/1893] bpb=1.127906 time=198.5s + ttt_chunk [801/1893] bpb=1.128239 time=201.0s + ttt_chunk [811/1893] bpb=1.127962 time=203.6s + ttt_chunk [821/1893] bpb=1.127753 time=206.1s + ttt_chunk [831/1893] bpb=1.127571 time=208.6s + ttt_chunk [841/1893] bpb=1.126914 time=211.0s + ttt_chunk [851/1893] bpb=1.126640 time=213.6s + ttt_chunk [861/1893] bpb=1.126382 time=216.1s + ttt_chunk [871/1893] bpb=1.126626 time=218.5s + ttt_chunk [881/1893] bpb=1.126788 time=221.1s + ttt_chunk [891/1893] bpb=1.126361 time=223.5s + ttt_chunk [901/1893] bpb=1.126072 time=226.1s + ttt_chunk [911/1893] bpb=1.126172 time=228.6s + ttt_chunk [921/1893] bpb=1.126657 time=231.1s + ttt_chunk [931/1893] bpb=1.126618 time=233.6s + ttt_chunk [941/1893] bpb=1.126286 time=236.1s + ttt_chunk [951/1893] bpb=1.126675 time=238.6s + ttt_chunk [961/1893] bpb=1.126781 time=241.1s + ttt_chunk [971/1893] bpb=1.127633 time=243.6s + ttt_chunk [981/1893] bpb=1.127709 time=246.1s + ttt_chunk [991/1893] bpb=1.127749 time=248.6s + ttt_chunk [1001/1893] bpb=1.127703 time=251.1s + ttt_chunk [1011/1893] bpb=1.127487 time=253.6s + ttt_chunk [1021/1893] bpb=1.127835 time=256.1s + ttt_chunk [1031/1893] bpb=1.128318 time=258.6s + ttt_chunk [1041/1893] bpb=1.127965 time=261.1s + ttt_chunk [1051/1893] bpb=1.127732 time=263.6s + ttt_chunk [1061/1893] bpb=1.127778 time=266.1s + ttt_chunk [1071/1893] bpb=1.128390 time=268.6s + ttt_chunk [1081/1893] bpb=1.128659 time=271.1s + ttt_chunk [1091/1893] bpb=1.129394 time=273.6s + ttt_chunk [1101/1893] bpb=1.129401 time=276.1s + ttt_chunk [1111/1893] bpb=1.129251 time=278.6s + ttt_chunk [1121/1893] bpb=1.129077 time=281.1s + ttt_chunk [1131/1893] bpb=1.128942 time=283.6s + ttt_chunk [1141/1893] bpb=1.128641 time=286.2s + ttt_chunk [1151/1893] bpb=1.128669 time=288.7s + ttt_chunk [1161/1893] bpb=1.128280 time=291.2s + ttt_chunk [1171/1893] bpb=1.128600 time=293.7s + ttt_chunk [1181/1893] bpb=1.127839 time=296.2s + ttt_chunk [1191/1893] bpb=1.127713 time=298.7s + ttt_chunk [1201/1893] bpb=1.128090 time=301.2s + ttt_chunk [1211/1893] bpb=1.127629 time=303.6s + ttt_chunk [1221/1893] bpb=1.127339 time=306.2s + ttt_chunk [1231/1893] bpb=1.127053 time=308.7s + ttt_chunk [1241/1893] bpb=1.126716 time=311.2s + ttt_chunk [1251/1893] bpb=1.126136 time=313.7s + ttt_chunk [1261/1893] bpb=1.126119 time=316.2s + ttt_chunk [1271/1893] bpb=1.125744 time=318.7s + ttt_chunk [1281/1893] bpb=1.125561 time=321.2s + ttt_chunk [1291/1893] bpb=1.125331 time=323.7s + ttt_chunk [1301/1893] bpb=1.124734 time=326.2s + ttt_chunk [1311/1893] bpb=1.124350 time=328.7s + ttt_chunk [1321/1893] bpb=1.124044 time=331.2s + ttt_chunk [1331/1893] bpb=1.123993 time=333.7s + ttt_chunk [1341/1893] bpb=1.123860 time=336.2s + ttt_chunk [1351/1893] bpb=1.123796 time=338.7s + ttt_chunk [1361/1893] bpb=1.123841 time=341.2s + ttt_chunk [1371/1893] bpb=1.123723 time=343.7s + ttt_chunk [1381/1893] bpb=1.123699 time=346.2s + ttt_chunk [1391/1893] bpb=1.123309 time=348.8s + ttt_chunk [1401/1893] bpb=1.123276 time=351.3s + ttt_chunk [1411/1893] bpb=1.123374 time=353.8s + ttt_chunk [1421/1893] bpb=1.123616 time=356.3s + ttt_chunk [1431/1893] bpb=1.123319 time=358.8s + ttt_chunk [1441/1893] bpb=1.123826 time=361.3s + ttt_chunk [1451/1893] bpb=1.124157 time=363.9s + ttt_chunk [1461/1893] bpb=1.123698 time=366.4s + ttt_chunk [1471/1893] bpb=1.124741 time=368.9s + ttt_chunk [1481/1893] bpb=1.124276 time=371.4s + ttt_chunk [1491/1893] bpb=1.124080 time=373.9s + ttt_chunk [1501/1893] bpb=1.123990 time=376.4s + ttt_chunk [1511/1893] bpb=1.124002 time=378.9s + ttt_chunk [1521/1893] bpb=1.124037 time=381.4s + ttt_chunk [1531/1893] bpb=1.123507 time=383.9s + ttt_chunk [1541/1893] bpb=1.123343 time=386.4s + ttt_chunk [1551/1893] bpb=1.123654 time=389.0s + ttt_chunk [1561/1893] bpb=1.123668 time=391.5s + ttt_chunk [1571/1893] bpb=1.123492 time=394.0s + ttt_chunk [1581/1893] bpb=1.123611 time=396.5s + ttt_chunk [1591/1893] bpb=1.123470 time=399.0s + ttt_chunk [1601/1893] bpb=1.123644 time=401.5s + ttt_chunk [1611/1893] bpb=1.123583 time=404.0s + ttt_chunk [1621/1893] bpb=1.123181 time=406.5s + ttt_chunk [1631/1893] bpb=1.123506 time=409.0s + ttt_chunk [1641/1893] bpb=1.123512 time=411.5s + ttt_chunk [1651/1893] bpb=1.123468 time=414.1s + ttt_chunk [1661/1893] bpb=1.123346 time=416.6s + ttt_chunk [1671/1893] bpb=1.123816 time=419.1s + ttt_chunk [1681/1893] bpb=1.123963 time=421.6s + ttt_chunk [1691/1893] bpb=1.123804 time=424.1s + ttt_chunk [1701/1893] bpb=1.123954 time=426.6s + ttt_chunk [1711/1893] bpb=1.123942 time=429.2s + ttt_chunk [1721/1893] bpb=1.123947 time=431.7s + ttt_chunk [1731/1893] bpb=1.123836 time=434.2s + ttt_chunk [1741/1893] bpb=1.123626 time=436.7s + ttt_chunk [1751/1893] bpb=1.123475 time=439.2s + ttt_chunk [1761/1893] bpb=1.123631 time=441.7s + ttt_chunk [1771/1893] bpb=1.123539 time=444.2s + ttt_chunk [1781/1893] bpb=1.123568 time=446.7s + ttt_chunk [1791/1893] bpb=1.123158 time=449.2s + ttt_chunk [1801/1893] bpb=1.123045 time=451.7s + ttt_chunk [1811/1893] bpb=1.122955 time=454.2s + ttt_chunk [1821/1893] bpb=1.123013 time=456.7s + ttt_chunk [1831/1893] bpb=1.122421 time=459.2s + ttt_chunk [1841/1893] bpb=1.122384 time=461.7s + ttt_chunk [1851/1893] bpb=1.122180 time=464.2s + ttt_chunk [1861/1893] bpb=1.121829 time=466.7s + ttt_chunk [1871/1893] bpb=1.121826 time=469.2s + ttt_chunk [1881/1893] bpb=1.121367 time=471.7s + ttt_chunk [1891/1893] bpb=1.121141 time=474.2s + ttt_chunk [1893/1893] bpb=1.121187 time=474.5s +ttt_sliding:done val_loss=1.889486 val_bpb=1.119063 elapsed=474.6s +legal_ttt val_loss:1.8895 val_bpb:1.1191 eval_time:475115ms +legal_ttt_exact val_loss:1.88948575 val_bpb:1.11906285 From bf968877c2a3e403135946777efa3d1bd95795da Mon Sep 17 00:00:00 2001 From: Marko Sisovic Date: Wed, 25 Mar 2026 05:52:59 +0000 Subject: [PATCH 4/8] Remove experiment_results.md from submission Co-Authored-By: Claude Opus 4.6 --- experiment_results.md | 164 ------------------------------------------ 1 file changed, 164 deletions(-) delete mode 100644 experiment_results.md diff --git a/experiment_results.md b/experiment_results.md deleted file mode 100644 index cf48237c3..000000000 --- a/experiment_results.md +++ /dev/null @@ -1,164 +0,0 @@ -# Experiment Results - -Baseline reference: PR #549 — val_bpb 1.1194 (3-seed mean, 8xH100, dim=512, 11 layers) - -## Experiment 1: MODEL_DIM=576 (all else PR #549 defaults) - -- **Date**: 2026-03-24 -- **Hardware**: 4xH100 80GB -- **Key changes**: MODEL_DIM=576 (up from 512), MAX_WALLCLOCK_SECONDS=1200 -- **Other params**: ITERATIONS=9000, all other hyperparams at code defaults (matching PR #549) -- **model_params**: 33,968,348 (~34M, over budget) -- **Steps completed**: 5,609 / 9000 (wallclock capped at 1200s) -- **Step avg**: ~214ms -- **Peak memory**: 24,301 MiB - -### Results -| Metric | Value | -|--------|-------| -| Pre-EMA val_bpb | 1.1284 | -| Post-EMA val_bpb | 1.1277 | -| **Final int6 quantized val_bpb** | **1.1359** | -| Submission size (int6+lzma) | 19,525,669 bytes (~19.5 MB) | - -### Notes -- Over param budget (~34M vs typical ~24M), so not submittable as-is. -- Quantization gap is large: 1.1277 -> 1.1359 (+0.0082), likely because bigger model loses more to int6. -- Only got ~5.6k steps due to slower step time at larger dim on 4 GPUs. -- No TTT was run (would need separate eval pass). - ---- - -## Experiment 2: NUM_LAYERS=12 (dim=512, all else PR #549 defaults) - -- **Date**: 2026-03-24 -- **Hardware**: 4xH100 80GB -- **Key changes**: NUM_LAYERS=12 (up from 11), MAX_WALLCLOCK_SECONDS=1200 -- **Other params**: ITERATIONS=9000, MODEL_DIM=512, all other hyperparams at code defaults -- **model_params**: 29,355,620 (~29M, over budget but closer than dim=576) -- **Steps completed**: 6,878 / 9000 (wallclock capped at 1200s) -- **Step avg**: ~174ms -- **Peak memory**: 23,484 MiB - -### Results -| Metric | Value | -|--------|-------| -| Pre-EMA val_bpb | 1.1317 | -| Post-EMA val_bpb | 1.1307 | -| Final int6 quantized val_bpb | 1.1390 | -| **Final int6 sliding window val_bpb** | **1.1153** | -| Submission size (int6+lzma) | 17,275,453 bytes (~17.3 MB) | - -### Notes -- Also over param budget (~29M) but less so than dim=576. -- Quantization gap: 1.1307 -> 1.1390 (+0.0083), similar to exp 1. -- Sliding window eval (stride 64) brings it to 1.1153 — better than PR #549 baseline (1.1194) pre-TTT. -- Got more steps (6,878 vs 5,609) due to faster per-step time at dim=512. -- No TTT was run. - ---- - -## Experiment 3: NUM_LAYERS=12 + TTT (dim=512, all else PR #549 defaults) - -- **Date**: 2026-03-24 -- **Hardware**: 4xH100 80GB -- **Key changes**: NUM_LAYERS=12, TTT_ENABLED=1, MAX_WALLCLOCK_SECONDS=1200 -- **Other params**: ITERATIONS=9000, MODEL_DIM=512, all other hyperparams at code defaults -- **model_params**: 29,355,620 (~29M) -- **Steps completed**: 6,879 / 9000 (wallclock capped at 1200s) -- **Step avg**: ~174ms -- **TTT time**: 620s (1893 chunks, 3 epochs each) - -### Results -| Metric | Value | -|--------|-------| -| Post-EMA val_bpb | 1.1304 | -| Final int6 quantized val_bpb | 1.1387 | -| Final int6 sliding window val_bpb | 1.1151 | -| **Post-TTT sliding window val_bpb** | **1.1126** | - -### Notes -- TTT gave a -0.0025 gain over sliding window (1.1151 -> 1.1126), similar to PR #549's TTT gain. -- **Beats PR #549's 1.1194 by 0.0068 BPB** — but still over param budget (~29M). -- TTT took 620s which would be within a 10-min eval constraint. - ---- - -## Experiment 4: RECUR_LAYER=5 + TTT (depth recurrence, 11 physical → 12 virtual layers) - -- **Date**: 2026-03-25 -- **Hardware**: 4xH100 80GB -- **Key changes**: RECUR_LAYER=5 (layer 5 duplicated), TTT_ENABLED=1, MAX_WALLCLOCK_SECONDS=1200 -- **Other params**: ITERATIONS=9000, MODEL_DIM=512, NUM_LAYERS=11 (physical), all else defaults -- **model_params**: 26,996,324 (~27M) — only ~2.7M over baseline from extra block scalars -- **Steps completed**: 6,884 / 9000 (wallclock capped at 1200s) -- **Step avg**: ~174ms (same as full 12-layer) -- **TTT time**: 622s (untied recurrence before TTT) -- **Submission size**: 15,927,562 bytes (~15.9 MB, well under 16MB budget!) - -### Results -| Metric | Value | -|--------|-------| -| Post-EMA val_bpb | 1.1354 | -| Final int6 quantized val_bpb | 1.1440 | -| Final int6 sliding window val_bpb | 1.1205 | -| **Post-TTT sliding window val_bpb** | **1.1180** | - -### Comparison to full 12-layer (Exp 3) -| Metric | Full 12L (Exp 3) | Recur L5 (Exp 4) | Delta | -|--------|-----------------|-------------------|-------| -| Params | 29.4M | 27.0M | -2.4M | -| Submission size | 17.3 MB | 15.9 MB | -1.4 MB | -| Sliding window val_bpb | 1.1151 | 1.1205 | +0.0054 | -| Post-TTT val_bpb | 1.1126 | 1.1180 | +0.0054 | - -### Notes -- Recurrence adds depth for free in compute, but shared weights cost ~0.005 BPB vs independent layers. -- TTT untying gave a -0.0025 gain (1.1205 -> 1.1180), same magnitude as independent layers. -- Submission size is much smaller (15.9 MB vs 17.3 MB) since banks stay at 11-layer size. -- Still beats PR #549 baseline (1.1194) by 0.0014 BPB, with a smaller model. - ---- - -## Experiment 4b: Tied TTT (same checkpoint as Exp 4, no untying) - -- **Post-TTT val_bpb**: **1.1179** (vs 1.1180 untied — negligible difference) -- Conclusion: untying doesn't help with 3-epoch TTT. Tied is fine. - ---- - -## Experiment 5: RECUR_LAYERS=4,5 + tied TTT (dual recurrence, 11 physical → 13 virtual layers) - -- **Date**: 2026-03-25 -- **Hardware**: 4xH100 80GB -- **Key changes**: RECUR_LAYERS=4,5 (4,5,4,5 pattern), TTT_ENABLED=1, TTT_UNTIE=0 -- **Other params**: ITERATIONS=9000, MODEL_DIM=512, NUM_LAYERS=11 (physical), all else defaults -- **model_params**: 26,998,380 (~27M) -- **Steps completed**: 6,389 / 9000 (wallclock capped at 1200s) -- **Step avg**: ~188ms (up from 174ms with single recurrence — extra virtual layer costs ~14ms/step) -- **TTT time**: 655s (tied) -- **Submission size**: 15,944,748 bytes (~15.9 MB) - -### Results -| Metric | Value | -|--------|-------| -| Post-EMA val_bpb | 1.1337 | -| Final int6 quantized val_bpb | 1.1421 | -| Final int6 sliding window val_bpb | 1.1187 | -| **Post-TTT sliding window val_bpb** | **1.1163** | - -### Full comparison -| | PR #549 | Recur L5 (Exp 4) | Recur L4,5 (Exp 5) | Full 12L (Exp 3) | -|---|---|---|---|---| -| Virtual depth | 11 | 12 | **13** | 12 | -| Params | ~24M | ~27M | ~27M | ~29M | -| Submission size | ~19.5 MB | 15.9 MB | 15.9 MB | 17.3 MB | -| Steps completed | ~7,180 | 6,884 | 6,389 | 6,879 | -| Post-TTT val_bpb | 1.1194 | 1.1179 | **1.1163** | **1.1126** | - -### Notes -- Dual recurrence (1.1163) beats single recurrence (1.1179) by 0.0016 BPB. -- Beats PR #549 by 0.0031 BPB, with ~27M params and ~16 MB submission. -- Gap to full independent 12-layer (1.1126) is 0.0037 — weight sharing costs more with 2 repeated layers. -- Step time increased to ~188ms (from 174ms), resulting in ~500 fewer steps in the wallclock budget. -- The extra virtual depth helps despite fewer training steps. From a9bb2ab0052a22a0489b3c3634f1a4085fc6b71e Mon Sep 17 00:00:00 2001 From: Marko Sisovic Date: Wed, 25 Mar 2026 05:53:47 +0000 Subject: [PATCH 5/8] Restore train_gpt.py to match main branch Co-Authored-By: Claude Opus 4.6 --- train_gpt.py | 1927 ++++++++++++++------------------------------------ 1 file changed, 536 insertions(+), 1391 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 942f03ad3..651beb2b8 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -1,8 +1,14 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + from __future__ import annotations + import copy import glob import io -import lzma import math import os import random @@ -12,11 +18,7 @@ import uuid import zlib from pathlib import Path -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" + import numpy as np import sentencepiece as spm import torch @@ -24,251 +26,156 @@ import torch.nn.functional as F from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP -from flash_attn_interface import flash_attn_func as flash_attn_3_func + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") train_files = os.path.join(data_path, "fineweb_train_*.bin") val_files = os.path.join(data_path, "fineweb_val_*.bin") tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.6)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) - lawa_k = int(os.environ.get("LAWA_K", 10)) - lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) - muon_wd = float(os.environ.get("MUON_WD", 0.04)) - adam_wd = float(os.environ.get("ADAM_WD", 0.04)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) - rope_dims = int(os.environ.get("ROPE_DIMS", 16)) - ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) - dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) - late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) - ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) - ve_dim = int(os.environ.get("VE_DIM", 128)) - ve_layers = os.environ.get("VE_LAYERS", "9,10") - gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) - value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) - ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) - ttt_lr = float(os.environ.get("TTT_LR", 0.002)) - ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) - ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) - ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) - ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) - ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) - ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) - recur_layer = int(os.environ.get("RECUR_LAYER", -1)) # single layer compat - recur_layers_str = os.environ.get("RECUR_LAYERS", "") # comma-separated, e.g. "4,5" - eval_only = bool(int(os.environ.get("EVAL_ONLY", "0"))) - -# --- Batched Newton-Schulz orthogonalization --- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: - """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. a, b, c = (3.4445, -4.7750, 2.0315) - was_2d = G.ndim == 2 - if was_2d: - G = G.unsqueeze(0) X = G.bfloat16() - transposed = X.size(-2) > X.size(-1) + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) if transposed: - X = X.mT - X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + X = X.T for _ in range(steps): - A = X @ X.mT - B = b * A + c * (A @ A) + A = X @ X.T + B = b * A + c * A @ A X = a * X + B @ X - if transposed: - X = X.mT - if was_2d: - X = X.squeeze(0) - return X + return X.T if transposed else X -# --- Parallel Muon optimizer --- class Muon(torch.optim.Optimizer): - """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. - - No DDP for bank params. After backward, this optimizer: - 1. Launches async reduce-scatter for all banks (biggest first) - 2. Returns control so Adam can step on small params while RS is in-flight - 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather - 4. Each all-gather overlaps with next bank's NS5 - """ - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): super().__init__( params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), ) - self._built = False - - def _build(self): - self._distributed = dist.is_available() and dist.is_initialized() - self._world_size = dist.get_world_size() if self._distributed else 1 - self._rank = dist.get_rank() if self._distributed else 0 - ws = self._world_size - - self._bank_meta = [] - for group in self.param_groups: - for p in group["params"]: - B = p.shape[0] - padded_B = ((B + ws - 1) // ws) * ws - shard_B = padded_B // ws - tail = p.shape[1:] - dev = p.device - self._bank_meta.append({ - 'p': p, - 'B': B, - 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, - }) - # Sort by size descending -- launch biggest reduce-scatters first - self._bank_meta.sort(key=lambda m: -m['p'].numel()) - self._built = True - - def launch_reduce_scatters(self): - """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" - if not self._built: - self._build() - if not self._distributed: - return - self._rs_futures = [] - for m in self._bank_meta: - p = m['p'] - if p.grad is None: - self._rs_futures.append(None) - continue - pg = m['padded_grad'] - pg[:m['B']].copy_(p.grad.bfloat16()) - if pg.shape[0] > m['B']: - pg[m['B']:].zero_() - fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) - self._rs_futures.append(fut) @torch.no_grad() def step(self, closure=None): - """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" loss = None if closure is not None: with torch.enable_grad(): loss = closure() - if not self._built: - self._build() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 for group in self.param_groups: + params = group["params"] + if not params: + continue lr = group["lr"] momentum = group["momentum"] backend_steps = group["backend_steps"] nesterov = group["nesterov"] - wd = group.get("weight_decay", 0.0) - - prev_ag_handle = None - prev_m = None - - sharded = self._distributed and hasattr(self, '_rs_futures') - - for i, m in enumerate(self._bank_meta): - p = m['p'] - if p.grad is None: - continue - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if sharded and self._rs_futures[i] is not None: - self._rs_futures[i].wait() - g = m['shard'] - buf = m['shard_mom'] - else: - g = p.grad.bfloat16() + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad state = self.state[p] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() - buf.mul_(momentum).add_(g) - if nesterov: - update = g.add(buf, alpha=momentum) - else: - update = buf - - update = zeropower_via_newtonschulz5(update, steps=backend_steps) - - if sharded: - prev_ag_handle = dist.all_gather_into_tensor( - m['full_update'], update, async_op=True) - prev_m = m - else: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if hasattr(self, '_rs_futures'): - del self._rs_futures + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() return loss -# --- Tokenizer evaluation helpers --- + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. def build_sentencepiece_luts( sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device @@ -286,7 +193,7 @@ def build_sentencepiece_luts( base_bytes_np[token_id] = 1 continue piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): + if piece.startswith("▁"): has_leading_space_np[token_id] = True piece = piece[1:] base_bytes_np[token_id] = len(piece.encode("utf-8")) @@ -295,15 +202,20 @@ def build_sentencepiece_luts( torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), ) + + def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: files = [Path(p) for p in sorted(glob.glob(pattern))] if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() usable = ((tokens.numel() - 1) // seq_len) * seq_len if usable <= 0: raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") return tokens[: usable + 1] + + def eval_val( args: Hyperparameters, model: nn.Module, @@ -315,32 +227,34 @@ def eval_val( base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, ) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: + if local_batch_tokens < args.train_seq_len: raise ValueError( "VAL_BATCH_SIZE must provide at least one sequence per rank; " f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len seq_start = (total_seqs * rank) // world_size seq_end = (total_seqs * (rank + 1)) // world_size val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) val_token_count = torch.zeros((), device=device, dtype=torch.float64) val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() with torch.inference_mode(): for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): batch_loss = model(x, y).detach() batch_token_count = float(y.numel()) @@ -351,23 +265,31 @@ def eval_val( token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count bits_per_token = val_loss.item() / math.log(2.0) tokens_per_byte = val_token_count.item() / val_byte_count.item() model.train() return float(val_loss.item()), float(bits_per_token * tokens_per_byte) -# --- Quantization helpers --- +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. CONTROL_TENSOR_NAME_PATTERNS = tuple( pattern for pattern in os.environ.get( "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", ).split(",") if pattern ) @@ -384,8 +306,10 @@ def eval_val( INT8_PER_ROW_SCALE_DTYPE = torch.float16 INT8_CLIP_PERCENTILE = 99.99984 INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + def tensor_nbytes(t: Tensor) -> int: return int(t.numel()) * int(t.element_size()) + def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): return t.float().contiguous() @@ -393,9 +317,12 @@ def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, s passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() return t + def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: t32 = t.float() if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. clip_abs = ( torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() @@ -405,11 +332,19 @@ def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() return q, scale + def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes quantized: dict[str, Tensor] = {} scales: dict[str, Tensor] = {} dtypes: dict[str, str] = {} @@ -420,21 +355,27 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), 0, ) + for name, tensor in state_dict.items(): t = tensor.detach().to("cpu").contiguous() stats["param_count"] += int(t.numel()) stats["num_tensors"] += 1 stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): stats["num_nonfloat_tensors"] += 1 passthrough[name] = t stats["int8_payload_bytes"] += tensor_nbytes(t) continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: kept = keep_float_tensor(name, t, passthrough_orig_dtypes) passthrough[name] = kept stats["int8_payload_bytes"] += tensor_nbytes(kept) continue + stats["num_float_tensors"] += 1 q, s = quantize_float_tensor(t) if s.ndim > 0: @@ -443,6 +384,7 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): scales[name] = s dtypes[name] = str(t.dtype).removeprefix("torch.") stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { "__quant_format__": "int8_clean_per_row_v1", "quantized": quantized, @@ -455,6 +397,7 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): if passthrough_orig_dtypes: obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes return obj, stats + def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: out: dict[str, Tensor] = {} qmeta = obj.get("qmeta", {}) @@ -464,11 +407,13 @@ def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: s = obj["scales"][name] if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() else: scale = float(s.item()) out[name] = (q.float() * scale).to(dtype=dtype).contiguous() for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. out_t = t.detach().to("cpu").contiguous() orig_dtype = passthrough_orig_dtypes.get(name) if isinstance(orig_dtype, str): @@ -476,12 +421,16 @@ def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: out[name] = out_t return out -# --- Data loading --- + +# ----------------------------- +# DATA LOADING +# ----------------------------- def load_data_shard(file: Path) -> Tensor: header_bytes = 256 * np.dtype(" Tensor: if tokens_np.size != num_tokens: raise ValueError(f"Short read for {file}") return torch.from_numpy(tokens_np.astype(np.uint16, copy=False)) + + class TokenStream: + # Reads shards sequentially and wraps around forever. The training loop therefore + # has deterministic, simple streaming behavior with no sampling or workers. def __init__(self, pattern: str): self.files = [Path(p) for p in sorted(glob.glob(pattern))] if not self.files: @@ -500,10 +453,12 @@ def __init__(self, pattern: str): self.file_idx = 0 self.tokens = load_data_shard(self.files[0]) self.pos = 0 + def _advance_file(self) -> None: self.file_idx = (self.file_idx + 1) % len(self.files) self.tokens = load_data_shard(self.files[self.file_idx]) self.pos = 0 + def take(self, n: int) -> Tensor: chunks: list[Tensor] = [] remaining = n @@ -517,12 +472,17 @@ def take(self, n: int) -> Tensor: self.pos += k remaining -= k return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): self.rank = rank self.world_size = world_size self.device = device self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: local_tokens = global_tokens // (self.world_size * grad_accum_steps) per_rank_span = local_tokens + 1 @@ -533,44 +493,44 @@ def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> y = local[1:].reshape(-1, seq_len) return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) -# --- Transformer modules --- +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- class RMSNorm(nn.Module): def __init__(self, eps: float | None = None): super().__init__() self.eps = eps + def forward(self, x: Tensor) -> Tensor: return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + class CastedLinear(nn.Linear): - _qat_enabled: bool = False + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) + return F.linear(x, self.weight.to(x.dtype), bias) + + def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. with torch.no_grad(): for name, param in module.named_parameters(): if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: param.data = param.data.float() + + class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - self.rope_dims = rope_dims if rope_dims > 0 else dim - inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._cos_cached: Tensor | None = None self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: if ( self._cos_cached is None @@ -578,30 +538,20 @@ def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tup or self._seq_len_cached != seq_len or self._cos_cached.device != device ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] self._seq_len_cached = seq_len return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: - if rope_dims > 0 and rope_dims < x.size(-1): - x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] - half = rope_dims // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rope, x_pass), dim=-1) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: half = x.size(-1) // 2 x1, x2 = x[..., :half], x[..., half:] return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + class CausalSelfAttention(nn.Module): def __init__( self, @@ -610,8 +560,6 @@ def __init__( num_kv_heads: int, rope_base: float, qk_gain_init: float, - gated_attention: bool = False, - value_residual: bool = False, ): super().__init__() if dim % num_heads != 0: @@ -623,114 +571,51 @@ def __init__( self.head_dim = dim // num_heads if self.head_dim % 2 != 0: raise ValueError("head_dim must be even for RoPE") - # No CastedLinear -- weights come from banks + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = 0 # set by GPT.__init__ for partial RoPE - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False # set by GPT.__init__ for deep layers only - # Gated attention and value residual (non-banked small params) - self.gated_attention = gated_attention - if gated_attention: - self.attn_gate = nn.Linear(dim, num_heads, bias=True) - nn.init.zeros_(self.attn_gate.weight) - nn.init.constant_(self.attn_gate.bias, 4.0) - self.value_residual = value_residual - if value_residual: - self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). - y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] - vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: bsz, seqlen, dim = x.shape - q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = F.linear(x, v_w.to(x.dtype)) - if v_embed is not None: - v = v + v_embed - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - raw_v = v if self.value_residual else None - if self.value_residual and v0 is not None: - lam = self.vr_lambda.to(dtype=v.dtype) - v = lam[0] * v0 + lam[1] * v + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin, self.rope_dims) - k = apply_rotary_emb(k, cos, sin, self.rope_dims) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - if self.gated_attention: - # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout - gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) - y = y * gate - y = y.reshape(bsz, seqlen, dim) - return F.linear(y, out_w.to(x.dtype)), raw_v - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - -class ValueEmbedding(nn.Module): - """Reinject token identity into attention values at specific layers. - Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" - def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): - super().__init__() - self.embed = nn.Embedding(vocab_size, ve_dim) - nn.init.normal_(self.embed.weight, std=0.01) - self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(token_ids) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup def __init__(self, dim: int, mlp_mult: int): super().__init__() - # No CastedLinear -- weights come from banks - def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: - x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5) - return F.linear(x.square(), down_w.to(x.dtype)) + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + class Block(nn.Module): def __init__( @@ -741,38 +626,24 @@ def __init__( mlp_mult: int, rope_base: float, qk_gain_init: float, - layer_idx: int = 0, - ln_scale: bool = False, - dtg: bool = False, - gated_attention: bool = False, - value_residual: bool = False, ): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - gated_attention=gated_attention, value_residual=value_residual) + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - if dtg: - self.dtg_gate = nn.Linear(dim, 1, bias=True) - nn.init.zeros_(self.dtg_gate.weight) - nn.init.constant_(self.dtg_gate.bias, 2.0) - else: - self.dtg_gate = None - def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: mix = self.resid_mix.to(dtype=x.dtype) - x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) - x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out - x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) - if self.dtg_gate is not None: - gate = torch.sigmoid(self.dtg_gate(x_in.detach())) - x_out = x_in + gate * (x_out - x_in) - return x_out, raw_v + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + class GPT(nn.Module): def __init__( @@ -788,65 +659,18 @@ def __init__( logit_softcap: float, rope_base: float, qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - rope_dims: int = 0, - ln_scale: bool = False, - dtg: bool = False, - ve_enabled: bool = False, - ve_dim: int = 128, - ve_layers: str = "9,10", - gated_attention: bool = False, - value_residual: bool = False, - recur_layer: int = -1, - recur_layers: list[int] | None = None, ): super().__init__() - self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection if logit_softcap <= 0.0: raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") self.tie_embeddings = tie_embeddings self.tied_embed_init_std = tied_embed_init_std self.logit_softcap = logit_softcap - self.value_residual = value_residual - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - # Depth recurrence: duplicate layers virtually for near-zero param overhead - # Normalize: recur_layers is the canonical list; recur_layer is single-layer compat - if recur_layers is None and recur_layer >= 0: - recur_layers = [recur_layer] - self.recur_layers = sorted(recur_layers) if recur_layers else [] - self.num_physical_layers = num_layers - if self.recur_layers: - for rl in self.recur_layers: - assert 0 <= rl < num_layers, f"recur_layer={rl} out of range [0, {num_layers})" - cutoff = max(self.recur_layers) + 1 - self.v2p = list(range(cutoff)) + self.recur_layers + list(range(cutoff, num_layers)) - virtual_num_layers = num_layers + len(self.recur_layers) - else: - virtual_num_layers = num_layers - self.v2p = list(range(num_layers)) - self.virtual_num_layers = virtual_num_layers - self.num_encoder_layers = virtual_num_layers // 2 - self.num_decoder_layers = virtual_num_layers - self.num_encoder_layers + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - # Parameter banks: contiguous 3D tensors for batched optimizer (physical layer count) - head_dim = model_dim // num_heads - kv_dim = num_kv_heads * head_dim - mlp_dim = int(mlp_mult * model_dim) - self.num_layers = num_layers # physical, used for bank indexing - self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) - self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) - self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) - self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) - # Blocks: one per virtual layer (each has own scalar params) self.blocks = nn.ModuleList( [ Block( @@ -856,601 +680,65 @@ def __init__( mlp_mult, rope_base, qk_gain_init, - layer_idx=i, - ln_scale=ln_scale, - dtg=dtg, - gated_attention=gated_attention, - value_residual=value_residual, ) - for i in range(virtual_num_layers) + for i in range(num_layers) ] ) - if rope_dims > 0: - head_dim = model_dim // num_heads - for block in self.blocks: - block.attn.rope_dims = rope_dims - block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) - self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] - kv_dim_ve = self._ve_target_dim - if self.ve_layer_indices: - self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) - self.ve_layer_scales = nn.ParameterList( - [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] - ) - else: - self.ve_shared = None - self.ve_layer_scales = nn.ParameterList() - self.value_embeds = nn.ModuleList() # keep empty for compat self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) if self.lm_head is not None: self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, virtual_num_layers - xsa_last_n), virtual_num_layers): - self.blocks[i].attn.use_xsa = True self._init_weights() + def _init_weights(self) -> None: if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - n = self.num_layers - proj_scale = 1.0 / math.sqrt(2 * n) - # Init banks: orthogonal, with proj layers scaled down and out/down zero-init - for i in range(n): - nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q - nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) - nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K - nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V - nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up - nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) - # Scale proj layers (out_proj and mlp_down are "proj" layers) - self.qo_bank.data[n + i].mul_(proj_scale) - self.mlp_down_bank.data[i].mul_(proj_scale) - # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: - """Get value embedding for a specific layer using shared table + per-layer scale.""" - if self.ve_shared is None or layer_idx not in self.ve_layer_indices: - return None - if ve_cache is not None and 've' not in ve_cache: - ve_cache['ve'] = self.ve_shared(input_ids) - ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) - ve_idx = self.ve_layer_indices.index(layer_idx) - return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - n = self.num_layers # physical layer count for bank offset - v2p = self.v2p x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) x0 = x - v0 = None skips: list[Tensor] = [] - ve_cache: dict = {} + + # First half stores skips; second half reuses them in reverse order. for i in range(self.num_encoder_layers): - pi = v2p[i] - ve = self._get_ve(i, input_ids, ve_cache) - x, raw_v = self.blocks[i](x, x0, - self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], - self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], - v_embed=ve, v0=v0) - if v0 is None and raw_v is not None: - v0 = raw_v + x = self.blocks[i](x, x0) skips.append(x) for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - pi = v2p[bi] if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x, _ = self.blocks[bi](x, x0, - self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], - self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], - v_embed=ve, v0=v0) - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) + logits_proj = F.linear(x, self.tok_emb.weight) else: if self.lm_head is None: raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - return main_loss - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - n = self.num_layers # physical layer count for bank offset - v2p = self.v2p - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - v0 = None - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - pi = v2p[i] - ve = self._get_ve(i, input_ids, ve_cache) - x, raw_v = self.blocks[i](x, x0, - self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], - self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], - v_embed=ve, v0=v0) - if v0 is None and raw_v is not None: - v0 = raw_v - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - pi = v2p[bi] - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x, _ = self.blocks[bi](x, x0, - self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], - self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], - v_embed=ve, v0=v0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - def untie_recurrence(self): - """Expand banks so duplicated layers get their own independent weights. - Called before TTT so SGD can update each virtual layer independently.""" - if not self.recur_layers: - return - n = self.num_layers # physical count before expansion - # Build list of cloned rows to insert after max(recur_layers) - insert_after = max(self.recur_layers) - clones = sorted(self.recur_layers) # rows to duplicate - - def _expand(bank: Tensor) -> Tensor: - """Insert clones of recur rows after insert_after position.""" - parts = [bank[:insert_after + 1]] - for rl in clones: - parts.append(bank[rl:rl + 1].clone()) - parts.append(bank[insert_after + 1:]) - return torch.cat(parts, dim=0) - - # qo_bank: [2*n, dim, dim] -> first n are Q, last n are O - q_part = _expand(self.qo_bank.data[:n]) - o_part = _expand(self.qo_bank.data[n:]) - self.qo_bank = nn.Parameter(torch.cat([q_part, o_part], dim=0)) - - # kv_bank: [2*n, kv_dim, dim] -> first n are K, last n are V - k_part = _expand(self.kv_bank.data[:n]) - v_part = _expand(self.kv_bank.data[n:]) - self.kv_bank = nn.Parameter(torch.cat([k_part, v_part], dim=0)) - - # mlp banks: [n, ...] - self.mlp_up_bank = nn.Parameter(_expand(self.mlp_up_bank.data)) - self.mlp_down_bank = nn.Parameter(_expand(self.mlp_down_bank.data)) - - # Update to identity mapping - new_n = n + len(clones) - self.num_layers = new_n - self.num_physical_layers = new_n - self.v2p = list(range(new_n)) - self.recur_layers = [] - -# --- Sliding window evaluation --- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -def eval_val_sliding_ttt( - args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, - device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, - stride: int, batch_seqs: int = 32, log0=print, -) -> tuple[float, float]: - """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, - then train on it. Every token scored BEFORE any update that could use it.""" - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - ttt_chunk = args.ttt_chunk_tokens - - # Pre-compute all window starts - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - - # Assign each window to a chunk based on the first token it scores - num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk - chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] - for ws in window_starts: - end = min(ws + seq_len, total_tokens) - wlen = end - ws - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_start = ws + s - ci = min(scored_start // ttt_chunk, num_chunks - 1) - chunk_windows[ci].append(ws) - - log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " - f"total_windows={len(window_starts)} stride={stride} " - f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " - f"freeze_blocks={args.ttt_freeze_blocks}") - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Freeze first N blocks - frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) - ttt_params = [] - for name, p in base_model.named_parameters(): - freeze = False - for bi in frozen_block_ids: - if f"blocks.{bi}." in name: - freeze = True - break - if freeze: - p.requires_grad_(False) - else: - p.requires_grad_(True) - ttt_params.append(p) - - log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " - f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") - - optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) - t0 = time.perf_counter() - - for ci in range(num_chunks): - windows = chunk_windows[ci] - if not windows: - continue - chunk_start = ci * ttt_chunk - chunk_end = min((ci + 1) * ttt_chunk, total_tokens) - - # --- Phase 1: SCORE this chunk's windows (inference_mode) --- - my_s = (len(windows) * rank) // world_size - my_e = (len(windows) * (rank + 1)) // world_size - my_windows = windows[my_s:my_e] - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk_tok[:-1] - y_batch[i, :wlen] = chunk_tok[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - # --- Phase 2: TRAIN on this chunk (already scored = legal) --- - is_last_chunk = (ci == num_chunks - 1) - if not is_last_chunk and args.ttt_epochs > 0: - base_model.train() - chunk_seqs = (chunk_end - chunk_start) // seq_len - if chunk_seqs > 0: - cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) - for pg in optimizer.param_groups: - pg['lr'] = cos_lr - my_seq_s = (chunk_seqs * rank) // world_size - my_seq_e = (chunk_seqs * (rank + 1)) // world_size - my_chunk_seqs = my_seq_e - my_seq_s - for _ep in range(args.ttt_epochs): - for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): - be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) - actual_bs = my_seq_s + bs - start_tok = chunk_start + actual_bs * seq_len - end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 - if end_tok > val_tokens.numel(): - continue - local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - optimizer.zero_grad(set_to_none=True) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - loss = base_model(x, y) - loss.backward() - if world_size > 1: - for p in ttt_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) - optimizer.step() - - if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): - elapsed = time.perf_counter() - t0 - rl = loss_sum.item() / max(token_count.item(), 1) - rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 - log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) - - for p in base_model.parameters(): - p.requires_grad_(True) - base_model.eval() - - log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " - f"elapsed={time.perf_counter() - t0:.1f}s") - return val_loss, val_bpb - - -# --- GPTQ-lite int6 quantization --- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" -def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - best_q, best_s, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(t32.abs(), pct, dim=1) - else: - row_clip = t32.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) - recon = q.float() * s.float()[:, None] - err = (t32 - recon).pow(2).mean().item() - if err < best_err: - best_q, best_s, best_err = q, s, err - return best_q, best_s - amax = t32.abs().max().item() - scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) - return q, scale - -def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: - """Convert 3D bank tensors into individual 2D tensors with standard names.""" - out: dict[str, Tensor] = {} - n = num_layers - for name, tensor in sd.items(): - if name == "qo_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] - out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] - elif name == "kv_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] - out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] - elif name == "mlp_up_bank": - for i in range(n): - out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] - elif name == "mlp_down_bank": - for i in range(n): - out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] - else: - out[name] = tensor - return out + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") -def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - """Convert individual 2D tensors back into 3D bank tensors.""" - out: dict[str, Tensor] = {} - n = num_layers - # Reconstruct banks from individual weight keys - qo_slices = [None] * (2 * n) - kv_slices = [None] * (2 * n) - up_slices = [None] * n - down_slices = [None] * n - consumed = set() - for i in range(n): - qk = f"blocks.{i}.attn.c_q.weight" - if qk in sd: - qo_slices[i] = sd[qk] - consumed.add(qk) - ok = f"blocks.{i}.attn.proj.weight" - if ok in sd: - qo_slices[n + i] = sd[ok] - consumed.add(ok) - kk = f"blocks.{i}.attn.c_k.weight" - if kk in sd: - kv_slices[i] = sd[kk] - consumed.add(kk) - vk = f"blocks.{i}.attn.c_v.weight" - if vk in sd: - kv_slices[n + i] = sd[vk] - consumed.add(vk) - fk = f"blocks.{i}.mlp.fc.weight" - if fk in sd: - up_slices[i] = sd[fk] - consumed.add(fk) - dk = f"blocks.{i}.mlp.proj.weight" - if dk in sd: - down_slices[i] = sd[dk] - consumed.add(dk) - out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) - out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) - out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) - out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) - for name, tensor in sd.items(): - if name not in consumed: - out[name] = tensor - return out -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - -# --- Training --- +# ----------------------------- +# TRAINING +# ----------------------------- def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() - # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) @@ -1469,18 +757,23 @@ def main() -> None: dist.init_process_group(backend="nccl", device_id=device) dist.barrier() master_process = rank == 0 + + # Fast math knobs torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) enable_flash_sdp(True) enable_mem_efficient_sdp(False) enable_math_sdp(False) + logfile = None if master_process: os.makedirs("logs", exist_ok=True) logfile = f"logs/{args.run_id}.txt" print(logfile) + def log0(msg: str, console: bool = True) -> None: if not master_process: return @@ -1489,6 +782,7 @@ def log0(msg: str, console: bool = True) -> None: if logfile is not None: with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + log0(code, console=False) log0("=" * 100, console=False) log0(f"Running Python {sys.version}", console=False) @@ -1498,10 +792,16 @@ def log0(msg: str, console: bool = True) -> None: console=False, ) log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) @@ -1511,16 +811,18 @@ def log0(msg: str, console: bool = True) -> None: ) dataset_dir = Path(args.data_path).resolve() actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( sp, args.vocab_size, device ) log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - CastedLinear._qat_enabled = args.qat_enabled + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + base_model = GPT( vocab_size=args.vocab_size, num_layers=args.num_layers, @@ -1533,449 +835,292 @@ def log0(msg: str, console: bool = True) -> None: logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - dtg=args.dtg_enabled, - ve_enabled=args.ve_enabled, - ve_dim=args.ve_dim, - ve_layers=args.ve_layers, - gated_attention=args.gated_attention, - value_residual=args.value_residual, - recur_layer=args.recur_layer, - recur_layers=([int(x) for x in args.recur_layers_str.split(",") if x.strip()] or None), ).to(device).bfloat16() - # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward - base_model.qo_bank.data = base_model.qo_bank.data.float() - base_model.kv_bank.data = base_model.kv_bank.data.float() - base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() - base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() for module in base_model.modules(): if isinstance(module, CastedLinear): module.float() restore_low_dim_params_to_fp32(base_model) - # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, - # and non-bank grads are manually all-reduced before Adam steps. compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model = compiled_model - - if not args.eval_only: - # - 4 parameter banks -> Muon (batched Newton-Schulz) - # - token embedding -> Adam - # - scalars/control tensors -> Adam - # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) - matrix_params = [ - base_model.qo_bank, base_model.kv_bank, - base_model.mlp_up_bank, base_model.mlp_down_bank, - ] - block_named_params = list(base_model.blocks.named_parameters()) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - scalar_params.append(base_model.bigram.proj.weight) - if base_model.ve_shared is not None: - tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.ve_shared.proj is not None: - scalar_params.append(base_model.ve_shared.proj.weight) - scalar_params.append(base_model.ve_shared.scale) - for s in base_model.ve_layer_scales: - scalar_params.append(s) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, - weight_decay=args.adam_wd, fused=True, ) - # Non-bank params that need manual all-reduce (replicated across GPUs) - replicated_params = list(optimizer_tok.param_groups[0]["params"]) - for pg in optimizer_tok.param_groups[1:]: - replicated_params.extend(pg["params"]) - replicated_params.extend(scalar_params) - - optimizer_head = None - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - replicated_params.append(base_model.lm_head.weight) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if optimizer_head is not None: - optimizers.append(optimizer_head) - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - if base_model.recur_layers: - log0(f"recurrence:layers={base_model.recur_layers} physical_layers={args.num_layers} virtual_layers={base_model.virtual_num_layers}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] - log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - # All-reduce all grads for warmup (simple, not optimized) - if distributed: - for p in base_model.parameters(): - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - from collections import deque - lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - ema_decay = 0.997 - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: - CastedLinear._qat_enabled = True - log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): zero_grad_all() - train_loss = torch.zeros((), device=device) for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - # === 3-phase overlapped optimizer step === - # Phase 1: Launch async reduce-scatter for banks (biggest first) - optimizer_muon.launch_reduce_scatters() - # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) - if distributed: - for p in replicated_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - optimizer_tok.step() - optimizer_scalar.step() - if optimizer_head is not None: - optimizer_head.step() - # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) - optimizer_muon.step() + opt.step() zero_grad_all() - # EMA update - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - if args.lawa_enabled and step % args.lawa_freq == 0: - lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, ) - if should_log_train: + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" ) - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - # Apply weight averaging - if args.lawa_enabled and len(lawa_queue) > 1: - log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") - current_state = base_model.state_dict() - avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} - for snap in lawa_queue: - for name in avg_state: - avg_state[name] += snap[name].float() - for name in avg_state: - avg_state[name] /= len(lawa_queue) - avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) - base_model.load_state_dict(avg_state, strict=True) - else: - log0("ema:applying EMA weights") - current_state = base_model.state_dict() - avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} - base_model.load_state_dict(avg_state, strict=True) - torch.cuda.synchronize() - t_diag = time.perf_counter() - diag_val_loss, diag_val_bpb = eval_val( - args, compiled_model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) ) - torch.cuda.synchronize() + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) log0( - f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" ) - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - # Unbank 3D tensors into individual 2D tensors for quantization - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) - quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = lzma.compress(quant_raw, preset=6) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") - log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") - if distributed: - dist.barrier() - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, - ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, - gated_attention=args.gated_attention, value_residual=args.value_residual, - recur_layer=args.recur_layer, - recur_layers=([int(x) for x in args.recur_layers_str.split(",") if x.strip()] or None), - ).to(device).bfloat16() - eval_model.qo_bank.data = eval_model.qo_bank.data.float() - eval_model.kv_bank.data = eval_model.kv_bank.data.float() - eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() - eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - with open("final_model.int6.ptz", "rb") as f: + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(lzma.decompress(quant_blob_disk)), - map_location="cpu", - ) - # Use eval_model's own state as dequant template - if not args.eval_only: - template_sd, template_unbanked = sd_cpu, unbanked_sd - else: - template_sd = {k: v.detach().cpu() for k, v in eval_model.state_dict().items()} - template_unbanked = _unbank_state_dict(template_sd, args.num_layers) - deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], template_unbanked) - # Re-bank the dequantized tensors - deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, template_sd) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) torch.cuda.synchronize() t_qeval = time.perf_counter() q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, ) torch.cuda.synchronize() log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - # Legal score-first TTT (PR #461 recipe) - ttt_untie = bool(int(os.environ.get("TTT_UNTIE", "1"))) - if args.ttt_enabled: - if eval_model.recur_layers and ttt_untie: - log0(f"ttt:untying recurrence at layers {eval_model.recur_layers}") - eval_model.untie_recurrence() - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_loss, ttt_bpb = eval_val_sliding_ttt( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, log0=log0, - ) - torch.cuda.synchronize() - log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") - log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + if distributed: dist.destroy_process_group() + + if __name__ == "__main__": - main() \ No newline at end of file + main() From 174142c39b029b55031e4a60c155f1124571711f Mon Sep 17 00:00:00 2001 From: Marko Sisovic Date: Wed, 25 Mar 2026 05:57:56 +0000 Subject: [PATCH 6/8] Clean up submission README: first person, simplify approach section Co-Authored-By: Claude Opus 4.6 --- .../2026-03-25_RecurLayers/README.md | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/README.md b/records/track_10min_16mb/2026-03-25_RecurLayers/README.md index 53db1c8cc..f35f44ba5 100644 --- a/records/track_10min_16mb/2026-03-25_RecurLayers/README.md +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/README.md @@ -6,26 +6,16 @@ Trained on 8xH100 SXM in ~600 seconds. ~15.9MB artifact (int6+lzma). ## Motivation -We explored both width scaling (MODEL_DIM=576) and depth scaling (adding layers) and found that depth consistently wins over width in this regime. A full independent 12-layer model at dim=512 outperformed a wider 11-layer model at dim=576, despite the wider model having more parameters. However, adding independent layers pushes the model over the 16MB artifact budget. Depth recurrence solves this: by re-executing mid-network layers with independent block scalars, we get the depth benefit without the parameter/size cost. Dual recurrence on layers 4 and 5 gives us 13 virtual layers from 11 physical, staying well under budget at ~15.9MB. +I explored both width scaling (MODEL_DIM=576) and depth scaling (adding layers) and found that depth consistently wins over width in this regime. A full independent 12-layer model at dim=512 outperformed a wider 11-layer model at dim=576, despite the wider model having more parameters. However, adding independent layers pushes the model over the 16MB artifact budget. Depth recurrence solves this: by re-executing mid-network layers with independent block scalars, I get the depth benefit without the parameter/size cost. Dual recurrence on layers 4 and 5 gives 13 virtual layers from 11 physical, staying well under budget at ~15.9MB. ## Approach Depth recurrence applied to layers 4 and 5, creating 13 virtual layers from 11 physical layers while keeping parameter count at ~27M. Combined with test-time training (TTT) for additional evaluation-time adaptation. -### 1. Dual Depth Recurrence (layers 4,5) +### Dual Depth Recurrence (layers 4,5) Layers 4 and 5 are each executed twice in sequence (pattern: 0,1,2,3,4,5,4,5,6,7,8,9,10), producing 13 virtual layers from 11 physical layers. Each recurrent pass uses independent learnable block scalars, so the model can modulate how the repeated layers behave on their second pass. This adds depth without increasing model size or artifact bytes — only the small block scalar parameters are added (~2K params). -### 2. Test-Time Training (TTT) -At evaluation time, the model adapts its weights to the validation data using a short fine-tuning pass. 3 epochs over the validation set with lr=0.002, chunked into 32K-token segments. The top 2 blocks are frozen during TTT to preserve the output head's calibration. Tied TTT (no weight untying) performs equivalently to untied. - -### 3. Inherited Techniques from Baseline -- **Int6 quantization + lzma compression**: Per-row int6 quantization on MLP/attention weights -- **3x MLP expansion**: Hidden dim 1536 (3x model dim) -- **Bigram hash embeddings**: 2048-bucket hash table (dim=128) -- **Value embeddings**: Learned value residuals on layers 9,10 -- **SWA**: Stochastic weight averaging every 50 steps -- **Muon optimizer**: With weight decay 0.04, momentum warmup 0.92->0.99 -- **Orthogonal initialization** +Everything else (TTT, int6 quantization, SWA, bigram embeddings, value embeddings, Muon optimizer, etc.) is inherited from [PR #549](https://github.com/openai/parameter-golf/pull/549). ## Hyperparameters From f5b9631ef32865ba3bf7f9c1a770242c90816fce Mon Sep 17 00:00:00 2001 From: Marko Sisovic Date: Wed, 25 Mar 2026 11:54:42 +0000 Subject: [PATCH 7/8] Rerun seed 2024 with ITERATIONS=9000 for consistency Previous run accidentally used 8000 iterations. Reran with 9000 to match other seeds. Mean val_bpb: 1.1184 (was 1.1182), std: 0.00049 (was 0.00076). Co-Authored-By: Claude Opus 4.6 --- .../2026-03-25_RecurLayers/README.md | 10 +- .../2026-03-25_RecurLayers/submission.json | 10 +- .../2026-03-25_RecurLayers/train_seed2024.log | 472 +++++++++--------- 3 files changed, 246 insertions(+), 246 deletions(-) diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/README.md b/records/track_10min_16mb/2026-03-25_RecurLayers/README.md index f35f44ba5..b6a1951bb 100644 --- a/records/track_10min_16mb/2026-03-25_RecurLayers/README.md +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/README.md @@ -1,6 +1,6 @@ # Depth Recurrence (layers 4,5) -## Score: mean val_bpb = 1.1182 (3 seeds: 1.1179, 1.1191, 1.1176) +## Score: mean val_bpb = 1.1184 (3 seeds: 1.1179, 1.1191, 1.1183) Trained on 8xH100 SXM in ~600 seconds. ~15.9MB artifact (int6+lzma). @@ -44,7 +44,7 @@ Everything else (TTT, int6 quantization, SWA, bigram embeddings, value embedding ## Key Metrics -- **Mean val_bpb: 1.11819** (std: 0.00076) +- **Mean val_bpb: 1.11840** (std: 0.00049) - Training: ~6,100 steps in ~600s - Model params: ~27M - Artifact size: ~15.9MB (int6+lzma) @@ -57,9 +57,9 @@ Three independent training runs with different random seeds: |------|----------|---------| | 1337 | 1.88749538 | 1.11788404 | | 2025 | 1.88948575 | 1.11906285 | -| 2024 | 1.88706558 | 1.11762949 | -| **Mean** | **1.88801557** | **1.11819213** | -| **Std** | **0.00129122** | **0.00076473** | +| 2024 | 1.88811812 | 1.11825287 | +| **Mean** | **1.88836642** | **1.11839992** | +| **Std** | **0.00083132** | **0.00049235** | ## Run Commands diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/submission.json b/records/track_10min_16mb/2026-03-25_RecurLayers/submission.json index c33dede75..e6ca363ea 100644 --- a/records/track_10min_16mb/2026-03-25_RecurLayers/submission.json +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/submission.json @@ -4,15 +4,15 @@ "name": "Depth Recurrence (layers 4,5) + TTT", "blurb": "Dual depth recurrence on layers 4 and 5 (11 physical -> 13 virtual layers) with tied test-time training. Reuses layer weights to add depth without increasing model size, keeping the artifact under 16MB with int6+lzma compression. Combined with TTT, SWA, bigram embeddings, value embeddings, and Muon optimizer with weight decay.", "date": "2026-03-25T00:00:00Z", - "val_loss": 1.88801557, - "val_bpb": 1.11819213, - "val_loss_std": 0.00129122, - "val_bpb_std": 0.00076473, + "val_loss": 1.88836642, + "val_bpb": 1.11839992, + "val_loss_std": 0.00083132, + "val_bpb_std": 0.00049235, "seeds": [1337, 2025, 2024], "seed_results": { "1337": {"val_loss": 1.88749538, "val_bpb": 1.11788404}, "2025": {"val_loss": 1.88948575, "val_bpb": 1.11906285}, - "2024": {"val_loss": 1.88706558, "val_bpb": 1.11762949} + "2024": {"val_loss": 1.88811812, "val_bpb": 1.11825287} }, "step_stop": 6100, "wallclock_seconds": 600.0, diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2024.log b/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2024.log index 1d3852ead..34cfd816b 100644 --- a/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2024.log +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2024.log @@ -1,8 +1,8 @@ -W0325 05:28:06.586000 75757 torch/distributed/run.py:803] -W0325 05:28:06.586000 75757 torch/distributed/run.py:803] ***************************************** -W0325 05:28:06.586000 75757 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0325 05:28:06.586000 75757 torch/distributed/run.py:803] ***************************************** -logs/388e32c2-50e0-4e35-88ac-51219634dd13.txt +W0325 11:30:23.920000 12574 torch/distributed/run.py:803] +W0325 11:30:23.920000 12574 torch/distributed/run.py:803] ***************************************** +W0325 11:30:23.920000 12574 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 11:30:23.920000 12574 torch/distributed/run.py:803] ***************************************** +logs/f5675aa4-2278-4d41-8bfa-b0ded28ea1c7.txt val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model train_loader:dataset:fineweb10B_sp1024 train_shards:80 val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 @@ -14,7 +14,7 @@ world_size:8 grad_accum_steps:1 sdp_backends:cudnn=False flash=True mem_efficient=False math=False attention_mode:gqa num_heads:8 num_kv_heads:4 tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:8000 warmup_steps:20 max_wallclock_seconds:600.000 +train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 seed:2024 warmup_step:1/20 warmup_step:2/20 @@ -36,239 +36,239 @@ warmup_step:17/20 warmup_step:18/20 warmup_step:19/20 warmup_step:20/20 -step:0/8000 val_loss:6.9327 val_bpb:4.1059 train_time:0ms step_avg:0.02ms -step:1/8000 train_loss:6.9341 train_time:144ms step_avg:144.25ms -step:2/8000 train_loss:8.6442 train_time:190ms step_avg:95.07ms -step:3/8000 train_loss:7.6507 train_time:285ms step_avg:95.08ms -step:4/8000 train_loss:7.2426 train_time:380ms step_avg:94.93ms -step:5/8000 train_loss:7.1122 train_time:475ms step_avg:94.98ms -step:6/8000 train_loss:7.0615 train_time:570ms step_avg:94.98ms -step:7/8000 train_loss:6.9790 train_time:665ms step_avg:95.01ms -step:8/8000 train_loss:6.8907 train_time:759ms step_avg:94.92ms -step:9/8000 train_loss:6.5547 train_time:853ms step_avg:94.83ms -step:10/8000 train_loss:6.1912 train_time:949ms step_avg:94.86ms -step:500/8000 train_loss:2.3959 train_time:48743ms step_avg:97.49ms -step:1000/8000 train_loss:2.2556 train_time:97892ms step_avg:97.89ms -step:1500/8000 train_loss:2.2016 train_time:146957ms step_avg:97.97ms -step:2000/8000 train_loss:2.0479 train_time:196043ms step_avg:98.02ms -step:2500/8000 train_loss:2.1513 train_time:245114ms step_avg:98.05ms -step:3000/8000 train_loss:2.1306 train_time:294168ms step_avg:98.06ms -step:3500/8000 train_loss:2.1382 train_time:343183ms step_avg:98.05ms -step:4000/8000 train_loss:1.9288 train_time:392206ms step_avg:98.05ms -step:4000/8000 val_loss:2.0189 val_bpb:1.1957 train_time:392256ms step_avg:98.06ms -step:4500/8000 train_loss:2.0800 train_time:441225ms step_avg:98.05ms -step:5000/8000 train_loss:2.0518 train_time:490224ms step_avg:98.04ms +step:0/9000 val_loss:6.9327 val_bpb:4.1059 train_time:0ms step_avg:0.01ms +step:1/9000 train_loss:6.9341 train_time:173ms step_avg:172.60ms +step:2/9000 train_loss:8.6442 train_time:218ms step_avg:109.24ms +step:3/9000 train_loss:7.6536 train_time:313ms step_avg:104.44ms +step:4/9000 train_loss:7.2292 train_time:409ms step_avg:102.28ms +step:5/9000 train_loss:7.1066 train_time:503ms step_avg:100.65ms +step:6/9000 train_loss:7.0691 train_time:599ms step_avg:99.77ms +step:7/9000 train_loss:6.9931 train_time:694ms step_avg:99.16ms +step:8/9000 train_loss:6.8960 train_time:790ms step_avg:98.74ms +step:9/9000 train_loss:6.5644 train_time:886ms step_avg:98.39ms +step:10/9000 train_loss:6.2023 train_time:981ms step_avg:98.07ms +step:500/9000 train_loss:2.3940 train_time:48782ms step_avg:97.56ms +step:1000/9000 train_loss:2.2600 train_time:97666ms step_avg:97.67ms +step:1500/9000 train_loss:2.2016 train_time:146580ms step_avg:97.72ms +step:2000/9000 train_loss:2.0487 train_time:195581ms step_avg:97.79ms +step:2500/9000 train_loss:2.1484 train_time:244522ms step_avg:97.81ms +step:3000/9000 train_loss:2.1327 train_time:293569ms step_avg:97.86ms +step:3500/9000 train_loss:2.1421 train_time:342629ms step_avg:97.89ms +step:4000/9000 train_loss:1.9307 train_time:391586ms step_avg:97.90ms +step:4000/9000 val_loss:2.0208 val_bpb:1.1968 train_time:391638ms step_avg:97.91ms +step:4500/9000 train_loss:2.0783 train_time:440574ms step_avg:97.91ms +step:5000/9000 train_loss:2.0532 train_time:489649ms step_avg:97.93ms swa:start step:5450 -step:5500/8000 train_loss:1.9626 train_time:539319ms step_avg:98.06ms -late_qat:enabled step:5593 scale:0.1500 -step:6000/8000 train_loss:1.8847 train_time:588951ms step_avg:98.16ms -step:6111/8000 val_loss:1.9183 val_bpb:1.1361 train_time:600058ms step_avg:98.19ms -stopping_early: wallclock_cap train_time:600058ms step:6111/8000 -peak memory allocated: 25295 MiB reserved: 26058 MiB +step:5500/9000 train_loss:1.9651 train_time:538840ms step_avg:97.97ms +late_qat:enabled step:5598 scale:0.1498 +step:6000/9000 train_loss:1.8884 train_time:588478ms step_avg:98.08ms +step:6116/9000 val_loss:1.9192 val_bpb:1.1367 train_time:600071ms step_avg:98.11ms +stopping_early: wallclock_cap train_time:600071ms step:6116/9000 +peak memory allocated: 25295 MiB reserved: 26088 MiB ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9167 val_bpb:1.1352 eval_time:2330ms +DIAGNOSTIC post_ema val_loss:1.9177 val_bpb:1.1358 eval_time:2335ms Serialized model: 106179454 bytes Code size: 95036 bytes -Serialized model int6+lzma: 15825172 bytes -Total submission size int6+lzma: 15920208 bytes -final_int6_roundtrip val_loss:1.9308 val_bpb:1.1435 eval_time:6942ms -final_int6_roundtrip_exact val_loss:1.93077809 val_bpb:1.14351554 -final_int6_sliding_window val_loss:1.8911 val_bpb:1.1200 stride:64 eval_time:86022ms -final_int6_sliding_window_exact val_loss:1.89110334 val_bpb:1.12002088 -final_int8_zlib_roundtrip_exact val_loss:1.89110334 val_bpb:1.12002088 +Serialized model int6+lzma: 15834540 bytes +Total submission size int6+lzma: 15929576 bytes +final_int6_roundtrip val_loss:1.9316 val_bpb:1.1440 eval_time:24477ms +final_int6_roundtrip_exact val_loss:1.93157593 val_bpb:1.14398806 +final_int6_sliding_window val_loss:1.8920 val_bpb:1.1205 stride:64 eval_time:116093ms +final_int6_sliding_window_exact val_loss:1.89197002 val_bpb:1.12053418 +final_int8_zlib_roundtrip_exact val_loss:1.89197002 val_bpb:1.12053418 ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=2 ttt_sliding:params unfrozen=26994268 frozen=4112 - ttt_chunk [1/1893] bpb=1.152698 time=0.5s - ttt_chunk [11/1893] bpb=1.142797 time=3.0s - ttt_chunk [21/1893] bpb=1.129555 time=5.4s - ttt_chunk [31/1893] bpb=1.127492 time=7.9s - ttt_chunk [41/1893] bpb=1.114271 time=10.4s - ttt_chunk [51/1893] bpb=1.108772 time=12.9s - ttt_chunk [61/1893] bpb=1.115238 time=15.3s - ttt_chunk [71/1893] bpb=1.113682 time=17.8s - ttt_chunk [81/1893] bpb=1.112830 time=20.3s - ttt_chunk [91/1893] bpb=1.113697 time=22.8s - ttt_chunk [101/1893] bpb=1.117243 time=25.2s - ttt_chunk [111/1893] bpb=1.119761 time=27.7s - ttt_chunk [121/1893] bpb=1.112954 time=30.2s - ttt_chunk [131/1893] bpb=1.113122 time=32.7s - ttt_chunk [141/1893] bpb=1.118719 time=35.1s - ttt_chunk [151/1893] bpb=1.120823 time=37.6s - ttt_chunk [161/1893] bpb=1.120211 time=40.1s - ttt_chunk [171/1893] bpb=1.124588 time=42.6s - ttt_chunk [181/1893] bpb=1.126885 time=45.0s - ttt_chunk [191/1893] bpb=1.134034 time=47.5s - ttt_chunk [201/1893] bpb=1.132635 time=50.0s - ttt_chunk [211/1893] bpb=1.130300 time=52.5s - ttt_chunk [221/1893] bpb=1.131792 time=55.0s - ttt_chunk [231/1893] bpb=1.130602 time=57.5s - ttt_chunk [241/1893] bpb=1.131019 time=60.0s - ttt_chunk [251/1893] bpb=1.130536 time=62.4s - ttt_chunk [261/1893] bpb=1.127636 time=64.9s - ttt_chunk [271/1893] bpb=1.126431 time=67.4s - ttt_chunk [281/1893] bpb=1.127757 time=69.8s - ttt_chunk [291/1893] bpb=1.129542 time=72.4s - ttt_chunk [301/1893] bpb=1.130241 time=74.8s - ttt_chunk [311/1893] bpb=1.132259 time=77.3s - ttt_chunk [321/1893] bpb=1.134252 time=79.8s - ttt_chunk [331/1893] bpb=1.134108 time=82.3s - ttt_chunk [341/1893] bpb=1.133064 time=84.8s - ttt_chunk [351/1893] bpb=1.135435 time=87.3s - ttt_chunk [361/1893] bpb=1.135684 time=89.8s - ttt_chunk [371/1893] bpb=1.135018 time=92.2s - ttt_chunk [381/1893] bpb=1.135189 time=94.7s - ttt_chunk [391/1893] bpb=1.135030 time=97.2s - ttt_chunk [401/1893] bpb=1.132961 time=99.7s - ttt_chunk [411/1893] bpb=1.131839 time=102.1s - ttt_chunk [421/1893] bpb=1.130917 time=104.6s - ttt_chunk [431/1893] bpb=1.130798 time=107.1s - ttt_chunk [441/1893] bpb=1.131204 time=109.6s - ttt_chunk [451/1893] bpb=1.131534 time=112.1s - ttt_chunk [461/1893] bpb=1.130458 time=114.5s - ttt_chunk [471/1893] bpb=1.131108 time=117.0s - ttt_chunk [481/1893] bpb=1.130715 time=119.5s - ttt_chunk [491/1893] bpb=1.129638 time=122.0s - ttt_chunk [501/1893] bpb=1.129226 time=124.5s - ttt_chunk [511/1893] bpb=1.128588 time=127.0s - ttt_chunk [521/1893] bpb=1.126247 time=129.4s - ttt_chunk [531/1893] bpb=1.127478 time=131.9s - ttt_chunk [541/1893] bpb=1.127861 time=134.4s - ttt_chunk [551/1893] bpb=1.126835 time=136.8s - ttt_chunk [561/1893] bpb=1.127383 time=139.3s - ttt_chunk [571/1893] bpb=1.126322 time=141.8s - ttt_chunk [581/1893] bpb=1.125534 time=144.3s - ttt_chunk [591/1893] bpb=1.124915 time=146.8s - ttt_chunk [601/1893] bpb=1.125413 time=149.3s - ttt_chunk [611/1893] bpb=1.125370 time=151.8s - ttt_chunk [621/1893] bpb=1.125213 time=154.3s - ttt_chunk [631/1893] bpb=1.125922 time=156.7s - ttt_chunk [641/1893] bpb=1.125684 time=159.2s - ttt_chunk [651/1893] bpb=1.125767 time=161.7s - ttt_chunk [661/1893] bpb=1.125207 time=164.2s - ttt_chunk [671/1893] bpb=1.125569 time=166.6s - ttt_chunk [681/1893] bpb=1.126268 time=169.1s - ttt_chunk [691/1893] bpb=1.127255 time=171.6s - ttt_chunk [701/1893] bpb=1.126684 time=174.1s - ttt_chunk [711/1893] bpb=1.126676 time=176.6s - ttt_chunk [721/1893] bpb=1.126337 time=179.0s - ttt_chunk [731/1893] bpb=1.126351 time=181.5s - ttt_chunk [741/1893] bpb=1.126446 time=184.0s - ttt_chunk [751/1893] bpb=1.126280 time=186.5s - ttt_chunk [761/1893] bpb=1.126209 time=189.0s - ttt_chunk [771/1893] bpb=1.125883 time=191.4s - ttt_chunk [781/1893] bpb=1.126612 time=193.9s - ttt_chunk [791/1893] bpb=1.126185 time=196.4s - ttt_chunk [801/1893] bpb=1.126506 time=198.8s - ttt_chunk [811/1893] bpb=1.126256 time=201.3s - ttt_chunk [821/1893] bpb=1.126017 time=203.8s - ttt_chunk [831/1893] bpb=1.125800 time=206.3s - ttt_chunk [841/1893] bpb=1.125135 time=208.7s - ttt_chunk [851/1893] bpb=1.124881 time=211.2s - ttt_chunk [861/1893] bpb=1.124615 time=213.7s - ttt_chunk [871/1893] bpb=1.124877 time=216.1s - ttt_chunk [881/1893] bpb=1.125051 time=218.7s - ttt_chunk [891/1893] bpb=1.124604 time=221.1s - ttt_chunk [901/1893] bpb=1.124350 time=223.6s - ttt_chunk [911/1893] bpb=1.124485 time=226.1s - ttt_chunk [921/1893] bpb=1.124957 time=228.5s - ttt_chunk [931/1893] bpb=1.124919 time=231.0s - ttt_chunk [941/1893] bpb=1.124607 time=233.5s - ttt_chunk [951/1893] bpb=1.125010 time=235.9s - ttt_chunk [961/1893] bpb=1.125091 time=238.4s - ttt_chunk [971/1893] bpb=1.125951 time=240.9s - ttt_chunk [981/1893] bpb=1.126037 time=243.3s - ttt_chunk [991/1893] bpb=1.126072 time=245.8s - ttt_chunk [1001/1893] bpb=1.126024 time=248.3s - ttt_chunk [1011/1893] bpb=1.125816 time=250.7s - ttt_chunk [1021/1893] bpb=1.126176 time=253.2s - ttt_chunk [1031/1893] bpb=1.126656 time=255.7s - ttt_chunk [1041/1893] bpb=1.126303 time=258.2s - ttt_chunk [1051/1893] bpb=1.126056 time=260.6s - ttt_chunk [1061/1893] bpb=1.126125 time=263.1s - ttt_chunk [1071/1893] bpb=1.126729 time=265.6s - ttt_chunk [1081/1893] bpb=1.127023 time=268.1s - ttt_chunk [1091/1893] bpb=1.127756 time=270.5s - ttt_chunk [1101/1893] bpb=1.127773 time=273.0s - ttt_chunk [1111/1893] bpb=1.127631 time=275.5s - ttt_chunk [1121/1893] bpb=1.127428 time=278.0s - ttt_chunk [1131/1893] bpb=1.127315 time=280.5s - ttt_chunk [1141/1893] bpb=1.127020 time=282.9s - ttt_chunk [1151/1893] bpb=1.127036 time=285.4s - ttt_chunk [1161/1893] bpb=1.126665 time=287.9s - ttt_chunk [1171/1893] bpb=1.127008 time=290.3s - ttt_chunk [1181/1893] bpb=1.126281 time=292.8s - ttt_chunk [1191/1893] bpb=1.126148 time=295.3s - ttt_chunk [1201/1893] bpb=1.126564 time=297.7s - ttt_chunk [1211/1893] bpb=1.126098 time=300.2s - ttt_chunk [1221/1893] bpb=1.125816 time=302.7s - ttt_chunk [1231/1893] bpb=1.125546 time=305.1s - ttt_chunk [1241/1893] bpb=1.125214 time=307.6s - ttt_chunk [1251/1893] bpb=1.124635 time=310.1s - ttt_chunk [1261/1893] bpb=1.124613 time=312.6s - ttt_chunk [1271/1893] bpb=1.124243 time=315.1s - ttt_chunk [1281/1893] bpb=1.124042 time=317.5s - ttt_chunk [1291/1893] bpb=1.123808 time=320.0s - ttt_chunk [1301/1893] bpb=1.123219 time=322.5s - ttt_chunk [1311/1893] bpb=1.122848 time=324.9s - ttt_chunk [1321/1893] bpb=1.122545 time=327.4s - ttt_chunk [1331/1893] bpb=1.122490 time=329.9s - ttt_chunk [1341/1893] bpb=1.122373 time=332.3s - ttt_chunk [1351/1893] bpb=1.122312 time=334.8s - ttt_chunk [1361/1893] bpb=1.122344 time=337.3s - ttt_chunk [1371/1893] bpb=1.122201 time=339.7s - ttt_chunk [1381/1893] bpb=1.122180 time=342.2s - ttt_chunk [1391/1893] bpb=1.121784 time=344.7s - ttt_chunk [1401/1893] bpb=1.121738 time=347.2s - ttt_chunk [1411/1893] bpb=1.121838 time=349.6s - ttt_chunk [1421/1893] bpb=1.122101 time=352.1s - ttt_chunk [1431/1893] bpb=1.121782 time=354.6s - ttt_chunk [1441/1893] bpb=1.122299 time=357.0s - ttt_chunk [1451/1893] bpb=1.122639 time=359.5s - ttt_chunk [1461/1893] bpb=1.122192 time=362.0s - ttt_chunk [1471/1893] bpb=1.123246 time=364.5s - ttt_chunk [1481/1893] bpb=1.122784 time=367.0s - ttt_chunk [1491/1893] bpb=1.122589 time=369.4s - ttt_chunk [1501/1893] bpb=1.122499 time=371.9s - ttt_chunk [1511/1893] bpb=1.122532 time=374.4s - ttt_chunk [1521/1893] bpb=1.122545 time=376.9s - ttt_chunk [1531/1893] bpb=1.122009 time=379.3s - ttt_chunk [1541/1893] bpb=1.121883 time=381.8s - ttt_chunk [1551/1893] bpb=1.122203 time=384.3s - ttt_chunk [1561/1893] bpb=1.122216 time=386.8s - ttt_chunk [1571/1893] bpb=1.122055 time=389.3s - ttt_chunk [1581/1893] bpb=1.122170 time=391.8s - ttt_chunk [1591/1893] bpb=1.122016 time=394.3s - ttt_chunk [1601/1893] bpb=1.122186 time=396.8s - ttt_chunk [1611/1893] bpb=1.122140 time=399.3s - ttt_chunk [1621/1893] bpb=1.121728 time=401.8s - ttt_chunk [1631/1893] bpb=1.122045 time=404.2s - ttt_chunk [1641/1893] bpb=1.122058 time=406.7s - ttt_chunk [1651/1893] bpb=1.122020 time=409.2s - ttt_chunk [1661/1893] bpb=1.121890 time=411.7s - ttt_chunk [1671/1893] bpb=1.122370 time=414.2s - ttt_chunk [1681/1893] bpb=1.122515 time=416.7s - ttt_chunk [1691/1893] bpb=1.122366 time=419.2s - ttt_chunk [1701/1893] bpb=1.122534 time=421.6s - ttt_chunk [1711/1893] bpb=1.122524 time=424.1s - ttt_chunk [1721/1893] bpb=1.122528 time=426.6s - ttt_chunk [1731/1893] bpb=1.122396 time=429.1s - ttt_chunk [1741/1893] bpb=1.122206 time=431.6s - ttt_chunk [1751/1893] bpb=1.122042 time=434.1s - ttt_chunk [1761/1893] bpb=1.122188 time=436.6s - ttt_chunk [1771/1893] bpb=1.122101 time=439.1s - ttt_chunk [1781/1893] bpb=1.122120 time=441.5s - ttt_chunk [1791/1893] bpb=1.121712 time=444.0s - ttt_chunk [1801/1893] bpb=1.121600 time=446.5s - ttt_chunk [1811/1893] bpb=1.121495 time=449.0s - ttt_chunk [1821/1893] bpb=1.121555 time=451.5s - ttt_chunk [1831/1893] bpb=1.120959 time=453.9s - ttt_chunk [1841/1893] bpb=1.120911 time=456.4s - ttt_chunk [1851/1893] bpb=1.120694 time=459.0s - ttt_chunk [1861/1893] bpb=1.120338 time=461.5s - ttt_chunk [1871/1893] bpb=1.120323 time=464.0s - ttt_chunk [1881/1893] bpb=1.119878 time=466.5s - ttt_chunk [1891/1893] bpb=1.119649 time=469.0s - ttt_chunk [1893/1893] bpb=1.119692 time=469.3s -ttt_sliding:done val_loss=1.887066 val_bpb=1.117629 elapsed=469.3s -legal_ttt val_loss:1.8871 val_bpb:1.1176 eval_time:469850ms -legal_ttt_exact val_loss:1.88706558 val_bpb:1.11762949 + ttt_chunk [1/1893] bpb=1.157235 time=0.9s + ttt_chunk [11/1893] bpb=1.145437 time=3.6s + ttt_chunk [21/1893] bpb=1.130624 time=6.3s + ttt_chunk [31/1893] bpb=1.128957 time=9.0s + ttt_chunk [41/1893] bpb=1.115633 time=11.6s + ttt_chunk [51/1893] bpb=1.109563 time=14.3s + ttt_chunk [61/1893] bpb=1.115821 time=16.7s + ttt_chunk [71/1893] bpb=1.114674 time=19.2s + ttt_chunk [81/1893] bpb=1.113366 time=21.8s + ttt_chunk [91/1893] bpb=1.114258 time=24.4s + ttt_chunk [101/1893] bpb=1.117719 time=27.0s + ttt_chunk [111/1893] bpb=1.120297 time=29.6s + ttt_chunk [121/1893] bpb=1.113485 time=32.2s + ttt_chunk [131/1893] bpb=1.113538 time=34.8s + ttt_chunk [141/1893] bpb=1.119008 time=37.2s + ttt_chunk [151/1893] bpb=1.120764 time=39.5s + ttt_chunk [161/1893] bpb=1.120335 time=41.8s + ttt_chunk [171/1893] bpb=1.124695 time=44.1s + ttt_chunk [181/1893] bpb=1.126860 time=46.4s + ttt_chunk [191/1893] bpb=1.134234 time=48.6s + ttt_chunk [201/1893] bpb=1.132865 time=50.9s + ttt_chunk [211/1893] bpb=1.130650 time=53.1s + ttt_chunk [221/1893] bpb=1.132106 time=55.4s + ttt_chunk [231/1893] bpb=1.130797 time=57.6s + ttt_chunk [241/1893] bpb=1.131151 time=59.9s + ttt_chunk [251/1893] bpb=1.130670 time=62.1s + ttt_chunk [261/1893] bpb=1.127722 time=64.4s + ttt_chunk [271/1893] bpb=1.126606 time=66.6s + ttt_chunk [281/1893] bpb=1.127953 time=68.9s + ttt_chunk [291/1893] bpb=1.129728 time=71.1s + ttt_chunk [301/1893] bpb=1.130487 time=73.4s + ttt_chunk [311/1893] bpb=1.132594 time=75.6s + ttt_chunk [321/1893] bpb=1.134641 time=78.0s + ttt_chunk [331/1893] bpb=1.134524 time=80.4s + ttt_chunk [341/1893] bpb=1.133536 time=82.8s + ttt_chunk [351/1893] bpb=1.135880 time=85.2s + ttt_chunk [361/1893] bpb=1.136115 time=87.5s + ttt_chunk [371/1893] bpb=1.135469 time=89.9s + ttt_chunk [381/1893] bpb=1.135609 time=92.2s + ttt_chunk [391/1893] bpb=1.135480 time=94.5s + ttt_chunk [401/1893] bpb=1.133385 time=96.7s + ttt_chunk [411/1893] bpb=1.132205 time=99.0s + ttt_chunk [421/1893] bpb=1.131322 time=101.2s + ttt_chunk [431/1893] bpb=1.131216 time=103.5s + ttt_chunk [441/1893] bpb=1.131554 time=105.7s + ttt_chunk [451/1893] bpb=1.131902 time=108.0s + ttt_chunk [461/1893] bpb=1.130875 time=110.3s + ttt_chunk [471/1893] bpb=1.131539 time=112.5s + ttt_chunk [481/1893] bpb=1.131143 time=114.8s + ttt_chunk [491/1893] bpb=1.130071 time=117.1s + ttt_chunk [501/1893] bpb=1.129645 time=119.3s + ttt_chunk [511/1893] bpb=1.128984 time=121.6s + ttt_chunk [521/1893] bpb=1.126643 time=123.8s + ttt_chunk [531/1893] bpb=1.127851 time=126.1s + ttt_chunk [541/1893] bpb=1.128224 time=128.3s + ttt_chunk [551/1893] bpb=1.127218 time=130.6s + ttt_chunk [561/1893] bpb=1.127746 time=132.9s + ttt_chunk [571/1893] bpb=1.126747 time=135.1s + ttt_chunk [581/1893] bpb=1.125979 time=137.4s + ttt_chunk [591/1893] bpb=1.125397 time=139.6s + ttt_chunk [601/1893] bpb=1.125898 time=141.9s + ttt_chunk [611/1893] bpb=1.125859 time=144.2s + ttt_chunk [621/1893] bpb=1.125791 time=146.4s + ttt_chunk [631/1893] bpb=1.126510 time=148.7s + ttt_chunk [641/1893] bpb=1.126293 time=151.1s + ttt_chunk [651/1893] bpb=1.126428 time=153.5s + ttt_chunk [661/1893] bpb=1.125903 time=155.8s + ttt_chunk [671/1893] bpb=1.126267 time=158.2s + ttt_chunk [681/1893] bpb=1.126983 time=160.6s + ttt_chunk [691/1893] bpb=1.128003 time=163.0s + ttt_chunk [701/1893] bpb=1.127435 time=165.2s + ttt_chunk [711/1893] bpb=1.127429 time=167.5s + ttt_chunk [721/1893] bpb=1.127089 time=169.8s + ttt_chunk [731/1893] bpb=1.127105 time=172.0s + ttt_chunk [741/1893] bpb=1.127177 time=174.3s + ttt_chunk [751/1893] bpb=1.127015 time=176.5s + ttt_chunk [761/1893] bpb=1.126947 time=178.8s + ttt_chunk [771/1893] bpb=1.126646 time=181.0s + ttt_chunk [781/1893] bpb=1.127394 time=183.3s + ttt_chunk [791/1893] bpb=1.126967 time=185.5s + ttt_chunk [801/1893] bpb=1.127291 time=187.8s + ttt_chunk [811/1893] bpb=1.127037 time=190.1s + ttt_chunk [821/1893] bpb=1.126818 time=192.3s + ttt_chunk [831/1893] bpb=1.126631 time=194.6s + ttt_chunk [841/1893] bpb=1.125964 time=196.8s + ttt_chunk [851/1893] bpb=1.125690 time=199.1s + ttt_chunk [861/1893] bpb=1.125444 time=201.3s + ttt_chunk [871/1893] bpb=1.125696 time=203.6s + ttt_chunk [881/1893] bpb=1.125899 time=205.9s + ttt_chunk [891/1893] bpb=1.125483 time=208.2s + ttt_chunk [901/1893] bpb=1.125209 time=210.5s + ttt_chunk [911/1893] bpb=1.125311 time=212.7s + ttt_chunk [921/1893] bpb=1.125781 time=215.0s + ttt_chunk [931/1893] bpb=1.125751 time=217.3s + ttt_chunk [941/1893] bpb=1.125410 time=219.5s + ttt_chunk [951/1893] bpb=1.125797 time=221.8s + ttt_chunk [961/1893] bpb=1.125866 time=224.0s + ttt_chunk [971/1893] bpb=1.126712 time=226.3s + ttt_chunk [981/1893] bpb=1.126805 time=228.5s + ttt_chunk [991/1893] bpb=1.126801 time=230.8s + ttt_chunk [1001/1893] bpb=1.126733 time=233.0s + ttt_chunk [1011/1893] bpb=1.126526 time=235.3s + ttt_chunk [1021/1893] bpb=1.126842 time=237.5s + ttt_chunk [1031/1893] bpb=1.127309 time=239.8s + ttt_chunk [1041/1893] bpb=1.126969 time=242.0s + ttt_chunk [1051/1893] bpb=1.126730 time=244.3s + ttt_chunk [1061/1893] bpb=1.126774 time=246.7s + ttt_chunk [1071/1893] bpb=1.127356 time=249.3s + ttt_chunk [1081/1893] bpb=1.127653 time=251.8s + ttt_chunk [1091/1893] bpb=1.128419 time=254.3s + ttt_chunk [1101/1893] bpb=1.128422 time=256.6s + ttt_chunk [1111/1893] bpb=1.128265 time=258.9s + ttt_chunk [1121/1893] bpb=1.128077 time=261.2s + ttt_chunk [1131/1893] bpb=1.127957 time=263.7s + ttt_chunk [1141/1893] bpb=1.127664 time=266.3s + ttt_chunk [1151/1893] bpb=1.127668 time=268.8s + ttt_chunk [1161/1893] bpb=1.127274 time=271.1s + ttt_chunk [1171/1893] bpb=1.127594 time=273.3s + ttt_chunk [1181/1893] bpb=1.126854 time=275.6s + ttt_chunk [1191/1893] bpb=1.126731 time=277.9s + ttt_chunk [1201/1893] bpb=1.127156 time=280.1s + ttt_chunk [1211/1893] bpb=1.126694 time=282.6s + ttt_chunk [1221/1893] bpb=1.126393 time=285.1s + ttt_chunk [1231/1893] bpb=1.126114 time=287.7s + ttt_chunk [1241/1893] bpb=1.125781 time=290.2s + ttt_chunk [1251/1893] bpb=1.125195 time=292.8s + ttt_chunk [1261/1893] bpb=1.125177 time=295.3s + ttt_chunk [1271/1893] bpb=1.124812 time=297.8s + ttt_chunk [1281/1893] bpb=1.124630 time=300.1s + ttt_chunk [1291/1893] bpb=1.124399 time=302.4s + ttt_chunk [1301/1893] bpb=1.123817 time=304.6s + ttt_chunk [1311/1893] bpb=1.123435 time=306.9s + ttt_chunk [1321/1893] bpb=1.123118 time=309.1s + ttt_chunk [1331/1893] bpb=1.123075 time=311.4s + ttt_chunk [1341/1893] bpb=1.122952 time=313.6s + ttt_chunk [1351/1893] bpb=1.122890 time=315.9s + ttt_chunk [1361/1893] bpb=1.122934 time=318.1s + ttt_chunk [1371/1893] bpb=1.122812 time=320.4s + ttt_chunk [1381/1893] bpb=1.122799 time=322.6s + ttt_chunk [1391/1893] bpb=1.122409 time=324.8s + ttt_chunk [1401/1893] bpb=1.122373 time=327.1s + ttt_chunk [1411/1893] bpb=1.122481 time=329.3s + ttt_chunk [1421/1893] bpb=1.122736 time=331.6s + ttt_chunk [1431/1893] bpb=1.122443 time=333.8s + ttt_chunk [1441/1893] bpb=1.122945 time=336.1s + ttt_chunk [1451/1893] bpb=1.123286 time=338.4s + ttt_chunk [1461/1893] bpb=1.122835 time=340.6s + ttt_chunk [1471/1893] bpb=1.123891 time=342.9s + ttt_chunk [1481/1893] bpb=1.123440 time=345.1s + ttt_chunk [1491/1893] bpb=1.123271 time=347.4s + ttt_chunk [1501/1893] bpb=1.123187 time=349.6s + ttt_chunk [1511/1893] bpb=1.123210 time=351.9s + ttt_chunk [1521/1893] bpb=1.123226 time=354.1s + ttt_chunk [1531/1893] bpb=1.122703 time=356.3s + ttt_chunk [1541/1893] bpb=1.122582 time=358.8s + ttt_chunk [1551/1893] bpb=1.122886 time=361.1s + ttt_chunk [1561/1893] bpb=1.122896 time=363.3s + ttt_chunk [1571/1893] bpb=1.122735 time=365.6s + ttt_chunk [1581/1893] bpb=1.122840 time=367.8s + ttt_chunk [1591/1893] bpb=1.122693 time=370.1s + ttt_chunk [1601/1893] bpb=1.122859 time=372.5s + ttt_chunk [1611/1893] bpb=1.122799 time=374.8s + ttt_chunk [1621/1893] bpb=1.122396 time=377.2s + ttt_chunk [1631/1893] bpb=1.122711 time=379.5s + ttt_chunk [1641/1893] bpb=1.122729 time=381.9s + ttt_chunk [1651/1893] bpb=1.122697 time=384.2s + ttt_chunk [1661/1893] bpb=1.122572 time=386.6s + ttt_chunk [1671/1893] bpb=1.123045 time=389.0s + ttt_chunk [1681/1893] bpb=1.123201 time=391.5s + ttt_chunk [1691/1893] bpb=1.123036 time=393.9s + ttt_chunk [1701/1893] bpb=1.123203 time=396.3s + ttt_chunk [1711/1893] bpb=1.123207 time=398.5s + ttt_chunk [1721/1893] bpb=1.123209 time=400.8s + ttt_chunk [1731/1893] bpb=1.123097 time=403.1s + ttt_chunk [1741/1893] bpb=1.122899 time=405.3s + ttt_chunk [1751/1893] bpb=1.122737 time=407.6s + ttt_chunk [1761/1893] bpb=1.122879 time=409.8s + ttt_chunk [1771/1893] bpb=1.122783 time=412.1s + ttt_chunk [1781/1893] bpb=1.122812 time=414.3s + ttt_chunk [1791/1893] bpb=1.122421 time=416.6s + ttt_chunk [1801/1893] bpb=1.122302 time=418.8s + ttt_chunk [1811/1893] bpb=1.122203 time=421.1s + ttt_chunk [1821/1893] bpb=1.122273 time=423.3s + ttt_chunk [1831/1893] bpb=1.121673 time=425.6s + ttt_chunk [1841/1893] bpb=1.121646 time=427.8s + ttt_chunk [1851/1893] bpb=1.121450 time=430.1s + ttt_chunk [1861/1893] bpb=1.121083 time=432.3s + ttt_chunk [1871/1893] bpb=1.121072 time=434.6s + ttt_chunk [1881/1893] bpb=1.120618 time=436.9s + ttt_chunk [1891/1893] bpb=1.120396 time=439.3s + ttt_chunk [1893/1893] bpb=1.120441 time=439.6s +ttt_sliding:done val_loss=1.888118 val_bpb=1.118253 elapsed=439.6s +legal_ttt val_loss:1.8881 val_bpb:1.1183 eval_time:440025ms +legal_ttt_exact val_loss:1.88811812 val_bpb:1.11825287 From 3162c39ca919a9f8fc9057efc3380ed03aecd0d4 Mon Sep 17 00:00:00 2001 From: Marko Sisovic Date: Wed, 25 Mar 2026 16:28:10 +0100 Subject: [PATCH 8/8] Polishing readme --- .../2026-03-25_RecurLayers/README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/README.md b/records/track_10min_16mb/2026-03-25_RecurLayers/README.md index b6a1951bb..da3e00f33 100644 --- a/records/track_10min_16mb/2026-03-25_RecurLayers/README.md +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/README.md @@ -47,19 +47,19 @@ Everything else (TTT, int6 quantization, SWA, bigram embeddings, value embedding - **Mean val_bpb: 1.11840** (std: 0.00049) - Training: ~6,100 steps in ~600s - Model params: ~27M -- Artifact size: ~15.9MB (int6+lzma) +- Mean total submission size: 15,931,152 bytes (~15.9MB, int6+lzma) ## Reproducibility Three independent training runs with different random seeds: -| Seed | val_loss | val_bpb | -|------|----------|---------| -| 1337 | 1.88749538 | 1.11788404 | -| 2025 | 1.88948575 | 1.11906285 | -| 2024 | 1.88811812 | 1.11825287 | -| **Mean** | **1.88836642** | **1.11839992** | -| **Std** | **0.00083132** | **0.00049235** | +| Seed | val_loss | val_bpb | total_bytes | +|------|----------|---------|-------------| +| 1337 | 1.88749538 | 1.11788404 | 15,928,948 | +| 2025 | 1.88948575 | 1.11906285 | 15,934,932 | +| 2024 | 1.88811812 | 1.11825287 | 15,929,576 | +| **Mean** | **1.88836642** | **1.11839992** | **15,931,152** | +| **Std** | **0.00083132** | **0.00049235** | | ## Run Commands