Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions records/track_10min_16mb/2026-03-24_HedgeMixer_TTT/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# 5-expert Hedge Mixer + TTT

**val_bpb: 1.0745** (3-seed mean) | **<15.5 MB** | 8xH100 SXM

## Results (8xH100 80GB SXM)

| Seed | steps | step_avg | Pre-TTT bpb | **Post-TTT bpb** | TTT gain | Eval time | Artifact |
|------|-------|----------|-------------|-----------------|----------|-----------|----------|
| 1337 | 5,997 | 97.1ms | 1.1248 | **1.0560** | -0.0688 | 563s | 15.48 MB |
| 42 | 5,997 | 97.1ms | 1.1257 | **1.0970** | -0.0287 | 563s | 15.41 MB |
| 7 | 5,983 | 97.3ms | 1.1251 | **1.0704** | -0.0547 | 561s | 15.43 MB |
| **Mean** | | | **1.1252** | **1.0745** | **-0.0507** | | |

## Key Contribution: 5-expert Logistic Context Mixer

GPU-vectorized online context mixing using the Hedge/multiplicative-weights algorithm. Five experts blend predictions in log-probability space:

| Expert | Source | Description |
|--------|--------|-------------|
| 0 | Neural | Base model log-softmax |
| 1 | Unigram | Token frequency from scored tokens |
| 2 | Bigram | P(next \| prev) from scored tokens |
| 3 | Trigram | Hashed P(next \| prev2, prev1) with 64K buckets |
| 4 | Entropy | Neural model entropy as confidence regularizer |

Expert weights are updated online via Hedge: `log_w -= eta * loss`. N-gram tables are built incrementally from already-scored tokens only (legal).

## Architecture

PR #606 base with the following additions:

| Component | Setting |
|-----------|---------|
| Layers | 11 (512d, 8H, 8KV) |
| MLP | 3x with **LeakyReLU(0.5)^2** |
| BigramHash | 6144 (dim=128) |
| XSA | All 11 layers (ws=8) |
| RoPE | Partial (16/64 dims) |
| LN Scale | 1/sqrt(layer+1) |
| VE128 | Layers 9-10 |
| Weight avg | EMA(0.997) |
| Quantization | Full GPTQ int5 + zstd (level 22) |
| Pruning | 3% magnitude |

## Legal Score-First TTT

Backward-looking adaptation with GPTQ-calibrated model:

1. Validation tokens split into 474 chunks of 131K tokens each
2. For each chunk:
- **SCORE**: Sliding window eval (stride=32, seq_len=2048) with 5-expert mixer blending
- **TRAIN**: AdamW(lr=0.0001) on already-scored chunk. 3 epochs, last 2 blocks unfrozen + norms + lm_head, cosine LR decay, Polyak averaging
3. Last chunk scored but never trained on

### TTT Hyperparameters

| Parameter | Value |
|-----------|-------|
| Chunk size | 131,072 tokens |
| Optimizer | AdamW (lr=0.0001) |
| Epochs per chunk | 3 |
| Frozen blocks | First 9 (last 2 + norms + head unfrozen) |
| Polyak decay | 0.998 |
| Adaptive LR | max_mult=3.0 |
| Mixer eta | 0.1 |

### Training Budget

GPTQ calibration runs within the 600s training budget (18s reserved from training loop for EMA selection + calibration + quantization).

| Phase | Time |
|-------|------|
| Training loop | 582s |
| EMA + GPTQ calibration + quantization | ~18s |
| **Total training** | **~600s** |
| Sliding window eval | ~165s |
| TTT eval with mixer | ~562s |
| **Total eval** | **~562s** |

## Reproduction

```bash
# Install dependencies
pip install -r requirements.txt
# Build FA3 Hopper kernels (required)
cd /tmp && git clone https://github.com/Dao-AILab/flash-attention
cd flash-attention/hopper && python setup.py install

# Run training + eval (single seed)
DATA_PATH=./data/datasets/fineweb10B_sp1024 \
TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \
SEED=1337 MAX_WALLCLOCK_SECONDS=600 \
USE_MIXER=1 TTT_LR=0.0001 TTT_CHUNK_TOKENS=131072 \
torchrun --standalone --nproc_per_node=8 train_gpt.py

# Run all 3 seeds
for SEED in 1337 42 7; do
DATA_PATH=./data/datasets/fineweb10B_sp1024 \
TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \
SEED=$SEED MAX_WALLCLOCK_SECONDS=600 \
USE_MIXER=1 TTT_LR=0.0001 TTT_CHUNK_TOKENS=131072 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
done
```

## Credits

