diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/README.md b/records/track_10min_16mb/2026-03-25_RecurLayers/README.md new file mode 100644 index 000000000..da3e00f33 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/README.md @@ -0,0 +1,78 @@ +# Depth Recurrence (layers 4,5) + +## Score: mean val_bpb = 1.1184 (3 seeds: 1.1179, 1.1191, 1.1183) + +Trained on 8xH100 SXM in ~600 seconds. ~15.9MB artifact (int6+lzma). + +## Motivation + +I explored both width scaling (MODEL_DIM=576) and depth scaling (adding layers) and found that depth consistently wins over width in this regime. A full independent 12-layer model at dim=512 outperformed a wider 11-layer model at dim=576, despite the wider model having more parameters. However, adding independent layers pushes the model over the 16MB artifact budget. Depth recurrence solves this: by re-executing mid-network layers with independent block scalars, I get the depth benefit without the parameter/size cost. Dual recurrence on layers 4 and 5 gives 13 virtual layers from 11 physical, staying well under budget at ~15.9MB. + +## Approach + +Depth recurrence applied to layers 4 and 5, creating 13 virtual layers from 11 physical layers while keeping parameter count at ~27M. Combined with test-time training (TTT) for additional evaluation-time adaptation. + +### Dual Depth Recurrence (layers 4,5) +Layers 4 and 5 are each executed twice in sequence (pattern: 0,1,2,3,4,5,4,5,6,7,8,9,10), producing 13 virtual layers from 11 physical layers. Each recurrent pass uses independent learnable block scalars, so the model can modulate how the repeated layers behave on their second pass. This adds depth without increasing model size or artifact bytes — only the small block scalar parameters are added (~2K params). + +Everything else (TTT, int6 quantization, SWA, bigram embeddings, value embeddings, Muon optimizer, etc.) is inherited from [PR #549](https://github.com/openai/parameter-golf/pull/549). + +## Hyperparameters + +| Parameter | Value | +|-----------|-------| +| num_layers | 11 (physical) / 13 (virtual) | +| model_dim | 512 | +| mlp_mult | 3.0 (hidden=1536) | +| recur_layers | 4, 5 | +| train_seq_len | 2048 | +| train_batch_tokens | 786,432 | +| warmdown_iters | 3500 | +| matrix_lr | 0.025 | +| scalar_lr | 0.025 | +| tied_embed_lr | 0.035 | +| muon_momentum | 0.99 (warmup from 0.92 over 1500 steps) | +| muon_weight_decay | 0.04 | +| adam_weight_decay | 0.04 | +| grad_clip_norm | 0.3 | +| eval_stride | 64 | +| swa_every | 50 | +| ttt_lr | 0.002 | +| ttt_epochs | 3 | +| ttt_chunk_tokens | 32768 | +| ttt_freeze_blocks | 2 | + +## Key Metrics + +- **Mean val_bpb: 1.11840** (std: 0.00049) +- Training: ~6,100 steps in ~600s +- Model params: ~27M +- Mean total submission size: 15,931,152 bytes (~15.9MB, int6+lzma) + +## Reproducibility + +Three independent training runs with different random seeds: + +| Seed | val_loss | val_bpb | total_bytes | +|------|----------|---------|-------------| +| 1337 | 1.88749538 | 1.11788404 | 15,928,948 | +| 2025 | 1.88948575 | 1.11906285 | 15,934,932 | +| 2024 | 1.88811812 | 1.11825287 | 15,929,576 | +| **Mean** | **1.88836642** | **1.11839992** | **15,931,152** | +| **Std** | **0.00083132** | **0.00049235** | | + +## Run Commands + +```bash +# Seed 1337 (default) +ITERATIONS=9000 RECUR_LAYERS=4,5 TTT_ENABLED=1 TTT_UNTIE=0 \ + torchrun --nproc_per_node=8 train_gpt.py + +# Seed 2025 +ITERATIONS=9000 RECUR_LAYERS=4,5 TTT_ENABLED=1 TTT_UNTIE=0 SEED=2025 \ + torchrun --nproc_per_node=8 train_gpt.py + +# Seed 2024 +ITERATIONS=9000 RECUR_LAYERS=4,5 TTT_ENABLED=1 TTT_UNTIE=0 SEED=2024 \ + torchrun --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/submission.json b/records/track_10min_16mb/2026-03-25_RecurLayers/submission.json new file mode 100644 index 000000000..e6ca363ea --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/submission.json @@ -0,0 +1,22 @@ +{ + "author": "Marko Sisovic", + "github_id": "msisovic", + "name": "Depth Recurrence (layers 4,5) + TTT", + "blurb": "Dual depth recurrence on layers 4 and 5 (11 physical -> 13 virtual layers) with tied test-time training. Reuses layer weights to add depth without increasing model size, keeping the artifact under 16MB with int6+lzma compression. Combined with TTT, SWA, bigram embeddings, value embeddings, and Muon optimizer with weight decay.", + "date": "2026-03-25T00:00:00Z", + "val_loss": 1.88836642, + "val_bpb": 1.11839992, + "val_loss_std": 0.00083132, + "val_bpb_std": 0.00049235, + "seeds": [1337, 2025, 2024], + "seed_results": { + "1337": {"val_loss": 1.88749538, "val_bpb": 1.11788404}, + "2025": {"val_loss": 1.88948575, "val_bpb": 1.11906285}, + "2024": {"val_loss": 1.88811812, "val_bpb": 1.11825287} + }, + "step_stop": 6100, + "wallclock_seconds": 600.0, + "eval_time_seconds": 475, + "bytes_total": 15928948, + "bytes_code": 95036 +} diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/train_gpt.py b/records/track_10min_16mb/2026-03-25_RecurLayers/train_gpt.py new file mode 100644 index 000000000..942f03ad3 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/train_gpt.py @@ -0,0 +1,1981 @@ +from __future__ import annotations +import copy +import glob +import io +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as flash_attn_3_func +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) + lawa_k = int(os.environ.get("LAWA_K", 10)) + lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) + value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + recur_layer = int(os.environ.get("RECUR_LAYER", -1)) # single layer compat + recur_layers_str = os.environ.get("RECUR_LAYERS", "") # comma-separated, e.g. "4,5" + eval_only = bool(int(os.environ.get("EVAL_ONLY", "0"))) + +# --- Batched Newton-Schulz orthogonalization --- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +# --- Parallel Muon optimizer --- + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# --- Tokenizer evaluation helpers --- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# --- Quantization helpers --- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# --- Data loading --- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# --- Transformer modules --- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + # No CastedLinear -- weights come from banks + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + # Gated attention and value residual (non-banked small params) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout + gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) + y = y * gate + y = y.reshape(bsz, seqlen, dim) + return F.linear(y, out_w.to(x.dtype)), raw_v + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + # No CastedLinear -- weights come from banks + def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: + x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5) + return F.linear(x.square(), down_w.to(x.dtype)) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + gated_attention=gated_attention, value_residual=value_residual) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, raw_v + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + gated_attention: bool = False, + value_residual: bool = False, + recur_layer: int = -1, + recur_layers: list[int] | None = None, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.value_residual = value_residual + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + # Depth recurrence: duplicate layers virtually for near-zero param overhead + # Normalize: recur_layers is the canonical list; recur_layer is single-layer compat + if recur_layers is None and recur_layer >= 0: + recur_layers = [recur_layer] + self.recur_layers = sorted(recur_layers) if recur_layers else [] + self.num_physical_layers = num_layers + if self.recur_layers: + for rl in self.recur_layers: + assert 0 <= rl < num_layers, f"recur_layer={rl} out of range [0, {num_layers})" + cutoff = max(self.recur_layers) + 1 + self.v2p = list(range(cutoff)) + self.recur_layers + list(range(cutoff, num_layers)) + virtual_num_layers = num_layers + len(self.recur_layers) + else: + virtual_num_layers = num_layers + self.v2p = list(range(num_layers)) + self.virtual_num_layers = virtual_num_layers + self.num_encoder_layers = virtual_num_layers // 2 + self.num_decoder_layers = virtual_num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Parameter banks: contiguous 3D tensors for batched optimizer (physical layer count) + head_dim = model_dim // num_heads + kv_dim = num_kv_heads * head_dim + mlp_dim = int(mlp_mult * model_dim) + self.num_layers = num_layers # physical, used for bank indexing + self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) + # Blocks: one per virtual layer (each has own scalar params) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + gated_attention=gated_attention, + value_residual=value_residual, + ) + for i in range(virtual_num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim_ve = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, virtual_num_layers - xsa_last_n), virtual_num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + # Init banks: orthogonal, with proj layers scaled down and out/down zero-init + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q + nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up + nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) + # Scale proj layers (out_proj and mlp_down are "proj" layers) + self.qo_bank.data[n + i].mul_(proj_scale) + self.mlp_down_bank.data[i].mul_(proj_scale) + # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + n = self.num_layers # physical layer count for bank offset + v2p = self.v2p + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + pi = v2p[i] + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], + self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + pi = v2p[bi] + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], + self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + n = self.num_layers # physical layer count for bank offset + v2p = self.v2p + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + pi = v2p[i] + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], + self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + pi = v2p[bi] + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[pi], self.kv_bank[pi], self.kv_bank[n + pi], + self.qo_bank[n + pi], self.mlp_up_bank[pi], self.mlp_down_bank[pi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def untie_recurrence(self): + """Expand banks so duplicated layers get their own independent weights. + Called before TTT so SGD can update each virtual layer independently.""" + if not self.recur_layers: + return + n = self.num_layers # physical count before expansion + # Build list of cloned rows to insert after max(recur_layers) + insert_after = max(self.recur_layers) + clones = sorted(self.recur_layers) # rows to duplicate + + def _expand(bank: Tensor) -> Tensor: + """Insert clones of recur rows after insert_after position.""" + parts = [bank[:insert_after + 1]] + for rl in clones: + parts.append(bank[rl:rl + 1].clone()) + parts.append(bank[insert_after + 1:]) + return torch.cat(parts, dim=0) + + # qo_bank: [2*n, dim, dim] -> first n are Q, last n are O + q_part = _expand(self.qo_bank.data[:n]) + o_part = _expand(self.qo_bank.data[n:]) + self.qo_bank = nn.Parameter(torch.cat([q_part, o_part], dim=0)) + + # kv_bank: [2*n, kv_dim, dim] -> first n are K, last n are V + k_part = _expand(self.kv_bank.data[:n]) + v_part = _expand(self.kv_bank.data[n:]) + self.kv_bank = nn.Parameter(torch.cat([k_part, v_part], dim=0)) + + # mlp banks: [n, ...] + self.mlp_up_bank = nn.Parameter(_expand(self.mlp_up_bank.data)) + self.mlp_down_bank = nn.Parameter(_expand(self.mlp_down_bank.data)) + + # Update to identity mapping + new_n = n + len(clones) + self.num_layers = new_n + self.num_physical_layers = new_n + self.v2p = list(range(new_n)) + self.recur_layers = [] + +# --- Sliding window evaluation --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# --- GPTQ-lite int6 quantization --- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: + """Convert 3D bank tensors into individual 2D tensors with standard names.""" + out: dict[str, Tensor] = {} + n = num_layers + for name, tensor in sd.items(): + if name == "qo_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] + out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] + elif name == "kv_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] + out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] + elif name == "mlp_up_bank": + for i in range(n): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "mlp_down_bank": + for i in range(n): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Convert individual 2D tensors back into 3D bank tensors.""" + out: dict[str, Tensor] = {} + n = num_layers + # Reconstruct banks from individual weight keys + qo_slices = [None] * (2 * n) + kv_slices = [None] * (2 * n) + up_slices = [None] * n + down_slices = [None] * n + consumed = set() + for i in range(n): + qk = f"blocks.{i}.attn.c_q.weight" + if qk in sd: + qo_slices[i] = sd[qk] + consumed.add(qk) + ok = f"blocks.{i}.attn.proj.weight" + if ok in sd: + qo_slices[n + i] = sd[ok] + consumed.add(ok) + kk = f"blocks.{i}.attn.c_k.weight" + if kk in sd: + kv_slices[i] = sd[kk] + consumed.add(kk) + vk = f"blocks.{i}.attn.c_v.weight" + if vk in sd: + kv_slices[n + i] = sd[vk] + consumed.add(vk) + fk = f"blocks.{i}.mlp.fc.weight" + if fk in sd: + up_slices[i] = sd[fk] + consumed.add(fk) + dk = f"blocks.{i}.mlp.proj.weight" + if dk in sd: + down_slices[i] = sd[dk] + consumed.add(dk) + out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) + out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) + out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) + out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# --- Training --- + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + gated_attention=args.gated_attention, + value_residual=args.value_residual, + recur_layer=args.recur_layer, + recur_layers=([int(x) for x in args.recur_layers_str.split(",") if x.strip()] or None), + ).to(device).bfloat16() + # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward + base_model.qo_bank.data = base_model.qo_bank.data.float() + base_model.kv_bank.data = base_model.kv_bank.data.float() + base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() + base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, + # and non-bank grads are manually all-reduced before Adam steps. + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = compiled_model + + if not args.eval_only: + # - 4 parameter banks -> Muon (batched Newton-Schulz) + # - token embedding -> Adam + # - scalars/control tensors -> Adam + # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) + matrix_params = [ + base_model.qo_bank, base_model.kv_bank, + base_model.mlp_up_bank, base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + # Non-bank params that need manual all-reduce (replicated across GPUs) + replicated_params = list(optimizer_tok.param_groups[0]["params"]) + for pg in optimizer_tok.param_groups[1:]: + replicated_params.extend(pg["params"]) + replicated_params.extend(scalar_params) + + optimizer_head = None + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + replicated_params.append(base_model.lm_head.weight) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if optimizer_head is not None: + optimizers.append(optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + if base_model.recur_layers: + log0(f"recurrence:layers={base_model.recur_layers} physical_layers={args.num_layers} virtual_layers={base_model.virtual_num_layers}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + # All-reduce all grads for warmup (simple, not optimized) + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + from collections import deque + lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + # === 3-phase overlapped optimizer step === + # Phase 1: Launch async reduce-scatter for banks (biggest first) + optimizer_muon.launch_reduce_scatters() + # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + optimizer_tok.step() + optimizer_scalar.step() + if optimizer_head is not None: + optimizer_head.step() + # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) + optimizer_muon.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + if args.lawa_enabled and step % args.lawa_freq == 0: + lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply weight averaging + if args.lawa_enabled and len(lawa_queue) > 1: + log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") + current_state = base_model.state_dict() + avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} + for snap in lawa_queue: + for name in avg_state: + avg_state[name] += snap[name].float() + for name in avg_state: + avg_state[name] /= len(lawa_queue) + avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) + base_model.load_state_dict(avg_state, strict=True) + else: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + # Unbank 3D tensors into individual 2D tensors for quantization + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) + quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") + log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + gated_attention=args.gated_attention, value_residual=args.value_residual, + recur_layer=args.recur_layer, + recur_layers=([int(x) for x in args.recur_layers_str.split(",") if x.strip()] or None), + ).to(device).bfloat16() + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + # Use eval_model's own state as dequant template + if not args.eval_only: + template_sd, template_unbanked = sd_cpu, unbanked_sd + else: + template_sd = {k: v.detach().cpu() for k, v in eval_model.state_dict().items()} + template_unbanked = _unbank_state_dict(template_sd, args.num_layers) + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], template_unbanked) + # Re-bank the dequantized tensors + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, template_sd) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + ttt_untie = bool(int(os.environ.get("TTT_UNTIE", "1"))) + if args.ttt_enabled: + if eval_model.recur_layers and ttt_untie: + log0(f"ttt:untying recurrence at layers {eval_model.recur_layers}") + eval_model.untie_recurrence() + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed1337.log b/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed1337.log new file mode 100644 index 000000000..038211e79 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed1337.log @@ -0,0 +1,274 @@ +W0325 03:04:09.626000 9338 torch/distributed/run.py:803] +W0325 03:04:09.626000 9338 torch/distributed/run.py:803] ***************************************** +W0325 03:04:09.626000 9338 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 03:04:09.626000 9338 torch/distributed/run.py:803] ***************************************** +logs/0b5022bd-4335-4cd0-b1f6-10ed1d9c971d.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26998380 +recurrence:layers=[4, 5] physical_layers=11 virtual_layers=13 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[9, 10, 11, 12] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/9000 val_loss:6.9309 val_bpb:4.1049 train_time:0ms step_avg:0.01ms +step:1/9000 train_loss:6.9317 train_time:145ms step_avg:145.35ms +step:2/9000 train_loss:8.5645 train_time:192ms step_avg:95.83ms +step:3/9000 train_loss:7.6147 train_time:286ms step_avg:95.17ms +step:4/9000 train_loss:7.3037 train_time:381ms step_avg:95.32ms +step:5/9000 train_loss:7.2074 train_time:476ms step_avg:95.21ms +step:6/9000 train_loss:7.1029 train_time:572ms step_avg:95.29ms +step:7/9000 train_loss:6.9765 train_time:667ms step_avg:95.32ms +step:8/9000 train_loss:6.9345 train_time:763ms step_avg:95.32ms +step:9/9000 train_loss:6.5741 train_time:857ms step_avg:95.25ms +step:10/9000 train_loss:6.1479 train_time:952ms step_avg:95.19ms +step:500/9000 train_loss:2.3847 train_time:48843ms step_avg:97.69ms +step:1000/9000 train_loss:2.2560 train_time:98080ms step_avg:98.08ms +step:1500/9000 train_loss:2.1966 train_time:147290ms step_avg:98.19ms +step:2000/9000 train_loss:2.0422 train_time:196499ms step_avg:98.25ms +step:2500/9000 train_loss:2.1505 train_time:245737ms step_avg:98.29ms +step:3000/9000 train_loss:2.1321 train_time:294888ms step_avg:98.30ms +step:3500/9000 train_loss:2.1387 train_time:344021ms step_avg:98.29ms +step:4000/9000 train_loss:1.9297 train_time:393161ms step_avg:98.29ms +step:4000/9000 val_loss:2.0194 val_bpb:1.1960 train_time:393212ms step_avg:98.30ms +step:4500/9000 train_loss:2.0776 train_time:442263ms step_avg:98.28ms +step:5000/9000 train_loss:2.0527 train_time:491349ms step_avg:98.27ms +swa:start step:5450 +step:5500/9000 train_loss:1.9627 train_time:540544ms step_avg:98.28ms +late_qat:enabled step:5579 scale:0.1500 +step:6000/9000 train_loss:1.8879 train_time:590216ms step_avg:98.37ms +step:6100/9000 val_loss:1.9185 val_bpb:1.1362 train_time:600188ms step_avg:98.39ms +stopping_early: wallclock_cap train_time:600188ms step:6100/9000 +peak memory allocated: 25295 MiB reserved: 26088 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9170 val_bpb:1.1353 eval_time:2344ms +Serialized model: 106179454 bytes +Code size: 95036 bytes +Serialized model int6+lzma: 15833912 bytes +Total submission size int6+lzma: 15928948 bytes +final_int6_roundtrip val_loss:1.9321 val_bpb:1.1443 eval_time:21468ms +final_int6_roundtrip_exact val_loss:1.93214110 val_bpb:1.14432279 +final_int6_sliding_window val_loss:1.8925 val_bpb:1.1208 stride:64 eval_time:110318ms +final_int6_sliding_window_exact val_loss:1.89245481 val_bpb:1.12082130 +final_int8_zlib_roundtrip_exact val_loss:1.89245481 val_bpb:1.12082130 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=2 +ttt_sliding:params unfrozen=26994268 frozen=4112 + ttt_chunk [1/1893] bpb=1.149595 time=0.5s + ttt_chunk [11/1893] bpb=1.144860 time=3.0s + ttt_chunk [21/1893] bpb=1.130659 time=5.5s + ttt_chunk [31/1893] bpb=1.128493 time=8.0s + ttt_chunk [41/1893] bpb=1.114966 time=10.6s + ttt_chunk [51/1893] bpb=1.109232 time=13.1s + ttt_chunk [61/1893] bpb=1.115662 time=15.6s + ttt_chunk [71/1893] bpb=1.114420 time=18.1s + ttt_chunk [81/1893] bpb=1.113703 time=20.6s + ttt_chunk [91/1893] bpb=1.114592 time=23.2s + ttt_chunk [101/1893] bpb=1.118197 time=25.7s + ttt_chunk [111/1893] bpb=1.120626 time=28.2s + ttt_chunk [121/1893] bpb=1.113734 time=30.7s + ttt_chunk [131/1893] bpb=1.113868 time=33.2s + ttt_chunk [141/1893] bpb=1.119649 time=35.7s + ttt_chunk [151/1893] bpb=1.121408 time=38.2s + ttt_chunk [161/1893] bpb=1.120830 time=40.8s + ttt_chunk [171/1893] bpb=1.125186 time=43.3s + ttt_chunk [181/1893] bpb=1.127255 time=45.8s + ttt_chunk [191/1893] bpb=1.134413 time=48.3s + ttt_chunk [201/1893] bpb=1.133158 time=50.8s + ttt_chunk [211/1893] bpb=1.130866 time=53.3s + ttt_chunk [221/1893] bpb=1.132428 time=55.9s + ttt_chunk [231/1893] bpb=1.131007 time=58.4s + ttt_chunk [241/1893] bpb=1.131338 time=60.9s + ttt_chunk [251/1893] bpb=1.130756 time=63.4s + ttt_chunk [261/1893] bpb=1.127867 time=66.0s + ttt_chunk [271/1893] bpb=1.126701 time=68.5s + ttt_chunk [281/1893] bpb=1.128083 time=71.0s + ttt_chunk [291/1893] bpb=1.129891 time=73.5s + ttt_chunk [301/1893] bpb=1.130617 time=76.0s + ttt_chunk [311/1893] bpb=1.132800 time=78.5s + ttt_chunk [321/1893] bpb=1.134770 time=81.0s + ttt_chunk [331/1893] bpb=1.134598 time=83.5s + ttt_chunk [341/1893] bpb=1.133566 time=86.0s + ttt_chunk [351/1893] bpb=1.135875 time=88.5s + ttt_chunk [361/1893] bpb=1.136122 time=91.0s + ttt_chunk [371/1893] bpb=1.135433 time=93.5s + ttt_chunk [381/1893] bpb=1.135618 time=96.0s + ttt_chunk [391/1893] bpb=1.135421 time=98.5s + ttt_chunk [401/1893] bpb=1.133384 time=101.0s + ttt_chunk [411/1893] bpb=1.132187 time=103.5s + ttt_chunk [421/1893] bpb=1.131277 time=106.0s + ttt_chunk [431/1893] bpb=1.131156 time=108.6s + ttt_chunk [441/1893] bpb=1.131604 time=111.0s + ttt_chunk [451/1893] bpb=1.131891 time=113.5s + ttt_chunk [461/1893] bpb=1.130814 time=116.1s + ttt_chunk [471/1893] bpb=1.131442 time=118.6s + ttt_chunk [481/1893] bpb=1.131098 time=121.1s + ttt_chunk [491/1893] bpb=1.129967 time=123.6s + ttt_chunk [501/1893] bpb=1.129522 time=126.1s + ttt_chunk [511/1893] bpb=1.128859 time=128.6s + ttt_chunk [521/1893] bpb=1.126487 time=131.1s + ttt_chunk [531/1893] bpb=1.127705 time=133.6s + ttt_chunk [541/1893] bpb=1.128083 time=136.1s + ttt_chunk [551/1893] bpb=1.127075 time=138.6s + ttt_chunk [561/1893] bpb=1.127641 time=141.1s + ttt_chunk [571/1893] bpb=1.126636 time=143.6s + ttt_chunk [581/1893] bpb=1.125882 time=146.1s + ttt_chunk [591/1893] bpb=1.125267 time=148.6s + ttt_chunk [601/1893] bpb=1.125751 time=151.1s + ttt_chunk [611/1893] bpb=1.125689 time=153.6s + ttt_chunk [621/1893] bpb=1.125524 time=156.1s + ttt_chunk [631/1893] bpb=1.126261 time=158.7s + ttt_chunk [641/1893] bpb=1.126015 time=161.2s + ttt_chunk [651/1893] bpb=1.126099 time=163.7s + ttt_chunk [661/1893] bpb=1.125578 time=166.2s + ttt_chunk [671/1893] bpb=1.125936 time=168.7s + ttt_chunk [681/1893] bpb=1.126666 time=171.2s + ttt_chunk [691/1893] bpb=1.127662 time=173.7s + ttt_chunk [701/1893] bpb=1.127101 time=176.2s + ttt_chunk [711/1893] bpb=1.127076 time=178.7s + ttt_chunk [721/1893] bpb=1.126723 time=181.2s + ttt_chunk [731/1893] bpb=1.126778 time=183.8s + ttt_chunk [741/1893] bpb=1.126872 time=186.3s + ttt_chunk [751/1893] bpb=1.126710 time=188.8s + ttt_chunk [761/1893] bpb=1.126628 time=191.4s + ttt_chunk [771/1893] bpb=1.126293 time=193.9s + ttt_chunk [781/1893] bpb=1.127027 time=196.4s + ttt_chunk [791/1893] bpb=1.126598 time=198.9s + ttt_chunk [801/1893] bpb=1.126939 time=201.4s + ttt_chunk [811/1893] bpb=1.126684 time=203.9s + ttt_chunk [821/1893] bpb=1.126459 time=206.4s + ttt_chunk [831/1893] bpb=1.126271 time=208.9s + ttt_chunk [841/1893] bpb=1.125602 time=211.4s + ttt_chunk [851/1893] bpb=1.125312 time=213.9s + ttt_chunk [861/1893] bpb=1.125070 time=216.4s + ttt_chunk [871/1893] bpb=1.125326 time=218.9s + ttt_chunk [881/1893] bpb=1.125508 time=221.5s + ttt_chunk [891/1893] bpb=1.125071 time=224.0s + ttt_chunk [901/1893] bpb=1.124822 time=226.5s + ttt_chunk [911/1893] bpb=1.124940 time=229.0s + ttt_chunk [921/1893] bpb=1.125432 time=231.5s + ttt_chunk [931/1893] bpb=1.125398 time=234.0s + ttt_chunk [941/1893] bpb=1.125069 time=236.5s + ttt_chunk [951/1893] bpb=1.125495 time=239.0s + ttt_chunk [961/1893] bpb=1.125568 time=241.5s + ttt_chunk [971/1893] bpb=1.126425 time=244.1s + ttt_chunk [981/1893] bpb=1.126515 time=246.6s + ttt_chunk [991/1893] bpb=1.126548 time=249.1s + ttt_chunk [1001/1893] bpb=1.126519 time=251.6s + ttt_chunk [1011/1893] bpb=1.126294 time=254.1s + ttt_chunk [1021/1893] bpb=1.126616 time=256.6s + ttt_chunk [1031/1893] bpb=1.127071 time=259.1s + ttt_chunk [1041/1893] bpb=1.126714 time=261.6s + ttt_chunk [1051/1893] bpb=1.126465 time=264.1s + ttt_chunk [1061/1893] bpb=1.126523 time=266.6s + ttt_chunk [1071/1893] bpb=1.127116 time=269.1s + ttt_chunk [1081/1893] bpb=1.127398 time=271.6s + ttt_chunk [1091/1893] bpb=1.128136 time=274.2s + ttt_chunk [1101/1893] bpb=1.128158 time=276.6s + ttt_chunk [1111/1893] bpb=1.128007 time=279.2s + ttt_chunk [1121/1893] bpb=1.127817 time=281.7s + ttt_chunk [1131/1893] bpb=1.127688 time=284.2s + ttt_chunk [1141/1893] bpb=1.127398 time=286.7s + ttt_chunk [1151/1893] bpb=1.127419 time=289.2s + ttt_chunk [1161/1893] bpb=1.127044 time=291.7s + ttt_chunk [1171/1893] bpb=1.127350 time=294.2s + ttt_chunk [1181/1893] bpb=1.126595 time=296.7s + ttt_chunk [1191/1893] bpb=1.126455 time=299.2s + ttt_chunk [1201/1893] bpb=1.126858 time=301.7s + ttt_chunk [1211/1893] bpb=1.126384 time=304.2s + ttt_chunk [1221/1893] bpb=1.126089 time=306.7s + ttt_chunk [1231/1893] bpb=1.125807 time=309.2s + ttt_chunk [1241/1893] bpb=1.125451 time=311.7s + ttt_chunk [1251/1893] bpb=1.124854 time=314.2s + ttt_chunk [1261/1893] bpb=1.124831 time=316.7s + ttt_chunk [1271/1893] bpb=1.124454 time=319.2s + ttt_chunk [1281/1893] bpb=1.124276 time=321.6s + ttt_chunk [1291/1893] bpb=1.124034 time=324.2s + ttt_chunk [1301/1893] bpb=1.123472 time=326.7s + ttt_chunk [1311/1893] bpb=1.123101 time=329.1s + ttt_chunk [1321/1893] bpb=1.122783 time=331.6s + ttt_chunk [1331/1893] bpb=1.122728 time=334.1s + ttt_chunk [1341/1893] bpb=1.122603 time=336.6s + ttt_chunk [1351/1893] bpb=1.122539 time=339.1s + ttt_chunk [1361/1893] bpb=1.122574 time=341.6s + ttt_chunk [1371/1893] bpb=1.122431 time=344.1s + ttt_chunk [1381/1893] bpb=1.122414 time=346.6s + ttt_chunk [1391/1893] bpb=1.122020 time=349.2s + ttt_chunk [1401/1893] bpb=1.121968 time=351.7s + ttt_chunk [1411/1893] bpb=1.122084 time=354.2s + ttt_chunk [1421/1893] bpb=1.122329 time=356.7s + ttt_chunk [1431/1893] bpb=1.122018 time=359.2s + ttt_chunk [1441/1893] bpb=1.122520 time=361.6s + ttt_chunk [1451/1893] bpb=1.122848 time=364.2s + ttt_chunk [1461/1893] bpb=1.122393 time=366.7s + ttt_chunk [1471/1893] bpb=1.123433 time=369.2s + ttt_chunk [1481/1893] bpb=1.123004 time=371.7s + ttt_chunk [1491/1893] bpb=1.122822 time=374.2s + ttt_chunk [1501/1893] bpb=1.122721 time=376.7s + ttt_chunk [1511/1893] bpb=1.122739 time=379.1s + ttt_chunk [1521/1893] bpb=1.122758 time=381.6s + ttt_chunk [1531/1893] bpb=1.122243 time=384.1s + ttt_chunk [1541/1893] bpb=1.122105 time=386.6s + ttt_chunk [1551/1893] bpb=1.122420 time=389.1s + ttt_chunk [1561/1893] bpb=1.122437 time=391.6s + ttt_chunk [1571/1893] bpb=1.122273 time=394.2s + ttt_chunk [1581/1893] bpb=1.122387 time=396.7s + ttt_chunk [1591/1893] bpb=1.122241 time=399.1s + ttt_chunk [1601/1893] bpb=1.122411 time=401.6s + ttt_chunk [1611/1893] bpb=1.122349 time=404.2s + ttt_chunk [1621/1893] bpb=1.121958 time=406.7s + ttt_chunk [1631/1893] bpb=1.122271 time=409.2s + ttt_chunk [1641/1893] bpb=1.122288 time=411.7s + ttt_chunk [1651/1893] bpb=1.122242 time=414.2s + ttt_chunk [1661/1893] bpb=1.122123 time=416.7s + ttt_chunk [1671/1893] bpb=1.122594 time=419.2s + ttt_chunk [1681/1893] bpb=1.122739 time=421.7s + ttt_chunk [1691/1893] bpb=1.122583 time=424.2s + ttt_chunk [1701/1893] bpb=1.122750 time=426.7s + ttt_chunk [1711/1893] bpb=1.122758 time=429.2s + ttt_chunk [1721/1893] bpb=1.122755 time=431.7s + ttt_chunk [1731/1893] bpb=1.122630 time=434.2s + ttt_chunk [1741/1893] bpb=1.122445 time=436.7s + ttt_chunk [1751/1893] bpb=1.122287 time=439.2s + ttt_chunk [1761/1893] bpb=1.122426 time=441.7s + ttt_chunk [1771/1893] bpb=1.122334 time=444.2s + ttt_chunk [1781/1893] bpb=1.122368 time=446.7s + ttt_chunk [1791/1893] bpb=1.121959 time=449.2s + ttt_chunk [1801/1893] bpb=1.121851 time=451.7s + ttt_chunk [1811/1893] bpb=1.121747 time=454.2s + ttt_chunk [1821/1893] bpb=1.121805 time=456.7s + ttt_chunk [1831/1893] bpb=1.121219 time=459.2s + ttt_chunk [1841/1893] bpb=1.121157 time=461.7s + ttt_chunk [1851/1893] bpb=1.120944 time=464.2s + ttt_chunk [1861/1893] bpb=1.120574 time=466.7s + ttt_chunk [1871/1893] bpb=1.120562 time=469.2s + ttt_chunk [1881/1893] bpb=1.120122 time=471.7s + ttt_chunk [1891/1893] bpb=1.119889 time=474.2s + ttt_chunk [1893/1893] bpb=1.119934 time=474.6s +ttt_sliding:done val_loss=1.887495 val_bpb=1.117884 elapsed=474.6s +legal_ttt val_loss:1.8875 val_bpb:1.1179 eval_time:475147ms +legal_ttt_exact val_loss:1.88749538 val_bpb:1.11788404 diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2024.log b/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2024.log new file mode 100644 index 000000000..34cfd816b --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2024.log @@ -0,0 +1,274 @@ +W0325 11:30:23.920000 12574 torch/distributed/run.py:803] +W0325 11:30:23.920000 12574 torch/distributed/run.py:803] ***************************************** +W0325 11:30:23.920000 12574 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 11:30:23.920000 12574 torch/distributed/run.py:803] ***************************************** +logs/f5675aa4-2278-4d41-8bfa-b0ded28ea1c7.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26998380 +recurrence:layers=[4, 5] physical_layers=11 virtual_layers=13 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[9, 10, 11, 12] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2024 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/9000 val_loss:6.9327 val_bpb:4.1059 train_time:0ms step_avg:0.01ms +step:1/9000 train_loss:6.9341 train_time:173ms step_avg:172.60ms +step:2/9000 train_loss:8.6442 train_time:218ms step_avg:109.24ms +step:3/9000 train_loss:7.6536 train_time:313ms step_avg:104.44ms +step:4/9000 train_loss:7.2292 train_time:409ms step_avg:102.28ms +step:5/9000 train_loss:7.1066 train_time:503ms step_avg:100.65ms +step:6/9000 train_loss:7.0691 train_time:599ms step_avg:99.77ms +step:7/9000 train_loss:6.9931 train_time:694ms step_avg:99.16ms +step:8/9000 train_loss:6.8960 train_time:790ms step_avg:98.74ms +step:9/9000 train_loss:6.5644 train_time:886ms step_avg:98.39ms +step:10/9000 train_loss:6.2023 train_time:981ms step_avg:98.07ms +step:500/9000 train_loss:2.3940 train_time:48782ms step_avg:97.56ms +step:1000/9000 train_loss:2.2600 train_time:97666ms step_avg:97.67ms +step:1500/9000 train_loss:2.2016 train_time:146580ms step_avg:97.72ms +step:2000/9000 train_loss:2.0487 train_time:195581ms step_avg:97.79ms +step:2500/9000 train_loss:2.1484 train_time:244522ms step_avg:97.81ms +step:3000/9000 train_loss:2.1327 train_time:293569ms step_avg:97.86ms +step:3500/9000 train_loss:2.1421 train_time:342629ms step_avg:97.89ms +step:4000/9000 train_loss:1.9307 train_time:391586ms step_avg:97.90ms +step:4000/9000 val_loss:2.0208 val_bpb:1.1968 train_time:391638ms step_avg:97.91ms +step:4500/9000 train_loss:2.0783 train_time:440574ms step_avg:97.91ms +step:5000/9000 train_loss:2.0532 train_time:489649ms step_avg:97.93ms +swa:start step:5450 +step:5500/9000 train_loss:1.9651 train_time:538840ms step_avg:97.97ms +late_qat:enabled step:5598 scale:0.1498 +step:6000/9000 train_loss:1.8884 train_time:588478ms step_avg:98.08ms +step:6116/9000 val_loss:1.9192 val_bpb:1.1367 train_time:600071ms step_avg:98.11ms +stopping_early: wallclock_cap train_time:600071ms step:6116/9000 +peak memory allocated: 25295 MiB reserved: 26088 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9177 val_bpb:1.1358 eval_time:2335ms +Serialized model: 106179454 bytes +Code size: 95036 bytes +Serialized model int6+lzma: 15834540 bytes +Total submission size int6+lzma: 15929576 bytes +final_int6_roundtrip val_loss:1.9316 val_bpb:1.1440 eval_time:24477ms +final_int6_roundtrip_exact val_loss:1.93157593 val_bpb:1.14398806 +final_int6_sliding_window val_loss:1.8920 val_bpb:1.1205 stride:64 eval_time:116093ms +final_int6_sliding_window_exact val_loss:1.89197002 val_bpb:1.12053418 +final_int8_zlib_roundtrip_exact val_loss:1.89197002 val_bpb:1.12053418 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=2 +ttt_sliding:params unfrozen=26994268 frozen=4112 + ttt_chunk [1/1893] bpb=1.157235 time=0.9s + ttt_chunk [11/1893] bpb=1.145437 time=3.6s + ttt_chunk [21/1893] bpb=1.130624 time=6.3s + ttt_chunk [31/1893] bpb=1.128957 time=9.0s + ttt_chunk [41/1893] bpb=1.115633 time=11.6s + ttt_chunk [51/1893] bpb=1.109563 time=14.3s + ttt_chunk [61/1893] bpb=1.115821 time=16.7s + ttt_chunk [71/1893] bpb=1.114674 time=19.2s + ttt_chunk [81/1893] bpb=1.113366 time=21.8s + ttt_chunk [91/1893] bpb=1.114258 time=24.4s + ttt_chunk [101/1893] bpb=1.117719 time=27.0s + ttt_chunk [111/1893] bpb=1.120297 time=29.6s + ttt_chunk [121/1893] bpb=1.113485 time=32.2s + ttt_chunk [131/1893] bpb=1.113538 time=34.8s + ttt_chunk [141/1893] bpb=1.119008 time=37.2s + ttt_chunk [151/1893] bpb=1.120764 time=39.5s + ttt_chunk [161/1893] bpb=1.120335 time=41.8s + ttt_chunk [171/1893] bpb=1.124695 time=44.1s + ttt_chunk [181/1893] bpb=1.126860 time=46.4s + ttt_chunk [191/1893] bpb=1.134234 time=48.6s + ttt_chunk [201/1893] bpb=1.132865 time=50.9s + ttt_chunk [211/1893] bpb=1.130650 time=53.1s + ttt_chunk [221/1893] bpb=1.132106 time=55.4s + ttt_chunk [231/1893] bpb=1.130797 time=57.6s + ttt_chunk [241/1893] bpb=1.131151 time=59.9s + ttt_chunk [251/1893] bpb=1.130670 time=62.1s + ttt_chunk [261/1893] bpb=1.127722 time=64.4s + ttt_chunk [271/1893] bpb=1.126606 time=66.6s + ttt_chunk [281/1893] bpb=1.127953 time=68.9s + ttt_chunk [291/1893] bpb=1.129728 time=71.1s + ttt_chunk [301/1893] bpb=1.130487 time=73.4s + ttt_chunk [311/1893] bpb=1.132594 time=75.6s + ttt_chunk [321/1893] bpb=1.134641 time=78.0s + ttt_chunk [331/1893] bpb=1.134524 time=80.4s + ttt_chunk [341/1893] bpb=1.133536 time=82.8s + ttt_chunk [351/1893] bpb=1.135880 time=85.2s + ttt_chunk [361/1893] bpb=1.136115 time=87.5s + ttt_chunk [371/1893] bpb=1.135469 time=89.9s + ttt_chunk [381/1893] bpb=1.135609 time=92.2s + ttt_chunk [391/1893] bpb=1.135480 time=94.5s + ttt_chunk [401/1893] bpb=1.133385 time=96.7s + ttt_chunk [411/1893] bpb=1.132205 time=99.0s + ttt_chunk [421/1893] bpb=1.131322 time=101.2s + ttt_chunk [431/1893] bpb=1.131216 time=103.5s + ttt_chunk [441/1893] bpb=1.131554 time=105.7s + ttt_chunk [451/1893] bpb=1.131902 time=108.0s + ttt_chunk [461/1893] bpb=1.130875 time=110.3s + ttt_chunk [471/1893] bpb=1.131539 time=112.5s + ttt_chunk [481/1893] bpb=1.131143 time=114.8s + ttt_chunk [491/1893] bpb=1.130071 time=117.1s + ttt_chunk [501/1893] bpb=1.129645 time=119.3s + ttt_chunk [511/1893] bpb=1.128984 time=121.6s + ttt_chunk [521/1893] bpb=1.126643 time=123.8s + ttt_chunk [531/1893] bpb=1.127851 time=126.1s + ttt_chunk [541/1893] bpb=1.128224 time=128.3s + ttt_chunk [551/1893] bpb=1.127218 time=130.6s + ttt_chunk [561/1893] bpb=1.127746 time=132.9s + ttt_chunk [571/1893] bpb=1.126747 time=135.1s + ttt_chunk [581/1893] bpb=1.125979 time=137.4s + ttt_chunk [591/1893] bpb=1.125397 time=139.6s + ttt_chunk [601/1893] bpb=1.125898 time=141.9s + ttt_chunk [611/1893] bpb=1.125859 time=144.2s + ttt_chunk [621/1893] bpb=1.125791 time=146.4s + ttt_chunk [631/1893] bpb=1.126510 time=148.7s + ttt_chunk [641/1893] bpb=1.126293 time=151.1s + ttt_chunk [651/1893] bpb=1.126428 time=153.5s + ttt_chunk [661/1893] bpb=1.125903 time=155.8s + ttt_chunk [671/1893] bpb=1.126267 time=158.2s + ttt_chunk [681/1893] bpb=1.126983 time=160.6s + ttt_chunk [691/1893] bpb=1.128003 time=163.0s + ttt_chunk [701/1893] bpb=1.127435 time=165.2s + ttt_chunk [711/1893] bpb=1.127429 time=167.5s + ttt_chunk [721/1893] bpb=1.127089 time=169.8s + ttt_chunk [731/1893] bpb=1.127105 time=172.0s + ttt_chunk [741/1893] bpb=1.127177 time=174.3s + ttt_chunk [751/1893] bpb=1.127015 time=176.5s + ttt_chunk [761/1893] bpb=1.126947 time=178.8s + ttt_chunk [771/1893] bpb=1.126646 time=181.0s + ttt_chunk [781/1893] bpb=1.127394 time=183.3s + ttt_chunk [791/1893] bpb=1.126967 time=185.5s + ttt_chunk [801/1893] bpb=1.127291 time=187.8s + ttt_chunk [811/1893] bpb=1.127037 time=190.1s + ttt_chunk [821/1893] bpb=1.126818 time=192.3s + ttt_chunk [831/1893] bpb=1.126631 time=194.6s + ttt_chunk [841/1893] bpb=1.125964 time=196.8s + ttt_chunk [851/1893] bpb=1.125690 time=199.1s + ttt_chunk [861/1893] bpb=1.125444 time=201.3s + ttt_chunk [871/1893] bpb=1.125696 time=203.6s + ttt_chunk [881/1893] bpb=1.125899 time=205.9s + ttt_chunk [891/1893] bpb=1.125483 time=208.2s + ttt_chunk [901/1893] bpb=1.125209 time=210.5s + ttt_chunk [911/1893] bpb=1.125311 time=212.7s + ttt_chunk [921/1893] bpb=1.125781 time=215.0s + ttt_chunk [931/1893] bpb=1.125751 time=217.3s + ttt_chunk [941/1893] bpb=1.125410 time=219.5s + ttt_chunk [951/1893] bpb=1.125797 time=221.8s + ttt_chunk [961/1893] bpb=1.125866 time=224.0s + ttt_chunk [971/1893] bpb=1.126712 time=226.3s + ttt_chunk [981/1893] bpb=1.126805 time=228.5s + ttt_chunk [991/1893] bpb=1.126801 time=230.8s + ttt_chunk [1001/1893] bpb=1.126733 time=233.0s + ttt_chunk [1011/1893] bpb=1.126526 time=235.3s + ttt_chunk [1021/1893] bpb=1.126842 time=237.5s + ttt_chunk [1031/1893] bpb=1.127309 time=239.8s + ttt_chunk [1041/1893] bpb=1.126969 time=242.0s + ttt_chunk [1051/1893] bpb=1.126730 time=244.3s + ttt_chunk [1061/1893] bpb=1.126774 time=246.7s + ttt_chunk [1071/1893] bpb=1.127356 time=249.3s + ttt_chunk [1081/1893] bpb=1.127653 time=251.8s + ttt_chunk [1091/1893] bpb=1.128419 time=254.3s + ttt_chunk [1101/1893] bpb=1.128422 time=256.6s + ttt_chunk [1111/1893] bpb=1.128265 time=258.9s + ttt_chunk [1121/1893] bpb=1.128077 time=261.2s + ttt_chunk [1131/1893] bpb=1.127957 time=263.7s + ttt_chunk [1141/1893] bpb=1.127664 time=266.3s + ttt_chunk [1151/1893] bpb=1.127668 time=268.8s + ttt_chunk [1161/1893] bpb=1.127274 time=271.1s + ttt_chunk [1171/1893] bpb=1.127594 time=273.3s + ttt_chunk [1181/1893] bpb=1.126854 time=275.6s + ttt_chunk [1191/1893] bpb=1.126731 time=277.9s + ttt_chunk [1201/1893] bpb=1.127156 time=280.1s + ttt_chunk [1211/1893] bpb=1.126694 time=282.6s + ttt_chunk [1221/1893] bpb=1.126393 time=285.1s + ttt_chunk [1231/1893] bpb=1.126114 time=287.7s + ttt_chunk [1241/1893] bpb=1.125781 time=290.2s + ttt_chunk [1251/1893] bpb=1.125195 time=292.8s + ttt_chunk [1261/1893] bpb=1.125177 time=295.3s + ttt_chunk [1271/1893] bpb=1.124812 time=297.8s + ttt_chunk [1281/1893] bpb=1.124630 time=300.1s + ttt_chunk [1291/1893] bpb=1.124399 time=302.4s + ttt_chunk [1301/1893] bpb=1.123817 time=304.6s + ttt_chunk [1311/1893] bpb=1.123435 time=306.9s + ttt_chunk [1321/1893] bpb=1.123118 time=309.1s + ttt_chunk [1331/1893] bpb=1.123075 time=311.4s + ttt_chunk [1341/1893] bpb=1.122952 time=313.6s + ttt_chunk [1351/1893] bpb=1.122890 time=315.9s + ttt_chunk [1361/1893] bpb=1.122934 time=318.1s + ttt_chunk [1371/1893] bpb=1.122812 time=320.4s + ttt_chunk [1381/1893] bpb=1.122799 time=322.6s + ttt_chunk [1391/1893] bpb=1.122409 time=324.8s + ttt_chunk [1401/1893] bpb=1.122373 time=327.1s + ttt_chunk [1411/1893] bpb=1.122481 time=329.3s + ttt_chunk [1421/1893] bpb=1.122736 time=331.6s + ttt_chunk [1431/1893] bpb=1.122443 time=333.8s + ttt_chunk [1441/1893] bpb=1.122945 time=336.1s + ttt_chunk [1451/1893] bpb=1.123286 time=338.4s + ttt_chunk [1461/1893] bpb=1.122835 time=340.6s + ttt_chunk [1471/1893] bpb=1.123891 time=342.9s + ttt_chunk [1481/1893] bpb=1.123440 time=345.1s + ttt_chunk [1491/1893] bpb=1.123271 time=347.4s + ttt_chunk [1501/1893] bpb=1.123187 time=349.6s + ttt_chunk [1511/1893] bpb=1.123210 time=351.9s + ttt_chunk [1521/1893] bpb=1.123226 time=354.1s + ttt_chunk [1531/1893] bpb=1.122703 time=356.3s + ttt_chunk [1541/1893] bpb=1.122582 time=358.8s + ttt_chunk [1551/1893] bpb=1.122886 time=361.1s + ttt_chunk [1561/1893] bpb=1.122896 time=363.3s + ttt_chunk [1571/1893] bpb=1.122735 time=365.6s + ttt_chunk [1581/1893] bpb=1.122840 time=367.8s + ttt_chunk [1591/1893] bpb=1.122693 time=370.1s + ttt_chunk [1601/1893] bpb=1.122859 time=372.5s + ttt_chunk [1611/1893] bpb=1.122799 time=374.8s + ttt_chunk [1621/1893] bpb=1.122396 time=377.2s + ttt_chunk [1631/1893] bpb=1.122711 time=379.5s + ttt_chunk [1641/1893] bpb=1.122729 time=381.9s + ttt_chunk [1651/1893] bpb=1.122697 time=384.2s + ttt_chunk [1661/1893] bpb=1.122572 time=386.6s + ttt_chunk [1671/1893] bpb=1.123045 time=389.0s + ttt_chunk [1681/1893] bpb=1.123201 time=391.5s + ttt_chunk [1691/1893] bpb=1.123036 time=393.9s + ttt_chunk [1701/1893] bpb=1.123203 time=396.3s + ttt_chunk [1711/1893] bpb=1.123207 time=398.5s + ttt_chunk [1721/1893] bpb=1.123209 time=400.8s + ttt_chunk [1731/1893] bpb=1.123097 time=403.1s + ttt_chunk [1741/1893] bpb=1.122899 time=405.3s + ttt_chunk [1751/1893] bpb=1.122737 time=407.6s + ttt_chunk [1761/1893] bpb=1.122879 time=409.8s + ttt_chunk [1771/1893] bpb=1.122783 time=412.1s + ttt_chunk [1781/1893] bpb=1.122812 time=414.3s + ttt_chunk [1791/1893] bpb=1.122421 time=416.6s + ttt_chunk [1801/1893] bpb=1.122302 time=418.8s + ttt_chunk [1811/1893] bpb=1.122203 time=421.1s + ttt_chunk [1821/1893] bpb=1.122273 time=423.3s + ttt_chunk [1831/1893] bpb=1.121673 time=425.6s + ttt_chunk [1841/1893] bpb=1.121646 time=427.8s + ttt_chunk [1851/1893] bpb=1.121450 time=430.1s + ttt_chunk [1861/1893] bpb=1.121083 time=432.3s + ttt_chunk [1871/1893] bpb=1.121072 time=434.6s + ttt_chunk [1881/1893] bpb=1.120618 time=436.9s + ttt_chunk [1891/1893] bpb=1.120396 time=439.3s + ttt_chunk [1893/1893] bpb=1.120441 time=439.6s +ttt_sliding:done val_loss=1.888118 val_bpb=1.118253 elapsed=439.6s +legal_ttt val_loss:1.8881 val_bpb:1.1183 eval_time:440025ms +legal_ttt_exact val_loss:1.88811812 val_bpb:1.11825287 diff --git a/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2025.log b/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2025.log new file mode 100644 index 000000000..24de5a5db --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_RecurLayers/train_seed2025.log @@ -0,0 +1,274 @@ +W0325 05:05:06.263000 8081 torch/distributed/run.py:803] +W0325 05:05:06.263000 8081 torch/distributed/run.py:803] ***************************************** +W0325 05:05:06.263000 8081 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 05:05:06.263000 8081 torch/distributed/run.py:803] ***************************************** +logs/c7954938-1ada-4e4c-882a-aaecc4b07794.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26998380 +recurrence:layers=[4, 5] physical_layers=11 virtual_layers=13 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[9, 10, 11, 12] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2025 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/9000 val_loss:6.9277 val_bpb:4.1030 train_time:0ms step_avg:0.01ms +step:1/9000 train_loss:6.9281 train_time:145ms step_avg:145.42ms +step:2/9000 train_loss:8.5172 train_time:193ms step_avg:96.30ms +step:3/9000 train_loss:7.6279 train_time:287ms step_avg:95.63ms +step:4/9000 train_loss:7.3505 train_time:382ms step_avg:95.58ms +step:5/9000 train_loss:7.1826 train_time:477ms step_avg:95.45ms +step:6/9000 train_loss:7.0688 train_time:571ms step_avg:95.23ms +step:7/9000 train_loss:6.9227 train_time:667ms step_avg:95.25ms +step:8/9000 train_loss:6.8939 train_time:762ms step_avg:95.22ms +step:9/9000 train_loss:6.5512 train_time:857ms step_avg:95.22ms +step:10/9000 train_loss:6.1678 train_time:951ms step_avg:95.12ms +step:500/9000 train_loss:2.3824 train_time:48648ms step_avg:97.30ms +step:1000/9000 train_loss:2.2574 train_time:97714ms step_avg:97.71ms +step:1500/9000 train_loss:2.2027 train_time:146827ms step_avg:97.88ms +step:2000/9000 train_loss:2.0445 train_time:196024ms step_avg:98.01ms +step:2500/9000 train_loss:2.1514 train_time:245140ms step_avg:98.06ms +step:3000/9000 train_loss:2.1320 train_time:294203ms step_avg:98.07ms +step:3500/9000 train_loss:2.1407 train_time:343243ms step_avg:98.07ms +step:4000/9000 train_loss:1.9321 train_time:392296ms step_avg:98.07ms +step:4000/9000 val_loss:2.0220 val_bpb:1.1976 train_time:392346ms step_avg:98.09ms +step:4500/9000 train_loss:2.0800 train_time:441368ms step_avg:98.08ms +step:5000/9000 train_loss:2.0584 train_time:490429ms step_avg:98.09ms +swa:start step:5450 +step:5500/9000 train_loss:1.9663 train_time:539572ms step_avg:98.10ms +late_qat:enabled step:5590 scale:0.1500 +step:6000/9000 train_loss:1.8896 train_time:589238ms step_avg:98.21ms +step:6108/9000 val_loss:1.9204 val_bpb:1.1374 train_time:600085ms step_avg:98.25ms +stopping_early: wallclock_cap train_time:600085ms step:6108/9000 +peak memory allocated: 25295 MiB reserved: 26088 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9189 val_bpb:1.1365 eval_time:2339ms +Serialized model: 106179454 bytes +Code size: 95036 bytes +Serialized model int6+lzma: 15839896 bytes +Total submission size int6+lzma: 15934932 bytes +final_int6_roundtrip val_loss:1.9334 val_bpb:1.1450 eval_time:21441ms +final_int6_roundtrip_exact val_loss:1.93335053 val_bpb:1.14503908 +final_int6_sliding_window val_loss:1.8938 val_bpb:1.1216 stride:64 eval_time:110549ms +final_int6_sliding_window_exact val_loss:1.89375476 val_bpb:1.12159121 +final_int8_zlib_roundtrip_exact val_loss:1.89375476 val_bpb:1.12159121 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=2 +ttt_sliding:params unfrozen=26994268 frozen=4112 + ttt_chunk [1/1893] bpb=1.162210 time=0.5s + ttt_chunk [11/1893] bpb=1.145781 time=3.0s + ttt_chunk [21/1893] bpb=1.130322 time=5.5s + ttt_chunk [31/1893] bpb=1.129152 time=8.1s + ttt_chunk [41/1893] bpb=1.116044 time=10.6s + ttt_chunk [51/1893] bpb=1.110693 time=13.1s + ttt_chunk [61/1893] bpb=1.117044 time=15.6s + ttt_chunk [71/1893] bpb=1.115827 time=18.1s + ttt_chunk [81/1893] bpb=1.115163 time=20.6s + ttt_chunk [91/1893] bpb=1.115961 time=23.1s + ttt_chunk [101/1893] bpb=1.119242 time=25.6s + ttt_chunk [111/1893] bpb=1.121548 time=28.1s + ttt_chunk [121/1893] bpb=1.114806 time=30.6s + ttt_chunk [131/1893] bpb=1.114754 time=33.1s + ttt_chunk [141/1893] bpb=1.120362 time=35.6s + ttt_chunk [151/1893] bpb=1.122303 time=38.1s + ttt_chunk [161/1893] bpb=1.121662 time=40.6s + ttt_chunk [171/1893] bpb=1.125906 time=43.1s + ttt_chunk [181/1893] bpb=1.128185 time=45.6s + ttt_chunk [191/1893] bpb=1.135549 time=48.1s + ttt_chunk [201/1893] bpb=1.134305 time=50.6s + ttt_chunk [211/1893] bpb=1.132100 time=53.1s + ttt_chunk [221/1893] bpb=1.133617 time=55.6s + ttt_chunk [231/1893] bpb=1.132254 time=58.1s + ttt_chunk [241/1893] bpb=1.132617 time=60.7s + ttt_chunk [251/1893] bpb=1.132233 time=63.2s + ttt_chunk [261/1893] bpb=1.129325 time=65.7s + ttt_chunk [271/1893] bpb=1.128204 time=68.2s + ttt_chunk [281/1893] bpb=1.129517 time=70.7s + ttt_chunk [291/1893] bpb=1.131399 time=73.2s + ttt_chunk [301/1893] bpb=1.132069 time=75.7s + ttt_chunk [311/1893] bpb=1.134154 time=78.2s + ttt_chunk [321/1893] bpb=1.136082 time=80.7s + ttt_chunk [331/1893] bpb=1.135933 time=83.2s + ttt_chunk [341/1893] bpb=1.134982 time=85.8s + ttt_chunk [351/1893] bpb=1.137301 time=88.3s + ttt_chunk [361/1893] bpb=1.137582 time=90.7s + ttt_chunk [371/1893] bpb=1.136861 time=93.2s + ttt_chunk [381/1893] bpb=1.137043 time=95.7s + ttt_chunk [391/1893] bpb=1.136862 time=98.2s + ttt_chunk [401/1893] bpb=1.134839 time=100.8s + ttt_chunk [411/1893] bpb=1.133707 time=103.2s + ttt_chunk [421/1893] bpb=1.132831 time=105.8s + ttt_chunk [431/1893] bpb=1.132741 time=108.3s + ttt_chunk [441/1893] bpb=1.133128 time=110.8s + ttt_chunk [451/1893] bpb=1.133348 time=113.3s + ttt_chunk [461/1893] bpb=1.132251 time=115.8s + ttt_chunk [471/1893] bpb=1.132836 time=118.2s + ttt_chunk [481/1893] bpb=1.132438 time=120.8s + ttt_chunk [491/1893] bpb=1.131381 time=123.3s + ttt_chunk [501/1893] bpb=1.130900 time=125.8s + ttt_chunk [511/1893] bpb=1.130280 time=128.3s + ttt_chunk [521/1893] bpb=1.127914 time=130.8s + ttt_chunk [531/1893] bpb=1.129097 time=133.3s + ttt_chunk [541/1893] bpb=1.129478 time=135.8s + ttt_chunk [551/1893] bpb=1.128459 time=138.3s + ttt_chunk [561/1893] bpb=1.128997 time=140.8s + ttt_chunk [571/1893] bpb=1.127988 time=143.3s + ttt_chunk [581/1893] bpb=1.127180 time=145.9s + ttt_chunk [591/1893] bpb=1.126527 time=148.4s + ttt_chunk [601/1893] bpb=1.127034 time=150.9s + ttt_chunk [611/1893] bpb=1.126999 time=153.4s + ttt_chunk [621/1893] bpb=1.126844 time=155.9s + ttt_chunk [631/1893] bpb=1.127588 time=158.4s + ttt_chunk [641/1893] bpb=1.127333 time=160.9s + ttt_chunk [651/1893] bpb=1.127450 time=163.5s + ttt_chunk [661/1893] bpb=1.126922 time=165.9s + ttt_chunk [671/1893] bpb=1.127297 time=168.4s + ttt_chunk [681/1893] bpb=1.127980 time=170.9s + ttt_chunk [691/1893] bpb=1.128981 time=173.4s + ttt_chunk [701/1893] bpb=1.128408 time=175.9s + ttt_chunk [711/1893] bpb=1.128368 time=178.4s + ttt_chunk [721/1893] bpb=1.128037 time=180.9s + ttt_chunk [731/1893] bpb=1.128080 time=183.4s + ttt_chunk [741/1893] bpb=1.128164 time=185.9s + ttt_chunk [751/1893] bpb=1.128023 time=188.4s + ttt_chunk [761/1893] bpb=1.127942 time=190.9s + ttt_chunk [771/1893] bpb=1.127617 time=193.5s + ttt_chunk [781/1893] bpb=1.128346 time=196.0s + ttt_chunk [791/1893] bpb=1.127906 time=198.5s + ttt_chunk [801/1893] bpb=1.128239 time=201.0s + ttt_chunk [811/1893] bpb=1.127962 time=203.6s + ttt_chunk [821/1893] bpb=1.127753 time=206.1s + ttt_chunk [831/1893] bpb=1.127571 time=208.6s + ttt_chunk [841/1893] bpb=1.126914 time=211.0s + ttt_chunk [851/1893] bpb=1.126640 time=213.6s + ttt_chunk [861/1893] bpb=1.126382 time=216.1s + ttt_chunk [871/1893] bpb=1.126626 time=218.5s + ttt_chunk [881/1893] bpb=1.126788 time=221.1s + ttt_chunk [891/1893] bpb=1.126361 time=223.5s + ttt_chunk [901/1893] bpb=1.126072 time=226.1s + ttt_chunk [911/1893] bpb=1.126172 time=228.6s + ttt_chunk [921/1893] bpb=1.126657 time=231.1s + ttt_chunk [931/1893] bpb=1.126618 time=233.6s + ttt_chunk [941/1893] bpb=1.126286 time=236.1s + ttt_chunk [951/1893] bpb=1.126675 time=238.6s + ttt_chunk [961/1893] bpb=1.126781 time=241.1s + ttt_chunk [971/1893] bpb=1.127633 time=243.6s + ttt_chunk [981/1893] bpb=1.127709 time=246.1s + ttt_chunk [991/1893] bpb=1.127749 time=248.6s + ttt_chunk [1001/1893] bpb=1.127703 time=251.1s + ttt_chunk [1011/1893] bpb=1.127487 time=253.6s + ttt_chunk [1021/1893] bpb=1.127835 time=256.1s + ttt_chunk [1031/1893] bpb=1.128318 time=258.6s + ttt_chunk [1041/1893] bpb=1.127965 time=261.1s + ttt_chunk [1051/1893] bpb=1.127732 time=263.6s + ttt_chunk [1061/1893] bpb=1.127778 time=266.1s + ttt_chunk [1071/1893] bpb=1.128390 time=268.6s + ttt_chunk [1081/1893] bpb=1.128659 time=271.1s + ttt_chunk [1091/1893] bpb=1.129394 time=273.6s + ttt_chunk [1101/1893] bpb=1.129401 time=276.1s + ttt_chunk [1111/1893] bpb=1.129251 time=278.6s + ttt_chunk [1121/1893] bpb=1.129077 time=281.1s + ttt_chunk [1131/1893] bpb=1.128942 time=283.6s + ttt_chunk [1141/1893] bpb=1.128641 time=286.2s + ttt_chunk [1151/1893] bpb=1.128669 time=288.7s + ttt_chunk [1161/1893] bpb=1.128280 time=291.2s + ttt_chunk [1171/1893] bpb=1.128600 time=293.7s + ttt_chunk [1181/1893] bpb=1.127839 time=296.2s + ttt_chunk [1191/1893] bpb=1.127713 time=298.7s + ttt_chunk [1201/1893] bpb=1.128090 time=301.2s + ttt_chunk [1211/1893] bpb=1.127629 time=303.6s + ttt_chunk [1221/1893] bpb=1.127339 time=306.2s + ttt_chunk [1231/1893] bpb=1.127053 time=308.7s + ttt_chunk [1241/1893] bpb=1.126716 time=311.2s + ttt_chunk [1251/1893] bpb=1.126136 time=313.7s + ttt_chunk [1261/1893] bpb=1.126119 time=316.2s + ttt_chunk [1271/1893] bpb=1.125744 time=318.7s + ttt_chunk [1281/1893] bpb=1.125561 time=321.2s + ttt_chunk [1291/1893] bpb=1.125331 time=323.7s + ttt_chunk [1301/1893] bpb=1.124734 time=326.2s + ttt_chunk [1311/1893] bpb=1.124350 time=328.7s + ttt_chunk [1321/1893] bpb=1.124044 time=331.2s + ttt_chunk [1331/1893] bpb=1.123993 time=333.7s + ttt_chunk [1341/1893] bpb=1.123860 time=336.2s + ttt_chunk [1351/1893] bpb=1.123796 time=338.7s + ttt_chunk [1361/1893] bpb=1.123841 time=341.2s + ttt_chunk [1371/1893] bpb=1.123723 time=343.7s + ttt_chunk [1381/1893] bpb=1.123699 time=346.2s + ttt_chunk [1391/1893] bpb=1.123309 time=348.8s + ttt_chunk [1401/1893] bpb=1.123276 time=351.3s + ttt_chunk [1411/1893] bpb=1.123374 time=353.8s + ttt_chunk [1421/1893] bpb=1.123616 time=356.3s + ttt_chunk [1431/1893] bpb=1.123319 time=358.8s + ttt_chunk [1441/1893] bpb=1.123826 time=361.3s + ttt_chunk [1451/1893] bpb=1.124157 time=363.9s + ttt_chunk [1461/1893] bpb=1.123698 time=366.4s + ttt_chunk [1471/1893] bpb=1.124741 time=368.9s + ttt_chunk [1481/1893] bpb=1.124276 time=371.4s + ttt_chunk [1491/1893] bpb=1.124080 time=373.9s + ttt_chunk [1501/1893] bpb=1.123990 time=376.4s + ttt_chunk [1511/1893] bpb=1.124002 time=378.9s + ttt_chunk [1521/1893] bpb=1.124037 time=381.4s + ttt_chunk [1531/1893] bpb=1.123507 time=383.9s + ttt_chunk [1541/1893] bpb=1.123343 time=386.4s + ttt_chunk [1551/1893] bpb=1.123654 time=389.0s + ttt_chunk [1561/1893] bpb=1.123668 time=391.5s + ttt_chunk [1571/1893] bpb=1.123492 time=394.0s + ttt_chunk [1581/1893] bpb=1.123611 time=396.5s + ttt_chunk [1591/1893] bpb=1.123470 time=399.0s + ttt_chunk [1601/1893] bpb=1.123644 time=401.5s + ttt_chunk [1611/1893] bpb=1.123583 time=404.0s + ttt_chunk [1621/1893] bpb=1.123181 time=406.5s + ttt_chunk [1631/1893] bpb=1.123506 time=409.0s + ttt_chunk [1641/1893] bpb=1.123512 time=411.5s + ttt_chunk [1651/1893] bpb=1.123468 time=414.1s + ttt_chunk [1661/1893] bpb=1.123346 time=416.6s + ttt_chunk [1671/1893] bpb=1.123816 time=419.1s + ttt_chunk [1681/1893] bpb=1.123963 time=421.6s + ttt_chunk [1691/1893] bpb=1.123804 time=424.1s + ttt_chunk [1701/1893] bpb=1.123954 time=426.6s + ttt_chunk [1711/1893] bpb=1.123942 time=429.2s + ttt_chunk [1721/1893] bpb=1.123947 time=431.7s + ttt_chunk [1731/1893] bpb=1.123836 time=434.2s + ttt_chunk [1741/1893] bpb=1.123626 time=436.7s + ttt_chunk [1751/1893] bpb=1.123475 time=439.2s + ttt_chunk [1761/1893] bpb=1.123631 time=441.7s + ttt_chunk [1771/1893] bpb=1.123539 time=444.2s + ttt_chunk [1781/1893] bpb=1.123568 time=446.7s + ttt_chunk [1791/1893] bpb=1.123158 time=449.2s + ttt_chunk [1801/1893] bpb=1.123045 time=451.7s + ttt_chunk [1811/1893] bpb=1.122955 time=454.2s + ttt_chunk [1821/1893] bpb=1.123013 time=456.7s + ttt_chunk [1831/1893] bpb=1.122421 time=459.2s + ttt_chunk [1841/1893] bpb=1.122384 time=461.7s + ttt_chunk [1851/1893] bpb=1.122180 time=464.2s + ttt_chunk [1861/1893] bpb=1.121829 time=466.7s + ttt_chunk [1871/1893] bpb=1.121826 time=469.2s + ttt_chunk [1881/1893] bpb=1.121367 time=471.7s + ttt_chunk [1891/1893] bpb=1.121141 time=474.2s + ttt_chunk [1893/1893] bpb=1.121187 time=474.5s +ttt_sliding:done val_loss=1.889486 val_bpb=1.119063 elapsed=474.6s +legal_ttt val_loss:1.8895 val_bpb:1.1191 eval_time:475115ms +legal_ttt_exact val_loss:1.88948575 val_bpb:1.11906285