Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions evoxels/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@
import os
import subprocess
import tracemalloc
import shutil
from abc import ABC, abstractmethod

class MemoryProfiler(ABC):
"""Base interface for tracking host and device memory usage."""
def __init__(self):
self.max_used_cpu = 0.0
self.max_used_gpu = 0.0
self.track_gpu = False # subclasses set this

def get_cuda_memory_from_nvidia_smi(self):
"""Return currently used CUDA memory in megabytes."""
if shutil.which("nvidia-smi") is None:
return None
try:
output = subprocess.check_output(
['nvidia-smi', '--query-gpu=memory.used',
Expand All @@ -23,8 +31,10 @@ def update_memory_stats(self):
process = psutil.Process(os.getpid())
used_cpu = process.memory_info().rss / 1024**2
self.max_used_cpu = np.max((self.max_used_cpu, used_cpu))
used = self.get_cuda_memory_from_nvidia_smi()
self.max_used_gpu = np.max((self.max_used_gpu, used))
if self.track_gpu:
used = self.get_cuda_memory_from_nvidia_smi()
if used is not None:
self.max_used_gpu = np.max((self.max_used_gpu, used))

@abstractmethod
def print_memory_stats(self, start: float, end: float, iters: int):
Expand All @@ -35,13 +45,14 @@ class TorchMemoryProfiler(MemoryProfiler):
def __init__(self, device):
"""Initialize the profiler for a given torch device."""
import torch
super().__init__()
self.torch = torch
self.device = device
self.track_gpu = (device.type == 'cuda')

tracemalloc.start()
if device.type == 'cuda':
if self.track_gpu:
torch.cuda.reset_peak_memory_stats(device=device)
self.max_used_gpu = 0
self.max_used_cpu = 0

def print_memory_stats(self, start, end, iters):
"""Print usage statistics for the Torch backend."""
Expand All @@ -60,7 +71,10 @@ def print_memory_stats(self, start, end, iters):
elif self.device.type == 'cuda':
self.update_memory_stats()
used = self.get_cuda_memory_from_nvidia_smi()
print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
if used is None:
print("GPU-RAM (nvidia-smi) unavailable.")
else:
print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
print(f"GPU-RAM (torch) current: "
f"{self.torch.cuda.memory_allocated(self.device) / 1024**2:.2f} MB "
f"({self.torch.cuda.max_memory_allocated(self.device) / 1024**2:.2f} MB max, "
Expand All @@ -70,9 +84,9 @@ class JAXMemoryProfiler(MemoryProfiler):
def __init__(self):
"""Initialize the profiler for JAX."""
import jax
super().__init__()
self.jax = jax
self.max_used_gpu = 0
self.max_used_cpu = 0
self.track_gpu = any(d.platform == "gpu" for d in jax.devices())
tracemalloc.start()

def print_memory_stats(self, start, end, iters):
Expand All @@ -88,7 +102,10 @@ def print_memory_stats(self, start, end, iters):
current = process.memory_info().rss / 1024**2
print(f"CPU-RAM (psutil) current: {current:.2f} MB ({self.max_used_cpu:.2f} MB max)")

if self.jax.default_backend() == 'gpu':
if self.track_gpu:
self.update_memory_stats()
used = self.get_cuda_memory_from_nvidia_smi()
print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
if used is None:
print("GPU-RAM (nvidia-smi) unavailable.")
else:
print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")