Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
## EMA-GPU + Multi-Order N-gram Backoff + Pre-Enrichment + XSA

**val_bpb: 0.2995** (3-seed mean, std 0.0016) | 14.94 MB | 8xH100 SXM, 600s

---

### 3-Seed Results

| Seed | Steps | Sliding BPB | val_bpb | Artifact |
|---|---|---|---|---|
| 1337 | 9,268 | 1.1478 | 0.3001 | 14,942,971 |
| 42 | 9,318 | 1.1468 | 0.2977 | 14,922,769 |
| 3011 | 9,322 | 1.1463 | 0.3008 | 14,939,305 |
| **Mean** | — | **1.1470** | **0.2995** | — |
| **Std** | — | — | **0.0016** | — |

---

### Architecture

10L/512d U-Net, 25.25M params. GQA 8H/4KV, MLP 3x (1536 hidden), tied embeddings, logit softcap=30.0.

- **GELU Pre-Enrichment** (512→768→512): Wider nonlinear transformation before transformer blocks. Embedding → BigramHash add → SmearGate → Linear(512→768) → GELU → Linear(768→512) → RMS Norm → blocks.
- **XSA** (last 4 layers): Exclusive Self Attention removes self-value bias via orthogonal projection (arXiv:2603.09078, GQA-aware implementation from PR #265 @unnir). Zero parameters.
- **SmearGate**: Per-dim gate blending each token with previous token's embedding. F.pad for efficiency.
- **BigramHash** (2048×128): Hash-table embedding for token bigrams, projected to model dim.
- **U-Net skip connections**: Encoder-decoder with learnable skip weights.

Training: Muon+AdamW, WD=0.04, matrix_lr=0.025, scalar_lr=0.025, warmdown=3500 iters, batch=524K tokens, seq=2048. EMA decay=0.997. Int6 QAT + lzma (preset=6).

---

### EMA on GPU (37% faster training) — novel contribution

EMA state kept on GPU during training instead of synchronous GPU→CPU copy every step. Only moved to CPU at the end for serialization. To my knowledge, this optimization is not used in other submissions.

Step time: **64.4ms** (vs 101ms before). Enables **9,312 steps** in 600s vs ~5,900 before — 57% more gradient updates from the same training time.

---

### Multi-Order N-gram Backoff (score-first, backward-looking)

Multi-order n-gram backoff with entropy-adaptive alpha during sliding window eval. Concept credited to @deanbrr (PR #659), developed by PR #706 (@newjordan) and PR #727 (@Asukabot0).

**Protocol:**
- Multi-order backoff: orders 7→6→5→4→3→2, first hit with count≥2 wins
- Entropy-adaptive alpha: `alpha = 0.05 + 0.55 * sigmoid(2 * (H - 4.0))`
- High model entropy → trust n-gram more; low entropy → trust model
- Cache built from already-scored tokens only (backward-looking)
- Score-first: cache updated AFTER segment scoring
- Dual-array hash scheme: separate context count and pair count arrays per order (4M buckets each)
- Per-GPU independent cache, no cross-GPU sync
- Hash tables precomputed for all orders in single pass
- Integrated into sliding window eval (single pass)

**Compliance:**
- Score-first, backward-looking: n-gram counts built from previously scored tokens only
- No oracle selection: alpha depends solely on model's own entropy, never on ground-truth
- No cross-GPU sync: each GPU maintains its own independent cache

**Improvement:** 1.1478 → 0.3001 = **-0.848 BPB**

#### Pre-Enrichment Confidence Modulation

Uses the pre-enrichment layer's transformation magnitude as a confidence signal. High delta = model uncertain about this context = trust n-gram more. Low delta = model confident = trust model more. Modulates entropy-adaptive alpha by `(0.5 + 1.0 * pe_conf)`.

---

### Toggleable Features (default OFF, not used in this submission)

- `VALUE_RESIDUAL=1` — Layer-0 V mixed into all subsequent layers via learned sigmoid gates
- `GATED_ATTN=1` — Per-head sigmoid gates on attention output

---

### Reproduce

```bash
python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

All defaults baked in. No env vars needed. 8xH100 SXM, 600s training + ~182s eval.

---

### Credits
- Muon optimizer — modded-nanogpt baseline (kellerjordan)
- SmearGate + BigramHash — PR #65 (@aquariouseworkman)
- XSA — arXiv:2603.09078; GQA-aware PR #265 (@unnir)
- EMA + GPTQ-lite + warmdown tuning — PR #414 (@signalrush)
- N-gram eval cache — concept PR #659 (@deanbrr); fixed 5-gram PR #706 (@newjordan); multi-order entropy-adaptive PR #727 (@Asukabot0)
- Shared GPU n-gram cache — PR #796 (@Robby955); chunk-synchronized PR #800 (@newjordan); PR #809 (@AayushBaniya2006)
- Per-order adaptive alpha — PR #798 (@travispchen); Cubric scaling PR #800 (@newjordan)
- Overtone init — modded-nanogpt baseline
- GELU Pre-Enrichment — original to this submission
- EMA on GPU — original to this submission
- Pre-Enrichment Confidence Modulation — original to this submission

### Included Files

- `train_gpt.py` — standalone training script with all modifications
- `train.log` — full 8xH100 training + eval log (seed 1337)
- `submission.json` — leaderboard metadata
- `README.md` — this file
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"author": "Idanr",
"github_id": "idan3011",
"name": "Two-Phase Shared N-gram Cache + EMA-GPU + Pre-Enrichment + XSA",
"blurb": "Two-phase eval: distributed GPU forward + global sequential n-gram cache (orders 2-11). Per-order adaptive alpha + PE confidence modulation. EMA on GPU (64.7ms/step). 10L 512d. 3-seed mean 0.2995 (std 0.0016).",
"date": "2026-03-26T07:45:00Z",
"val_loss": 1.93793804,
"val_bpb": 0.29953465,
"val_bpb_seeds": [0.30008709, 0.29765751, 0.30079936],
"val_bpb_std": 0.0016,
"pre_quant_val_loss": 1.9663,
"pre_quant_val_bpb": 1.1646,
"step_stop": 9268,
"wallclock_seconds": 600.031,
"eval_time_seconds": 204.138,
"bytes_total": 14942971,
"bytes_model_int6_lzma": 14878748,
"bytes_code": 66581
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
W0326 02:39:19.172000 34413 torch/distributed/run.py:803]
W0326 02:39:19.172000 34413 torch/distributed/run.py:803] *****************************************
W0326 02:39:19.172000 34413 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0326 02:39:19.172000 34413 torch/distributed/run.py:803] *****************************************
logs/0d771539-26db-4427-b5a8-0a4c24bd56ad.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:25254992
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=True flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9319 val_bpb:4.1055 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9318 train_time:62ms step_avg:61.75ms
step:2/20000 train_loss:7.1516 train_time:121ms step_avg:60.53ms
step:3/20000 train_loss:6.1791 train_time:185ms step_avg:61.59ms
step:4/20000 train_loss:6.4189 train_time:249ms step_avg:62.18ms
step:5/20000 train_loss:6.5862 train_time:313ms step_avg:62.55ms
step:6/20000 train_loss:6.2277 train_time:377ms step_avg:62.78ms
step:7/20000 train_loss:5.4960 train_time:441ms step_avg:62.97ms
step:8/20000 train_loss:5.2973 train_time:505ms step_avg:63.10ms
step:9/20000 train_loss:5.0005 train_time:569ms step_avg:63.20ms
step:10/20000 train_loss:4.8514 train_time:633ms step_avg:63.30ms
step:200/20000 train_loss:2.7511 train_time:12872ms step_avg:64.36ms
step:400/20000 train_loss:2.2579 train_time:25781ms step_avg:64.45ms
step:600/20000 train_loss:2.4713 train_time:38736ms step_avg:64.56ms
step:800/20000 train_loss:2.2316 train_time:51722ms step_avg:64.65ms
step:1000/20000 train_loss:2.3340 train_time:64727ms step_avg:64.73ms
step:1000/20000 val_loss:2.2855 val_bpb:1.3536 train_time:64739ms step_avg:64.74ms
step:1200/20000 train_loss:2.3620 train_time:77744ms step_avg:64.79ms
step:1400/20000 train_loss:2.3964 train_time:90750ms step_avg:64.82ms
step:1600/20000 train_loss:2.0689 train_time:103750ms step_avg:64.84ms
step:1800/20000 train_loss:2.1729 train_time:116742ms step_avg:64.86ms
step:2000/20000 train_loss:2.2158 train_time:129716ms step_avg:64.86ms
step:2000/20000 val_loss:2.1975 val_bpb:1.3015 train_time:129728ms step_avg:64.86ms
step:2200/20000 train_loss:2.0324 train_time:142686ms step_avg:64.86ms
step:2400/20000 train_loss:2.1624 train_time:155641ms step_avg:64.85ms
step:2600/20000 train_loss:2.3841 train_time:168596ms step_avg:64.84ms
step:2800/20000 train_loss:2.2002 train_time:181543ms step_avg:64.84ms
step:3000/20000 train_loss:2.1908 train_time:194474ms step_avg:64.82ms
step:3000/20000 val_loss:2.1539 val_bpb:1.2757 train_time:194486ms step_avg:64.83ms
step:3200/20000 train_loss:2.1563 train_time:207406ms step_avg:64.81ms
step:3400/20000 train_loss:2.1250 train_time:220338ms step_avg:64.81ms
step:3600/20000 train_loss:2.0721 train_time:233268ms step_avg:64.80ms
step:3800/20000 train_loss:2.1786 train_time:246196ms step_avg:64.79ms
step:4000/20000 train_loss:2.1419 train_time:259115ms step_avg:64.78ms
step:4000/20000 val_loss:2.1367 val_bpb:1.2655 train_time:259127ms step_avg:64.78ms
step:4200/20000 train_loss:2.1372 train_time:272101ms step_avg:64.79ms
step:4400/20000 train_loss:2.0839 train_time:285022ms step_avg:64.78ms
step:4600/20000 train_loss:1.9446 train_time:297946ms step_avg:64.77ms
step:4800/20000 train_loss:2.2371 train_time:310856ms step_avg:64.76ms
step:5000/20000 train_loss:1.9905 train_time:323763ms step_avg:64.75ms
step:5000/20000 val_loss:2.1285 val_bpb:1.2606 train_time:323775ms step_avg:64.76ms
step:5200/20000 train_loss:2.1516 train_time:336678ms step_avg:64.75ms
step:5400/20000 train_loss:2.1670 train_time:349585ms step_avg:64.74ms
step:5600/20000 train_loss:2.1609 train_time:362500ms step_avg:64.73ms
step:5800/20000 train_loss:2.1178 train_time:375416ms step_avg:64.73ms
step:6000/20000 train_loss:2.1963 train_time:388331ms step_avg:64.72ms
step:6000/20000 val_loss:2.1194 val_bpb:1.2552 train_time:388343ms step_avg:64.72ms
step:6200/20000 train_loss:2.0618 train_time:401239ms step_avg:64.72ms
step:6400/20000 train_loss:2.1328 train_time:414152ms step_avg:64.71ms
step:6600/20000 train_loss:2.0839 train_time:427067ms step_avg:64.71ms
step:6800/20000 train_loss:2.1327 train_time:439971ms step_avg:64.70ms
step:7000/20000 train_loss:2.1739 train_time:452890ms step_avg:64.70ms
step:7000/20000 val_loss:2.0766 val_bpb:1.2299 train_time:452903ms step_avg:64.70ms
step:7200/20000 train_loss:2.1442 train_time:465802ms step_avg:64.69ms
step:7400/20000 train_loss:2.0575 train_time:478715ms step_avg:64.69ms
step:7600/20000 train_loss:1.9264 train_time:491637ms step_avg:64.69ms
step:7800/20000 train_loss:2.0683 train_time:504556ms step_avg:64.69ms
step:8000/20000 train_loss:2.0304 train_time:517550ms step_avg:64.69ms
step:8000/20000 val_loss:2.0324 val_bpb:1.2037 train_time:517563ms step_avg:64.70ms
step:8200/20000 train_loss:2.1001 train_time:530461ms step_avg:64.69ms
step:8400/20000 train_loss:2.0298 train_time:543436ms step_avg:64.69ms
step:8600/20000 train_loss:2.0308 train_time:556429ms step_avg:64.70ms
step:8800/20000 train_loss:1.9809 train_time:569549ms step_avg:64.72ms
step:9000/20000 train_loss:1.8848 train_time:582572ms step_avg:64.73ms
step:9000/20000 val_loss:1.9773 val_bpb:1.1711 train_time:582573ms step_avg:64.73ms
step:9200/20000 train_loss:1.9494 train_time:595634ms step_avg:64.74ms
step:9268/20000 val_loss:1.9663 val_bpb:1.1646 train_time:600031ms step_avg:64.74ms
stopping_early: wallclock_cap train_time:600031ms step:9268/20000
peak memory allocated: 13058 MiB reserved: 13280 MiB
swa: averaging 14 checkpoints on top of EMA
ema: loading weights
Serialized model: 99486509 bytes
Code size: 64223 bytes
Total submission size: 99550732 bytes
Serialized model int6+lzma: 14878748 bytes (payload:25993024 raw_torch:26045291 payload_ratio:3.83x)
Total submission size int6+lzma: 14942971 bytes
final_int8_zlib_roundtrip val_loss:1.9738 val_bpb:1.1690 eval_time:2054ms
final_int8_zlib_roundtrip_exact val_loss:1.97382834 val_bpb:1.16901232
final_sliding_window sliding_bpb:1.1478 val_bpb:0.3001 eval_time:204138ms
final_sliding_window_exact sliding_bpb:1.14775606 val_bpb:0.30008709
Loading