Skip to content

Add device abstraction for multi-platform training (MPS, CUDA, CPU) #10

@Jason-Adam

Description

@Jason-Adam

Summary

train.py and prepare.py are hardcoded to CUDA in ~7 locations, making it impossible to run experiments locally on Apple Silicon (MPS) or CPU. This prevents:

  • Local dev iteration on Mac without a GPU
  • Quick smoke-testing of code changes before pushing to remote
  • Running the monitor dashboard alongside a local training run

Proposal

Introduce a device.py module that resolves the target device at import time via the SKYLAB_DEVICE environment variable (default: auto-detect). Replace all hardcoded CUDA references in train.py and prepare.py with imports from this module.

Targets to support

Target How Status
Local CUDA (single GPU) uv run train.py Works today
Remote single GPU bash remote_run.sh Works today
Remote multi-GPU num_gpus=N in remote.toml Works today
Mac MPS SKYLAB_DEVICE=mps uv run train.py New
CPU SKYLAB_DEVICE=cpu uv run train.py New

Key changes

  • New device.py (~50 lines) — device resolution, synchronize(), max_memory_mb(), seed() helpers
  • train.py — replace 7 CUDA-specific call sites with device.py imports
  • prepare.py — parameterize 3 hardcoded CUDA references (pinned memory, gpu buffer, eval device)
  • FA3 fallback to F.scaled_dot_product_attention on non-CUDA backends
  • Makefile — make train-mps, make train-cpu convenience targets

MPS limitations

  • No Flash Attention 3 (falls back to native SDPA, no sliding window)
  • No pinned memory
  • No multi-device distributed
  • Results will differ from CUDA — MPS is for dev iteration, not experiments of record

Detailed Plan

See ~/.planning/plans/2026-03-28-skylab-device-abstraction.md

🤖 Generated with Claude Code

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions