Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions device_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Centralized cross-platform device selection for CorridorKey."""

import json
import logging
import os
import subprocess
import sys
from dataclasses import dataclass

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -110,6 +114,159 @@ def resolve_device(requested: str | None = None) -> str:
return device


@dataclass
class GPUInfo:
"""Information about a single GPU."""

index: int
name: str
vram_total_gb: float
vram_free_gb: float


def _enumerate_nvidia() -> list[GPUInfo] | None:
"""Enumerate NVIDIA GPUs via nvidia-smi. Returns None if unavailable."""
try:
result = subprocess.run(
["nvidia-smi", "--query-gpu=index,name,memory.total,memory.free", "--format=csv,nounits,noheader"],
capture_output=True,
text=True,
timeout=5,
)
if result.returncode != 0:
return None
gpus: list[GPUInfo] = []
for line in result.stdout.strip().split("\n"):
parts = [p.strip() for p in line.split(",")]
if len(parts) >= 4:
gpus.append(
GPUInfo(
index=int(parts[0]),
name=parts[1],
vram_total_gb=float(parts[2]) / 1024,
vram_free_gb=float(parts[3]) / 1024,
)
)
return gpus
except (FileNotFoundError, subprocess.TimeoutExpired):
return None


def _enumerate_amd() -> list[GPUInfo] | None:
"""Enumerate AMD GPUs via amd-smi or rocm-smi. Returns None if unavailable."""
# Try amd-smi (ROCm 6.0+)
try:
result = subprocess.run(["amd-smi", "static", "--json"], capture_output=True, text=True, timeout=10)
if result.returncode == 0:
data = json.loads(result.stdout)
gpus: list[GPUInfo] = []
for i, gpu in enumerate(data):
try:
name = gpu.get("asic", {}).get("market_name", f"AMD GPU {i}")
vram_info = gpu.get("vram", {})
total_mb = vram_info.get("size", {}).get("value", 0)
total_gb = float(total_mb) / 1024 if total_mb else 0
gpus.append(GPUInfo(index=i, name=name, vram_total_gb=total_gb, vram_free_gb=total_gb))
except (KeyError, TypeError, ValueError):
pass
if gpus:
return gpus
except (FileNotFoundError, subprocess.TimeoutExpired, json.JSONDecodeError):
pass

# Fallback: rocm-smi (legacy)
try:
result = subprocess.run(
["rocm-smi", "--showid", "--showmeminfo", "vram", "--csv"], capture_output=True, text=True, timeout=10
)
if result.returncode == 0 and result.stdout.strip():
gpus = []
for line in result.stdout.strip().split("\n")[1:]:
parts = [p.strip() for p in line.split(",")]
if len(parts) >= 3:
idx = int(parts[0]) if parts[0].isdigit() else len(gpus)
total_b = int(parts[1]) if parts[1].isdigit() else 0
used_b = int(parts[2]) if parts[2].isdigit() else 0
gpus.append(
GPUInfo(
index=idx,
name=f"AMD GPU {idx}",
vram_total_gb=total_b / (1024**3),
vram_free_gb=(total_b - used_b) / (1024**3),
)
)
if gpus:
return gpus
except (FileNotFoundError, subprocess.TimeoutExpired):
pass

# Windows: fall back to registry
if sys.platform == "win32":
try:
import winreg

gpus = []
base_key = r"SYSTEM\CurrentControlSet\Control\Class\{4d36e968-e325-11ce-bfc1-08002be10318}"
key = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, base_key)
for i in range(20):
try:
subkey = winreg.OpenKey(key, f"{i:04d}")
provider, _ = winreg.QueryValueEx(subkey, "ProviderName")
if "AMD" not in provider.upper() and "ATI" not in provider.upper():
continue
desc, _ = winreg.QueryValueEx(subkey, "DriverDesc")
total_gb = 0.0
for reg_name in ("HardwareInformation.qwMemorySize", "HardwareInformation.MemorySize"):
try:
mem_bytes, _ = winreg.QueryValueEx(subkey, reg_name)
total_gb = float(mem_bytes) / (1024**3)
if total_gb > 0:
break
except OSError:
continue
gpus.append(GPUInfo(index=len(gpus), name=desc, vram_total_gb=total_gb, vram_free_gb=total_gb))
except OSError:
continue
if gpus:
return gpus
except Exception:
pass

return None


def enumerate_gpus() -> list[GPUInfo]:
"""List all available GPUs with VRAM info.

Tries nvidia-smi (NVIDIA), then amd-smi/rocm-smi (AMD ROCm),
then falls back to torch.cuda API.
Returns an empty list on non-GPU systems.
"""
gpus = _enumerate_nvidia()
if gpus is not None:
return gpus

gpus = _enumerate_amd()
if gpus is not None:
return gpus

# Fallback to torch (works for both NVIDIA and ROCm via HIP)
try:
import torch

if torch.cuda.is_available():
fallback: list[GPUInfo] = []
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
total = props.total_memory / (1024**3)
fallback.append(GPUInfo(index=i, name=props.name, vram_total_gb=total, vram_free_gb=total))
return fallback
except RuntimeError:
logger.debug("torch.cuda init failed, falling through", exc_info=True)

return []


def clear_device_cache(device) -> None:
"""Clear GPU memory cache if applicable (no-op for CPU)."""
import torch
Expand Down