Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
4dec08a
add CUDATimer and NullTimer
mcgibbon Feb 10, 2026
3460bea
test assert_close
mcgibbon Feb 10, 2026
6c1c29c
copy-paste sht_fix.py
mcgibbon Feb 10, 2026
332540a
update conditional SFNO to pass timers for profiling
mcgibbon Feb 10, 2026
50369f0
add benchmarks with gpu regression testing
mcgibbon Feb 10, 2026
a92561d
remove unused typevar
mcgibbon Feb 10, 2026
e53bff5
use optional argument for optional argument
mcgibbon Feb 10, 2026
aeda7ee
codify single name
mcgibbon Feb 10, 2026
a55f908
Merge branch 'main' into feature/cuda_timer
mcgibbon Feb 10, 2026
2c2518d
Merge branch 'feature/cuda_timer' into feature/sfno_timers
mcgibbon Feb 10, 2026
0551dc8
Merge branch 'feature/sfno_timers' into feature/benchmarking
mcgibbon Feb 10, 2026
d17f3f1
incorporate review comments
mcgibbon Feb 10, 2026
43a00c3
fix test
mcgibbon Feb 10, 2026
c6f633e
Merge branch 'feature/cuda_timer' into feature/sfno_timers
mcgibbon Feb 10, 2026
ce9ae4f
Merge branch 'feature/sfno_timers' into feature/benchmarking
mcgibbon Feb 10, 2026
c613b73
delete dead code
mcgibbon Feb 10, 2026
36f9cd8
add force_cpu
mcgibbon Feb 10, 2026
17554a8
use cpu for regression test
mcgibbon Feb 10, 2026
f011630
Merge branch 'main' into feature/benchmarking
mcgibbon Feb 11, 2026
2354ef8
use logging.info, save json as well
mcgibbon Feb 11, 2026
1c75607
revert changes to conftest
mcgibbon Feb 11, 2026
cec4913
maintain insertion (runtime) order
mcgibbon Feb 11, 2026
fb647d4
simpler imports
mcgibbon Feb 11, 2026
f0d6f9b
add dirty label
mcgibbon Feb 11, 2026
4738d9e
Merge branch 'main' into feature/benchmarking
mcgibbon Feb 12, 2026
4f407ba
add logging, arg for out dir, basic test
mcgibbon Feb 12, 2026
1de2576
Merge branch 'main' into feature/benchmarking
mcgibbon Feb 12, 2026
de0bef6
Merge branch 'main' into feature/benchmarking
mcgibbon Feb 12, 2026
bb51f35
skip test on non-gpu
mcgibbon Feb 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fme/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import models as _ # to trigger registrations
from .atmosphere_data import AtmosphereData
from .device import get_device, using_gpu
from .gridded_ops import GriddedOperations
Expand All @@ -14,11 +15,14 @@
from .rand import set_seed
from .registry import Registry

del _

