forked from karpathy/autoresearch
-
Notifications
You must be signed in to change notification settings - Fork 0
Experiment: Batch size and compilation optimization #8
Copy link
Copy link
Open
Labels
experimentHyperparameter or architecture experimentHyperparameter or architecture experimentpriority: highHigh impact, run firstHigh impact, run firstsize: MMedium — 5-15 experiments or 1-3 hoursMedium — 5-15 experiments or 1-3 hours
Description
Objective
Optimize batch size and torch.compile settings to maximize learning per 5-minute budget.
Key Insight
In a fixed time budget, batch size controls the step count vs gradient quality tradeoff. Smaller batch = more optimizer steps (more learning signal) but noisier gradients. The optimal depends on the model's critical batch size.
Current Config
| Parameter | Value |
|---|---|
TOTAL_BATCH_SIZE |
2^19 (524,288 tokens) |
DEVICE_BATCH_SIZE |
128 |
grad_accum_steps |
2 |
torch.compile |
dynamic=False |
matmul_precision |
"high" |
Batch Size Experiments (HIGH priority)
| ID | TOTAL_BATCH_SIZE | grad_accum | Priority | Rationale |
|---|---|---|---|---|
| BS-1 | 2^18 (262K) | 1 | HIGH | 2x more steps, no accumulation overhead |
| BS-2 | 2^17 (131K) | 1 (DEVICE=64) | HIGH | 4x more steps, needs DEVICE_BATCH_SIZE=64 |
| BS-3 | 2^20 (1M) | 4 | MEDIUM | Tests opposite direction — likely worse |
Run BS-1 first — eliminates grad accumulation entirely, giving both more steps AND less overhead.
Important Note for BS-2
TOTAL_BATCH_SIZE = 2**17
DEVICE_BATCH_SIZE = 64 # Must reduce: 64 * 2048 = 131072 = 2^17DEVICE_BATCH_SIZE Experiments (MEDIUM priority)
| ID | DEVICE_BATCH_SIZE | Effect | Priority |
|---|---|---|---|
| DB-1 | 64 | More grad accumulation, less VRAM | MEDIUM |
| DB-2 | 256 | No grad accumulation, may OOM | MEDIUM |
Compilation Experiments (MEDIUM priority)
| ID | Change | Priority | Expected |
|---|---|---|---|
| CMP-1 | fullgraph=True |
MEDIUM | 5-15% throughput if graph breaks existed |
| CMP-2 | mode="max-autotune" |
MEDIUM | 10-20% steady-state, 30-60s warmup cost |
| CMP-3 | Both combined | MEDIUM | Only if CMP-1 or CMP-2 help individually |
Other Throughput Experiments (LOW priority)
| ID | Change | Priority |
|---|---|---|
| MP-1 | set_float32_matmul_precision("medium") |
LOW |
| GC-1 | Remove all GC collection (currently every 5000 steps) | LOW |
| CKPT-1 | Gradient checkpointing (enabler for larger models) | MEDIUM |
Gradient checkpointing trades ~33% throughput for ~30-50% VRAM savings. Only valuable if it enables a larger model that compensates.
Execution Order
- BS-1 (highest expected value)
- BS-2 (if BS-1 improves, test further)
- BS-3 (confirm direction)
- CMP-1 and CMP-2
- DB-2 (throughput check)
- Lower priority items
Decision Gate
After batch size winner is known, re-test compilation settings (shapes change → different compile behavior).
🤖 Generated with Claude Code
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
experimentHyperparameter or architecture experimentHyperparameter or architecture experimentpriority: highHigh impact, run firstHigh impact, run firstsize: MMedium — 5-15 experiments or 1-3 hoursMedium — 5-15 experiments or 1-3 hours