diff --git a/fme/core/benchmark/__init__.py b/fme/core/benchmark/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fme/core/benchmark/test_timer.py b/fme/core/benchmark/test_timer.py new file mode 100644 index 000000000..9e22caa63 --- /dev/null +++ b/fme/core/benchmark/test_timer.py @@ -0,0 +1,42 @@ +from unittest.mock import patch + +import pytest +import torch + +from fme.core.benchmark.timer import CUDATimer + + +@pytest.mark.parametrize("is_available", [True, False]) +def test_new_if_available(is_available: bool): + from fme.core.benchmark.timer import CUDATimer, NullTimer + + with patch("torch.cuda.is_available", return_value=is_available): + timer = CUDATimer.new_if_available() + if is_available: + assert isinstance(timer, CUDATimer) + else: + assert isinstance(timer, NullTimer) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA is not available, skipping CUDATimer tests.", +) +def test_timer_with_child(): + timer = CUDATimer() + # get cuda to wait + with timer.context("parent"): + torch.cuda._sleep(100_000) + child_timer = timer.child("child") + with child_timer.context("child_time"): + torch.cuda._sleep(100_000) + report = timer.report() + assert "parent" in report.average_time_seconds + assert "child" in report.children + assert "child_time" in report.children["child"].average_time_seconds + # parent time should include the child time, so it should be + # at least 2x the child time (since we sleep for the same amount of time in both) + assert ( + report.average_time_seconds["parent"] + >= 2.0 * report.children["child"].average_time_seconds["child_time"] + ) diff --git a/fme/core/benchmark/timer.py b/fme/core/benchmark/timer.py new file mode 100644 index 000000000..75c605470 --- /dev/null +++ b/fme/core/benchmark/timer.py @@ -0,0 +1,85 @@ +import collections +import contextlib +import dataclasses +from typing import Protocol, Self + +import torch + + +@dataclasses.dataclass +class TimerReport: + average_time_seconds: dict[str, float] + children: dict[str, "TimerReport"] + + +class Timer(Protocol): + def context(self, name: str) -> contextlib.AbstractContextManager[None]: ... + def child(self, name: str) -> Self: ... + + +class NullTimer: + def context(self, name: str) -> contextlib.nullcontext: + return contextlib.nullcontext() + + def child(self, name: str) -> "Self": + return self + + +_: Timer = NullTimer() +del _ + + +class CUDATimer: + def __init__(self): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available, cannot use CUDATimer.") + self._starters = [] + self._enders = [] + self._names = [] + self._children: collections.defaultdict[str, CUDATimer] = ( + collections.defaultdict(CUDATimer) + ) + + @classmethod + def new_if_available(cls) -> "CUDATimer" | NullTimer: + if torch.cuda.is_available(): + return cls() + else: + return NullTimer() + + @contextlib.contextmanager + def context(self, name: str): + starter = torch.cuda.Event(enable_timing=True) + ender = torch.cuda.Event(enable_timing=True) + self._starters.append(starter) + self._enders.append(ender) + self._names.append(name) + stream = torch.cuda.current_stream() + starter.record(stream) + try: + yield + finally: + ender.record(stream) + return + + def child(self, name: str) -> "CUDATimer": + return self._children[name] + + def report(self) -> TimerReport: + torch.cuda.synchronize() + total_time_seconds: dict[str, float] = collections.defaultdict(float) + counts: dict[str, int] = collections.defaultdict(int) + for starter, ender, name in zip(self._starters, self._enders, self._names): + total_time_seconds[name] += starter.elapsed_time(ender) + counts[name] += 1 + average_time_seconds = { + name: total / counts[name] for name, total in total_time_seconds.items() + } + children = {} + for name, child in self._children.items(): + children[name] = child.report() + return TimerReport(average_time_seconds=average_time_seconds, children=children) + + +__: Timer = CUDATimer() +del __