forked from karpathy/autoresearch
-
Notifications
You must be signed in to change notification settings - Fork 0
Add device abstraction for multi-platform training (MPS, CUDA, CPU) #10
Copy link
Copy link
Open
Description
Summary
train.py and prepare.py are hardcoded to CUDA in ~7 locations, making it impossible to run experiments locally on Apple Silicon (MPS) or CPU. This prevents:
- Local dev iteration on Mac without a GPU
- Quick smoke-testing of code changes before pushing to remote
- Running the monitor dashboard alongside a local training run
Proposal
Introduce a device.py module that resolves the target device at import time via the SKYLAB_DEVICE environment variable (default: auto-detect). Replace all hardcoded CUDA references in train.py and prepare.py with imports from this module.
Targets to support
| Target | How | Status |
|---|---|---|
| Local CUDA (single GPU) | uv run train.py |
Works today |
| Remote single GPU | bash remote_run.sh |
Works today |
| Remote multi-GPU | num_gpus=N in remote.toml |
Works today |
| Mac MPS | SKYLAB_DEVICE=mps uv run train.py |
New |
| CPU | SKYLAB_DEVICE=cpu uv run train.py |
New |
Key changes
- New
device.py(~50 lines) — device resolution,synchronize(),max_memory_mb(),seed()helpers - train.py — replace 7 CUDA-specific call sites with device.py imports
- prepare.py — parameterize 3 hardcoded CUDA references (pinned memory, gpu buffer, eval device)
- FA3 fallback to
F.scaled_dot_product_attentionon non-CUDA backends - Makefile —
make train-mps,make train-cpuconvenience targets
MPS limitations
- No Flash Attention 3 (falls back to native SDPA, no sliding window)
- No pinned memory
- No multi-device distributed
- Results will differ from CUDA — MPS is for dev iteration, not experiments of record
Detailed Plan
See ~/.planning/plans/2026-03-28-skylab-device-abstraction.md
🤖 Generated with Claude Code
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels