diff --git a/device_utils.py b/device_utils.py index 58125b9d..1894d3ed 100644 --- a/device_utils.py +++ b/device_utils.py @@ -1,7 +1,11 @@ """Centralized cross-platform device selection for CorridorKey.""" +import json import logging import os +import subprocess +import sys +from dataclasses import dataclass logger = logging.getLogger(__name__) @@ -110,6 +114,159 @@ def resolve_device(requested: str | None = None) -> str: return device +@dataclass +class GPUInfo: + """Information about a single GPU.""" + + index: int + name: str + vram_total_gb: float + vram_free_gb: float + + +def _enumerate_nvidia() -> list[GPUInfo] | None: + """Enumerate NVIDIA GPUs via nvidia-smi. Returns None if unavailable.""" + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=index,name,memory.total,memory.free", "--format=csv,nounits,noheader"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode != 0: + return None + gpus: list[GPUInfo] = [] + for line in result.stdout.strip().split("\n"): + parts = [p.strip() for p in line.split(",")] + if len(parts) >= 4: + gpus.append( + GPUInfo( + index=int(parts[0]), + name=parts[1], + vram_total_gb=float(parts[2]) / 1024, + vram_free_gb=float(parts[3]) / 1024, + ) + ) + return gpus + except (FileNotFoundError, subprocess.TimeoutExpired): + return None + + +def _enumerate_amd() -> list[GPUInfo] | None: + """Enumerate AMD GPUs via amd-smi or rocm-smi. Returns None if unavailable.""" + # Try amd-smi (ROCm 6.0+) + try: + result = subprocess.run(["amd-smi", "static", "--json"], capture_output=True, text=True, timeout=10) + if result.returncode == 0: + data = json.loads(result.stdout) + gpus: list[GPUInfo] = [] + for i, gpu in enumerate(data): + try: + name = gpu.get("asic", {}).get("market_name", f"AMD GPU {i}") + vram_info = gpu.get("vram", {}) + total_mb = vram_info.get("size", {}).get("value", 0) + total_gb = float(total_mb) / 1024 if total_mb else 0 + gpus.append(GPUInfo(index=i, name=name, vram_total_gb=total_gb, vram_free_gb=total_gb)) + except (KeyError, TypeError, ValueError): + pass + if gpus: + return gpus + except (FileNotFoundError, subprocess.TimeoutExpired, json.JSONDecodeError): + pass + + # Fallback: rocm-smi (legacy) + try: + result = subprocess.run( + ["rocm-smi", "--showid", "--showmeminfo", "vram", "--csv"], capture_output=True, text=True, timeout=10 + ) + if result.returncode == 0 and result.stdout.strip(): + gpus = [] + for line in result.stdout.strip().split("\n")[1:]: + parts = [p.strip() for p in line.split(",")] + if len(parts) >= 3: + idx = int(parts[0]) if parts[0].isdigit() else len(gpus) + total_b = int(parts[1]) if parts[1].isdigit() else 0 + used_b = int(parts[2]) if parts[2].isdigit() else 0 + gpus.append( + GPUInfo( + index=idx, + name=f"AMD GPU {idx}", + vram_total_gb=total_b / (1024**3), + vram_free_gb=(total_b - used_b) / (1024**3), + ) + ) + if gpus: + return gpus + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + + # Windows: fall back to registry + if sys.platform == "win32": + try: + import winreg + + gpus = [] + base_key = r"SYSTEM\CurrentControlSet\Control\Class\{4d36e968-e325-11ce-bfc1-08002be10318}" + key = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, base_key) + for i in range(20): + try: + subkey = winreg.OpenKey(key, f"{i:04d}") + provider, _ = winreg.QueryValueEx(subkey, "ProviderName") + if "AMD" not in provider.upper() and "ATI" not in provider.upper(): + continue + desc, _ = winreg.QueryValueEx(subkey, "DriverDesc") + total_gb = 0.0 + for reg_name in ("HardwareInformation.qwMemorySize", "HardwareInformation.MemorySize"): + try: + mem_bytes, _ = winreg.QueryValueEx(subkey, reg_name) + total_gb = float(mem_bytes) / (1024**3) + if total_gb > 0: + break + except OSError: + continue + gpus.append(GPUInfo(index=len(gpus), name=desc, vram_total_gb=total_gb, vram_free_gb=total_gb)) + except OSError: + continue + if gpus: + return gpus + except Exception: + pass + + return None + + +def enumerate_gpus() -> list[GPUInfo]: + """List all available GPUs with VRAM info. + + Tries nvidia-smi (NVIDIA), then amd-smi/rocm-smi (AMD ROCm), + then falls back to torch.cuda API. + Returns an empty list on non-GPU systems. + """ + gpus = _enumerate_nvidia() + if gpus is not None: + return gpus + + gpus = _enumerate_amd() + if gpus is not None: + return gpus + + # Fallback to torch (works for both NVIDIA and ROCm via HIP) + try: + import torch + + if torch.cuda.is_available(): + fallback: list[GPUInfo] = [] + for i in range(torch.cuda.device_count()): + props = torch.cuda.get_device_properties(i) + total = props.total_memory / (1024**3) + fallback.append(GPUInfo(index=i, name=props.name, vram_total_gb=total, vram_free_gb=total)) + return fallback + except RuntimeError: + logger.debug("torch.cuda init failed, falling through", exc_info=True) + + return [] + + def clear_device_cache(device) -> None: """Clear GPU memory cache if applicable (no-op for CPU).""" import torch