diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 56aa68e18..03f1d06a2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: args: [--maxkb=250] exclude: | (?x)^( - fme/ace/aggregator/inference/testdata/.*-regression.pt + fme/ace/aggregator/inference/testdata/.*-regression\.pt | )$ - id: trailing-whitespace - id: file-contents-sorter diff --git a/conftest.py b/conftest.py index 84c62901a..930e9fc68 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,7 @@ +import os + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # required for determinism + import gc import signal from unittest import mock @@ -5,6 +9,15 @@ import pytest import torch +from fme.core.rand import set_seed + + +@pytest.fixture(autouse=True, scope="session") +def deterministic_pytorch(): + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False + set_seed(0) + def pytest_addoption(parser): parser.addoption( diff --git a/fme/core/__init__.py b/fme/core/__init__.py index 67658ccfa..a2650da09 100644 --- a/fme/core/__init__.py +++ b/fme/core/__init__.py @@ -1,3 +1,4 @@ +from . import models as _ # to trigger registrations from .atmosphere_data import AtmosphereData from .device import get_device, using_gpu from .gridded_ops import GriddedOperations @@ -14,6 +15,8 @@ from .rand import set_seed from .registry import Registry +del _ + __all__ = [ "spherical_area_weights", "weighted_mean", diff --git a/fme/core/benchmark/.gitignore b/fme/core/benchmark/.gitignore new file mode 100644 index 000000000..1a06816d8 --- /dev/null +++ b/fme/core/benchmark/.gitignore @@ -0,0 +1 @@ +results diff --git a/fme/core/benchmark/__init__.py b/fme/core/benchmark/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fme/core/benchmark/benchmark.py b/fme/core/benchmark/benchmark.py new file mode 100644 index 000000000..c227264bf --- /dev/null +++ b/fme/core/benchmark/benchmark.py @@ -0,0 +1,305 @@ +import abc +import dataclasses +import pathlib +from collections.abc import Callable +from typing import Self, TypeVar + +import dacite +import matplotlib.pyplot as plt +import torch + +from fme.core.benchmark.memory import MemoryResult, benchmark_memory +from fme.core.benchmark.timer import CUDATimer, NullTimer, Timer, TimerResult +from fme.core.typing_ import TensorDict + + +@dataclasses.dataclass +class BenchmarkResult: + memory: MemoryResult + timer: TimerResult + + def __repr__(self) -> str: + return f"BenchmarkResult(memory={self.memory}, timer={self.timer})" + + def asdict(self) -> dict: + return dataclasses.asdict(self) + + @classmethod + def from_dict(cls, d: dict) -> "BenchmarkResult": + return dacite.from_dict(cls, d, config=dacite.Config(strict=True)) + + def assert_close( + self, other: "BenchmarkResult", rtol=0.02, children_rtol=0.02 + ) -> None: + try: + self.timer.assert_close(other.timer, rtol=rtol, children_rtol=children_rtol) + except AssertionError as e: + raise AssertionError(f"Timer results differ: {e}") from e + try: + self.memory.assert_close(other.memory, rtol=rtol) + except AssertionError as e: + raise AssertionError(f"Memory results differ: {e}") from e + + def to_png( + self, path: str | pathlib.Path, label: str, child: str | None = None + ) -> None: + # note this function was generated with AI + def avg_time(t: TimerResult) -> float: + return float(t.avg_time) + + def self_time(t: TimerResult) -> float: + t_avg = avg_time(t) + c_avg = sum(avg_time(c) for c in t.children.values()) + return max(t_avg - c_avg, 0.0) + + def fmt_time(ms: float) -> str: + if ms >= 1000.0: + return f"{ms/1000.0:.2f}s" + if ms >= 10.0: + return f"{ms:.1f}ms" + return f"{ms:.2f}ms" + + def label_ok(name: str, ms: float, frac_of_root: float) -> bool: + if not name: + return False + return frac_of_root >= 0.05 + + def sorted_children(t: TimerResult) -> list[tuple[str, TimerResult]]: + return sorted( + t.children.items(), key=lambda kv: avg_time(kv[1]), reverse=True + ) + + def blend_with_white( + rgb: tuple[float, float, float], amount: float + ) -> tuple[float, float, float]: + # amount in [0,1]: 0 -> original, 1 -> white + return ( + rgb[0] + (1.0 - rgb[0]) * amount, + rgb[1] + (1.0 - rgb[1]) * amount, + rgb[2] + (1.0 - rgb[2]) * amount, + ) + + root = self.timer + if child is not None: + for part in child.split("."): + if part not in root.children: + raise ValueError(f"Child '{child}' not found in timer results.") + root = root.children[part] + root_avg = avg_time(root) + + max_alloc_mb = self.memory.max_alloc / (1024.0 * 1024.0) + + fig = plt.figure(figsize=(8, 6), constrained_layout=True) + if root_avg <= 0.0: + fig.suptitle( + f"Benchmark for {label}\ntotal=0.00s, max_alloc={max_alloc_mb:.1f} MB", + fontsize=14, + ) + ax0 = fig.add_subplot(1, 1, 1) + ax0.text(0.5, 0.5, "No timing data", ha="center", va="center") + ax0.axis("off") + fig.savefig(path, dpi=200) + plt.close(fig) + return + + fig.suptitle( + f"Benchmark for {label}\ntotal={fmt_time(root_avg)}, " + f"max_alloc={max_alloc_mb:.1f} MB", + fontsize=14, + ) + + ax = fig.add_subplot(1, 1, 1) + ax.set_xlim(0, 2) + ax.set_ylim(0, root_avg) + ax.set_xticks([0.5, 1.5]) + ax.set_xticklabels(["Level 1", "Level 2"]) + ax.set_ylabel("Avg time") + ax.set_yticks([]) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + gray = (0.85, 0.85, 0.85, 1.0) + cmap = plt.get_cmap("tab20") + + lvl1 = sorted_children(root) + lvl1_names = [n for n, _ in lvl1] + lvl1_index = {n: i for i, n in enumerate(lvl1_names)} + + # Level 1 stack (root children + root self in gray, unlabeled) + lvl1_segments: list[tuple[str, float, tuple[float, float, float, float]]] = [] + for n1, t1 in lvl1: + base = cmap(lvl1_index[n1] % cmap.N) + lvl1_segments.append((n1, avg_time(t1), base)) + r_self = self_time(root) + if r_self > 0.0: + lvl1_segments.append(("", r_self, gray)) + + def draw_stack( + x_center: float, + segments: list[tuple[str, float, tuple[float, float, float, float]]], + ) -> None: + width = 0.86 + y = 0.0 + for name, sec, color in segments: + if sec <= 0.0: + continue + ax.bar( + x_center, + sec, + bottom=y, + width=width, + align="center", + color=color, + edgecolor="white", + linewidth=1.0, + ) + frac = sec / root_avg + if label_ok(name, sec, frac): + ax.text( + x_center, + y + sec / 2.0, + f"{name}\n{fmt_time(sec)}", + ha="center", + va="center", + fontsize=9, + rotation=0, # keep horizontal to avoid cross-column overlap + clip_on=True, + ) + y += sec + if y < root_avg: + ax.bar( + x_center, + root_avg - y, + bottom=y, + width=width, + align="center", + color=gray, + edgecolor="white", + linewidth=1.0, + ) + + draw_stack(0.5, lvl1_segments) + + # Level 2 stack: + # For each level-1 slice, stack its children + # (colored as parent hue variants) + self in gray. + lvl2_segments: list[tuple[str, float, tuple[float, float, float, float]]] = [] + for n1, t1 in lvl1: + parent_rgba = cmap(lvl1_index[n1] % cmap.N) + parent_rgb = (parent_rgba[0], parent_rgba[1], parent_rgba[2]) + + children = sorted_children(t1) + k = len(children) + for i, (n2, t2) in enumerate(children): + # Same “type” of color as parent: lighten progressively per child. + # First child is closest to parent; later children are lighter. + lighten = 0.10 + (0.55 * (i / max(k - 1, 1))) + rgb = blend_with_white(parent_rgb, lighten) + lvl2_segments.append((n2, avg_time(t2), (rgb[0], rgb[1], rgb[2], 1.0))) + + s1 = self_time(t1) + if s1 > 0.0: + lvl2_segments.append(("", s1, gray)) + + draw_stack(1.5, lvl2_segments) + + fig.tight_layout(rect=(0.02, 0.02, 0.98, 0.98)) + fig.savefig(path, dpi=200, bbox_inches="tight") + plt.close(fig) + + +T = TypeVar("T") + + +class BenchmarkABC(abc.ABC): + @classmethod + def new_from_fn( + cls, + fn: Callable[[Timer], TensorDict], + ) -> "BenchmarkABC": + class FnBenchmark(BenchmarkABC): + @classmethod + def new(cls) -> "FnBenchmark": + return FnBenchmark() + + def run_instance(self, timer: Timer) -> TensorDict: + return fn(timer) + + return FnBenchmark() + + @classmethod + @abc.abstractmethod + def new(cls: type[Self]) -> Self: + """ + Initialize any state needed for the benchmark. + This will be called once before the benchmark is run. + """ + pass + + @classmethod + def new_for_regression(cls: type[Self]) -> Self | None: + """ + Initialize any state needed for regression testing. + This will be called once before regression tests are run. + + If regression testing is not needed, this can return None, + and regression testing will not be run. + + This exists as a separate method from new so that it can + use small data sizes more conducive to storing regression targets in git. + """ + return None + + @classmethod + def run_benchmark(cls, iters=10, warmup=1) -> BenchmarkResult: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available, cannot run benchmark.") + null_timer = NullTimer() + benchmark = cls.new() + for _ in range(warmup): + benchmark.run_instance(null_timer) + timer = CUDATimer() + with benchmark_memory() as bm: + for _ in range(iters): + with timer: + benchmark.run_instance(timer) + return BenchmarkResult( + timer=timer.result, + memory=bm.result, + ) + + @classmethod + def run_regression(cls) -> TensorDict | None: + benchmark = cls.new_for_regression() + if benchmark is None: + return None + null_timer = NullTimer() + return benchmark.run_instance(null_timer) + + @abc.abstractmethod + def run_instance(self: Self, timer: Timer) -> TensorDict: + """ + Run the benchmark. This will be called multiple times, + and should return a TensorDict of results. + + This must not mutate any state on self, since the same instance may be + used across multiple iterations. + """ + pass + + +_BENCHMARKS: dict[str, type[BenchmarkABC]] = {} + + +def register_benchmark(name: str) -> Callable[[type[BenchmarkABC]], type[BenchmarkABC]]: + def _register(fn: type[BenchmarkABC]) -> type[BenchmarkABC]: + if name in _BENCHMARKS: + raise ValueError(f"Benchmark with name '{name}' is already registered.") + _BENCHMARKS[name] = fn + return fn + + return _register + + +def get_benchmarks() -> dict[str, type[BenchmarkABC]]: + return _BENCHMARKS.copy() diff --git a/fme/core/benchmark/memory.py b/fme/core/benchmark/memory.py new file mode 100644 index 000000000..d104b0035 --- /dev/null +++ b/fme/core/benchmark/memory.py @@ -0,0 +1,86 @@ +import dataclasses +from typing import Literal + +import torch + +_benchmark_memory_started = False + + +@dataclasses.dataclass +class MemoryResult: + max_alloc: int + max_reserved: int + + def assert_close(self, other: "MemoryResult", rtol=0.02) -> None: + if not torch.isclose( + torch.tensor(self.max_alloc, dtype=torch.float64), + torch.tensor(other.max_alloc, dtype=torch.float64), + rtol=rtol, + ): + raise AssertionError( + f"max_alloc differs: {self.max_alloc} vs " + f"{other.max_alloc} given rtol={rtol}" + ) + if not torch.isclose( + torch.tensor(self.max_reserved, dtype=torch.float64), + torch.tensor(other.max_reserved, dtype=torch.float64), + rtol=rtol, + ): + raise AssertionError( + f"max_reserved differs: {self.max_reserved} vs " + f"{other.max_reserved} given rtol={rtol}" + ) + + +class MemoryBenchmark: + def __init__(self): + self._started = False + self._ended = False + + def __enter__(self) -> "MemoryBenchmark": + global _benchmark_memory_started + if _benchmark_memory_started: + raise RuntimeError( + "benchmark_memory cannot be nested due to its use of globals" + ) + _benchmark_memory_started = True + if self._started: + raise RuntimeError( + "MemoryBenchmark cannot be nested due to its use of globals" + ) + if self._ended: + raise RuntimeError("MemoryBenchmark cannot be reused after it has ended.") + self._started = True + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + self._max_alloc = 0 + self._max_reserved = 0 + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]: + torch.cuda.synchronize() + global _benchmark_memory_started + _benchmark_memory_started = False + self._started = False + self._ended = True + self._max_alloc = torch.cuda.max_memory_allocated() + self._max_reserved = torch.cuda.max_memory_reserved() + return False # Don't suppress exceptions + + @property + def result(self) -> MemoryResult: + if self._started: + raise RuntimeError( + "MemoryBenchmark is still running. " + "Please exit the context before getting results." + ) + if not self._ended: + raise RuntimeError( + "MemoryBenchmark has not been run yet. " + "Please enter and exit the context before getting results." + ) + return MemoryResult(max_alloc=self._max_alloc, max_reserved=self._max_reserved) + + +def benchmark_memory() -> MemoryBenchmark: + return MemoryBenchmark() diff --git a/fme/core/benchmark/run.py b/fme/core/benchmark/run.py new file mode 100644 index 000000000..b5a2988cb --- /dev/null +++ b/fme/core/benchmark/run.py @@ -0,0 +1,107 @@ +import argparse +import os +import pathlib +import subprocess + +import torch + +from fme.core.benchmark.benchmark import get_benchmarks + +RESULTS_PATH = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "results" + +_GIT_COMMIT: str | None = None + + +def get_git_commit() -> str: + global _GIT_COMMIT + if _GIT_COMMIT is None: + args = ["git", "rev-parse", "--short", "HEAD"] + _GIT_COMMIT = ( + subprocess.check_output(args, stderr=subprocess.DEVNULL).decode().strip() + ) + return _GIT_COMMIT + + +def get_device_name() -> str: + if torch.cuda.is_available(): + return torch.cuda.get_device_properties(0).name + else: + return "CPU" + + +def main(names: list[str] | None, iters: int, child: str | None = None) -> None: + RESULTS_PATH.mkdir(exist_ok=True) + device_name = get_device_name() + + print(f"Running benchmarks on device: {device_name}") + benchmarks = get_benchmarks() + if names is not None: + if any(name not in benchmarks for name in names): + print("Some specified benchmarks not found. Available benchmarks:") + for name in benchmarks: + print(f" - {name}") + return + benchmarks_to_run = {name: benchmarks[name] for name in names} + else: + benchmarks_to_run = benchmarks + + def get_label(name): + return f"{name} on {device_name} at commit {get_git_commit()}" + + def get_filename(name) -> pathlib.Path: + safe_name = name.replace("/", "_").replace(".", "_").lower() + safe_device_name = device_name.replace(" ", "_").replace("/", "_").lower() + return RESULTS_PATH / f"{safe_name}_{safe_device_name}_{get_git_commit()}.png" + + for name, cls in benchmarks_to_run.items(): + print(f"Running benchmark: {name}") + result = cls.run_benchmark(iters=iters) + result.to_png(get_filename(name), label=get_label(name)) + if child is not None: + child_name = f"{name}.{child}" + child_label = get_label(child_name) + print(f" Generating report for child timer: {child_label}") + result.to_png(get_filename(child_name), label=child_label, child=child) + print(f" Result: {result}") + + +def get_benchmark_label(name): + device_name = get_device_name() + return f"{name} on {device_name} at commit {get_git_commit()}" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run registered benchmarks.") + parser.add_argument( + "benchmark", + type=str, + nargs="?", + default=None, + help=( + "Name of the benchmark to run. If not provided, " + "all benchmarks will be run." + ), + ) + parser.add_argument( + "--child", + type=str, + default=None, + help=( + "If provided, the child timer to generate a report for. " + "This should be a dot-separated path to a child timer, " + "e.g. 'forward' or 'forward.linear'." + ), + ) + parser.add_argument( + "--iters", + type=int, + default=10, + help="Number of iterations to run each benchmark for.", + ) + args = parser.parse_args() + + main( + names=[args.benchmark] if args.benchmark else None, + iters=args.iters, + child=args.child, + ) diff --git a/fme/core/benchmark/test_benchmark.py b/fme/core/benchmark/test_benchmark.py new file mode 100644 index 000000000..c7c95629e --- /dev/null +++ b/fme/core/benchmark/test_benchmark.py @@ -0,0 +1,54 @@ +import os + +import pytest +import torch + +import fme # to trigger registration of benchmarks +from fme.core.benchmark.benchmark import BenchmarkABC, get_benchmarks +from fme.core.rand import set_seed +from fme.core.testing.regression import validate_tensor_dict + +del fme + +DIR = os.path.abspath(os.path.dirname(__file__)) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_run_benchmark(): + def benchmark_fn(timer): + torch.cuda._sleep(100_000_000) + + benchmark = BenchmarkABC.new_from_fn(benchmark_fn) + + first_result = benchmark.run_benchmark(iters=15, warmup=1) + assert first_result.timer.total_runs == 15 + second_result = benchmark.run_benchmark(iters=20, warmup=1) + assert second_result.timer.total_runs == 20 + torch.testing.assert_close( + first_result.timer.avg_time, second_result.timer.avg_time, rtol=0.2, atol=0 + ) + + +def test_benchmarks_are_not_empty(): + assert ( + len(get_benchmarks()) > 0 + ), "No benchmarks were registered, but at least one was expected." + + +BENCHMARKS = get_benchmarks() + + +@pytest.mark.parametrize("benchmark_name", BENCHMARKS.keys()) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_regression(benchmark_name: str): + set_seed(0) + benchmark_cls = BENCHMARKS[benchmark_name] + regression_result = benchmark_cls.run_regression() + if regression_result is None: + pytest.skip("Benchmark does not have regression targets.") + # If run_regression returns something, we expect it to be a TensorDict of results + assert isinstance(regression_result, dict) + validate_tensor_dict( + regression_result, + os.path.join(DIR, "testdata", f"{benchmark_name}-regression.pt"), + ) diff --git a/fme/core/benchmark/test_memory.py b/fme/core/benchmark/test_memory.py new file mode 100644 index 000000000..fc986712b --- /dev/null +++ b/fme/core/benchmark/test_memory.py @@ -0,0 +1,29 @@ +import pytest +import torch + +from fme.core.benchmark.memory import benchmark_memory + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_cannot_nest_benchmark(): + with benchmark_memory(): + with pytest.raises(RuntimeError, match="benchmark_memory cannot be nested"): + with benchmark_memory(): + pass + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_cannot_get_result_before_end(): + with benchmark_memory() as bm: + with pytest.raises(RuntimeError, match="MemoryBenchmark is still running"): + bm.result + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_larger_array_uses_larger_memory(): + with benchmark_memory() as bm1: + _ = torch.randn(100, 100, device="cuda") + with benchmark_memory() as bm2: + _ = torch.randn(200, 200, device="cuda") + + assert bm2.result.max_alloc > bm1.result.max_alloc diff --git a/fme/core/benchmark/test_timer.py b/fme/core/benchmark/test_timer.py new file mode 100644 index 000000000..dd9cd5624 --- /dev/null +++ b/fme/core/benchmark/test_timer.py @@ -0,0 +1,36 @@ +from unittest.mock import patch + +import pytest +import torch + +from fme.core.benchmark.timer import CUDATimer + + +@pytest.mark.parametrize("is_available", [True, False]) +def test_new_if_available(is_available: bool): + from fme.core.benchmark.timer import CUDATimer, NullTimer + + with patch("torch.cuda.is_available", return_value=is_available): + timer = CUDATimer.new_if_available() + if is_available: + assert isinstance(timer, CUDATimer) + else: + assert isinstance(timer, NullTimer) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA is not available, skipping CUDATimer tests.", +) +def test_timer_with_child(): + timer = CUDATimer() + with timer: + # get cuda to wait + torch.cuda._sleep(100_000) + with timer.child("child"): + torch.cuda._sleep(100_000) + result = timer.result + assert "child" in result.children + # parent time should include the child time, so it should be + # at least 2x the child time (since we sleep for the same amount of time in both) + assert result.avg_time >= 2.0 * result.children["child"].avg_time diff --git a/fme/core/benchmark/testdata/csfno_block-regression.pt b/fme/core/benchmark/testdata/csfno_block-regression.pt new file mode 100644 index 000000000..3ae9270b3 Binary files /dev/null and b/fme/core/benchmark/testdata/csfno_block-regression.pt differ diff --git a/fme/core/benchmark/testdata/csfno_block_8_groups-regression.pt b/fme/core/benchmark/testdata/csfno_block_8_groups-regression.pt new file mode 100644 index 000000000..8d4430508 Binary files /dev/null and b/fme/core/benchmark/testdata/csfno_block_8_groups-regression.pt differ diff --git a/fme/core/benchmark/timer.py b/fme/core/benchmark/timer.py new file mode 100644 index 000000000..230fe4e72 --- /dev/null +++ b/fme/core/benchmark/timer.py @@ -0,0 +1,171 @@ +import collections +import contextlib +import dataclasses +from typing import Literal, Protocol, Self + +import torch + + +@dataclasses.dataclass +class TimerResult: + total_runs: int + avg_time: float + children: dict[str, "TimerResult"] + + def assert_close(self, other: "TimerResult", rtol=0.02, children_rtol=0.02) -> None: + if self.total_runs != other.total_runs: + raise AssertionError( + f"total_runs differ: {self.total_runs} vs {other.total_runs}" + ) + if not torch.isclose( + torch.tensor(self.avg_time), torch.tensor(other.avg_time), rtol=rtol + ): + raise AssertionError( + f"avg_time differ: {self.avg_time} vs " + f"{other.avg_time} given rtol={rtol}" + ) + if self.children.keys() != other.children.keys(): + raise AssertionError( + f"children keys differ: {self.children.keys()} vs " + f"{other.children.keys()}" + ) + for key in self.children.keys(): + try: + self.children[key].assert_close( + other.children[key], rtol=children_rtol, children_rtol=children_rtol + ) + except AssertionError as e: + raise AssertionError(f"child '{key}' differ: {e}") from e + + +class Timer(Protocol): + def child(self, name: str) -> Self: ... + def __enter__(self) -> Self: ... + def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]: ... + + +class NullTimer: + def context(self, name: str) -> contextlib.nullcontext: + return contextlib.nullcontext() + + def child(self, name: str) -> "Self": + return self + + def __enter__(self) -> "Self": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]: + return False + + def report(self) -> TimerResult: + return TimerResult(total_runs=0, avg_time=0.0, children={}) + + +_: Timer = NullTimer() +del _ + + +class EventPair: + def __init__(self): + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + self._stream = None + self._start_recorded = False + self._end_recorded = False + + def record_start(self): + if self._start_recorded: + raise RuntimeError( + "record_start has already been called on this EventPair." + ) + self._stream = torch.cuda.current_stream() + self.start.record(self._stream) + self._start_recorded = True + + def record_end(self): + if not self._start_recorded: + raise RuntimeError("record_start must be called before record_end") + if self._end_recorded: + raise RuntimeError("record_end has already been called on this EventPair.") + if self._stream is None: + raise RuntimeError("record_start must be called before record_end") + self.end.record(self._stream) + self._end_recorded = True + + def elapsed_time_ms(self) -> float: + if not self._start_recorded or not self._end_recorded: + raise RuntimeError( + "Both record_start and record_end must be called " + "before elapsed_time_ms can be called." + ) + return self.start.elapsed_time(self.end) + + +class CUDATimer: + def __init__(self): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available, cannot use CUDATimer.") + self._children: collections.defaultdict[str, CUDATimer] = ( + collections.defaultdict(CUDATimer) + ) + self._event_pairs: list[EventPair] = [] + self._entered = False + self._result: TimerResult | None = None + + @classmethod + def new_if_available(cls) -> "CUDATimer | NullTimer": + if torch.cuda.is_available(): + return cls() + else: + return NullTimer() + + def __enter__(self): + if self._entered: + raise RuntimeError("CUDATimer is already entered.") + self._entered = True + self._event_pairs.append(EventPair()) + self._event_pairs[-1].record_start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self._event_pairs: + raise RuntimeError("CUDATimer context was not properly entered.") + self._event_pairs[-1].record_end() + self._entered = False + return False + + def child(self, name: str) -> "CUDATimer": + if not self._entered: + raise RuntimeError( + "CUDATimer child cannot be used before entering the timer." + ) + return self._children[name] + + @property + def _avg_time(self) -> float: + if len(self._event_pairs) == 0: + raise RuntimeError( + "CUDATimer report cannot be generated before entering the timer." + ) + total_time = sum( + event_pair.elapsed_time_ms() for event_pair in self._event_pairs + ) + return total_time / len(self._event_pairs) + + def _child_reports(self) -> dict[str, TimerResult]: + return {name: child.result for name, child in self._children.items()} + + @property + def result(self) -> TimerResult: + if self._result is None: + torch.cuda.synchronize() + self._result = TimerResult( + total_runs=len(self._event_pairs), + avg_time=self._avg_time, + children=self._child_reports(), + ) + return self._result + + +__: type[Timer] = CUDATimer +del __ diff --git a/fme/core/models/__init__.py b/fme/core/models/__init__.py new file mode 100644 index 000000000..4ce828d37 --- /dev/null +++ b/fme/core/models/__init__.py @@ -0,0 +1,3 @@ +from . import conditional_sfno as _ # to trigger registrations + +del _ diff --git a/fme/core/models/conditional_sfno/__init__.py b/fme/core/models/conditional_sfno/__init__.py index e69de29bb..d37a22639 100644 --- a/fme/core/models/conditional_sfno/__init__.py +++ b/fme/core/models/conditional_sfno/__init__.py @@ -0,0 +1,3 @@ +from . import benchmark as _ # to trigger registrations + +del _ diff --git a/fme/core/models/conditional_sfno/benchmark.py b/fme/core/models/conditional_sfno/benchmark.py new file mode 100644 index 000000000..62b4ef992 --- /dev/null +++ b/fme/core/models/conditional_sfno/benchmark.py @@ -0,0 +1,126 @@ +from typing import Self + +import torch + +from fme.core.benchmark.benchmark import BenchmarkABC, register_benchmark +from fme.core.benchmark.timer import Timer +from fme.core.device import get_device +from fme.core.models.conditional_sfno.layers import Context, ContextConfig +from fme.core.models.conditional_sfno.sfnonet import FourierNeuralOperatorBlock +from fme.core.models.conditional_sfno.sht import InverseRealSHT, RealSHT +from fme.core.typing_ import TensorDict + + +def get_block_benchmark(filter_num_groups: int) -> type[BenchmarkABC]: + class BlockBenchmark(BenchmarkABC): + def __init__( + self, block: FourierNeuralOperatorBlock, x: torch.Tensor, context: Context + ): + self.block = block + self.x = x + self.context = context + + def run_instance(self, timer: Timer) -> TensorDict: + result = self.block(self.x, self.context, timer=timer) + return {"output": result.detach()} + + @classmethod + def new(cls) -> Self: + B = 2 + C = 512 + H = 180 + L = 360 + G = filter_num_groups + conditional_embed_dim_noise = 64 + conditional_embed_dim_labels = 3 + conditional_embed_dim_pos = 32 + return cls._new_with_params( + B=B, + C=C, + H=H, + L=L, + G=G, + conditional_embed_dim_noise=conditional_embed_dim_noise, + conditional_embed_dim_labels=conditional_embed_dim_labels, + conditional_embed_dim_pos=conditional_embed_dim_pos, + ) + + @classmethod + def _new_with_params( + cls, + B: int, + C: int, + H: int, + L: int, + G: int, + conditional_embed_dim_noise: int, + conditional_embed_dim_labels: int, + conditional_embed_dim_pos: int, + ) -> Self: + G = filter_num_groups + device = get_device() + conditional_embed_dim_scalar = 0 + embedding_scalar = None + context_embedding_noise = torch.randn( + B, conditional_embed_dim_noise, H, L + ).to(device) + context_embedding_labels = torch.randn(B, conditional_embed_dim_labels).to( + device + ) + context_embedding_pos = torch.randn(B, conditional_embed_dim_pos, H, L).to( + device + ) + context = Context( + embedding_scalar=embedding_scalar, + embedding_pos=context_embedding_pos, + noise=context_embedding_noise, + labels=context_embedding_labels, + ) + x = torch.randn(B, C, H, L, device=get_device()) + forward = RealSHT(nlat=H, nlon=L) + inverse = InverseRealSHT(nlat=H, nlon=L) + context_config = ContextConfig( + embed_dim_scalar=conditional_embed_dim_scalar, + embed_dim_noise=conditional_embed_dim_noise, + embed_dim_labels=conditional_embed_dim_labels, + embed_dim_pos=conditional_embed_dim_pos, + ) + block = FourierNeuralOperatorBlock( + forward_transform=forward, + inverse_transform=inverse, + img_shape=(H, L), + embed_dim=C, + filter_type="linear", + operator_type="dhconv", + use_mlp=True, + context_config=context_config, + filter_num_groups=G, + ).to(device) + return cls(block=block, x=x, context=context) + + @classmethod + def new_for_regression(cls): + B = 1 + C = 16 + H = 9 + L = 18 + G = 2 + conditional_embed_dim_noise = 4 + conditional_embed_dim_labels = 3 + conditional_embed_dim_pos = 2 + return cls._new_with_params( + B=B, + C=C, + H=H, + L=L, + G=G, + conditional_embed_dim_noise=conditional_embed_dim_noise, + conditional_embed_dim_labels=conditional_embed_dim_labels, + conditional_embed_dim_pos=conditional_embed_dim_pos, + ) + + return BlockBenchmark + + +register_benchmark("csfno_block")(get_block_benchmark(filter_num_groups=1)) +register_benchmark("csfno_block_8_groups")(get_block_benchmark(filter_num_groups=8)) diff --git a/fme/core/models/conditional_sfno/layers.py b/fme/core/models/conditional_sfno/layers.py index 47648d781..5f6dcbe8e 100644 --- a/fme/core/models/conditional_sfno/layers.py +++ b/fme/core/models/conditional_sfno/layers.py @@ -24,6 +24,7 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint +from fme.core.benchmark.timer import Timer, NullTimer from fme.core.models.conditional_sfno.lora import LoRAConv2d from .activations import ComplexReLU @@ -223,7 +224,12 @@ def reset_parameters(self): torch.nn.init.constant_(self.W_bias_pos.weight, 0.0) # no bias on 2d layers as it is already handled in the non-2d layers - def forward(self, x: torch.Tensor, context: Context) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + context: Context, + timer: Timer = NullTimer(), + ) -> torch.Tensor: """ Conditional Layer Normalization @@ -242,52 +248,58 @@ def forward(self, x: torch.Tensor, context: Context) -> torch.Tensor: self.W_scale_labels is not None or self.W_bias_labels is not None ): raise ValueError("labels must be provided") - if self.W_scale is not None: - if context.embedding_scalar is None: - raise ValueError("embedding_scalar must be provided") - scale: torch.Tensor = ( - self.W_scale(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) - ) - else: - scale = torch.ones( - list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype - ) + with timer.child("compute_scaling_and_bias"): + if self.W_scale is not None: + if context.embedding_scalar is None: + raise ValueError("embedding_scalar must be provided") + scale: torch.Tensor = ( + self.W_scale(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) + ) + else: + scale = torch.ones( + list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype + ) - if self.W_scale_2d is not None: - if context.noise is None: - raise ValueError("embedding_2d must be provided") - scale = scale + self.W_scale_2d(context.noise) - if self.W_bias is not None: - if context.embedding_scalar is None: - raise ValueError("embedding_scalar must be provided") - bias: torch.Tensor = ( - self.W_bias(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) - ) - else: - bias = torch.zeros( - list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype - ) + if self.W_scale_2d is not None: + if context.noise is None: + raise ValueError("embedding_2d must be provided") + scale = scale + self.W_scale_2d(context.noise) + if self.W_bias is not None: + if context.embedding_scalar is None: + raise ValueError("embedding_scalar must be provided") + bias: torch.Tensor = ( + self.W_bias(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) + ) + else: + bias = torch.zeros( + list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype + ) - if self.W_scale_labels is not None: - scale = scale + self.W_scale_labels(context.labels).unsqueeze(-1).unsqueeze( - -1 - ) - if self.W_bias_labels is not None: - bias = bias + self.W_bias_labels(context.labels).unsqueeze(-1).unsqueeze(-1) - if self.W_bias_2d is not None: - if context.noise is None: - raise ValueError("embedding_2d must be provided") - bias = bias + self.W_bias_2d(context.noise) - if self.W_scale_pos is not None: - if context.embedding_pos is None: - raise ValueError("embedding_pos must be provided") - scale = scale + self.W_scale_pos(context.embedding_pos) - if self.W_bias_pos is not None: - if context.embedding_pos is None: - raise ValueError("embedding_pos must be provided") - bias = bias + self.W_bias_pos(context.embedding_pos) - x_norm: torch.Tensor = self.norm(x) - return x_norm * scale + bias + if self.W_scale_labels is not None: + scale = scale + self.W_scale_labels(context.labels).unsqueeze( + -1 + ).unsqueeze(-1) + if self.W_bias_labels is not None: + bias = bias + self.W_bias_labels(context.labels).unsqueeze( + -1 + ).unsqueeze(-1) + if self.W_bias_2d is not None: + if context.noise is None: + raise ValueError("embedding_2d must be provided") + bias = bias + self.W_bias_2d(context.noise) + if self.W_scale_pos is not None: + if context.embedding_pos is None: + raise ValueError("embedding_pos must be provided") + scale = scale + self.W_scale_pos(context.embedding_pos) + if self.W_bias_pos is not None: + if context.embedding_pos is None: + raise ValueError("embedding_pos must be provided") + bias = bias + self.W_bias_pos(context.embedding_pos) + with timer.child("normalize"): + x_norm: torch.Tensor = self.norm(x) + with timer.child("apply_scaling_and_bias"): + return_value = x_norm * scale + bias + return return_value @torch.jit.script diff --git a/fme/core/models/conditional_sfno/makani/spectral_convolution.py b/fme/core/models/conditional_sfno/makani/spectral_convolution.py index f38894c83..e99a7f5e1 100644 --- a/fme/core/models/conditional_sfno/makani/spectral_convolution.py +++ b/fme/core/models/conditional_sfno/makani/spectral_convolution.py @@ -19,6 +19,8 @@ import torch.nn as nn from torch import amp +from fme.core.benchmark.timer import NullTimer, Timer + # import convenience functions for factorized tensors from .factorizations import get_contract_fun @@ -124,7 +126,7 @@ def __init__( if bias: self.bias = nn.Parameter(torch.zeros(1, self.out_channels, 1, 1)) - def forward(self, x): + def forward(self, x, timer: Timer = NullTimer()): dtype = x.dtype residual = x x = x.float() @@ -138,7 +140,10 @@ def forward(self, x): B, C, H, W = x.shape x = x.reshape(B, self.num_groups, C // self.num_groups, H, W) xp = self._contract( - x, self.weight, separable=self.separable, operator_type=self.operator_type + x, + self.weight, + separable=self.separable, + operator_type=self.operator_type, ) x = xp.reshape(B, self.out_channels, H, W).contiguous() diff --git a/fme/core/models/conditional_sfno/s2convolutions.py b/fme/core/models/conditional_sfno/s2convolutions.py index 93299256a..b138dd442 100644 --- a/fme/core/models/conditional_sfno/s2convolutions.py +++ b/fme/core/models/conditional_sfno/s2convolutions.py @@ -22,6 +22,8 @@ import torch_harmonics as th import torch_harmonics.distributed as thd +from fme.core.benchmark.timer import NullTimer, Timer + # import convenience functions for factorized tensors from .activations import ComplexReLU @@ -223,45 +225,51 @@ def __init__( self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) self.out_channels = out_channels - def forward(self, x): # pragma: no cover + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover dtype = x.dtype residual = x x = x.float() with torch.amp.autocast("cuda", enabled=False): - x = self.forward_transform(x.float()) + with timer.child("forward_transform"): + x = self.forward_transform(x.float()) if self._round_trip_residual: - x = x.contiguous() - residual = self.inverse_transform(x) - residual = residual.to(dtype) + with timer.child("round_trip_residual"): + x = x.contiguous() + residual = self.inverse_transform(x) + residual = residual.to(dtype) B, C, H, W = x.shape assert C % self.num_groups == 0 x = x.reshape(B, self.num_groups, C // self.num_groups, H, W) if self.lora_A is not None and self.lora_B is not None: - lora_update = _contract_lora( - self.lora_A, - self.lora_B, - x[..., : self.modes_lat_local, : self.modes_lon_local], - ) + with timer.child("lora_update"): + lora_update = _contract_lora( + self.lora_A, + self.lora_B, + x[..., : self.modes_lat_local, : self.modes_lon_local], + ) else: lora_update = 0.0 - xp = torch.zeros_like(x) - xp[..., : self.modes_lat_local, : self.modes_lon_local] = _contract_dhconv( - x[..., : self.modes_lat_local, : self.modes_lon_local], - self.weight, - ) - xp = xp + self.lora_scaling * lora_update - xp = xp.reshape(B, self.out_channels, H, W) - x = xp.contiguous() + with timer.child("dhconv"): + xp = torch.zeros_like(x) + xp[..., : self.modes_lat_local, : self.modes_lon_local] = _contract_dhconv( + x[..., : self.modes_lat_local, : self.modes_lon_local], + self.weight, + ) + xp = xp + self.lora_scaling * lora_update + xp = xp.reshape(B, self.out_channels, H, W) + x = xp.contiguous() with torch.amp.autocast("cuda", enabled=False): - x = self.inverse_transform(x) + with timer.child("inverse_transform"): + x = self.inverse_transform(x) if hasattr(self, "bias"): - x = x + self.bias + with timer.child("add_bias"): + x = x + self.bias x = x.type(dtype) @@ -320,7 +328,7 @@ def __init__( scale * torch.randn(1, out_channels, *self.output_dims) ) - def forward(self, x): # pragma: no cover + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover dtype = x.dtype x = x.float() B, C, H, W = x.shape @@ -503,7 +511,7 @@ def forward_mlp(self, x): # pragma: no cover return x - def forward(self, x): # pragma: no cover + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover dtype = x.dtype residual = x x = x.to(torch.float32) @@ -626,7 +634,7 @@ def forward_mlp(self, x): # pragma: no cover return x - def forward(self, x): # pragma: no cover + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover dtype = x.dtype x = x.to(torch.float32) diff --git a/fme/core/models/conditional_sfno/sfnonet.py b/fme/core/models/conditional_sfno/sfnonet.py index 61d35ca27..29eb986f0 100644 --- a/fme/core/models/conditional_sfno/sfnonet.py +++ b/fme/core/models/conditional_sfno/sfnonet.py @@ -24,6 +24,8 @@ import torch_harmonics as th from torch.utils.checkpoint import checkpoint +from fme.core.benchmark.timer import Timer, NullTimer + from .initialization import trunc_normal_ # wrap fft, to unify interface to spectral transforms @@ -62,7 +64,7 @@ def __init__(self, *args, **kwargs): super().__init__() self.conv = th.DiscreteContinuousConvS2(*args, **kwargs) - def forward(self, x): + def forward(self, x, timer: Timer = NullTimer()): return self.conv(x), x @@ -153,8 +155,8 @@ def __init__( else: raise (NotImplementedError) - def forward(self, x): - return self.filter(x) + def forward(self, x, timer: Timer = NullTimer()): + return self.filter(x, timer=timer) class FourierNeuralOperatorBlock(nn.Module): @@ -295,44 +297,54 @@ def __init__( lora_alpha=lora_alpha, ) - def forward(self, x, context_embedding): - x_norm = torch.zeros_like(x) - x_norm[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] = self.norm0( - x[..., : self.input_shape_loc[0], : self.input_shape_loc[1]], - context_embedding, - ) - x, residual = self.filter(x_norm) - + def forward(self, x, context_embedding, timer: Timer = NullTimer()): + with timer.child("norm0") as norm0_timer: + x_norm = torch.zeros_like(x) + x_norm[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] = ( + self.norm0( + x[..., : self.input_shape_loc[0], : self.input_shape_loc[1]], + context_embedding, + timer=norm0_timer, + ) + ) + with timer.child("filter") as filter_timer: + x, residual = self.filter(x_norm, timer=filter_timer) if hasattr(self, "inner_skip"): - if self.concat_skip: - x = torch.cat((x, self.inner_skip(residual)), dim=1) - x = self.inner_skip_conv(x) - else: - x = x + self.inner_skip(residual) + with timer.child("inner_skip"): + if self.concat_skip: + x = torch.cat((x, self.inner_skip(residual)), dim=1) + x = self.inner_skip_conv(x) + else: + x = x + self.inner_skip(residual) if hasattr(self, "act_layer"): - x = self.act_layer(x) - - x_norm = torch.zeros_like(x) - x_norm[..., : self.output_shape_loc[0], : self.output_shape_loc[1]] = ( - self.norm1( - x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]], - context_embedding, + with timer.child("activation"): + x = self.act_layer(x) + + with timer.child("norm1") as norm1_timer: + x_norm = torch.zeros_like(x) + x_norm[..., : self.output_shape_loc[0], : self.output_shape_loc[1]] = ( + self.norm1( + x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]], + context_embedding, + timer=norm1_timer, + ) ) - ) - x = x_norm + x = x_norm if hasattr(self, "mlp"): - x = self.mlp(x) + with timer.child("mlp"): + x = self.mlp(x) x = self.drop_path(x) if hasattr(self, "outer_skip"): - if self.concat_skip: - x = torch.cat((x, self.outer_skip(residual)), dim=1) - x = self.outer_skip_conv(x) - else: - x = x + self.outer_skip(residual) + with timer.child("outer_skip"): + if self.concat_skip: + x = torch.cat((x, self.outer_skip(residual)), dim=1) + x = self.outer_skip_conv(x) + else: + x = x + self.outer_skip(residual) return x diff --git a/fme/core/models/conditional_sfno/sht.py b/fme/core/models/conditional_sfno/sht.py new file mode 100644 index 000000000..dd9f8fc02 --- /dev/null +++ b/fme/core/models/conditional_sfno/sht.py @@ -0,0 +1,225 @@ +# flake8: noqa +# fmt: off +# isort: skip_file + +""" +This file contains a fix that we needed to get the SFNO to work on multiple +unroll steps in multiprocessing (e.g. multi-GPU mode.) We forked this code from +the torch harmonics sht.py file [*]. + +[*] https://github.com/NVIDIA/torch-harmonics/blob/17eefa53468d1a885d72087918eba905fa53e10a/torch_harmonics/sht.py +""" + + +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import torch +import torch.nn as nn +import torch.fft + +from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights +from torch_harmonics.legendre import _precompute_legpoly + +from fme.core.device import get_device +from fme.core.benchmark.timer import Timer, NullTimer + + +class RealSHT(nn.Module): + """ + Defines a module for computing the forward (real-valued) SHT. + Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. + The SHT is applied to the last two dimensions of the input + + [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. + [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True): + """ + Initializes the SHT Layer, precomputing the necessary quadrature weights + + Parameters: + nlat: input grid resolution in the latitudinal direction + nlon: input grid resolution in the longitudinal direction + grid: grid in the latitude direction (for now only tensor product grids are supported) + """ + + super().__init__() + + self.nlat = nlat + self.nlon = nlon + self.grid = grid + self.norm = norm + self.csphase = csphase + + # TODO: include assertions regarding the dimensions + + # compute quadrature points + if self.grid == "legendre-gauss": + cost, w = legendre_gauss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "lobatto": + cost, w = lobatto_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat-1 + elif self.grid == "equiangular": + cost, w = clenshaw_curtiss_weights(nlat, -1, 1) + # cost, w = fejer2_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "healpix": + raise(NotImplementedError("'healpix' grid not supported by InverseRealVectorSHT")) + else: + raise(ValueError("Unknown quadrature mode")) + + # apply cosine transform and flip them + tq = torch.flip(torch.arccos(cost), dims=(0,)) + + # determine the dimensions + self.mmax = mmax or self.nlon // 2 + 1 + + # combine quadrature weights with the legendre weights + pct = torch.as_tensor(_precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)) + weights = torch.einsum('mlk,k->mlk', pct, w) + + # remember quadrature weights + self.weights = weights.float().to(get_device()) + + def extra_repr(self): + """ + Pretty print module + """ + return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' + + def forward(self, x: torch.Tensor, timer: Timer = NullTimer()): + + assert(x.shape[-2] == self.nlat) + assert(x.shape[-1] == self.nlon) + with torch.autocast("cuda", enabled=False): + with timer.child("rfft"): + # rfft and view_as_complex don't support BF16, see https://github.com/pytorch/pytorch/issues/117844 + x = x.float() + + # apply real fft in the longitudinal direction + x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward") + + with timer.child("contraction"): + # do the Legendre-Gauss quadrature + x = torch.view_as_real(x) + + # distributed contraction: fork + out_shape = list(x.size()) + out_shape[-3] = self.lmax + out_shape[-2] = self.mmax + xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) + + # contraction + weights = self.weights.to(x.device).to(x.dtype) + xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], weights) + xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], weights) + x = torch.view_as_complex(xout) + + return x + +class InverseRealSHT(nn.Module): + """ + Defines a module for computing the inverse (real-valued) SHT. + Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. + nlat, nlon: Output dimensions + lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions + + [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. + [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True): + + super().__init__() + + self.nlat = nlat + self.nlon = nlon + self.grid = grid + self.norm = norm + self.csphase = csphase + + # compute quadrature points + if self.grid == "legendre-gauss": + cost, _ = legendre_gauss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "lobatto": + cost, _ = lobatto_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat-1 + elif self.grid == "equiangular": + cost, _ = clenshaw_curtiss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "healpix": + raise(NotImplementedError("'healpix' grid not supported by RealVectorSHT")) + else: + raise(ValueError("Unknown quadrature mode")) + + # apply cosine transform and flip them + t = torch.flip(torch.arccos(cost), dims=(0,)) + + # determine the dimensions + self.mmax = mmax or self.nlon // 2 + 1 + + pct = torch.as_tensor(_precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)) + + # register buffer + self.pct = pct.float().to(get_device()) + + def extra_repr(self): + """ + Pretty print module + """ + return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' + + def forward(self, x: torch.Tensor, timer: Timer = NullTimer()): + + assert(x.shape[-2] == self.lmax) + assert(x.shape[-1] == self.mmax) + + with torch.autocast("cuda", enabled=False): + with timer.child("contraction"): + # irfft and view_as_complex don't support BF16, see https://github.com/pytorch/pytorch/issues/117844 + # Evaluate associated Legendre functions on the output nodes + x = torch.view_as_real(x).float() + + pct = self.pct.to(x.device).to(x.dtype) + rl = torch.einsum('...lm, mlk->...km', x[..., 0], pct ) + im = torch.einsum('...lm, mlk->...km', x[..., 1], pct ) + xs = torch.stack((rl, im), -1) + + # apply the inverse (real) FFT + x = torch.view_as_complex(xs) + with timer.child("irfft"): + x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") + + return x diff --git a/fme/core/models/conditional_sfno/test_sfnonet.py b/fme/core/models/conditional_sfno/test_sfnonet.py index 3230d7c87..eb01658f4 100644 --- a/fme/core/models/conditional_sfno/test_sfnonet.py +++ b/fme/core/models/conditional_sfno/test_sfnonet.py @@ -1,3 +1,4 @@ +import dataclasses import os from types import SimpleNamespace @@ -6,6 +7,7 @@ from torch import nn from fme.core.device import get_device +from fme.core.models.conditional_sfno.benchmark import get_block_benchmark from fme.core.testing.regression import validate_tensor from .layers import Context, ContextConfig @@ -221,3 +223,63 @@ def forward(self, x): assert not torch.isnan(output).any() else: assert torch.isnan(output).any() + + +@dataclasses.dataclass +class BenchmarkResult: + ms_total: float + ms_per: float + max_alloc: int + max_reserved: int + y_shape: tuple + y_dtype: torch.dtype + + +def benchmark(fn, iters=10, warmup=1) -> BenchmarkResult: + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + torch.cuda.reset_peak_memory_stats() + starter = torch.cuda.Event(enable_timing=True) + ender = torch.cuda.Event(enable_timing=True) + + starter.record() + for _ in range(iters): + y = fn() + ender.record() + torch.cuda.synchronize() + + ms = starter.elapsed_time(ender) + return BenchmarkResult( + ms_total=ms, + ms_per=ms / iters, + max_alloc=torch.cuda.max_memory_allocated(), + max_reserved=torch.cuda.max_memory_reserved(), + y_shape=tuple(y.shape), + y_dtype=y.dtype, + ) + + +@pytest.mark.skipif( + get_device().type != "cuda", + reason=( + "This test is only relevant for CUDA since " + "it's testing speed of SFNO blocks on GPU." + ), +) # noqa: E501 +def test_block_speed(): + ungrouped = get_block_benchmark(filter_num_groups=1).run_benchmark( + iters=5, warmup=1 + ) + grouped = get_block_benchmark(filter_num_groups=8).run_benchmark(iters=5, warmup=1) + assert grouped.timer.avg_time < ungrouped.timer.avg_time, ( + "Expected grouped DHConv to be faster than ungrouped, but got " + f"{grouped.timer.avg_time:.6f} ms for grouped and " + f"{ungrouped.timer.avg_time:.6f} ms for ungrouped." + ) + assert grouped.memory.max_alloc < ungrouped.memory.max_alloc, ( + "Expected grouped DHConv to use less memory than ungrouped, but got " + f"{grouped.memory.max_alloc / 1e6:.2f} MB for grouped and " + f"{ungrouped.memory.max_alloc / 1e6:.2f} MB for ungrouped." + ) diff --git a/fme/core/registry/registry.py b/fme/core/registry/registry.py index 98c00d6f2..4358447d9 100644 --- a/fme/core/registry/registry.py +++ b/fme/core/registry/registry.py @@ -57,3 +57,6 @@ def register_func(cls: type[T]) -> type[T]: def get(self, type_name: str, config: Mapping[str, Any]) -> T: cls = self._types[type_name] return cls.from_state(config) + + def get_all(self) -> dict[str, type[T]]: + return self._types.copy()