diff --git a/fme/core/models/conditional_sfno/layers.py b/fme/core/models/conditional_sfno/layers.py index 47648d781..5f6dcbe8e 100644 --- a/fme/core/models/conditional_sfno/layers.py +++ b/fme/core/models/conditional_sfno/layers.py @@ -24,6 +24,7 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint +from fme.core.benchmark.timer import Timer, NullTimer from fme.core.models.conditional_sfno.lora import LoRAConv2d from .activations import ComplexReLU @@ -223,7 +224,12 @@ def reset_parameters(self): torch.nn.init.constant_(self.W_bias_pos.weight, 0.0) # no bias on 2d layers as it is already handled in the non-2d layers - def forward(self, x: torch.Tensor, context: Context) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + context: Context, + timer: Timer = NullTimer(), + ) -> torch.Tensor: """ Conditional Layer Normalization @@ -242,52 +248,58 @@ def forward(self, x: torch.Tensor, context: Context) -> torch.Tensor: self.W_scale_labels is not None or self.W_bias_labels is not None ): raise ValueError("labels must be provided") - if self.W_scale is not None: - if context.embedding_scalar is None: - raise ValueError("embedding_scalar must be provided") - scale: torch.Tensor = ( - self.W_scale(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) - ) - else: - scale = torch.ones( - list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype - ) + with timer.child("compute_scaling_and_bias"): + if self.W_scale is not None: + if context.embedding_scalar is None: + raise ValueError("embedding_scalar must be provided") + scale: torch.Tensor = ( + self.W_scale(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) + ) + else: + scale = torch.ones( + list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype + ) - if self.W_scale_2d is not None: - if context.noise is None: - raise ValueError("embedding_2d must be provided") - scale = scale + self.W_scale_2d(context.noise) - if self.W_bias is not None: - if context.embedding_scalar is None: - raise ValueError("embedding_scalar must be provided") - bias: torch.Tensor = ( - self.W_bias(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) - ) - else: - bias = torch.zeros( - list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype - ) + if self.W_scale_2d is not None: + if context.noise is None: + raise ValueError("embedding_2d must be provided") + scale = scale + self.W_scale_2d(context.noise) + if self.W_bias is not None: + if context.embedding_scalar is None: + raise ValueError("embedding_scalar must be provided") + bias: torch.Tensor = ( + self.W_bias(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) + ) + else: + bias = torch.zeros( + list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype + ) - if self.W_scale_labels is not None: - scale = scale + self.W_scale_labels(context.labels).unsqueeze(-1).unsqueeze( - -1 - ) - if self.W_bias_labels is not None: - bias = bias + self.W_bias_labels(context.labels).unsqueeze(-1).unsqueeze(-1) - if self.W_bias_2d is not None: - if context.noise is None: - raise ValueError("embedding_2d must be provided") - bias = bias + self.W_bias_2d(context.noise) - if self.W_scale_pos is not None: - if context.embedding_pos is None: - raise ValueError("embedding_pos must be provided") - scale = scale + self.W_scale_pos(context.embedding_pos) - if self.W_bias_pos is not None: - if context.embedding_pos is None: - raise ValueError("embedding_pos must be provided") - bias = bias + self.W_bias_pos(context.embedding_pos) - x_norm: torch.Tensor = self.norm(x) - return x_norm * scale + bias + if self.W_scale_labels is not None: + scale = scale + self.W_scale_labels(context.labels).unsqueeze( + -1 + ).unsqueeze(-1) + if self.W_bias_labels is not None: + bias = bias + self.W_bias_labels(context.labels).unsqueeze( + -1 + ).unsqueeze(-1) + if self.W_bias_2d is not None: + if context.noise is None: + raise ValueError("embedding_2d must be provided") + bias = bias + self.W_bias_2d(context.noise) + if self.W_scale_pos is not None: + if context.embedding_pos is None: + raise ValueError("embedding_pos must be provided") + scale = scale + self.W_scale_pos(context.embedding_pos) + if self.W_bias_pos is not None: + if context.embedding_pos is None: + raise ValueError("embedding_pos must be provided") + bias = bias + self.W_bias_pos(context.embedding_pos) + with timer.child("normalize"): + x_norm: torch.Tensor = self.norm(x) + with timer.child("apply_scaling_and_bias"): + return_value = x_norm * scale + bias + return return_value @torch.jit.script diff --git a/fme/core/models/conditional_sfno/makani/spectral_convolution.py b/fme/core/models/conditional_sfno/makani/spectral_convolution.py index f38894c83..e99a7f5e1 100644 --- a/fme/core/models/conditional_sfno/makani/spectral_convolution.py +++ b/fme/core/models/conditional_sfno/makani/spectral_convolution.py @@ -19,6 +19,8 @@ import torch.nn as nn from torch import amp +from fme.core.benchmark.timer import NullTimer, Timer + # import convenience functions for factorized tensors from .factorizations import get_contract_fun @@ -124,7 +126,7 @@ def __init__( if bias: self.bias = nn.Parameter(torch.zeros(1, self.out_channels, 1, 1)) - def forward(self, x): + def forward(self, x, timer: Timer = NullTimer()): dtype = x.dtype residual = x x = x.float() @@ -138,7 +140,10 @@ def forward(self, x): B, C, H, W = x.shape x = x.reshape(B, self.num_groups, C // self.num_groups, H, W) xp = self._contract( - x, self.weight, separable=self.separable, operator_type=self.operator_type + x, + self.weight, + separable=self.separable, + operator_type=self.operator_type, ) x = xp.reshape(B, self.out_channels, H, W).contiguous() diff --git a/fme/core/models/conditional_sfno/s2convolutions.py b/fme/core/models/conditional_sfno/s2convolutions.py index 93299256a..b138dd442 100644 --- a/fme/core/models/conditional_sfno/s2convolutions.py +++ b/fme/core/models/conditional_sfno/s2convolutions.py @@ -22,6 +22,8 @@ import torch_harmonics as th import torch_harmonics.distributed as thd +from fme.core.benchmark.timer import NullTimer, Timer + # import convenience functions for factorized tensors from .activations import ComplexReLU @@ -223,45 +225,51 @@ def __init__( self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) self.out_channels = out_channels - def forward(self, x): # pragma: no cover + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover dtype = x.dtype residual = x x = x.float() with torch.amp.autocast("cuda", enabled=False): - x = self.forward_transform(x.float()) + with timer.child("forward_transform"): + x = self.forward_transform(x.float()) if self._round_trip_residual: - x = x.contiguous() - residual = self.inverse_transform(x) - residual = residual.to(dtype) + with timer.child("round_trip_residual"): + x = x.contiguous() + residual = self.inverse_transform(x) + residual = residual.to(dtype) B, C, H, W = x.shape assert C % self.num_groups == 0 x = x.reshape(B, self.num_groups, C // self.num_groups, H, W) if self.lora_A is not None and self.lora_B is not None: - lora_update = _contract_lora( - self.lora_A, - self.lora_B, - x[..., : self.modes_lat_local, : self.modes_lon_local], - ) + with timer.child("lora_update"): + lora_update = _contract_lora( + self.lora_A, + self.lora_B, + x[..., : self.modes_lat_local, : self.modes_lon_local], + ) else: lora_update = 0.0 - xp = torch.zeros_like(x) - xp[..., : self.modes_lat_local, : self.modes_lon_local] = _contract_dhconv( - x[..., : self.modes_lat_local, : self.modes_lon_local], - self.weight, - ) - xp = xp + self.lora_scaling * lora_update - xp = xp.reshape(B, self.out_channels, H, W) - x = xp.contiguous() + with timer.child("dhconv"): + xp = torch.zeros_like(x) + xp[..., : self.modes_lat_local, : self.modes_lon_local] = _contract_dhconv( + x[..., : self.modes_lat_local, : self.modes_lon_local], + self.weight, + ) + xp = xp + self.lora_scaling * lora_update + xp = xp.reshape(B, self.out_channels, H, W) + x = xp.contiguous() with torch.amp.autocast("cuda", enabled=False): - x = self.inverse_transform(x) + with timer.child("inverse_transform"): + x = self.inverse_transform(x) if hasattr(self, "bias"): - x = x + self.bias + with timer.child("add_bias"): + x = x + self.bias x = x.type(dtype) @@ -320,7 +328,7 @@ def __init__( scale * torch.randn(1, out_channels, *self.output_dims) ) - def forward(self, x): # pragma: no cover + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover dtype = x.dtype x = x.float() B, C, H, W = x.shape @@ -503,7 +511,7 @@ def forward_mlp(self, x): # pragma: no cover return x - def forward(self, x): # pragma: no cover + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover dtype = x.dtype residual = x x = x.to(torch.float32) @@ -626,7 +634,7 @@ def forward_mlp(self, x): # pragma: no cover return x - def forward(self, x): # pragma: no cover + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover dtype = x.dtype x = x.to(torch.float32) diff --git a/fme/core/models/conditional_sfno/sfnonet.py b/fme/core/models/conditional_sfno/sfnonet.py index 61d35ca27..29eb986f0 100644 --- a/fme/core/models/conditional_sfno/sfnonet.py +++ b/fme/core/models/conditional_sfno/sfnonet.py @@ -24,6 +24,8 @@ import torch_harmonics as th from torch.utils.checkpoint import checkpoint +from fme.core.benchmark.timer import Timer, NullTimer + from .initialization import trunc_normal_ # wrap fft, to unify interface to spectral transforms @@ -62,7 +64,7 @@ def __init__(self, *args, **kwargs): super().__init__() self.conv = th.DiscreteContinuousConvS2(*args, **kwargs) - def forward(self, x): + def forward(self, x, timer: Timer = NullTimer()): return self.conv(x), x @@ -153,8 +155,8 @@ def __init__( else: raise (NotImplementedError) - def forward(self, x): - return self.filter(x) + def forward(self, x, timer: Timer = NullTimer()): + return self.filter(x, timer=timer) class FourierNeuralOperatorBlock(nn.Module): @@ -295,44 +297,54 @@ def __init__( lora_alpha=lora_alpha, ) - def forward(self, x, context_embedding): - x_norm = torch.zeros_like(x) - x_norm[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] = self.norm0( - x[..., : self.input_shape_loc[0], : self.input_shape_loc[1]], - context_embedding, - ) - x, residual = self.filter(x_norm) - + def forward(self, x, context_embedding, timer: Timer = NullTimer()): + with timer.child("norm0") as norm0_timer: + x_norm = torch.zeros_like(x) + x_norm[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] = ( + self.norm0( + x[..., : self.input_shape_loc[0], : self.input_shape_loc[1]], + context_embedding, + timer=norm0_timer, + ) + ) + with timer.child("filter") as filter_timer: + x, residual = self.filter(x_norm, timer=filter_timer) if hasattr(self, "inner_skip"): - if self.concat_skip: - x = torch.cat((x, self.inner_skip(residual)), dim=1) - x = self.inner_skip_conv(x) - else: - x = x + self.inner_skip(residual) + with timer.child("inner_skip"): + if self.concat_skip: + x = torch.cat((x, self.inner_skip(residual)), dim=1) + x = self.inner_skip_conv(x) + else: + x = x + self.inner_skip(residual) if hasattr(self, "act_layer"): - x = self.act_layer(x) - - x_norm = torch.zeros_like(x) - x_norm[..., : self.output_shape_loc[0], : self.output_shape_loc[1]] = ( - self.norm1( - x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]], - context_embedding, + with timer.child("activation"): + x = self.act_layer(x) + + with timer.child("norm1") as norm1_timer: + x_norm = torch.zeros_like(x) + x_norm[..., : self.output_shape_loc[0], : self.output_shape_loc[1]] = ( + self.norm1( + x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]], + context_embedding, + timer=norm1_timer, + ) ) - ) - x = x_norm + x = x_norm if hasattr(self, "mlp"): - x = self.mlp(x) + with timer.child("mlp"): + x = self.mlp(x) x = self.drop_path(x) if hasattr(self, "outer_skip"): - if self.concat_skip: - x = torch.cat((x, self.outer_skip(residual)), dim=1) - x = self.outer_skip_conv(x) - else: - x = x + self.outer_skip(residual) + with timer.child("outer_skip"): + if self.concat_skip: + x = torch.cat((x, self.outer_skip(residual)), dim=1) + x = self.outer_skip_conv(x) + else: + x = x + self.outer_skip(residual) return x diff --git a/fme/core/models/conditional_sfno/sht.py b/fme/core/models/conditional_sfno/sht.py new file mode 100644 index 000000000..dd9f8fc02 --- /dev/null +++ b/fme/core/models/conditional_sfno/sht.py @@ -0,0 +1,225 @@ +# flake8: noqa +# fmt: off +# isort: skip_file + +""" +This file contains a fix that we needed to get the SFNO to work on multiple +unroll steps in multiprocessing (e.g. multi-GPU mode.) We forked this code from +the torch harmonics sht.py file [*]. + +[*] https://github.com/NVIDIA/torch-harmonics/blob/17eefa53468d1a885d72087918eba905fa53e10a/torch_harmonics/sht.py +""" + + +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import torch +import torch.nn as nn +import torch.fft + +from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights +from torch_harmonics.legendre import _precompute_legpoly + +from fme.core.device import get_device +from fme.core.benchmark.timer import Timer, NullTimer + + +class RealSHT(nn.Module): + """ + Defines a module for computing the forward (real-valued) SHT. + Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. + The SHT is applied to the last two dimensions of the input + + [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. + [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True): + """ + Initializes the SHT Layer, precomputing the necessary quadrature weights + + Parameters: + nlat: input grid resolution in the latitudinal direction + nlon: input grid resolution in the longitudinal direction + grid: grid in the latitude direction (for now only tensor product grids are supported) + """ + + super().__init__() + + self.nlat = nlat + self.nlon = nlon + self.grid = grid + self.norm = norm + self.csphase = csphase + + # TODO: include assertions regarding the dimensions + + # compute quadrature points + if self.grid == "legendre-gauss": + cost, w = legendre_gauss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "lobatto": + cost, w = lobatto_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat-1 + elif self.grid == "equiangular": + cost, w = clenshaw_curtiss_weights(nlat, -1, 1) + # cost, w = fejer2_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "healpix": + raise(NotImplementedError("'healpix' grid not supported by InverseRealVectorSHT")) + else: + raise(ValueError("Unknown quadrature mode")) + + # apply cosine transform and flip them + tq = torch.flip(torch.arccos(cost), dims=(0,)) + + # determine the dimensions + self.mmax = mmax or self.nlon // 2 + 1 + + # combine quadrature weights with the legendre weights + pct = torch.as_tensor(_precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)) + weights = torch.einsum('mlk,k->mlk', pct, w) + + # remember quadrature weights + self.weights = weights.float().to(get_device()) + + def extra_repr(self): + """ + Pretty print module + """ + return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' + + def forward(self, x: torch.Tensor, timer: Timer = NullTimer()): + + assert(x.shape[-2] == self.nlat) + assert(x.shape[-1] == self.nlon) + with torch.autocast("cuda", enabled=False): + with timer.child("rfft"): + # rfft and view_as_complex don't support BF16, see https://github.com/pytorch/pytorch/issues/117844 + x = x.float() + + # apply real fft in the longitudinal direction + x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward") + + with timer.child("contraction"): + # do the Legendre-Gauss quadrature + x = torch.view_as_real(x) + + # distributed contraction: fork + out_shape = list(x.size()) + out_shape[-3] = self.lmax + out_shape[-2] = self.mmax + xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) + + # contraction + weights = self.weights.to(x.device).to(x.dtype) + xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], weights) + xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], weights) + x = torch.view_as_complex(xout) + + return x + +class InverseRealSHT(nn.Module): + """ + Defines a module for computing the inverse (real-valued) SHT. + Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. + nlat, nlon: Output dimensions + lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions + + [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. + [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True): + + super().__init__() + + self.nlat = nlat + self.nlon = nlon + self.grid = grid + self.norm = norm + self.csphase = csphase + + # compute quadrature points + if self.grid == "legendre-gauss": + cost, _ = legendre_gauss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "lobatto": + cost, _ = lobatto_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat-1 + elif self.grid == "equiangular": + cost, _ = clenshaw_curtiss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "healpix": + raise(NotImplementedError("'healpix' grid not supported by RealVectorSHT")) + else: + raise(ValueError("Unknown quadrature mode")) + + # apply cosine transform and flip them + t = torch.flip(torch.arccos(cost), dims=(0,)) + + # determine the dimensions + self.mmax = mmax or self.nlon // 2 + 1 + + pct = torch.as_tensor(_precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)) + + # register buffer + self.pct = pct.float().to(get_device()) + + def extra_repr(self): + """ + Pretty print module + """ + return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' + + def forward(self, x: torch.Tensor, timer: Timer = NullTimer()): + + assert(x.shape[-2] == self.lmax) + assert(x.shape[-1] == self.mmax) + + with torch.autocast("cuda", enabled=False): + with timer.child("contraction"): + # irfft and view_as_complex don't support BF16, see https://github.com/pytorch/pytorch/issues/117844 + # Evaluate associated Legendre functions on the output nodes + x = torch.view_as_real(x).float() + + pct = self.pct.to(x.device).to(x.dtype) + rl = torch.einsum('...lm, mlk->...km', x[..., 0], pct ) + im = torch.einsum('...lm, mlk->...km', x[..., 1], pct ) + xs = torch.stack((rl, im), -1) + + # apply the inverse (real) FFT + x = torch.view_as_complex(xs) + with timer.child("irfft"): + x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") + + return x