- **Base model**: PR #606 by @gowtham0992
- **TTT recipe**: PR #461 by @Christopher-Lee-McClendon
- **Mixer inspiration**: PAQ compression (context mixing) + Hedge algorithm
173 changes: 173 additions & 0 deletions records/track_10min_16mb/2026-03-24_HedgeMixer_TTT/log_seed1337.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
W0325 00:33:17.141000 673732 torch/distributed/run.py:852]
W0325 00:33:17.141000 673732 torch/distributed/run.py:852] *****************************************
W0325 00:33:17.141000 673732 torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0325 00:33:17.141000 673732 torch/distributed/run.py:852] *****************************************
logs/31484e99-50c6-404c-8bb8-635f618baafa.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=../data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=../data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
mixed_precision: 68 int5 layers, 0 int6 layers (last 0 blocks)
model_params:33317980
XSA:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] ws:8 gqa:8/8
lr:embed=0.035 matrix=0.025 scalar=0.025 batch:786432 wall:600s seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9285 val_bpb:4.1034 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9305 train_time:148ms step_avg:147.68ms
step:2/20000 train_loss:8.6412 train_time:239ms step_avg:119.37ms
step:3/20000 train_loss:7.7278 train_time:333ms step_avg:111.04ms
step:4/20000 train_loss:7.2812 train_time:428ms step_avg:106.95ms
step:5/20000 train_loss:7.0672 train_time:524ms step_avg:104.81ms
step:6/20000 train_loss:6.9647 train_time:619ms step_avg:103.14ms
step:7/20000 train_loss:6.8519 train_time:714ms step_avg:102.01ms
step:8/20000 train_loss:6.7091 train_time:809ms step_avg:101.08ms
step:9/20000 train_loss:6.3640 train_time:903ms step_avg:100.36ms
step:10/20000 train_loss:6.0314 train_time:998ms step_avg:99.77ms
step:500/20000 train_loss:2.3594 train_time:48287ms step_avg:96.57ms
step:1000/20000 train_loss:2.2366 train_time:96650ms step_avg:96.65ms
step:1500/20000 train_loss:2.1871 train_time:145067ms step_avg:96.71ms
step:2000/20000 train_loss:2.0272 train_time:193559ms step_avg:96.78ms
step:2500/20000 train_loss:2.1332 train_time:242098ms step_avg:96.84ms
step:3000/20000 train_loss:2.1140 train_time:290638ms step_avg:96.88ms
step:3500/20000 train_loss:2.1198 train_time:339166ms step_avg:96.90ms
step:4000/20000 train_loss:1.9079 train_time:387684ms step_avg:96.92ms
step:4000/20000 val_loss:2.0000 val_bpb:1.1845 train_time:387689ms step_avg:96.92ms
late_qat:enabled step:4255 scale:0.4998
step:4500/20000 train_loss:2.0575 train_time:436219ms step_avg:96.94ms
step:5000/20000 train_loss:2.0307 train_time:484743ms step_avg:96.95ms
swa:start step:5350
step:5500/20000 train_loss:1.9404 train_time:533445ms step_avg:96.99ms
step:5997/20000 val_loss:1.9019 val_bpb:1.1264 train_time:582051ms step_avg:97.06ms
stopping_early: wallclock_cap train_time:582051ms step:5997/20000
peak memory allocated: 26200 MiB reserved: 26782 MiB
ema:applying EMA weights (skipping diagnostic evals)
gptq:calibrating with training data...
gptq:calibrated 68 layers in 1.9s
Serialized model: 130432585 bytes
Code size: 96428 bytes
pruning:3.0% magnitude pruning applied
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
Serialized model int6+zstd: 15387977 bytes
Total submission size int6+zstd: 15484405 bytes
ttt: pre-compiling forward+backward kernels...
ttt: pre-compile done
final_int6_sliding_window val_loss:1.8992 val_bpb:1.1248 stride:32 eval_time:164968ms
final_int6_sliding_window_exact val_loss:1.89920781 val_bpb:1.12482157
TTT: epochs=3 lr=0.0001 freeze_first=2 chunk=131072 opt=adamw
TTT temperature: 0.98
PPM alpha: 0.85, Byte-weighted TTT: True
Logistic context mixer enabled: eta=0.1
Adaptive LR enabled: max_mult=3.0
ttt:start chunks=474 chunk_tokens=131072 windows=1938176 stride=32 lr=0.0001 epochs=3 opt=adamw freeze_first=2
ttt:params unfrozen=5780500 frozen=27537480
Polyak averaging enabled: decay=0.998
ttt_train [1] seqs=64 start_train...
ttt_train [1] epoch=1/3 batches=8 ...
step done ep=1 bs=0 loss=2.3437
ttt_train [1] epoch=2/3 batches=8 ...
step done ep=2 bs=0 loss=2.3101
ttt_train [1] epoch=3/3 batches=8 ...
step done ep=3 bs=0 loss=2.3086
ttt_chunk [1/474] bpb=1.202687 time=1.3s
ttt_train [2] seqs=64 start_train...
ttt_train [2] epoch=1/3 batches=8 ...
step done ep=1 bs=0 loss=2.1361
ttt_train [2] epoch=2/3 batches=8 ...
step done ep=2 bs=0 loss=2.1334
ttt_train [2] epoch=3/3 batches=8 ...
step done ep=3 bs=0 loss=2.1283
ttt_chunk [2/474] bpb=1.128760 time=2.5s
ttt_train [3] seqs=64 start_train...
ttt_train [3] epoch=1/3 batches=8 ...
step done ep=1 bs=0 loss=2.0525
ttt_train [3] epoch=2/3 batches=8 ...
step done ep=2 bs=0 loss=2.0510
ttt_train [3] epoch=3/3 batches=8 ...
step done ep=3 bs=0 loss=2.0484
ttt_chunk [3/474] bpb=1.080018 time=3.7s
ttt_chunk [4/474] bpb=1.076079 time=4.8s
ttt_chunk [5/474] bpb=1.066728 time=6.0s
ttt_chunk [11/474] bpb=1.031430 time=13.1s
ttt_chunk [21/474] bpb=1.018884 time=25.0s
ttt_chunk [31/474] bpb=1.015939 time=36.9s
ttt_chunk [41/474] bpb=1.022519 time=48.7s
ttt_chunk [51/474] bpb=1.028238 time=60.6s
ttt_chunk [61/474] bpb=1.025430 time=72.5s
ttt_chunk [71/474] bpb=1.026528 time=84.3s
ttt_chunk [81/474] bpb=1.026915 time=96.2s
ttt_chunk [91/474] bpb=1.028643 time=108.1s
ttt_chunk [101/474] bpb=1.025107 time=119.9s
ttt_chunk [111/474] bpb=1.024927 time=131.8s
ttt_chunk [121/474] bpb=1.027728 time=143.7s
ttt_chunk [131/474] bpb=1.027818 time=155.5s
ttt_chunk [141/474] bpb=1.026649 time=167.4s
ttt_chunk [151/474] bpb=1.024273 time=179.3s
ttt_chunk [161/474] bpb=1.024436 time=191.2s
ttt_chunk [171/474] bpb=1.022907 time=203.0s
ttt_chunk [181/474] bpb=1.023651 time=214.9s
ttt_chunk [191/474] bpb=1.022518 time=226.8s
ttt_chunk [201/474] bpb=1.021847 time=238.6s
ttt_chunk [211/474] bpb=1.020784 time=250.5s
ttt_chunk [221/474] bpb=1.021277 time=262.4s
ttt_chunk [231/474] bpb=1.020969 time=274.2s
ttt_chunk [241/474] bpb=1.020429 time=286.1s
ttt_chunk [251/474] bpb=1.022389 time=298.0s
ttt_chunk [261/474] bpb=1.024492 time=309.8s
ttt_chunk [271/474] bpb=1.024714 time=321.7s
ttt_chunk [281/474] bpb=1.026263 time=333.6s
ttt_chunk [291/474] bpb=1.027062 time=345.4s
ttt_chunk [301/474] bpb=1.029658 time=357.3s
ttt_chunk [311/474] bpb=1.031598 time=369.2s
ttt_chunk [321/474] bpb=1.032553 time=381.0s
ttt_chunk [331/474] bpb=1.034020 time=392.9s
ttt_chunk [341/474] bpb=1.035727 time=404.8s
ttt_chunk [351/474] bpb=1.036601 time=416.6s
ttt_chunk [361/474] bpb=1.039560 time=428.5s
ttt_chunk [371/474] bpb=1.041485 time=440.4s
ttt_chunk [381/474] bpb=1.044560 time=452.3s
ttt_chunk [391/474] bpb=1.047764 time=464.1s
ttt_chunk [401/474] bpb=1.050371 time=476.0s
ttt_chunk [411/474] bpb=1.052690 time=487.9s
ttt_chunk [421/474] bpb=1.055963 time=499.8s
ttt_chunk [431/474] bpb=1.056364 time=511.6s
ttt_chunk [441/474] bpb=1.057641 time=523.5s
ttt_chunk [451/474] bpb=1.058576 time=535.4s
ttt_chunk [461/474] bpb=1.060302 time=547.2s
ttt_chunk [471/474] bpb=1.061844 time=559.1s
ttt_chunk [474/474] bpb=1.062011 time=561.7s
ttt:done val_loss=1.783077 val_bpb=1.056042 elapsed=561.7s
final_int6_ttt val_loss:1.7831 val_bpb:1.0560 stride:32 eval_time:562538ms
final_int6_ttt_exact val_loss:1.78307674 val_bpb:1.05604198
Loading