From 133b0c1cffc637ceabed04138257f909e1b0e1fc Mon Sep 17 00:00:00 2001 From: amabito Date: Wed, 25 Mar 2026 11:15:23 +0900 Subject: [PATCH] Add TRN hybrid non-record submission (1.4942 bpb, 1x RTX 5090) Oscillatory recurrence + attention hybrid under 16 MB constraint. 10 layers (7 TRN + 3 Attn), int5 QAT, Kogge-Stone parallel scan. Int5 collapses at 20K steps due to oscillator projection phase drift. --- README.md | 1 + .../README.md | 471 ++++ .../submission.json | 25 + .../train.log | 77 + .../train_gpt_trn.py | 2222 +++++++++++++++++ 5 files changed, 2796 insertions(+) create mode 100644 records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/README.md create mode 100644 records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/submission.json create mode 100644 records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/train.log create mode 100644 records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/train_gpt_trn.py diff --git a/README.md b/README.md index 34e1b74d8..6b74160ab 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,7 @@ Happy training! | Run | Score | Author | Summary | Date | Info | |-----|------:|--------|---------|------|------| | 4-Hour Baseline | 1.2074 | Will DePue | Testing unlimited compute, 4 hours on 8xH100 | 2026-03-18 | [info](records/track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3/README.md) | +| TRN Hybrid Int5 | 1.4942 | amabito | TRN+Attention hybrid (7 TRN + 3 Attn), int5 QAT, Kogge-Stone scan, 1x RTX 5090 | 2026-03-25 | [info](records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/README.md) | ## Getting Started diff --git a/records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/README.md b/records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/README.md new file mode 100644 index 000000000..c3575b2cd --- /dev/null +++ b/records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/README.md @@ -0,0 +1,471 @@ +# Temporal Resonance Network (TRN) Hybrid for Parameter Golf + +## Central Question + +Can a linear recurrence architecture compete with attention under extreme parameter +constraints (16 MB, 10 minutes)? TRN offers O(n log n) training via parallel scan +and frequency-domain compression of sequential patterns -- but its selective recall +is weak (8.8% on synthetic copy tasks vs 96.2% for attention). This entry tests a hybrid: +7 TRN layers for pattern compression, 3 attention layers for exact retrieval, +quantized to int5 to fit 10 layers within the 16 MB budget. + +## Score + +val_bpb = 1.4942 (single seed, int5+zstd-22 roundtrip, 636 steps / 600s wallclock on single RTX 5090) +Artifact size: 15.28 MB (int5+zstd-22, code+model) + +Note: The 636-step score (wallclock-capped at 600s) is the best int5 roundtrip we +achieved. At 20K steps the fp32 model reaches 1.26 bpb, but int5 quantization +degrades it to 1.93 bpb (+0.67). The root cause is analyzed in the Quantization +section below. This is a non-record submission focused on the architecture and +the lessons learned. + +--- + +## Architecture + +10 layers: 7 TRN + 3 causal attention, interleaved. + +``` +[TRN][TRN][Attn][TRN][TRN][Attn][TRN][TRN][Attn][TRN] + 0 1 2 3 4 5 6 7 8 9 +``` + +- d_model = 512, K = 256 oscillators per TRN layer, vocab = 1024 +- SwiGLU FFN (mlp_mult=2), U-Net skip connections, tied embeddings +- Position-aware oscillators replace standard positional encoding -- + omega * log(1+t) is a deterministic phase offset (functionally a form of PE), + but with learned per-oscillator frequencies fused into the recurrence +- BigramHash(10240, dim=128): token-pair hash table added to embedding + +### Parameter Breakdown + +| Component | Params | % | +|-----------|--------|---| +| TRN oscillator projections (7L, d_model->6K) | 5.6M | 22% | +| TRN W_res (7L, 2K->d_model) | 1.8M | 7% | +| Attention QKV+out (3L, GQA 8/4 heads) | 2.4M | 9% | +| FFN SwiGLU (10L, 2x expansion) | 10.5M | 41% | +| BigramHash (10240 x 128 + projection) | 1.8M | 7% | +| Embeddings (1024 x 512, tied) | 0.5M | 2% | +| Other (norms, biases, skip weights) | 3.2M | 12% | +| **Total** | **25.8M** | **100%** | + +### TRN Recurrence + +Each TRN layer maintains K complex-valued oscillators: + +``` +drive_t = A_t * exp(j * (omega * log(1 + t) + phi_t)) +r_t = alpha_t * r_{t-1} + (1 - alpha_t) * drive_t +y_t = Re(r_t * exp(-j * omega * log(1 + t))) +``` + +A (amplitude), omega (frequency), phi (phase), alpha (decay gate) are projected +from each input token via a single linear layer (d_model -> 6K). alpha is +sigmoid-gated with multi-scale initialization (half-life ~1.4 / ~2.9 / ~33 steps). + +The log-time warp (log(1+t)) prevents frequency aliasing at long range. +Parallel training uses a Kogge-Stone prefix scan -- O(n log n), pure PyTorch. + +Related work: S4, S4D/DSS, LRU, Mamba, RWKV. Our variant combines complex +oscillatory state with learned per-oscillator frequency/phase and log-time +warping -- distinct from LRU's complex diagonal recurrence. + +### Why Hybrid + +TRN alone cannot reliably recall discrete tokens; attention alone cannot fit +competitive depth within the 16 MB budget at int8. + +Attention handles exact content-addressed retrieval -- locating a specific token +at an arbitrary position. TRN compresses sequential patterns into frequency-domain +state, which is efficient for repetitive structure and positional correlations but +weak at selective discrete recall (8.8% vs 96.2% on synthetic copy tasks; FineWeb +equivalent not directly measurable). Interleaving attention layers injects +exact-retrieval capacity periodically, preventing gradient starvation in deep +TRN stacks. + +--- + +## Training + +| Hyperparameter | Value | +|----------------|-------| +| Layers | 10 (7 TRN + 3 Attn) | +| d_model | 512 | +| Oscillators K | 256 | +| Vocab | 1024 | +| Heads / KV heads | 8 / 4 (GQA) | +| MLP mult | 2 (SwiGLU) | +| Seq len | 1024 | +| Batch tokens | 262144 | +| Iterations | 20000 | +| Warmup steps | 20 (torch.compile cache warmup, not LR warmup) | +| Warmdown iters | 1200 | +| Muon (matrices) | lr=0.04, momentum=0.95 | +| Adam (embeddings) | lr=0.05, beta=(0.9, 0.95) | +| Grad clip norm | 0.3 | +| Token shift | enabled (RWKV-6 style pre-resonance mixing) | +| Activation | LeakyReLU(0.5)^2 | +| PCG lambda | 0.5 | +| Weight decay | 0.04 (Muon matrices only) | +| EMA | decay=0.997, start at 50% of training | +| Compression | zstd level 22 | + +--- + +## Quantization + +### The int5 trade at 1000 steps + +**Depth-for-quantization trade:** int5 (vs int8) frees ~4 MB, allowing 10 layers +instead of 7 within the 16 MB budget. 10L int8 is 18.6 MB -- over the limit. + +Int5 per-row symmetric: weights mapped to [-15, 15], scale stored as fp32. +Applied to all matrix weights; embeddings remain fp16. +5 bits/weight, 0.625 bytes/weight. +Size: 25.8M * 0.625 = 16.1 MB raw + scales (~0.4 MB) + fp16 embeddings (~1 MB). +zstd-22 compresses the blob to 15.28 MB. + +QAT uses straight-through estimator (STE): forward pass quantizes weight.data, +backward pass updates the fp32 copy. + +| Quantization | bpb (10L, 1000 steps) | Degradation vs fp32 | Artifact | +|--------------|-----------------------|---------------------|----------| +| fp32 (train) | 1.4537 | -- | 103 MB | +| int8 + zstd-22 | 1.4554 | +0.002 | 18.6 MB (over limit) | +| int5 + zstd-22 (no QAT) | 1.5254 | +0.072 | 15.27 MB | +| int5 + zstd-22 (QAT frac=0.3) | 1.4942 | +0.041 | 15.28 MB | + +QAT reduces int5 degradation from +0.072 to +0.041 (a 0.031 bpb improvement). +The 7L int5+QAT model (1.4963) and 10L int5+QAT model (1.4942) achieve +similar bpb despite the depth difference -- at 1000 steps, the extra layers +have not yet contributed. Longer training is needed for the depth trade to pay off. + +### The int5 collapse at 20K steps + +At 20K steps, int5 shows severe degradation: + +| Quantization | bpb (10L, 20K steps) | Degradation | +|--------------|----------------------|-------------| +| fp32 | 1.2633 | -- | +| int8 + zstd-22 | 1.2711 | +0.008 | +| int5 + zstd-22 | 1.9321 | +0.669 | + +The int5 degradation grows from +0.041 at 1000 steps to +0.669 at 20K steps. +Int8 remains healthy (+0.008), ruling out a general quantization problem. + +**Root cause: TRN oscillator projection weights are structurally incompatible +with 5-bit quantization at convergence.** + +The oscillator projection (d_model -> 6K) encodes omega (frequency), phi (phase), +alpha (decay), and amplitude for each oscillator. These parameters control +`sin(omega * t + phi)` -- a 1-bit error in omega accumulates as O(t) phase drift +across the sequence. At 1000 steps, weight magnitudes are small and quantization +bins are close together, so errors stay bounded. At 20K steps, the distribution +sharpens: oscillator projections develop precise frequency-selective patterns +where small perturbations cascade through the recurrence. + +Supporting evidence: +- A parameter-matched 13L Transformer (25.8M params, same codebase) shows only + +0.016 bpb int5 degradation at 1000 steps -- Transformer weights do not encode + phase-sensitive recurrence parameters. +- The TRN hybrid's int8 artifact (20.1 MB) exceeds 16 MB, while the 13L + Transformer's int8 artifact (15.7 MB) fits. TRN oscillator weights have higher + entropy, resisting zstd compression (1.30x vs 1.66x for Transformer). +- QAT with 6000 STE steps did not recover the int5 quality, suggesting the + quantization error surface for oscillatory parameters is not smoothly navigable + by gradient descent through STE. + +This is an open problem. Possible directions include mixed-precision quantization +(int8 for oscillator projections, int4 for FFN), learned quantization grids +per component, or architectural changes to reduce phase sensitivity (e.g., +discretizing omega to a fixed grid before training). + +--- + +## Internal Ablation: TRN Hybrid vs Transformer + +### Short training (1000 steps, 600s wallclock, single RTX 5090) + +Both models trained within the same codebase (train_gpt_trn.py), same optimizer, +same BigramHash(10240), grad_clip=0.3, +int5 QAT, zstd-22. + +**Important caveats:** +- This is an internal ablation within one codebase, not a comparison against + the fully-optimized leaderboard baseline (train_gpt.py, 1.14-1.22 bpb). +- TRN hybrid has 27% more parameters (25.8M vs 20.3M). The bpb advantage is + partially attributable to this difference. +- Single run, no variance bars. Results are indicative, not statistically conclusive. +- Transformer step_avg includes a QAT compile spike at step 2 (49.8s); steady-state + is ~1443 ms/step. TRN steady-state is ~945 ms/step (1.53x faster). + +| Metric | Transformer 10L | TRN Hybrid 10L | Delta | +|--------|----------------|----------------|-------| +| Batch tokens/step | 262144 | 262144 | same | +| ms/step (steady-state) | ~1443 | ~945 | -35% | +| Steps in 600s | 379 | 636 | +68% | +| Total tokens seen | 199M | 334M | +68% | +| val_bpb (wallclock stop) | 1.5983 | 1.4715 | -0.127 | +| val_bpb (int8 roundtrip) | 1.6252 | 1.4755 | -0.150 | +| Params | 20.3M | 25.8M | +27% | +| Int5 artifact | 11.9 MB | 15.3 MB | | + +At the same batch size, TRN hybrid is 35% faster per step (steady-state) and sees +68% more tokens in the same 600-second budget. It achieves 0.127 bpb lower loss. + +### Parameter-matched comparison (1000 steps) + +A 13-layer Transformer (25,785,449 params) was trained under identical conditions +to provide a parameter-matched comparison. + +| Metric | Transformer 13L | TRN Hybrid 10L | +|--------|----------------|----------------| +| Params | 25.79M | 25.79M | +| ms/step (steady-state) | ~1131 | ~945 | +| fp32 val_bpb | 1.3527 | 1.4537 | +| int5 val_bpb | 1.3689 (+0.016) | 1.4942 (+0.041) | +| int5 artifact | 15.30 MB | 15.28 MB | +| VRAM | 21,574 MiB | ~19,500 MiB | + +At 1000 steps, the 13L Transformer achieves lower bpb than the 10L TRN hybrid +(-0.101). The Transformer's advantage here likely reflects two factors: +(1) 13 layers of attention vs 10 layers (3 attention + 7 TRN) provides more +capacity for the content-retrieval tasks dominant in early training, and +(2) the TRN's Kogge-Stone scan, while asymptotically efficient, has not yet +converged its oscillator parameters at 1000 steps. + +The TRN hybrid's per-step speed advantage (16% faster) partially offsets this, +yielding more steps per wallclock budget. Whether this advantage grows with +longer training is shown in the convergence section below. + +### Long training (20K steps, single RTX 5090) + +| Metric | TRN Hybrid 10L (20K) | +|--------|---------------------| +| fp32 val_bpb | 1.2633 | +| int8 val_bpb | 1.2711 (+0.008) | +| int5 val_bpb | 1.9321 (+0.669) | +| int8 artifact | 20.1 MB (over limit; 18.6 MB at 1000 steps -- weights compress less as they specialize) | +| int5 artifact | 15.4 MB | +| step_avg | ~870 ms | + +The fp32 model continues to improve (1.45 at 1K -> 1.26 at 20K), and int8 +faithfully tracks it. But int5 collapses, making the 20K model unusable under +the 16 MB constraint with uniform int5. This is the central failure of this +submission. + +--- + +## Convergence (10L, single RTX 5090, 20K steps) + +| Step | fp32 val_bpb | int8 val_bpb | int5 val_bpb | +|------|-------------|-------------|-------------| +| 636 | 1.4715 | 1.4755 | ~1.49 (QAT) | +| 2000 | 1.3874 | -- | -- | +| 6000 | 1.3307 | -- | -- | +| 10000 | 1.3167 | -- | -- | +| 14000 | 1.3100 | -- | -- | +| 18000 | 1.3044 | -- | -- | +| 20000 | 1.2633 | 1.2711 | 1.9321 | + +The fp32 trajectory shows steady improvement through 20K steps with no sign +of saturation. The gap between fp32 and int5 widens monotonically with training, +confirming that the quantization problem worsens as oscillator weights specialize. + +--- + +## Ablation: Layout (7 layers, 1000 steps) + +| Layout | attn_positions | val_bpb (int8) | +|--------|----------------|----------------| +| Front-loaded [AA TTTTT] | 0, 1 | 1.5411 | +| Interleaved [TT A TT A T] | 2, 5 | 1.5331 | + +Interleaved diverges after ~500 steps -- periodic exact-retrieval beats front-loading. + +--- + +## Ablation: Stacked Optimizations (7 layers, 1000 steps) + +Each row adds to the previous. All measured on single RTX 5090. + +| Config | int8 bpb | +|--------|----------| +| Interleaved baseline | 1.5331 | +| + BigramHash(10240) | 1.5303 | +| + grad_clip=0.3, token_shift, LeakyReLU^2, PCG=0.5 | 1.4963 | + +The last four changes were applied together (-0.034 bpb combined). +Individual contributions are not isolated. + +--- + +## Ablation: Depth + Quantization (1000 steps) + +| Config | Params | int5 bpb | Artifact | +|--------|--------|----------|----------| +| 7L int8 (no QAT) | 18.7M | -- (14.5 MB) | fits | +| 7L int5 + QAT | 18.7M | 1.4963 | 10.82 MB | +| 10L int5 + QAT | 25.8M | 1.4942 | 15.28 MB | + +The 10-layer int5 model matches the 7-layer int8 result at 1000 steps. + +--- + +## FLOP Analysis: TRN vs Transformer Per Layer + +| Component | Transformer | TRN Block | +|-----------|-------------|-----------| +| Projections (QKV / oscillator) | 0.40 GFLOP | 0.81 GFLOP | +| Attention / scan | 0.27 GFLOP | 0.002 GFLOP | +| Output proj (out / W_res) | 0.27 GFLOP | 0.54 GFLOP | +| FFN SwiGLU | 1.07 GFLOP | 1.21 GFLOP | +| Other (norms, skip, token_shift) | -- | ~0.12 GFLOP | +| **Total per layer** | **2.01 GFLOP** | **2.68 GFLOP** | + +TRN's overhead is not in the scan (0.3% of total FLOPs) but in the oscillator +projection (d_model -> 6K) and W_res output mapping (2K -> d_model). The scan +itself is cheaper than attention, but at seq_len=1024 this difference is small. + +Total model: 7 TRN + 3 Attn = 24.8 GFLOP vs 10 Attn = 20.1 GFLOP (1.23x). +Despite more FLOPs, TRN is faster per step because attention layers are +memory-bandwidth bound with synchronization points (softmax, causal mask), +while TRN's Kogge-Stone prefix scan in pure PyTorch streams with fewer +synchronization points. FLOPs alone do not predict wall-clock time for +memory-bound operations. + +--- + +## What We Tried (That Didn't Work) + +| Experiment | Result | Lesson | +|------------|--------|--------| +| Depth recurrence (TRN 2-pass weight reuse) | +0.06 bpb, 1.7x slower | omega*t phase is identical in both passes -- second pass recomputes the same function on refined input. Pass-dependent phase offset could help but was not attempted. | +| K=128 + MLP 3x expansion | +0.01 bpb vs K=256 + MLP 2x | Oscillator count matters more than FFN width for TRN | +| SWA/EMA (decay=0.997, 1000 steps) | +0.016 bpb | Warmdown too short at 1000 steps | +| TTT LoRA rank=16 | Same as rank=8 | 1 Adam step cannot move large LoRA; rank=8 is sufficient | +| EMA + QAT at 20K steps | int5 +0.669 bpb | EMA does not help int5; oscillator weights are the bottleneck | +| QAT with 6000 STE steps at 20K | int5 +0.669 bpb | STE cannot navigate the error surface of oscillatory parameters at convergence | + +--- + +## Reproducibility + +Single GPU (reproduces the submitted score -- runs 1000 iterations with 600s wallclock cap, stops at ~636 steps): + +```bash +INT5_QAT=1 INT5_QAT_START_FRAC=0.3 \ +USE_ZSTD=1 ZSTD_LEVEL=22 \ +ITERATIONS=1000 VAL_LOSS_EVERY=100 TRAIN_LOG_EVERY=50 \ +MODEL_TYPE=hybrid NUM_LAYERS=10 ATTN_POSITIONS=2,5,8 \ +BIGRAM_VOCAB_SIZE=10240 BIGRAM_DIM=128 \ +GRAD_CLIP_NORM=0.3 USE_TOKEN_SHIFT=1 \ +TRAIN_BATCH_TOKENS=262144 \ +SEED=42 \ +python train_gpt_trn.py +``` + +8xH100 (untested -- intended for future leaderboard runs if quantization is resolved): + +```bash +INT5_QAT=1 INT5_QAT_START_STEP=2000 \ +USE_ZSTD=1 ZSTD_LEVEL=22 \ +ITERATIONS=20000 \ +MODEL_TYPE=hybrid NUM_LAYERS=10 ATTN_POSITIONS=2,5,8 \ +BIGRAM_VOCAB_SIZE=10240 BIGRAM_DIM=128 \ +GRAD_CLIP_NORM=0.3 USE_TOKEN_SHIFT=1 \ +WEIGHT_DECAY=0.04 \ +TRAIN_BATCH_TOKENS=262144 \ +SEED=42 \ +torchrun --standalone --nproc_per_node=8 train_gpt_trn.py +``` + +The 20K run used TRAIN_BATCH_TOKENS=262144 explicitly via run_overnight.sh. + +Note: at 20K steps the fp32 model reaches 1.26 bpb, but int5 collapses to +1.93 bpb. The submitted score uses the 1000-step single-GPU command above. + +Unlisted env vars use code defaults: d_model=512, K=256, vocab=1024, +heads=8, kv_heads=4, mlp_mult=2, seq_len=1024, warmdown_iters=1200. + +Requires: `pip install zstandard sentencepiece` + +Data: `python3 data/cached_challenge_fineweb.py --variant sp1024` + +--- + +## Implementation Notes + +- Pure PyTorch -- no Triton, no custom CUDA extensions +- Kogge-Stone parallel prefix scan: log-depth tree reduction over complex recurrence +- torch.compile compatible (fullgraph=False; scan loop unrolls at fixed seq_len) +- Int5 bit packing in NumPy at save/load; inference uses fp32 dequantized weights +- All TRN modules inlined in train_gpt_trn.py -- zero external dependencies + +--- + +## Known Weaknesses and Open Problems + +1. **Int5 quantization collapse at convergence.** The central limitation. TRN oscillator + projections encode frequency/phase parameters where quantization errors accumulate + as O(t) phase drift. At 1000 steps the error is bounded (+0.041); at 20K steps + it is catastrophic (+0.669). This makes uniform int5 unsuitable for converged + TRN models under the 16 MB constraint. Possible mitigations (untested): + mixed-precision quantization, learned quantization grids, or omega discretization. + +2. **No int8 path within 16 MB.** TRN oscillator weights have high entropy, + resisting zstd compression (1.30x vs 1.66x for Transformer at int8). The 10L + hybrid int8 artifact is 20.1 MB -- 25% over the limit. Reducing model size + to fit int8 would sacrifice the depth advantage that motivated int5. + +3. **Parameter-matched comparison favors Transformer at 1000 steps.** The 13L + Transformer reaches 1.3689 int5 bpb vs the hybrid's 1.4942 at the same param + count and step count. TRN's per-step speed advantage (16%) partially compensates + under wallclock constraints, but the gap is real. + +4. **The four stacked optimizations (token_shift, LeakyReLU^2, PCG, grad_clip) are + not individually ablated.** Some may be inactive. + +5. **TRN's selective copy accuracy (8.8%) means the architecture depends on + attention layers for discrete recall.** The 3/10 attention ratio is a floor. + +--- + +## Ongoing Work + +The int5 quantization collapse is under active investigation. We are exploring: + +- **Mixed-precision quantization:** int8 for oscillator projections (5.6M params), + int4 for FFN (10.5M params), fp16 for embeddings. Preliminary size estimates + suggest this fits within 16 MB with careful compression tuning. +- **Early QAT:** Starting QAT from step 0 rather than the final phase, so the + model learns to maintain quantization-friendly weight distributions throughout + training. +- **Omega discretization:** Constraining oscillator frequencies to a fixed grid + during training, reducing the precision required to represent them post-quantization. + +If any of these pan out, we will update this submission. + +--- + +## Appendix: Test-Time Training (TTT) + +Not included in submission score. Per-document LoRA (rank=8, lr=0.01, chunk=256) +applied during eval. Score-then-train, causal, LoRA discarded between documents. + +Applied to the fp32 model at 1860 steps (7L, val_bpb 1.4119 before TTT). +Result: 1.4119 -> 1.3789 bpb (-0.033). Gain is smaller than Transformer TTT +(~-0.10) because TRN layers have fewer LoRA-targetable projections. + +--- + +## Acknowledgments + +Thanks to OpenAI for the challenge and compute credits. We spent most of our +time debugging why int5 kept collapsing on overnight runs -- two 5-hour training +runs wasted before we identified the oscillator projection sensitivity. We are +still working on it. The 16 MB constraint forced us into quantization territory +we had never touched before, and we plan to keep poking at the TRN + low-bit +quantization problem after the deadline. diff --git a/records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/submission.json b/records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/submission.json new file mode 100644 index 000000000..c69e67a6b --- /dev/null +++ b/records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/submission.json @@ -0,0 +1,25 @@ +{ + "author": "amabito", + "github_id": "amabito", + "name": "TRN Hybrid Int5 1x5090", + "blurb": "TRN+Attention hybrid (7 TRN + 3 Attn layers), int5 QAT zstd-22, Kogge-Stone parallel scan. 636 steps / 600s wallclock on 1x RTX 5090. Int5 collapses at 20K steps due to oscillator projection sensitivity.", + "date": "2026-03-25", + "track": "non-record-16mb", + "val_loss": 2.52288708, + "val_bpb": 1.4942, + "pre_quant_val_loss": 2.4846, + "pre_quant_val_bpb": 1.4715, + "step_stop": 636, + "wallclock_seconds": 600.674, + "bytes_total": 15359752, + "bytes_model_int5_zstd": 15277254, + "bytes_code": 82498, + "gpu": "1x RTX 5090", + "metadata": { + "tags": ["trn", "hybrid", "oscillatory-recurrence", "int5-qat", "kogge-stone-scan"], + "layers": 10, + "d_model": 512, + "params": "25.8M", + "quantization": "int5+zstd-22" + } +} diff --git a/records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/train.log b/records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/train.log new file mode 100644 index 000000000..f5bcb3632 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/train.log @@ -0,0 +1,77 @@ +logs/529d0ec9-c204-42e7-83ca-afa6048b676d.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:2 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024\fineweb_val_*.bin tokens:62021632 +model_type:hybrid num_layers:10 trn_K:256 hybrid_num_attn:2 +attn_positions:2,5,8 +layout:[TRN|TRN|Attn|TRN|TRN|Attn|TRN|TRN|Attn|TRN] +model_params:25789479 +bigram:vocab=10240 dim=128 params=1376257 +world_size:1 grad_accum_steps:8 +[W324 14:53:38.000000000 unwind.cpp:12] Warning: record_context_cpp is not support on non-linux non-x86_64 platforms (function operator ()) +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/1000 val_loss:6.9370 val_bpb:4.1085 train_time:0ms step_avg:0.01ms +step:1/1000 train_loss:6.9378 train_time:978ms step_avg:977.70ms +step:2/1000 train_loss:15.1706 train_time:1928ms step_avg:963.84ms +step:3/1000 train_loss:10.8297 train_time:2879ms step_avg:959.65ms +step:4/1000 train_loss:6.9361 train_time:3832ms step_avg:957.97ms +step:5/1000 train_loss:5.6420 train_time:4783ms step_avg:956.63ms +step:6/1000 train_loss:5.3310 train_time:5739ms step_avg:956.49ms +step:7/1000 train_loss:5.1623 train_time:6689ms step_avg:955.64ms +step:8/1000 train_loss:4.9910 train_time:7647ms step_avg:955.87ms +step:9/1000 train_loss:4.8279 train_time:8604ms step_avg:955.96ms +step:10/1000 train_loss:4.7205 train_time:9555ms step_avg:955.55ms +step:50/1000 train_loss:4.0292 train_time:47688ms step_avg:953.77ms +step:100/1000 train_loss:3.3645 train_time:95355ms step_avg:953.55ms +step:100/1000 val_loss:3.3550 val_bpb:1.9870 train_time:95356ms step_avg:953.56ms +step:150/1000 train_loss:3.1770 train_time:142399ms step_avg:949.32ms +step:200/1000 train_loss:2.9366 train_time:189430ms step_avg:947.15ms +step:200/1000 val_loss:2.9266 val_bpb:1.7333 train_time:189430ms step_avg:947.15ms +step:250/1000 train_loss:2.8042 train_time:236442ms step_avg:945.77ms +int5_qat:start step:275 scale:0.2999 +step:300/1000 train_loss:2.7486 train_time:283798ms step_avg:945.99ms +step:300/1000 val_loss:2.7174 val_bpb:1.6094 train_time:283799ms step_avg:946.00ms +step:350/1000 train_loss:2.5516 train_time:331044ms step_avg:945.84ms +step:400/1000 train_loss:2.5474 train_time:378631ms step_avg:946.58ms +step:400/1000 val_loss:2.6190 val_bpb:1.5511 train_time:378632ms step_avg:946.58ms +step:450/1000 train_loss:2.5509 train_time:425710ms step_avg:946.02ms +step:500/1000 train_loss:2.5613 train_time:472644ms step_avg:945.29ms +step:500/1000 val_loss:2.5422 val_bpb:1.5056 train_time:472645ms step_avg:945.29ms +step:550/1000 train_loss:2.4982 train_time:519767ms step_avg:945.03ms +step:600/1000 train_loss:2.4897 train_time:566767ms step_avg:944.61ms +step:600/1000 val_loss:2.4911 val_bpb:1.4754 train_time:566768ms step_avg:944.61ms +step:636/1000 val_loss:2.4846 val_bpb:1.4715 train_time:600674ms step_avg:944.46ms +stopping_early: wallclock_cap train_time:600674ms step:636/1000 +peak memory allocated: 19456 MiB reserved: 22512 MiB +int5_qat:hooks_removed before final eval +Serialized model: 62851631 bytes +Code size: 82498 bytes +Total submission size: 62934129 bytes +Serialized model int8+zstd-22: 14593028 bytes (payload:26047698 raw_torch:26108025 payload_ratio:2.41x) +Total submission size int8+zlib: 14675526 bytes +final_int8_zlib_roundtrip val_loss:2.4913 val_bpb:1.4755 eval_time:62475ms +final_int8_zlib_roundtrip_exact val_loss:2.49129179 val_bpb:1.47548333 +int5_roundtrip: quantizing to int5... +Serialized model int5+zstd-22: 15277254 bytes (15.28 MB) +Total submission size int5: 15359752 bytes +final_int5_roundtrip val_loss:2.5229 val_bpb:1.4942 eval_time:232061ms +final_int5_roundtrip_exact val_loss:2.52288708 val_bpb:1.49419584 diff --git a/records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/train_gpt_trn.py b/records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/train_gpt_trn.py new file mode 100644 index 000000000..659848e60 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-25_TRN_Hybrid_Int5_1x5090/train_gpt_trn.py @@ -0,0 +1,2222 @@ +""" +train_gpt_trn.py -- Hybrid TRN/Transformer experiment for parameter-golf. + +Supports 3 model variants via --model-type (or MODEL_TYPE env var): + transformer -- baseline GPT (identical to train_gpt.py) + trn -- all blocks replaced with TRN (no attention) + hybrid -- first 2 blocks are CausalAttn, remaining blocks are TRN + +TRN modules (oscillator, scan, resonance, block) are inlined here so this +file has zero external dependencies beyond train_gpt.py requirements. + +Architecture match at default config (model_dim=512, 9 layers): + transformer ~17.1M params + trn ~17.5M params (K=model_dim//2=256, 1.116x per layer) + hybrid ~17.3M params (2 attn + 7 trn) + +For ~same param budget, TRN uses K = model_dim // 2 by default. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import threading +import uuid +import zlib +from math import pi +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Optional Triton scan backend -- set TRITON_SCAN=1 to enable. +# Falls back to Kogge-Stone scan when Triton is unavailable or env var is unset. +_USE_TRITON_SCAN: bool = os.environ.get("TRITON_SCAN", "0") == "1" +_triton_resonance_scan = None +if _USE_TRITON_SCAN: + try: + from triton_scan import triton_resonance_scan as _triton_resonance_scan # type: ignore[assignment] + except Exception as _triton_import_err: # noqa: BLE001 + import warnings + warnings.warn( + f"TRITON_SCAN=1 set but triton_scan import failed: {_triton_import_err}. " + "Falling back to Kogge-Stone scan.", + stacklevel=1, + ) + _USE_TRITON_SCAN = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # TRN-specific: K = number of oscillators per layer. + # Default 0 = auto-computed to approximately match transformer param count. + # At d_model=512: K=120 gives TRN ~+1.1% vs transformer (within 5% target). + trn_n_oscillators = int(os.environ.get("TRN_N_OSCILLATORS", 0)) # 0 = auto + # Hybrid: how many leading blocks are attention (rest are TRN). + hybrid_num_attn = int(os.environ.get("HYBRID_NUM_ATTN", 2)) + # Interleaved: comma-separated 0-indexed positions for attention blocks. + # e.g. "2,5" means layers 2 and 5 are attention, rest TRN. + # Overrides hybrid_num_attn when set. + attn_positions_str = os.environ.get("ATTN_POSITIONS", "") + # Model variant: "transformer", "trn", "hybrid" + model_type = os.environ.get("MODEL_TYPE", "hybrid") + # Depth recurrence: how many times to run TRN blocks (attn blocks run once). + depth_recurrence = int(os.environ.get("DEPTH_RECURRENCE", 1)) + # Token shift: RWKV-6 style x_t / x_{t-1} interpolation before resonance + use_token_shift = bool(int(os.environ.get("USE_TOKEN_SHIFT", 0))) + # BigramHash embedding: 0 = disabled, >0 = bucket count + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 0)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + # Sliding window eval: stride for overlapping windows (0 = disabled) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + # SWA: stochastic weight averaging during warmdown + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", 0))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + swa_decay = float(os.environ.get("SWA_DECAY", 0.997)) + # EMA: exponential moving average, per-step, independent of LR schedule + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", 0))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + ema_start_frac = float(os.environ.get("EMA_START_FRAC", 0.5)) + # Int5 QAT: enable late quantization-aware training + int5_qat = bool(int(os.environ.get("INT5_QAT", 0))) + int5_qat_start_frac = float(os.environ.get("INT5_QAT_START_FRAC", 0.15)) + int5_qat_start_step = int(os.environ.get("INT5_QAT_START_STEP", 0)) # 0 = use frac + # Orthogonal init: apply nn.init.orthogonal_ to transformer (attention/FFN) linear weights. + # TRN layer linears (OscillatorProjection.proj, W_res, W_phase) are excluded -- they have + # domain-specific custom init and must not be overwritten. + ortho_init = bool(int(os.environ.get("ORTHO_INIT", 0))) + # Strategy 1: bake EMA shadow weights into live model when QAT starts, then disable EMA. + # Requires EMA_ENABLED=1 and INT5_QAT=1. EMA must have started before QAT trigger step. + # Final export uses live QAT weights (not EMA). Best int5 checkpoint saved every 200 QAT steps. + strategy1 = bool(int(os.environ.get("STRATEGY1", 0))) + # Periodic checkpoint (0 = disabled) + ckpt_every = int(os.environ.get("CKPT_EVERY", 0)) + # Compression: zstd (requires zstandard package) or zlib + use_zstd = bool(int(os.environ.get("USE_ZSTD", 0))) + zstd_level = int(os.environ.get("ZSTD_LEVEL", 22)) + + # Optimizer hyperparameters (identical to train_gpt.py). + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + wd = group["weight_decay"] + # Decoupled weight decay applied before the Newton-Schulz update. + if wd > 0: + for p in params: + p.mul_(1.0 - wd * lr) + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + """Sliding window evaluation -- overlapping windows for better context.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + # Only score the last `stride` tokens (except first window) + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights," + "res_scale,lambda_pcg,omega_base", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# INT5 QAT (QUANTIZATION-AWARE TRAINING) +# ----------------------------- + +_INT5_CLIP_Q = 0.9999984 # must match quantize_state_dict_int5 + + +def _int5_fake_quantize(w: Tensor) -> Tensor: + """Simulate int5 quantization: per-row percentile-clipped scale to [-15, 15].""" + if w.ndim < 2: + return w + w32 = w.float() + clip_abs = torch.quantile(w32.abs(), _INT5_CLIP_Q, dim=1).clamp_min(1e-8) + scale = (clip_abs / 15.0).unsqueeze(1) + clipped = w32.clamp(-clip_abs.unsqueeze(1), clip_abs.unsqueeze(1)) + w_q = (clipped / scale).round().clamp(-15, 15) + return (w_q * scale).to(dtype=w.dtype) + + +def apply_int5_ste_hooks(model: nn.Module) -> list: + """Register pre/post forward hooks for int5 fake-quantization. + + Pre-hook: save float weight, replace with quantized value. + Post-hook: restore float weight so optimizer steps on float values. + Gradient flows through the quantized forward via straight-through + (optimizer updates the float weight; next forward re-quantizes). + """ + hooks: list = [] + for name, module in model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + if module.weight.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + continue + + def _make_pre_hook() -> callable: + def hook(mod: nn.Module, args: tuple) -> None: + mod._weight_float = mod.weight.data.clone() + mod.weight.data = _int5_fake_quantize(mod.weight.data) + return hook + + def _make_post_hook() -> callable: + def hook(mod: nn.Module, args: tuple, output: object) -> None: + if hasattr(mod, "_weight_float"): + mod.weight.data = mod._weight_float + del mod._weight_float + return hook + + hooks.append(module.register_forward_pre_hook(_make_pre_hook())) + hooks.append(module.register_forward_hook(_make_post_hook())) + return hooks + + +def quantize_state_dict_int5(state_dict: dict[str, Tensor]) -> bytes: + """Quantize to int5 per-row and pack bits. Returns packed bytes (meta embedded).""" + packed_parts: list[bytes] = [] + byte_offset: int = 0 + meta: dict[str, object] = {"format": "int5_packed_v1", "entries": {}} + + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if not t.is_floating_point() or t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + # Keep as-is (passthrough) + buf = io.BytesIO() + torch.save({"t": t}, buf) + raw = buf.getvalue() + meta["entries"][name] = {"kind": "passthrough", "offset": byte_offset, "nbytes": len(raw)} + packed_parts.append(raw) + byte_offset += len(raw) + continue + + # Int5 quantize: per-row with percentile clipping (like int8) + t32 = t.float() + INT5_CLIP_Q = _INT5_CLIP_Q # shared with fake_quantize + if t32.ndim == 2: + clip_abs = torch.quantile(t32.abs(), INT5_CLIP_Q, dim=1).clamp_min(1e-8) + scale = clip_abs / 15.0 # (rows,) + clipped = torch.clamp(t32, -clip_abs[:, None], clip_abs[:, None]) + q = (clipped / scale[:, None]).round().clamp(-15, 15).to(torch.int8) + else: + scale_val = t32.abs().max().clamp_min(1e-8).item() / 15.0 + scale = torch.tensor([scale_val], dtype=torch.float32) + q = (t32 / scale_val).round().clamp(-15, 15).to(torch.int8) + + # Pack int5: shift to [0, 30], pack 8 values into 5 bytes + shifted = (q.flatten() + 15).to(torch.uint8) + packed = _pack_int5(shifted.numpy()) + + scale_buf = io.BytesIO() + torch.save({"s": scale.to(torch.float32), "shape": list(t.shape), "dtype": str(t.dtype).removeprefix("torch.")}, scale_buf) + scale_raw = scale_buf.getvalue() + + meta["entries"][name] = { + "kind": "int5", + "offset": byte_offset, + "packed_nbytes": len(packed), + "scale_nbytes": len(scale_raw), + "numel": t.numel(), + } + packed_parts.append(packed) + packed_parts.append(scale_raw) + byte_offset += len(packed) + len(scale_raw) + + payload = b"".join(packed_parts) + meta_buf = io.BytesIO() + torch.save(meta, meta_buf) + meta_raw = meta_buf.getvalue() + # Format: [4-byte meta_len][meta][payload] + import struct + result = struct.pack(" bytes: + """Pack uint8 values (0-30, 5 bits each) into bytes. 8 values -> 5 bytes.""" + n = len(values) + # Pad to multiple of 8 + pad = (8 - n % 8) % 8 + if pad > 0: + values = np.concatenate([values, np.zeros(pad, dtype=np.uint8)]) + values = values.reshape(-1, 8).astype(np.uint64) + # Pack 8 x 5-bit values into 40 bits (5 bytes) + packed_u64 = np.zeros(values.shape[0], dtype=np.uint64) + for i in range(8): + packed_u64 |= values[:, i].astype(np.uint64) << (i * 5) + # Extract 5 bytes from each uint64 + result = np.zeros((values.shape[0], 5), dtype=np.uint8) + for i in range(5): + result[:, i] = ((packed_u64 >> (i * 8)) & 0xFF).astype(np.uint8) + return result.tobytes() + + +def _unpack_int5(data: bytes, numel: int) -> np.ndarray: + """Unpack bytes into uint8 values (0-30, 5 bits each). Inverse of _pack_int5.""" + # Calculate padded length (multiple of 8) + padded = numel + (8 - numel % 8) % 8 + n_groups = padded // 8 + raw = np.frombuffer(data, dtype=np.uint8).reshape(n_groups, 5) + # Reconstruct uint64 from 5 bytes + packed_u64 = np.zeros(n_groups, dtype=np.uint64) + for i in range(5): + packed_u64 |= raw[:, i].astype(np.uint64) << (i * 8) + # Extract 8 x 5-bit values from each uint64 + result = np.zeros((n_groups, 8), dtype=np.uint8) + for i in range(8): + result[:, i] = ((packed_u64 >> (i * 5)) & 0x1F).astype(np.uint8) + return result.flatten()[:numel] + + +def dequantize_state_dict_int5(data: bytes) -> dict[str, Tensor]: + """Dequantize int5-packed state dict back to float tensors.""" + import struct + meta_len = struct.unpack(" Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# BASELINE TRANSFORMER MODULES +# (identical to train_gpt.py) +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), 0.5) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +# ----------------------------- +# TRN MODULES (inlined from Tri-Memory) +# ----------------------------- + +# --- scan.py --- + +class _SafeCumprod(torch.autograd.Function): + """torch.cumprod with NaN/Inf-safe backward.""" + + @staticmethod + def forward(ctx, alpha: Tensor, dim: int = 1) -> Tensor: + result = torch.cumprod(alpha, dim=dim) + ctx.save_for_backward(alpha, result) + ctx.dim = dim + return result + + @staticmethod + def backward(ctx, grad_output: Tensor) -> tuple[Tensor, None]: + alpha, result = ctx.saved_tensors + dim = ctx.dim + grad_weighted = grad_output * result + grad_sum = torch.flip( + torch.cumsum(torch.flip(grad_weighted, [dim]), dim=dim), [dim], + ) + grad_input = grad_sum / alpha.clamp(min=1e-6) + nonfinite_mask = ~torch.isfinite(grad_input) + grad_input = torch.where(nonfinite_mask, torch.zeros_like(grad_input), grad_input) + return grad_input, None + + +def _kogge_stone_scan(alpha: Tensor, drive: Tensor) -> Tensor: + """O(log n) parallel prefix scan for r_t = a_t * r_{t-1} + d_t.""" + B, n, K = alpha.shape + a = alpha + b = drive + offset = 1 + while offset < n: + # F.pad avoids new tensor alloc from torch.cat: pad left with identity values + a_left = F.pad(a[:, : n - offset], (0, 0, offset, 0), value=1.0) + b_left = F.pad(b[:, : n - offset], (0, 0, offset, 0), value=0.0) + b = b + a * b_left + a = a * a_left + offset *= 2 + return b + + +def _parallel_resonance_scan( + alpha: Tensor, + drive_r: Tensor, + drive_i: Tensor, +) -> tuple[Tensor, Tensor]: + """O(log n) parallel Kogge-Stone scan. No Triton/FLA dependency.""" + alpha = alpha.float() + drive_r = drive_r.float() + drive_i = drive_i.float() + return _kogge_stone_scan(alpha, drive_r), _kogge_stone_scan(alpha, drive_i) + + +def _triton_resonance_scan_fused( + alpha: Tensor, + drive_r: Tensor, + drive_i: Tensor, +) -> tuple[Tensor, Tensor]: + """Triton-backed fused scan for r_t = alpha_t * r_{t-1} + d_t. + + Fuses drive_r and drive_i into a single (B, n, 2K) call so that both + channels share the same alpha broadcast, halving kernel launch overhead. + + Requires TRITON_SCAN=1 and triton_scan module on PYTHONPATH. + """ + assert _triton_resonance_scan is not None, "Triton scan not available" + B, n, K = alpha.shape + alpha_f = alpha.float().contiguous() + # Stack real and imaginary drive along the K dimension: (B, n, 2K) + drive_fused = torch.cat([drive_r.float(), drive_i.float()], dim=-1) + # alpha: interleave K values for real+imag (avoids full repeat copy) + alpha_fused = torch.cat([alpha_f, alpha_f], dim=-1) + out_fused = _triton_resonance_scan(alpha_fused, drive_fused) # (B, n, 2K) + return out_fused[..., :K], out_fused[..., K:] + + +def _scan_chunk( + alpha_chunk: Tensor, + drive_chunk: Tensor, + r_prev: Tensor, +) -> tuple[Tensor, Tensor]: + alpha_cum = _SafeCumprod.apply(alpha_chunk, 1) + prev_contrib = alpha_cum * r_prev.unsqueeze(1) + safe_alpha_cum = alpha_cum.clamp(min=1e-6) + scaled_drive = drive_chunk / safe_alpha_cum + drive_contrib = torch.cumsum(scaled_drive, dim=1) * alpha_cum + small_alpha = alpha_cum < 1e-6 + drive_contrib = torch.where(small_alpha, drive_chunk * alpha_cum, drive_contrib) + chunk_out = prev_contrib + drive_contrib + return chunk_out, chunk_out[:, -1] + + +def _chunked_resonance_scan( + alpha: Tensor, + drive_r: Tensor, + drive_i: Tensor, + chunk_size: int = 64, +) -> tuple[Tensor, Tensor]: + """CPU-compatible fallback scan.""" + B, n, K = alpha.shape + r_r = alpha.new_zeros(B, K) + r_i = alpha.new_zeros(B, K) + out_r_chunks: list[Tensor] = [] + out_i_chunks: list[Tensor] = [] + for start in range(0, n, chunk_size): + end = min(start + chunk_size, n) + chunk_r, r_r = _scan_chunk(alpha[:, start:end], drive_r[:, start:end], r_r) + chunk_i, r_i = _scan_chunk(alpha[:, start:end], drive_i[:, start:end], r_i) + out_r_chunks.append(chunk_r) + out_i_chunks.append(chunk_i) + return torch.cat(out_r_chunks, dim=1), torch.cat(out_i_chunks, dim=1) + + +# --- oscillator.py --- + +class OscillatorProjection(nn.Module): + """Projects token embeddings to oscillator params (A, omega, phi, alpha, g_out, beta).""" + + def __init__( + self, + d_model: int, + K: int, + amplitude_max: float = 3.0, + gate_bias_init: float = 0.65, + max_seq_len: int = 2048, + ) -> None: + super().__init__() + self.K = K + self.amplitude_max = amplitude_max + self.proj = nn.Linear(d_model, 6 * K, bias=True) + self.proj._no_ortho_init = True # custom domain-specific init in _init_weights + + omega_init = torch.linspace( + math.log(2.0 * pi / max_seq_len), + math.log(pi * (1.0 - 1.0 / max(K, 2))), + K, + ).exp() + self.omega_base = nn.Parameter(omega_init) + + self._init_weights(gate_bias_init) + + def _init_weights(self, gate_bias_init: float = 0.65) -> None: + nn.init.normal_(self.proj.weight, std=self.proj.in_features ** -0.5) + nn.init.zeros_(self.proj.bias) + K = self.K + group_sizes = [K // 4, K // 2, K - K // 4 - K // 2] + alpha_centers = [0.30, 0.65, 0.97] + offset = 3 * K + for n_k, alpha_c in zip(group_sizes, alpha_centers): + alpha_c_clamped = max(0.01, min(0.99, alpha_c)) + bias_val = math.log(alpha_c_clamped / (1.0 - alpha_c_clamped)) + self.proj.bias.data[offset : offset + n_k].fill_(bias_val) + offset += n_k + gate_bias_init_clamped = max(0.01, min(0.99, gate_bias_init)) + g_out_bias = math.log(gate_bias_init_clamped / (1.0 - gate_bias_init_clamped)) + self.proj.bias.data[4 * K : 5 * K].fill_(g_out_bias) + self.proj.bias.data[5 * K :].fill_(-2.0) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + out = self.proj(x) + A_r, Om_r, Ph_r, Ga_r, Go_r, Be_r = out.chunk(6, dim=-1) + A = F.softplus(A_r).clamp(max=self.amplitude_max) + omega = (torch.sigmoid(Om_r) * pi + self.omega_base).clamp(min=1e-4, max=pi - 1e-4) + phi = torch.tanh(Ph_r) * pi + alpha = torch.sigmoid(Ga_r) + g_out = torch.sigmoid(Go_r) + beta = torch.sigmoid(Be_r) + return A, omega, phi, alpha, g_out, beta + + +# --- resonance.py (simplified: log phase_mode, SCPM on, no ATDR, no token_shift) --- + +class TemporalResonanceLayer(nn.Module): + """Core TRN layer with oscillatory recurrence. + + Uses log phase mode and SCPM (Spectral Cross-Product Mixing). + ATDR regularization is disabled for parameter-golf simplicity. + """ + + def __init__( + self, + d_model: int, + K: int, + amplitude_max: float = 3.0, + gate_bias_init: float = 0.65, + res_scale_init: float = 0.05, + state_norm: bool = True, + max_seq_len: int = 2048, + scan_chunk_size: int = 64, + ) -> None: + super().__init__() + self.K = K + self.state_norm_enabled = state_norm + self.scan_chunk_size = scan_chunk_size + + self.proj = OscillatorProjection( + d_model, K, + amplitude_max=amplitude_max, + gate_bias_init=gate_bias_init, + max_seq_len=max_seq_len, + ) + + # W_res: (4K-2) -> d_model, SCPM-augmented output projection + self.W_res = nn.Linear(4 * K - 2, d_model, bias=False) + nn.init.normal_(self.W_res.weight, std=2e-3) + self.W_res._no_ortho_init = True # custom small-std init for stable TRN output scale + + # PCG: phase-coupled gating + self.W_phase = nn.Linear(d_model, K, bias=True) + nn.init.normal_(self.W_phase.weight, std=1e-3) + nn.init.zeros_(self.W_phase.bias) + self.W_phase._no_ortho_init = True # custom small-std init for phase sensitivity + self.lambda_pcg = nn.Parameter(torch.tensor(0.5)) + + # Learnable output scale (small init for stable training start) + self.res_scale = nn.Parameter(torch.tensor(res_scale_init)) + self.register_buffer("_forward_count", torch.tensor(0, dtype=torch.long)) + # Cache for log-phase positions (invalidated on sequence length change) + self._cached_positions: torch.Tensor | None = None + self._cached_positions_n: int = -1 + + def _apply_state_norm(self, r_r: Tensor, r_i: Tensor) -> tuple[Tensor, Tensor]: + modulus = (r_r.pow(2) + r_i.pow(2) + 1e-8).sqrt() + scale = modulus.clamp(min=1.0) + return r_r / scale, r_i / scale + + def forward(self, x: Tensor) -> Tensor: + B, n, _ = x.shape + device = x.device + + A, omega, phi, alpha, g_out, beta = self.proj(x) + + # Log-phase time encoding (positions cached across calls for same n) + if self._cached_positions_n != n or self._cached_positions is None or self._cached_positions.device != device: + self._cached_positions = torch.log1p( + torch.arange(n, device=device, dtype=torch.float32) + ).view(1, n, 1) + self._cached_positions_n = n + positions = self._cached_positions + angle = omega * positions + phi + alpha_f = alpha.float() + drive_scale = 1.0 - alpha_f + cos_angle = torch.cos(angle).float() + sin_angle = torch.sin(angle).float() + A_f = A.float() + drive_r = drive_scale * A_f * cos_angle + drive_i = drive_scale * A_f * sin_angle + + device_type = "cuda" if x.is_cuda else "cpu" + with torch.amp.autocast(device_type, enabled=False): + if x.is_cuda and _USE_TRITON_SCAN: + r_r, r_i = _triton_resonance_scan_fused(alpha_f, drive_r, drive_i) + elif x.is_cuda: + r_r, r_i = _parallel_resonance_scan(alpha_f, drive_r, drive_i) + else: + r_r, r_i = _chunked_resonance_scan(alpha_f, drive_r, drive_i, self.scan_chunk_size) + + # Delta rule erase + beta_f = beta.float() + readout = r_r * cos_angle + r_i * sin_angle + r_r = r_r - beta_f * readout * cos_angle + r_i = r_i - beta_f * readout * sin_angle + + if self.state_norm_enabled: + r_r, r_i = self._apply_state_norm(r_r, r_i) + + r_r = r_r.to(x.dtype) + r_i = r_i.to(x.dtype) + cos_a = cos_angle.to(r_r.dtype) + sin_a = sin_angle.to(r_r.dtype) + rho_re = r_r * cos_a + r_i * sin_a + rho_im = -r_r * sin_a + r_i * cos_a + + # PCG: phase-coupled gate + rho_norm = (rho_re.pow(2) + rho_im.pow(2) + 1e-6).sqrt() + rho_re_n = rho_re / rho_norm + rho_im_n = rho_im / rho_norm + phase_query = self.W_phase(x) + cos_pq = torch.cos(phase_query) + sin_pq = torch.sin(phase_query) + phase_alignment = rho_re_n * cos_pq + rho_im_n * sin_pq + phase_gate = torch.sigmoid(self.lambda_pcg * phase_alignment) + rho_re = phase_gate * rho_re + rho_im = phase_gate * rho_im + + # SCPM: spectral cross-product mixing + xcross_re = rho_re[:, :, :-1] * rho_re[:, :, 1:] - rho_im[:, :, :-1] * rho_im[:, :, 1:] + xcross_im = rho_re[:, :, :-1] * rho_im[:, :, 1:] + rho_im[:, :, :-1] * rho_re[:, :, 1:] + + g = g_out.to(rho_re.dtype) + rho_base = torch.cat([rho_re, rho_im], dim=-1) + g_base = g.repeat(1, 1, 2) + g_cross = (g[:, :, :-1] + g[:, :, 1:]) / 2.0 + xcross = torch.cat([xcross_re, xcross_im], dim=-1) + g_cross_full = g_cross.repeat(1, 1, 2) + rho = torch.cat([g_base * rho_base, g_cross_full * xcross], dim=-1) # (B, n, 4K-2) + + # Smoothstep warmup for res_scale during training + # Note: with depth_recurrence>1, this increments multiple times per step. + # Use 2000 base steps to accommodate up to 2x recurrence. + if self.training: + self._forward_count += 1 + warmup_steps = 2000 + warmup_t = (self._forward_count.float() / warmup_steps).clamp(max=1.0) + warmup_factor: float | Tensor = warmup_t * warmup_t * (3.0 - 2.0 * warmup_t) + else: + warmup_factor = 1.0 + + return (self.res_scale * warmup_factor) * self.W_res(rho) + + +# --- TRNBlock --- + +class TRNBlock(nn.Module): + """TRN layer: pre-norm resonance + pre-norm SwiGLU FFN. + + Compatible with the GPT U-Net forward (takes x and x0). + """ + + def __init__( + self, + dim: int, + K: int, + mlp_mult: int, + max_seq_len: int = 2048, + use_token_shift: bool = False, + ) -> None: + super().__init__() + self.norm1 = RMSNorm() + self.norm2 = RMSNorm() + # Token shift: RWKV-6 style lerp between x_t and x_{t-1} + self.token_shift_mu = nn.Parameter(torch.zeros(dim)) if use_token_shift else None + self.resonance = TemporalResonanceLayer( + d_model=dim, + K=K, + max_seq_len=max_seq_len, + ) + # SwiGLU FFN (LLaMA-style, matches param count of relu^2 MLP at mlp_mult=2) + raw_hidden = int(2 / 3 * mlp_mult * dim) + d_ff_hidden = (raw_hidden + 255) // 256 * 256 + self.gate_up = nn.Linear(dim, d_ff_hidden * 2, bias=False) + self.down = nn.Linear(d_ff_hidden, dim, bias=False) + self.d_ff_hidden = d_ff_hidden + + # Control tensors (kept in fp32, matched to Block API) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + nn.init.normal_(self.gate_up.weight, std=dim ** -0.5) + nn.init.normal_(self.down.weight, std=d_ff_hidden ** -0.5) + + def _swiglu(self, x: Tensor) -> Tensor: + y = self.gate_up(x) + gate, up = y.split(self.d_ff_hidden, dim=-1) + return self.down(F.silu(gate) * up) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + h = self.norm1(x) + if self.token_shift_mu is not None: + mu = torch.sigmoid(self.token_shift_mu.to(dtype=h.dtype)) + h_shifted = torch.empty_like(h) + h_shifted[:, 0, :] = h[:, 0, :] + h_shifted[:, 1:, :] = mu * h[:, 1:, :] + (1.0 - mu) * h[:, :-1, :] + h = h_shifted + res_out = self.resonance(h) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * res_out + ffn_out = self._swiglu(self.norm2(x)) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * ffn_out + return x + + +# ----------------------------- +# BIGRAM HASH EMBEDDING +# ----------------------------- + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int) -> None: + super().__init__() + if bigram_vocab_size < 2: + raise ValueError(f"bigram_vocab_size must be >= 2, got {bigram_vocab_size}") + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +# ----------------------------- +# HYBRID GPT MODEL +# ----------------------------- + +class HybridGPT(nn.Module): + """GPT model supporting transformer, TRN-only, and hybrid block layouts. + + U-Net skip connections and resid_mix are preserved from the baseline. + """ + + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + model_type: str, + trn_K: int, + hybrid_num_attn: int, + attn_positions: set[int] | None = None, + depth_recurrence: int = 1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + use_token_shift: bool = False, + max_seq_len: int = 2048, + ortho_init: bool = False, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + if model_type not in ("transformer", "trn", "hybrid"): + raise ValueError(f"model_type must be 'transformer', 'trn', or 'hybrid', got {model_type!r}") + + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.model_type = model_type + self.ortho_init = ortho_init + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + attn_pos_set = attn_positions or set() + + def _make_block(i: int) -> nn.Module: + if model_type == "transformer": + return Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + if model_type == "trn": + return TRNBlock(model_dim, trn_K, mlp_mult, max_seq_len, use_token_shift=use_token_shift) + # hybrid: use attn_positions if set, else front-loaded + if attn_pos_set: + is_attn = i in attn_pos_set + else: + is_attn = i < hybrid_num_attn + if is_attn: + return Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + return TRNBlock(model_dim, trn_K, mlp_mult, max_seq_len, use_token_shift=use_token_shift) + + self.blocks = nn.ModuleList([_make_block(i) for i in range(num_layers)]) + # Track which blocks are TRN for depth recurrence + self.is_trn_block = [isinstance(b, TRNBlock) for b in self.blocks] + self.depth_recurrence = depth_recurrence + # Learnable pass embeddings for depth recurrence (pass 1+) + if depth_recurrence > 1: + self.pass_embeds = nn.ParameterList([ + nn.Parameter(torch.zeros(1, 1, model_dim)) + for _ in range(depth_recurrence - 1) + ]) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if not isinstance(module, nn.Linear): + continue + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif self.ortho_init and not getattr(module, "_no_ortho_init", False): + # Apply orthogonal init to transformer (attention/FFN) linears. + # TRN linears marked _no_ortho_init are skipped -- they have + # domain-specific custom init (OscillatorProjection.proj, + # TemporalResonanceLayer.W_res, W_phase). + nn.init.orthogonal_(module.weight) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # Pass 1: full U-Net (all blocks) + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + # Pass 2..N: re-run TRN blocks only (skip attn blocks) + for pass_idx in range(1, self.depth_recurrence): + x1 = x + self.pass_embeds[pass_idx - 1].to(dtype=x.dtype) + for i, block in enumerate(self.blocks): + if self.is_trn_block[i]: + x1 = block(x1, x) + x = x1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + target_ids.reshape(-1), + reduction="mean", + ) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + # Disable cpp codegen to avoid MSVC CP932/UTF-8 conflict on Japanese Windows + import torch._inductor.config as _inductor_cfg + _inductor_cfg.disable_cpp_codegen = True + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not bool(int(os.environ.get("SKIP_COMPILE", "0"))): + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(True) + enable_math_sdp(True) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # Resolve TRN K (number of oscillators). + # TRN_N_OSCILLATORS=256 gives ~1.44M params/layer vs 0.79M for attention. + # To match ~17M budget: 9*0.79M = 7.1M attn; 7*1.44M = 10.1M TRN. + # With MLP + embedding, 7-layer TRN is ~15.6M and 7-layer hybrid ~16.1M. + if args.trn_n_oscillators > 0: + trn_K = args.trn_n_oscillators + else: + trn_K = 256 # default K per instructions + + # Auto-select layer count by model_type to match ~17M param budget: + # transformer: 9L (baseline) + # trn: 7L (~15.6M with K=256) + # hybrid: 7L (2 attn + 5 TRN, ~16.1M) + num_layers = args.num_layers # override by NUM_LAYERS env if set explicitly + if os.environ.get("NUM_LAYERS") is None: + if args.model_type == "trn": + num_layers = 7 + elif args.model_type == "hybrid": + num_layers = args.hybrid_num_attn + 5 # 2+5=7 by default + + base_model = HybridGPT( + vocab_size=args.vocab_size, + num_layers=num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + model_type=args.model_type, + trn_K=trn_K, + hybrid_num_attn=args.hybrid_num_attn, + attn_positions={int(x.strip()) for x in args.attn_positions_str.split(",")} if args.attn_positions_str else None, + depth_recurrence=args.depth_recurrence, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + use_token_shift=args.use_token_shift, + max_seq_len=args.train_seq_len, + ortho_init=args.ortho_init, + ).to(device).bfloat16() + + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + skip_compile = bool(int(os.environ.get("SKIP_COMPILE", "0"))) + if skip_compile: + compiled_model = base_model + elif args.model_type == "transformer": + # Transformer: fullgraph=True for maximum fusion (same as train_gpt.py) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + else: + # TRN/hybrid: Python control flow in scan prevents fullgraph; use dynamic=True + compiled_model = torch.compile(base_model, dynamic=True, fullgraph=False) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split matching train_gpt.py: + # matrix_params (2D, not control) -> Muon + # scalar_params (1D or control) -> Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if hasattr(base_model, "pass_embeds"): + for pe in base_model.pass_embeds: + scalar_params.append(pe) + + # BigramHash params: scale -> scalar, proj -> matrix, embed -> token optimizer + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_param_list: list[dict] = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_param_list.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + optimizer_tok = torch.optim.Adam( + tok_param_list, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.weight_decay, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=0.0, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=0.0, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_type:{args.model_type} num_layers:{num_layers} trn_K:{trn_K} hybrid_num_attn:{args.hybrid_num_attn} ortho_init:{args.ortho_init}") + if args.attn_positions_str: + log0(f"attn_positions:{args.attn_positions_str}") + # Log layout + layout = [] + for blk in base_model.blocks: + layout.append("Attn" if isinstance(blk, Block) else "TRN") + log0(f"layout:[{'|'.join(layout)}]") + if args.depth_recurrence > 1: + log0(f"depth_recurrence:{args.depth_recurrence} effective_trn_depth:{sum(1 for t in base_model.is_trn_block if t) * args.depth_recurrence}") + log0(f"model_params:{n_params}") + if base_model.bigram is not None: + bigram_params = sum(p.numel() for p in base_model.bigram.parameters()) + log0(f"bigram:vocab={args.bigram_vocab_size} dim={args.bigram_dim} params={bigram_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"weight_decay:{args.weight_decay} (muon only) ema_enabled:{args.ema_enabled} ema_decay:{args.ema_decay} ema_start_frac:{args.ema_start_frac}") + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state: dict[str, Tensor] | None = None + ema_start_step = int(args.ema_start_frac * args.iterations) + if args.ema_enabled and args.swa_enabled: + log0("ema:warning both EMA_ENABLED and SWA_ENABLED set -- EMA takes priority at apply time") + qat_hooks: list | None = None + qat_active = False + strategy1_qat_active = False + best_qat_int5_bpb = float("inf") + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # Late QAT: activate int5 STE (step-based or frac-based) + qat_trigger = ( + (args.int5_qat_start_step > 0 and step >= args.int5_qat_start_step) + or (args.int5_qat_start_step == 0 and scale < args.int5_qat_start_frac) + ) + if args.int5_qat and not qat_active and qat_trigger: + if args.strategy1: + # Strategy 1: bake EMA shadow weights into live model, reset optimizer moments. + # This makes live weights equal to EMA's smoothed distribution before QAT starts. + # EMA updates are stopped from this point (strategy1_qat_active flag below). + if not args.ema_enabled: + raise RuntimeError("STRATEGY1=1 requires EMA_ENABLED=1") + if ema_state is None: + raise RuntimeError( + f"STRATEGY1=1: EMA shadow is empty at QAT start (step:{step}). " + f"EMA must begin before QAT trigger. " + f"Set EMA_START_FRAC lower than QAT trigger or use INT5_QAT_START_STEP." + ) + current_state = base_model.state_dict() + baked_state = { + name: tensor.to(dtype=current_state[name].dtype, device=current_state[name].device) + for name, tensor in ema_state.items() + } + base_model.load_state_dict(baked_state, strict=True) + # Reset first/second moments only -- keep param values and lr schedule intact. + for opt in optimizers: + for state in opt.state.values(): + if "exp_avg" in state: + state["exp_avg"].zero_() + if "exp_avg_sq" in state: + state["exp_avg_sq"].zero_() + if "max_exp_avg_sq" in state: + state["max_exp_avg_sq"].zero_() + if "momentum_buffer" in state: + state["momentum_buffer"].zero_() + # Disable weight decay during QAT -- WD shrinks weights toward zero, + # conflicting with QAT's quantization bin optimization. + for opt in optimizers: + for group in opt.param_groups: + if group.get("weight_decay", 0) > 0: + group["_orig_weight_decay"] = group["weight_decay"] + group["weight_decay"] = 0.0 + strategy1_qat_active = True + # Save pre-QAT checkpoint for rollback. + if master_process: + torch.save(base_model.state_dict(), "strategy1_baked_step{}.pt".format(step)) + log0(f"strategy1:bake_in step:{step} ema_shadow_applied:1 optimizer_moments_reset:1 weight_decay:0.0") + # Forward hooks mutate weight.data, which causes graph breaks. + # fullgraph=True does not allow graph breaks, so recompile with + # fullgraph=False before registering hooks. + if not skip_compile and args.model_type == "transformer": + compiled_model = torch.compile(base_model, dynamic=True, fullgraph=False) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + qat_hooks = apply_int5_ste_hooks(base_model) + qat_active = True + log0(f"int5_qat:start step:{step} scale:{scale:.4f}") + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + lr_scale = scale + # Strategy 1: reduce LR to 0.3x during QAT phase for stable bin adaptation. + if strategy1_qat_active: + lr_scale = scale * 0.3 + group["lr"] = group["base_lr"] * lr_scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + # SWA/EMA: update every step during warmdown (decay is per-step) + if args.swa_enabled and scale < args.swa_start_frac: + decay = args.swa_decay + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"ema:start step:{step} decay:{decay}") + else: + for name, t in base_model.state_dict().items(): + if swa_state[name].is_floating_point(): + swa_state[name] = decay * swa_state[name] + (1.0 - decay) * t.detach().cpu() + else: + swa_state[name] = t.detach().cpu() + swa_count += 1 + + # EMA: per-step update based on step fraction (independent of LR schedule) + # Strategy 1: stop EMA once QAT has started (live weights are already baked-in EMA). + if args.ema_enabled and step >= ema_start_step and not strategy1_qat_active: + ema_decay_val = args.ema_decay + if ema_state is None: + ema_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + log0(f"ema:start step:{step} decay:{ema_decay_val}") + else: + for name, t in base_model.state_dict().items(): + if ema_state[name].is_floating_point(): + ema_state[name] = ema_decay_val * ema_state[name] + (1.0 - ema_decay_val) * t.detach().cpu() + else: + ema_state[name] = t.detach().cpu() + + step += 1 + + # Strategy 1: every 200 QAT steps, do actual int5 roundtrip eval and save best checkpoint. + # Temporarily remove hooks, quantize/dequantize, eval, restore live weights, re-attach hooks. + if strategy1_qat_active and step % 200 == 0: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + # Save live fp32 weights before quantizing. + live_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + # Detach QAT hooks so eval uses clean quantized weights (no double-quantization). + if qat_hooks is None: + raise RuntimeError("strategy1_qat_active but qat_hooks is None -- internal error") + for h in qat_hooks: + h.remove() + qat_hooks.clear() + # Full int5 roundtrip: quantize -> dequantize -> load. + int5_packed = quantize_state_dict_int5(live_state) + int5_state = dequantize_state_dict_int5(int5_packed) + base_model.load_state_dict(int5_state, strict=True) + torch.cuda.synchronize() + t_int5eval = time.perf_counter() + int5_val_loss, int5_val_bpb = eval_val( + args, + base_model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"step:{step}/{args.iterations} strategy1_int5_roundtrip " + f"val_loss:{int5_val_loss:.4f} val_bpb:{int5_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_int5eval):.0f}ms" + ) + if int5_val_bpb < best_qat_int5_bpb: + best_qat_int5_bpb = int5_val_bpb + if master_process: + torch.save(live_state, "best_qat_int5.pt") + log0(f"strategy1_int5_best step:{step} val_bpb:{int5_val_bpb:.4f} path:best_qat_int5.pt") + # Restore live fp32 weights and re-attach QAT hooks. + base_model.load_state_dict(live_state, strict=True) + qat_hooks = apply_int5_ste_hooks(base_model) + # Reset timing reference after eval. + t0 = time.perf_counter() + + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Periodic checkpoint + if args.ckpt_every > 0 and step > 0 and step % args.ckpt_every == 0 and master_process: + ckpt_path = f"ckpt_step{step}.pt" + torch.save(base_model.state_dict(), ckpt_path) + log0(f"checkpoint:{ckpt_path}") + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply EMA if collected (takes priority over SWA). + # Strategy 1: skip -- live model already carries baked-in EMA weights adapted by QAT. + if strategy1_qat_active: + log0( + f"strategy1:skip_ema_apply final export uses live QAT weights " + f"best_int5_bpb:{best_qat_int5_bpb:.4f}" + ) + elif args.ema_enabled and ema_state is not None: + log0(f"ema:applying decay={args.ema_decay}") + current_state = base_model.state_dict() + avg_state = {} + for name, tensor in ema_state.items(): + avg_state[name] = tensor.to(dtype=current_state[name].dtype, device=current_state[name].device) + base_model.load_state_dict(avg_state) + # Apply SWA/EMA if collected + elif args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"ema:applying decay={args.swa_decay} from {swa_count} updates") + current_state = base_model.state_dict() + avg_state = {} + for name, tensor in swa_state.items(): + avg_state[name] = tensor.to(dtype=current_state[name].dtype) + base_model.load_state_dict(avg_state) + + # Remove QAT hooks before eval to avoid double-quantization + if qat_active: + for h in qat_hooks: + h.remove() + qat_hooks.clear() + log0("int5_qat:hooks_removed before final eval") + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if args.use_zstd: + try: + import zstandard + except ImportError: + raise ImportError("USE_ZSTD=1 requires: pip install zstandard") + quant_blob = zstandard.ZstdCompressor(level=args.zstd_level).compress(quant_raw) + compress_label = f"zstd-{args.zstd_level}" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_label = "zlib-9" + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if args.use_zstd: + try: + import zstandard + except ImportError: + raise ImportError("USE_ZSTD=1 requires: pip install zstandard") + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval on quantized model + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Int5 roundtrip eval (if QAT was used or for size comparison) + if args.int5_qat or os.environ.get("INT5_EVAL", "0") == "1": + log0("int5_roundtrip: quantizing to int5...") + # Reload fresh fp32 weights (int8 roundtrip may have overwritten) + base_model.load_state_dict(torch.load("final_model.pt", map_location="cpu")) + int5_data = quantize_state_dict_int5(base_model.state_dict()) + if args.use_zstd: + import zstandard + int5_blob = zstandard.ZstdCompressor(level=args.zstd_level).compress(int5_data) + int5_compress_label = f"zstd-{args.zstd_level}" + else: + int5_blob = zlib.compress(int5_data, level=9) + int5_compress_label = "zlib-9" + if master_process: + with open("final_model.int5.ptz", "wb") as f: + f.write(int5_blob) + int5_file_bytes = os.path.getsize("final_model.int5.ptz") + log0( + f"Serialized model int5+{int5_compress_label}: {int5_file_bytes} bytes " + f"({int5_file_bytes / 1_000_000:.2f} MB)" + ) + log0(f"Total submission size int5: {int5_file_bytes + code_bytes} bytes") + + # Barrier so non-master ranks wait for file write + if distributed: + dist.barrier() + + # Dequantize and eval + with open("final_model.int5.ptz", "rb") as f: + int5_blob_disk = f.read() + if args.use_zstd: + import zstandard + int5_decompressed = zstandard.ZstdDecompressor().decompress(int5_blob_disk) + else: + int5_decompressed = zlib.decompress(int5_blob_disk) + int5_state = dequantize_state_dict_int5(int5_decompressed) + base_model.load_state_dict(int5_state, strict=True) + base_model.to(device) + torch.cuda.synchronize() + t_int5eval = time.perf_counter() + int5_val_loss, int5_val_bpb = eval_val( + args, base_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int5_roundtrip val_loss:{int5_val_loss:.4f} val_bpb:{int5_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_int5eval):.0f}ms" + ) + log0(f"final_int5_roundtrip_exact val_loss:{int5_val_loss:.8f} val_bpb:{int5_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()