From 4dec08a57a7b26054ba51a2927732e12433f9dcc Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 10 Feb 2026 19:00:54 +0000 Subject: [PATCH 1/6] add CUDATimer and NullTimer --- fme/core/benchmark/test_timer.py | 36 +++++++ fme/core/benchmark/timer.py | 171 +++++++++++++++++++++++++++++++ 2 files changed, 207 insertions(+) create mode 100644 fme/core/benchmark/test_timer.py create mode 100644 fme/core/benchmark/timer.py diff --git a/fme/core/benchmark/test_timer.py b/fme/core/benchmark/test_timer.py new file mode 100644 index 000000000..dd9cd5624 --- /dev/null +++ b/fme/core/benchmark/test_timer.py @@ -0,0 +1,36 @@ +from unittest.mock import patch + +import pytest +import torch + +from fme.core.benchmark.timer import CUDATimer + + +@pytest.mark.parametrize("is_available", [True, False]) +def test_new_if_available(is_available: bool): + from fme.core.benchmark.timer import CUDATimer, NullTimer + + with patch("torch.cuda.is_available", return_value=is_available): + timer = CUDATimer.new_if_available() + if is_available: + assert isinstance(timer, CUDATimer) + else: + assert isinstance(timer, NullTimer) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA is not available, skipping CUDATimer tests.", +) +def test_timer_with_child(): + timer = CUDATimer() + with timer: + # get cuda to wait + torch.cuda._sleep(100_000) + with timer.child("child"): + torch.cuda._sleep(100_000) + result = timer.result + assert "child" in result.children + # parent time should include the child time, so it should be + # at least 2x the child time (since we sleep for the same amount of time in both) + assert result.avg_time >= 2.0 * result.children["child"].avg_time diff --git a/fme/core/benchmark/timer.py b/fme/core/benchmark/timer.py new file mode 100644 index 000000000..230fe4e72 --- /dev/null +++ b/fme/core/benchmark/timer.py @@ -0,0 +1,171 @@ +import collections +import contextlib +import dataclasses +from typing import Literal, Protocol, Self + +import torch + + +@dataclasses.dataclass +class TimerResult: + total_runs: int + avg_time: float + children: dict[str, "TimerResult"] + + def assert_close(self, other: "TimerResult", rtol=0.02, children_rtol=0.02) -> None: + if self.total_runs != other.total_runs: + raise AssertionError( + f"total_runs differ: {self.total_runs} vs {other.total_runs}" + ) + if not torch.isclose( + torch.tensor(self.avg_time), torch.tensor(other.avg_time), rtol=rtol + ): + raise AssertionError( + f"avg_time differ: {self.avg_time} vs " + f"{other.avg_time} given rtol={rtol}" + ) + if self.children.keys() != other.children.keys(): + raise AssertionError( + f"children keys differ: {self.children.keys()} vs " + f"{other.children.keys()}" + ) + for key in self.children.keys(): + try: + self.children[key].assert_close( + other.children[key], rtol=children_rtol, children_rtol=children_rtol + ) + except AssertionError as e: + raise AssertionError(f"child '{key}' differ: {e}") from e + + +class Timer(Protocol): + def child(self, name: str) -> Self: ... + def __enter__(self) -> Self: ... + def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]: ... + + +class NullTimer: + def context(self, name: str) -> contextlib.nullcontext: + return contextlib.nullcontext() + + def child(self, name: str) -> "Self": + return self + + def __enter__(self) -> "Self": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]: + return False + + def report(self) -> TimerResult: + return TimerResult(total_runs=0, avg_time=0.0, children={}) + + +_: Timer = NullTimer() +del _ + + +class EventPair: + def __init__(self): + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + self._stream = None + self._start_recorded = False + self._end_recorded = False + + def record_start(self): + if self._start_recorded: + raise RuntimeError( + "record_start has already been called on this EventPair." + ) + self._stream = torch.cuda.current_stream() + self.start.record(self._stream) + self._start_recorded = True + + def record_end(self): + if not self._start_recorded: + raise RuntimeError("record_start must be called before record_end") + if self._end_recorded: + raise RuntimeError("record_end has already been called on this EventPair.") + if self._stream is None: + raise RuntimeError("record_start must be called before record_end") + self.end.record(self._stream) + self._end_recorded = True + + def elapsed_time_ms(self) -> float: + if not self._start_recorded or not self._end_recorded: + raise RuntimeError( + "Both record_start and record_end must be called " + "before elapsed_time_ms can be called." + ) + return self.start.elapsed_time(self.end) + + +class CUDATimer: + def __init__(self): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available, cannot use CUDATimer.") + self._children: collections.defaultdict[str, CUDATimer] = ( + collections.defaultdict(CUDATimer) + ) + self._event_pairs: list[EventPair] = [] + self._entered = False + self._result: TimerResult | None = None + + @classmethod + def new_if_available(cls) -> "CUDATimer | NullTimer": + if torch.cuda.is_available(): + return cls() + else: + return NullTimer() + + def __enter__(self): + if self._entered: + raise RuntimeError("CUDATimer is already entered.") + self._entered = True + self._event_pairs.append(EventPair()) + self._event_pairs[-1].record_start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self._event_pairs: + raise RuntimeError("CUDATimer context was not properly entered.") + self._event_pairs[-1].record_end() + self._entered = False + return False + + def child(self, name: str) -> "CUDATimer": + if not self._entered: + raise RuntimeError( + "CUDATimer child cannot be used before entering the timer." + ) + return self._children[name] + + @property + def _avg_time(self) -> float: + if len(self._event_pairs) == 0: + raise RuntimeError( + "CUDATimer report cannot be generated before entering the timer." + ) + total_time = sum( + event_pair.elapsed_time_ms() for event_pair in self._event_pairs + ) + return total_time / len(self._event_pairs) + + def _child_reports(self) -> dict[str, TimerResult]: + return {name: child.result for name, child in self._children.items()} + + @property + def result(self) -> TimerResult: + if self._result is None: + torch.cuda.synchronize() + self._result = TimerResult( + total_runs=len(self._event_pairs), + avg_time=self._avg_time, + children=self._child_reports(), + ) + return self._result + + +__: type[Timer] = CUDATimer +del __ From 3460bea8cf28d1b6e002a74976ad773f63973798 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 10 Feb 2026 19:22:51 +0000 Subject: [PATCH 2/6] test assert_close --- fme/core/benchmark/test_timer.py | 80 +++++++++++++++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/fme/core/benchmark/test_timer.py b/fme/core/benchmark/test_timer.py index dd9cd5624..c8f97e6b8 100644 --- a/fme/core/benchmark/test_timer.py +++ b/fme/core/benchmark/test_timer.py @@ -1,9 +1,10 @@ +from typing import Literal from unittest.mock import patch import pytest import torch -from fme.core.benchmark.timer import CUDATimer +from fme.core.benchmark.timer import CUDATimer, TimerResult @pytest.mark.parametrize("is_available", [True, False]) @@ -34,3 +35,80 @@ def test_timer_with_child(): # parent time should include the child time, so it should be # at least 2x the child time (since we sleep for the same amount of time in both) assert result.avg_time >= 2.0 * result.children["child"].avg_time + + +def _create_parent_result(avg_time: float) -> TimerResult: + return TimerResult(total_runs=2, avg_time=avg_time, children={}) + + +def _create_child_result(avg_time: float) -> TimerResult: + return TimerResult( + total_runs=2, + avg_time=1.0, + children={"child": TimerResult(total_runs=2, avg_time=avg_time, children={})}, + ) + + +@pytest.mark.parametrize( + "v1, v2, rtol, expect_raise", + [ + (100, 101, 0.02, False), # within 2% + (100, 103, 0.02, True), # outside 2% + (100, 102, 0.02, False), # exactly 2% is considered inside + (10000, 10201, 0.02, True), # more than 2% is considered outside + (100, 102, 0.03, False), # exactly 2% is within 3% + ], +) +@pytest.mark.parametrize("kind", ["parent", "child"]) +def test_assert_close( + v1: int, v2: int, rtol: float, kind: Literal["parent", "child"], expect_raise: bool +): + if kind == "child": + result1 = _create_child_result(avg_time=v1) + result2 = _create_child_result(avg_time=v2) + else: + result1 = _create_parent_result(avg_time=v1) + result2 = _create_parent_result(avg_time=v2) + if expect_raise: + with pytest.raises(AssertionError): + result2.assert_close(result1, rtol=rtol) + else: + result2.assert_close(result1, rtol=rtol) + + +def test_assert_close_different_total_runs(): + # different total runs should raise regardless of rtol + result1 = TimerResult(total_runs=100, avg_time=100.0, children={}) + result2 = TimerResult(total_runs=101, avg_time=100.0, children={}) + with pytest.raises(AssertionError): + result2.assert_close(result1, rtol=0.5) + + +def test_assert_close_children_rtol(): + # test that children_rtol is used for child comparisons + result1 = TimerResult( + total_runs=2, + avg_time=100.0, + children={"child": TimerResult(total_runs=2, avg_time=100.0, children={})}, + ) + result2 = TimerResult( + total_runs=2, + avg_time=110.0, + children={"child": TimerResult(total_runs=2, avg_time=103.0, children={})}, + ) + result2.assert_close(result1, rtol=0.2, children_rtol=0.05) + + +def test_assert_close_children_rtol_raises(): + # test that children_rtol is used for child comparisons + result1 = TimerResult( + total_runs=2, + avg_time=100.0, + children={"child": TimerResult(total_runs=2, avg_time=100.0, children={})}, + ) + result2 = TimerResult( + total_runs=2, + avg_time=110.0, + children={"child": TimerResult(total_runs=2, avg_time=103.0, children={})}, + ) + result2.assert_close(result1, rtol=0.5, children_rtol=0.2) From 6c1c29c60dbd5341233a6e9528b698a9faf83586 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 10 Feb 2026 19:25:49 +0000 Subject: [PATCH 3/6] copy-paste sht_fix.py --- fme/core/models/conditional_sfno/sht.py | 223 ++++++++++++++++++++++++ 1 file changed, 223 insertions(+) create mode 100644 fme/core/models/conditional_sfno/sht.py diff --git a/fme/core/models/conditional_sfno/sht.py b/fme/core/models/conditional_sfno/sht.py new file mode 100644 index 000000000..b4127c5ec --- /dev/null +++ b/fme/core/models/conditional_sfno/sht.py @@ -0,0 +1,223 @@ +# flake8: noqa +# fmt: off +# isort: skip_file + +""" +This file contains a fix that we needed to get the SFNO to work on multiple +unroll steps in multiprocessing (e.g. multi-GPU mode.) We forked this code from +the torch harmonics sht.py file [*]. + +[*] https://github.com/NVIDIA/torch-harmonics/blob/17eefa53468d1a885d72087918eba905fa53e10a/torch_harmonics/sht.py +""" + + +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import torch +import torch.nn as nn +import torch.fft + +from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights +from torch_harmonics.legendre import _precompute_legpoly +import torch_harmonics + +from fme.core.device import get_device + +class RealSHT(nn.Module): + """ + Defines a module for computing the forward (real-valued) SHT. + Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. + The SHT is applied to the last two dimensions of the input + + [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. + [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True): + """ + Initializes the SHT Layer, precomputing the necessary quadrature weights + + Parameters: + nlat: input grid resolution in the latitudinal direction + nlon: input grid resolution in the longitudinal direction + grid: grid in the latitude direction (for now only tensor product grids are supported) + """ + + super().__init__() + + self.nlat = nlat + self.nlon = nlon + self.grid = grid + self.norm = norm + self.csphase = csphase + + # TODO: include assertions regarding the dimensions + + # compute quadrature points + if self.grid == "legendre-gauss": + cost, w = legendre_gauss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "lobatto": + cost, w = lobatto_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat-1 + elif self.grid == "equiangular": + cost, w = clenshaw_curtiss_weights(nlat, -1, 1) + # cost, w = fejer2_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "healpix": + raise(NotImplementedError("'healpix' grid not supported by InverseRealVectorSHT")) + else: + raise(ValueError("Unknown quadrature mode")) + + # apply cosine transform and flip them + tq = torch.flip(torch.arccos(cost), dims=(0,)) + + # determine the dimensions + self.mmax = mmax or self.nlon // 2 + 1 + + # combine quadrature weights with the legendre weights + pct = torch.as_tensor(_precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)) + weights = torch.einsum('mlk,k->mlk', pct, w) + + # remember quadrature weights + self.weights = weights.float().to(get_device()) + + def extra_repr(self): + """ + Pretty print module + """ + return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' + + def forward(self, x: torch.Tensor): + + assert(x.shape[-2] == self.nlat) + assert(x.shape[-1] == self.nlon) + with torch.autocast("cuda", enabled=False): + # rfft and view_as_complex don't support BF16, see https://github.com/pytorch/pytorch/issues/117844 + x = x.float() + + # apply real fft in the longitudinal direction + x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward") + + # do the Legendre-Gauss quadrature + x = torch.view_as_real(x) + + # distributed contraction: fork + out_shape = list(x.size()) + out_shape[-3] = self.lmax + out_shape[-2] = self.mmax + xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) + + # contraction + weights = self.weights.to(x.device).to(x.dtype) + xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], weights) + xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], weights) + x = torch.view_as_complex(xout) + + return x + +class InverseRealSHT(nn.Module): + """ + Defines a module for computing the inverse (real-valued) SHT. + Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. + nlat, nlon: Output dimensions + lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions + + [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. + [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True): + + super().__init__() + + self.nlat = nlat + self.nlon = nlon + self.grid = grid + self.norm = norm + self.csphase = csphase + + # compute quadrature points + if self.grid == "legendre-gauss": + cost, _ = legendre_gauss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "lobatto": + cost, _ = lobatto_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat-1 + elif self.grid == "equiangular": + cost, _ = clenshaw_curtiss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "healpix": + raise(NotImplementedError("'healpix' grid not supported by RealVectorSHT")) + else: + raise(ValueError("Unknown quadrature mode")) + + # apply cosine transform and flip them + t = torch.flip(torch.arccos(cost), dims=(0,)) + + # determine the dimensions + self.mmax = mmax or self.nlon // 2 + 1 + + pct = torch.as_tensor(_precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)) + + # register buffer + self.pct = pct.float().to(get_device()) + + def extra_repr(self): + """ + Pretty print module + """ + return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' + + def forward(self, x: torch.Tensor): + + assert(x.shape[-2] == self.lmax) + assert(x.shape[-1] == self.mmax) + + with torch.autocast("cuda", enabled=False): + # irfft and view_as_complex don't support BF16, see https://github.com/pytorch/pytorch/issues/117844 + # Evaluate associated Legendre functions on the output nodes + x = torch.view_as_real(x).float() + + pct = self.pct.to(x.device).to(x.dtype) + rl = torch.einsum('...lm, mlk->...km', x[..., 0], pct ) + im = torch.einsum('...lm, mlk->...km', x[..., 1], pct ) + xs = torch.stack((rl, im), -1) + + # apply the inverse (real) FFT + x = torch.view_as_complex(xs) + x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") + + return x + +torch_harmonics.RealSHT = RealSHT +torch_harmonics.InverseRealSHT = InverseRealSHT From 332540a6058fa0d6b206554ba02c3ed5694c60b5 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 10 Feb 2026 19:27:48 +0000 Subject: [PATCH 4/6] update conditional SFNO to pass timers for profiling --- fme/core/models/conditional_sfno/layers.py | 102 ++++++++++-------- .../makani/spectral_convolution.py | 9 +- .../models/conditional_sfno/s2convolutions.py | 54 ++++++---- fme/core/models/conditional_sfno/sfnonet.py | 74 +++++++------ fme/core/models/conditional_sfno/sht.py | 66 ++++++------ 5 files changed, 172 insertions(+), 133 deletions(-) diff --git a/fme/core/models/conditional_sfno/layers.py b/fme/core/models/conditional_sfno/layers.py index 47648d781..5f6dcbe8e 100644 --- a/fme/core/models/conditional_sfno/layers.py +++ b/fme/core/models/conditional_sfno/layers.py @@ -24,6 +24,7 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint +from fme.core.benchmark.timer import Timer, NullTimer from fme.core.models.conditional_sfno.lora import LoRAConv2d from .activations import ComplexReLU @@ -223,7 +224,12 @@ def reset_parameters(self): torch.nn.init.constant_(self.W_bias_pos.weight, 0.0) # no bias on 2d layers as it is already handled in the non-2d layers - def forward(self, x: torch.Tensor, context: Context) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + context: Context, + timer: Timer = NullTimer(), + ) -> torch.Tensor: """ Conditional Layer Normalization @@ -242,52 +248,58 @@ def forward(self, x: torch.Tensor, context: Context) -> torch.Tensor: self.W_scale_labels is not None or self.W_bias_labels is not None ): raise ValueError("labels must be provided") - if self.W_scale is not None: - if context.embedding_scalar is None: - raise ValueError("embedding_scalar must be provided") - scale: torch.Tensor = ( - self.W_scale(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) - ) - else: - scale = torch.ones( - list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype - ) + with timer.child("compute_scaling_and_bias"): + if self.W_scale is not None: + if context.embedding_scalar is None: + raise ValueError("embedding_scalar must be provided") + scale: torch.Tensor = ( + self.W_scale(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) + ) + else: + scale = torch.ones( + list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype + ) - if self.W_scale_2d is not None: - if context.noise is None: - raise ValueError("embedding_2d must be provided") - scale = scale + self.W_scale_2d(context.noise) - if self.W_bias is not None: - if context.embedding_scalar is None: - raise ValueError("embedding_scalar must be provided") - bias: torch.Tensor = ( - self.W_bias(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) - ) - else: - bias = torch.zeros( - list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype - ) + if self.W_scale_2d is not None: + if context.noise is None: + raise ValueError("embedding_2d must be provided") + scale = scale + self.W_scale_2d(context.noise) + if self.W_bias is not None: + if context.embedding_scalar is None: + raise ValueError("embedding_scalar must be provided") + bias: torch.Tensor = ( + self.W_bias(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) + ) + else: + bias = torch.zeros( + list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype + ) - if self.W_scale_labels is not None: - scale = scale + self.W_scale_labels(context.labels).unsqueeze(-1).unsqueeze( - -1 - ) - if self.W_bias_labels is not None: - bias = bias + self.W_bias_labels(context.labels).unsqueeze(-1).unsqueeze(-1) - if self.W_bias_2d is not None: - if context.noise is None: - raise ValueError("embedding_2d must be provided") - bias = bias + self.W_bias_2d(context.noise) - if self.W_scale_pos is not None: - if context.embedding_pos is None: - raise ValueError("embedding_pos must be provided") - scale = scale + self.W_scale_pos(context.embedding_pos) - if self.W_bias_pos is not None: - if context.embedding_pos is None: - raise ValueError("embedding_pos must be provided") - bias = bias + self.W_bias_pos(context.embedding_pos) - x_norm: torch.Tensor = self.norm(x) - return x_norm * scale + bias + if self.W_scale_labels is not None: + scale = scale + self.W_scale_labels(context.labels).unsqueeze( + -1 + ).unsqueeze(-1) + if self.W_bias_labels is not None: + bias = bias + self.W_bias_labels(context.labels).unsqueeze( + -1 + ).unsqueeze(-1) + if self.W_bias_2d is not None: + if context.noise is None: + raise ValueError("embedding_2d must be provided") + bias = bias + self.W_bias_2d(context.noise) + if self.W_scale_pos is not None: + if context.embedding_pos is None: + raise ValueError("embedding_pos must be provided") + scale = scale + self.W_scale_pos(context.embedding_pos) + if self.W_bias_pos is not None: + if context.embedding_pos is None: + raise ValueError("embedding_pos must be provided") + bias = bias + self.W_bias_pos(context.embedding_pos) + with timer.child("normalize"): + x_norm: torch.Tensor = self.norm(x) + with timer.child("apply_scaling_and_bias"): + return_value = x_norm * scale + bias + return return_value @torch.jit.script diff --git a/fme/core/models/conditional_sfno/makani/spectral_convolution.py b/fme/core/models/conditional_sfno/makani/spectral_convolution.py index f38894c83..e99a7f5e1 100644 --- a/fme/core/models/conditional_sfno/makani/spectral_convolution.py +++ b/fme/core/models/conditional_sfno/makani/spectral_convolution.py @@ -19,6 +19,8 @@ import torch.nn as nn from torch import amp +from fme.core.benchmark.timer import NullTimer, Timer + # import convenience functions for factorized tensors from .factorizations import get_contract_fun @@ -124,7 +126,7 @@ def __init__( if bias: self.bias = nn.Parameter(torch.zeros(1, self.out_channels, 1, 1)) - def forward(self, x): + def forward(self, x, timer: Timer = NullTimer()): dtype = x.dtype residual = x x = x.float() @@ -138,7 +140,10 @@ def forward(self, x): B, C, H, W = x.shape x = x.reshape(B, self.num_groups, C // self.num_groups, H, W) xp = self._contract( - x, self.weight, separable=self.separable, operator_type=self.operator_type + x, + self.weight, + separable=self.separable, + operator_type=self.operator_type, ) x = xp.reshape(B, self.out_channels, H, W).contiguous() diff --git a/fme/core/models/conditional_sfno/s2convolutions.py b/fme/core/models/conditional_sfno/s2convolutions.py index 93299256a..b138dd442 100644 --- a/fme/core/models/conditional_sfno/s2convolutions.py +++ b/fme/core/models/conditional_sfno/s2convolutions.py @@ -22,6 +22,8 @@ import torch_harmonics as th import torch_harmonics.distributed as thd +from fme.core.benchmark.timer import NullTimer, Timer + # import convenience functions for factorized tensors from .activations import ComplexReLU @@ -223,45 +225,51 @@ def __init__( self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) self.out_channels = out_channels - def forward(self, x): # pragma: no cover + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover dtype = x.dtype residual = x x = x.float() with torch.amp.autocast("cuda", enabled=False): - x = self.forward_transform(x.float()) + with timer.child("forward_transform"): + x = self.forward_transform(x.float()) if self._round_trip_residual: - x = x.contiguous() - residual = self.inverse_transform(x) - residual = residual.to(dtype) + with timer.child("round_trip_residual"): + x = x.contiguous() + residual = self.inverse_transform(x) + residual = residual.to(dtype) B, C, H, W = x.shape assert C % self.num_groups == 0 x = x.reshape(B, self.num_groups, C // self.num_groups, H, W) if self.lora_A is not None and self.lora_B is not None: - lora_update = _contract_lora( - self.lora_A, - self.lora_B, - x[..., : self.modes_lat_local, : self.modes_lon_local], - ) + with timer.child("lora_update"): + lora_update = _contract_lora( + self.lora_A, + self.lora_B, + x[..., : self.modes_lat_local, : self.modes_lon_local], + ) else: lora_update = 0.0 - xp = torch.zeros_like(x) - xp[..., : self.modes_lat_local, : self.modes_lon_local] = _contract_dhconv( - x[..., : self.modes_lat_local, : self.modes_lon_local], - self.weight, - ) - xp = xp + self.lora_scaling * lora_update - xp = xp.reshape(B, self.out_channels, H, W) - x = xp.contiguous() + with timer.child("dhconv"): + xp = torch.zeros_like(x) + xp[..., : self.modes_lat_local, : self.modes_lon_local] = _contract_dhconv( + x[..., : self.modes_lat_local, : self.modes_lon_local], + self.weight, + ) + xp = xp + self.lora_scaling * lora_update + xp = xp.reshape(B, self.out_channels, H, W) + x = xp.contiguous() with torch.amp.autocast("cuda", enabled=False): - x = self.inverse_transform(x) + with timer.child("inverse_transform"): + x = self.inverse_transform(x) if hasattr(self, "bias"): - x = x + self.bias + with timer.child("add_bias"): + x = x + self.bias x = x.type(dtype) @@ -320,7 +328,7 @@ def __init__( scale * torch.randn(1, out_channels, *self.output_dims) ) - def forward(self, x): # pragma: no cover + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover dtype = x.dtype x = x.float() B, C, H, W = x.shape @@ -503,7 +511,7 @@ def forward_mlp(self, x): # pragma: no cover return x - def forward(self, x): # pragma: no cover + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover dtype = x.dtype residual = x x = x.to(torch.float32) @@ -626,7 +634,7 @@ def forward_mlp(self, x): # pragma: no cover return x - def forward(self, x): # pragma: no cover + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover dtype = x.dtype x = x.to(torch.float32) diff --git a/fme/core/models/conditional_sfno/sfnonet.py b/fme/core/models/conditional_sfno/sfnonet.py index 61d35ca27..29eb986f0 100644 --- a/fme/core/models/conditional_sfno/sfnonet.py +++ b/fme/core/models/conditional_sfno/sfnonet.py @@ -24,6 +24,8 @@ import torch_harmonics as th from torch.utils.checkpoint import checkpoint +from fme.core.benchmark.timer import Timer, NullTimer + from .initialization import trunc_normal_ # wrap fft, to unify interface to spectral transforms @@ -62,7 +64,7 @@ def __init__(self, *args, **kwargs): super().__init__() self.conv = th.DiscreteContinuousConvS2(*args, **kwargs) - def forward(self, x): + def forward(self, x, timer: Timer = NullTimer()): return self.conv(x), x @@ -153,8 +155,8 @@ def __init__( else: raise (NotImplementedError) - def forward(self, x): - return self.filter(x) + def forward(self, x, timer: Timer = NullTimer()): + return self.filter(x, timer=timer) class FourierNeuralOperatorBlock(nn.Module): @@ -295,44 +297,54 @@ def __init__( lora_alpha=lora_alpha, ) - def forward(self, x, context_embedding): - x_norm = torch.zeros_like(x) - x_norm[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] = self.norm0( - x[..., : self.input_shape_loc[0], : self.input_shape_loc[1]], - context_embedding, - ) - x, residual = self.filter(x_norm) - + def forward(self, x, context_embedding, timer: Timer = NullTimer()): + with timer.child("norm0") as norm0_timer: + x_norm = torch.zeros_like(x) + x_norm[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] = ( + self.norm0( + x[..., : self.input_shape_loc[0], : self.input_shape_loc[1]], + context_embedding, + timer=norm0_timer, + ) + ) + with timer.child("filter") as filter_timer: + x, residual = self.filter(x_norm, timer=filter_timer) if hasattr(self, "inner_skip"): - if self.concat_skip: - x = torch.cat((x, self.inner_skip(residual)), dim=1) - x = self.inner_skip_conv(x) - else: - x = x + self.inner_skip(residual) + with timer.child("inner_skip"): + if self.concat_skip: + x = torch.cat((x, self.inner_skip(residual)), dim=1) + x = self.inner_skip_conv(x) + else: + x = x + self.inner_skip(residual) if hasattr(self, "act_layer"): - x = self.act_layer(x) - - x_norm = torch.zeros_like(x) - x_norm[..., : self.output_shape_loc[0], : self.output_shape_loc[1]] = ( - self.norm1( - x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]], - context_embedding, + with timer.child("activation"): + x = self.act_layer(x) + + with timer.child("norm1") as norm1_timer: + x_norm = torch.zeros_like(x) + x_norm[..., : self.output_shape_loc[0], : self.output_shape_loc[1]] = ( + self.norm1( + x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]], + context_embedding, + timer=norm1_timer, + ) ) - ) - x = x_norm + x = x_norm if hasattr(self, "mlp"): - x = self.mlp(x) + with timer.child("mlp"): + x = self.mlp(x) x = self.drop_path(x) if hasattr(self, "outer_skip"): - if self.concat_skip: - x = torch.cat((x, self.outer_skip(residual)), dim=1) - x = self.outer_skip_conv(x) - else: - x = x + self.outer_skip(residual) + with timer.child("outer_skip"): + if self.concat_skip: + x = torch.cat((x, self.outer_skip(residual)), dim=1) + x = self.outer_skip_conv(x) + else: + x = x + self.outer_skip(residual) return x diff --git a/fme/core/models/conditional_sfno/sht.py b/fme/core/models/conditional_sfno/sht.py index b4127c5ec..dd9f8fc02 100644 --- a/fme/core/models/conditional_sfno/sht.py +++ b/fme/core/models/conditional_sfno/sht.py @@ -48,9 +48,10 @@ from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights from torch_harmonics.legendre import _precompute_legpoly -import torch_harmonics from fme.core.device import get_device +from fme.core.benchmark.timer import Timer, NullTimer + class RealSHT(nn.Module): """ @@ -117,31 +118,33 @@ def extra_repr(self): """ return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, timer: Timer = NullTimer()): assert(x.shape[-2] == self.nlat) assert(x.shape[-1] == self.nlon) with torch.autocast("cuda", enabled=False): - # rfft and view_as_complex don't support BF16, see https://github.com/pytorch/pytorch/issues/117844 - x = x.float() + with timer.child("rfft"): + # rfft and view_as_complex don't support BF16, see https://github.com/pytorch/pytorch/issues/117844 + x = x.float() - # apply real fft in the longitudinal direction - x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward") + # apply real fft in the longitudinal direction + x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward") - # do the Legendre-Gauss quadrature - x = torch.view_as_real(x) + with timer.child("contraction"): + # do the Legendre-Gauss quadrature + x = torch.view_as_real(x) - # distributed contraction: fork - out_shape = list(x.size()) - out_shape[-3] = self.lmax - out_shape[-2] = self.mmax - xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) + # distributed contraction: fork + out_shape = list(x.size()) + out_shape[-3] = self.lmax + out_shape[-2] = self.mmax + xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) - # contraction - weights = self.weights.to(x.device).to(x.dtype) - xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], weights) - xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], weights) - x = torch.view_as_complex(xout) + # contraction + weights = self.weights.to(x.device).to(x.dtype) + xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], weights) + xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], weights) + x = torch.view_as_complex(xout) return x @@ -198,26 +201,25 @@ def extra_repr(self): """ return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, timer: Timer = NullTimer()): assert(x.shape[-2] == self.lmax) assert(x.shape[-1] == self.mmax) with torch.autocast("cuda", enabled=False): - # irfft and view_as_complex don't support BF16, see https://github.com/pytorch/pytorch/issues/117844 - # Evaluate associated Legendre functions on the output nodes - x = torch.view_as_real(x).float() + with timer.child("contraction"): + # irfft and view_as_complex don't support BF16, see https://github.com/pytorch/pytorch/issues/117844 + # Evaluate associated Legendre functions on the output nodes + x = torch.view_as_real(x).float() - pct = self.pct.to(x.device).to(x.dtype) - rl = torch.einsum('...lm, mlk->...km', x[..., 0], pct ) - im = torch.einsum('...lm, mlk->...km', x[..., 1], pct ) - xs = torch.stack((rl, im), -1) + pct = self.pct.to(x.device).to(x.dtype) + rl = torch.einsum('...lm, mlk->...km', x[..., 0], pct ) + im = torch.einsum('...lm, mlk->...km', x[..., 1], pct ) + xs = torch.stack((rl, im), -1) - # apply the inverse (real) FFT - x = torch.view_as_complex(xs) - x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") + # apply the inverse (real) FFT + x = torch.view_as_complex(xs) + with timer.child("irfft"): + x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") return x - -torch_harmonics.RealSHT = RealSHT -torch_harmonics.InverseRealSHT = InverseRealSHT From d17f3f15901300bcb4ecc24b052b261e700f74f3 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 10 Feb 2026 21:21:23 +0000 Subject: [PATCH 5/6] incorporate review comments --- fme/core/benchmark/test_timer.py | 30 +++++++++++++++--------------- fme/core/benchmark/timer.py | 17 ++++------------- 2 files changed, 19 insertions(+), 28 deletions(-) diff --git a/fme/core/benchmark/test_timer.py b/fme/core/benchmark/test_timer.py index c8f97e6b8..47c527d56 100644 --- a/fme/core/benchmark/test_timer.py +++ b/fme/core/benchmark/test_timer.py @@ -38,14 +38,14 @@ def test_timer_with_child(): def _create_parent_result(avg_time: float) -> TimerResult: - return TimerResult(total_runs=2, avg_time=avg_time, children={}) + return TimerResult(count=2, avg_time=avg_time, children={}) def _create_child_result(avg_time: float) -> TimerResult: return TimerResult( - total_runs=2, + count=2, avg_time=1.0, - children={"child": TimerResult(total_runs=2, avg_time=avg_time, children={})}, + children={"child": TimerResult(count=2, avg_time=avg_time, children={})}, ) @@ -76,10 +76,10 @@ def test_assert_close( result2.assert_close(result1, rtol=rtol) -def test_assert_close_different_total_runs(): - # different total runs should raise regardless of rtol - result1 = TimerResult(total_runs=100, avg_time=100.0, children={}) - result2 = TimerResult(total_runs=101, avg_time=100.0, children={}) +def test_assert_close_different_count(): + # different count should raise regardless of rtol + result1 = TimerResult(count=100, avg_time=100.0, children={}) + result2 = TimerResult(count=101, avg_time=100.0, children={}) with pytest.raises(AssertionError): result2.assert_close(result1, rtol=0.5) @@ -87,14 +87,14 @@ def test_assert_close_different_total_runs(): def test_assert_close_children_rtol(): # test that children_rtol is used for child comparisons result1 = TimerResult( - total_runs=2, + count=2, avg_time=100.0, - children={"child": TimerResult(total_runs=2, avg_time=100.0, children={})}, + children={"child": TimerResult(count=2, avg_time=100.0, children={})}, ) result2 = TimerResult( - total_runs=2, + count=2, avg_time=110.0, - children={"child": TimerResult(total_runs=2, avg_time=103.0, children={})}, + children={"child": TimerResult(count=2, avg_time=103.0, children={})}, ) result2.assert_close(result1, rtol=0.2, children_rtol=0.05) @@ -102,13 +102,13 @@ def test_assert_close_children_rtol(): def test_assert_close_children_rtol_raises(): # test that children_rtol is used for child comparisons result1 = TimerResult( - total_runs=2, + count=2, avg_time=100.0, - children={"child": TimerResult(total_runs=2, avg_time=100.0, children={})}, + children={"child": TimerResult(count=2, avg_time=100.0, children={})}, ) result2 = TimerResult( - total_runs=2, + count=2, avg_time=110.0, - children={"child": TimerResult(total_runs=2, avg_time=103.0, children={})}, + children={"child": TimerResult(count=2, avg_time=103.0, children={})}, ) result2.assert_close(result1, rtol=0.5, children_rtol=0.2) diff --git a/fme/core/benchmark/timer.py b/fme/core/benchmark/timer.py index 230fe4e72..943d034ff 100644 --- a/fme/core/benchmark/timer.py +++ b/fme/core/benchmark/timer.py @@ -1,5 +1,4 @@ import collections -import contextlib import dataclasses from typing import Literal, Protocol, Self @@ -8,15 +7,13 @@ @dataclasses.dataclass class TimerResult: - total_runs: int + count: int avg_time: float children: dict[str, "TimerResult"] def assert_close(self, other: "TimerResult", rtol=0.02, children_rtol=0.02) -> None: - if self.total_runs != other.total_runs: - raise AssertionError( - f"total_runs differ: {self.total_runs} vs {other.total_runs}" - ) + if self.count != other.count: + raise AssertionError(f"count differ: {self.count} vs {other.count}") if not torch.isclose( torch.tensor(self.avg_time), torch.tensor(other.avg_time), rtol=rtol ): @@ -45,9 +42,6 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]: ... class NullTimer: - def context(self, name: str) -> contextlib.nullcontext: - return contextlib.nullcontext() - def child(self, name: str) -> "Self": return self @@ -57,9 +51,6 @@ def __enter__(self) -> "Self": def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]: return False - def report(self) -> TimerResult: - return TimerResult(total_runs=0, avg_time=0.0, children={}) - _: Timer = NullTimer() del _ @@ -160,7 +151,7 @@ def result(self) -> TimerResult: if self._result is None: torch.cuda.synchronize() self._result = TimerResult( - total_runs=len(self._event_pairs), + count=len(self._event_pairs), avg_time=self._avg_time, children=self._child_reports(), ) From 43a00c3692d5790340e8a03f966d5a2c5fdf1f7a Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 10 Feb 2026 21:23:24 +0000 Subject: [PATCH 6/6] fix test --- fme/core/benchmark/test_timer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fme/core/benchmark/test_timer.py b/fme/core/benchmark/test_timer.py index 47c527d56..809894b66 100644 --- a/fme/core/benchmark/test_timer.py +++ b/fme/core/benchmark/test_timer.py @@ -111,4 +111,5 @@ def test_assert_close_children_rtol_raises(): avg_time=110.0, children={"child": TimerResult(count=2, avg_time=103.0, children={})}, ) - result2.assert_close(result1, rtol=0.5, children_rtol=0.2) + with pytest.raises(AssertionError): + result2.assert_close(result1, rtol=0.05, children_rtol=0.2)