From d71f5da72fd7419f729265b4242a9652d17ffc33 Mon Sep 17 00:00:00 2001 From: xexyz Date: Wed, 25 Mar 2026 01:05:38 -0600 Subject: [PATCH 1/3] Add PR #414 + 30-epoch cosine TTT submission (1.0988 BPB) 30-epoch cosine pre-eval Test-Time Training on PR #414 consensus stack. Adapts quantized model on validation data before sliding-window eval. - Pre-TTT post-quant: 1.1594 BPB - Post-TTT sliding (stride=64): 1.0988 BPB - Total artifact: 15,900,191 bytes (under 16MB) - 5434 training steps + 30ep TTT + sliding eval on 8xH100 Built on PR #414 by @signalrush. TTT recipe from PR #518/@sofiabod, PR #672/@andrewbaggio1. --- .../README.md | 54 + .../submission.json | 11 + .../train.log | 97 ++ .../train_gpt.py | 1480 +++++++++++++++++ 4 files changed, 1642 insertions(+) create mode 100644 records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/README.md create mode 100644 records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/submission.json create mode 100644 records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train.log create mode 100644 records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train_gpt.py diff --git a/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/README.md b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/README.md new file mode 100644 index 000000000..f5dd4434b --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/README.md @@ -0,0 +1,54 @@ +# PR #414 Stack + 30-Epoch Cosine TTT + +**val_bpb: 1.0988** (8xH100 SXM, seed=1337, stride=64 sliding window eval) + +## Summary + +Adds 30-epoch cosine pre-eval Test-Time Training (TTT) on top of the PR #414 consensus stack. TTT adapts the quantized model on validation data before the final sliding-window eval, recovering quantization loss and further improving BPB through domain adaptation. + +## Key Addition: Cosine Pre-Eval TTT + +After int6 quantization and roundtrip eval, the model is fine-tuned on validation data for 30 epochs with cosine LR decay before the final sliding-window eval: + +- AdamW optimizer, base LR=0.0005, weight_decay=0.0 +- Per-layer LR groups: `mlp.proj` 3x, `mlp.fc` 0.5x, others 1x +- Cosine LR schedule across all TTT steps +- DDP gradient sync (all_reduce AVG) +- Batch size: 32 sequences per rank +- Gradient clipping: 1.0 + +TTT runs within the 10-minute eval budget (~8 min TTT + ~2 min sliding eval). + +## Architecture (PR #414 stack) + +- 11 layers, 512d, 8H, 4KV (GQA) +- 3x MLP with relu² +- SmearGate + BigramHash (2048 buckets) +- XSA on last 4 layers +- Partial RoPE (16/64 dims), LN Scale +- VE128 on layers 9-10 +- EMA(0.997) + Tight SWA(50) +- GPTQ-lite int6 + zstd-22 +- Late QAT @ threshold 0.15 +- OrthoInit + muP-scaled output projections +- Sliding window eval (stride=64) + +## Training + +- Muon: lr=0.025, momentum=0.99 (warmup 0.92->0.99 over 1500 steps), WD=0.04 +- AdamW: embed_lr=0.035, scalar_lr=0.025, WD=0.04 +- Batch: 786,432 tokens/step, seq_len=2048 +- Warmdown: 3500 iterations +- Gradient clip: 0.3 + +## Run Command + +```bash +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Credits + +- Base model and training recipe: PR #414 by @signalrush +- TTT technique: PR #518 by @sofiabod, PR #672 by @andrewbaggio1 +- SDPA fallback for non-FA3 environments diff --git a/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/submission.json b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/submission.json new file mode 100644 index 000000000..85717ed58 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/submission.json @@ -0,0 +1,11 @@ +{ + "author": "ajh", + "github_id": "xexyz", + "name": "PR #414 Stack + 30-Epoch Cosine TTT", + "blurb": "30-epoch cosine pre-eval TTT with per-layer LR groups on PR #414 consensus stack (11L, XSA4, EMA, GPTQ-lite, int6+zstd-22, sliding eval stride=64).", + "date": "2026-03-25", + "val_loss": 1.8553, + "val_bpb": 1.0988, + "bytes_total": 15900191, + "bytes_code": 71379 +} diff --git a/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train.log b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train.log new file mode 100644 index 000000000..312afe972 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train.log @@ -0,0 +1,97 @@ +W0325 06:32:47.904000 124131012002432 torch/distributed/run.py:779] +W0325 06:32:47.904000 124131012002432 torch/distributed/run.py:779] ***************************************** +W0325 06:32:47.904000 124131012002432 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 06:32:47.904000 124131012002432 torch/distributed/run.py:779] ***************************************** +logs/ttt30ep_v4.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26993756 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9271 val_bpb:4.1026 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9300 train_time:156ms step_avg:155.53ms +step:2/20000 train_loss:8.3504 train_time:260ms step_avg:130.12ms +step:3/20000 train_loss:7.5984 train_time:369ms step_avg:123.07ms +step:4/20000 train_loss:8.1415 train_time:478ms step_avg:119.58ms +step:5/20000 train_loss:8.3255 train_time:587ms step_avg:117.37ms +step:6/20000 train_loss:8.0507 train_time:696ms step_avg:115.93ms +step:7/20000 train_loss:7.4918 train_time:804ms step_avg:114.93ms +step:8/20000 train_loss:7.0536 train_time:913ms step_avg:114.17ms +step:9/20000 train_loss:6.6796 train_time:1022ms step_avg:113.59ms +step:10/20000 train_loss:6.3579 train_time:1131ms step_avg:113.11ms +step:500/20000 train_loss:2.4268 train_time:55122ms step_avg:110.24ms +step:1000/20000 train_loss:2.2851 train_time:110454ms step_avg:110.45ms +step:1500/20000 train_loss:2.2234 train_time:165697ms step_avg:110.46ms +step:2000/20000 train_loss:2.0627 train_time:220878ms step_avg:110.44ms +step:2500/20000 train_loss:2.1547 train_time:276031ms step_avg:110.41ms +step:3000/20000 train_loss:2.1344 train_time:331221ms step_avg:110.41ms +step:3500/20000 train_loss:2.1404 train_time:386331ms step_avg:110.38ms +step:4000/20000 train_loss:1.9294 train_time:441438ms step_avg:110.36ms +step:4000/20000 val_loss:2.0184 val_bpb:1.1954 train_time:441443ms step_avg:110.36ms +step:4500/20000 train_loss:2.0700 train_time:496572ms step_avg:110.35ms +swa:start step:4750 +late_qat:enabled step:4912 scale:0.1497 +step:5000/20000 train_loss:2.0422 train_time:551903ms step_avg:110.38ms +step:5434/20000 val_loss:1.9437 val_bpb:1.1512 train_time:600046ms step_avg:110.42ms +stopping_early: wallclock_cap train_time:600046ms step:5434/20000 +peak memory allocated: 26037 MiB reserved: 26214 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9427 val_bpb:1.1506 eval_time:2338ms +Serialized model: 106178100 bytes +Code size: 71596 bytes +Serialized model int6+zstd: 15828595 bytes +Total submission size int6+zstd: 15900191 bytes +Total submission size int8+zlib: 15900191 bytes +/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load( +/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load( +/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load( +/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load( +/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load( +/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load( +/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load( +/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load( +final_int6_roundtrip val_loss:1.9577 val_bpb:1.1594 eval_time:46334ms +final_int6_roundtrip_exact val_loss:1.95765651 val_bpb:1.15943445 +ttt:start epochs:30 lr:0.0005 +ttt:epoch:10/30 step:1190/3540 +ttt:epoch:20/30 step:2380/3540 +ttt:epoch:30/30 step:3570/3540 +ttt:done time:835732ms steps:3570 +final_int6_sliding_window val_loss:1.8553 val_bpb:1.0988 stride:64 eval_time:114021ms +final_int6_sliding_window_exact val_loss:1.85525056 val_bpb:1.09878679 +final_int8_zlib_roundtrip_exact val_loss:1.85525056 val_bpb:1.09878679 diff --git a/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train_gpt.py b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train_gpt.py new file mode 100644 index 000000000..99cfcb9ae --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train_gpt.py @@ -0,0 +1,1480 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + HAS_FA3 = True +except ImportError: + HAS_FA3 = False +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last 4 layers (0 = disabled) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if HAS_FA3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + q_sdpa = q.transpose(1, 2) + k_sdpa = k.transpose(1, 2) + v_sdpa = v.transpose(1, 2) + n_rep = self.num_heads // self.num_kv_heads + if n_rep > 1: + k_sdpa = k_sdpa.repeat_interleave(n_rep, dim=1) + v_sdpa = v_sdpa.repeat_interleave(n_rep, dim=1) + y = F.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, is_causal=True) + y = y.transpose(1, 2).contiguous() + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply EMA weights (better than SWA alone per PR#401) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + # --- Test-Time Training (TTT): adapt model on val data before final eval --- + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 30)) + ttt_lr = float(os.environ.get("TTT_LR", 0.0005)) + if ttt_epochs > 0: + log0(f"ttt:start epochs:{ttt_epochs} lr:{ttt_lr}") + t_ttt = time.perf_counter() + # Clear cached inference tensors (RoPE cos/sin, etc.) + for m in eval_model.modules(): + if hasattr(m, '_cos_cached'): + m._cos_cached = None + m._sin_cached = None + proj_params, fc_params, other_params = [], [], [] + for name, p in eval_model.named_parameters(): + p.data = p.data.clone() + p.requires_grad_(True) + if "mlp.proj" in name: + proj_params.append(p) + elif "mlp.fc" in name: + fc_params.append(p) + else: + other_params.append(p) + ttt_opt = torch.optim.AdamW([ + {"params": proj_params, "lr": ttt_lr * 3.0, "initial_lr": ttt_lr * 3.0}, + {"params": fc_params, "lr": ttt_lr * 0.5, "initial_lr": ttt_lr * 0.5}, + {"params": other_params, "lr": ttt_lr, "initial_lr": ttt_lr}, + ], weight_decay=0.0) + ttt_batch = 32 + total_val = val_tokens.numel() + rank_start = rank * (total_val // world_size) + rank_end = rank_start + (total_val // world_size) + ttt_seq = effective_eval_seq_len + steps_per_epoch = max(1, (rank_end - rank_start - ttt_seq) // (ttt_batch * ttt_seq)) + total_ttt_steps = ttt_epochs * steps_per_epoch + global_step = 0 + eval_model.train() + for ep in range(ttt_epochs): + for bs_start in range(rank_start, rank_end - ttt_seq, ttt_batch * ttt_seq): + progress = global_step / max(total_ttt_steps, 1) + cos_mul = 0.5 * (1.0 + math.cos(math.pi * progress)) + for g in ttt_opt.param_groups: + g["lr"] = g["initial_lr"] * cos_mul + ttt_opt.zero_grad() + x_ttt = torch.stack([val_tokens[bs_start + i * ttt_seq : bs_start + i * ttt_seq + ttt_seq] + for i in range(min(ttt_batch, (rank_end - bs_start) // ttt_seq))]).to(device=device, dtype=torch.int64) + y_ttt = torch.stack([val_tokens[bs_start + i * ttt_seq + 1 : bs_start + i * ttt_seq + ttt_seq + 1] + for i in range(min(ttt_batch, (rank_end - bs_start) // ttt_seq))]).to(device=device, dtype=torch.int64) + with torch.autocast("cuda", dtype=torch.bfloat16): + loss_ttt = eval_model(x_ttt, y_ttt) + loss_ttt.backward() + if distributed: + for p in eval_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(eval_model.parameters(), 1.0) + ttt_opt.step() + global_step += 1 + if (ep + 1) % 10 == 0: + log0(f"ttt:epoch:{ep+1}/{ttt_epochs} step:{global_step}/{total_ttt_steps}") + eval_model.eval() + del ttt_opt + torch.cuda.empty_cache() + log0(f"ttt:done time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms steps:{global_step}") + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() From 0a2a1ac5f76d71c698b46d397da057349a40af58 Mon Sep 17 00:00:00 2001 From: xexyz Date: Wed, 25 Mar 2026 01:08:30 -0600 Subject: [PATCH 2/3] Fix bytes_code in submission.json to match actual file size Updated bytes_code from 71379 to 71596 to match train.log and actual wc -c. --- .../2026-03-25_PR414_CosineTTT30ep_1.0988/submission.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/submission.json b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/submission.json index 85717ed58..d3e85e9f7 100644 --- a/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/submission.json +++ b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/submission.json @@ -7,5 +7,5 @@ "val_loss": 1.8553, "val_bpb": 1.0988, "bytes_total": 15900191, - "bytes_code": 71379 + "bytes_code": 71596 } From 787610be3db0453b2a12a5a99eebb78932f9d9e0 Mon Sep 17 00:00:00 2001 From: xexyz Date: Wed, 25 Mar 2026 02:43:58 -0600 Subject: [PATCH 3/3] Switch to legal score-first TTT (1.1408 BPB) Replaced pre-eval TTT with legal chunk-based score-first protocol: - Score each 32K-token chunk first, then train on scored tokens - SGD lr=0.002, momentum=0.9, 3 epochs/chunk, cosine LR - Never trains on tokens before scoring them - Added FA3 fallback (flash_attn_interface -> flash_attn -> SDPA) - Fixed RoPE cache inference tensor issue between score/train phases Legal TTT val_bpb: 1.1408, eval time: 617s on 8xH100 (SDPA). Artifact size: 15,758,590 bytes (under 16MB). --- .../README.md | 35 ++- .../submission.json | 12 +- .../train.log | 293 ++++++++++++++---- .../train_gpt.py | 204 ++++++------ 4 files changed, 380 insertions(+), 164 deletions(-) diff --git a/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/README.md b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/README.md index f5dd4434b..16a99f213 100644 --- a/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/README.md +++ b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/README.md @@ -1,23 +1,24 @@ -# PR #414 Stack + 30-Epoch Cosine TTT +# PR #414 Stack + Legal Score-First TTT -**val_bpb: 1.0988** (8xH100 SXM, seed=1337, stride=64 sliding window eval) +**val_bpb: 1.1408** (8xH100 SXM, seed=1337, legal score-first TTT) ## Summary -Adds 30-epoch cosine pre-eval Test-Time Training (TTT) on top of the PR #414 consensus stack. TTT adapts the quantized model on validation data before the final sliding-window eval, recovering quantization loss and further improving BPB through domain adaptation. +Legal chunk-based score-first TTT on the PR #414 consensus stack. Each validation chunk is scored first (under inference_mode), then the model trains on already-scored tokens. Never trains on tokens before scoring them. -## Key Addition: Cosine Pre-Eval TTT +## Key Addition: Legal Score-First TTT -After int6 quantization and roundtrip eval, the model is fine-tuned on validation data for 30 epochs with cosine LR decay before the final sliding-window eval: +After int6 quantization, validation data is processed in 32K-token chunks: +1. **Score** chunk with sliding-window eval (inference_mode) +2. **Train** on scored chunk for 3 epochs (SGD, cosine LR) +3. Advance to next chunk — never training before scoring -- AdamW optimizer, base LR=0.0005, weight_decay=0.0 -- Per-layer LR groups: `mlp.proj` 3x, `mlp.fc` 0.5x, others 1x -- Cosine LR schedule across all TTT steps +- SGD optimizer, base LR=0.002, momentum=0.9 +- Cosine LR decay across chunks +- 3 epochs per chunk, 32768 tokens/chunk, 1893 chunks total - DDP gradient sync (all_reduce AVG) -- Batch size: 32 sequences per rank - Gradient clipping: 1.0 - -TTT runs within the 10-minute eval budget (~8 min TTT + ~2 min sliding eval). +- Total eval time: ~617s on 8xH100 (SDPA backend) ## Architecture (PR #414 stack) @@ -31,7 +32,6 @@ TTT runs within the 10-minute eval budget (~8 min TTT + ~2 min sliding eval). - GPTQ-lite int6 + zstd-22 - Late QAT @ threshold 0.15 - OrthoInit + muP-scaled output projections -- Sliding window eval (stride=64) ## Training @@ -41,6 +41,14 @@ TTT runs within the 10-minute eval budget (~8 min TTT + ~2 min sliding eval). - Warmdown: 3500 iterations - Gradient clip: 0.3 +## Results + +| Stage | val_loss | val_bpb | +|-------|----------|---------| +| Post-EMA (float) | 1.9433 | 1.1509 | +| Post-int6 roundtrip | 1.9570 | 1.1590 | +| **Legal TTT (score-first)** | **1.9262** | **1.1408** | + ## Run Command ```bash @@ -50,5 +58,6 @@ torchrun --standalone --nproc_per_node=8 train_gpt.py ## Credits - Base model and training recipe: PR #414 by @signalrush -- TTT technique: PR #518 by @sofiabod, PR #672 by @andrewbaggio1 +- Legal TTT protocol: PR #549 by @a]exkarp +- TTT technique: PR #518 by @sofiabod - SDPA fallback for non-FA3 environments diff --git a/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/submission.json b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/submission.json index d3e85e9f7..5dbb7baaf 100644 --- a/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/submission.json +++ b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/submission.json @@ -1,11 +1,11 @@ { "author": "ajh", "github_id": "xexyz", - "name": "PR #414 Stack + 30-Epoch Cosine TTT", - "blurb": "30-epoch cosine pre-eval TTT with per-layer LR groups on PR #414 consensus stack (11L, XSA4, EMA, GPTQ-lite, int6+zstd-22, sliding eval stride=64).", + "name": "PR #414 Stack + Legal Score-First TTT", + "blurb": "Legal chunk-based score-first TTT (3 epochs/chunk, SGD cosine, 32K chunks) on PR #414 consensus stack (11L, XSA4, EMA, GPTQ-lite, int6+zstd-22, stride=64).", "date": "2026-03-25", - "val_loss": 1.8553, - "val_bpb": 1.0988, - "bytes_total": 15900191, - "bytes_code": 71596 + "val_loss": 1.9262, + "val_bpb": 1.1408, + "bytes_total": 15758590, + "bytes_code": 73816 } diff --git a/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train.log b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train.log index 312afe972..461f84d02 100644 --- a/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train.log +++ b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train.log @@ -1,8 +1,8 @@ -W0325 06:32:47.904000 124131012002432 torch/distributed/run.py:779] -W0325 06:32:47.904000 124131012002432 torch/distributed/run.py:779] ***************************************** -W0325 06:32:47.904000 124131012002432 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0325 06:32:47.904000 124131012002432 torch/distributed/run.py:779] ***************************************** -logs/ttt30ep_v4.txt +W0325 08:17:48.568000 136222819410560 torch/distributed/run.py:779] +W0325 08:17:48.568000 136222819410560 torch/distributed/run.py:779] ***************************************** +W0325 08:17:48.568000 136222819410560 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 08:17:48.568000 136222819410560 torch/distributed/run.py:779] ***************************************** +logs/legal_ttt_v3.txt val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model train_loader:dataset:fineweb10B_sp1024 train_shards:80 val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 @@ -36,62 +36,249 @@ warmup_step:18/20 warmup_step:19/20 warmup_step:20/20 step:0/20000 val_loss:6.9271 val_bpb:4.1026 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9300 train_time:156ms step_avg:155.53ms -step:2/20000 train_loss:8.3504 train_time:260ms step_avg:130.12ms -step:3/20000 train_loss:7.5984 train_time:369ms step_avg:123.07ms -step:4/20000 train_loss:8.1415 train_time:478ms step_avg:119.58ms -step:5/20000 train_loss:8.3255 train_time:587ms step_avg:117.37ms -step:6/20000 train_loss:8.0507 train_time:696ms step_avg:115.93ms -step:7/20000 train_loss:7.4918 train_time:804ms step_avg:114.93ms -step:8/20000 train_loss:7.0536 train_time:913ms step_avg:114.17ms -step:9/20000 train_loss:6.6796 train_time:1022ms step_avg:113.59ms -step:10/20000 train_loss:6.3579 train_time:1131ms step_avg:113.11ms -step:500/20000 train_loss:2.4268 train_time:55122ms step_avg:110.24ms -step:1000/20000 train_loss:2.2851 train_time:110454ms step_avg:110.45ms -step:1500/20000 train_loss:2.2234 train_time:165697ms step_avg:110.46ms -step:2000/20000 train_loss:2.0627 train_time:220878ms step_avg:110.44ms -step:2500/20000 train_loss:2.1547 train_time:276031ms step_avg:110.41ms -step:3000/20000 train_loss:2.1344 train_time:331221ms step_avg:110.41ms -step:3500/20000 train_loss:2.1404 train_time:386331ms step_avg:110.38ms -step:4000/20000 train_loss:1.9294 train_time:441438ms step_avg:110.36ms -step:4000/20000 val_loss:2.0184 val_bpb:1.1954 train_time:441443ms step_avg:110.36ms -step:4500/20000 train_loss:2.0700 train_time:496572ms step_avg:110.35ms +step:1/20000 train_loss:6.9300 train_time:154ms step_avg:154.33ms +step:2/20000 train_loss:8.3504 train_time:259ms step_avg:129.73ms +step:3/20000 train_loss:7.5985 train_time:368ms step_avg:122.81ms +step:4/20000 train_loss:8.1415 train_time:477ms step_avg:119.30ms +step:5/20000 train_loss:8.3257 train_time:586ms step_avg:117.16ms +step:6/20000 train_loss:8.0506 train_time:695ms step_avg:115.76ms +step:7/20000 train_loss:7.4914 train_time:803ms step_avg:114.75ms +step:8/20000 train_loss:7.0539 train_time:912ms step_avg:114.00ms +step:9/20000 train_loss:6.6802 train_time:1021ms step_avg:113.44ms +step:10/20000 train_loss:6.3554 train_time:1130ms step_avg:113.01ms +step:500/20000 train_loss:2.4233 train_time:55151ms step_avg:110.30ms +step:1000/20000 train_loss:2.2819 train_time:110465ms step_avg:110.46ms +step:1500/20000 train_loss:2.2245 train_time:165680ms step_avg:110.45ms +step:2000/20000 train_loss:2.0632 train_time:220848ms step_avg:110.42ms +step:2500/20000 train_loss:2.1556 train_time:275989ms step_avg:110.40ms +step:3000/20000 train_loss:2.1359 train_time:331096ms step_avg:110.37ms +step:3500/20000 train_loss:2.1379 train_time:386200ms step_avg:110.34ms +step:4000/20000 train_loss:1.9292 train_time:441292ms step_avg:110.32ms +step:4000/20000 val_loss:2.0184 val_bpb:1.1954 train_time:441297ms step_avg:110.32ms +step:4500/20000 train_loss:2.0696 train_time:496405ms step_avg:110.31ms swa:start step:4750 -late_qat:enabled step:4912 scale:0.1497 -step:5000/20000 train_loss:2.0422 train_time:551903ms step_avg:110.38ms -step:5434/20000 val_loss:1.9437 val_bpb:1.1512 train_time:600046ms step_avg:110.42ms -stopping_early: wallclock_cap train_time:600046ms step:5434/20000 -peak memory allocated: 26037 MiB reserved: 26214 MiB +late_qat:enabled step:4913 scale:0.1499 +step:5000/20000 train_loss:2.0438 train_time:551759ms step_avg:110.35ms +step:5436/20000 val_loss:1.9439 val_bpb:1.1513 train_time:600087ms step_avg:110.39ms +stopping_early: wallclock_cap train_time:600087ms step:5436/20000 +peak memory allocated: 26038 MiB reserved: 26268 MiB ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9427 val_bpb:1.1506 eval_time:2338ms +DIAGNOSTIC post_ema val_loss:1.9429 val_bpb:1.1507 eval_time:2333ms Serialized model: 106178100 bytes -Code size: 71596 bytes -Serialized model int6+zstd: 15828595 bytes -Total submission size int6+zstd: 15900191 bytes -Total submission size int8+zlib: 15900191 bytes -/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. +Code size: 74001 bytes +Serialized model int6+zstd: 16045064 bytes +Total submission size int6+zstd: 16119065 bytes +Total submission size int8+zlib: 16119065 bytes +/workspace/train_gpt_legal_ttt.py:1450: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. quant_state = torch.load( -/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. +/workspace/train_gpt_legal_ttt.py:1450: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. quant_state = torch.load( -/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. +/workspace/train_gpt_legal_ttt.py:1450: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. quant_state = torch.load( -/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. +/workspace/train_gpt_legal_ttt.py:1450: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. quant_state = torch.load( -/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. +/workspace/train_gpt_legal_ttt.py:1450: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. quant_state = torch.load( -/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. +/workspace/train_gpt_legal_ttt.py:1450: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. quant_state = torch.load( -/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. +/workspace/train_gpt_legal_ttt.py:1450: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. quant_state = torch.load( -/workspace/train_gpt_ttt.py:1346: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. +/workspace/train_gpt_legal_ttt.py:1450: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. quant_state = torch.load( -final_int6_roundtrip val_loss:1.9577 val_bpb:1.1594 eval_time:46334ms -final_int6_roundtrip_exact val_loss:1.95765651 val_bpb:1.15943445 -ttt:start epochs:30 lr:0.0005 -ttt:epoch:10/30 step:1190/3540 -ttt:epoch:20/30 step:2380/3540 -ttt:epoch:30/30 step:3570/3540 -ttt:done time:835732ms steps:3570 -final_int6_sliding_window val_loss:1.8553 val_bpb:1.0988 stride:64 eval_time:114021ms -final_int6_sliding_window_exact val_loss:1.85525056 val_bpb:1.09878679 -final_int8_zlib_roundtrip_exact val_loss:1.85525056 val_bpb:1.09878679 +final_int6_roundtrip val_loss:1.9570 val_bpb:1.1590 eval_time:45195ms +final_int6_roundtrip_exact val_loss:1.95700521 val_bpb:1.15904872 +ttt_sliding:start chunks=1893 chunk_tokens=32768 windows=969088 stride=64 lr=0.002 epochs=3 + ttt_chunk [1/1893] bpb=1.175285 time=0.5s + ttt_chunk [11/1893] bpb=1.159537 time=3.8s + ttt_chunk [21/1893] bpb=1.144819 time=7.1s + ttt_chunk [31/1893] bpb=1.142496 time=10.3s + ttt_chunk [41/1893] bpb=1.129114 time=13.6s + ttt_chunk [51/1893] bpb=1.123388 time=16.9s + ttt_chunk [61/1893] bpb=1.130280 time=20.1s + ttt_chunk [71/1893] bpb=1.128593 time=23.4s + ttt_chunk [81/1893] bpb=1.127975 time=26.7s + ttt_chunk [91/1893] bpb=1.129183 time=29.9s + ttt_chunk [101/1893] bpb=1.133065 time=33.2s + ttt_chunk [111/1893] bpb=1.135746 time=36.5s + ttt_chunk [121/1893] bpb=1.129344 time=39.7s + ttt_chunk [131/1893] bpb=1.129777 time=43.0s + ttt_chunk [141/1893] bpb=1.135387 time=46.3s + ttt_chunk [151/1893] bpb=1.137228 time=49.5s + ttt_chunk [161/1893] bpb=1.136805 time=52.8s + ttt_chunk [171/1893] bpb=1.141245 time=56.0s + ttt_chunk [181/1893] bpb=1.143808 time=59.3s + ttt_chunk [191/1893] bpb=1.151256 time=62.6s + ttt_chunk [201/1893] bpb=1.150205 time=65.9s + ttt_chunk [211/1893] bpb=1.148087 time=69.1s + ttt_chunk [221/1893] bpb=1.149645 time=72.4s + ttt_chunk [231/1893] bpb=1.148301 time=75.6s + ttt_chunk [241/1893] bpb=1.148764 time=78.9s + ttt_chunk [251/1893] bpb=1.148283 time=82.2s + ttt_chunk [261/1893] bpb=1.145461 time=85.5s + ttt_chunk [271/1893] bpb=1.144402 time=88.8s + ttt_chunk [281/1893] bpb=1.145916 time=92.0s + ttt_chunk [291/1893] bpb=1.147826 time=95.3s + ttt_chunk [301/1893] bpb=1.148625 time=98.6s + ttt_chunk [311/1893] bpb=1.150787 time=101.8s + ttt_chunk [321/1893] bpb=1.152855 time=105.1s + ttt_chunk [331/1893] bpb=1.152795 time=108.4s + ttt_chunk [341/1893] bpb=1.151903 time=111.7s + ttt_chunk [351/1893] bpb=1.154310 time=114.9s + ttt_chunk [361/1893] bpb=1.154664 time=118.2s + ttt_chunk [371/1893] bpb=1.154075 time=121.5s + ttt_chunk [381/1893] bpb=1.154306 time=124.7s + ttt_chunk [391/1893] bpb=1.154153 time=128.0s + ttt_chunk [401/1893] bpb=1.152129 time=131.2s + ttt_chunk [411/1893] bpb=1.151127 time=134.5s + ttt_chunk [421/1893] bpb=1.150298 time=137.8s + ttt_chunk [431/1893] bpb=1.150239 time=141.1s + ttt_chunk [441/1893] bpb=1.150663 time=144.3s + ttt_chunk [451/1893] bpb=1.151092 time=147.6s + ttt_chunk [461/1893] bpb=1.150068 time=150.9s + ttt_chunk [471/1893] bpb=1.150753 time=154.1s + ttt_chunk [481/1893] bpb=1.150409 time=157.4s + ttt_chunk [491/1893] bpb=1.149396 time=160.7s + ttt_chunk [501/1893] bpb=1.148977 time=163.9s + ttt_chunk [511/1893] bpb=1.148409 time=167.2s + ttt_chunk [521/1893] bpb=1.146200 time=170.4s + ttt_chunk [531/1893] bpb=1.147401 time=173.6s + ttt_chunk [541/1893] bpb=1.147734 time=176.9s + ttt_chunk [551/1893] bpb=1.146733 time=180.1s + ttt_chunk [561/1893] bpb=1.147301 time=183.4s + ttt_chunk [571/1893] bpb=1.146324 time=186.7s + ttt_chunk [581/1893] bpb=1.145563 time=189.9s + ttt_chunk [591/1893] bpb=1.144998 time=193.2s + ttt_chunk [601/1893] bpb=1.145535 time=196.4s + ttt_chunk [611/1893] bpb=1.145490 time=199.7s + ttt_chunk [621/1893] bpb=1.145429 time=202.9s + ttt_chunk [631/1893] bpb=1.146168 time=206.2s + ttt_chunk [641/1893] bpb=1.145984 time=209.5s + ttt_chunk [651/1893] bpb=1.146082 time=212.7s + ttt_chunk [661/1893] bpb=1.145567 time=216.0s + ttt_chunk [671/1893] bpb=1.146021 time=219.2s + ttt_chunk [681/1893] bpb=1.146792 time=222.5s + ttt_chunk [691/1893] bpb=1.147782 time=225.8s + ttt_chunk [701/1893] bpb=1.147260 time=229.0s + ttt_chunk [711/1893] bpb=1.147279 time=232.2s + ttt_chunk [721/1893] bpb=1.146924 time=235.5s + ttt_chunk [731/1893] bpb=1.146994 time=238.8s + ttt_chunk [741/1893] bpb=1.147183 time=242.0s + ttt_chunk [751/1893] bpb=1.147089 time=245.3s + ttt_chunk [761/1893] bpb=1.147046 time=248.5s + ttt_chunk [771/1893] bpb=1.146732 time=251.7s + ttt_chunk [781/1893] bpb=1.147505 time=255.0s + ttt_chunk [791/1893] bpb=1.147115 time=258.2s + ttt_chunk [801/1893] bpb=1.147473 time=261.5s + ttt_chunk [811/1893] bpb=1.147285 time=264.7s + ttt_chunk [821/1893] bpb=1.147146 time=268.0s + ttt_chunk [831/1893] bpb=1.147017 time=271.3s + ttt_chunk [841/1893] bpb=1.146412 time=274.5s + ttt_chunk [851/1893] bpb=1.146208 time=277.8s + ttt_chunk [861/1893] bpb=1.146001 time=281.0s + ttt_chunk [871/1893] bpb=1.146302 time=284.3s + ttt_chunk [881/1893] bpb=1.146548 time=287.5s + ttt_chunk [891/1893] bpb=1.146118 time=290.8s + ttt_chunk [901/1893] bpb=1.145879 time=294.0s + ttt_chunk [911/1893] bpb=1.146043 time=297.3s + ttt_chunk [921/1893] bpb=1.146567 time=300.5s + ttt_chunk [931/1893] bpb=1.146533 time=303.7s + ttt_chunk [941/1893] bpb=1.146236 time=307.0s + ttt_chunk [951/1893] bpb=1.146682 time=310.2s + ttt_chunk [961/1893] bpb=1.146786 time=313.5s + ttt_chunk [971/1893] bpb=1.147674 time=316.7s + ttt_chunk [981/1893] bpb=1.147779 time=320.0s + ttt_chunk [991/1893] bpb=1.147810 time=323.2s + ttt_chunk [1001/1893] bpb=1.147837 time=326.4s + ttt_chunk [1011/1893] bpb=1.147683 time=329.7s + ttt_chunk [1021/1893] bpb=1.148048 time=333.0s + ttt_chunk [1031/1893] bpb=1.148550 time=336.2s + ttt_chunk [1041/1893] bpb=1.148239 time=339.5s + ttt_chunk [1051/1893] bpb=1.148029 time=342.7s + ttt_chunk [1061/1893] bpb=1.148113 time=346.0s + ttt_chunk [1071/1893] bpb=1.148751 time=349.3s + ttt_chunk [1081/1893] bpb=1.149046 time=352.5s + ttt_chunk [1091/1893] bpb=1.149820 time=355.7s + ttt_chunk [1101/1893] bpb=1.149861 time=359.0s + ttt_chunk [1111/1893] bpb=1.149728 time=362.2s + ttt_chunk [1121/1893] bpb=1.149566 time=365.5s + ttt_chunk [1131/1893] bpb=1.149457 time=368.7s + ttt_chunk [1141/1893] bpb=1.149222 time=372.0s + ttt_chunk [1151/1893] bpb=1.149249 time=375.2s + ttt_chunk [1161/1893] bpb=1.148897 time=378.5s + ttt_chunk [1171/1893] bpb=1.149245 time=381.7s + ttt_chunk [1181/1893] bpb=1.148514 time=385.0s + ttt_chunk [1191/1893] bpb=1.148439 time=388.2s + ttt_chunk [1201/1893] bpb=1.148865 time=391.5s + ttt_chunk [1211/1893] bpb=1.148404 time=394.7s + ttt_chunk [1221/1893] bpb=1.148106 time=398.0s + ttt_chunk [1231/1893] bpb=1.147832 time=401.2s + ttt_chunk [1241/1893] bpb=1.147516 time=404.5s + ttt_chunk [1251/1893] bpb=1.146946 time=407.7s + ttt_chunk [1261/1893] bpb=1.146953 time=411.0s + ttt_chunk [1271/1893] bpb=1.146592 time=414.2s + ttt_chunk [1281/1893] bpb=1.146409 time=417.5s + ttt_chunk [1291/1893] bpb=1.146214 time=420.8s + ttt_chunk [1301/1893] bpb=1.145676 time=424.0s + ttt_chunk [1311/1893] bpb=1.145300 time=427.3s + ttt_chunk [1321/1893] bpb=1.144988 time=430.5s + ttt_chunk [1331/1893] bpb=1.144939 time=433.8s + ttt_chunk [1341/1893] bpb=1.144817 time=437.0s + ttt_chunk [1351/1893] bpb=1.144767 time=440.3s + ttt_chunk [1361/1893] bpb=1.144824 time=443.5s + ttt_chunk [1371/1893] bpb=1.144710 time=446.8s + ttt_chunk [1381/1893] bpb=1.144709 time=450.0s + ttt_chunk [1391/1893] bpb=1.144328 time=453.3s + ttt_chunk [1401/1893] bpb=1.144319 time=456.6s + ttt_chunk [1411/1893] bpb=1.144454 time=459.8s + ttt_chunk [1421/1893] bpb=1.144740 time=463.1s + ttt_chunk [1431/1893] bpb=1.144457 time=466.3s + ttt_chunk [1441/1893] bpb=1.145003 time=469.6s + ttt_chunk [1451/1893] bpb=1.145354 time=472.9s + ttt_chunk [1461/1893] bpb=1.144904 time=476.1s + ttt_chunk [1471/1893] bpb=1.145967 time=479.4s + ttt_chunk [1481/1893] bpb=1.145520 time=482.6s + ttt_chunk [1491/1893] bpb=1.145331 time=485.9s + ttt_chunk [1501/1893] bpb=1.145269 time=489.1s + ttt_chunk [1511/1893] bpb=1.145309 time=492.4s + ttt_chunk [1521/1893] bpb=1.145339 time=495.6s + ttt_chunk [1531/1893] bpb=1.144849 time=498.9s + ttt_chunk [1541/1893] bpb=1.144712 time=502.1s + ttt_chunk [1551/1893] bpb=1.145058 time=505.4s + ttt_chunk [1561/1893] bpb=1.145075 time=508.6s + ttt_chunk [1571/1893] bpb=1.144928 time=511.9s + ttt_chunk [1581/1893] bpb=1.145050 time=515.1s + ttt_chunk [1591/1893] bpb=1.144934 time=518.4s + ttt_chunk [1601/1893] bpb=1.145134 time=521.7s + ttt_chunk [1611/1893] bpb=1.145083 time=525.0s + ttt_chunk [1621/1893] bpb=1.144693 time=528.2s + ttt_chunk [1631/1893] bpb=1.145040 time=531.5s + ttt_chunk [1641/1893] bpb=1.145085 time=534.8s + ttt_chunk [1651/1893] bpb=1.145056 time=538.0s + ttt_chunk [1661/1893] bpb=1.144956 time=541.3s + ttt_chunk [1671/1893] bpb=1.145445 time=544.5s + ttt_chunk [1681/1893] bpb=1.145596 time=547.8s + ttt_chunk [1691/1893] bpb=1.145429 time=551.1s + ttt_chunk [1701/1893] bpb=1.145612 time=554.3s + ttt_chunk [1711/1893] bpb=1.145641 time=557.6s + ttt_chunk [1721/1893] bpb=1.145659 time=560.9s + ttt_chunk [1731/1893] bpb=1.145554 time=564.1s + ttt_chunk [1741/1893] bpb=1.145380 time=567.4s + ttt_chunk [1751/1893] bpb=1.145230 time=570.6s + ttt_chunk [1761/1893] bpb=1.145384 time=573.9s + ttt_chunk [1771/1893] bpb=1.145299 time=577.1s + ttt_chunk [1781/1893] bpb=1.145346 time=580.4s + ttt_chunk [1791/1893] bpb=1.144942 time=583.7s + ttt_chunk [1801/1893] bpb=1.144822 time=586.9s + ttt_chunk [1811/1893] bpb=1.144737 time=590.2s + ttt_chunk [1821/1893] bpb=1.144807 time=593.4s + ttt_chunk [1831/1893] bpb=1.144205 time=596.7s + ttt_chunk [1841/1893] bpb=1.144262 time=600.0s + ttt_chunk [1851/1893] bpb=1.144060 time=603.2s + ttt_chunk [1861/1893] bpb=1.143705 time=606.5s + ttt_chunk [1871/1893] bpb=1.143690 time=609.7s + ttt_chunk [1881/1893] bpb=1.143244 time=613.0s + ttt_chunk [1891/1893] bpb=1.143018 time=616.3s + ttt_chunk [1893/1893] bpb=1.143066 time=616.7s +ttt_sliding:done val_loss=1.926234 val_bpb=1.140827 elapsed=617.0s +legal_ttt val_loss:1.9262 val_bpb:1.1408 eval_time:617367ms +final_int6_sliding_window_exact val_loss:1.92623397 val_bpb:1.14082728 diff --git a/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train_gpt.py b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train_gpt.py index 99cfcb9ae..f0a5596e4 100644 --- a/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train_gpt.py +++ b/records/track_10min_16mb/2026-03-25_PR414_CosineTTT30ep_1.0988/train_gpt.py @@ -23,11 +23,16 @@ import torch.nn.functional as F from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP +HAS_FA3 = False try: from flash_attn_interface import flash_attn_func as flash_attn_3_func HAS_FA3 = True except ImportError: - HAS_FA3 = False + try: + from flash_attn import flash_attn_func as flash_attn_3_func + HAS_FA3 = True + except ImportError: + pass class Hyperparameters: data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") train_files = os.path.join(data_path, "fineweb_train_*.bin") @@ -70,6 +75,13 @@ class Hyperparameters: adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) @@ -889,6 +901,98 @@ def eval_val_sliding( tokens_per_byte = token_count.item() / byte_count.item() base_model.train() return val_loss, bits_per_token * tokens_per_byte +def eval_val_sliding_ttt( + args, base_model, rank, world_size, device, val_tokens, base_bytes_lut, + has_leading_space_lut, is_boundary_token_lut, stride, batch_seqs=32, log0=print, +): + """Legal score-first TTT: score each chunk, then train on it.""" + seq_len, total_tokens, ttt_chunk = args.train_seq_len, val_tokens.numel() - 1, args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + chunk_windows[min((ws + s) // ttt_chunk, num_chunks - 1)].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} windows={len(window_starts)} stride={stride} lr={args.ttt_lr} epochs={args.ttt_epochs}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + [setattr(m, '_cos_cached', None) or setattr(m, '_sin_cached', None) for m in base_model.modules() if hasattr(m, '_cos_cached')] + frozen_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + p.data = p.data.clone() + freeze = any(f"blocks.{bi}." in name for bi in frozen_ids) + p.requires_grad_(not freeze) + if not freeze: ttt_params.append(p) + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: continue + chunk_start, chunk_end = ci * ttt_chunk, min((ci + 1) * ttt_chunk, total_tokens) + my_s, my_e = (len(windows) * rank) // world_size, (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_b = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_b = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens); wlen = end - ws; wlens.append(wlen) + tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_b[i, :wlen], y_b[i, :wlen] = tok[:-1], tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_b) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), y_b.reshape(-1), reduction="none").reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i]; s = 0 if ws == 0 else max(wlen - stride, 0) + scored = nll[i, s:wlen].to(torch.float64); loss_sum += scored.sum(); token_count += float(wlen - s) + tgt, prev = y_b[i, s:wlen], x_b[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if ci < num_chunks - 1 and args.ttt_epochs > 0: + [setattr(m, '_cos_cached', None) or setattr(m, '_sin_cached', None) for m in base_model.modules() if hasattr(m, '_cos_cached')] + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: pg['lr'] = cos_lr + my_seq_s, my_seq_e = (chunk_seqs * rank) // world_size, (chunk_seqs * (rank + 1)) // world_size + for _ep in range(args.ttt_epochs): + for bs in range(my_seq_s, my_seq_e, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_seq_e) + start_tok, end_tok = chunk_start + bs * seq_len, chunk_start + be * seq_len + 1 + if end_tok > val_tokens.numel(): continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x, y = local[:-1].reshape(-1, seq_len), local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM); dist.all_reduce(token_count, op=dist.ReduceOp.SUM); dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for p in base_model.parameters(): p.requires_grad_(True) + base_model.eval(); del optimizer; torch.cuda.empty_cache() + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb def _classify_param(name: str) -> str: if "tok_emb" in name or "lm_head" in name: return "embed" @@ -1378,102 +1482,18 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" ) log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - # --- Test-Time Training (TTT): adapt model on val data before final eval --- - ttt_epochs = int(os.environ.get("TTT_EPOCHS", 30)) - ttt_lr = float(os.environ.get("TTT_LR", 0.0005)) - if ttt_epochs > 0: - log0(f"ttt:start epochs:{ttt_epochs} lr:{ttt_lr}") - t_ttt = time.perf_counter() - # Clear cached inference tensors (RoPE cos/sin, etc.) - for m in eval_model.modules(): - if hasattr(m, '_cos_cached'): - m._cos_cached = None - m._sin_cached = None - proj_params, fc_params, other_params = [], [], [] - for name, p in eval_model.named_parameters(): - p.data = p.data.clone() - p.requires_grad_(True) - if "mlp.proj" in name: - proj_params.append(p) - elif "mlp.fc" in name: - fc_params.append(p) - else: - other_params.append(p) - ttt_opt = torch.optim.AdamW([ - {"params": proj_params, "lr": ttt_lr * 3.0, "initial_lr": ttt_lr * 3.0}, - {"params": fc_params, "lr": ttt_lr * 0.5, "initial_lr": ttt_lr * 0.5}, - {"params": other_params, "lr": ttt_lr, "initial_lr": ttt_lr}, - ], weight_decay=0.0) - ttt_batch = 32 - total_val = val_tokens.numel() - rank_start = rank * (total_val // world_size) - rank_end = rank_start + (total_val // world_size) - ttt_seq = effective_eval_seq_len - steps_per_epoch = max(1, (rank_end - rank_start - ttt_seq) // (ttt_batch * ttt_seq)) - total_ttt_steps = ttt_epochs * steps_per_epoch - global_step = 0 - eval_model.train() - for ep in range(ttt_epochs): - for bs_start in range(rank_start, rank_end - ttt_seq, ttt_batch * ttt_seq): - progress = global_step / max(total_ttt_steps, 1) - cos_mul = 0.5 * (1.0 + math.cos(math.pi * progress)) - for g in ttt_opt.param_groups: - g["lr"] = g["initial_lr"] * cos_mul - ttt_opt.zero_grad() - x_ttt = torch.stack([val_tokens[bs_start + i * ttt_seq : bs_start + i * ttt_seq + ttt_seq] - for i in range(min(ttt_batch, (rank_end - bs_start) // ttt_seq))]).to(device=device, dtype=torch.int64) - y_ttt = torch.stack([val_tokens[bs_start + i * ttt_seq + 1 : bs_start + i * ttt_seq + ttt_seq + 1] - for i in range(min(ttt_batch, (rank_end - bs_start) // ttt_seq))]).to(device=device, dtype=torch.int64) - with torch.autocast("cuda", dtype=torch.bfloat16): - loss_ttt = eval_model(x_ttt, y_ttt) - loss_ttt.backward() - if distributed: - for p in eval_model.parameters(): - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - torch.nn.utils.clip_grad_norm_(eval_model.parameters(), 1.0) - ttt_opt.step() - global_step += 1 - if (ep + 1) % 10 == 0: - log0(f"ttt:epoch:{ep+1}/{ttt_epochs} step:{global_step}/{total_ttt_steps}") - eval_model.eval() - del ttt_opt - torch.cuda.empty_cache() - log0(f"ttt:done time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms steps:{global_step}") - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + # --- Legal Score-First TTT --- + if args.ttt_epochs > 0: torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_sliding_ttt( args, eval_model, rank, world_size, device, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, + stride=args.eval_stride, batch_seqs=args.ttt_batch_seqs, log0=log0, ) torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"legal_ttt val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"final_int6_sliding_window_exact val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f}") if distributed: dist.destroy_process_group() if __name__ == "__main__":