Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Record: 5-expert Hedge Mixer + CROWN-Q + stride=64 (val_bpb=1.0541)

**val_bpb: 1.0541** (3-seed mean) | **~15.7 MB** | 8xH100 SXM

## Results (8xH100 80GB SXM)

| Seed | step_avg | steps | Pre-TTT bpb | **Post-TTT bpb** | TTT gain | Eval time | Artifact |
|------|----------|-------|-------------|-----------------|----------|-----------|----------|
| 1337 | 98.1ms | 5,935 | 1.1251 | **1.0473** | -0.0778 | 336s | 15.89 MB |
| 42 | 97.9ms | 5,947 | 1.1264 | **1.0686** | -0.0578 | 336s | 15.69 MB |
| 7 | 98.0ms | 5,940 | 1.1246 | **1.0465** | -0.0781 | 336s | 15.66 MB |
| **Mean** | | | 1.1254 | **1.0541** | -0.0713 | 336s | ~15.75 MB |

## Contributions

### 1. CROWN-Q Training Penalty (training-time)
Added a quantization-aware penalty during warmdown that penalizes weights sensitive to quantization error:
```
crown_q_loss = lambda * mean(w^2 * delta^2 / 12)
```
where `delta = row_max / clip_range` is the per-row quantization step size. This encourages weights to be quantization-friendly, reducing post-quantization degradation. `CROWN_Q_LAMBDA=0.01`.

**Effect**: Slightly better compression (artifact ~200KB smaller) and more robust quantization.

### 2. Eval stride 32 -> 64 (eval-time)
Changed sliding window stride from 32 to 64 during evaluation. Experiment showed identical BPB quality but 2x faster scoring. Frees ~100s of eval budget for more TTT epochs.

### 3. TTT Epochs 3 -> 4 (eval-time)
Increased test-time training from 3 to 4 epochs per chunk, using the time freed by stride=64. Each additional epoch adapts the model more to scored data. Tested 8 epochs but that overfits (1.0735 vs 1.0473 for 4 epochs).

### Combined Effect
- stride=64 saves ~100s of eval time
- 4th TTT epoch uses ~85s of the saved time
- Net eval time: ~336s (down from ~562s), well within 600s budget
- BPB improvement: 1.0745 -> 1.0541 (-0.0204)

## Architecture

| Component | Setting |
|-----------|---------|
| Layers | 11 (512d, 8H, 8KV) |
| MLP | 3.5x with LeakyReLU(0.5)^2 |
| BigramHash | 6144 (dim=128) |
| XSA | All 11 layers (ws=8) |
| VE128 | Layers 9-10 |
| Quantization | Full GPTQ int5 + zstd level 22 |
| Pruning | 3% magnitude |
| TTT | AdamW lr=0.0001, **4 epochs**, 131K chunks, Polyak 0.998 |
| Mixer | 5-expert Hedge (neural, unigram, bigram, trigram, entropy) |
| Training reserve | 18s (for EMA + calibration + quantization) |
| Early warmdown | LR schedule targets 582s |
| **CROWN-Q** | lambda=0.01 during warmdown |
| **Eval stride** | 64 (was 32) |

## Reproduction

```bash
DATA_PATH=../data/datasets/fineweb10B_sp1024 \
TOKENIZER_PATH=../data/tokenizers/fineweb_1024_bpe.model \
SEED=1337 MAX_WALLCLOCK_SECONDS=600 \
USE_MIXER=1 MIXER_ETA=0.1 \
TTT_EPOCHS=4 TTT_FREEZE_BLOCKS=2 \
TTT_LR=0.0001 TTT_CHUNK_TOKENS=131072 \
ADAPTIVE_LR=1 ADAPTIVE_LR_MAX=3.0 \
EVAL_STRIDE=64 \
CROWN_Q_LAMBDA=0.01 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
```


## Compliance

| Constraint | Limit | Actual | Status |
|-----------|-------|--------|--------|
| Train time | 600s | 582s | Pass |
| Eval time | 600s | 336s | Pass |
| Artifact size | 16,000,000 bytes | 15,892,040 bytes (worst seed) | Pass |
| No pre-scoring training | — | Score-first TTT: each chunk scored under `inference_mode()` before any training on it | Pass |
| GPTQ calibration in training budget | — | Runs within 18s training reserve (1.9s actual) | Pass |

## Credits

