Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added fme/core/benchmark/__init__.py
Empty file.
42 changes: 42 additions & 0 deletions fme/core/benchmark/test_timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from unittest.mock import patch

import pytest
import torch

from fme.core.benchmark.timer import CUDATimer


@pytest.mark.parametrize("is_available", [True, False])
def test_new_if_available(is_available: bool):
from fme.core.benchmark.timer import CUDATimer, NullTimer

with patch("torch.cuda.is_available", return_value=is_available):
timer = CUDATimer.new_if_available()
if is_available:
assert isinstance(timer, CUDATimer)
else:
assert isinstance(timer, NullTimer)


@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA is not available, skipping CUDATimer tests.",
)
def test_timer_with_child():
timer = CUDATimer()
# get cuda to wait
with timer.context("parent"):
torch.cuda._sleep(100_000)
child_timer = timer.child("child")
with child_timer.context("child_time"):
torch.cuda._sleep(100_000)
report = timer.report()
assert "parent" in report.average_time_seconds
assert "child" in report.children
assert "child_time" in report.children["child"].average_time_seconds
# parent time should include the child time, so it should be
# at least 2x the child time (since we sleep for the same amount of time in both)
assert (
report.average_time_seconds["parent"]
>= 2.0 * report.children["child"].average_time_seconds["child_time"]
)
85 changes: 85 additions & 0 deletions fme/core/benchmark/timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import collections
import contextlib
import dataclasses
from typing import Protocol, Self

import torch


@dataclasses.dataclass
class TimerReport:
average_time_seconds: dict[str, float]
children: dict[str, "TimerReport"]


class Timer(Protocol):
def context(self, name: str) -> contextlib.AbstractContextManager[None]: ...
def child(self, name: str) -> Self: ...


class NullTimer:
def context(self, name: str) -> contextlib.nullcontext:
return contextlib.nullcontext()

def child(self, name: str) -> "Self":
return self


_: Timer = NullTimer()
del _


class CUDATimer:
def __init__(self):
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available, cannot use CUDATimer.")
self._starters = []
self._enders = []
self._names = []
self._children: collections.defaultdict[str, CUDATimer] = (
collections.defaultdict(CUDATimer)
)

@classmethod
def new_if_available(cls) -> "CUDATimer" | NullTimer:
if torch.cuda.is_available():
return cls()
else:
return NullTimer()

@contextlib.contextmanager
def context(self, name: str):
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
self._starters.append(starter)
self._enders.append(ender)
self._names.append(name)
stream = torch.cuda.current_stream()
starter.record(stream)
try:
yield
finally:
ender.record(stream)
return

def child(self, name: str) -> "CUDATimer":
return self._children[name]

def report(self) -> TimerReport:
torch.cuda.synchronize()
total_time_seconds: dict[str, float] = collections.defaultdict(float)
counts: dict[str, int] = collections.defaultdict(int)
for starter, ender, name in zip(self._starters, self._enders, self._names):
total_time_seconds[name] += starter.elapsed_time(ender)
counts[name] += 1
average_time_seconds = {
name: total / counts[name] for name, total in total_time_seconds.items()
}
children = {}
for name, child in self._children.items():
children[name] = child.report()
return TimerReport(average_time_seconds=average_time_seconds, children=children)


__: Timer = CUDATimer()
del __
Loading