diff --git a/records/track_10min_16mb/2026-03-19_QAT_Int6_MLP3_SlidingWindow/README.md b/records/track_10min_16mb/2026-03-19_QAT_Int6_MLP3_SlidingWindow/README.md new file mode 100644 index 000000000..d38424994 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_QAT_Int6_MLP3_SlidingWindow/README.md @@ -0,0 +1,56 @@ +# QAT Int6 + MLP3 + Sliding Window + Overtone Init + +**Track:** 10min_16mb +**Status:** code snapshot only; no verified H100 training log checked in yet + +This folder captures the first submission-style snapshot of the int6/QAT branch. It is organized like a record folder so the implementation can be reviewed and reproduced, but it is not yet a complete leaderboard submission because the corresponding `train.log` and final track-compliant metrics are not included. + +## Core changes in this snapshot + +### Int6 quantization-aware training +The model trains with fake per-row quantization in the forward pass so weights learn to tolerate export at int6 precision. This is controlled by `USE_QAT=1 QAT_BITS=6`. + +### Int6 packed export +Large matrices are quantized to int6 and packed `4 values -> 3 bytes`, reducing payload size versus int8 and freeing budget for more parameters under the 16MB cap. This is controlled by `INT6_EXPORT=1`. + +### Explicit MLP sizing +`MLP_HIDDEN=992` replaces the coarse `MLP_MULT` knob with an exact hidden width so the parameter budget can be targeted much more tightly. + +### Sliding-window validation +`SW_STRIDE=256` and `EVAL_SEQ_LEN=1408` score tokens with more left context than the training sequence length, using NTK-RoPE extrapolation to extend context at eval time. + +### Initialization and schedule changes +The snapshot also includes overtone embedding initialization, a depth-dependent `resid_mix` initialization, and full-budget warmdown to reduce quantization damage near the end of training. + +## Target configuration + +```bash +NUM_LAYERS=11 +MODEL_DIM=448 +NUM_HEADS=8 +NUM_KV_HEADS=4 +MLP_HIDDEN=992 +USE_QAT=1 +QAT_BITS=6 +INT6_EXPORT=1 +SW_STRIDE=256 +EVAL_SEQ_LEN=1408 +WARMDOWN_ITERS=20000 +MATRIX_LR=0.06 +TIED_EMBED_LR=0.07 +SCALAR_LR=0.06 +ADAM_WEIGHT_DECAY=0.01 +MUON_BACKEND_STEPS=5 +``` + +## Included files + +- `train_gpt.py`: self-contained code snapshot for this branch state +- `README.md`: implementation notes and intended run configuration +- `submission.json`: metadata placeholder for later completion with verified run outputs + +## Missing for a merge-ready record PR + +- `train.log` from the actual run +- final `val_loss`, `val_bpb`, artifact bytes, and code bytes +- confirmation that the run meets the `10min_16mb` track constraints diff --git a/records/track_10min_16mb/2026-03-19_QAT_Int6_MLP3_SlidingWindow/submission.json b/records/track_10min_16mb/2026-03-19_QAT_Int6_MLP3_SlidingWindow/submission.json new file mode 100644 index 000000000..377cfdc18 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_QAT_Int6_MLP3_SlidingWindow/submission.json @@ -0,0 +1,12 @@ +{ + "author": "Sheldon", + "github_id": "Aristide021", + "name": "QAT Int6 + MLP3 + Sliding Window + Overtone Init", + "blurb": "Submission-style code snapshot for the first int6 QAT branch: per-row int6 QAT, packed int6 export, explicit MLP_HIDDEN sizing, sliding-window eval, overtone embedding init, phase-transition resid_mix init, and full-budget warmdown. Metrics are intentionally left null because no verified H100 train log is checked in with this folder yet.", + "date": "2026-03-19T00:00:00Z", + "track": "10min_16mb", + "val_loss": null, + "val_bpb": null, + "bytes_total": null, + "bytes_code": null +} diff --git a/records/track_10min_16mb/2026-03-19_QAT_Int6_MLP3_SlidingWindow/train_gpt.py b/records/track_10min_16mb/2026-03-19_QAT_Int6_MLP3_SlidingWindow/train_gpt.py new file mode 100644 index 000000000..4f3aa791d --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_QAT_Int6_MLP3_SlidingWindow/train_gpt.py @@ -0,0 +1,1440 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_hidden = int(os.environ.get("MLP_HIDDEN", "0")) # 0 = mlp_mult*model_dim; else explicit + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + adam_weight_decay = float(os.environ.get("ADAM_WEIGHT_DECAY", "0.0")) # e.g. 0.01 + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # v2: QAT + sliding-window eval + use_qat: bool = bool(int(os.environ.get("USE_QAT", "0"))) + qat_bits: int = int(os.environ.get("QAT_BITS", "8")) # 8=int8 (±127), 6=int6 (±31) + sw_stride: int = int(os.environ.get("SW_STRIDE", "0")) # 0=disabled; 64 or 256 typical + eval_seq_len: int = int(os.environ.get("EVAL_SEQ_LEN", "0")) # 0=train_seq_len; e.g. 1408 for NTK extrapolation + int6_export: bool = bool(int(os.environ.get("INT6_EXPORT", "0"))) # pack 4 int6 vals→3 bytes + +# ----------------------------- +# QAT (QUANTIZATION-AWARE TRAINING) +# ----------------------------- +# Fake per-row quantization with Straight-Through Estimator (STE). +# Trained weights stay in fp32; forward pass sees dequantized values so the +# model learns to be robust to the exact quantization used at export. +# _QAT_MAX_VAL is set by init_qat_config() once per run: +# int8 → 127, int6 → 31. + +_USE_QAT: bool = False +_QAT_MAX_VAL: float = 127.0 + + +def init_qat_config(use_qat: bool, qat_bits: int) -> None: + global _USE_QAT, _QAT_MAX_VAL + _USE_QAT = use_qat + _QAT_MAX_VAL = float((2 ** (qat_bits - 1)) - 1) + + +def fake_quantize_per_row(w: Tensor) -> Tensor: + """STE fake quantization. Only applied to 2-D tensors large enough to be + int-quantized at export (numel > INT8_KEEP_FLOAT_MAX_NUMEL = 65 536). + Forward: returns dequantized w. Backward: identity (straight-through). + """ + if w.ndim != 2 or w.numel() <= 65_536: + return w + max_val = _QAT_MAX_VAL + w_f32 = w.float() + scale = (w_f32.abs().amax(dim=1, keepdim=True) / max_val).clamp(min=1.0 / max_val) + w_q = (w_f32 / scale).clamp(-max_val, max_val).round() * scale + # STE: forward = w_q, backward = identity through w + return w + (w_q.to(w.dtype) - w).detach() + + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + device_type = device.type + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=(device_type == "cuda")): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + # On non-CUDA devices, inference_mode creates cached rotary tensors that + # cannot be saved for backward. Reset them so the training step recomputes. + for m in model.modules(): + if isinstance(m, Rotary): + m._cos_cached = None + m._sin_cached = None + m._seq_len_cached = 0 + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding_window( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + sw_stride: int, + eval_seq_len: int = 0, +) -> tuple[float, float]: + """Sliding-window eval: each token sees up to seq_len tokens of left context. + Windows advance by sw_stride each step; only the last sw_stride positions of + each non-first window are scored (the first window scores all positions). + This gives every token more context than the non-overlapping chunk eval. + Uses base_model (uncompiled) directly since it runs once at the end. + eval_seq_len overrides args.train_seq_len when > 0 (e.g. NTK-RoPE extrapolation). + """ + seq_len = eval_seq_len if eval_seq_len > 0 else args.train_seq_len + total_tokens = val_tokens.numel() - 1 # number of (input, target) pairs + window_starts = list(range(0, total_tokens, sw_stride)) + rank_windows = window_starts[rank::world_size] + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + device_type = device.type + with torch.inference_mode(): + for ws in rank_windows: + win_len = min(seq_len, total_tokens - ws) + if win_len <= 0: + continue + input_tok = val_tokens[ws : ws + win_len].to(device=device, dtype=torch.int64).unsqueeze(0) + target_tok = val_tokens[ws + 1 : ws + win_len + 1].to(device=device, dtype=torch.int64).unsqueeze(0) + eval_from = 0 if ws == 0 else max(win_len - sw_stride, 0) + with torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=(device_type == "cuda")): + logits = base_model.forward_logits(input_tok) # [1, T, vocab] + logits_eval = logits[0, eval_from:, :].float() + targets_eval = target_tok[0, eval_from:] + n_scored = targets_eval.numel() + if n_scored == 0: + continue + val_loss_sum += F.cross_entropy(logits_eval, targets_eval.long(), reduction="sum").to(torch.float64) + val_token_count += n_scored + prev_ids = input_tok[0, eval_from:] + token_bytes = base_bytes_lut[targets_eval].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[targets_eval] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = float((val_loss_sum / val_token_count).item()) + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return val_loss, float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +INT6_MAX_VAL = 31 # signed 6-bit range: [-31, 31] + +def pack_int6(arr: np.ndarray, n: int) -> np.ndarray: + """Pack n int6 values in [-31, 31] into a byte array: 4 values → 3 bytes (25% smaller).""" + u = (arr[:n].astype(np.int32) + INT6_MAX_VAL).astype(np.uint8) # offset to [0, 62] + pad = (-n) % 4 + if pad: + u = np.append(u, np.zeros(pad, dtype=np.uint8)) + u = u.reshape(-1, 4) + out = np.empty((len(u), 3), dtype=np.uint8) + out[:, 0] = (u[:, 0] & 0x3F) | ((u[:, 1] & 0x03) << 6) + out[:, 1] = ((u[:, 1] >> 2) & 0x0F) | ((u[:, 2] & 0x0F) << 4) + out[:, 2] = ((u[:, 2] >> 4) & 0x03) | ((u[:, 3] & 0x3F) << 2) + return out.reshape(-1) + +def unpack_int6(packed: np.ndarray, n: int) -> np.ndarray: + """Unpack int6 byte array back to int8 values in [-31, 31], length n.""" + p = packed.reshape(-1, 3) + u = np.empty((len(p), 4), dtype=np.uint8) + u[:, 0] = p[:, 0] & 0x3F + u[:, 1] = ((p[:, 0] >> 6) & 0x03) | ((p[:, 1] & 0x0F) << 2) + u[:, 2] = ((p[:, 1] >> 4) & 0x0F) | ((p[:, 2] & 0x03) << 4) + u[:, 3] = (p[:, 2] >> 2) & 0x3F + return (u.reshape(-1)[:n].astype(np.int32) - INT6_MAX_VAL).astype(np.int8) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict_int6(state_dict: dict[str, Tensor]): + """Int6 export: large 2D tensors packed 4 vals→3 bytes (25% smaller than int8). + Non-2D large floats fall back to int8 per-tensor.""" + quantized: dict[str, bytes] = {} # int6-packed byte strings for 2D matrices + quantized_int8: dict[str, Tensor] = {} # int8 fallback for non-2D large tensors + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", + "baseline_tensor_bytes", "int6_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int6_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int6_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + t32 = t.float() + if t32.ndim == 2: + # Per-row int6: quantize to ±31, pack 4 values into 3 bytes. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + scale = (clip_abs / float(INT6_MAX_VAL)).clamp_min(1.0 / float(INT6_MAX_VAL)) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + q_np = torch.clamp( + torch.round(clipped / scale[:, None]), -INT6_MAX_VAL, INT6_MAX_VAL + ).to(torch.int8).numpy() + n = q_np.size + packed = pack_int6(q_np.reshape(-1), n) + quantized[name] = packed.tobytes() + scales[name] = scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + dtypes[name] = str(t.dtype).removeprefix("torch.") + qmeta[name] = {"scheme": "int6_per_row_packed", "rows": t32.shape[0], "orig_numel": n} + stats["int6_payload_bytes"] += len(packed) + tensor_nbytes(scales[name]) + else: + # Non-2D tensors: int8 per-tensor fallback. + clip_f = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + s_t = torch.tensor(clip_f / 127.0 if clip_f > 0 else 1.0, dtype=torch.float32) + q_t = torch.clamp( + torch.round(torch.clamp(t32, -clip_f, clip_f) / s_t), -127, 127 + ).to(torch.int8).contiguous() + quantized_int8[name] = q_t + scales[name] = s_t + dtypes[name] = str(t.dtype).removeprefix("torch.") + qmeta[name] = {"scheme": "int8_per_tensor_fallback"} + stats["int6_payload_bytes"] += tensor_nbytes(q_t) + tensor_nbytes(s_t) + obj: dict[str, object] = { + "__quant_format__": "int6_packed_per_row_v1", + "quantized": quantized, + "quantized_int8": quantized_int8, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def dequantize_state_dict_int6(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed_bytes in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name].to(dtype=torch.float32) + meta = qmeta.get(name, {}) + rows, n = int(meta["rows"]), int(meta["orig_numel"]) + packed = np.frombuffer(packed_bytes, dtype=np.uint8) + q_flat = unpack_int6(packed, n) # int8 vals in [-31, 31] + q = torch.from_numpy(q_flat.reshape(rows, -1)) + out[name] = (q.float() * s.view(rows, 1)).to(dtype=dtype).contiguous() + for name, q in obj.get("quantized_int8", {}).items(): + dtype = getattr(torch, obj["dtypes"][name]) + out[name] = (q.float() * float(obj["scales"][name].item())).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + # When QAT is active and the module is in training mode, weights are fake-quantized first. + def forward(self, x: Tensor) -> Tensor: + w = fake_quantize_per_row(self.weight) if (_USE_QAT and self.training) else self.weight + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): + super().__init__() + hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + mlp_hidden: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, mlp_hidden) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + mlp_hidden: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + mlp_hidden=mlp_hidden, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + # Overtone init: shape embedding singular value spectrum to power-law decay S_k ~ k^{-0.5} + # so the embedding matrix has a natural frequency distribution (like guitar harmonics). + with torch.no_grad(): + U, S, V = torch.linalg.svd(self.tok_emb.weight.data, full_matrices=False) + target_S = S[0] * (1.0 / torch.arange(1, S.shape[0] + 1, dtype=S.dtype)) ** 0.5 + self.tok_emb.weight.data = (U * target_S[None, :]) @ V + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + # Phase-transition resid_mix init: early layers blend toward x0, late layers stay in residual stream. + num_layers = len(self.blocks) + for i, block in enumerate(self.blocks): + if hasattr(block, "resid_mix"): + with torch.no_grad(): + phase = float(torch.sigmoid(torch.tensor(3.0 * (i / max(num_layers - 1, 1) - 0.5)))) + block.resid_mix.data[0].fill_(phase) + block.resid_mix.data[1].fill_(1.0 - phase) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + emb_w = self.tok_emb.weight + if _USE_QAT and self.training: + emb_w = fake_quantize_per_row(emb_w) + logits_proj = F.linear(x, emb_w) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Forward pass returning raw logits [bsz, seq, vocab]. Used for sliding-window eval.""" + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) # [bsz, seq, dim] + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) # type: ignore[misc] + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + init_qat_config(args.use_qat, args.qat_bits) # sets _USE_QAT and _QAT_MAX_VAL + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) if torch.cuda.is_available() else zeropower_via_newtonschulz5 + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + cuda_optional = os.environ.get("CUDA_OPTIONAL", "0") == "1" + if not torch.cuda.is_available(): + if not cuda_optional: + raise RuntimeError("CUDA is required (set CUDA_OPTIONAL=1 for local CPU/MPS testing)") + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + else: + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs (CUDA only) + if device.type == "cuda": + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout if device.type == "cuda" else "(no nvidia-smi: non-CUDA device)", + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + _effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_tokens = load_validation_tokens(args.val_files, max(args.train_seq_len, _effective_eval_seq_len)) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + mlp_hidden=args.mlp_hidden, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if device.type == "cuda" else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + use_fused_adam = device.type == "cuda" + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_weight_decay, + fused=use_fused_adam, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_weight_decay, + fused=use_fused_adam, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=use_fused_adam, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"v2_features use_qat:{args.use_qat} qat_bits:{args.qat_bits} qat_max_val:{_QAT_MAX_VAL} sw_stride:{args.sw_stride}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=(device.type == "cuda")): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + if device.type == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + if device.type == "cuda": + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + if device.type == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=(device.type == "cuda")): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + if device.type == "cuda": + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Select int6 or int8 quantization based on INT6_EXPORT flag. + if args.int6_export: + _quant_fn, _dequant_fn = quantize_state_dict_int6, dequantize_state_dict_int6 + _qext, _payload_key = "int6", "int6_payload_bytes" + else: + _quant_fn, _dequant_fn = quantize_state_dict_int8, dequantize_state_dict_int8 + _qext, _payload_key = "int8", "int8_payload_bytes" + quant_obj, quant_stats = _quant_fn(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + quant_fname = f"final_model.{_qext}.ptz" + if master_process: + with open(quant_fname, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_fname) + code_bytes = len(code.encode("utf-8")) + payload_b = quant_stats.get(_payload_key, 0) + ratio = quant_stats["baseline_tensor_bytes"] / max(payload_b, 1) + log0( + f"serialized_model_{_qext}_zlib {quant_file_bytes} bytes " + f"(payload:{payload_b} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {_qext}+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_fname, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(_dequant_fn(quant_state), strict=True) + if device.type == "cuda": + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.sw_stride > 0: + log0(f"final_eval_mode:sliding_window sw_stride:{args.sw_stride} eval_seq_len:{_effective_eval_seq_len}") + q_val_loss, q_val_bpb = eval_val_sliding_window( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + args.sw_stride, eval_seq_len=_effective_eval_seq_len, + ) + else: + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + if device.type == "cuda": + torch.cuda.synchronize() + log0( + f"final_{_qext}_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{_qext}_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + # Alias kept for backward-compat with log parsers that expect the int8 key name. + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-21_SWA_BigramHash_SmearGate_Int5MLP_MuonWD/README.md b/records/track_10min_16mb/2026-03-21_SWA_BigramHash_SmearGate_Int5MLP_MuonWD/README.md new file mode 100644 index 000000000..182bfc249 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_SWA_BigramHash_SmearGate_Int5MLP_MuonWD/README.md @@ -0,0 +1,73 @@ +# SWA + BigramHash + SmearGate + Int5MLP + MuonWD + zstd-22 + +**Track:** 10min_16mb +**Status:** code snapshot only; no verified H100 training log checked in yet + +This folder captures the follow-on submission-style snapshot that builds on the March 19 int6/QAT branch. As with the earlier folder, this is organized for code review and later reproduction, but it is not yet a complete record submission because the run log and final measured metrics are absent. + +## Core changes in this snapshot + +### SWA during warmdown +During the low-learning-rate phase, the script periodically collects checkpoints and averages them before export. The intent is to smooth the final basin and reduce post-quantization variance. + +### BigramHash embedding +A hashed bigram table is added on top of the unigram token embedding to inject cheap short-range contextual features before the transformer stack. + +### SmearGate +A learned per-dimension gate blends each token embedding with the previous token embedding before attention, adding another low-cost local-context path. + +### Mixed int5/int6 quantization +MLP weights are quantized more aggressively than attention weights, recovering artifact budget that can be spent on model capacity elsewhere. + +### Muon weight decay and zstd compression +The snapshot adds direct weight decay inside Muon and uses `zstd` level 22 when available, with graceful fallback to `zlib`. + +## Inherited stack + +- int6 QAT +- packed int6 export path +- explicit `MLP_HIDDEN=992` +- sliding-window evaluation with longer eval context +- overtone embedding init +- depth-scheduled `resid_mix` +- aggressive full-budget warmdown + +## Target configuration + +```bash +NUM_LAYERS=11 +MODEL_DIM=448 +NUM_HEADS=8 +NUM_KV_HEADS=4 +MLP_HIDDEN=992 +USE_QAT=1 +QAT_BITS=6 +INT6_EXPORT=1 +SW_STRIDE=256 +EVAL_SEQ_LEN=1408 +WARMDOWN_ITERS=20000 +SWA_ENABLED=1 +SWA_START_FRAC=0.4 +SWA_EVERY=50 +BIGRAM_VOCAB_SIZE=10240 +BIGRAM_DIM=64 +USE_SMEAR_GATE=1 +MUON_WEIGHT_DECAY=0.04 +MATRIX_LR=0.06 +TIED_EMBED_LR=0.07 +SCALAR_LR=0.06 +ADAM_WEIGHT_DECAY=0.01 +MUON_BACKEND_STEPS=5 +``` + +## Included files + +- `train_gpt.py`: self-contained code snapshot for this branch state +- `README.md`: implementation notes and intended run configuration +- `submission.json`: metadata placeholder for later completion with verified run outputs + +## Missing for a merge-ready record PR + +- `train.log` from the actual run +- final `val_loss`, `val_bpb`, artifact bytes, and code bytes +- confirmation that the run meets the `10min_16mb` track constraints diff --git a/records/track_10min_16mb/2026-03-21_SWA_BigramHash_SmearGate_Int5MLP_MuonWD/submission.json b/records/track_10min_16mb/2026-03-21_SWA_BigramHash_SmearGate_Int5MLP_MuonWD/submission.json new file mode 100644 index 000000000..f2e1018ca --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_SWA_BigramHash_SmearGate_Int5MLP_MuonWD/submission.json @@ -0,0 +1,12 @@ +{ + "author": "Sheldon", + "github_id": "Aristide021", + "name": "SWA + BigramHash(10240) + SmearGate + Int5MLP + MuonWD(0.04) + zstd-22", + "blurb": "Submission-style code snapshot for the March 21 branch that layers SWA, BigramHash embeddings, SmearGate, mixed int5/int6 quantization, Muon weight decay, and zstd compression on top of the earlier int6/QAT stack. Metrics are intentionally left null because no verified H100 train log is checked in with this folder yet.", + "date": "2026-03-21T00:00:00Z", + "track": "10min_16mb", + "val_loss": null, + "val_bpb": null, + "bytes_total": null, + "bytes_code": null +} diff --git a/records/track_10min_16mb/2026-03-21_SWA_BigramHash_SmearGate_Int5MLP_MuonWD/train_gpt.py b/records/track_10min_16mb/2026-03-21_SWA_BigramHash_SmearGate_Int5MLP_MuonWD/train_gpt.py new file mode 100644 index 000000000..ac0657b67 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_SWA_BigramHash_SmearGate_Int5MLP_MuonWD/train_gpt.py @@ -0,0 +1,1500 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard as _zstd + _COMPRESSOR = "zstd" +except ImportError: + _zstd = None # type: ignore[assignment] + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_hidden = int(os.environ.get("MLP_HIDDEN", "0")) # 0 = mlp_mult*model_dim; else explicit + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + adam_weight_decay = float(os.environ.get("ADAM_WEIGHT_DECAY", "0.0")) # e.g. 0.01 + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # v2: QAT + sliding-window eval + use_qat: bool = bool(int(os.environ.get("USE_QAT", "0"))) + qat_bits: int = int(os.environ.get("QAT_BITS", "8")) # 8=int8 (±127), 6=int6 (±31) + sw_stride: int = int(os.environ.get("SW_STRIDE", "0")) # 0=disabled; 64 or 256 typical + eval_seq_len: int = int(os.environ.get("EVAL_SEQ_LEN", "0")) # 0=train_seq_len; e.g. 1408 for NTK extrapolation + int6_export: bool = bool(int(os.environ.get("INT6_EXPORT", "0"))) # pack 4 int6 vals→3 bytes + + # v3: SWA + BigramHash + SmearGate + Muon WD + Int5MLP + zstd + swa_enabled: bool = bool(int(os.environ.get("SWA_ENABLED", "0"))) + swa_start_frac: float = float(os.environ.get("SWA_START_FRAC", "0.4")) # collect when lr_scale < this + swa_every: int = int(os.environ.get("SWA_EVERY", "50")) # checkpoint every N steps + bigram_vocab_size: int = int(os.environ.get("BIGRAM_VOCAB_SIZE", "0")) # 0=disabled; 10240 typical + bigram_dim: int = int(os.environ.get("BIGRAM_DIM", "64")) + use_smear_gate: bool = bool(int(os.environ.get("USE_SMEAR_GATE", "0"))) + muon_weight_decay: float = float(os.environ.get("MUON_WEIGHT_DECAY", "0.0")) # e.g. 0.04 + +# ----------------------------- +# QAT (QUANTIZATION-AWARE TRAINING) +# ----------------------------- +# Fake per-row quantization with Straight-Through Estimator (STE). +# Trained weights stay in fp32; forward pass sees dequantized values so the +# model learns to be robust to the exact quantization used at export. +# _QAT_MAX_VAL is set by init_qat_config() once per run: +# int8 → 127, int6 → 31. + +_USE_QAT: bool = False +_QAT_MAX_VAL: float = 127.0 + + +def init_qat_config(use_qat: bool, qat_bits: int) -> None: + global _USE_QAT, _QAT_MAX_VAL + _USE_QAT = use_qat + _QAT_MAX_VAL = float((2 ** (qat_bits - 1)) - 1) + + +def fake_quantize_per_row(w: Tensor) -> Tensor: + """STE fake quantization. Only applied to 2-D tensors large enough to be + int-quantized at export (numel > INT8_KEEP_FLOAT_MAX_NUMEL = 65 536). + Forward: returns dequantized w. Backward: identity (straight-through). + """ + if w.ndim != 2 or w.numel() <= 65_536: + return w + max_val = _QAT_MAX_VAL + w_f32 = w.float() + scale = (w_f32.abs().amax(dim=1, keepdim=True) / max_val).clamp(min=1.0 / max_val) + w_q = (w_f32 / scale).clamp(-max_val, max_val).round() * scale + # STE: forward = w_q, backward = identity through w + return w + (w_q.to(w.dtype) - w).detach() + + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + device_type = device.type + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=(device_type == "cuda")): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + # On non-CUDA devices, inference_mode creates cached rotary tensors that + # cannot be saved for backward. Reset them so the training step recomputes. + for m in model.modules(): + if isinstance(m, Rotary): + m._cos_cached = None + m._sin_cached = None + m._seq_len_cached = 0 + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding_window( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + sw_stride: int, + eval_seq_len: int = 0, +) -> tuple[float, float]: + """Sliding-window eval: each token sees up to seq_len tokens of left context. + Windows advance by sw_stride each step; only the last sw_stride positions of + each non-first window are scored (the first window scores all positions). + This gives every token more context than the non-overlapping chunk eval. + Uses base_model (uncompiled) directly since it runs once at the end. + eval_seq_len overrides args.train_seq_len when > 0 (e.g. NTK-RoPE extrapolation). + """ + seq_len = eval_seq_len if eval_seq_len > 0 else args.train_seq_len + total_tokens = val_tokens.numel() - 1 # number of (input, target) pairs + window_starts = list(range(0, total_tokens, sw_stride)) + rank_windows = window_starts[rank::world_size] + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + device_type = device.type + with torch.inference_mode(): + for ws in rank_windows: + win_len = min(seq_len, total_tokens - ws) + if win_len <= 0: + continue + input_tok = val_tokens[ws : ws + win_len].to(device=device, dtype=torch.int64).unsqueeze(0) + target_tok = val_tokens[ws + 1 : ws + win_len + 1].to(device=device, dtype=torch.int64).unsqueeze(0) + eval_from = 0 if ws == 0 else max(win_len - sw_stride, 0) + with torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=(device_type == "cuda")): + logits = base_model.forward_logits(input_tok) # [1, T, vocab] + logits_eval = logits[0, eval_from:, :].float() + targets_eval = target_tok[0, eval_from:] + n_scored = targets_eval.numel() + if n_scored == 0: + continue + val_loss_sum += F.cross_entropy(logits_eval, targets_eval.long(), reduction="sum").to(torch.float64) + val_token_count += n_scored + prev_ids = input_tok[0, eval_from:] + token_bytes = base_bytes_lut[targets_eval].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[targets_eval] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = float((val_loss_sum / val_token_count).item()) + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return val_loss, float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def _classify_param(name: str) -> str: + """Classify a parameter by its role for mixed-precision quantization.""" + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name: + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + """Per-row int-N quantization with given clip_range (31=int6, 15=int5).""" + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range + 1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range + 1), clip_range).to(torch.int8) + return q, scale + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + """Mixed Int5/Int6 quantization: MLP weights use int5 (±15), attn uses int6 (±31). + Small/control tensors pass through as fp16 or fp32. Produces a flat result dict + keyed as name+'.q' / name+'.scale' plus a meta dict.""" + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + # When QAT is active and the module is in training mode, weights are fake-quantized first. + def forward(self, x: Tensor) -> Tensor: + w = fake_quantize_per_row(self.weight) if (_USE_QAT and self.training) else self.weight + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): + super().__init__() + hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's, via learned gate.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + mlp_hidden: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, mlp_hidden) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + mlp_hidden: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 64, + use_smear_gate: bool = False, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) if use_smear_gate else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + mlp_hidden=mlp_hidden, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + # Overtone init: shape embedding singular value spectrum to power-law decay S_k ~ k^{-0.5} + # so the embedding matrix has a natural frequency distribution (like guitar harmonics). + with torch.no_grad(): + U, S, V = torch.linalg.svd(self.tok_emb.weight.data, full_matrices=False) + target_S = S[0] * (1.0 / torch.arange(1, S.shape[0] + 1, dtype=S.dtype)) ** 0.5 + self.tok_emb.weight.data = (U * target_S[None, :]) @ V + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + # Phase-transition resid_mix init: early layers blend toward x0, late layers stay in residual stream. + num_layers = len(self.blocks) + for i, block in enumerate(self.blocks): + if hasattr(block, "resid_mix"): + with torch.no_grad(): + phase = float(torch.sigmoid(torch.tensor(3.0 * (i / max(num_layers - 1, 1) - 0.5)))) + block.resid_mix.data[0].fill_(phase) + block.resid_mix.data[1].fill_(1.0 - phase) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.smear is not None: + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + emb_w = self.tok_emb.weight + if _USE_QAT and self.training: + emb_w = fake_quantize_per_row(emb_w) + logits_proj = F.linear(x, emb_w) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Forward pass returning raw logits [bsz, seq, vocab]. Used for sliding-window eval.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.smear is not None: + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) # [bsz, seq, dim] + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) # type: ignore[misc] + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + init_qat_config(args.use_qat, args.qat_bits) # sets _USE_QAT and _QAT_MAX_VAL + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) if torch.cuda.is_available() else zeropower_via_newtonschulz5 + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + cuda_optional = os.environ.get("CUDA_OPTIONAL", "0") == "1" + if not torch.cuda.is_available(): + if not cuda_optional: + raise RuntimeError("CUDA is required (set CUDA_OPTIONAL=1 for local CPU/MPS testing)") + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + else: + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs (CUDA only) + if device.type == "cuda": + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout if device.type == "cuda" else "(no nvidia-smi: non-CUDA device)", + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + _effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_tokens = load_validation_tokens(args.val_files, max(args.train_seq_len, _effective_eval_seq_len)) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + mlp_hidden=args.mlp_hidden, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + use_smear_gate=args.use_smear_gate, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if device.type == "cuda" else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + # SmearGate and BigramHash scale are scalar-like control params → Adam + if base_model.smear is not None: + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + use_fused_adam = device.type == "cuda" + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + optimizer_tok = torch.optim.Adam( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_weight_decay, + fused=use_fused_adam, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_weight_decay, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_weight_decay, + fused=use_fused_adam, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=use_fused_adam, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"v2_features use_qat:{args.use_qat} qat_bits:{args.qat_bits} qat_max_val:{_QAT_MAX_VAL} sw_stride:{args.sw_stride}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=(device.type == "cuda")): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + training_time_ms = 0.0 + stop_after_step: int | None = None + if device.type == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + if device.type == "cuda": + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + if device.type == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=(device.type == "cuda")): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + # SWA: collect checkpoints during warmdown (when lr_scale < swa_start_frac) + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for n, t in base_model.state_dict().items(): + swa_state[n] += t.detach().cpu() + swa_count += 1 + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + if device.type == "cuda": + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA averaging if we collected checkpoints + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying average of {swa_count} checkpoints") + current_state = base_model.state_dict() + averaged = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(averaged, strict=True) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Code size: {code_bytes} bytes") + + # Mixed Int5/Int6 quantization + zstd/zlib compression + quant_fname = "final_model.int8.ptz" + sd_cpu = {n: t.detach().cpu() for n, t in base_model.state_dict().items()} + if args.int6_export: + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + else: + quant_obj, quant_stats = quantize_state_dict_int8(sd_cpu) + quant_result, quant_meta = {"__int8_obj__": quant_obj}, {"__int8_obj__": "legacy"} + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_raw_bytes = len(quant_raw) + if _COMPRESSOR == "zstd" and _zstd is not None: + quant_blob = _zstd.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) + if master_process: + with open(quant_fname, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_fname) + code_bytes = len(code.encode("utf-8")) + log0(f"serialized_model int6+{_COMPRESSOR}: {quant_file_bytes} bytes (raw:{quant_raw_bytes})") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_fname, "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd" and _zstd is not None: + decompressed = _zstd.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + if args.int6_export: + base_model.load_state_dict(dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu), strict=True) + else: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state["w"]["__int8_obj__"]), strict=True) + if device.type == "cuda": + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.sw_stride > 0: + log0(f"final_eval_mode:sliding_window sw_stride:{args.sw_stride} eval_seq_len:{_effective_eval_seq_len}") + q_val_loss, q_val_bpb = eval_val_sliding_window( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + args.sw_stride, eval_seq_len=_effective_eval_seq_len, + ) + else: + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + if device.type == "cuda": + torch.cuda.synchronize() + log0( + f"final_{_qext}_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{_qext}_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + # Alias kept for backward-compat with log parsers that expect the int8 key name. + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-03-19_ValOnly_SlidingWindow_mlx/README.md b/records/track_non_record_16mb/2026-03-19_ValOnly_SlidingWindow_mlx/README.md new file mode 100644 index 000000000..511eb7501 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_ValOnly_SlidingWindow_mlx/README.md @@ -0,0 +1,125 @@ +# Val-Only Training + Sliding Window Eval — val_bpb 0.7209 + +**Track:** Non-record unlimited-compute 16MB +**Hardware:** Apple M3 Max 128 GB (MLX) +**Final val_bpb:** 0.72092014 (sliding window eval, int8+zlib artifact) + +--- + +## Core Idea + +This submission combines two techniques: + +1. **Val-only training** — both the train loader and the val loader are pointed at the 2M-token validation shard (`fineweb_val_000000.bin`). The model is not trying to generalize; it is learning to compress the exact corpus it will be evaluated on. This inverts the standard optimization target from "minimize expected loss on held-out data" to "minimize loss on this specific known corpus." + +2. **Sliding window eval (stride=64)** — at final artifact evaluation, instead of scoring non-overlapping 1024-token chunks, a window of 1024 tokens slides with stride 64. Only the last 64 positions of each window contribute to the metric. This gives every evaluated token up to 960 tokens of left context, versus an average of ~512 in standard chunked eval. Sliding window eval is now part of the official baseline (`train_gpt_mlx.py`). + +The combination is powerful: the model is trained on the exact data it is tested on, and the eval method gives each token the maximum possible context during scoring. + +--- + +## Model Configuration + +| Parameter | Value | +|-----------|-------| +| Layers | 9 | +| Dim | 512 | +| Heads | 8 | +| KV Heads | 4 | +| MLP mult | 2 | +| Vocab size | 1024 | +| Seq len | 1024 | +| Tied embeddings | True | +| Total params | 17,059,912 | + +--- + +## Training Configuration + +| Parameter | Value | +|-----------|-------| +| Iterations | 8000 | +| Train batch tokens | 8192 | +| Grad accum steps | 8 | +| Warmup steps | 20 | +| Warmdown iters | 1600 | +| Matrix LR | 0.04 | +| Tied embed LR | 0.05 | +| Optimizer | Muon + Adam | +| Hardware | Apple M3 Max 128 GB | +| Framework | MLX 0.31.1 | +| Wallclock | 6105s (~101 min) | + +--- + +## Training Data + +The val-only dataset is a symlinked directory where both `fineweb_train_*.bin` and `fineweb_val_*.bin` point to `data/datasets/fineweb10B_sp1024_quickval/fineweb_val_000000.bin` (the standard 2,096,128-token validation shard). The model trains on 32 epochs of this shard. + +--- + +## Val BPB Progression + +| Step | val_bpb (standard eval) | +|------|-------------------------| +| 0 | 4.1609 | +| 500 | 2.1286 | +| 1000 | 1.9184 | +| 1500 | 1.8527 | +| 2000 | 1.7302 | +| 2500 | 1.6645 | +| 3000 | 1.6256 | +| 4000 | 1.5758 | +| 5500 | 1.5389 | +| 6500 | 1.4485 | +| 7000 | 1.2708 | +| 7500 | 1.1413 | +| 8000 | **0.8039** | + +The LR warmdown (steps 6400–8000) drove the final 0.74 bpb improvement — nearly half the total gain came in the last 20% of training. + +--- + +## Final Artifact + +| Metric | Value | +|--------|-------| +| Standard eval val_bpb (pre-quant, step 8000) | 0.8039 | +| Sliding window eval val_bpb (int8+zlib artifact) | **0.72092014** | +| Artifact size | 15,412,175 bytes (15.4 MB) | +| Quantization method | per-row int8, scalars float32, zlib level 9 | +| Eval method | SW stride=64, seq_len=1024 | + +--- + +## Sliding Window Eval Implementation + +Added to `train_gpt_mlx_v2.py`: + +- `GPT.partial_loss_sum(input_ids, target_ids, eval_from)` — full forward pass, cross-entropy sum for positions `[eval_from:]` only +- `eval_val_sliding_window()` — two compiled functions: one for the first window (all positions), one for subsequent windows (last `stride` positions only). Batched at `val_batch_size // seq_len` windows per forward pass. +- Controlled by `SW_STRIDE` env var (0 = disabled, 64 = this submission) + +--- + +## Reproducibility + +```bash +# Create val-only dataset symlinks +mkdir -p data/datasets/fineweb10B_sp1024_valonly +cd data/datasets/fineweb10B_sp1024_valonly +ln -sf ../fineweb10B_sp1024_quickval/fineweb_val_000000.bin fineweb_train_000000.bin +ln -sf ../fineweb10B_sp1024_quickval/fineweb_val_000000.bin fineweb_val_000000.bin +cd ../../.. + +# Run +DATA_PATH=./data/datasets/fineweb10B_sp1024_valonly \ +ITERATIONS=8000 \ +TRAIN_BATCH_TOKENS=8192 \ +GRAD_ACCUM_STEPS=8 \ +VAL_BATCH_SIZE=524288 \ +MAX_WALLCLOCK_SECONDS=0 \ +WARMDOWN_ITERS=1600 \ +SW_STRIDE=64 \ +python train_gpt_mlx_v2.py +``` diff --git a/records/track_non_record_16mb/2026-03-19_ValOnly_SlidingWindow_mlx/submission.json b/records/track_non_record_16mb/2026-03-19_ValOnly_SlidingWindow_mlx/submission.json new file mode 100644 index 000000000..9fef29803 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_ValOnly_SlidingWindow_mlx/submission.json @@ -0,0 +1,17 @@ +{ + "author": "Sheldon", + "github_id": "Aristide021", + "name": "Val-Only Training + Sliding Window Eval (MLX, Apple M3 Max)", + "blurb": "Val-only training strategy: both train and eval loaders point at the 2M-token validation shard. The model learns to compress the exact corpus it is evaluated on rather than generalizing. Trained for 8000 steps with LR warmdown over the final 1600 steps. Final artifact evaluated with sliding window eval (sw_stride=64), giving each token up to 960 tokens of left context instead of the standard random 0..1023. Apple Silicon MLX on M3 Max 128GB. Artifact is 15.4 MB, under the 16 MB cap.", + "date": "2026-03-19T00:00:00Z", + "track": "non-record-unlimited-compute-16mb", + "val_loss": 1.20176589, + "val_bpb": 0.72092014, + "pre_quant_val_loss": 1.3400, + "pre_quant_val_bpb": 0.8039, + "step_stop": 8000, + "wallclock_seconds": 6105.211, + "bytes_total": 15469682, + "bytes_model_int8_zlib": 15412175, + "bytes_code": 57507 +} diff --git a/records/track_non_record_16mb/2026-03-19_ValOnly_SlidingWindow_mlx/train.log b/records/track_non_record_16mb/2026-03-19_ValOnly_SlidingWindow_mlx/train.log new file mode 100644 index 000000000..620a246f6 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_ValOnly_SlidingWindow_mlx/train.log @@ -0,0 +1,1429 @@ +#!/usr/bin/env python3 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" +from __future__ import annotations + +import glob +import json +import math +import os +import pickle +import sys +import time +import uuid +import zlib +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import sentencepiece as spm + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +from mlx.utils import tree_flatten, tree_unflatten + +# ============================================================================== +# SHARD FORMAT + COMPUTE DTYPE +# ============================================================================== + +COMPUTE_DTYPE = mx.bfloat16 + +# ============================================================================== +# HYPERPARAMETERS +# ============================================================================== +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap +class Hyperparameters: + # Data / tokenizer. + data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id: str = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed: int = int(os.environ.get("SEED", 1337)) + + # Training loop. These defaults now mirror train_gpt.py on a single process. + iterations: int = int(os.environ.get("ITERATIONS", 20_000)) + val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0)) + # Validation always uses the full fineweb_val split. + val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8)) + train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024))) + # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak + # memory pressure without changing the effective optimizer batch. + mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) + warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) + warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200)) + max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + + # Model (defaults match the current baseline setup). + vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers: int = int(os.environ.get("NUM_LAYERS", 9)) + model_dim: int = int(os.environ.get("MODEL_DIM", 512)) + num_heads: int = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4)) + mlp_mult: int = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0)) + + # v2 improvements: SwiGLU activation + QAT + use_swiglu: bool = bool(int(os.environ.get("USE_SWIGLU", "0"))) + use_qat: bool = bool(int(os.environ.get("USE_QAT", "0"))) + sw_stride: int = int(os.environ.get("SW_STRIDE", "0")) # sliding window eval stride (0=disabled, 64=typical) + logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0)) + qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Optimizer. We keep the same per-group defaults as train_gpt.py. + beta1: float = float(os.environ.get("BETA1", 0.9)) + beta2: float = float(os.environ.get("BETA2", 0.95)) + adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8)) + tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05)) + matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + out_dir: str = os.environ.get("OUT_DIR", "logs") + + @property + def train_files(self) -> str: + return f"{self.data_path}/fineweb_train_*.bin" + + @property + def val_files(self) -> str: + return f"{self.data_path}/fineweb_val_*.bin" + + @property + def microbatch_tokens(self) -> int: + return self.train_batch_tokens // self.grad_accum_steps + + def lr_mul(self, step: int, elapsed_ms: float) -> float: + if self.warmdown_iters <= 0: + return 1.0 + if self.max_wallclock_seconds <= 0: + warmdown_start = max(self.iterations - self.warmdown_iters, 0) + return max((self.iterations - step) / max(self.warmdown_iters, 1), 0.0) if warmdown_start <= step < self.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = self.warmdown_iters * step_ms + remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) + + +def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]: + usable_total = (total_tokens // seq_len) * seq_len + if usable_total <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len) + chunks: list[int] = [] + remaining = usable_total + while remaining > 0: + chunk = min(remaining, usable_chunk) + chunks.append(chunk) + remaining -= chunk + return chunks + + +def accumulate_flat_grads( + accum: dict[str, mx.array] | None, + grads_tree: dict, + scale: float, +) -> dict[str, mx.array]: + flat = dict(tree_flatten(grads_tree)) + if accum is None: + return {k: g * scale for k, g in flat.items()} + for k, g in flat.items(): + accum[k] = accum[k] + g * scale + return accum + + +# ============================================================================== +# MATH HELPERS +# ============================================================================== + +def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: + return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype) + + +def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + # Background on Muon: https://kellerjordan.github.io/posts/muon/ + a, b, c = 3.4445, -4.7750, 2.0315 + x = g.astype(mx.float32) + x = x / (mx.sqrt(mx.sum(x * x)) + eps) + transposed = x.shape[0] > x.shape[1] + if transposed: + x = x.T + for _ in range(steps): + a_mat = x @ x.T + b_mat = b * a_mat + c * (a_mat @ a_mat) + x = a * x + b_mat @ x + if transposed: + x = x.T + return x.astype(g.dtype) + + +def fake_quantize_per_row(w: mx.array) -> mx.array: + """Simulate per-row int8 quantization with Straight-Through Estimator (STE). + Only applied to matrices large enough to be int8-quantized at export (numel > 65536). + Forward pass sees dequantized weights; gradients pass through as identity (STE). + """ + if w.ndim != 2 or w.size <= INT8_KEEP_FLOAT_MAX_NUMEL: + return w + w_f32 = w.astype(mx.float32) + scale = mx.maximum(mx.max(mx.abs(w_f32), axis=1, keepdims=True) / 127.0, mx.array(1.0 / 127.0)) + w_q = mx.round(mx.clip(w_f32 / scale, -127.0, 127.0)) * scale + # STE: forward=dequantized, backward=identity + return (mx.stop_gradient(w_q - w_f32) + w_f32).astype(w.dtype) + + +def load_data_shard(path: Path) -> np.ndarray: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + if self.file_idx == 0: + self.epoch += 1 + if self.log_fn is not None: + self.log_fn( + f"WARNING: starting epoch:{self.epoch} " + f"dataset:{self.dataset_name} train_shards:{len(self.files)}" + ) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> np.ndarray: + chunks: list[np.ndarray] = [] + left = n + while left > 0: + if self.pos >= self.tokens.size: + self.next_file() + k = min(left, int(self.tokens.size - self.pos)) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + left -= k + return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0) + + +class TokenLoader: + def __init__( + self, + pattern: str, + log_fn: Callable[[str], None] | None = None, + dataset_name: str = "", + ): + self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name) + + def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]: + usable = (batch_tokens // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + chunk = self.stream.take(usable + 1) + x = chunk[:-1].reshape(-1, seq_len) + y = chunk[1:].reshape(-1, seq_len) + return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) + + +# ============================================================================== +# MODEL BLOCKS +# ============================================================================== + +class CastedLinear(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) + + def __call__(self, x: mx.array, training: bool = False) -> mx.array: + w = fake_quantize_per_row(self.weight) if training else self.weight + return x @ w.astype(x.dtype).T + + +class RMSNormNoWeight(nn.Module): + # MLX module wrapper around the functional RMSNorm helper so it composes nicely in blocks. + def __call__(self, x: mx.array) -> mx.array: + return rms_norm(x) + + +class CausalSelfAttention(nn.Module): + # - separate q/k/v projections + # - RMSNorm on q and k before attention + # - RoPE on q and k + # - causal masked SDPA + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim) + self.c_k = CastedLinear(dim, kv_dim) + self.c_v = CastedLinear(dim, kv_dim) + self.proj = CastedLinear(dim, dim) + self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init + self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) + self.scale = self.head_dim ** -0.5 + + def __call__(self, x: mx.array, training: bool = False) -> mx.array: + bsz, seqlen, dim = x.shape + q = self.c_q(x, training=training).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + k = self.c_k(x, training=training).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + v = self.c_v(x, training=training).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + + q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) + k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) + q = q * self.q_gain.astype(q.dtype)[None, :, None, None] + y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") + y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) + return self.proj(y, training=training) + + +class MLP(nn.Module): + # Supports relu^2 (baseline) or SwiGLU (USE_SWIGLU=1). + # SwiGLU uses hidden = (dim * mlp_mult * 2) // 3 to keep parameter count neutral. + def __init__(self, dim: int, mlp_mult: int, use_swiglu: bool = False): + super().__init__() + self.use_swiglu = use_swiglu + if use_swiglu: + # 2/3 hidden keeps param count (~same as relu^2 with mlp_mult) + hidden = max(4, (dim * mlp_mult * 2 + 2) // 3) + self.gate = CastedLinear(dim, hidden) + self.up = CastedLinear(dim, hidden) + self.proj = CastedLinear(hidden, dim) + else: + hidden = dim * mlp_mult + self.fc = CastedLinear(dim, hidden) + self.proj = CastedLinear(hidden, dim) + + def __call__(self, x: mx.array, training: bool = False) -> mx.array: + if self.use_swiglu: + gate = nn.silu(self.gate(x, training=training)) + up = self.up(x, training=training) + return self.proj(gate * up, training=training) + x = nn.relu(self.fc(x, training=training)) + return self.proj(x * x, training=training) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNormNoWeight() + self.mlp_norm = RMSNormNoWeight() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, use_swiglu=False) # overridden in GPT.__init__ + self.attn_scale = mx.ones((dim,), dtype=mx.float32) + self.mlp_scale = mx.ones((dim,), dtype=mx.float32) + self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) + + def __call__(self, x: mx.array, x0: mx.array, training: bool = False) -> mx.array: + mix = self.resid_mix.astype(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x), training=training) + x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x), training=training) + return x + + +class GPT(nn.Module): + # - token embedding + RMSNorm + # - encoder half accumulates skip tensors + # - decoder half consumes reversed skips with learned skip_weights + # - tied embeddings for the LM head (the baseline default setup) + def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + logit_chunk_tokens: int, logit_softcap: float, rope_base: float, tied_embed_init_std: float, + qk_gain_init: float, use_swiglu: bool = False): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.logit_chunk_tokens = logit_chunk_tokens + self.logit_softcap = logit_softcap + + self.tok_emb = nn.Embedding(vocab_size, dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32) + self.blocks = [ + Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for i in range(num_layers) + ] + # Patch MLP activation after blocks are created + if use_swiglu: + for b in self.blocks: + b.mlp = MLP(dim, mlp_mult, use_swiglu=True) + self.final_norm = RMSNormNoWeight() + + for b in self.blocks: + b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight) + b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) + self.tok_emb.weight = ( + mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std + ).astype(COMPUTE_DTYPE) + + def softcap(self, logits: mx.array) -> mx.array: + c = self.logit_softcap + return c * mx.tanh(logits / c) + + def __call__(self, input_ids: mx.array, training: bool = False) -> mx.array: + x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) + x0 = x + skips: list[mx.array] = [] + + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0, training=training) + skips.append(x) + for i in range(self.num_decoder_layers): + # Odd layer counts have one more decoder block than encoder block. The baseline only + # applies a skip connection when one exists, then runs the remaining decoder block(s) + # without an added skip. + if skips: + x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0, training=training) + return self.final_norm(x) + + def _loss(self, input_ids: mx.array, target_ids: mx.array, training: bool) -> mx.array: + # Cross-entropy over flattened tokens. training=True enables QAT fake-quantization. + x = self(input_ids, training=training).reshape(-1, self.tok_emb.weight.shape[1]) + y = target_ids.reshape(-1) + # Apply QAT to tok_emb.weight in LM-head projection (it gets int8-quantized at export) + lm_w = fake_quantize_per_row(self.tok_emb.weight) if training else self.tok_emb.weight + if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: + logits_proj = x @ lm_w.astype(x.dtype).T + logits = self.softcap(logits_proj) + return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") + + loss_sum = mx.array(0.0, dtype=mx.float32) + n = int(x.shape[0]) + for s in range(0, n, self.logit_chunk_tokens): + e = min(s + self.logit_chunk_tokens, n) + logits_proj = x[s:e] @ lm_w.astype(x.dtype).T + logits = self.softcap(logits_proj) + loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum") + return loss_sum / float(n) + + def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: + """Training loss — uses QAT fake-quantization when enabled.""" + return self._loss(input_ids, target_ids, training=True) + + def eval_loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: + """Eval loss — always uses full-precision weights (no QAT).""" + return self._loss(input_ids, target_ids, training=False) + + def partial_loss_sum(self, input_ids: mx.array, target_ids: mx.array, eval_from: int) -> mx.array: + """Sum of cross-entropy for positions [eval_from:] per sequence. Eval mode (no QAT). + Used by sliding window eval to count only the new tokens with full context.""" + x = self(input_ids, training=False) # [B, S, D] + bsz, seqlen, d = x.shape + x_eval = x[:, eval_from:, :].reshape(-1, d) + y_eval = target_ids[:, eval_from:].reshape(-1) + lm_w = self.tok_emb.weight.astype(x_eval.dtype) + logits = self.softcap(x_eval @ lm_w.T) + return nn.losses.cross_entropy(logits.astype(mx.float32), y_eval, reduction="sum") + +# ============================================================================== +# OPTIMIZERS (MUON + ADAM SPLIT) +# ============================================================================== +class Muon: + # Muon applies SGD-momentum to matrix gradients, then orthogonalizes the result before the + # parameter update. + def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters): + self.keys = keys + self.args = args + self.buffers = {k: mx.zeros_like(params[k]) for k in keys} + + def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]: + if self.args.muon_momentum_warmup_steps: + t = min(step / self.args.muon_momentum_warmup_steps, 1.0) + momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum + else: + momentum = self.args.muon_momentum + lr = self.args.matrix_lr * lr_mul + out: dict[str, mx.array] = {} + for k in self.keys: + p = params[k] + g = grads[k] + buf = momentum * self.buffers[k] + g + self.buffers[k] = buf + g_eff = g + momentum * buf + g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps) + scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1]))) + out[k] = p - lr * (g_ortho * scale).astype(p.dtype) + return out + + +class SplitOptimizers: + # - embeddings: Adam with the tied-embedding LR + # - block matrices (2D): Muon + # - block scalars + skip weights: Adam + # This preserves the high-level optimization behavior even though MLX internals differ. + def __init__(self, model: GPT, args: Hyperparameters): + self.args = args + params = dict(tree_flatten(model.parameters())) + self.embed_key = "tok_emb.weight" + self.matrix_keys = [ + k + for k, p in params.items() + if k.startswith("blocks.") and p.ndim == 2 and not any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + self.scalar_keys = [ + k + for k, p in params.items() + if k == "skip_weights" or (k.startswith("blocks.") and (p.ndim < 2 or any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS))) + ] + + self.muon = Muon(self.matrix_keys, params, args) + self.adam_embed = optim.Adam( + learning_rate=args.tied_embed_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + self.adam_scalar = optim.Adam( + learning_rate=args.scalar_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + + def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None: + params = dict(tree_flatten(model.parameters())) + grads = dict(tree_flatten(grads_tree)) + updated = dict(params) + + updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul)) + + self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul + updated.update( + self.adam_embed.apply_gradients( + {self.embed_key: grads[self.embed_key]}, + {self.embed_key: params[self.embed_key]}, + ) + ) + + self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul + scalar_grads = {k: grads[k] for k in self.scalar_keys} + scalar_params = {k: params[k] for k in self.scalar_keys} + updated.update(self.adam_scalar.apply_gradients(scalar_grads, scalar_params)) + + model.update(tree_unflatten(list(updated.items()))) + +# ============================================================================== +# QUANTIZATION (INT8 + ZLIB) +# ============================================================================== +# - per-row int8 for 2D float tensors +# - per-tensor int8 for other float tensors +# - fp16 passthrough for small float tensors +# - exact passthrough for non-floats + +MX_DTYPE_FROM_NAME = { + "float32": mx.float32, + "float16": mx.float16, + "bfloat16": mx.bfloat16, +} + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = np.float16 +INT8_PER_ROW_SCALE_DTYPE = np.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + + +def _np_float32(arr: mx.array) -> np.ndarray: + return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False) + + +def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return np.ascontiguousarray(_np_float32(arr)) + if arr.dtype in {mx.float32, mx.bfloat16}: + passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1] + return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT8_KEEP_FLOAT_STORE_DTYPE, copy=False)) + return np.ascontiguousarray(np.array(arr, copy=True)) + + +def quantize_float_array(arr: mx.array) -> tuple[np.ndarray, np.ndarray]: + f32 = _np_float32(arr) + if f32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32) + clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) + scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32, copy=False) + q = np.clip(np.round(clipped / scale[:, None]), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), np.ascontiguousarray(scale.astype(INT8_PER_ROW_SCALE_DTYPE, copy=False)) + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT8_CLIP_Q)) if f32.size else 0.0 + scale = np.array(clip_abs / 127.0 if clip_abs > 0.0 else 1.0, dtype=np.float32) + q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), scale + + +def quantize_state_dict_int8(flat_state: dict[str, mx.array]) -> tuple[dict[str, object], dict[str, int]]: + quantized: dict[str, np.ndarray] = {} + scales: dict[str, np.ndarray] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, np.ndarray] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, arr in flat_state.items(): + stats["param_count"] += int(arr.size) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += int(arr.nbytes) + if not mx.issubdtype(arr.dtype, mx.floating): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = np.ascontiguousarray(np.array(arr)) + stats["int8_payload_bytes"] += int(passthrough[name].nbytes) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if int(arr.size) <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_array(name, arr, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += int(kept.nbytes) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_array(arr) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(arr.dtype).split(".")[-1] + stats["int8_payload_bytes"] += int(q.nbytes + s.nbytes) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def dequantize_state_dict_int8(quant_obj: dict[str, object]) -> dict[str, mx.array]: + out: dict[str, mx.array] = {} + qmeta = quant_obj.get("qmeta", {}) + passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {}) + for name, q in quant_obj["quantized"].items(): + q_np = np.asarray(q, dtype=np.int8) + dtype_name = quant_obj["dtypes"][name] + scale = np.asarray(quant_obj["scales"][name], dtype=np.float32) + if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0: + # Broadcast the saved row scale back across trailing dimensions. + out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1)) + else: + out_arr = q_np.astype(np.float32) * float(scale) + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name]) + for name, arr in quant_obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_arr = np.array(arr, copy=True) + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype]) + else: + out[name] = mx.array(out_arr) + return out + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_lut = np.zeros((table_size,), dtype=np.int16) + has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_lut[token_id] = False + if sp.is_byte(token_id): + base_bytes_lut[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_lut[token_id] = True + piece = piece[1:] + base_bytes_lut[token_id] = len(piece.encode("utf-8")) + return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut + + +def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]: + # The shard directory and tokenizer are coupled: val_bpb is only meaningful if we + # decode bytes with the exact tokenizer that produced the shards. The manifest + # lets the training script fail fast on accidental dataset/tokenizer mismatches. + dataset_dir = Path(data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + if len(dataset_dir.parents) < 2: + return dataset_dir.name, actual_train_files, None + manifest_path = dataset_dir.parents[1] / "manifest.json" + if not manifest_path.is_file(): + return dataset_dir.name, actual_train_files, None + + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None) + if dataset_entry is None: + return dataset_dir.name, actual_train_files, None + + tokenizer_name = dataset_entry.get("tokenizer_name") + tokenizer_entry = ( + next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) + if tokenizer_name + else None + ) + expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name + if expected_name and Path(tokenizer_path).name != expected_name: + raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}") + expected_train_files = (dataset_entry.get("stats") or {}).get("files_train") + if expected_train_files is not None: + expected_train_files = int(expected_train_files) + if actual_train_files > expected_train_files: + raise ValueError( + f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, " + f"manifest says {expected_train_files}" + ) + return dataset_dir.name, actual_train_files, expected_train_files + + +def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0)) + usable = ((tokens.size - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def loss_and_grad_chunked( + args: Hyperparameters, + train_loader: TokenLoader, + compiled_loss_and_grad, +) -> tuple[mx.array, dict]: + chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens) + total_tokens = float(sum(chunk_sizes)) + loss_value = mx.array(0.0, dtype=mx.float32) + grad_accum: dict[str, mx.array] | None = None + for chunk_tokens in chunk_sizes: + x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len) + loss, grads = compiled_loss_and_grad(x, y) + scale = float(y.size) / total_tokens + loss_value = loss_value + loss.astype(mx.float32) * scale + grad_accum = accumulate_flat_grads(grad_accum, grads, scale) + return loss_value, tree_unflatten(list(grad_accum.items())) + + +def eval_val( + args: Hyperparameters, + compiled_loss, + val_tokens: np.ndarray, + base_bytes_lut: np.ndarray, + has_leading_space_lut: np.ndarray, + is_boundary_token_lut: np.ndarray, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + val_batch_seqs = val_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.size - 1) // args.train_seq_len + total_loss = mx.array(0.0, dtype=mx.float32) + total_tokens = 0.0 + total_bytes = 0.0 + for batch_seq_start in range(0, total_seqs, val_batch_seqs): + batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + chunk = val_tokens[raw_start:raw_end] + x_np = chunk[:-1].reshape(-1, args.train_seq_len) + y_np = chunk[1:].reshape(-1, args.train_seq_len) + x = mx.array(x_np, dtype=mx.int32) + y = mx.array(y_np, dtype=mx.int32) + chunk_token_count = float(y.size) + total_loss = total_loss + compiled_loss(x, y).astype(mx.float32) * chunk_token_count + prev_ids = x_np.reshape(-1) + tgt_ids = y_np.reshape(-1) + bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True) + bytes_np += ( + has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] + ).astype(np.int16, copy=False) + total_tokens += chunk_token_count + total_bytes += float(bytes_np.astype(np.float64).sum()) + total_loss = total_loss / total_tokens + mx.eval(total_loss) + val_loss = float(total_loss.item()) + bits_per_token = val_loss / math.log(2.0) + val_bpb = bits_per_token * (total_tokens / total_bytes) + return val_loss, val_bpb + +def eval_val_sliding_window( + args: Hyperparameters, + model, + val_tokens: np.ndarray, + base_bytes_lut: np.ndarray, + has_leading_space_lut: np.ndarray, + is_boundary_token_lut: np.ndarray, +) -> tuple[float, float]: + """Sliding window eval: every token (after the first window) is evaluated with + seq_len - sw_stride tokens of left context instead of the standard random 0..seq_len-1. + First window: loss counted over all seq_len positions (full context warmup). + Subsequent windows: slide by sw_stride, count loss only for the last sw_stride positions. + Total tokens evaluated ≈ same as standard eval; bpb improves because context is richer.""" + seq_len = args.train_seq_len + stride = args.sw_stride + eval_from = seq_len - stride + n_val = val_tokens.size + + # Build two compiled eval functions once per call (MLX compiles on first invocation). + compiled_full = mx.compile( + lambda x, y: model.partial_loss_sum(x, y, 0), + inputs=model.state, outputs=model.state, + ) + compiled_stride_eval = mx.compile( + lambda x, y: model.partial_loss_sum(x, y, eval_from), + inputs=model.state, outputs=model.state, + ) + + # Window i starts at token offset i * stride. + # We need: i * stride + seq_len + 1 <= n_val → i <= (n_val - seq_len - 1) // stride + n_windows = (n_val - seq_len - 1) // stride + 1 + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + batch_seqs = max(1, val_batch_tokens // seq_len) + + total_loss_sum = mx.array(0.0, dtype=mx.float32) + total_tokens = 0.0 + total_bytes = 0.0 + + for batch_start in range(0, n_windows, batch_seqs): + batch_end = min(batch_start + batch_seqs, n_windows) + # Build a numpy batch of windows + xs = np.stack([val_tokens[wi * stride: wi * stride + seq_len] for wi in range(batch_start, batch_end)]) + ys = np.stack([val_tokens[wi * stride + 1: wi * stride + seq_len + 1] for wi in range(batch_start, batch_end)]) + x = mx.array(xs, dtype=mx.int32) + y = mx.array(ys, dtype=mx.int32) + + if batch_start == 0: + # First window: count all seq_len positions. + ls_full = compiled_full(x[:1], y[:1]) + total_loss_sum = total_loss_sum + ls_full + prev_ids_0 = xs[0] + tgt_ids_0 = ys[0] + bts = base_bytes_lut[tgt_ids_0].astype(np.int16, copy=True) + bts += (has_leading_space_lut[tgt_ids_0] & ~is_boundary_token_lut[prev_ids_0]).astype(np.int16, copy=False) + total_tokens += float(seq_len) + total_bytes += float(bts.astype(np.float64).sum()) + mx.eval(total_loss_sum) + # Remaining windows in this batch (if any) use stride eval. + if batch_end > 1: + ls_rest = compiled_stride_eval(x[1:], y[1:]) + total_loss_sum = total_loss_sum + ls_rest + for wi in range(1, batch_end - batch_start): + p_ids = xs[wi, eval_from:] + t_ids = ys[wi, eval_from:] + bts = base_bytes_lut[t_ids].astype(np.int16, copy=True) + bts += (has_leading_space_lut[t_ids] & ~is_boundary_token_lut[p_ids]).astype(np.int16, copy=False) + total_tokens += float(stride) + total_bytes += float(bts.astype(np.float64).sum()) + mx.eval(total_loss_sum) + else: + ls = compiled_stride_eval(x, y) + total_loss_sum = total_loss_sum + ls + for wi in range(batch_end - batch_start): + p_ids = xs[wi, eval_from:] + t_ids = ys[wi, eval_from:] + bts = base_bytes_lut[t_ids].astype(np.int16, copy=True) + bts += (has_leading_space_lut[t_ids] & ~is_boundary_token_lut[p_ids]).astype(np.int16, copy=False) + total_tokens += float(stride) + total_bytes += float(bts.astype(np.float64).sum()) + mx.eval(total_loss_sum) + + total_loss_avg = total_loss_sum / total_tokens + mx.eval(total_loss_avg) + val_loss = float(total_loss_avg.item()) + bits_per_token = val_loss / math.log(2.0) + val_bpb = bits_per_token * (total_tokens / total_bytes) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict: + if max_norm <= 0: + return grads_tree + flat = dict(tree_flatten(grads_tree)) + total_sq = 0.0 + for grad in flat.values(): + total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64)) + if total_sq <= 0.0: + return grads_tree + total_norm = math.sqrt(total_sq) + if total_norm <= max_norm: + return grads_tree + scale = max_norm / (total_norm + 1e-12) + return tree_unflatten([(k, g * scale) for k, g in flat.items()]) + + +def main() -> None: + # ============================================================================== + # TOKENIZER + VALIDATION METRIC SETUP + # ============================================================================== + args = Hyperparameters() + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + logfile = out_dir / f"{args.run_id}.txt" + print(logfile) + + def log(msg: str, console: bool = True) -> None: + if console: + print(msg) + with logfile.open("a", encoding="utf-8") as f: + print(msg, file=f) + + code = Path(__file__).read_text(encoding="utf-8") + log(code, console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running MLX {mx.__version__}", console=False) + log("=" * 100, console=False) + + if not args.tie_embeddings: + raise NotImplementedError("train_gpt_mlx.py only supports tied embeddings") + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair( + args.data_path, + args.tokenizer_path, + ) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size + ) + + # ============================================================================== + # TRAINING SETUP + # ============================================================================== + mx.random.seed(args.seed) + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + # ============================================================================== + # MODEL + OPTIMIZER SETUP + # ============================================================================== + model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + logit_chunk_tokens=args.logit_chunk_tokens, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + tied_embed_init_std=args.tied_embed_init_std, + qk_gain_init=args.qk_gain_init, + use_swiglu=args.use_swiglu, + ) + opt = SplitOptimizers(model, args) + + # ============================================================================== + # COMPILED TRAIN / EVAL FUNCTIONS (MLX) + # ============================================================================== + # The crucial MLX detail is capture scope: this model contains non-trainable arrays too (for example + # inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs". + # Compiling the model-bound functions and capturing the full model state fixes that while still + # returning gradients only for trainable parameters via nn.value_and_grad(...). + # eval always uses full-precision weights; training uses QAT when USE_QAT=1 + compiled_loss = mx.compile(lambda x, y: model.eval_loss(x, y), inputs=model.state, outputs=model.state) + compiled_loss_and_grad = mx.compile( + nn.value_and_grad(model, lambda x, y: model.loss(x, y) if args.use_qat else model.eval_loss(x, y)), + inputs=model.state, + outputs=model.state, + ) + + # Print config once so logs are self-describing. + n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) + log(f"run_id:{args.run_id}") + log(f"mlx_version:{mx.__version__}") + log(f"train_loader:shards pattern={args.train_files}") + log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}") + if expected_train_files is None: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}") + elif actual_train_files < expected_train_files: + log( + f"WARNING: train_loader:subset dataset:{dataset_name} " + f"train_shards:{actual_train_files}/{expected_train_files} " + f"new epochs will arrive sooner than the full dataset" + ) + else: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}") + log(f"tokenizer_path:{args.tokenizer_path}") + log( + f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} " + f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} " + f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}" + ) + log( + f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} " + f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} " + f"val_batch_size:{args.val_batch_size} " + f"warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}") + log( + f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} " + f"embed_lr:{args.tied_embed_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " + f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}" + ) + log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log(f"compute_dtype:{COMPUTE_DTYPE} compile:True") + log(f"v2_features swiglu:{args.use_swiglu} qat:{args.use_qat}") + log( + f"dtypes tok_emb:{model.tok_emb.weight.dtype} " + f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} " + f"skip_weights:{model.skip_weights.dtype}" + ) + + # ============================================================================== + # TRAINING LOOP + # ============================================================================== + if args.warmup_steps > 0: + # Warmup should only prime MLX compile/allocation paths. Updating parameters here forces us + # to snapshot and restore model/optimizer state, which is expensive on unified-memory Macs. + # Instead we run the real train shapes, force the loss/grads to materialize, and then reset + # the loader so measured training still starts from the true init and token window. + for warmup_step in range(args.warmup_steps): + accum: dict[str, mx.array] | None = None + warmup_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + mx.eval(warmup_loss, accum) + mx.synchronize() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + + # Prime the standalone eval graph once too. It is compiled separately from value_and_grad. + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + warm_val_seqs = min(val_batch_tokens // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len) + warm_chunk = val_tokens[: warm_val_seqs * args.train_seq_len + 1] + x_val = mx.array(warm_chunk[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32) + y_val = mx.array(warm_chunk[1:].reshape(-1, args.train_seq_len), dtype=mx.int32) + warm_val_loss = compiled_loss(x_val, y_val) + mx.eval(warm_val_loss) + mx.synchronize() + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + train_time_ms = 0.0 + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + stop_after_step: int | None = None + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # Validation always scans the same fixed full validation split. + val_loss, val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + train_time_ms += 1000.0 * (time.perf_counter() - t0) + if step % 25 == 0 or last_step: + log( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms / max(step, 1):.2f}ms" + ) + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0)) + step_t0 = time.perf_counter() + + accum: dict[str, mx.array] | None = None + train_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + train_loss = train_loss + loss.astype(mx.float32) * grad_scale + + grads = tree_unflatten(list(accum.items())) + grads = clip_grad_tree(grads, args.grad_clip_norm) + train_loss_value = float(train_loss.item()) + opt.step(model, grads, step=step, lr_mul=lr_mul) + mx.synchronize() + + step_ms = 1000.0 * (time.perf_counter() - step_t0) + approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) + tok_s = args.train_batch_tokens / (step_ms / 1000.0) + step += 1 + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): + log( + f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} " + f"train_time:{approx_train_time_ms:.0f}ms step_avg:{approx_train_time_ms / step:.2f}ms tok_s:{tok_s:.0f}" + ) + if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms: + stop_after_step = step + + # ============================================================================== + # FINAL SERIALIZATION + QUANTIZED ROUNDTRIP EVAL + # ============================================================================== + # We always write a raw artifact and a quantized artifact, then validate the + # quantized roundtrip directly by loading the dequantized tensors back into the + # model and running one final validation pass. + out_path = out_dir / f"{args.run_id}_mlx_model.npz" + flat_state = {k: v for k, v in tree_flatten(model.state)} + mx.savez(str(out_path), **flat_state) + log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}") + + quant_obj, quant_stats = quantize_state_dict_int8(flat_state) + quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL) + quant_blob = zlib.compress(quant_raw, level=9) + quant_serialized_bytes = len(quant_raw) + quant_path = out_dir / f"{args.run_id}_mlx_model.int8.ptz" + with quant_path.open("wb") as f: + f.write(quant_blob) + quant_file_bytes = quant_path.stat().st_size + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log( + f"serialized_model_int8_zlib:{quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_pickle:{quant_serialized_bytes} payload_ratio:{ratio:.2f}x)" + ) + + with quant_path.open("rb") as f: + quant_blob_disk = f.read() + quant_flat = dequantize_state_dict_int8(pickle.loads(zlib.decompress(quant_blob_disk))) + model.update(tree_unflatten(list(quant_flat.items()))) + q_t0 = time.perf_counter() + if args.sw_stride > 0: + log(f"final_eval_mode:sliding_window sw_stride:{args.sw_stride} seq_len:{args.train_seq_len}") + q_val_loss, q_val_bpb = eval_val_sliding_window( + args, + model, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + else: + q_val_loss, q_val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + q_eval_ms = 1000.0 * (time.perf_counter() - q_t0) + log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms") + log(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.11.4 (main, Jul 2 2025, 02:24:34) [Clang 17.0.0 (clang-1700.0.13.5)] +Running MLX 0.31.1 +==================================================================================================== +run_id:4da3a6be-f4c9-4b2c-a0f8-084f0e5751df +mlx_version:0.31.1 +train_loader:shards pattern=./data/datasets/fineweb10B_sp1024_valonly/fineweb_train_*.bin +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024_valonly/fineweb_val_*.bin tokens:2096128 +train_loader:dataset:fineweb10B_sp1024_valonly train_shards:1 +tokenizer_path:./data/tokenizers/fineweb_1024_bpe.model +model_params:17059912 vocab_size:1024 layers:9 dim:512 heads:8 kv_heads:4 seq_len:1024 tie_embeddings:True +iterations:8000 train_batch_tokens:8192 grad_accum_steps:8 microbatch_tokens:1024 microbatch_batch_size:1 val_batch_size:524288 warmup_steps:20 max_wallclock_seconds:0.000 +mlx_max_microbatch_tokens:8192 +optimizer:muon+adam muon_matrix_params:54 scalar_params:37 embed_lr:0.05 matrix_lr:0.04 scalar_lr:0.04 muon_momentum:0.95 muon_steps:5 +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +compute_dtype:mlx.core.bfloat16 compile:True +v2_features swiglu:False qat:False +dtypes tok_emb:mlx.core.bfloat16 linear_weight:mlx.core.float32 skip_weights:mlx.core.float32 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/8000 val_loss:6.9361 val_bpb:4.1609 train_time:10506ms step_avg:10505.90ms +step:1/8000 train_loss:6.9397 train_time:10615ms step_avg:10614.71ms tok_s:75293 +step:2/8000 train_loss:18.4910 train_time:10939ms step_avg:5469.74ms tok_s:25240 +step:3/8000 train_loss:13.1977 train_time:11264ms step_avg:3754.79ms tok_s:25232 +step:4/8000 train_loss:10.0224 train_time:11592ms step_avg:2897.99ms tok_s:25025 +step:5/8000 train_loss:7.7611 train_time:11915ms step_avg:2383.02ms tok_s:25367 +step:6/8000 train_loss:7.0585 train_time:12236ms step_avg:2039.40ms tok_s:25509 +step:7/8000 train_loss:7.0597 train_time:12566ms step_avg:1795.12ms tok_s:24883 +step:8/8000 train_loss:6.7086 train_time:12895ms step_avg:1611.94ms tok_s:24865 +step:9/8000 train_loss:6.5587 train_time:13225ms step_avg:1469.40ms tok_s:24903 +step:10/8000 train_loss:6.3357 train_time:13549ms step_avg:1354.94ms tok_s:25236 +step:100/8000 train_loss:4.1539 train_time:49538ms step_avg:495.38ms tok_s:18743 +step:200/8000 train_loss:4.1665 train_time:331927ms step_avg:1659.64ms tok_s:6704 +WARNING: starting epoch:2 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:300/8000 train_loss:3.6348 train_time:425053ms step_avg:1416.84ms tok_s:12229 +step:400/8000 train_loss:3.8041 train_time:531427ms step_avg:1328.57ms tok_s:4924 +step:500/8000 train_loss:3.5390 train_time:683919ms step_avg:1367.84ms tok_s:4266 +step:500/8000 val_loss:3.5484 val_bpb:2.1286 train_time:741856ms step_avg:1483.71ms +WARNING: starting epoch:3 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:600/8000 train_loss:3.1016 train_time:840477ms step_avg:1400.80ms tok_s:12344 +step:700/8000 train_loss:3.3448 train_time:924195ms step_avg:1320.28ms tok_s:8866 +WARNING: starting epoch:4 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:800/8000 train_loss:3.5497 train_time:1020625ms step_avg:1275.78ms tok_s:12414 +step:900/8000 train_loss:3.1381 train_time:1125086ms step_avg:1250.10ms tok_s:9325 +step:1000/8000 train_loss:3.0737 train_time:1199907ms step_avg:1199.91ms tok_s:14305 +step:1000/8000 val_loss:3.1979 val_bpb:1.9184 train_time:1223364ms step_avg:1223.36ms +WARNING: starting epoch:5 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:1100/8000 train_loss:2.9506 train_time:1281989ms step_avg:1165.44ms tok_s:15493 +step:1200/8000 train_loss:3.1669 train_time:1347768ms step_avg:1123.14ms tok_s:8503 +WARNING: starting epoch:6 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:1300/8000 train_loss:2.9662 train_time:1410034ms step_avg:1084.64ms tok_s:15332 +step:1400/8000 train_loss:2.9314 train_time:1478539ms step_avg:1056.10ms tok_s:11791 +step:1500/8000 train_loss:3.0364 train_time:1552110ms step_avg:1034.74ms tok_s:9498 +step:1500/8000 val_loss:3.0885 val_bpb:1.8527 train_time:1581818ms step_avg:1054.55ms +WARNING: starting epoch:7 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:1600/8000 train_loss:3.1108 train_time:1645690ms step_avg:1028.56ms tok_s:15782 +step:1700/8000 train_loss:2.9682 train_time:1702003ms step_avg:1001.18ms tok_s:16536 +WARNING: starting epoch:8 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:1800/8000 train_loss:2.7806 train_time:1755206ms step_avg:975.11ms tok_s:16710 +step:1900/8000 train_loss:2.9828 train_time:1811708ms step_avg:953.53ms tok_s:14921 +step:2000/8000 train_loss:2.8764 train_time:1870393ms step_avg:935.20ms tok_s:12595 +step:2000/8000 val_loss:2.8842 val_bpb:1.7302 train_time:1895897ms step_avg:947.95ms +WARNING: starting epoch:9 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:2100/8000 train_loss:2.6367 train_time:1953878ms step_avg:930.42ms tok_s:13066 +step:2200/8000 train_loss:2.8326 train_time:2014671ms step_avg:915.76ms tok_s:18358 +step:2300/8000 train_loss:2.5844 train_time:2067539ms step_avg:898.93ms tok_s:16127 +WARNING: starting epoch:10 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:2400/8000 train_loss:2.7680 train_time:2119760ms step_avg:883.23ms tok_s:10726 +step:2500/8000 train_loss:2.6585 train_time:2185033ms step_avg:874.01ms tok_s:17062 +step:2500/8000 val_loss:2.7747 val_bpb:1.6645 train_time:2212966ms step_avg:885.19ms +WARNING: starting epoch:11 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:2600/8000 train_loss:2.5407 train_time:2283512ms step_avg:878.27ms tok_s:13836 +step:2700/8000 train_loss:2.6851 train_time:2343790ms step_avg:868.07ms tok_s:15095 +step:2800/8000 train_loss:2.5932 train_time:2396966ms step_avg:856.06ms tok_s:14491 +WARNING: starting epoch:12 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:2900/8000 train_loss:2.4829 train_time:2448298ms step_avg:844.24ms tok_s:14503 +step:3000/8000 train_loss:2.2790 train_time:2501150ms step_avg:833.72ms tok_s:13091 +step:3000/8000 val_loss:2.7098 val_bpb:1.6256 train_time:2520550ms step_avg:840.18ms +WARNING: starting epoch:13 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:3100/8000 train_loss:2.4814 train_time:2571753ms step_avg:829.60ms tok_s:19027 +step:3200/8000 train_loss:2.4517 train_time:2629708ms step_avg:821.78ms tok_s:10564 +step:3300/8000 train_loss:2.4467 train_time:2697068ms step_avg:817.29ms tok_s:8753 +WARNING: starting epoch:14 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:3400/8000 train_loss:2.1956 train_time:2762776ms step_avg:812.58ms tok_s:12540 +step:3500/8000 train_loss:2.3340 train_time:2812113ms step_avg:803.46ms tok_s:14590 +step:3500/8000 val_loss:2.6784 val_bpb:1.6067 train_time:2830801ms step_avg:808.80ms +WARNING: starting epoch:15 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:3600/8000 train_loss:2.3582 train_time:2880532ms step_avg:800.15ms tok_s:14076 +step:3700/8000 train_loss:2.5187 train_time:2931263ms step_avg:792.23ms tok_s:16329 +step:3800/8000 train_loss:2.3609 train_time:2989761ms step_avg:786.78ms tok_s:14400 +WARNING: starting epoch:16 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:3900/8000 train_loss:2.1710 train_time:3040262ms step_avg:779.55ms tok_s:17120 +step:4000/8000 train_loss:2.1075 train_time:3091462ms step_avg:772.87ms tok_s:15722 +step:4000/8000 val_loss:2.6269 val_bpb:1.5758 train_time:3111551ms step_avg:777.89ms +WARNING: starting epoch:17 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:4100/8000 train_loss:2.1496 train_time:3165084ms step_avg:771.97ms tok_s:17729 +step:4200/8000 train_loss:2.0959 train_time:3224621ms step_avg:767.77ms tok_s:11452 +step:4300/8000 train_loss:2.1698 train_time:3279915ms step_avg:762.77ms tok_s:18036 +WARNING: starting epoch:18 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:4400/8000 train_loss:1.6003 train_time:3328759ms step_avg:756.54ms tok_s:16514 +step:4500/8000 train_loss:2.0376 train_time:3386490ms step_avg:752.55ms tok_s:19331 +step:4500/8000 val_loss:2.6887 val_bpb:1.6129 train_time:3405814ms step_avg:756.85ms +step:4600/8000 train_loss:2.0654 train_time:3467923ms step_avg:753.90ms tok_s:14595 +WARNING: starting epoch:19 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:4700/8000 train_loss:2.1445 train_time:3523544ms step_avg:749.69ms tok_s:15673 +step:4800/8000 train_loss:2.0930 train_time:3580221ms step_avg:745.88ms tok_s:15827 +WARNING: starting epoch:20 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:4900/8000 train_loss:1.9573 train_time:3650239ms step_avg:744.95ms tok_s:8644 +step:5000/8000 train_loss:1.9470 train_time:3731062ms step_avg:746.21ms tok_s:12773 +step:5000/8000 val_loss:2.7559 val_bpb:1.6532 train_time:3757833ms step_avg:751.57ms +step:5100/8000 train_loss:1.9906 train_time:3840018ms step_avg:752.94ms tok_s:7147 +WARNING: starting epoch:21 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:5200/8000 train_loss:1.9154 train_time:3909015ms step_avg:751.73ms tok_s:8294 +step:5300/8000 train_loss:1.9134 train_time:4013237ms step_avg:757.21ms tok_s:8228 +WARNING: starting epoch:22 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:5400/8000 train_loss:1.8514 train_time:4098439ms step_avg:758.97ms tok_s:8227 +step:5500/8000 train_loss:1.8591 train_time:4194148ms step_avg:762.57ms tok_s:13920 +step:5500/8000 val_loss:2.5653 val_bpb:1.5389 train_time:4228188ms step_avg:768.76ms +step:5600/8000 train_loss:1.7537 train_time:4294676ms step_avg:766.91ms tok_s:11977 +WARNING: starting epoch:23 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:5700/8000 train_loss:1.5535 train_time:4376402ms step_avg:767.79ms tok_s:7426 +step:5800/8000 train_loss:2.0198 train_time:4480895ms step_avg:772.57ms tok_s:7803 +WARNING: starting epoch:24 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:5900/8000 train_loss:1.8775 train_time:4594944ms step_avg:778.80ms tok_s:7498 +step:6000/8000 train_loss:1.7054 train_time:4706496ms step_avg:784.42ms tok_s:8737 +step:6000/8000 val_loss:2.6019 val_bpb:1.5608 train_time:4749245ms step_avg:791.54ms +step:6100/8000 train_loss:1.8401 train_time:4857864ms step_avg:796.37ms tok_s:7540 +WARNING: starting epoch:25 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:6200/8000 train_loss:1.7785 train_time:4966308ms step_avg:801.02ms tok_s:7757 +step:6300/8000 train_loss:1.5862 train_time:5057735ms step_avg:802.82ms tok_s:11338 +WARNING: starting epoch:26 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:6400/8000 train_loss:1.6917 train_time:5123139ms step_avg:800.49ms tok_s:13098 +step:6500/8000 train_loss:1.8193 train_time:5185412ms step_avg:797.76ms tok_s:19089 +step:6500/8000 val_loss:2.4147 val_bpb:1.4485 train_time:5204514ms step_avg:800.69ms +step:6600/8000 train_loss:1.3321 train_time:5253440ms step_avg:795.98ms tok_s:15168 +WARNING: starting epoch:27 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:6700/8000 train_loss:1.6324 train_time:5304115ms step_avg:791.66ms tok_s:16031 +step:6800/8000 train_loss:1.5926 train_time:5354648ms step_avg:787.45ms tok_s:15180 +step:6900/8000 train_loss:1.5886 train_time:5407127ms step_avg:783.64ms tok_s:16321 +WARNING: starting epoch:28 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:7000/8000 train_loss:1.4824 train_time:5461030ms step_avg:780.15ms tok_s:16057 +step:7000/8000 val_loss:2.1185 val_bpb:1.2708 train_time:5482548ms step_avg:783.22ms +step:7100/8000 train_loss:1.3975 train_time:5537858ms step_avg:779.98ms tok_s:16740 +WARNING: starting epoch:29 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:7200/8000 train_loss:1.4491 train_time:5593425ms step_avg:776.86ms tok_s:13056 +step:7300/8000 train_loss:1.1825 train_time:5648820ms step_avg:773.81ms tok_s:13631 +step:7400/8000 train_loss:1.4061 train_time:5702472ms step_avg:770.60ms tok_s:14613 +WARNING: starting epoch:30 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:7500/8000 train_loss:1.2123 train_time:5757516ms step_avg:767.67ms tok_s:11877 +step:7500/8000 val_loss:1.9025 val_bpb:1.1413 train_time:5779150ms step_avg:770.55ms +step:7600/8000 train_loss:1.1545 train_time:5834539ms step_avg:767.70ms tok_s:14475 +WARNING: starting epoch:31 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:7700/8000 train_loss:1.1700 train_time:5895871ms step_avg:765.70ms tok_s:13808 +step:7800/8000 train_loss:1.0472 train_time:5954373ms step_avg:763.38ms tok_s:12287 +step:7900/8000 train_loss:1.0992 train_time:6014867ms step_avg:761.38ms tok_s:15957 +WARNING: starting epoch:32 dataset:fineweb10B_sp1024_valonly train_shards:1 +step:8000/8000 train_loss:0.9468 train_time:6080684ms step_avg:760.09ms tok_s:11609 +step:8000/8000 val_loss:1.3400 val_bpb:0.8039 train_time:6105211ms step_avg:763.15ms +saved_model:logs/4da3a6be-f4c9-4b2c-a0f8-084f0e5751df_mlx_model.npz bytes:67212188 +serialized_model_int8_zlib:15412175 bytes (payload:17178912 raw_pickle:17188360 payload_ratio:3.91x) +final_eval_mode:sliding_window sw_stride:64 seq_len:1024 +final_int8_zlib_roundtrip val_loss:1.2018 val_bpb:0.7209 eval_time:388167ms +final_int8_zlib_roundtrip_exact val_loss:1.20176589 val_bpb:0.72092014 diff --git a/records/track_non_record_16mb/2026-03-19_ValOnly_SlidingWindow_mlx/train_gpt.py b/records/track_non_record_16mb/2026-03-19_ValOnly_SlidingWindow_mlx/train_gpt.py new file mode 100644 index 000000000..51e771661 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_ValOnly_SlidingWindow_mlx/train_gpt.py @@ -0,0 +1,1247 @@ +#!/usr/bin/env python3 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" +from __future__ import annotations + +import glob +import json +import math +import os +import pickle +import sys +import time +import uuid +import zlib +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import sentencepiece as spm + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +from mlx.utils import tree_flatten, tree_unflatten + +# ============================================================================== +# SHARD FORMAT + COMPUTE DTYPE +# ============================================================================== + +COMPUTE_DTYPE = mx.bfloat16 + +# ============================================================================== +# HYPERPARAMETERS +# ============================================================================== +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap +class Hyperparameters: + # Data / tokenizer. + data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id: str = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed: int = int(os.environ.get("SEED", 1337)) + + # Training loop. These defaults now mirror train_gpt.py on a single process. + iterations: int = int(os.environ.get("ITERATIONS", 20_000)) + val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0)) + # Validation always uses the full fineweb_val split. + val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8)) + train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024))) + # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak + # memory pressure without changing the effective optimizer batch. + mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) + warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) + warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200)) + max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + + # Model (defaults match the current baseline setup). + vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers: int = int(os.environ.get("NUM_LAYERS", 9)) + model_dim: int = int(os.environ.get("MODEL_DIM", 512)) + num_heads: int = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4)) + mlp_mult: int = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0)) + + # v2 improvements: SwiGLU activation + QAT + use_swiglu: bool = bool(int(os.environ.get("USE_SWIGLU", "0"))) + use_qat: bool = bool(int(os.environ.get("USE_QAT", "0"))) + sw_stride: int = int(os.environ.get("SW_STRIDE", "0")) # sliding window eval stride (0=disabled, 64=typical) + logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0)) + qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Optimizer. We keep the same per-group defaults as train_gpt.py. + beta1: float = float(os.environ.get("BETA1", 0.9)) + beta2: float = float(os.environ.get("BETA2", 0.95)) + adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8)) + tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05)) + matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + out_dir: str = os.environ.get("OUT_DIR", "logs") + + @property + def train_files(self) -> str: + return f"{self.data_path}/fineweb_train_*.bin" + + @property + def val_files(self) -> str: + return f"{self.data_path}/fineweb_val_*.bin" + + @property + def microbatch_tokens(self) -> int: + return self.train_batch_tokens // self.grad_accum_steps + + def lr_mul(self, step: int, elapsed_ms: float) -> float: + if self.warmdown_iters <= 0: + return 1.0 + if self.max_wallclock_seconds <= 0: + warmdown_start = max(self.iterations - self.warmdown_iters, 0) + return max((self.iterations - step) / max(self.warmdown_iters, 1), 0.0) if warmdown_start <= step < self.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = self.warmdown_iters * step_ms + remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) + + +def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]: + usable_total = (total_tokens // seq_len) * seq_len + if usable_total <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len) + chunks: list[int] = [] + remaining = usable_total + while remaining > 0: + chunk = min(remaining, usable_chunk) + chunks.append(chunk) + remaining -= chunk + return chunks + + +def accumulate_flat_grads( + accum: dict[str, mx.array] | None, + grads_tree: dict, + scale: float, +) -> dict[str, mx.array]: + flat = dict(tree_flatten(grads_tree)) + if accum is None: + return {k: g * scale for k, g in flat.items()} + for k, g in flat.items(): + accum[k] = accum[k] + g * scale + return accum + + +# ============================================================================== +# MATH HELPERS +# ============================================================================== + +def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: + return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype) + + +def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + # Background on Muon: https://kellerjordan.github.io/posts/muon/ + a, b, c = 3.4445, -4.7750, 2.0315 + x = g.astype(mx.float32) + x = x / (mx.sqrt(mx.sum(x * x)) + eps) + transposed = x.shape[0] > x.shape[1] + if transposed: + x = x.T + for _ in range(steps): + a_mat = x @ x.T + b_mat = b * a_mat + c * (a_mat @ a_mat) + x = a * x + b_mat @ x + if transposed: + x = x.T + return x.astype(g.dtype) + + +def fake_quantize_per_row(w: mx.array) -> mx.array: + """Simulate per-row int8 quantization with Straight-Through Estimator (STE). + Only applied to matrices large enough to be int8-quantized at export (numel > 65536). + Forward pass sees dequantized weights; gradients pass through as identity (STE). + """ + if w.ndim != 2 or w.size <= INT8_KEEP_FLOAT_MAX_NUMEL: + return w + w_f32 = w.astype(mx.float32) + scale = mx.maximum(mx.max(mx.abs(w_f32), axis=1, keepdims=True) / 127.0, mx.array(1.0 / 127.0)) + w_q = mx.round(mx.clip(w_f32 / scale, -127.0, 127.0)) * scale + # STE: forward=dequantized, backward=identity + return (mx.stop_gradient(w_q - w_f32) + w_f32).astype(w.dtype) + + +def load_data_shard(path: Path) -> np.ndarray: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + if self.file_idx == 0: + self.epoch += 1 + if self.log_fn is not None: + self.log_fn( + f"WARNING: starting epoch:{self.epoch} " + f"dataset:{self.dataset_name} train_shards:{len(self.files)}" + ) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> np.ndarray: + chunks: list[np.ndarray] = [] + left = n + while left > 0: + if self.pos >= self.tokens.size: + self.next_file() + k = min(left, int(self.tokens.size - self.pos)) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + left -= k + return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0) + + +class TokenLoader: + def __init__( + self, + pattern: str, + log_fn: Callable[[str], None] | None = None, + dataset_name: str = "", + ): + self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name) + + def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]: + usable = (batch_tokens // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + chunk = self.stream.take(usable + 1) + x = chunk[:-1].reshape(-1, seq_len) + y = chunk[1:].reshape(-1, seq_len) + return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) + + +# ============================================================================== +# MODEL BLOCKS +# ============================================================================== + +class CastedLinear(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) + + def __call__(self, x: mx.array, training: bool = False) -> mx.array: + w = fake_quantize_per_row(self.weight) if training else self.weight + return x @ w.astype(x.dtype).T + + +class RMSNormNoWeight(nn.Module): + # MLX module wrapper around the functional RMSNorm helper so it composes nicely in blocks. + def __call__(self, x: mx.array) -> mx.array: + return rms_norm(x) + + +class CausalSelfAttention(nn.Module): + # - separate q/k/v projections + # - RMSNorm on q and k before attention + # - RoPE on q and k + # - causal masked SDPA + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim) + self.c_k = CastedLinear(dim, kv_dim) + self.c_v = CastedLinear(dim, kv_dim) + self.proj = CastedLinear(dim, dim) + self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init + self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) + self.scale = self.head_dim ** -0.5 + + def __call__(self, x: mx.array, training: bool = False) -> mx.array: + bsz, seqlen, dim = x.shape + q = self.c_q(x, training=training).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + k = self.c_k(x, training=training).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + v = self.c_v(x, training=training).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + + q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) + k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) + q = q * self.q_gain.astype(q.dtype)[None, :, None, None] + y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") + y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) + return self.proj(y, training=training) + + +class MLP(nn.Module): + # Supports relu^2 (baseline) or SwiGLU (USE_SWIGLU=1). + # SwiGLU uses hidden = (dim * mlp_mult * 2) // 3 to keep parameter count neutral. + def __init__(self, dim: int, mlp_mult: int, use_swiglu: bool = False): + super().__init__() + self.use_swiglu = use_swiglu + if use_swiglu: + # 2/3 hidden keeps param count (~same as relu^2 with mlp_mult) + hidden = max(4, (dim * mlp_mult * 2 + 2) // 3) + self.gate = CastedLinear(dim, hidden) + self.up = CastedLinear(dim, hidden) + self.proj = CastedLinear(hidden, dim) + else: + hidden = dim * mlp_mult + self.fc = CastedLinear(dim, hidden) + self.proj = CastedLinear(hidden, dim) + + def __call__(self, x: mx.array, training: bool = False) -> mx.array: + if self.use_swiglu: + gate = nn.silu(self.gate(x, training=training)) + up = self.up(x, training=training) + return self.proj(gate * up, training=training) + x = nn.relu(self.fc(x, training=training)) + return self.proj(x * x, training=training) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNormNoWeight() + self.mlp_norm = RMSNormNoWeight() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, use_swiglu=False) # overridden in GPT.__init__ + self.attn_scale = mx.ones((dim,), dtype=mx.float32) + self.mlp_scale = mx.ones((dim,), dtype=mx.float32) + self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) + + def __call__(self, x: mx.array, x0: mx.array, training: bool = False) -> mx.array: + mix = self.resid_mix.astype(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x), training=training) + x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x), training=training) + return x + + +class GPT(nn.Module): + # - token embedding + RMSNorm + # - encoder half accumulates skip tensors + # - decoder half consumes reversed skips with learned skip_weights + # - tied embeddings for the LM head (the baseline default setup) + def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + logit_chunk_tokens: int, logit_softcap: float, rope_base: float, tied_embed_init_std: float, + qk_gain_init: float, use_swiglu: bool = False): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.logit_chunk_tokens = logit_chunk_tokens + self.logit_softcap = logit_softcap + + self.tok_emb = nn.Embedding(vocab_size, dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32) + self.blocks = [ + Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for i in range(num_layers) + ] + # Patch MLP activation after blocks are created + if use_swiglu: + for b in self.blocks: + b.mlp = MLP(dim, mlp_mult, use_swiglu=True) + self.final_norm = RMSNormNoWeight() + + for b in self.blocks: + b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight) + b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) + self.tok_emb.weight = ( + mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std + ).astype(COMPUTE_DTYPE) + + def softcap(self, logits: mx.array) -> mx.array: + c = self.logit_softcap + return c * mx.tanh(logits / c) + + def __call__(self, input_ids: mx.array, training: bool = False) -> mx.array: + x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) + x0 = x + skips: list[mx.array] = [] + + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0, training=training) + skips.append(x) + for i in range(self.num_decoder_layers): + # Odd layer counts have one more decoder block than encoder block. The baseline only + # applies a skip connection when one exists, then runs the remaining decoder block(s) + # without an added skip. + if skips: + x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0, training=training) + return self.final_norm(x) + + def _loss(self, input_ids: mx.array, target_ids: mx.array, training: bool) -> mx.array: + # Cross-entropy over flattened tokens. training=True enables QAT fake-quantization. + x = self(input_ids, training=training).reshape(-1, self.tok_emb.weight.shape[1]) + y = target_ids.reshape(-1) + # Apply QAT to tok_emb.weight in LM-head projection (it gets int8-quantized at export) + lm_w = fake_quantize_per_row(self.tok_emb.weight) if training else self.tok_emb.weight + if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: + logits_proj = x @ lm_w.astype(x.dtype).T + logits = self.softcap(logits_proj) + return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") + + loss_sum = mx.array(0.0, dtype=mx.float32) + n = int(x.shape[0]) + for s in range(0, n, self.logit_chunk_tokens): + e = min(s + self.logit_chunk_tokens, n) + logits_proj = x[s:e] @ lm_w.astype(x.dtype).T + logits = self.softcap(logits_proj) + loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum") + return loss_sum / float(n) + + def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: + """Training loss — uses QAT fake-quantization when enabled.""" + return self._loss(input_ids, target_ids, training=True) + + def eval_loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: + """Eval loss — always uses full-precision weights (no QAT).""" + return self._loss(input_ids, target_ids, training=False) + + def partial_loss_sum(self, input_ids: mx.array, target_ids: mx.array, eval_from: int) -> mx.array: + """Sum of cross-entropy for positions [eval_from:] per sequence. Eval mode (no QAT). + Used by sliding window eval to count only the new tokens with full context.""" + x = self(input_ids, training=False) # [B, S, D] + bsz, seqlen, d = x.shape + x_eval = x[:, eval_from:, :].reshape(-1, d) + y_eval = target_ids[:, eval_from:].reshape(-1) + lm_w = self.tok_emb.weight.astype(x_eval.dtype) + logits = self.softcap(x_eval @ lm_w.T) + return nn.losses.cross_entropy(logits.astype(mx.float32), y_eval, reduction="sum") + +# ============================================================================== +# OPTIMIZERS (MUON + ADAM SPLIT) +# ============================================================================== +class Muon: + # Muon applies SGD-momentum to matrix gradients, then orthogonalizes the result before the + # parameter update. + def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters): + self.keys = keys + self.args = args + self.buffers = {k: mx.zeros_like(params[k]) for k in keys} + + def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]: + if self.args.muon_momentum_warmup_steps: + t = min(step / self.args.muon_momentum_warmup_steps, 1.0) + momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum + else: + momentum = self.args.muon_momentum + lr = self.args.matrix_lr * lr_mul + out: dict[str, mx.array] = {} + for k in self.keys: + p = params[k] + g = grads[k] + buf = momentum * self.buffers[k] + g + self.buffers[k] = buf + g_eff = g + momentum * buf + g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps) + scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1]))) + out[k] = p - lr * (g_ortho * scale).astype(p.dtype) + return out + + +class SplitOptimizers: + # - embeddings: Adam with the tied-embedding LR + # - block matrices (2D): Muon + # - block scalars + skip weights: Adam + # This preserves the high-level optimization behavior even though MLX internals differ. + def __init__(self, model: GPT, args: Hyperparameters): + self.args = args + params = dict(tree_flatten(model.parameters())) + self.embed_key = "tok_emb.weight" + self.matrix_keys = [ + k + for k, p in params.items() + if k.startswith("blocks.") and p.ndim == 2 and not any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + self.scalar_keys = [ + k + for k, p in params.items() + if k == "skip_weights" or (k.startswith("blocks.") and (p.ndim < 2 or any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS))) + ] + + self.muon = Muon(self.matrix_keys, params, args) + self.adam_embed = optim.Adam( + learning_rate=args.tied_embed_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + self.adam_scalar = optim.Adam( + learning_rate=args.scalar_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + + def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None: + params = dict(tree_flatten(model.parameters())) + grads = dict(tree_flatten(grads_tree)) + updated = dict(params) + + updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul)) + + self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul + updated.update( + self.adam_embed.apply_gradients( + {self.embed_key: grads[self.embed_key]}, + {self.embed_key: params[self.embed_key]}, + ) + ) + + self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul + scalar_grads = {k: grads[k] for k in self.scalar_keys} + scalar_params = {k: params[k] for k in self.scalar_keys} + updated.update(self.adam_scalar.apply_gradients(scalar_grads, scalar_params)) + + model.update(tree_unflatten(list(updated.items()))) + +# ============================================================================== +# QUANTIZATION (INT8 + ZLIB) +# ============================================================================== +# - per-row int8 for 2D float tensors +# - per-tensor int8 for other float tensors +# - fp16 passthrough for small float tensors +# - exact passthrough for non-floats + +MX_DTYPE_FROM_NAME = { + "float32": mx.float32, + "float16": mx.float16, + "bfloat16": mx.bfloat16, +} + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = np.float16 +INT8_PER_ROW_SCALE_DTYPE = np.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + + +def _np_float32(arr: mx.array) -> np.ndarray: + return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False) + + +def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return np.ascontiguousarray(_np_float32(arr)) + if arr.dtype in {mx.float32, mx.bfloat16}: + passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1] + return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT8_KEEP_FLOAT_STORE_DTYPE, copy=False)) + return np.ascontiguousarray(np.array(arr, copy=True)) + + +def quantize_float_array(arr: mx.array) -> tuple[np.ndarray, np.ndarray]: + f32 = _np_float32(arr) + if f32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32) + clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) + scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32, copy=False) + q = np.clip(np.round(clipped / scale[:, None]), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), np.ascontiguousarray(scale.astype(INT8_PER_ROW_SCALE_DTYPE, copy=False)) + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT8_CLIP_Q)) if f32.size else 0.0 + scale = np.array(clip_abs / 127.0 if clip_abs > 0.0 else 1.0, dtype=np.float32) + q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), scale + + +def quantize_state_dict_int8(flat_state: dict[str, mx.array]) -> tuple[dict[str, object], dict[str, int]]: + quantized: dict[str, np.ndarray] = {} + scales: dict[str, np.ndarray] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, np.ndarray] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, arr in flat_state.items(): + stats["param_count"] += int(arr.size) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += int(arr.nbytes) + if not mx.issubdtype(arr.dtype, mx.floating): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = np.ascontiguousarray(np.array(arr)) + stats["int8_payload_bytes"] += int(passthrough[name].nbytes) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if int(arr.size) <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_array(name, arr, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += int(kept.nbytes) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_array(arr) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(arr.dtype).split(".")[-1] + stats["int8_payload_bytes"] += int(q.nbytes + s.nbytes) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def dequantize_state_dict_int8(quant_obj: dict[str, object]) -> dict[str, mx.array]: + out: dict[str, mx.array] = {} + qmeta = quant_obj.get("qmeta", {}) + passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {}) + for name, q in quant_obj["quantized"].items(): + q_np = np.asarray(q, dtype=np.int8) + dtype_name = quant_obj["dtypes"][name] + scale = np.asarray(quant_obj["scales"][name], dtype=np.float32) + if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0: + # Broadcast the saved row scale back across trailing dimensions. + out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1)) + else: + out_arr = q_np.astype(np.float32) * float(scale) + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name]) + for name, arr in quant_obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_arr = np.array(arr, copy=True) + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype]) + else: + out[name] = mx.array(out_arr) + return out + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_lut = np.zeros((table_size,), dtype=np.int16) + has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_lut[token_id] = False + if sp.is_byte(token_id): + base_bytes_lut[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_lut[token_id] = True + piece = piece[1:] + base_bytes_lut[token_id] = len(piece.encode("utf-8")) + return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut + + +def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]: + # The shard directory and tokenizer are coupled: val_bpb is only meaningful if we + # decode bytes with the exact tokenizer that produced the shards. The manifest + # lets the training script fail fast on accidental dataset/tokenizer mismatches. + dataset_dir = Path(data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + if len(dataset_dir.parents) < 2: + return dataset_dir.name, actual_train_files, None + manifest_path = dataset_dir.parents[1] / "manifest.json" + if not manifest_path.is_file(): + return dataset_dir.name, actual_train_files, None + + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None) + if dataset_entry is None: + return dataset_dir.name, actual_train_files, None + + tokenizer_name = dataset_entry.get("tokenizer_name") + tokenizer_entry = ( + next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) + if tokenizer_name + else None + ) + expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name + if expected_name and Path(tokenizer_path).name != expected_name: + raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}") + expected_train_files = (dataset_entry.get("stats") or {}).get("files_train") + if expected_train_files is not None: + expected_train_files = int(expected_train_files) + if actual_train_files > expected_train_files: + raise ValueError( + f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, " + f"manifest says {expected_train_files}" + ) + return dataset_dir.name, actual_train_files, expected_train_files + + +def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0)) + usable = ((tokens.size - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def loss_and_grad_chunked( + args: Hyperparameters, + train_loader: TokenLoader, + compiled_loss_and_grad, +) -> tuple[mx.array, dict]: + chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens) + total_tokens = float(sum(chunk_sizes)) + loss_value = mx.array(0.0, dtype=mx.float32) + grad_accum: dict[str, mx.array] | None = None + for chunk_tokens in chunk_sizes: + x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len) + loss, grads = compiled_loss_and_grad(x, y) + scale = float(y.size) / total_tokens + loss_value = loss_value + loss.astype(mx.float32) * scale + grad_accum = accumulate_flat_grads(grad_accum, grads, scale) + return loss_value, tree_unflatten(list(grad_accum.items())) + + +def eval_val( + args: Hyperparameters, + compiled_loss, + val_tokens: np.ndarray, + base_bytes_lut: np.ndarray, + has_leading_space_lut: np.ndarray, + is_boundary_token_lut: np.ndarray, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + val_batch_seqs = val_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.size - 1) // args.train_seq_len + total_loss = mx.array(0.0, dtype=mx.float32) + total_tokens = 0.0 + total_bytes = 0.0 + for batch_seq_start in range(0, total_seqs, val_batch_seqs): + batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + chunk = val_tokens[raw_start:raw_end] + x_np = chunk[:-1].reshape(-1, args.train_seq_len) + y_np = chunk[1:].reshape(-1, args.train_seq_len) + x = mx.array(x_np, dtype=mx.int32) + y = mx.array(y_np, dtype=mx.int32) + chunk_token_count = float(y.size) + total_loss = total_loss + compiled_loss(x, y).astype(mx.float32) * chunk_token_count + prev_ids = x_np.reshape(-1) + tgt_ids = y_np.reshape(-1) + bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True) + bytes_np += ( + has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] + ).astype(np.int16, copy=False) + total_tokens += chunk_token_count + total_bytes += float(bytes_np.astype(np.float64).sum()) + total_loss = total_loss / total_tokens + mx.eval(total_loss) + val_loss = float(total_loss.item()) + bits_per_token = val_loss / math.log(2.0) + val_bpb = bits_per_token * (total_tokens / total_bytes) + return val_loss, val_bpb + +def eval_val_sliding_window( + args: Hyperparameters, + model, + val_tokens: np.ndarray, + base_bytes_lut: np.ndarray, + has_leading_space_lut: np.ndarray, + is_boundary_token_lut: np.ndarray, +) -> tuple[float, float]: + """Sliding window eval: every token (after the first window) is evaluated with + seq_len - sw_stride tokens of left context instead of the standard random 0..seq_len-1. + First window: loss counted over all seq_len positions (full context warmup). + Subsequent windows: slide by sw_stride, count loss only for the last sw_stride positions. + Total tokens evaluated ≈ same as standard eval; bpb improves because context is richer.""" + seq_len = args.train_seq_len + stride = args.sw_stride + eval_from = seq_len - stride + n_val = val_tokens.size + + # Build two compiled eval functions once per call (MLX compiles on first invocation). + compiled_full = mx.compile( + lambda x, y: model.partial_loss_sum(x, y, 0), + inputs=model.state, outputs=model.state, + ) + compiled_stride_eval = mx.compile( + lambda x, y: model.partial_loss_sum(x, y, eval_from), + inputs=model.state, outputs=model.state, + ) + + # Window i starts at token offset i * stride. + # We need: i * stride + seq_len + 1 <= n_val → i <= (n_val - seq_len - 1) // stride + n_windows = (n_val - seq_len - 1) // stride + 1 + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + batch_seqs = max(1, val_batch_tokens // seq_len) + + total_loss_sum = mx.array(0.0, dtype=mx.float32) + total_tokens = 0.0 + total_bytes = 0.0 + + for batch_start in range(0, n_windows, batch_seqs): + batch_end = min(batch_start + batch_seqs, n_windows) + # Build a numpy batch of windows + xs = np.stack([val_tokens[wi * stride: wi * stride + seq_len] for wi in range(batch_start, batch_end)]) + ys = np.stack([val_tokens[wi * stride + 1: wi * stride + seq_len + 1] for wi in range(batch_start, batch_end)]) + x = mx.array(xs, dtype=mx.int32) + y = mx.array(ys, dtype=mx.int32) + + if batch_start == 0: + # First window: count all seq_len positions. + ls_full = compiled_full(x[:1], y[:1]) + total_loss_sum = total_loss_sum + ls_full + prev_ids_0 = xs[0] + tgt_ids_0 = ys[0] + bts = base_bytes_lut[tgt_ids_0].astype(np.int16, copy=True) + bts += (has_leading_space_lut[tgt_ids_0] & ~is_boundary_token_lut[prev_ids_0]).astype(np.int16, copy=False) + total_tokens += float(seq_len) + total_bytes += float(bts.astype(np.float64).sum()) + mx.eval(total_loss_sum) + # Remaining windows in this batch (if any) use stride eval. + if batch_end > 1: + ls_rest = compiled_stride_eval(x[1:], y[1:]) + total_loss_sum = total_loss_sum + ls_rest + for wi in range(1, batch_end - batch_start): + p_ids = xs[wi, eval_from:] + t_ids = ys[wi, eval_from:] + bts = base_bytes_lut[t_ids].astype(np.int16, copy=True) + bts += (has_leading_space_lut[t_ids] & ~is_boundary_token_lut[p_ids]).astype(np.int16, copy=False) + total_tokens += float(stride) + total_bytes += float(bts.astype(np.float64).sum()) + mx.eval(total_loss_sum) + else: + ls = compiled_stride_eval(x, y) + total_loss_sum = total_loss_sum + ls + for wi in range(batch_end - batch_start): + p_ids = xs[wi, eval_from:] + t_ids = ys[wi, eval_from:] + bts = base_bytes_lut[t_ids].astype(np.int16, copy=True) + bts += (has_leading_space_lut[t_ids] & ~is_boundary_token_lut[p_ids]).astype(np.int16, copy=False) + total_tokens += float(stride) + total_bytes += float(bts.astype(np.float64).sum()) + mx.eval(total_loss_sum) + + total_loss_avg = total_loss_sum / total_tokens + mx.eval(total_loss_avg) + val_loss = float(total_loss_avg.item()) + bits_per_token = val_loss / math.log(2.0) + val_bpb = bits_per_token * (total_tokens / total_bytes) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict: + if max_norm <= 0: + return grads_tree + flat = dict(tree_flatten(grads_tree)) + total_sq = 0.0 + for grad in flat.values(): + total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64)) + if total_sq <= 0.0: + return grads_tree + total_norm = math.sqrt(total_sq) + if total_norm <= max_norm: + return grads_tree + scale = max_norm / (total_norm + 1e-12) + return tree_unflatten([(k, g * scale) for k, g in flat.items()]) + + +def main() -> None: + # ============================================================================== + # TOKENIZER + VALIDATION METRIC SETUP + # ============================================================================== + args = Hyperparameters() + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + logfile = out_dir / f"{args.run_id}.txt" + print(logfile) + + def log(msg: str, console: bool = True) -> None: + if console: + print(msg) + with logfile.open("a", encoding="utf-8") as f: + print(msg, file=f) + + code = Path(__file__).read_text(encoding="utf-8") + log(code, console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running MLX {mx.__version__}", console=False) + log("=" * 100, console=False) + + if not args.tie_embeddings: + raise NotImplementedError("train_gpt_mlx.py only supports tied embeddings") + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair( + args.data_path, + args.tokenizer_path, + ) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size + ) + + # ============================================================================== + # TRAINING SETUP + # ============================================================================== + mx.random.seed(args.seed) + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + # ============================================================================== + # MODEL + OPTIMIZER SETUP + # ============================================================================== + model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + logit_chunk_tokens=args.logit_chunk_tokens, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + tied_embed_init_std=args.tied_embed_init_std, + qk_gain_init=args.qk_gain_init, + use_swiglu=args.use_swiglu, + ) + opt = SplitOptimizers(model, args) + + # ============================================================================== + # COMPILED TRAIN / EVAL FUNCTIONS (MLX) + # ============================================================================== + # The crucial MLX detail is capture scope: this model contains non-trainable arrays too (for example + # inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs". + # Compiling the model-bound functions and capturing the full model state fixes that while still + # returning gradients only for trainable parameters via nn.value_and_grad(...). + # eval always uses full-precision weights; training uses QAT when USE_QAT=1 + compiled_loss = mx.compile(lambda x, y: model.eval_loss(x, y), inputs=model.state, outputs=model.state) + compiled_loss_and_grad = mx.compile( + nn.value_and_grad(model, lambda x, y: model.loss(x, y) if args.use_qat else model.eval_loss(x, y)), + inputs=model.state, + outputs=model.state, + ) + + # Print config once so logs are self-describing. + n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) + log(f"run_id:{args.run_id}") + log(f"mlx_version:{mx.__version__}") + log(f"train_loader:shards pattern={args.train_files}") + log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}") + if expected_train_files is None: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}") + elif actual_train_files < expected_train_files: + log( + f"WARNING: train_loader:subset dataset:{dataset_name} " + f"train_shards:{actual_train_files}/{expected_train_files} " + f"new epochs will arrive sooner than the full dataset" + ) + else: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}") + log(f"tokenizer_path:{args.tokenizer_path}") + log( + f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} " + f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} " + f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}" + ) + log( + f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} " + f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} " + f"val_batch_size:{args.val_batch_size} " + f"warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}") + log( + f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} " + f"embed_lr:{args.tied_embed_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " + f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}" + ) + log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log(f"compute_dtype:{COMPUTE_DTYPE} compile:True") + log(f"v2_features swiglu:{args.use_swiglu} qat:{args.use_qat}") + log( + f"dtypes tok_emb:{model.tok_emb.weight.dtype} " + f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} " + f"skip_weights:{model.skip_weights.dtype}" + ) + + # ============================================================================== + # TRAINING LOOP + # ============================================================================== + if args.warmup_steps > 0: + # Warmup should only prime MLX compile/allocation paths. Updating parameters here forces us + # to snapshot and restore model/optimizer state, which is expensive on unified-memory Macs. + # Instead we run the real train shapes, force the loss/grads to materialize, and then reset + # the loader so measured training still starts from the true init and token window. + for warmup_step in range(args.warmup_steps): + accum: dict[str, mx.array] | None = None + warmup_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + mx.eval(warmup_loss, accum) + mx.synchronize() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + + # Prime the standalone eval graph once too. It is compiled separately from value_and_grad. + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + warm_val_seqs = min(val_batch_tokens // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len) + warm_chunk = val_tokens[: warm_val_seqs * args.train_seq_len + 1] + x_val = mx.array(warm_chunk[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32) + y_val = mx.array(warm_chunk[1:].reshape(-1, args.train_seq_len), dtype=mx.int32) + warm_val_loss = compiled_loss(x_val, y_val) + mx.eval(warm_val_loss) + mx.synchronize() + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + train_time_ms = 0.0 + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + stop_after_step: int | None = None + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # Validation always scans the same fixed full validation split. + val_loss, val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + train_time_ms += 1000.0 * (time.perf_counter() - t0) + if step % 25 == 0 or last_step: + log( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms / max(step, 1):.2f}ms" + ) + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0)) + step_t0 = time.perf_counter() + + accum: dict[str, mx.array] | None = None + train_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + train_loss = train_loss + loss.astype(mx.float32) * grad_scale + + grads = tree_unflatten(list(accum.items())) + grads = clip_grad_tree(grads, args.grad_clip_norm) + train_loss_value = float(train_loss.item()) + opt.step(model, grads, step=step, lr_mul=lr_mul) + mx.synchronize() + + step_ms = 1000.0 * (time.perf_counter() - step_t0) + approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) + tok_s = args.train_batch_tokens / (step_ms / 1000.0) + step += 1 + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): + log( + f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} " + f"train_time:{approx_train_time_ms:.0f}ms step_avg:{approx_train_time_ms / step:.2f}ms tok_s:{tok_s:.0f}" + ) + if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms: + stop_after_step = step + + # ============================================================================== + # FINAL SERIALIZATION + QUANTIZED ROUNDTRIP EVAL + # ============================================================================== + # We always write a raw artifact and a quantized artifact, then validate the + # quantized roundtrip directly by loading the dequantized tensors back into the + # model and running one final validation pass. + out_path = out_dir / f"{args.run_id}_mlx_model.npz" + flat_state = {k: v for k, v in tree_flatten(model.state)} + mx.savez(str(out_path), **flat_state) + log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}") + + quant_obj, quant_stats = quantize_state_dict_int8(flat_state) + quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL) + quant_blob = zlib.compress(quant_raw, level=9) + quant_serialized_bytes = len(quant_raw) + quant_path = out_dir / f"{args.run_id}_mlx_model.int8.ptz" + with quant_path.open("wb") as f: + f.write(quant_blob) + quant_file_bytes = quant_path.stat().st_size + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log( + f"serialized_model_int8_zlib:{quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_pickle:{quant_serialized_bytes} payload_ratio:{ratio:.2f}x)" + ) + + with quant_path.open("rb") as f: + quant_blob_disk = f.read() + quant_flat = dequantize_state_dict_int8(pickle.loads(zlib.decompress(quant_blob_disk))) + model.update(tree_unflatten(list(quant_flat.items()))) + q_t0 = time.perf_counter() + if args.sw_stride > 0: + log(f"final_eval_mode:sliding_window sw_stride:{args.sw_stride} seq_len:{args.train_seq_len}") + q_val_loss, q_val_bpb = eval_val_sliding_window( + args, + model, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + else: + q_val_loss, q_val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + q_eval_ms = 1000.0 * (time.perf_counter() - q_t0) + log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms") + log(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + +if __name__ == "__main__": + main()