__all__ = [
"spherical_area_weights",
"weighted_mean",
"weighted_mean_bias",
"weighted_nanmean",
"weighted_sum",
"root_mean_squared_error",
"get_device",
"using_gpu",
Expand Down
1 change: 1 addition & 0 deletions fme/core/benchmark/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
results
300 changes: 300 additions & 0 deletions fme/core/benchmark/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
import abc
import dataclasses
import pathlib
from collections.abc import Callable
from typing import Self

import dacite
import matplotlib.pyplot as plt
import torch

from fme.core.benchmark.memory import MemoryResult, benchmark_memory
from fme.core.benchmark.timer import CUDATimer, NullTimer, Timer, TimerResult
from fme.core.typing_ import TensorDict


@dataclasses.dataclass
class BenchmarkResult:
memory: MemoryResult
timer: TimerResult

def __repr__(self) -> str:
return f"BenchmarkResult(memory={self.memory}, timer={self.timer})"

def asdict(self) -> dict:
return dataclasses.asdict(self)

@classmethod
def from_dict(cls, d: dict) -> "BenchmarkResult":
return dacite.from_dict(cls, d, config=dacite.Config(strict=True))

def assert_close(
self, other: "BenchmarkResult", rtol=0.02, children_rtol=0.02
) -> None:
try:
self.timer.assert_close(other.timer, rtol=rtol, children_rtol=children_rtol)
except AssertionError as e:
raise AssertionError(f"Timer results differ: {e}") from e
try:
self.memory.assert_close(other.memory, rtol=rtol)
except AssertionError as e:
raise AssertionError(f"Memory results differ: {e}") from e

def to_png(
self, path: str | pathlib.Path, label: str, child: str | None = None
) -> None:
# note this function was generated with AI
def avg_time(t: TimerResult) -> float:
return float(t.avg_time)

def self_time(t: TimerResult) -> float:
t_avg = avg_time(t)
c_avg = sum(avg_time(c) for c in t.children.values())
return max(t_avg - c_avg, 0.0)

def fmt_time(ms: float) -> str:
if ms >= 1000.0:
return f"{ms/1000.0:.2f}s"
if ms >= 10.0:
return f"{ms:.1f}ms"
return f"{ms:.2f}ms"

def label_ok(name: str, ms: float, frac_of_root: float) -> bool:
if not name:
return False
return frac_of_root >= 0.05

def ordered_children(t: TimerResult) -> list[tuple[str, TimerResult]]:
return list(t.children.items()) # maintain dict order (insertion order)

def blend_with_white(
rgb: tuple[float, float, float], amount: float
) -> tuple[float, float, float]:
# amount in [0,1]: 0 -> original, 1 -> white
return (
rgb[0] + (1.0 - rgb[0]) * amount,
rgb[1] + (1.0 - rgb[1]) * amount,
rgb[2] + (1.0 - rgb[2]) * amount,
)

root = self.timer
if child is not None:
for part in child.split("."):
if part not in root.children:
raise ValueError(f"Child '{child}' not found in timer results.")
root = root.children[part]
root_avg = avg_time(root)

max_alloc_mb = self.memory.max_alloc / (1024.0 * 1024.0)

fig = plt.figure(figsize=(8, 6), constrained_layout=True)
if root_avg <= 0.0:
fig.suptitle(
f"Benchmark for {label}\ntotal=0.00s, max_alloc={max_alloc_mb:.1f} MB",
fontsize=14,
)
ax0 = fig.add_subplot(1, 1, 1)
ax0.text(0.5, 0.5, "No timing data", ha="center", va="center")
ax0.axis("off")
fig.savefig(path, dpi=200)
plt.close(fig)
return

fig.suptitle(
f"Benchmark for {label}\ntotal={fmt_time(root_avg)}, "
f"max_alloc={max_alloc_mb:.1f} MB",
fontsize=14,
)

ax = fig.add_subplot(1, 1, 1)
ax.set_xlim(0, 2)
ax.set_ylim(0, root_avg)
ax.set_xticks([0.5, 1.5])
ax.set_xticklabels(["Level 1", "Level 2"])
ax.set_ylabel("Avg time")
ax.set_yticks([])
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

gray = (0.85, 0.85, 0.85, 1.0)
cmap = plt.get_cmap("tab20")

lvl1 = ordered_children(root)
lvl1_names = [n for n, _ in lvl1]
lvl1_index = {n: i for i, n in enumerate(lvl1_names)}

# Level 1 stack (root children + root self in gray, unlabeled)
lvl1_segments: list[tuple[str, float, tuple[float, float, float, float]]] = []
for n1, t1 in lvl1:
base = cmap(lvl1_index[n1] % cmap.N)
lvl1_segments.append((n1, avg_time(t1), base))
r_self = self_time(root)
if r_self > 0.0:
lvl1_segments.append(("", r_self, gray))

def draw_stack(
x_center: float,
segments: list[tuple[str, float, tuple[float, float, float, float]]],
) -> None:
width = 0.86
y = 0.0
for name, sec, color in segments:
if sec <= 0.0:
continue
ax.bar(
x_center,
sec,
bottom=y,
width=width,
align="center",
color=color,
edgecolor="white",
linewidth=1.0,
)
frac = sec / root_avg
if label_ok(name, sec, frac):
ax.text(
x_center,
y + sec / 2.0,
f"{name}\n{fmt_time(sec)}",
ha="center",
va="center",
fontsize=9,
rotation=0, # keep horizontal to avoid cross-column overlap
clip_on=True,
)
y += sec
if y < root_avg:
ax.bar(
x_center,
root_avg - y,
bottom=y,
width=width,
align="center",
color=gray,
edgecolor="white",
linewidth=1.0,
)

draw_stack(0.5, lvl1_segments)

# Level 2 stack:
# For each level-1 slice, stack its children
# (colored as parent hue variants) + self in gray.
lvl2_segments: list[tuple[str, float, tuple[float, float, float, float]]] = []
for n1, t1 in lvl1:
parent_rgba = cmap(lvl1_index[n1] % cmap.N)
parent_rgb = (parent_rgba[0], parent_rgba[1], parent_rgba[2])

children = ordered_children(t1)
k = len(children)
for i, (n2, t2) in enumerate(children):
# Same “type” of color as parent: lighten progressively per child.
# First child is closest to parent; later children are lighter.
lighten = 0.10 + (0.55 * (i / max(k - 1, 1)))
rgb = blend_with_white(parent_rgb, lighten)
lvl2_segments.append((n2, avg_time(t2), (rgb[0], rgb[1], rgb[2], 1.0)))

s1 = self_time(t1)
if s1 > 0.0:
lvl2_segments.append(("", s1, gray))

draw_stack(1.5, lvl2_segments)

fig.tight_layout(rect=(0.02, 0.02, 0.98, 0.98))
fig.savefig(path, dpi=200, bbox_inches="tight")
plt.close(fig)


class BenchmarkABC(abc.ABC):
@classmethod
def new_from_fn(
cls,
fn: Callable[[Timer], TensorDict],
) -> "BenchmarkABC":
class FnBenchmark(BenchmarkABC):
@classmethod
def new(cls) -> "FnBenchmark":
return FnBenchmark()

def run_instance(self, timer: Timer) -> TensorDict:
return fn(timer)

return FnBenchmark()

@classmethod
@abc.abstractmethod
def new(cls: type[Self]) -> Self:
"""
Initialize any state needed for the benchmark.
This will be called once before the benchmark is run.
"""
pass

@classmethod
def new_for_regression(cls: type[Self]) -> Self | None:
"""
Initialize any state needed for regression testing.
This will be called once before regression tests are run.

If regression testing is not needed, this can return None,
and regression testing will not be run.

This exists as a separate method from new so that it can
use small data sizes more conducive to storing regression targets in git.
"""
return None

@classmethod
def run_benchmark(cls, iters=10, warmup=1) -> BenchmarkResult:
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available, cannot run benchmark.")
null_timer = NullTimer()
benchmark = cls.new()
for _ in range(warmup):
benchmark.run_instance(null_timer)
timer = CUDATimer()
with benchmark_memory() as bm:
for _ in range(iters):
with timer:
benchmark.run_instance(timer)
return BenchmarkResult(
timer=timer.result,
memory=bm.result,
)

@classmethod
def run_regression(cls) -> TensorDict | None:
benchmark = cls.new_for_regression()
if benchmark is None:
return None
null_timer = NullTimer()
return benchmark.run_instance(null_timer)

@abc.abstractmethod
def run_instance(self: Self, timer: Timer) -> TensorDict:
"""
Run the benchmark. This will be called multiple times,
and should return a TensorDict of results.

This must not mutate any state on self, since the same instance may be
used across multiple iterations.
"""
pass


_BENCHMARKS: dict[str, type[BenchmarkABC]] = {}


def register_benchmark(name: str) -> Callable[[type[BenchmarkABC]], type[BenchmarkABC]]:
def _register(fn: type[BenchmarkABC]) -> type[BenchmarkABC]:
if name in _BENCHMARKS:
raise ValueError(f"Benchmark with name '{name}' is already registered.")
_BENCHMARKS[name] = fn
return fn

return _register


def get_benchmarks() -> dict[str, type[BenchmarkABC]]:
return _BENCHMARKS.copy()
Loading