diff --git a/fme/core/__init__.py b/fme/core/__init__.py index 67658ccfa..5c2f47b01 100644 --- a/fme/core/__init__.py +++ b/fme/core/__init__.py @@ -1,3 +1,4 @@ +from . import models as _ # to trigger registrations from .atmosphere_data import AtmosphereData from .device import get_device, using_gpu from .gridded_ops import GriddedOperations @@ -14,11 +15,14 @@ from .rand import set_seed from .registry import Registry +del _ + __all__ = [ "spherical_area_weights", "weighted_mean", "weighted_mean_bias", "weighted_nanmean", + "weighted_sum", "root_mean_squared_error", "get_device", "using_gpu", diff --git a/fme/core/benchmark/.gitignore b/fme/core/benchmark/.gitignore new file mode 100644 index 000000000..1a06816d8 --- /dev/null +++ b/fme/core/benchmark/.gitignore @@ -0,0 +1 @@ +results diff --git a/fme/core/benchmark/benchmark.py b/fme/core/benchmark/benchmark.py new file mode 100644 index 000000000..47457f8d5 --- /dev/null +++ b/fme/core/benchmark/benchmark.py @@ -0,0 +1,300 @@ +import abc +import dataclasses +import pathlib +from collections.abc import Callable +from typing import Self + +import dacite +import matplotlib.pyplot as plt +import torch + +from fme.core.benchmark.memory import MemoryResult, benchmark_memory +from fme.core.benchmark.timer import CUDATimer, NullTimer, Timer, TimerResult +from fme.core.typing_ import TensorDict + + +@dataclasses.dataclass +class BenchmarkResult: + memory: MemoryResult + timer: TimerResult + + def __repr__(self) -> str: + return f"BenchmarkResult(memory={self.memory}, timer={self.timer})" + + def asdict(self) -> dict: + return dataclasses.asdict(self) + + @classmethod + def from_dict(cls, d: dict) -> "BenchmarkResult": + return dacite.from_dict(cls, d, config=dacite.Config(strict=True)) + + def assert_close( + self, other: "BenchmarkResult", rtol=0.02, children_rtol=0.02 + ) -> None: + try: + self.timer.assert_close(other.timer, rtol=rtol, children_rtol=children_rtol) + except AssertionError as e: + raise AssertionError(f"Timer results differ: {e}") from e + try: + self.memory.assert_close(other.memory, rtol=rtol) + except AssertionError as e: + raise AssertionError(f"Memory results differ: {e}") from e + + def to_png( + self, path: str | pathlib.Path, label: str, child: str | None = None + ) -> None: + # note this function was generated with AI + def avg_time(t: TimerResult) -> float: + return float(t.avg_time) + + def self_time(t: TimerResult) -> float: + t_avg = avg_time(t) + c_avg = sum(avg_time(c) for c in t.children.values()) + return max(t_avg - c_avg, 0.0) + + def fmt_time(ms: float) -> str: + if ms >= 1000.0: + return f"{ms/1000.0:.2f}s" + if ms >= 10.0: + return f"{ms:.1f}ms" + return f"{ms:.2f}ms" + + def label_ok(name: str, ms: float, frac_of_root: float) -> bool: + if not name: + return False + return frac_of_root >= 0.05 + + def ordered_children(t: TimerResult) -> list[tuple[str, TimerResult]]: + return list(t.children.items()) # maintain dict order (insertion order) + + def blend_with_white( + rgb: tuple[float, float, float], amount: float + ) -> tuple[float, float, float]: + # amount in [0,1]: 0 -> original, 1 -> white + return ( + rgb[0] + (1.0 - rgb[0]) * amount, + rgb[1] + (1.0 - rgb[1]) * amount, + rgb[2] + (1.0 - rgb[2]) * amount, + ) + + root = self.timer + if child is not None: + for part in child.split("."): + if part not in root.children: + raise ValueError(f"Child '{child}' not found in timer results.") + root = root.children[part] + root_avg = avg_time(root) + + max_alloc_mb = self.memory.max_alloc / (1024.0 * 1024.0) + + fig = plt.figure(figsize=(8, 6), constrained_layout=True) + if root_avg <= 0.0: + fig.suptitle( + f"Benchmark for {label}\ntotal=0.00s, max_alloc={max_alloc_mb:.1f} MB", + fontsize=14, + ) + ax0 = fig.add_subplot(1, 1, 1) + ax0.text(0.5, 0.5, "No timing data", ha="center", va="center") + ax0.axis("off") + fig.savefig(path, dpi=200) + plt.close(fig) + return + + fig.suptitle( + f"Benchmark for {label}\ntotal={fmt_time(root_avg)}, " + f"max_alloc={max_alloc_mb:.1f} MB", + fontsize=14, + ) + + ax = fig.add_subplot(1, 1, 1) + ax.set_xlim(0, 2) + ax.set_ylim(0, root_avg) + ax.set_xticks([0.5, 1.5]) + ax.set_xticklabels(["Level 1", "Level 2"]) + ax.set_ylabel("Avg time") + ax.set_yticks([]) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + gray = (0.85, 0.85, 0.85, 1.0) + cmap = plt.get_cmap("tab20") + + lvl1 = ordered_children(root) + lvl1_names = [n for n, _ in lvl1] + lvl1_index = {n: i for i, n in enumerate(lvl1_names)} + + # Level 1 stack (root children + root self in gray, unlabeled) + lvl1_segments: list[tuple[str, float, tuple[float, float, float, float]]] = [] + for n1, t1 in lvl1: + base = cmap(lvl1_index[n1] % cmap.N) + lvl1_segments.append((n1, avg_time(t1), base)) + r_self = self_time(root) + if r_self > 0.0: + lvl1_segments.append(("", r_self, gray)) + + def draw_stack( + x_center: float, + segments: list[tuple[str, float, tuple[float, float, float, float]]], + ) -> None: + width = 0.86 + y = 0.0 + for name, sec, color in segments: + if sec <= 0.0: + continue + ax.bar( + x_center, + sec, + bottom=y, + width=width, + align="center", + color=color, + edgecolor="white", + linewidth=1.0, + ) + frac = sec / root_avg + if label_ok(name, sec, frac): + ax.text( + x_center, + y + sec / 2.0, + f"{name}\n{fmt_time(sec)}", + ha="center", + va="center", + fontsize=9, + rotation=0, # keep horizontal to avoid cross-column overlap + clip_on=True, + ) + y += sec + if y < root_avg: + ax.bar( + x_center, + root_avg - y, + bottom=y, + width=width, + align="center", + color=gray, + edgecolor="white", + linewidth=1.0, + ) + + draw_stack(0.5, lvl1_segments) + + # Level 2 stack: + # For each level-1 slice, stack its children + # (colored as parent hue variants) + self in gray. + lvl2_segments: list[tuple[str, float, tuple[float, float, float, float]]] = [] + for n1, t1 in lvl1: + parent_rgba = cmap(lvl1_index[n1] % cmap.N) + parent_rgb = (parent_rgba[0], parent_rgba[1], parent_rgba[2]) + + children = ordered_children(t1) + k = len(children) + for i, (n2, t2) in enumerate(children): + # Same “type” of color as parent: lighten progressively per child. + # First child is closest to parent; later children are lighter. + lighten = 0.10 + (0.55 * (i / max(k - 1, 1))) + rgb = blend_with_white(parent_rgb, lighten) + lvl2_segments.append((n2, avg_time(t2), (rgb[0], rgb[1], rgb[2], 1.0))) + + s1 = self_time(t1) + if s1 > 0.0: + lvl2_segments.append(("", s1, gray)) + + draw_stack(1.5, lvl2_segments) + + fig.tight_layout(rect=(0.02, 0.02, 0.98, 0.98)) + fig.savefig(path, dpi=200, bbox_inches="tight") + plt.close(fig) + + +class BenchmarkABC(abc.ABC): + @classmethod + def new_from_fn( + cls, + fn: Callable[[Timer], TensorDict], + ) -> "BenchmarkABC": + class FnBenchmark(BenchmarkABC): + @classmethod + def new(cls) -> "FnBenchmark": + return FnBenchmark() + + def run_instance(self, timer: Timer) -> TensorDict: + return fn(timer) + + return FnBenchmark() + + @classmethod + @abc.abstractmethod + def new(cls: type[Self]) -> Self: + """ + Initialize any state needed for the benchmark. + This will be called once before the benchmark is run. + """ + pass + + @classmethod + def new_for_regression(cls: type[Self]) -> Self | None: + """ + Initialize any state needed for regression testing. + This will be called once before regression tests are run. + + If regression testing is not needed, this can return None, + and regression testing will not be run. + + This exists as a separate method from new so that it can + use small data sizes more conducive to storing regression targets in git. + """ + return None + + @classmethod + def run_benchmark(cls, iters=10, warmup=1) -> BenchmarkResult: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available, cannot run benchmark.") + null_timer = NullTimer() + benchmark = cls.new() + for _ in range(warmup): + benchmark.run_instance(null_timer) + timer = CUDATimer() + with benchmark_memory() as bm: + for _ in range(iters): + with timer: + benchmark.run_instance(timer) + return BenchmarkResult( + timer=timer.result, + memory=bm.result, + ) + + @classmethod + def run_regression(cls) -> TensorDict | None: + benchmark = cls.new_for_regression() + if benchmark is None: + return None + null_timer = NullTimer() + return benchmark.run_instance(null_timer) + + @abc.abstractmethod + def run_instance(self: Self, timer: Timer) -> TensorDict: + """ + Run the benchmark. This will be called multiple times, + and should return a TensorDict of results. + + This must not mutate any state on self, since the same instance may be + used across multiple iterations. + """ + pass + + +_BENCHMARKS: dict[str, type[BenchmarkABC]] = {} + + +def register_benchmark(name: str) -> Callable[[type[BenchmarkABC]], type[BenchmarkABC]]: + def _register(fn: type[BenchmarkABC]) -> type[BenchmarkABC]: + if name in _BENCHMARKS: + raise ValueError(f"Benchmark with name '{name}' is already registered.") + _BENCHMARKS[name] = fn + return fn + + return _register + + +def get_benchmarks() -> dict[str, type[BenchmarkABC]]: + return _BENCHMARKS.copy() diff --git a/fme/core/benchmark/run.py b/fme/core/benchmark/run.py new file mode 100644 index 000000000..f77486a6e --- /dev/null +++ b/fme/core/benchmark/run.py @@ -0,0 +1,164 @@ +import argparse +import dataclasses +import json +import logging +import os +import pathlib +import subprocess +import sys + +import torch + +from fme.core.benchmark.benchmark import get_benchmarks + +RESULTS_PATH = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "results" + +_GIT_COMMIT: str | None = None + + +def get_git_commit() -> str: + global _GIT_COMMIT + if _GIT_COMMIT is None: + commit = ( + subprocess.check_output( + ["git", "rev-parse", "--short", "HEAD"], + stderr=subprocess.DEVNULL, + ) + .decode() + .strip() + ) + + # Non-empty output means repo is dirty + dirty = ( + subprocess.check_output( + ["git", "status", "--porcelain"], + stderr=subprocess.DEVNULL, + ) + .decode() + .strip() + ) + + if dirty: + commit = f"{commit}-dirty" + + _GIT_COMMIT = commit + + return _GIT_COMMIT + + +def get_device_name() -> str: + if torch.cuda.is_available(): + return torch.cuda.get_device_properties(0).name + else: + return "CPU" + + +def main( + name: str | None, iters: int, output_dir: pathlib.Path, child: str | None = None +) -> int: + output_dir.mkdir(exist_ok=True) + device_name = get_device_name() + + logging.info(f"Running benchmarks on device: {device_name}") + benchmarks = get_benchmarks() + if name is not None: + if name not in benchmarks: + logging.error( + f"Specified benchmark {name} not found. " + f"Available benchmarks: {', '.join(benchmarks.keys())}" + ) + return 1 + benchmarks_to_run = {name: benchmarks[name]} + else: + benchmarks_to_run = benchmarks + + def get_label(name): + return f"{name} on {device_name} at commit {get_git_commit()}" + + def get_filename(name, extension) -> pathlib.Path: + safe_name = name.replace("/", "_").replace(".", "_").lower() + safe_device_name = device_name.replace(" ", "_").replace("/", "_").lower() + return ( + output_dir + / f"{safe_name}_{safe_device_name}_{get_git_commit()}.{extension}" + ) + + for name, cls in benchmarks_to_run.items(): + logging.info(f"Running benchmark: {name}") + result = cls.run_benchmark(iters=iters) + png_filename = get_filename(name, "png") + logging.info(f"Saving result image to {png_filename}") + result.to_png(png_filename, label=get_label(name)) + result_data = json.dumps(dataclasses.asdict(result), indent=2) + logging.info(f"Result: {result_data}") + with open(get_filename(name, "json"), "w") as f: + logging.info(f"Saving result json to {f.name}") + f.write(result_data) + if child is not None: + child_name = f"{name}.{child}" + child_label = get_label(child_name) + logging.info(f"Generating benchmark result for child timer: {child_label}") + png_filename = get_filename(child_name, "png") + logging.info(f"Saving child result image to {png_filename}") + result.to_png(png_filename, label=child_label, child=child) + return 0 + + +def get_benchmark_label(name): + device_name = get_device_name() + return f"{name} on {device_name} at commit {get_git_commit()}" + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" + ) + parser = argparse.ArgumentParser(description="Run registered benchmarks.") + parser.add_argument( + "--name", + type=str, + default=None, + help=( + "Name of the benchmark to run. If not provided, " + "all benchmarks will be run." + ), + ) + parser.add_argument( + "--child", + type=str, + default=None, + help=( + "If provided, the child timer to generate a report for. " + "This should be a dot-separated path to a child timer, " + "e.g. 'forward' or 'forward.linear'." + ), + ) + parser.add_argument( + "--iters", + type=int, + default=10, + help="Number of iterations to run each benchmark for.", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help=( + "Directory to save benchmark results in. If not provided, " + "results will be saved in a 'results' directory next to this script." + ), + ) + args = parser.parse_args() + if args.output_dir is not None: + output_dir = pathlib.Path(args.output_dir) + else: + output_dir = RESULTS_PATH + + sys.exit( + main( + name=args.name, + iters=args.iters, + child=args.child, + output_dir=output_dir, + ) + ) diff --git a/fme/core/benchmark/test_benchmark.py b/fme/core/benchmark/test_benchmark.py new file mode 100644 index 000000000..e1de577de --- /dev/null +++ b/fme/core/benchmark/test_benchmark.py @@ -0,0 +1,56 @@ +import os + +import pytest +import torch + +import fme # to trigger registration of benchmarks +from fme.core.benchmark.benchmark import BenchmarkABC, get_benchmarks +from fme.core.device import force_cpu +from fme.core.rand import set_seed +from fme.core.testing.regression import validate_tensor_dict + +del fme + +DIR = os.path.abspath(os.path.dirname(__file__)) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_run_benchmark(): + def benchmark_fn(timer): + torch.cuda._sleep(100_000_000) + + benchmark = BenchmarkABC.new_from_fn(benchmark_fn) + + first_result = benchmark.run_benchmark(iters=15, warmup=1) + assert first_result.timer.count == 15 + second_result = benchmark.run_benchmark(iters=20, warmup=1) + assert second_result.timer.count == 20 + torch.testing.assert_close( + first_result.timer.avg_time, second_result.timer.avg_time, rtol=0.2, atol=0 + ) + + +def test_benchmarks_are_not_empty(): + assert ( + len(get_benchmarks()) > 0 + ), "No benchmarks were registered, but at least one was expected." + + +BENCHMARKS = get_benchmarks() + + +@pytest.mark.parametrize("benchmark_name", BENCHMARKS.keys()) +def test_regression(benchmark_name: str): + with force_cpu(): + set_seed(0) + benchmark_cls = BENCHMARKS[benchmark_name] + regression_result = benchmark_cls.run_regression() + if regression_result is None: + pytest.skip("Benchmark does not have regression targets.") + # If run_regression returns something, + # we expect it to be a TensorDict of results + assert isinstance(regression_result, dict) + validate_tensor_dict( + regression_result, + os.path.join(DIR, "testdata", f"{benchmark_name}-regression.pt"), + ) diff --git a/fme/core/benchmark/test_run.py b/fme/core/benchmark/test_run.py new file mode 100644 index 000000000..62e4907d9 --- /dev/null +++ b/fme/core/benchmark/test_run.py @@ -0,0 +1,21 @@ +import pathlib +import tempfile + +import pytest +import torch + +from fme.core.benchmark.run import main + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_run(): + # Just test that the main function runs without error on a simple benchmark + # We don't care about the output here, just that it completes successfully + with tempfile.TemporaryDirectory() as tmpdir: + output_dir = pathlib.Path(tmpdir) + main( + name="csfno_block", # just one for speed + iters=1, + output_dir=output_dir, + child=None, + ) diff --git a/fme/core/benchmark/testdata/csfno_block-regression.pt b/fme/core/benchmark/testdata/csfno_block-regression.pt new file mode 100644 index 000000000..85a5aa015 Binary files /dev/null and b/fme/core/benchmark/testdata/csfno_block-regression.pt differ diff --git a/fme/core/benchmark/testdata/csfno_block_8_groups-regression.pt b/fme/core/benchmark/testdata/csfno_block_8_groups-regression.pt new file mode 100644 index 000000000..8ceeacbe6 Binary files /dev/null and b/fme/core/benchmark/testdata/csfno_block_8_groups-regression.pt differ diff --git a/fme/core/device.py b/fme/core/device.py index dc70a33cc..044ab41bc 100644 --- a/fme/core/device.py +++ b/fme/core/device.py @@ -1,9 +1,31 @@ +import contextlib import os +from collections.abc import Generator import torch from .typing_ import TensorDict, TensorMapping +_FORCE_CPU: bool = False + + +@contextlib.contextmanager +def force_cpu(force: bool = True) -> Generator[None, None, None]: + """Force the use of CPU even if a GPU is available. This is useful for + testing and debugging. + + Args: + force: If True, force the use of CPU. If False, allow the use of GPU if + available. + """ + global _FORCE_CPU + previous = _FORCE_CPU + try: + _FORCE_CPU = force + yield + finally: + _FORCE_CPU = previous + def using_gpu() -> bool: return get_device().type == "cuda" @@ -20,6 +42,8 @@ def get_device() -> torch.device: """If CUDA is available, return a CUDA device. Otherwise, return a CPU device unless FME_USE_MPS is set, in which case return an MPS device if available. """ + if _FORCE_CPU: + return torch.device("cpu") if torch.cuda.is_available(): return torch.device("cuda", torch.cuda.current_device()) else: diff --git a/fme/core/models/__init__.py b/fme/core/models/__init__.py new file mode 100644 index 000000000..ae3c6d041 --- /dev/null +++ b/fme/core/models/__init__.py @@ -0,0 +1 @@ +from . import conditional_sfno diff --git a/fme/core/models/conditional_sfno/__init__.py b/fme/core/models/conditional_sfno/__init__.py index e69de29bb..6e628e84f 100644 --- a/fme/core/models/conditional_sfno/__init__.py +++ b/fme/core/models/conditional_sfno/__init__.py @@ -0,0 +1 @@ +from . import benchmark diff --git a/fme/core/models/conditional_sfno/benchmark.py b/fme/core/models/conditional_sfno/benchmark.py new file mode 100644 index 000000000..62b4ef992 --- /dev/null +++ b/fme/core/models/conditional_sfno/benchmark.py @@ -0,0 +1,126 @@ +from typing import Self + +import torch + +from fme.core.benchmark.benchmark import BenchmarkABC, register_benchmark +from fme.core.benchmark.timer import Timer +from fme.core.device import get_device +from fme.core.models.conditional_sfno.layers import Context, ContextConfig +from fme.core.models.conditional_sfno.sfnonet import FourierNeuralOperatorBlock +from fme.core.models.conditional_sfno.sht import InverseRealSHT, RealSHT +from fme.core.typing_ import TensorDict + + +def get_block_benchmark(filter_num_groups: int) -> type[BenchmarkABC]: + class BlockBenchmark(BenchmarkABC): + def __init__( + self, block: FourierNeuralOperatorBlock, x: torch.Tensor, context: Context + ): + self.block = block + self.x = x + self.context = context + + def run_instance(self, timer: Timer) -> TensorDict: + result = self.block(self.x, self.context, timer=timer) + return {"output": result.detach()} + + @classmethod + def new(cls) -> Self: + B = 2 + C = 512 + H = 180 + L = 360 + G = filter_num_groups + conditional_embed_dim_noise = 64 + conditional_embed_dim_labels = 3 + conditional_embed_dim_pos = 32 + return cls._new_with_params( + B=B, + C=C, + H=H, + L=L, + G=G, + conditional_embed_dim_noise=conditional_embed_dim_noise, + conditional_embed_dim_labels=conditional_embed_dim_labels, + conditional_embed_dim_pos=conditional_embed_dim_pos, + ) + + @classmethod + def _new_with_params( + cls, + B: int, + C: int, + H: int, + L: int, + G: int, + conditional_embed_dim_noise: int, + conditional_embed_dim_labels: int, + conditional_embed_dim_pos: int, + ) -> Self: + G = filter_num_groups + device = get_device() + conditional_embed_dim_scalar = 0 + embedding_scalar = None + context_embedding_noise = torch.randn( + B, conditional_embed_dim_noise, H, L + ).to(device) + context_embedding_labels = torch.randn(B, conditional_embed_dim_labels).to( + device + ) + context_embedding_pos = torch.randn(B, conditional_embed_dim_pos, H, L).to( + device + ) + context = Context( + embedding_scalar=embedding_scalar, + embedding_pos=context_embedding_pos, + noise=context_embedding_noise, + labels=context_embedding_labels, + ) + x = torch.randn(B, C, H, L, device=get_device()) + forward = RealSHT(nlat=H, nlon=L) + inverse = InverseRealSHT(nlat=H, nlon=L) + context_config = ContextConfig( + embed_dim_scalar=conditional_embed_dim_scalar, + embed_dim_noise=conditional_embed_dim_noise, + embed_dim_labels=conditional_embed_dim_labels, + embed_dim_pos=conditional_embed_dim_pos, + ) + block = FourierNeuralOperatorBlock( + forward_transform=forward, + inverse_transform=inverse, + img_shape=(H, L), + embed_dim=C, + filter_type="linear", + operator_type="dhconv", + use_mlp=True, + context_config=context_config, + filter_num_groups=G, + ).to(device) + return cls(block=block, x=x, context=context) + + @classmethod + def new_for_regression(cls): + B = 1 + C = 16 + H = 9 + L = 18 + G = 2 + conditional_embed_dim_noise = 4 + conditional_embed_dim_labels = 3 + conditional_embed_dim_pos = 2 + return cls._new_with_params( + B=B, + C=C, + H=H, + L=L, + G=G, + conditional_embed_dim_noise=conditional_embed_dim_noise, + conditional_embed_dim_labels=conditional_embed_dim_labels, + conditional_embed_dim_pos=conditional_embed_dim_pos, + ) + + return BlockBenchmark + + +register_benchmark("csfno_block")(get_block_benchmark(filter_num_groups=1)) +register_benchmark("csfno_block_8_groups")(get_block_benchmark(filter_num_groups=8)) diff --git a/fme/core/models/conditional_sfno/test_sfnonet.py b/fme/core/models/conditional_sfno/test_sfnonet.py index 3230d7c87..54cb735e1 100644 --- a/fme/core/models/conditional_sfno/test_sfnonet.py +++ b/fme/core/models/conditional_sfno/test_sfnonet.py @@ -6,6 +6,7 @@ from torch import nn from fme.core.device import get_device +from fme.core.models.conditional_sfno.benchmark import get_block_benchmark from fme.core.testing.regression import validate_tensor from .layers import Context, ContextConfig @@ -221,3 +222,27 @@ def forward(self, x): assert not torch.isnan(output).any() else: assert torch.isnan(output).any() + + +@pytest.mark.skipif( + get_device().type != "cuda", + reason=( + "This test is only relevant for CUDA since " + "it's testing speed of SFNO blocks on GPU." + ), +) # noqa: E501 +def test_block_speed(): + ungrouped = get_block_benchmark(filter_num_groups=1).run_benchmark( + iters=5, warmup=1 + ) + grouped = get_block_benchmark(filter_num_groups=8).run_benchmark(iters=5, warmup=1) + assert grouped.timer.avg_time < ungrouped.timer.avg_time, ( + "Expected grouped DHConv to be faster than ungrouped, but got " + f"{grouped.timer.avg_time:.6f} ms for grouped and " + f"{ungrouped.timer.avg_time:.6f} ms for ungrouped." + ) + assert grouped.memory.max_alloc < ungrouped.memory.max_alloc, ( + "Expected grouped DHConv to use less memory than ungrouped, but got " + f"{grouped.memory.max_alloc / 1e6:.2f} MB for grouped and " + f"{ungrouped.memory.max_alloc / 1e6:.2f} MB for ungrouped." + ) diff --git a/fme/core/test_device.py b/fme/core/test_device.py index 80037c99c..5a0f9d13f 100644 --- a/fme/core/test_device.py +++ b/fme/core/test_device.py @@ -1,7 +1,20 @@ +import pytest import torch import fme +from fme.core.device import force_cpu, get_device def test_device_is_defined(): assert isinstance(fme.get_device(), torch.device) + + +def test_force_cpu(): + device_before = get_device() + if device_before.type == "cpu": + pytest.skip("Device is already CPU, cannot test force_cpu.") + with force_cpu(): + device = get_device() + assert device.type == "cpu" + device_after = get_device() + assert device_after.type == device_before.type