From 01801406e1c542652943a8b54d50dc99ff2b7b59 Mon Sep 17 00:00:00 2001 From: daubners Date: Fri, 9 Jan 2026 12:13:34 +0000 Subject: [PATCH] track gpu --- evoxels/profiler.py | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/evoxels/profiler.py b/evoxels/profiler.py index be0bc49..2320563 100644 --- a/evoxels/profiler.py +++ b/evoxels/profiler.py @@ -3,12 +3,20 @@ import os import subprocess import tracemalloc +import shutil from abc import ABC, abstractmethod class MemoryProfiler(ABC): """Base interface for tracking host and device memory usage.""" + def __init__(self): + self.max_used_cpu = 0.0 + self.max_used_gpu = 0.0 + self.track_gpu = False # subclasses set this + def get_cuda_memory_from_nvidia_smi(self): """Return currently used CUDA memory in megabytes.""" + if shutil.which("nvidia-smi") is None: + return None try: output = subprocess.check_output( ['nvidia-smi', '--query-gpu=memory.used', @@ -23,8 +31,10 @@ def update_memory_stats(self): process = psutil.Process(os.getpid()) used_cpu = process.memory_info().rss / 1024**2 self.max_used_cpu = np.max((self.max_used_cpu, used_cpu)) - used = self.get_cuda_memory_from_nvidia_smi() - self.max_used_gpu = np.max((self.max_used_gpu, used)) + if self.track_gpu: + used = self.get_cuda_memory_from_nvidia_smi() + if used is not None: + self.max_used_gpu = np.max((self.max_used_gpu, used)) @abstractmethod def print_memory_stats(self, start: float, end: float, iters: int): @@ -35,13 +45,14 @@ class TorchMemoryProfiler(MemoryProfiler): def __init__(self, device): """Initialize the profiler for a given torch device.""" import torch + super().__init__() self.torch = torch self.device = device + self.track_gpu = (device.type == 'cuda') + tracemalloc.start() - if device.type == 'cuda': + if self.track_gpu: torch.cuda.reset_peak_memory_stats(device=device) - self.max_used_gpu = 0 - self.max_used_cpu = 0 def print_memory_stats(self, start, end, iters): """Print usage statistics for the Torch backend.""" @@ -60,7 +71,10 @@ def print_memory_stats(self, start, end, iters): elif self.device.type == 'cuda': self.update_memory_stats() used = self.get_cuda_memory_from_nvidia_smi() - print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)") + if used is None: + print("GPU-RAM (nvidia-smi) unavailable.") + else: + print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)") print(f"GPU-RAM (torch) current: " f"{self.torch.cuda.memory_allocated(self.device) / 1024**2:.2f} MB " f"({self.torch.cuda.max_memory_allocated(self.device) / 1024**2:.2f} MB max, " @@ -70,9 +84,9 @@ class JAXMemoryProfiler(MemoryProfiler): def __init__(self): """Initialize the profiler for JAX.""" import jax + super().__init__() self.jax = jax - self.max_used_gpu = 0 - self.max_used_cpu = 0 + self.track_gpu = any(d.platform == "gpu" for d in jax.devices()) tracemalloc.start() def print_memory_stats(self, start, end, iters): @@ -88,7 +102,10 @@ def print_memory_stats(self, start, end, iters): current = process.memory_info().rss / 1024**2 print(f"CPU-RAM (psutil) current: {current:.2f} MB ({self.max_used_cpu:.2f} MB max)") - if self.jax.default_backend() == 'gpu': + if self.track_gpu: self.update_memory_stats() used = self.get_cuda_memory_from_nvidia_smi() - print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)") \ No newline at end of file + if used is None: + print("GPU-RAM (nvidia-smi) unavailable.") + else: + print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)") \ No newline at end of file