- Base model: PR #414 by @signalrush
- TTT recipe: PR #461 by @Christopher-Lee-McClendon
- CROWN-Q concept: PR #693 by @EthanYangTW
- 5-expert Hedge mixer: PR #688
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
W0325 08:52:23.247000 708816 torch/distributed/run.py:852]
W0325 08:52:23.247000 708816 torch/distributed/run.py:852] *****************************************
W0325 08:52:23.247000 708816 torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0325 08:52:23.247000 708816 torch/distributed/run.py:852] *****************************************
logs/ed001136-906c-44c4-b9a3-2cb83dec0dc5.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=../data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=../data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
mixed_precision: 68 int5 layers, 0 int6 layers (last 0 blocks)
model_params:33317980
XSA:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] ws:8 gqa:8/8
lr:embed=0.035 matrix=0.025 scalar=0.025 batch:786432 wall:600s seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9285 val_bpb:4.1034 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9305 train_time:151ms step_avg:151.37ms
step:2/20000 train_loss:8.6412 train_time:241ms step_avg:120.48ms
step:3/20000 train_loss:7.7278 train_time:336ms step_avg:111.86ms
step:4/20000 train_loss:7.2812 train_time:430ms step_avg:107.50ms
step:5/20000 train_loss:7.0672 train_time:525ms step_avg:104.91ms
step:6/20000 train_loss:6.9648 train_time:619ms step_avg:103.21ms
step:7/20000 train_loss:6.8519 train_time:714ms step_avg:102.00ms
step:8/20000 train_loss:6.7092 train_time:809ms step_avg:101.12ms
step:9/20000 train_loss:6.3651 train_time:904ms step_avg:100.42ms
step:10/20000 train_loss:6.0329 train_time:999ms step_avg:99.87ms
step:500/20000 train_loss:2.3632 train_time:48334ms step_avg:96.67ms
step:1000/20000 train_loss:2.2407 train_time:96735ms step_avg:96.74ms
step:1500/20000 train_loss:2.1858 train_time:145207ms step_avg:96.80ms
step:2000/20000 train_loss:2.0310 train_time:193752ms step_avg:96.88ms
step:2500/20000 train_loss:2.1364 train_time:242327ms step_avg:96.93ms
step:3000/20000 train_loss:2.1144 train_time:290910ms step_avg:96.97ms
step:3500/20000 train_loss:2.1219 train_time:339498ms step_avg:97.00ms
step:4000/20000 train_loss:1.9090 train_time:388081ms step_avg:97.02ms
step:4000/20000 val_loss:2.0001 val_bpb:1.1846 train_time:388086ms step_avg:97.02ms
late_qat:enabled step:4249 scale:0.4998
step:4500/20000 train_loss:2.0545 train_time:437520ms step_avg:97.23ms
step:5000/20000 train_loss:2.0327 train_time:487703ms step_avg:97.54ms
swa:start step:5300
step:5500/20000 train_loss:1.9408 train_time:538082ms step_avg:97.83ms
step:5935/20000 val_loss:1.9027 val_bpb:1.1269 train_time:582082ms step_avg:98.08ms
stopping_early: wallclock_cap train_time:582082ms step:5935/20000
peak memory allocated: 26200 MiB reserved: 26810 MiB
ema:applying EMA weights (skipping diagnostic evals)
gptq:calibrating with training data...
gptq:calibrated 68 layers in 1.8s
Serialized model: 130432585 bytes
Code size: 97234 bytes
pruning:3.0% magnitude pruning applied
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
gptq_quantize: 66 GPTQ layers, 0 naive layers
mixed_precision: 33161216 int5 params, 0 int6 params
Serialized model int6+zstd: 15794806 bytes
Total submission size int6+zstd: 15892040 bytes
ttt: pre-compiling forward+backward kernels...
ttt: pre-compile done
final_int6_sliding_window val_loss:1.8996 val_bpb:1.1251 stride:64 eval_time:93488ms
final_int6_sliding_window_exact val_loss:1.89959945 val_bpb:1.12505277
TTT: epochs=4 lr=0.0001 freeze_first=2 chunk=131072 opt=adamw
TTT temperature: 0.98
PPM alpha: 0.85, Byte-weighted TTT: True
Logistic context mixer enabled: eta=0.1
Adaptive LR enabled: max_mult=3.0
ttt:start chunks=474 chunk_tokens=131072 windows=969088 stride=64 lr=0.0001 epochs=4 opt=adamw freeze_first=2
ttt:params unfrozen=5780500 frozen=27537480
Polyak averaging enabled: decay=0.998
ttt_train [1] seqs=64 start_train...
ttt_train [1] epoch=1/4 batches=8 ...
step done ep=1 bs=0 loss=2.3520
ttt_train [1] epoch=2/4 batches=8 ...
step done ep=2 bs=0 loss=2.3190
ttt_train [1] epoch=3/4 batches=8 ...
step done ep=3 bs=0 loss=2.3168
ttt_train [1] epoch=4/4 batches=8 ...
step done ep=4 bs=0 loss=2.3064
ttt_chunk [1/474] bpb=1.205733 time=0.8s
ttt_train [2] seqs=64 start_train...
ttt_train [2] epoch=1/4 batches=8 ...
step done ep=1 bs=0 loss=2.1268
ttt_train [2] epoch=2/4 batches=8 ...
step done ep=2 bs=0 loss=2.1215
ttt_train [2] epoch=3/4 batches=8 ...
step done ep=3 bs=0 loss=2.1144
ttt_train [2] epoch=4/4 batches=8 ...
step done ep=4 bs=0 loss=2.1137
ttt_chunk [2/474] bpb=1.145205 time=1.5s
ttt_train [3] seqs=64 start_train...
ttt_train [3] epoch=1/4 batches=8 ...
step done ep=1 bs=0 loss=2.0650
ttt_train [3] epoch=2/4 batches=8 ...
step done ep=2 bs=0 loss=2.0640
ttt_train [3] epoch=3/4 batches=8 ...
step done ep=3 bs=0 loss=2.0610
ttt_train [3] epoch=4/4 batches=8 ...
step done ep=4 bs=0 loss=2.0570
ttt_chunk [3/474] bpb=1.093631 time=2.2s
ttt_chunk [4/474] bpb=1.086978 time=2.9s
ttt_chunk [5/474] bpb=1.075246 time=3.6s
ttt_chunk [11/474] bpb=1.037795 time=7.9s
ttt_chunk [21/474] bpb=1.025422 time=15.0s
ttt_chunk [31/474] bpb=1.021364 time=22.1s
ttt_chunk [41/474] bpb=1.028090 time=29.2s
ttt_chunk [51/474] bpb=1.033988 time=36.3s
ttt_chunk [61/474] bpb=1.031631 time=43.4s
ttt_chunk [71/474] bpb=1.032527 time=50.5s
ttt_chunk [81/474] bpb=1.033201 time=57.6s
ttt_chunk [91/474] bpb=1.034949 time=64.7s
ttt_chunk [101/474] bpb=1.031552 time=71.8s
ttt_chunk [111/474] bpb=1.031581 time=78.9s
ttt_chunk [121/474] bpb=1.034658 time=86.0s
ttt_chunk [131/474] bpb=1.034982 time=93.1s
ttt_chunk [141/474] bpb=1.033970 time=100.2s
ttt_chunk [151/474] bpb=1.031617 time=107.3s
ttt_chunk [161/474] bpb=1.031627 time=114.4s
ttt_chunk [171/474] bpb=1.029770 time=121.5s
ttt_chunk [181/474] bpb=1.030056 time=128.6s
ttt_chunk [191/474] bpb=1.028446 time=135.7s
ttt_chunk [201/474] bpb=1.027227 time=142.8s
ttt_chunk [211/474] bpb=1.025596 time=149.9s
ttt_chunk [221/474] bpb=1.025510 time=157.0s
ttt_chunk [231/474] bpb=1.024716 time=164.1s
ttt_chunk [241/474] bpb=1.023499 time=171.2s
ttt_chunk [251/474] bpb=1.024507 time=178.3s
ttt_chunk [261/474] bpb=1.025349 time=185.4s
ttt_chunk [271/474] bpb=1.024312 time=192.5s
ttt_chunk [281/474] bpb=1.024426 time=199.6s
ttt_chunk [291/474] bpb=1.023800 time=206.7s
ttt_chunk [301/474] bpb=1.025070 time=213.8s
ttt_chunk [311/474] bpb=1.025766 time=220.9s
ttt_chunk [321/474] bpb=1.025697 time=228.0s
ttt_chunk [331/474] bpb=1.026198 time=235.1s
ttt_chunk [341/474] bpb=1.027086 time=242.2s
ttt_chunk [351/474] bpb=1.027310 time=249.3s
ttt_chunk [361/474] bpb=1.029683 time=256.4s
ttt_chunk [371/474] bpb=1.031113 time=263.5s
ttt_chunk [381/474] bpb=1.033803 time=270.6s
ttt_chunk [391/474] bpb=1.036776 time=277.7s
ttt_chunk [401/474] bpb=1.039255 time=284.8s
ttt_chunk [411/474] bpb=1.041506 time=291.9s
ttt_chunk [421/474] bpb=1.044761 time=299.0s
ttt_chunk [431/474] bpb=1.045237 time=306.1s
ttt_chunk [441/474] bpb=1.046638 time=313.2s
ttt_chunk [451/474] bpb=1.047686 time=320.3s
ttt_chunk [461/474] bpb=1.049616 time=327.4s
ttt_chunk [471/474] bpb=1.051310 time=334.5s
ttt_chunk [474/474] bpb=1.051535 time=336.0s
ttt:done val_loss=1.768268 val_bpb=1.047270 elapsed=336.0s
final_int6_ttt val_loss:1.7683 val_bpb:1.0473 stride:64 eval_time:336499ms
final_int6_ttt_exact val_loss:1.76826750 val_bpb:1.04727039
Loading