diff --git a/fme/ace/__init__.py b/fme/ace/__init__.py index f99f968a0..144808c0e 100644 --- a/fme/ace/__init__.py +++ b/fme/ace/__init__.py @@ -42,7 +42,7 @@ from fme.ace.registry.land_net import LandNetBuilder from fme.ace.registry.m2lines import FloeNetBuilder, SamudraBuilder from fme.ace.registry.sfno import SFNO_V0_1_0, SphericalFourierNeuralOperatorBuilder -from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNO +from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNOBuilder from fme.ace.stepper import DerivedForcingsConfig, StepperOverrideConfig from fme.ace.stepper.insolation.config import InsolationConfig, NameConfig, ValueConfig from fme.ace.stepper.parameter_init import ( diff --git a/fme/ace/registry/stochastic_sfno.py b/fme/ace/registry/stochastic_sfno.py index 8be3f0193..8a2c9de1e 100644 --- a/fme/ace/registry/stochastic_sfno.py +++ b/fme/ace/registry/stochastic_sfno.py @@ -1,127 +1,10 @@ import dataclasses -import math -from collections.abc import Callable from typing import Literal -import torch - from fme.ace.registry.registry import ModuleConfig, ModuleSelector from fme.core.dataset_info import DatasetInfo -from fme.core.models.conditional_sfno.sfnonet import ( - Context, - ContextConfig, - get_lat_lon_sfnonet, -) -from fme.core.models.conditional_sfno.sfnonet import ( - SphericalFourierNeuralOperatorNet as ConditionalSFNO, -) - - -def isotropic_noise( - leading_shape: tuple[int, ...], - lmax: int, # length of the ℓ axis expected by isht - mmax: int, # length of the m axis expected by isht - isht: Callable[[torch.Tensor], torch.Tensor], - device: torch.device, -) -> torch.Tensor: - # --- draw independent N(0,1) parts -------------------------------------- - coeff_shape = (*leading_shape, lmax, mmax) - real = torch.randn(coeff_shape, dtype=torch.float32, device=device) - imag = torch.randn(coeff_shape, dtype=torch.float32, device=device) - imag[..., :, 0] = 0.0 # m = 0 ⇒ purely real - - # m > 0: make Re and Im each N(0,½) → |a_{ℓ m}|² has variance 1 - sqrt2 = math.sqrt(2.0) - real[..., :, 1:] /= sqrt2 - imag[..., :, 1:] /= sqrt2 - - # --- global scale that makes Var[T(θ,φ)] = 1 --------------------------- - scale = math.sqrt(4.0 * math.pi) / lmax # (Unsöld theorem ⇒ L = lmax) - alm = (real + 1j * imag) * scale - - return isht(alm) - - -class NoiseConditionedSFNO(torch.nn.Module): - def __init__( - self, - conditional_model: ConditionalSFNO, - img_shape: tuple[int, int], - noise_type: Literal["isotropic", "gaussian"] = "gaussian", - embed_dim_noise: int = 256, - embed_dim_pos: int = 0, - embed_dim_labels: int = 0, - ): - super().__init__() - self.conditional_model = conditional_model - self.embed_dim = embed_dim_noise - self.noise_type = noise_type - self.label_pos_embed: torch.nn.Parameter | None = None - # register pos embed if pos_embed_dim != 0 - if embed_dim_pos != 0: - self.pos_embed = torch.nn.Parameter( - torch.zeros( - 1, embed_dim_pos, img_shape[0], img_shape[1], requires_grad=True - ) - ) - # initialize pos embed with std=0.02 - torch.nn.init.trunc_normal_(self.pos_embed, std=0.02) - if embed_dim_labels > 0: - self.label_pos_embed = torch.nn.Parameter( - torch.zeros( - embed_dim_labels, - embed_dim_pos, - img_shape[0], - img_shape[1], - requires_grad=True, - ) - ) - torch.nn.init.trunc_normal_(self.label_pos_embed, std=0.02) - else: - self.pos_embed = None - - def forward( - self, x: torch.Tensor, labels: torch.Tensor | None = None - ) -> torch.Tensor: - x = x.reshape(-1, *x.shape[-3:]) - if self.noise_type == "isotropic": - lmax = self.conditional_model.itrans_up.lmax - mmax = self.conditional_model.itrans_up.mmax - noise = isotropic_noise( - (x.shape[0], self.embed_dim), - lmax, - mmax, - self.conditional_model.itrans_up, - device=x.device, - ) - elif self.noise_type == "gaussian": - noise = torch.randn( - [x.shape[0], self.embed_dim, *x.shape[-2:]], - device=x.device, - dtype=x.dtype, - ) - else: - raise ValueError(f"Invalid noise type: {self.noise_type}") - - if self.pos_embed is not None: - embedding_pos = self.pos_embed.repeat(noise.shape[0], 1, 1, 1) - if self.label_pos_embed is not None and labels is not None: - label_embedding_pos = torch.einsum( - "bl, lpxy -> bpxy", labels, self.label_pos_embed - ) - embedding_pos = embedding_pos + label_embedding_pos - else: - embedding_pos = None - - return self.conditional_model( - x, - Context( - embedding_scalar=None, - embedding_pos=embedding_pos, - labels=labels, - noise=noise, - ), - ) +from fme.core.models.conditional_sfno.v0.stochastic_sfno import build as build_v0 +from fme.core.models.conditional_sfno.v1.stochastic_sfno import build as build_v1 # this is based on the call signature of SphericalFourierNeuralOperatorNet at @@ -135,6 +18,7 @@ class NoiseConditionedSFNOBuilder(ModuleConfig): Noise is provided as conditioning input to conditional layer normalization. Attributes: + version: Version of the model. spectral_transform: Type of spherical transform to use. Kept for backwards compatibility. filter_type: Type of filter to use. @@ -186,6 +70,7 @@ class NoiseConditionedSFNOBuilder(ModuleConfig): Defaults to spectral_lora_rank. """ + version: Literal["v0", "v1", "latest"] = "v0" spectral_transform: Literal["sht"] = "sht" filter_type: Literal["linear", "makani-linear"] = "linear" operator_type: Literal["dhconv"] = "dhconv" @@ -236,6 +121,10 @@ def __post_init__(self): "Only 'dhconv' operator_type is supported for " "NoiseConditionedSFNO models." ) + if self.version == "latest": + # must replace as eventual newer versions break backwards compatibility + # v1 is not stable yet, keep using v0 as default for now + self.version = "v0" def build( self, @@ -243,23 +132,19 @@ def build( n_out_channels: int, dataset_info: DatasetInfo, ): - sfno_net = get_lat_lon_sfnonet( - params=self, - in_chans=n_in_channels, - out_chans=n_out_channels, - img_shape=dataset_info.img_shape, - context_config=ContextConfig( - embed_dim_scalar=0, - embed_dim_pos=self.context_pos_embed_dim, - embed_dim_noise=self.noise_embed_dim, - embed_dim_labels=len(dataset_info.all_labels), - ), - ) - return NoiseConditionedSFNO( - sfno_net, - noise_type=self.noise_type, - embed_dim_noise=self.noise_embed_dim, - embed_dim_pos=self.context_pos_embed_dim, - embed_dim_labels=len(dataset_info.all_labels), - img_shape=dataset_info.img_shape, - ) + if self.version == "v0": + return build_v0( + self, + n_in_channels=n_in_channels, + n_out_channels=n_out_channels, + dataset_info=dataset_info, + ) + elif self.version == "v1": + return build_v1( + self, + n_in_channels=n_in_channels, + n_out_channels=n_out_channels, + dataset_info=dataset_info, + ) + else: + raise ValueError(f"Unsupported version: {self.version}") diff --git a/fme/core/models/__init__.py b/fme/core/models/__init__.py new file mode 100644 index 000000000..585a9f834 --- /dev/null +++ b/fme/core/models/__init__.py @@ -0,0 +1 @@ +from . import conditional_sfno, mlp diff --git a/fme/core/models/conditional_sfno/__init__.py b/fme/core/models/conditional_sfno/__init__.py index e69de29bb..c82496585 100644 --- a/fme/core/models/conditional_sfno/__init__.py +++ b/fme/core/models/conditional_sfno/__init__.py @@ -0,0 +1 @@ +from . import v0 diff --git a/fme/core/models/conditional_sfno/v0/__init__.py b/fme/core/models/conditional_sfno/v0/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fme/core/models/conditional_sfno/activations.py b/fme/core/models/conditional_sfno/v0/activations.py similarity index 100% rename from fme/core/models/conditional_sfno/activations.py rename to fme/core/models/conditional_sfno/v0/activations.py diff --git a/fme/core/models/conditional_sfno/contractions.py b/fme/core/models/conditional_sfno/v0/contractions.py similarity index 100% rename from fme/core/models/conditional_sfno/contractions.py rename to fme/core/models/conditional_sfno/v0/contractions.py diff --git a/fme/core/models/conditional_sfno/initialization.py b/fme/core/models/conditional_sfno/v0/initialization.py similarity index 100% rename from fme/core/models/conditional_sfno/initialization.py rename to fme/core/models/conditional_sfno/v0/initialization.py diff --git a/fme/core/models/conditional_sfno/layers.py b/fme/core/models/conditional_sfno/v0/layers.py similarity index 99% rename from fme/core/models/conditional_sfno/layers.py rename to fme/core/models/conditional_sfno/v0/layers.py index 5f6dcbe8e..74dad079d 100644 --- a/fme/core/models/conditional_sfno/layers.py +++ b/fme/core/models/conditional_sfno/v0/layers.py @@ -21,11 +21,10 @@ import torch import torch.fft import torch.nn as nn -import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from fme.core.benchmark.timer import Timer, NullTimer -from fme.core.models.conditional_sfno.lora import LoRAConv2d +from .lora import LoRAConv2d from .activations import ComplexReLU from .contractions import compl_mul2d_fwd, compl_muladd2d_fwd diff --git a/fme/core/models/conditional_sfno/lora.py b/fme/core/models/conditional_sfno/v0/lora.py similarity index 100% rename from fme/core/models/conditional_sfno/lora.py rename to fme/core/models/conditional_sfno/v0/lora.py diff --git a/fme/core/models/conditional_sfno/makani/__init__.py b/fme/core/models/conditional_sfno/v0/makani/__init__.py similarity index 100% rename from fme/core/models/conditional_sfno/makani/__init__.py rename to fme/core/models/conditional_sfno/v0/makani/__init__.py diff --git a/fme/core/models/conditional_sfno/makani/contractions.py b/fme/core/models/conditional_sfno/v0/makani/contractions.py similarity index 100% rename from fme/core/models/conditional_sfno/makani/contractions.py rename to fme/core/models/conditional_sfno/v0/makani/contractions.py diff --git a/fme/core/models/conditional_sfno/makani/factorizations.py b/fme/core/models/conditional_sfno/v0/makani/factorizations.py similarity index 100% rename from fme/core/models/conditional_sfno/makani/factorizations.py rename to fme/core/models/conditional_sfno/v0/makani/factorizations.py diff --git a/fme/core/models/conditional_sfno/makani/spectral_convolution.py b/fme/core/models/conditional_sfno/v0/makani/spectral_convolution.py similarity index 100% rename from fme/core/models/conditional_sfno/makani/spectral_convolution.py rename to fme/core/models/conditional_sfno/v0/makani/spectral_convolution.py diff --git a/fme/core/models/conditional_sfno/s2convolutions.py b/fme/core/models/conditional_sfno/v0/s2convolutions.py similarity index 100% rename from fme/core/models/conditional_sfno/s2convolutions.py rename to fme/core/models/conditional_sfno/v0/s2convolutions.py diff --git a/fme/core/models/conditional_sfno/sfnonet.py b/fme/core/models/conditional_sfno/v0/sfnonet.py similarity index 100% rename from fme/core/models/conditional_sfno/sfnonet.py rename to fme/core/models/conditional_sfno/v0/sfnonet.py diff --git a/fme/core/models/conditional_sfno/sht.py b/fme/core/models/conditional_sfno/v0/sht.py similarity index 100% rename from fme/core/models/conditional_sfno/sht.py rename to fme/core/models/conditional_sfno/v0/sht.py diff --git a/fme/core/models/conditional_sfno/v0/stochastic_sfno.py b/fme/core/models/conditional_sfno/v0/stochastic_sfno.py new file mode 100644 index 000000000..f82e3fc86 --- /dev/null +++ b/fme/core/models/conditional_sfno/v0/stochastic_sfno.py @@ -0,0 +1,145 @@ +import math +from collections.abc import Callable +from typing import Literal + +import torch + +from fme.core.dataset_info import DatasetInfo + +from .sfnonet import Context, ContextConfig, get_lat_lon_sfnonet +from .sfnonet import SphericalFourierNeuralOperatorNet as ConditionalSFNO + + +def isotropic_noise( + leading_shape: tuple[int, ...], + lmax: int, # length of the ℓ axis expected by isht + mmax: int, # length of the m axis expected by isht + isht: Callable[[torch.Tensor], torch.Tensor], + device: torch.device, +) -> torch.Tensor: + # --- draw independent N(0,1) parts -------------------------------------- + coeff_shape = (*leading_shape, lmax, mmax) + real = torch.randn(coeff_shape, dtype=torch.float32, device=device) + imag = torch.randn(coeff_shape, dtype=torch.float32, device=device) + imag[..., :, 0] = 0.0 # m = 0 ⇒ purely real + + # m > 0: make Re and Im each N(0,½) → |a_{ℓ m}|² has variance 1 + sqrt2 = math.sqrt(2.0) + real[..., :, 1:] /= sqrt2 + imag[..., :, 1:] /= sqrt2 + + # --- global scale that makes Var[T(θ,φ)] = 1 --------------------------- + scale = math.sqrt(4.0 * math.pi) / lmax # (Unsöld theorem ⇒ L = lmax) + alm = (real + 1j * imag) * scale + + return isht(alm) + + +class NoiseConditionedSFNO(torch.nn.Module): + def __init__( + self, + conditional_model: ConditionalSFNO, + img_shape: tuple[int, int], + noise_type: Literal["isotropic", "gaussian"] = "gaussian", + embed_dim_noise: int = 256, + embed_dim_pos: int = 0, + embed_dim_labels: int = 0, + ): + super().__init__() + self.conditional_model = conditional_model + self.embed_dim = embed_dim_noise + self.noise_type = noise_type + self.label_pos_embed: torch.nn.Parameter | None = None + # register pos embed if pos_embed_dim != 0 + if embed_dim_pos != 0: + self.pos_embed = torch.nn.Parameter( + torch.zeros( + 1, embed_dim_pos, img_shape[0], img_shape[1], requires_grad=True + ) + ) + # initialize pos embed with std=0.02 + torch.nn.init.trunc_normal_(self.pos_embed, std=0.02) + if embed_dim_labels > 0: + self.label_pos_embed = torch.nn.Parameter( + torch.zeros( + embed_dim_labels, + embed_dim_pos, + img_shape[0], + img_shape[1], + requires_grad=True, + ) + ) + torch.nn.init.trunc_normal_(self.label_pos_embed, std=0.02) + else: + self.pos_embed = None + + def forward( + self, x: torch.Tensor, labels: torch.Tensor | None = None + ) -> torch.Tensor: + x = x.reshape(-1, *x.shape[-3:]) + if self.noise_type == "isotropic": + lmax = self.conditional_model.itrans_up.lmax + mmax = self.conditional_model.itrans_up.mmax + noise = isotropic_noise( + (x.shape[0], self.embed_dim), + lmax, + mmax, + self.conditional_model.itrans_up, + device=x.device, + ) + elif self.noise_type == "gaussian": + noise = torch.randn( + [x.shape[0], self.embed_dim, *x.shape[-2:]], + device=x.device, + dtype=x.dtype, + ) + else: + raise ValueError(f"Invalid noise type: {self.noise_type}") + + if self.pos_embed is not None: + embedding_pos = self.pos_embed.repeat(noise.shape[0], 1, 1, 1) + if self.label_pos_embed is not None and labels is not None: + label_embedding_pos = torch.einsum( + "bl, lpxy -> bpxy", labels, self.label_pos_embed + ) + embedding_pos = embedding_pos + label_embedding_pos + else: + embedding_pos = None + + return self.conditional_model( + x, + Context( + embedding_scalar=None, + embedding_pos=embedding_pos, + labels=labels, + noise=noise, + ), + ) + + +def build( + params, + n_in_channels: int, + n_out_channels: int, + dataset_info: DatasetInfo, +): + sfno_net = get_lat_lon_sfnonet( + params=params, + in_chans=n_in_channels, + out_chans=n_out_channels, + img_shape=dataset_info.img_shape, + context_config=ContextConfig( + embed_dim_scalar=0, + embed_dim_pos=params.context_pos_embed_dim, + embed_dim_noise=params.noise_embed_dim, + embed_dim_labels=len(dataset_info.all_labels), + ), + ) + return NoiseConditionedSFNO( + sfno_net, + noise_type=params.noise_type, + embed_dim_noise=params.noise_embed_dim, + embed_dim_pos=params.context_pos_embed_dim, + embed_dim_labels=len(dataset_info.all_labels), + img_shape=dataset_info.img_shape, + ) diff --git a/fme/core/models/conditional_sfno/test_layers.py b/fme/core/models/conditional_sfno/v0/test_layers.py similarity index 100% rename from fme/core/models/conditional_sfno/test_layers.py rename to fme/core/models/conditional_sfno/v0/test_layers.py diff --git a/fme/core/models/conditional_sfno/test_lora.py b/fme/core/models/conditional_sfno/v0/test_lora.py similarity index 89% rename from fme/core/models/conditional_sfno/test_lora.py rename to fme/core/models/conditional_sfno/v0/test_lora.py index d880098d5..e08a5a0d3 100644 --- a/fme/core/models/conditional_sfno/test_lora.py +++ b/fme/core/models/conditional_sfno/v0/test_lora.py @@ -1,7 +1,7 @@ import torch from torch import nn -from fme.core.models.conditional_sfno.lora import LoRAConv2d +from .lora import LoRAConv2d def test_lora_conv2d_load_conv2d_checkpoint(): diff --git a/fme/core/models/conditional_sfno/test_s2convolutions.py b/fme/core/models/conditional_sfno/v0/test_s2convolutions.py similarity index 96% rename from fme/core/models/conditional_sfno/test_s2convolutions.py rename to fme/core/models/conditional_sfno/v0/test_s2convolutions.py index cc6fcd1d6..f1f992b49 100644 --- a/fme/core/models/conditional_sfno/test_s2convolutions.py +++ b/fme/core/models/conditional_sfno/v0/test_s2convolutions.py @@ -5,9 +5,8 @@ from fme.core.device import get_device from fme.core.gridded_ops import LatLonOperations -from fme.core.models.conditional_sfno.s2convolutions import SpectralConvS2 -from .s2convolutions import _contract_dhconv +from .s2convolutions import SpectralConvS2, _contract_dhconv @dataclasses.dataclass diff --git a/fme/core/models/conditional_sfno/test_sfnonet.py b/fme/core/models/conditional_sfno/v0/test_sfnonet.py similarity index 100% rename from fme/core/models/conditional_sfno/test_sfnonet.py rename to fme/core/models/conditional_sfno/v0/test_sfnonet.py diff --git a/fme/ace/registry/test_stochastic_sfno.py b/fme/core/models/conditional_sfno/v0/test_stochastic_sfno.py similarity index 95% rename from fme/ace/registry/test_stochastic_sfno.py rename to fme/core/models/conditional_sfno/v0/test_stochastic_sfno.py index 191cc72d7..dbabd76f5 100644 --- a/fme/ace/registry/test_stochastic_sfno.py +++ b/fme/core/models/conditional_sfno/v0/test_stochastic_sfno.py @@ -4,13 +4,10 @@ import torch from torch_harmonics import InverseRealSHT -from fme.ace.registry.stochastic_sfno import ( - Context, - NoiseConditionedSFNO, - isotropic_noise, -) from fme.core.device import get_device +from .stochastic_sfno import Context, NoiseConditionedSFNO, isotropic_noise + @pytest.mark.parametrize("nlat, nlon", [(8, 16), (64, 128)]) def test_isotropic_noise(nlat: int, nlon: int): diff --git a/fme/core/models/conditional_sfno/testdata/test_sfnonet_output_is_unchanged.pt b/fme/core/models/conditional_sfno/v0/testdata/test_sfnonet_output_is_unchanged.pt similarity index 100% rename from fme/core/models/conditional_sfno/testdata/test_sfnonet_output_is_unchanged.pt rename to fme/core/models/conditional_sfno/v0/testdata/test_sfnonet_output_is_unchanged.pt diff --git a/fme/core/models/conditional_sfno/v1/__init__.py b/fme/core/models/conditional_sfno/v1/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fme/core/models/conditional_sfno/v1/activations.py b/fme/core/models/conditional_sfno/v1/activations.py new file mode 100644 index 000000000..071ce6dd8 --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/activations.py @@ -0,0 +1,116 @@ +# flake8: noqa +# Copied from https://github.com/ai2cm/modulus/commit/22df4a9427f5f12ff6ac891083220e7f2f54d229 +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + + +class ComplexReLU(nn.Module): + """ + Complex-valued variants of the ReLU activation function + """ + + def __init__(self, negative_slope=0.0, mode="real", bias_shape=None, scale=1.0): + super(ComplexReLU, self).__init__() + + # store parameters + self.mode = mode + if self.mode in ["modulus", "halfplane"]: + if bias_shape is not None: + self.bias = nn.Parameter( + scale * torch.ones(bias_shape, dtype=torch.float32) + ) + else: + self.bias = nn.Parameter(scale * torch.ones((1), dtype=torch.float32)) + else: + self.bias = 0 + + self.negative_slope = negative_slope + self.act = nn.LeakyReLU(negative_slope=negative_slope) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + if self.mode == "cartesian": + zr = torch.view_as_real(z) + za = self.act(zr) + out = torch.view_as_complex(za) + + elif self.mode == "modulus": + zabs = torch.sqrt(torch.square(z.real) + torch.square(z.imag)) + out = torch.where(zabs + self.bias > 0, (zabs + self.bias) * z / zabs, 0.0) + # out = self.act(zabs - self.bias) * torch.exp(1.j * z.angle()) + + elif self.mode == "halfplane": + # bias is an angle parameter in this case + modified_angle = torch.angle(z) - self.bias + condition = torch.logical_and( + (0.0 <= modified_angle), (modified_angle < torch.pi / 2.0) + ) + out = torch.where(condition, z, self.negative_slope * z) + + elif self.mode == "real": + zr = torch.view_as_real(z) + outr = zr.clone() + outr[..., 0] = self.act(zr[..., 0]) + out = torch.view_as_complex(outr) + + else: + raise NotImplementedError + + return out + + +class ComplexActivation(nn.Module): + """ + A module implementing complex-valued activation functions. + The module supports different modes of operation, depending on how + the complex numbers are treated for the activation function: + - "cartesian": the activation function is applied separately to the + real and imaginary parts of the complex input. + - "modulus": the activation function is applied to the modulus of + the complex input, after adding a learnable bias. + - any other mode: the complex input is returned as-is (identity operation). + """ + + def __init__(self, activation, mode="cartesian", bias_shape=None): + super(ComplexActivation, self).__init__() + + # store parameters + self.mode = mode + if self.mode == "modulus": + if bias_shape is not None: + self.bias = nn.Parameter(torch.zeros(bias_shape, dtype=torch.float32)) + else: + self.bias = nn.Parameter(torch.zeros((1), dtype=torch.float32)) + else: + bias = torch.zeros((1), dtype=torch.float32) + self.register_buffer("bias", bias) + + # real valued activation + self.act = activation + + def forward(self, z: torch.Tensor) -> torch.Tensor: + if self.mode == "cartesian": + zr = torch.view_as_real(z) + za = self.act(zr) + out = torch.view_as_complex(za) + elif self.mode == "modulus": + zabs = torch.sqrt(torch.square(z.real) + torch.square(z.imag)) + out = self.act(zabs + self.bias) * torch.exp(1.0j * z.angle()) + else: + # identity + out = z + + return out diff --git a/fme/core/models/conditional_sfno/v1/contractions.py b/fme/core/models/conditional_sfno/v1/contractions.py new file mode 100644 index 000000000..add13299e --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/contractions.py @@ -0,0 +1,223 @@ +# flake8: noqa +# Copied from https://github.com/ai2cm/modulus/commit/22df4a9427f5f12ff6ac891083220e7f2f54d229 +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +@torch.jit.script +def compl_mul1d_fwd( + a: torch.Tensor, b: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs a complex-valued multiplication operation between two 1-dimensional + tensors. + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bix,io->box", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script +def compl_muladd1d_fwd( + a: torch.Tensor, b: torch.Tensor, c: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs complex multiplication of two 1-dimensional tensors 'a' and 'b', and then + adds a third tensor 'c'. + """ + tmpcc = torch.view_as_complex(compl_mul1d_fwd(a, b)) + cc = torch.view_as_complex(c) + return torch.view_as_real(tmpcc + cc) + + +@torch.jit.script +def compl_mul2d_fwd( + a: torch.Tensor, b: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs a complex-valued multiplication operation between two 2-dimensional + tensors. + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,io->boxy", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script +def compl_muladd2d_fwd( + a: torch.Tensor, b: torch.Tensor, c: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs complex multiplication of two 2-dimensional tensors 'a' and 'b', and then + adds a third tensor 'c'. + """ + tmpcc = torch.view_as_complex(compl_mul2d_fwd(a, b)) + cc = torch.view_as_complex(c) + return torch.view_as_real(tmpcc + cc) + + +@torch.jit.script # TODO remove +def _contract_localconv_fwd( + a: torch.Tensor, b: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs a complex local convolution operation between two tensors 'a' and 'b'. + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,iox->boxy", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script # TODO remove +def _contract_blockconv_fwd( + a: torch.Tensor, b: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs a complex block convolution operation between two tensors 'a' and 'b'. + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bim,imn->bin", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script # TODO remove +def _contractadd_blockconv_fwd( + a: torch.Tensor, b: torch.Tensor, c: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs a complex block convolution operation between two tensors 'a' and 'b', and + then adds a third tensor 'c'. + """ + tmpcc = torch.view_as_complex(_contract_blockconv_fwd(a, b)) + cc = torch.view_as_complex(c) + return torch.view_as_real(tmpcc + cc) + + +# for the experimental layer +@torch.jit.script # TODO remove +def compl_exp_mul2d_fwd( + a: torch.Tensor, b: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs a 2D complex multiplication operation between two tensors 'a' and 'b'. + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,xio->boxy", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script +def compl_exp_muladd2d_fwd( # TODO remove + a: torch.Tensor, b: torch.Tensor, c: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs a 2D complex multiplication operation between two tensors 'a' and 'b', + and then adds a third tensor 'c'. + """ + tmpcc = torch.view_as_complex(compl_exp_mul2d_fwd(a, b)) + cc = torch.view_as_complex(c) + return torch.view_as_real(tmpcc + cc) + + +@torch.jit.script +def real_mul2d_fwd( + a: torch.Tensor, b: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs a 2D real multiplication operation between two tensors 'a' and 'b'. + """ + res = torch.einsum("bixy,io->boxy", a, b) + return res + + +@torch.jit.script +def real_muladd2d_fwd( + a: torch.Tensor, b: torch.Tensor, c: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs a 2D real multiplication operation between two tensors 'a' and 'b', and + then adds a third tensor 'c'. + """ + res = real_mul2d_fwd(a, b) + c + return res + + +# new contractions set to replace older ones. We use complex +@torch.jit.script +def _contract_diagonal( + a: torch.Tensor, b: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs a complex diagonal operation between two tensors 'a' and 'b'. + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,ioxy->boxy", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script +def _contract_dhconv( + a: torch.Tensor, b: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs a complex Driscoll-Healy style convolution operation between two tensors + 'a' and 'b'. + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,iox->boxy", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script +def _contract_sep_diagonal( + a: torch.Tensor, b: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs a complex convolution operation between two tensors 'a' and 'b' + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,ixy->boxy", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script +def _contract_sep_dhconv( + a: torch.Tensor, b: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs a complex convolution operation between two tensors 'a' and 'b' + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,ix->boxy", ac, bc) + res = torch.view_as_real(resc) + return res diff --git a/fme/core/models/conditional_sfno/v1/initialization.py b/fme/core/models/conditional_sfno/v1/initialization.py new file mode 100644 index 000000000..f8910c359 --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/initialization.py @@ -0,0 +1,75 @@ +# flake8: noqa +# Copied from https://github.com/ai2cm/modulus/commit/22df4a9427f5f12ff6ac891083220e7f2f54d229 +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings + +import torch + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/fme/core/models/conditional_sfno/v1/layers.py b/fme/core/models/conditional_sfno/v1/layers.py new file mode 100644 index 000000000..74dad079d --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/layers.py @@ -0,0 +1,708 @@ +# flake8: noqa +# Copied from https://github.com/ai2cm/modulus/commit/22df4a9427f5f12ff6ac891083220e7f2f54d229 +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import math +from typing import Tuple + +import torch +import torch.fft +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from fme.core.benchmark.timer import Timer, NullTimer +from .lora import LoRAConv2d + +from .activations import ComplexReLU +from .contractions import compl_mul2d_fwd, compl_muladd2d_fwd + + +@dataclasses.dataclass +class ContextConfig: + """ + Configuration for the context. + """ + + embed_dim_scalar: int + embed_dim_labels: int + embed_dim_noise: int + embed_dim_pos: int + + +@dataclasses.dataclass +class Context: + """ + Context for the conditional layer normalization. + + Parameters: + embedding_scalar: The scalar embedding to condition on. The + last dimension is the channel dimension. + embedding_pos: The positional embedding to condition on. The last + three dimensions are (channels, height, width). + labels: The labels to condition on, of shape (batch_size, n_labels). + noise: The 2D noise embedding to condition on. The last + three dimensions are (channels, height, width). + """ + + embedding_scalar: torch.Tensor | None + embedding_pos: torch.Tensor | None + labels: torch.Tensor | None + noise: torch.Tensor | None + + def __post_init__(self): + if ( + self.embedding_scalar is not None + and self.noise is not None + and self.noise.ndim != self.embedding_scalar.ndim + 2 + ): + raise ValueError("noise must have 2 more dimensions than embedding_scalar") + if self.labels is not None and self.labels.ndim != 2: + raise ValueError("labels must have 2 dimensions") + + +class ChannelLayerNorm(nn.Module): + """ + Layer Normalization over third-last channel dimension. + """ + + def __init__( + self, n_channels: int, eps: float = 1e-5, elementwise_affine: bool = False + ): + super(ChannelLayerNorm, self).__init__() + self.n_channels = n_channels + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(n_channels)) + self.bias = nn.Parameter(torch.zeros(n_channels)) + else: + self.weight = None + self.bias = None + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + torch.nn.init.constant_(self.weight, 1.0) + torch.nn.init.constant_(self.bias, 0.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.dim() < 2: + raise ValueError( + f"Expected at least 3D input with channel at dim=-3, got shape {tuple(x.shape)}" + ) + if x.size(-3) != self.n_channels: + raise ValueError( + f"Channel dimension mismatch: got C={x.size(-3)}, expected {self.n_channels}" + ) + + # Compute per-pixel mean/var across channels without transposing + mean = x.mean(dim=-3, keepdim=True) + var = x.var(dim=-3, keepdim=True, unbiased=False) + inv_std = torch.rsqrt(var + self.eps) + y = (x - mean) * inv_std + + if self.weight is not None and self.bias is not None: + # Broadcast [C] over [N, C, *spatial] + shape = [1, -1] + [1] * (x.dim() - 2) + y = y * self.weight.view(*shape) + self.bias.view(*shape) + return y + + +class ConditionalLayerNorm(nn.Module): + """ + Conditional Layer Normalization as described in "AdaSpeech: Adaptive + Text to Speech for Custom Voice" https://arxiv.org/abs/2103.00993. + + Assumes that the input has shape (batch_size, channels, height, width). + """ + + def __init__( + self, + n_channels: int, + img_shape: Tuple[int, int], + context_config: ContextConfig, + global_layer_norm: bool = False, + epsilon: float = 1e-5, + elementwise_affine: bool = False, + ): + super(ConditionalLayerNorm, self).__init__() + self.n_channels = n_channels + self.embed_dim_scalar = context_config.embed_dim_scalar + self.embed_dim_labels = context_config.embed_dim_labels + self.embed_dim_pos = context_config.embed_dim_pos + self.embed_dim_noise = context_config.embed_dim_noise + self.epsilon = epsilon + if self.embed_dim_scalar > 0: + self.W_scale: nn.Linear | None = nn.Linear( + self.embed_dim_scalar, self.n_channels + ) + self.W_bias: nn.Linear | None = nn.Linear( + self.embed_dim_scalar, self.n_channels + ) + else: + self.W_scale = None + self.W_bias = None + if self.embed_dim_labels > 0: + self.W_scale_labels = nn.Linear(self.embed_dim_labels, self.n_channels) + self.W_bias_labels = nn.Linear(self.embed_dim_labels, self.n_channels) + else: + self.W_scale_labels = None + self.W_bias_labels = None + if self.embed_dim_noise > 0: + # no bias as it is already handled in the non-2d layers + self.W_scale_2d = nn.Conv2d( + self.embed_dim_noise, self.n_channels, kernel_size=1, bias=False + ) + self.W_bias_2d = nn.Conv2d( + self.embed_dim_noise, self.n_channels, kernel_size=1, bias=False + ) + else: + self.W_scale_2d = None + self.W_bias_2d = None + if self.embed_dim_pos > 0: + # no bias as it is already handled in the non-2d layers + self.W_scale_pos = nn.Conv2d( + self.embed_dim_pos, self.n_channels, kernel_size=1, bias=False + ) + self.W_bias_pos = nn.Conv2d( + self.embed_dim_pos, self.n_channels, kernel_size=1, bias=False + ) + else: + self.W_scale_pos = None + self.W_bias_pos = None + if global_layer_norm: + self.norm = nn.LayerNorm( + (self.n_channels, img_shape[0], img_shape[1]), + eps=epsilon, + elementwise_affine=elementwise_affine, + ) + else: + self.norm = ChannelLayerNorm( + self.n_channels, + eps=epsilon, + elementwise_affine=elementwise_affine, + ) + self._global_layer_norm = global_layer_norm + self.reset_parameters() + + def reset_parameters(self): + if self.W_scale is not None: + torch.nn.init.constant_(self.W_scale.weight, 0.0) + torch.nn.init.constant_(self.W_scale.bias, 1.0) + if self.W_bias is not None: + torch.nn.init.constant_(self.W_bias.weight, 0.0) + torch.nn.init.constant_(self.W_bias.bias, 0.0) + if self.W_scale_labels is not None: + torch.nn.init.constant_(self.W_scale_labels.weight, 0.0) + # bias starts at 1 for the first scale, we don't want to add more. + torch.nn.init.constant_(self.W_scale_labels.bias, 0.0) + if self.W_bias_labels is not None: + torch.nn.init.constant_(self.W_bias_labels.weight, 0.0) + torch.nn.init.constant_(self.W_bias_labels.bias, 0.0) + # no bias on 2d layers as it is already handled in the non-2d layers + if self.W_scale_2d is not None: + torch.nn.init.constant_(self.W_scale_2d.weight, 0.0) + if self.W_bias_2d is not None: + torch.nn.init.constant_(self.W_bias_2d.weight, 0.0) + if self.W_scale_pos is not None: + torch.nn.init.constant_(self.W_scale_pos.weight, 0.0) + if self.W_bias_pos is not None: + torch.nn.init.constant_(self.W_bias_pos.weight, 0.0) + # no bias on 2d layers as it is already handled in the non-2d layers + + def forward( + self, + x: torch.Tensor, + context: Context, + timer: Timer = NullTimer(), + ) -> torch.Tensor: + """ + Conditional Layer Normalization + + This is a modified version of LayerNorm that allows the scale and bias to be + conditioned on a context embedding. + + Args: + x: The input tensor to normalize, of shape + (batch_size, channels, height, width). + context: The context to condition on. + + Returns: + The normalized tensor, of shape (batch_size, channels, height, width). + """ + if context.labels is None and ( + self.W_scale_labels is not None or self.W_bias_labels is not None + ): + raise ValueError("labels must be provided") + with timer.child("compute_scaling_and_bias"): + if self.W_scale is not None: + if context.embedding_scalar is None: + raise ValueError("embedding_scalar must be provided") + scale: torch.Tensor = ( + self.W_scale(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) + ) + else: + scale = torch.ones( + list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype + ) + + if self.W_scale_2d is not None: + if context.noise is None: + raise ValueError("embedding_2d must be provided") + scale = scale + self.W_scale_2d(context.noise) + if self.W_bias is not None: + if context.embedding_scalar is None: + raise ValueError("embedding_scalar must be provided") + bias: torch.Tensor = ( + self.W_bias(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) + ) + else: + bias = torch.zeros( + list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype + ) + + if self.W_scale_labels is not None: + scale = scale + self.W_scale_labels(context.labels).unsqueeze( + -1 + ).unsqueeze(-1) + if self.W_bias_labels is not None: + bias = bias + self.W_bias_labels(context.labels).unsqueeze( + -1 + ).unsqueeze(-1) + if self.W_bias_2d is not None: + if context.noise is None: + raise ValueError("embedding_2d must be provided") + bias = bias + self.W_bias_2d(context.noise) + if self.W_scale_pos is not None: + if context.embedding_pos is None: + raise ValueError("embedding_pos must be provided") + scale = scale + self.W_scale_pos(context.embedding_pos) + if self.W_bias_pos is not None: + if context.embedding_pos is None: + raise ValueError("embedding_pos must be provided") + bias = bias + self.W_bias_pos(context.embedding_pos) + with timer.child("normalize"): + x_norm: torch.Tensor = self.norm(x) + with timer.child("apply_scaling_and_bias"): + return_value = x_norm * scale + bias + return return_value + + +@torch.jit.script +def drop_path( + x: torch.Tensor, drop_prob: float = 0.0, training: bool = False +) -> torch.Tensor: # pragma: no cover + """Drop paths (Stochastic Depth) per sample (when applied in main path of + residual blocks). + This is the same as the DropConnect impl for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in + a separate paper. See discussion: + https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 + We've opted for changing the layer and argument names to 'drop path' rather than + mix DropConnect as a layer name and use 'survival rate' as the argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1.0 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2d ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual + blocks). + """ + + def __init__(self, drop_prob=None): # pragma: no cover + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): # pragma: no cover + return drop_path(x, self.drop_prob, self.training) + + +class PatchEmbed(nn.Module): + """ + Divides the input image into patches and embeds them into a specified dimension + using a convolutional layer. + """ + + def __init__( + self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768 + ): # pragma: no cover + super(PatchEmbed, self).__init__() + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, x): # pragma: no cover + # gather input + B, C, H, W = x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + # new: B, C, H*W + x = self.proj(x).flatten(2) + return x + + +class MLP(nn.Module): + """ + Basic CNN with support for gradient checkpointing + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + output_bias=True, + drop_rate=0.0, + checkpointing=0, + lora_rank: int = 0, + lora_alpha: float | None = None, + ): # pragma: no cover + super(MLP, self).__init__() + self.checkpointing = checkpointing + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + fc1 = LoRAConv2d( + in_features, + hidden_features, + 1, + bias=True, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + act = act_layer() + fc2 = LoRAConv2d( + hidden_features, + out_features, + 1, + bias=output_bias, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + if drop_rate > 0.0: + drop = nn.Dropout(drop_rate) + self.fwd = nn.Sequential(fc1, act, drop, fc2, drop) + else: + self.fwd = nn.Sequential(fc1, act, fc2) + + # by default, all weights are shared + + @torch.jit.ignore + def checkpoint_forward(self, x): # pragma: no cover + """Forward method with support for gradient checkpointing""" + return checkpoint(self.fwd, x) + + def forward(self, x): # pragma: no cover + if self.checkpointing >= 2: + return self.checkpoint_forward(x) + else: + return self.fwd(x) + + +class RealFFT2(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None): # pragma: no cover + super(RealFFT2, self).__init__() + + # use local FFT here + self.fft_handle = torch.fft.rfft2 + + self.nlat = nlat + self.nlon = nlon + self.lmax = lmax or self.nlat + self.mmax = mmax or self.nlon // 2 + 1 + + self.truncate = True + if (self.lmax == self.nlat) and (self.mmax == (self.nlon // 2 + 1)): + self.truncate = False + + # self.num_batches = 1 + assert self.lmax % 2 == 0 + + def forward(self, x): # pragma: no cover + y = self.fft_handle(x, (self.nlat, self.nlon), (-2, -1), "ortho") + + if self.truncate: + y = torch.cat( + ( + y[..., : math.ceil(self.lmax / 2), : self.mmax], + y[..., -math.floor(self.lmax / 2) :, : self.mmax], + ), + dim=-2, + ) + + return y + + +class InverseRealFFT2(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None): # pragma: no cover + super(InverseRealFFT2, self).__init__() + + # use local FFT here + self.ifft_handle = torch.fft.irfft2 + + self.nlat = nlat + self.nlon = nlon + self.lmax = lmax or self.nlat + self.mmax = mmax or self.nlon // 2 + 1 + + def forward(self, x): # pragma: no cover + out = self.ifft_handle(x, (self.nlat, self.nlon), (-2, -1), "ortho") + + return out + + +class SpectralAttention2d(nn.Module): + """ + 2d Spectral Attention layer + """ + + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + sparsity_threshold=0.0, + hidden_size_factor=2, + use_complex_network=True, + use_complex_kernels=False, + complex_activation="real", + bias=False, + spectral_layers=1, + drop_rate=0.0, + ): # pragma: no cover + super(SpectralAttention2d, self).__init__() + + self.embed_dim = embed_dim + self.sparsity_threshold = sparsity_threshold + self.hidden_size = int(hidden_size_factor * self.embed_dim) + self.scale = 0.02 + self.spectral_layers = spectral_layers + if use_complex_kernels: + raise NotImplementedError("complex kernels not supported") + self.mul_add_handle = compl_muladd2d_fwd + self.mul_handle = compl_mul2d_fwd + + self.modes_lat = forward_transform.lmax + self.modes_lon = forward_transform.mmax + + # only storing the forward handle to be able to call it + self.forward_transform = forward_transform.forward + self.inverse_transform = inverse_transform.forward + + assert inverse_transform.lmax == self.modes_lat + assert inverse_transform.mmax == self.modes_lon + + # weights + w = [self.scale * torch.randn(self.embed_dim, self.hidden_size, 2)] + # w = [self.scale * torch.randn(self.embed_dim + 2*self.embed_freqs, self.hidden_size, 2)] + # w = [self.scale * torch.randn(self.embed_dim + 4*self.embed_freqs, self.hidden_size, 2)] + for l in range(1, self.spectral_layers): + w.append(self.scale * torch.randn(self.hidden_size, self.hidden_size, 2)) + self.w = nn.ParameterList(w) + + if bias: + self.b = nn.ParameterList( + [ + self.scale * torch.randn(self.hidden_size, 1, 2) + for _ in range(self.spectral_layers) + ] + ) + + self.wout = nn.Parameter( + self.scale * torch.randn(self.hidden_size, self.embed_dim, 2) + ) + + self.drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity() + + self.activation = ComplexReLU( + mode=complex_activation, bias_shape=(self.hidden_size, 1, 1) + ) + + def forward_mlp(self, xr): # pragma: no cover + """forward method for the MLP part of the network""" + for l in range(self.spectral_layers): + if hasattr(self, "b"): + xr = self.mul_add_handle( + xr, self.w[l].to(xr.dtype), self.b[l].to(xr.dtype) + ) + else: + xr = self.mul_handle(xr, self.w[l].to(xr.dtype)) + xr = torch.view_as_complex(xr) + xr = self.activation(xr) + xr = self.drop(xr) + xr = torch.view_as_real(xr) + + xr = self.mul_handle(xr, self.wout) + + return xr + + def forward(self, x): # pragma: no cover + dtype = x.dtype + # x = x.to(torch.float32) + + # FWD transform + with torch.amp.autocast("cuda", enabled=False): + x = x.to(torch.float32) + x = x.contiguous() + x = self.forward_transform(x) + x = torch.view_as_real(x) + + # MLP + x = self.forward_mlp(x) + + # BWD transform + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x) + x = x.contiguous() + x = self.inverse_transform(x) + x = x.to(dtype) + + return x + + +class SpectralAttentionS2(nn.Module): + """ + geometrical Spectral Attention layer + """ + + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + sparsity_threshold=0.0, + hidden_size_factor=2, + use_complex_network=True, + use_complex_kernels=False, + complex_activation="real", + bias=False, + spectral_layers=1, + drop_rate=0.0, + ): # pragma: no cover + super(SpectralAttentionS2, self).__init__() + + self.embed_dim = embed_dim + self.sparsity_threshold = sparsity_threshold + self.hidden_size = int(hidden_size_factor * self.embed_dim) + self.scale = 0.02 + if use_complex_kernels: + raise NotImplementedError("complex kernels not supported") + # self.mul_add_handle = compl_muladd1d_fwd_c if use_complex_kernels else compl_muladd1d_fwd + self.mul_add_handle = compl_muladd2d_fwd + # self.mul_handle = compl_mul1d_fwd_c if use_complex_kernels else compl_mul1d_fwd + self.mul_handle = compl_mul2d_fwd + self.spectral_layers = spectral_layers + + self.modes_lat = forward_transform.lmax + self.modes_lon = forward_transform.mmax + + # only storing the forward handle to be able to call it + self.forward_transform = forward_transform.forward + self.inverse_transform = inverse_transform.forward + + assert inverse_transform.lmax == self.modes_lat + assert inverse_transform.mmax == self.modes_lon + + # weights + w = [self.scale * torch.randn(self.embed_dim, self.hidden_size, 2)] + # w = [self.scale * torch.randn(self.embed_dim + 4*self.embed_freqs, self.hidden_size, 2)] + for l in range(1, self.spectral_layers): + w.append(self.scale * torch.randn(self.hidden_size, self.hidden_size, 2)) + self.w = nn.ParameterList(w) + + if bias: + self.b = nn.ParameterList( + [ + self.scale * torch.randn(2 * self.hidden_size, 1, 1, 2) + for _ in range(self.spectral_layers) + ] + ) + + self.wout = nn.Parameter( + self.scale * torch.randn(self.hidden_size, self.embed_dim, 2) + ) + + self.drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity() + + self.activation = ComplexReLU( + mode=complex_activation, bias_shape=(self.hidden_size, 1, 1) + ) + + def forward_mlp(self, xr): # pragma: no cover + """forward method for the MLP part of the network""" + for l in range(self.spectral_layers): + if hasattr(self, "b"): + xr = self.mul_add_handle( + xr, self.w[l].to(xr.dtype), self.b[l].to(xr.dtype) + ) + else: + xr = self.mul_handle(xr, self.w[l].to(xr.dtype)) + xr = torch.view_as_complex(xr) + xr = self.activation(xr) + xr = self.drop(xr) + xr = torch.view_as_real(xr) + + # final MLP + xr = self.mul_handle(xr, self.wout) + + return xr + + def forward(self, x): # pragma: no cover + dtype = x.dtype + # x = x.to(torch.float32) + + # FWD transform + with torch.amp.autocast("cuda", enabled=False): + x = x.to(torch.float32) + x = x.contiguous() + x = self.forward_transform(x) + x = torch.view_as_real(x) + + # MLP + x = self.forward_mlp(x) + + # BWD transform + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x) + x = x.contiguous() + x = self.inverse_transform(x) + x = x.to(dtype) + + return x diff --git a/fme/core/models/conditional_sfno/v1/lora.py b/fme/core/models/conditional_sfno/v1/lora.py new file mode 100644 index 000000000..ea548ef32 --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/lora.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import math + +import torch +import torch.nn as nn + + +class LoRAConv2d(nn.Conv2d): + """ + Drop-in Conv2d with optional LoRA. + + - API matches torch.nn.Conv2d, with extra args: + lora_rank: int = 0 (0 disables LoRA) + lora_alpha: float = None (defaults to lora_rank) + lora_dropout: float = 0.0 + + - Can load a checkpoint saved from nn.Conv2d even when lora_rank > 0 + (i.e., state_dict only has "weight"/"bias"). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int | tuple[int, int] = 1, + padding: int | tuple[int, int] = 0, + dilation: int | tuple[int, int] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + *, + lora_rank: int = 0, + lora_alpha: float | None = None, + lora_dropout: float = 0.0, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + self.lora_down: nn.Conv2d | None = None + self.lora_up: nn.Conv2d | None = None + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + **factory_kwargs, + ) + + if lora_rank < 0: + raise ValueError(f"lora_rank must be >= 0, got {lora_rank}") + if lora_dropout < 0.0: + raise ValueError(f"lora_dropout must be >= 0, got {lora_dropout}") + + self.lora_rank = int(lora_rank) + self.lora_alpha = ( + float(lora_alpha) if lora_alpha is not None else float(lora_rank) + ) + self.lora_dropout_p = float(lora_dropout) + + self._lora_merged = False + + if self.lora_rank > 0: + # Group-compatible LoRA via two convs: + # down: 1x1 grouped conv: in_channels -> (groups * r), groups=groups + # up: kxk grouped conv: (groups * r) -> out_channels, groups=groups + # This produces a delta with the same grouped structure as the base conv. + mid_channels = self.groups * self.lora_rank + + self.lora_down = nn.Conv2d( + in_channels=self.in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=self.groups, + bias=False, + **factory_kwargs, + ) + self.lora_up = nn.Conv2d( + in_channels=mid_channels, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + bias=False, + padding_mode=self.padding_mode, + **factory_kwargs, + ) + + self.lora_dropout = ( + nn.Dropout(p=self.lora_dropout_p) + if self.lora_dropout_p > 0 + else nn.Identity() + ) + + # Scaling as in LoRA: alpha / r + self.lora_scaling = self.lora_alpha / float(self.lora_rank) + else: + self.lora_dropout = nn.Identity() + self.lora_scaling = 0.0 + self.reset_lora_parameters() # base parameters already reset in super init + + def reset_parameters(self) -> None: + super().reset_parameters() + self.reset_lora_parameters() + + def reset_lora_parameters(self): + # Init: down ~ Kaiming, up = 0 so the module starts + # identical to base Conv2d. + if self.lora_down is not None: + nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + if self.lora_up is not None: + nn.init.zeros_(self.lora_up.weight) + + def extra_repr(self) -> str: + base = super().extra_repr() + if self.lora_rank > 0: + return ( + f"{base}, lora_rank={self.lora_rank}, lora_alpha={self.lora_alpha}, " + f"lora_dropout={self.lora_dropout_p}, lora_merged={self._lora_merged}" + ) + return f"{base}, lora_rank=0" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = super().forward(x) + if self.lora_rank == 0 or self._lora_merged: + return y + assert self.lora_down is not None and self.lora_up is not None + return ( + y + self.lora_up(self.lora_down(self.lora_dropout(x))) * self.lora_scaling + ) diff --git a/fme/core/models/conditional_sfno/v1/makani/__init__.py b/fme/core/models/conditional_sfno/v1/makani/__init__.py new file mode 100644 index 000000000..270cdfa96 --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/makani/__init__.py @@ -0,0 +1 @@ +from .spectral_convolution import SpectralConv # noqa: F401 diff --git a/fme/core/models/conditional_sfno/v1/makani/contractions.py b/fme/core/models/conditional_sfno/v1/makani/contractions.py new file mode 100644 index 000000000..06225ef0c --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/makani/contractions.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +# for the factorized spectral convolution +@torch.jit.script +def _contract_rank( + xc: torch.Tensor, wc: torch.Tensor, ac: torch.Tensor, bc: torch.Tensor +) -> torch.Tensor: + resc = torch.einsum("bixy,ior,xr,yr->boxy", xc, wc, ac, bc) + return resc + + +# new contractions set to replace older ones. We use complex + + +@torch.jit.script +def _contract_lmwise(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + resc = torch.einsum("bgixy,gioxy->bgoxy", ac, bc) + return resc + + +@torch.jit.script +def _contract_lwise(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + resc = torch.einsum("bgixy,giox->bgoxy", ac, bc) + return resc + + +@torch.jit.script +def _contract_mwise(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + resc = torch.einsum("bgixy,gioy->bgoxy", ac, bc) + return resc + + +@torch.jit.script +def _contract_sep_lmwise(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + resc = torch.einsum("bgixy,gixy->bgoxy", ac, bc) + return resc + + +@torch.jit.script +def _contract_sep_lwise(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + resc = torch.einsum("bgixy,gix->bgoxy", ac, bc) + return resc + + +@torch.jit.script +def _contract_lmwise_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + res = torch.einsum("bgixys,gioxy->bgoxys", a, b).contiguous() + return res + + +@torch.jit.script +def _contract_lwise_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + res = torch.einsum("bgixys,giox->bgoxys", a, b).contiguous() + return res + + +@torch.jit.script +def _contract_sep_lmwise_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + res = torch.einsum("bgixys,gixy->bgoxys", a, b).contiguous() + return res + + +@torch.jit.script +def _contract_sep_lwise_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + res = torch.einsum("bgixys,gix->bgoxys", a, b).contiguous() + return res diff --git a/fme/core/models/conditional_sfno/v1/makani/factorizations.py b/fme/core/models/conditional_sfno/v1/makani/factorizations.py new file mode 100644 index 000000000..214d0ebdb --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/makani/factorizations.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import torch + +from .contractions import ( + _contract_lmwise, + _contract_lmwise_real, + _contract_lwise, + _contract_lwise_real, + _contract_sep_lmwise, + _contract_sep_lmwise_real, + _contract_sep_lwise, + _contract_sep_lwise_real, +) + +einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + +# jitted PyTorch contractions: +def _contract_dense_pytorch( + x, weight, separable=False, operator_type="diagonal", complex=True +): + # make sure input is contig + x = x.contiguous() + + if separable: + if operator_type == "diagonal": + if complex: + x = _contract_sep_lmwise(x, weight) + else: + x = _contract_sep_lmwise_real(x, weight) + elif operator_type == "dhconv": + if complex: + x = _contract_sep_lwise(x, weight) + else: + x = _contract_sep_lwise_real(x, weight) + else: + raise ValueError(f"Unkonw operator type {operator_type}") + else: + if operator_type == "diagonal": + if complex: + x = _contract_lmwise(x, weight) + else: + x = _contract_lmwise_real(x, weight) + elif operator_type == "dhconv": + if complex: + x = _contract_lwise(x, weight) + else: + x = _contract_lwise_real(x, weight) + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + # make contiguous + x = x.contiguous() + return x + + +def _contract_dense_reconstruct( + x, weight, separable=False, operator_type="diagonal", complex=True +): + """Contraction for dense tensors, factorized or not.""" + if not torch.is_tensor(weight): + weight = weight.to_tensor() + # weight = torch.view_as_real(weight) + + return _contract_dense_pytorch( + x, weight, separable=separable, operator_type=operator_type, complex=complex + ) + + +def get_contract_fun( + weight, + implementation="reconstructed", + separable=False, + operator_type="diagonal", + complex=True, +): + """ + Generic ND implementation of Fourier Spectral Conv contraction. + + Parameters + ---------- + weight : torch.Tensor + implementation : {'reconstructed', 'factorized'}, default is 'reconstructed' + whether to reconstruct the weight and do a forward pass (reconstructed) + or contract directly the factors of the factorized weight with the input + factorized) + + Returns: + ------- + function : (x, weight) -> x * weight in Fourier space + """ + if implementation == "reconstructed": + handle = partial( + _contract_dense_reconstruct, + separable=separable, + complex=complex, + operator_type=operator_type, + ) + return handle + elif implementation == "factorized": + handle = partial( + _contract_dense_pytorch, + separable=separable, + complex=complex, + operator_type=operator_type, + ) + return handle + else: + raise ValueError( + f'Got {implementation=}, expected "reconstructed" or "factorized"' + ) diff --git a/fme/core/models/conditional_sfno/v1/makani/spectral_convolution.py b/fme/core/models/conditional_sfno/v1/makani/spectral_convolution.py new file mode 100644 index 000000000..e99a7f5e1 --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/makani/spectral_convolution.py @@ -0,0 +1,158 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn +from torch import amp + +from fme.core.benchmark.timer import NullTimer, Timer + +# import convenience functions for factorized tensors +from .factorizations import get_contract_fun + + +class SpectralConv(nn.Module): + """ + Spectral Convolution implemented via SHT or FFT. Designed for convolutions on the + two-sphere S2 + using the Spherical Harmonic Transforms in torch-harmonics, but supports + convolutions on the periodic + domain via the RealFFT2 and InverseRealFFT2 wrappers. + """ + + def __init__( + self, + forward_transform, + inverse_transform, + in_channels, + out_channels, + num_groups=1, + operator_type="dhconv", + separable=False, + bias=False, + gain=1.0, + ): + super().__init__() + + assert in_channels % num_groups == 0 + assert out_channels % num_groups == 0 + + self.forward_transform = forward_transform + self.inverse_transform = inverse_transform + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_groups = num_groups + + self.modes_lat = self.inverse_transform.lmax + self.modes_lon = self.inverse_transform.mmax + + self.scale_residual = ( + self.forward_transform.nlat != self.inverse_transform.nlat + ) or (self.forward_transform.nlon != self.inverse_transform.nlon) + if hasattr(self.forward_transform, "grid"): + self.scale_residual = self.scale_residual or ( + self.forward_transform.grid != self.inverse_transform.grid + ) + + # remember factorization details + self.operator_type = operator_type + self.separable = separable + + assert self.inverse_transform.lmax == self.modes_lat + assert self.inverse_transform.mmax == self.modes_lon + + weight_shape = [num_groups, in_channels // num_groups] + + if not self.separable: + weight_shape += [out_channels // num_groups] + + self.modes_lat_local = self.modes_lat + self.modes_lon_local = self.modes_lon + self.nlat_local = self.inverse_transform.nlat + self.nlon_local = self.inverse_transform.nlon + + # unpadded weights + if self.operator_type == "diagonal": + weight_shape += [self.modes_lat_local, self.modes_lon_local] + elif self.operator_type == "dhconv": + weight_shape += [self.modes_lat_local] + else: + raise ValueError(f"Unsupported operator type f{self.operator_type}") + + # Compute scaling factor for correct initialization + scale = math.sqrt(gain / (in_channels // num_groups)) * torch.ones( + self.modes_lat_local, dtype=torch.complex64 + ) + # seemingly the first weight is not really complex, so we need to + # account for that + scale[0] *= math.sqrt(2.0) + init = scale * torch.randn(*weight_shape, dtype=torch.complex64) + self.weight = nn.Parameter(init) + + if self.operator_type == "dhconv": + self.weight.is_shared_mp = ["matmul", "w"] + self.weight.sharded_dims_mp = [None for _ in weight_shape] + self.weight.sharded_dims_mp[-1] = "h" + else: + self.weight.is_shared_mp = ["matmul"] + self.weight.sharded_dims_mp = [None for _ in weight_shape] + self.weight.sharded_dims_mp[-1] = "w" + self.weight.sharded_dims_mp[-2] = "h" + + # get the contraction handle. This should return a pyTorch contraction + self._contract = get_contract_fun( + self.weight, + implementation="factorized", + separable=separable, + complex=True, + operator_type=operator_type, + ) + + if bias: + self.bias = nn.Parameter(torch.zeros(1, self.out_channels, 1, 1)) + + def forward(self, x, timer: Timer = NullTimer()): + dtype = x.dtype + residual = x + x = x.float() + + with amp.autocast(device_type="cuda", enabled=False): + x = self.forward_transform(x).contiguous() + if self.scale_residual: + residual = self.inverse_transform(x) + residual = residual.to(dtype) + + B, C, H, W = x.shape + x = x.reshape(B, self.num_groups, C // self.num_groups, H, W) + xp = self._contract( + x, + self.weight, + separable=self.separable, + operator_type=self.operator_type, + ) + x = xp.reshape(B, self.out_channels, H, W).contiguous() + + with amp.autocast(device_type="cuda", enabled=False): + x = self.inverse_transform(x) + + if hasattr(self, "bias"): + x = x + self.bias + + x = x.to(dtype=dtype) + + return x, residual diff --git a/fme/core/models/conditional_sfno/v1/s2convolutions.py b/fme/core/models/conditional_sfno/v1/s2convolutions.py new file mode 100644 index 000000000..b138dd442 --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/s2convolutions.py @@ -0,0 +1,655 @@ +# flake8: noqa +# Copied from https://github.com/ai2cm/modulus/commit/22df4a9427f5f12ff6ac891083220e7f2f54d229 +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# import FactorizedTensor from tensorly for tensorized operations +import math +import torch +import torch.nn as nn + +import torch_harmonics as th +import torch_harmonics.distributed as thd + +from fme.core.benchmark.timer import NullTimer, Timer + +# import convenience functions for factorized tensors +from .activations import ComplexReLU + +# for the experimental module +from .contractions import ( + _contract_localconv_fwd, + compl_exp_mul2d_fwd, + compl_exp_muladd2d_fwd, + compl_mul2d_fwd, + compl_muladd2d_fwd, + real_mul2d_fwd, + real_muladd2d_fwd, +) + + +@torch.jit.script +def _contract_lora( + lora_A: torch.Tensor, + lora_B: torch.Tensor, + x: torch.Tensor, +): + """ + Performs LoRA update contraction. + + Args: + lora_A: LoRA A matrix of shape (group, in_channels, rank, nlat, 2) + lora_B: LoRA B matrix of shape (group, rank, out_channels, nlat, 2) + x: Complex input tensor of shape + (batch_size, group, in_channels, nlat, nlon) + + Returns: + Complex output tensor of shape (batch_size, group, out_channels, nlat, nlon) + """ + lora_A = torch.view_as_complex(lora_A) + lora_B = torch.view_as_complex(lora_B) + return torch.einsum("girx,grox,bgixy->bgoxy", lora_A, lora_B, x) + + +@torch.jit.script +def _contract_dhconv( + xc: torch.Tensor, weight: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs a complex Driscoll-Healy style convolution operation between two tensors + 'a' and 'b'. + + Args: + xc: Complex input tensor of shape (batch_size, group, in_channels, nlat, nlon) + weight: Weight tensor of shape (group, in_channels, out_channels, nlat, 2) + + Returns: + Complex output tensor of shape (batch_size, group, out_channels, nlat, nlon) + """ + wc = torch.view_as_complex(weight) + return torch.einsum("bgixy,giox->bgoxy", xc, wc) + + +class SpectralConvS2(nn.Module): + """ + Spectral Convolution according to Driscoll & Healy. Designed for convolutions on + the two-sphere S2 using the Spherical Harmonic Transforms in torch-harmonics, but + supports convolutions on the periodic domain via the RealFFT2 and InverseRealFFT2 + wrappers. + """ + + def __init__( + self, + forward_transform, + inverse_transform, + in_channels, + out_channels, + num_groups: int = 1, + scale="auto", + operator_type="diagonal", + rank=0.2, + factorization=None, + separable=False, + decomposition_kwargs=dict(), + bias=False, + use_tensorly=True, + filter_residual: bool = False, + lora_rank: int = 0, + lora_alpha: float | None = None, + ): # pragma: no cover + super(SpectralConvS2, self).__init__() + if operator_type != "dhconv": + raise NotImplementedError( + "Only 'dhconv' operator type is currently supported." + ) + if factorization is not None: + raise NotImplementedError( + "Factorizations other than None are not currently supported." + ) + if use_tensorly: + raise NotImplementedError( + "Tensorly-based implementation is not currently supported." + ) + if separable: + raise NotImplementedError( + "Separable convolutions are not currently supported." + ) + + if in_channels != out_channels: + raise NotImplementedError( + "Currently only in_channels == out_channels is supported." + ) + + assert in_channels % num_groups == 0 + assert out_channels % num_groups == 0 + self.num_groups = num_groups + + if in_channels != out_channels: + raise NotImplementedError( + "Currently only in_channels == out_channels is supported." + ) + + self.forward_transform = forward_transform + self.inverse_transform = inverse_transform + + self.modes_lat = self.inverse_transform.lmax + self.modes_lon = self.inverse_transform.mmax + + self._round_trip_residual = filter_residual or ( + (self.forward_transform.nlat != self.inverse_transform.nlat) + or (self.forward_transform.nlon != self.inverse_transform.nlon) + or (self.forward_transform.grid != self.inverse_transform.grid) + ) + # Make sure we are using a Complex Factorized Tensor + if factorization is None: + factorization = "Dense" # No factorization + + if not factorization.lower().startswith("complex"): + factorization = f"Complex{factorization}" + + # remember factorization details + self.operator_type = operator_type + self.rank = rank + self.factorization = factorization + self.separable = separable + + assert self.inverse_transform.lmax == self.modes_lat + assert self.inverse_transform.mmax == self.modes_lon + + if isinstance(self.inverse_transform, thd.DistributedInverseRealSHT): + self.modes_lat_local = self.inverse_transform.lmax_local + self.modes_lon_local = self.inverse_transform.mmax_local + self.lpad_local = self.inverse_transform.lpad_local + self.mpad_local = self.inverse_transform.mpad_local + else: + self.modes_lat_local = self.modes_lat + self.modes_lon_local = self.modes_lon + self.lpad = 0 + self.mpad = 0 + + if scale == "auto": + scale = math.sqrt(1 / (in_channels)) * torch.ones(self.modes_lat_local, 2) + # seemingly the first weight is not really complex, so we need to account for that + scale[0, :] *= math.sqrt(2.0) + + weight_shape = [ + num_groups, + in_channels // num_groups, + out_channels // num_groups, + self.modes_lat_local, + ] + + assert factorization == "ComplexDense" + self.weight = nn.Parameter(scale * torch.randn(*weight_shape, 2)) + self.weight.is_shared_mp = ["matmul", "w"] + + if lora_rank > 0: + self.lora_A = nn.Parameter( + scale + * torch.randn( + num_groups, + in_channels // num_groups, + lora_rank, + self.modes_lat_local, + 2, + ) + ) + self.lora_B = nn.Parameter( + torch.zeros( + num_groups, + lora_rank, + out_channels // num_groups, + self.modes_lat_local, + 2, + ) + ) + self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank + self.lora_scaling = self.lora_alpha / lora_rank + else: + self.lora_A = None + self.lora_B = None + self.lora_scaling = 0.0 + + if bias: + self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) + self.out_channels = out_channels + + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover + dtype = x.dtype + residual = x + x = x.float() + + with torch.amp.autocast("cuda", enabled=False): + with timer.child("forward_transform"): + x = self.forward_transform(x.float()) + if self._round_trip_residual: + with timer.child("round_trip_residual"): + x = x.contiguous() + residual = self.inverse_transform(x) + residual = residual.to(dtype) + + B, C, H, W = x.shape + assert C % self.num_groups == 0 + x = x.reshape(B, self.num_groups, C // self.num_groups, H, W) + + if self.lora_A is not None and self.lora_B is not None: + with timer.child("lora_update"): + lora_update = _contract_lora( + self.lora_A, + self.lora_B, + x[..., : self.modes_lat_local, : self.modes_lon_local], + ) + else: + lora_update = 0.0 + + with timer.child("dhconv"): + xp = torch.zeros_like(x) + xp[..., : self.modes_lat_local, : self.modes_lon_local] = _contract_dhconv( + x[..., : self.modes_lat_local, : self.modes_lon_local], + self.weight, + ) + xp = xp + self.lora_scaling * lora_update + xp = xp.reshape(B, self.out_channels, H, W) + x = xp.contiguous() + + with torch.amp.autocast("cuda", enabled=False): + with timer.child("inverse_transform"): + x = self.inverse_transform(x) + + if hasattr(self, "bias"): + with timer.child("add_bias"): + x = x + self.bias + + x = x.type(dtype) + + return x, residual + + +class LocalConvS2(nn.Module): + """ + S2 Convolution according to Driscoll & Healy + """ + + def __init__( + self, + forward_transform, + inverse_transform, + in_channels, + out_channels, + nradius=120, + scale="auto", + bias=False, + ): # pragma: no cover + super(LocalConvS2, self).__init__() + + if scale == "auto": + scale = 1 / (in_channels * out_channels) + + self.in_channels = in_channels + self.out_channels = out_channels + self.nradius = nradius + + self.forward_transform = forward_transform + self.zonal_transform = th.RealSHT( + forward_transform.nlat, + 1, + lmax=forward_transform.lmax, + mmax=1, + grid=forward_transform.grid, + ).float() + self.inverse_transform = inverse_transform + + self.modes_lat = self.inverse_transform.lmax + self.modes_lon = self.inverse_transform.mmax + self.output_dims = (self.inverse_transform.nlat, self.inverse_transform.nlon) + + assert self.inverse_transform.lmax == self.modes_lat + assert self.inverse_transform.mmax == self.modes_lon + + self.weight = nn.Parameter( + scale * torch.randn(in_channels, out_channels, nradius, 1) + ) + + self._contract = _contract_localconv_fwd + + if bias: + self.bias = nn.Parameter( + scale * torch.randn(1, out_channels, *self.output_dims) + ) + + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover + dtype = x.dtype + x = x.float() + B, C, H, W = x.shape + + with torch.amp.autocast("cuda", enabled=False): + f = torch.zeros( + (self.in_channels, self.out_channels, H, 1), + dtype=x.dtype, + device=x.device, + ) + f[..., : self.nradius, :] = self.weight + + x = self.forward_transform(x) + f = self.zonal_transform(f)[..., :, 0] + + x = torch.view_as_real(x) + f = torch.view_as_real(f) + + x = self._contract(x, f) + x = x.contiguous() + + x = torch.view_as_complex(x) + + with torch.amp.autocast("cuda", enabled=False): + x = self.inverse_transform(x) + + if hasattr(self, "bias"): + x = x + self.bias + + x = x.type(dtype) + + return x + + +class SpectralAttentionS2(nn.Module): + """ + Spherical non-linear FNO layer + """ + + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + operator_type="diagonal", + sparsity_threshold=0.0, + hidden_size_factor=2, + complex_activation="real", + scale="auto", + bias=False, + spectral_layers=1, + drop_rate=0.0, + ): # pragma: no cover + super(SpectralAttentionS2, self).__init__() + + self.embed_dim = embed_dim + self.sparsity_threshold = sparsity_threshold + self.operator_type = operator_type + self.spectral_layers = spectral_layers + + if scale == "auto": + self.scale = 1 / (embed_dim * embed_dim) + + self.modes_lat = forward_transform.lmax + self.modes_lon = forward_transform.mmax + + # only storing the forward handle to be able to call it + self.forward_transform = forward_transform + self.inverse_transform = inverse_transform + + self.scale_residual = ( + self.forward_transform.nlat != self.inverse_transform.nlat + ) or (self.forward_transform.nlon != self.inverse_transform.nlon) + + assert inverse_transform.lmax == self.modes_lat + assert inverse_transform.mmax == self.modes_lon + + hidden_size = int(hidden_size_factor * self.embed_dim) + + if operator_type == "diagonal": + self.mul_add_handle = compl_muladd2d_fwd + self.mul_handle = compl_mul2d_fwd + + # weights + w = [self.scale * torch.randn(self.embed_dim, hidden_size, 2)] + for l in range(1, self.spectral_layers): + w.append(self.scale * torch.randn(hidden_size, hidden_size, 2)) + self.w = nn.ParameterList(w) + + self.wout = nn.Parameter( + self.scale * torch.randn(hidden_size, self.embed_dim, 2) + ) + + if bias: + self.b = nn.ParameterList( + [ + self.scale * torch.randn(hidden_size, 1, 1, 2) + for _ in range(self.spectral_layers) + ] + ) + + self.activations = nn.ModuleList([]) + for l in range(0, self.spectral_layers): + self.activations.append( + ComplexReLU( + mode=complex_activation, + bias_shape=(hidden_size, 1, 1), + scale=self.scale, + ) + ) + + elif operator_type == "l-dependant": + self.mul_add_handle = compl_exp_muladd2d_fwd + self.mul_handle = compl_exp_mul2d_fwd + + # weights + w = [ + self.scale * torch.randn(self.modes_lat, self.embed_dim, hidden_size, 2) + ] + for l in range(1, self.spectral_layers): + w.append( + self.scale + * torch.randn(self.modes_lat, hidden_size, hidden_size, 2) + ) + self.w = nn.ParameterList(w) + + if bias: + self.b = nn.ParameterList( + [ + self.scale * torch.randn(hidden_size, 1, 1, 2) + for _ in range(self.spectral_layers) + ] + ) + + self.wout = nn.Parameter( + self.scale * torch.randn(self.modes_lat, hidden_size, self.embed_dim, 2) + ) + + self.activations = nn.ModuleList([]) + for l in range(0, self.spectral_layers): + self.activations.append( + ComplexReLU( + mode=complex_activation, + bias_shape=(hidden_size, 1, 1), + scale=self.scale, + ) + ) + + else: + raise ValueError("Unknown operator type") + + self.drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity() + + def forward_mlp(self, x): # pragma: no cover + """forward pass of the MLP""" + B, C, H, W = x.shape + + if self.operator_type == "block-separable": + x = x.permute(0, 3, 1, 2) + + xr = torch.view_as_real(x) + + for l in range(self.spectral_layers): + if hasattr(self, "b"): + xr = self.mul_add_handle(xr, self.w[l], self.b[l]) + else: + xr = self.mul_handle(xr, self.w[l]) + xr = torch.view_as_complex(xr) + xr = self.activations[l](xr) + xr = self.drop(xr) + xr = torch.view_as_real(xr) + + # final MLP + x = self.mul_handle(xr, self.wout) + + x = torch.view_as_complex(x) + + if self.operator_type == "block-separable": + x = x.permute(0, 2, 3, 1) + + return x + + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover + dtype = x.dtype + residual = x + x = x.to(torch.float32) + + # FWD transform + with torch.amp.autocast("cuda", enabled=False): + x = self.forward_transform(x) + if self.scale_residual: + x = x.contiguous() + residual = self.inverse_transform(x) + residual = residual.to(dtype) + + # MLP + x = self.forward_mlp(x) + + # BWD transform + x = x.contiguous() + with torch.amp.autocast("cuda", enabled=False): + x = self.inverse_transform(x) + + # cast back to initial precision + x = x.to(dtype) + + return x, residual + + +class RealSpectralAttentionS2(nn.Module): + """ + Non-linear SFNO layer using a real-valued NN instead of a complex one + """ + + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + operator_type="diagonal", + sparsity_threshold=0.0, + hidden_size_factor=2, + complex_activation="real", + scale="auto", + bias=False, + spectral_layers=1, + drop_rate=0.0, + ): # pragma: no cover + super(RealSpectralAttentionS2, self).__init__() + + self.embed_dim = embed_dim + self.sparsity_threshold = sparsity_threshold + self.operator_type = operator_type + self.spectral_layers = spectral_layers + + if scale == "auto": + self.scale = 1 / (embed_dim * embed_dim) + + self.modes_lat = forward_transform.lmax + self.modes_lon = forward_transform.mmax + + # only storing the forward handle to be able to call it + self.forward_transform = forward_transform + self.inverse_transform = inverse_transform + + self.scale_residual = ( + self.forward_transform.nlat != self.inverse_transform.nlat + ) or (self.forward_transform.nlon != self.inverse_transform.nlon) + + assert inverse_transform.lmax == self.modes_lat + assert inverse_transform.mmax == self.modes_lon + + hidden_size = int(hidden_size_factor * self.embed_dim * 2) + + self.mul_add_handle = real_muladd2d_fwd + self.mul_handle = real_mul2d_fwd + + # weights + w = [self.scale * torch.randn(2 * self.embed_dim, hidden_size)] + for l in range(1, self.spectral_layers): + w.append(self.scale * torch.randn(hidden_size, hidden_size)) + self.w = nn.ParameterList(w) + + self.wout = nn.Parameter( + self.scale * torch.randn(hidden_size, 2 * self.embed_dim) + ) + + if bias: + self.b = nn.ParameterList( + [ + self.scale * torch.randn(hidden_size, 1, 1) + for _ in range(self.spectral_layers) + ] + ) + + self.activations = nn.ModuleList([]) + for l in range(0, self.spectral_layers): + self.activations.append(nn.ReLU()) + + self.drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity() + + def forward_mlp(self, x): # pragma: no cover + """forward pass of the MLP""" + B, C, H, W = x.shape + + xr = torch.view_as_real(x) + xr = xr.permute(0, 1, 4, 2, 3).reshape(B, C * 2, H, W) + + for l in range(self.spectral_layers): + if hasattr(self, "b"): + xr = self.mul_add_handle(xr, self.w[l], self.b[l]) + else: + xr = self.mul_handle(xr, self.w[l]) + xr = self.activations[l](xr) + xr = self.drop(xr) + + # final MLP + xr = self.mul_handle(xr, self.wout) + + xr = xr.reshape(B, C, 2, H, W).permute(0, 1, 3, 4, 2) + + x = torch.view_as_complex(xr) + + return x + + def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover + dtype = x.dtype + x = x.to(torch.float32) + + # FWD transform + with torch.amp.autocast("cuda", enabled=False): + x = self.forward_transform(x) + + # MLP + x = self.forward_mlp(x) + + # BWD transform + with torch.amp.autocast("cuda", enabled=False): + x = self.inverse_transform(x) + + # cast back to initial precision + x = x.to(dtype) + + return x diff --git a/fme/core/models/conditional_sfno/v1/sfnonet.py b/fme/core/models/conditional_sfno/v1/sfnonet.py new file mode 100644 index 000000000..29eb986f0 --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/sfnonet.py @@ -0,0 +1,900 @@ +# flake8: noqa +# Copied from https://github.com/ai2cm/modulus/commit/22df4a9427f5f12ff6ac891083220e7f2f54d229 +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Callable, List, Optional, Tuple + +import torch +import torch.nn as nn + +# get spectral transforms from torch_harmonics +import torch_harmonics as th +from torch.utils.checkpoint import checkpoint + +from fme.core.benchmark.timer import Timer, NullTimer + +from .initialization import trunc_normal_ + +# wrap fft, to unify interface to spectral transforms +# import global convolution and non-linear spectral layers +# helpers +from .layers import ( + MLP, + ConditionalLayerNorm, + Context, + ContextConfig, + DropPath, +) +from .lora import LoRAConv2d +from .s2convolutions import SpectralAttentionS2, SpectralConvS2 +from .makani.spectral_convolution import SpectralConv + + +# heuristic for finding theta_cutoff +def _compute_cutoff_radius(nlat, kernel_shape, basis_type): + theta_cutoff_factor = { + "piecewise linear": 0.5, + "morlet": 0.5, + "zernike": math.sqrt(2.0), + } + + return ( + (kernel_shape[0] + 1) + * theta_cutoff_factor[basis_type] + * math.pi + / float(nlat - 1) + ) + + +class DiscreteContinuousConvS2(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = th.DiscreteContinuousConvS2(*args, **kwargs) + + def forward(self, x, timer: Timer = NullTimer()): + return self.conv(x), x + + +class SpectralFilterLayer(nn.Module): + """Spectral filter layer""" + + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + filter_type="linear", + operator_type="block-diagonal", + sparsity_threshold=0.0, + use_complex_kernels=True, + hidden_size_factor=1, + rank=1.0, + factorization=None, + separable=False, + complex_network=True, + complex_activation="real", + spectral_layers=1, + drop_rate=0.0, + num_groups=1, + filter_residual=False, + lora_rank: int = 0, + lora_alpha: float | None = None, + ): + super(SpectralFilterLayer, self).__init__() + + if lora_rank != 0 and filter_type != "linear": + raise NotImplementedError("LoRA is only supported for linear filter type.") + + if filter_type == "non-linear": + raise NotImplementedError("Non-linear spectral filters are not supported.") + + # spectral transform is passed to the module + elif filter_type == "linear": + self.filter = SpectralConvS2( + forward_transform, + inverse_transform, + embed_dim, + embed_dim, + operator_type=operator_type, + rank=rank, + factorization=factorization, + separable=separable, + bias=True, + use_tensorly=False if factorization is None else True, + filter_residual=filter_residual, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + num_groups=num_groups, + ) + elif filter_type == "makani-linear": + self.filter = SpectralConv( + forward_transform, + inverse_transform, + embed_dim, + embed_dim, + operator_type="dhconv", + num_groups=num_groups, + bias=False, + gain=1.0, + ) + + elif filter_type == "local": + # heuristic for finding theta_cutoff + theta_cutoff = 2 * _compute_cutoff_radius( + nlat=forward_transform.nlat, + kernel_shape=(3, 3), + basis_type="morlet", + ) + self.filter = DiscreteContinuousConvS2( + embed_dim, + embed_dim, + in_shape=(forward_transform.nlat, forward_transform.nlon), + out_shape=(inverse_transform.nlat, inverse_transform.nlon), + kernel_shape=(3, 3), + basis_type="morlet", + basis_norm_mode="mean", + groups=1, + grid_in=forward_transform.grid, + grid_out=inverse_transform.grid, + bias=False, + theta_cutoff=theta_cutoff, + ) + else: + raise (NotImplementedError) + + def forward(self, x, timer: Timer = NullTimer()): + return self.filter(x, timer=timer) + + +class FourierNeuralOperatorBlock(nn.Module): + """Fourier Neural Operator Block""" + + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + img_shape: Tuple[int, int], + context_config: ContextConfig, + filter_type="linear", + operator_type="diagonal", + global_layer_norm: bool = False, + mlp_ratio=2.0, + drop_rate=0.0, + drop_path=0.0, + act_layer=nn.GELU, + sparsity_threshold=0.0, + use_complex_kernels=True, + rank=1.0, + factorization=None, + separable=False, + inner_skip="linear", + outer_skip=None, # None, nn.linear or nn.Identity + concat_skip=False, + use_mlp=False, + complex_network=True, + complex_activation="real", + spectral_layers=1, + checkpointing=0, + filter_residual=False, + affine_norms=False, + filter_num_groups: int = 1, + lora_rank: int = 0, + lora_alpha: float | None = None, + spectral_lora_rank: int = 0, + spectral_lora_alpha: float | None = None, + ): + super(FourierNeuralOperatorBlock, self).__init__() + + self.input_shape_loc = img_shape + self.output_shape_loc = img_shape + + # norm layer + self.norm0 = ConditionalLayerNorm( + embed_dim, + img_shape=self.input_shape_loc, + global_layer_norm=global_layer_norm, + context_config=context_config, + elementwise_affine=affine_norms, + ) + + # convolution layer + self.filter = SpectralFilterLayer( + forward_transform, + inverse_transform, + embed_dim, + filter_type, + operator_type, + sparsity_threshold, + use_complex_kernels=use_complex_kernels, + hidden_size_factor=mlp_ratio, + rank=rank, + factorization=factorization, + separable=separable, + complex_network=complex_network, + complex_activation=complex_activation, + spectral_layers=spectral_layers, + drop_rate=drop_rate, + filter_residual=filter_residual, + num_groups=filter_num_groups, + lora_rank=spectral_lora_rank, + lora_alpha=spectral_lora_alpha, + ) + + if inner_skip == "linear": + self.inner_skip = LoRAConv2d( + embed_dim, embed_dim, 1, 1, lora_rank=lora_rank, lora_alpha=lora_alpha + ) + elif inner_skip == "identity": + self.inner_skip = nn.Identity() + + self.concat_skip = concat_skip + + if concat_skip and inner_skip is not None: + self.inner_skip_conv = LoRAConv2d( + 2 * embed_dim, + embed_dim, + 1, + bias=False, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + + if filter_type == "linear" or filter_type == "real linear": + self.act_layer = act_layer() + + # dropout + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + # norm layer + self.norm1 = ConditionalLayerNorm( + embed_dim, + img_shape=self.output_shape_loc, + global_layer_norm=global_layer_norm, + context_config=context_config, + elementwise_affine=affine_norms, + ) + + if use_mlp == True: + mlp_hidden_dim = int(embed_dim * mlp_ratio) + self.mlp = MLP( + in_features=embed_dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop_rate=drop_rate, + checkpointing=checkpointing, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + + if outer_skip == "linear": + self.outer_skip = LoRAConv2d( + embed_dim, embed_dim, 1, 1, lora_rank=lora_rank, lora_alpha=lora_alpha + ) + elif outer_skip == "identity": + self.outer_skip = nn.Identity() + + if concat_skip and outer_skip is not None: + self.outer_skip_conv = LoRAConv2d( + 2 * embed_dim, + embed_dim, + 1, + bias=False, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + + def forward(self, x, context_embedding, timer: Timer = NullTimer()): + with timer.child("norm0") as norm0_timer: + x_norm = torch.zeros_like(x) + x_norm[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] = ( + self.norm0( + x[..., : self.input_shape_loc[0], : self.input_shape_loc[1]], + context_embedding, + timer=norm0_timer, + ) + ) + with timer.child("filter") as filter_timer: + x, residual = self.filter(x_norm, timer=filter_timer) + if hasattr(self, "inner_skip"): + with timer.child("inner_skip"): + if self.concat_skip: + x = torch.cat((x, self.inner_skip(residual)), dim=1) + x = self.inner_skip_conv(x) + else: + x = x + self.inner_skip(residual) + + if hasattr(self, "act_layer"): + with timer.child("activation"): + x = self.act_layer(x) + + with timer.child("norm1") as norm1_timer: + x_norm = torch.zeros_like(x) + x_norm[..., : self.output_shape_loc[0], : self.output_shape_loc[1]] = ( + self.norm1( + x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]], + context_embedding, + timer=norm1_timer, + ) + ) + x = x_norm + + if hasattr(self, "mlp"): + with timer.child("mlp"): + x = self.mlp(x) + + x = self.drop_path(x) + + if hasattr(self, "outer_skip"): + with timer.child("outer_skip"): + if self.concat_skip: + x = torch.cat((x, self.outer_skip(residual)), dim=1) + x = self.outer_skip_conv(x) + else: + x = x + self.outer_skip(residual) + + return x + + +class NoLayerNorm(nn.Module): + def forward(self, x, context: Context): + return x + + +def get_lat_lon_sfnonet( + params, + in_chans: int, + out_chans: int, + img_shape: Tuple[int, int], + context_config: ContextConfig = ContextConfig( + embed_dim_scalar=0, + embed_dim_noise=0, + embed_dim_labels=0, + embed_dim_pos=0, + ), +) -> "SphericalFourierNeuralOperatorNet": + h, w = img_shape + hard_thresholding_fraction = ( + params.hard_thresholding_fraction + if hasattr(params, "hard_thresholding_fraction") + else 1.0 + ) + modes_lat = int(h * hard_thresholding_fraction) + modes_lon = int((w // 2 + 1) * hard_thresholding_fraction) + data_grid = params.data_grid if hasattr(params, "data_grid") else "equiangular" + trans_down = th.RealSHT( + *img_shape, lmax=modes_lat, mmax=modes_lon, grid=data_grid + ).float() + itrans_up = th.InverseRealSHT( + *img_shape, lmax=modes_lat, mmax=modes_lon, grid=data_grid + ).float() + trans = th.RealSHT( + *img_shape, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss" + ).float() + itrans = th.InverseRealSHT( + h, w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss" + ).float() + + def get_pos_embed(): + pos_embed = nn.Parameter(torch.zeros(1, params.embed_dim, h, w)) + pos_embed.is_shared_mp = ["matmul"] + trunc_normal_(pos_embed, std=0.02) + return pos_embed + + return SphericalFourierNeuralOperatorNet( + params, + img_shape=img_shape, + in_chans=in_chans, + out_chans=out_chans, + context_config=context_config, + trans_down=trans_down, + itrans_up=itrans_up, + trans=trans, + itrans=itrans, + get_pos_embed=get_pos_embed, + ) + + +class SphericalFourierNeuralOperatorNet(torch.nn.Module): + """ + Spherical Fourier Neural Operator Network + + Parameters + ---------- + params : dict + Dictionary of parameters + img_shape : tuple + Shape of the input channels, by default (721, 1440) + get_pos_embed : Callable + Function to get the positional embedding + trans_down : nn.Module + Transform from input space to spectral space + itrans_up : nn.Module + Transform from spectral space to output space + trans : nn.Module + Transform from intermediate data space to spectral space + itrans : nn.Module + Transform from spectral space to intermediate data space + filter_type : str, optional + Type of filter to use ('linear', 'non-linear'), by default "non-linear" + operator_type : str, optional + Type of operator to use ('diaginal', 'dhconv'), by default "diagonal" + scale_factor : int, optional + Scale factor to use, by default 16 + in_chans : int, optional + Number of input channels, by default 2 + out_chans : int, optional + Number of output channels, by default 2 + embed_dim : int, optional + Dimension of the embeddings, by default 256 + context_config : ContextConfig, optional + Context configuration, by default + ContextConfig(embed_dim_scalar=0, embed_dim_2d=0) + num_layers : int, optional + Number of layers in the network, by default 12 + use_mlp : int, optional + Whether to use MLP, by default True + mlp_ratio : int, optional + Ratio of MLP to use, by default 2.0 + activation_function : str, optional + Activation function to use, by default "gelu" + encoder_layers : int, optional + Number of layers in the encoder, by default 1 + pos_embed : bool, optional + Whether to use positional embedding, by default True + drop_rate : float, optional + Dropout rate, by default 0.0 + drop_path_rate : float, optional + Dropout path rate, by default 0.0 + num_blocks : int, optional + Number of blocks in the network, by default 16 + sparsity_threshold : float, optional + Threshold for sparsity, by default 0.0 + hard_thresholding_fraction : float, optional + Fraction of hard thresholding to apply, by default 1.0 + use_complex_kernels : bool, optional + Whether to use complex kernels, by default True + big_skip : bool, optional + Whether to use big skip connections, by default True + rank : float, optional + Rank of the approximation, by default 1.0 + factorization : Any, optional + Type of factorization to use, by default None + separable : bool, optional + Whether to use separable convolutions, by default False + complex_network : bool, optional + Whether to use a complex network architecture, by default True + complex_activation : str, optional + Type of complex activation function to use, by default "real" + spectral_layers : int, optional + Number of spectral layers, by default 3 + checkpointing : int, optional + Number of checkpointing segments, by default 0 + local_blocks: List[int], optional + List of blocks to use local filters, by default [] + normalize_big_skip: bool, optional + Whether to normalize the big_skip connection, by default False + affine_norms: bool, optional + Whether to use element-wise affine parameters in the normalization layers, + by default False. + + Example: + -------- + >>> from modulus.models.sfno.sfnonet import SphericalFourierNeuralOperatorNet as SFNO + >>> model = SFNO( + ... params={}, + ... img_shape=(8, 16), + ... scale_factor=4, + ... in_chans=2, + ... out_chans=2, + ... embed_dim=16, + ... num_layers=2, + ... encoder_layers=1, + ... num_blocks=4, + ... spectral_layers=2, + ... use_mlp=True,) + >>> model(torch.randn(1, 2, 8, 16)).shape + torch.Size([1, 2, 8, 16]) + """ + + def __init__( + self, + params, + img_shape: Tuple[int, int], + get_pos_embed: Callable[[], nn.Parameter], + trans_down: nn.Module, + itrans_up: nn.Module, + trans: nn.Module, + itrans: nn.Module, + filter_type: str = "linear", + operator_type: str = "diagonal", + scale_factor: int = 1, + in_chans: int = 2, + out_chans: int = 2, + embed_dim: int = 256, + context_config: ContextConfig = ContextConfig( + embed_dim_scalar=0, + embed_dim_labels=0, + embed_dim_noise=0, + embed_dim_pos=0, + ), + global_layer_norm: bool = False, + num_layers: int = 12, + use_mlp: int = True, + mlp_ratio: float = 2.0, + activation_function: str = "gelu", + encoder_layers: int = 1, + pos_embed: bool = True, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + num_blocks: int = 16, + sparsity_threshold: float = 0.0, + hard_thresholding_fraction: float = 1.0, + use_complex_kernels: bool = True, + big_skip: bool = True, + rank: float = 1.0, + factorization: Any = None, + separable: bool = False, + complex_network: bool = True, + complex_activation: str = "real", + spectral_layers: int = 3, + checkpointing: int = 0, + filter_num_groups: int = 1, + filter_residual: bool = False, + filter_output: bool = False, + local_blocks: Optional[List[int]] = None, + normalize_big_skip: bool = False, + affine_norms: bool = False, + lora_rank: int = 0, + lora_alpha: float | None = None, + spectral_lora_rank: int = 0, + spectral_lora_alpha: float | None = None, + ): + super(SphericalFourierNeuralOperatorNet, self).__init__() + + self.params = params + self.filter_type = ( + params.filter_type if hasattr(params, "filter_type") else filter_type + ) + self.filter_residual = ( + params.filter_residual + if hasattr(params, "filter_residual") + else filter_residual + ) + self.filter_output = ( + params.filter_output if hasattr(params, "filter_output") else filter_output + ) + self.mlp_ratio = params.mlp_ratio if hasattr(params, "mlp_ratio") else mlp_ratio + self.operator_type = ( + params.operator_type if hasattr(params, "operator_type") else operator_type + ) + self.img_shape = ( + (params.img_shape_x, params.img_shape_y) + if hasattr(params, "img_shape_x") and hasattr(params, "img_shape_y") + else img_shape + ) + self.scale_factor = ( + params.scale_factor if hasattr(params, "scale_factor") else scale_factor + ) + if self.scale_factor != 1: + raise NotImplementedError( + "scale factor must be 1 as it is not implemented for " + "conditional layer normalization" + ) + self.global_layer_norm = ( + params.global_layer_norm + if hasattr(params, "global_layer_norm") + else global_layer_norm + ) + self.in_chans = ( + params.N_in_channels if hasattr(params, "N_in_channels") else in_chans + ) + self.out_chans = ( + params.N_out_channels if hasattr(params, "N_out_channels") else out_chans + ) + self.embed_dim = self.num_features = ( + params.embed_dim if hasattr(params, "embed_dim") else embed_dim + ) + self.num_layers = ( + params.num_layers if hasattr(params, "num_layers") else num_layers + ) + self.num_blocks = ( + params.num_blocks if hasattr(params, "num_blocks") else num_blocks + ) + self.hard_thresholding_fraction = ( + params.hard_thresholding_fraction + if hasattr(params, "hard_thresholding_fraction") + else hard_thresholding_fraction + ) + self.use_mlp = params.use_mlp if hasattr(params, "use_mlp") else use_mlp + self.activation_function = ( + params.activation_function + if hasattr(params, "activation_function") + else activation_function + ) + self.encoder_layers = ( + params.encoder_layers + if hasattr(params, "encoder_layers") + else encoder_layers + ) + self.pos_embed = params.pos_embed if hasattr(params, "pos_embed") else pos_embed + self.big_skip = params.big_skip if hasattr(params, "big_skip") else big_skip + self.rank = params.rank if hasattr(params, "rank") else rank + self.factorization = ( + params.factorization if hasattr(params, "factorization") else factorization + ) + self.separable = params.separable if hasattr(params, "separable") else separable + self.complex_network = ( + params.complex_network + if hasattr(params, "complex_network") + else complex_network + ) + self.complex_activation = ( + params.complex_activation + if hasattr(params, "complex_activation") + else complex_activation + ) + self.spectral_layers = ( + params.spectral_layers + if hasattr(params, "spectral_layers") + else spectral_layers + ) + self.checkpointing = ( + params.checkpointing if hasattr(params, "checkpointing") else checkpointing + ) + local_blocks = ( + params.local_blocks if hasattr(params, "local_blocks") else local_blocks + ) + if local_blocks is not None: + self.local_blocks = [i for i in range(self.num_layers) if i in local_blocks] + else: + self.local_blocks = [] + normalize_big_skip = ( + params.normalize_big_skip + if hasattr(params, "normalize_big_skip") + else normalize_big_skip + ) + self.affine_norms = ( + params.affine_norms if hasattr(params, "affine_norms") else affine_norms + ) + self.filter_num_groups = ( + params.filter_num_groups + if hasattr(params, "filter_num_groups") + else filter_num_groups + ) + self.lora_rank = params.lora_rank if hasattr(params, "lora_rank") else lora_rank + self.lora_alpha = ( + params.lora_alpha if hasattr(params, "lora_alpha") else lora_alpha + ) + self.spectral_lora_rank = ( + params.spectral_lora_rank + if hasattr(params, "spectral_lora_rank") + else spectral_lora_rank + ) + self.spectral_lora_alpha = ( + params.spectral_lora_alpha + if hasattr(params, "spectral_lora_alpha") + else spectral_lora_alpha + ) + + # no global padding because we removed the horizontal distributed code + self.padding = (0, 0) + + self.trans_down = trans_down + self.itrans_up = itrans_up + self.trans = trans + self.itrans = itrans + + if self.filter_residual: + self.residual_filter_down = self.trans_down + self.residual_filter_up = self.itrans_up + else: + self.residual_filter_down = nn.Identity() + self.residual_filter_up = nn.Identity() + + if self.filter_output: + self.filter_output_down = self.trans_down + self.filter_output_up = self.itrans_up + else: + self.filter_output_down = nn.Identity() + self.filter_output_up = nn.Identity() + + # determine activation function + if self.activation_function == "relu": + self.activation_function = nn.ReLU + elif self.activation_function == "gelu": + self.activation_function = nn.GELU + elif self.activation_function == "silu": + self.activation_function = nn.SiLU + else: + raise ValueError(f"Unknown activation function {self.activation_function}") + + # encoder + encoder_hidden_dim = self.embed_dim + current_dim = self.in_chans + encoder_modules = [] + for i in range(self.encoder_layers): + encoder_modules.append( + LoRAConv2d( + current_dim, + encoder_hidden_dim, + 1, + bias=True, + lora_rank=self.lora_rank, + lora_alpha=self.lora_alpha, + ) + ) + encoder_modules.append(self.activation_function()) + current_dim = encoder_hidden_dim + encoder_modules.append( + LoRAConv2d( + current_dim, + self.embed_dim, + 1, + bias=False, + lora_rank=self.lora_rank, + lora_alpha=self.lora_alpha, + ) + ) + self.encoder = nn.Sequential(*encoder_modules) + + # dropout + self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)] + + # FNO blocks + self.blocks = nn.ModuleList([]) + for i in range(self.num_layers): + if i in self.local_blocks: + block_filter_type = "local" + else: + block_filter_type = self.filter_type + + first_layer = i == 0 + last_layer = i == self.num_layers - 1 + + forward_transform = self.trans_down if first_layer else self.trans + inverse_transform = self.itrans_up if last_layer else self.itrans + + inner_skip = "linear" + outer_skip = "identity" + + block = FourierNeuralOperatorBlock( + forward_transform, + inverse_transform, + self.embed_dim, + img_shape=self.img_shape, + context_config=context_config, + filter_type=block_filter_type, + operator_type=self.operator_type, + mlp_ratio=self.mlp_ratio, + drop_rate=drop_rate, + drop_path=dpr[i], + act_layer=self.activation_function, + sparsity_threshold=sparsity_threshold, + global_layer_norm=self.global_layer_norm, + use_complex_kernels=use_complex_kernels, + inner_skip=inner_skip, + outer_skip=outer_skip, + use_mlp=self.use_mlp, + rank=self.rank, + factorization=self.factorization, + separable=self.separable, + complex_network=self.complex_network, + complex_activation=self.complex_activation, + spectral_layers=self.spectral_layers, + checkpointing=self.checkpointing, + filter_residual=self.filter_residual, + affine_norms=self.affine_norms, + filter_num_groups=self.filter_num_groups, + lora_rank=self.lora_rank, + lora_alpha=self.lora_alpha, + spectral_lora_rank=self.spectral_lora_rank, + spectral_lora_alpha=self.spectral_lora_alpha, + ) + + self.blocks.append(block) + + # decoder + decoder_hidden_dim = self.embed_dim + current_dim = self.embed_dim + self.big_skip * self.in_chans + decoder_modules = [] + for i in range(self.encoder_layers): + decoder_modules.append( + LoRAConv2d( + current_dim, + decoder_hidden_dim, + 1, + bias=True, + lora_rank=self.lora_rank, + lora_alpha=self.lora_alpha, + ) + ) + decoder_modules.append(self.activation_function()) + current_dim = decoder_hidden_dim + decoder_modules.append( + LoRAConv2d( + current_dim, + self.out_chans, + 1, + bias=False, + lora_rank=self.lora_rank, + lora_alpha=self.lora_alpha, + ) + ) + self.decoder = nn.Sequential(*decoder_modules) + + # learned position embedding + if self.pos_embed: + self.pos_embed = get_pos_embed() + + if normalize_big_skip: + self.norm_big_skip = ConditionalLayerNorm( + in_chans, + img_shape=self.img_shape, + global_layer_norm=self.global_layer_norm, + context_config=context_config, + elementwise_affine=self.affine_norms, + ) + else: + self.norm_big_skip = NoLayerNorm() + + @torch.jit.ignore + def no_weight_decay(self): # pragma: no cover + """Helper""" + return {"pos_embed", "cls_token"} + + def _forward_features(self, x: torch.Tensor, context: Context): + for blk in self.blocks: + if self.checkpointing >= 3: + x = checkpoint(blk, x, context) + else: + x = blk(x, context) + + return x + + def forward(self, x: torch.Tensor, context: Context): + # save big skip + if self.big_skip: + residual = self.residual_filter_up(self.residual_filter_down(x)) + residual = self.norm_big_skip(residual, context=context) + + if self.checkpointing >= 1: + x = checkpoint(self.encoder, x) + else: + x = self.encoder(x) + + if hasattr(self, "pos_embed"): + # old way of treating unequally shaped weights + x = x + self.pos_embed + + # maybe clean the padding just in case + + x = self.pos_drop(x) + + x = self._forward_features(x, context) + + if self.big_skip: + x = torch.cat((x, residual), dim=1) + + if self.checkpointing >= 1: + x = checkpoint(self.decoder, x) + else: + x = self.decoder(x) + + x = self.filter_output_up(self.filter_output_down(x)) + + return x diff --git a/fme/core/models/conditional_sfno/v1/sht.py b/fme/core/models/conditional_sfno/v1/sht.py new file mode 100644 index 000000000..dd9f8fc02 --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/sht.py @@ -0,0 +1,225 @@ +# flake8: noqa +# fmt: off +# isort: skip_file + +""" +This file contains a fix that we needed to get the SFNO to work on multiple +unroll steps in multiprocessing (e.g. multi-GPU mode.) We forked this code from +the torch harmonics sht.py file [*]. + +[*] https://github.com/NVIDIA/torch-harmonics/blob/17eefa53468d1a885d72087918eba905fa53e10a/torch_harmonics/sht.py +""" + + +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import torch +import torch.nn as nn +import torch.fft + +from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights +from torch_harmonics.legendre import _precompute_legpoly + +from fme.core.device import get_device +from fme.core.benchmark.timer import Timer, NullTimer + + +class RealSHT(nn.Module): + """ + Defines a module for computing the forward (real-valued) SHT. + Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. + The SHT is applied to the last two dimensions of the input + + [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. + [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True): + """ + Initializes the SHT Layer, precomputing the necessary quadrature weights + + Parameters: + nlat: input grid resolution in the latitudinal direction + nlon: input grid resolution in the longitudinal direction + grid: grid in the latitude direction (for now only tensor product grids are supported) + """ + + super().__init__() + + self.nlat = nlat + self.nlon = nlon + self.grid = grid + self.norm = norm + self.csphase = csphase + + # TODO: include assertions regarding the dimensions + + # compute quadrature points + if self.grid == "legendre-gauss": + cost, w = legendre_gauss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "lobatto": + cost, w = lobatto_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat-1 + elif self.grid == "equiangular": + cost, w = clenshaw_curtiss_weights(nlat, -1, 1) + # cost, w = fejer2_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "healpix": + raise(NotImplementedError("'healpix' grid not supported by InverseRealVectorSHT")) + else: + raise(ValueError("Unknown quadrature mode")) + + # apply cosine transform and flip them + tq = torch.flip(torch.arccos(cost), dims=(0,)) + + # determine the dimensions + self.mmax = mmax or self.nlon // 2 + 1 + + # combine quadrature weights with the legendre weights + pct = torch.as_tensor(_precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)) + weights = torch.einsum('mlk,k->mlk', pct, w) + + # remember quadrature weights + self.weights = weights.float().to(get_device()) + + def extra_repr(self): + """ + Pretty print module + """ + return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' + + def forward(self, x: torch.Tensor, timer: Timer = NullTimer()): + + assert(x.shape[-2] == self.nlat) + assert(x.shape[-1] == self.nlon) + with torch.autocast("cuda", enabled=False): + with timer.child("rfft"): + # rfft and view_as_complex don't support BF16, see https://github.com/pytorch/pytorch/issues/117844 + x = x.float() + + # apply real fft in the longitudinal direction + x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward") + + with timer.child("contraction"): + # do the Legendre-Gauss quadrature + x = torch.view_as_real(x) + + # distributed contraction: fork + out_shape = list(x.size()) + out_shape[-3] = self.lmax + out_shape[-2] = self.mmax + xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) + + # contraction + weights = self.weights.to(x.device).to(x.dtype) + xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], weights) + xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], weights) + x = torch.view_as_complex(xout) + + return x + +class InverseRealSHT(nn.Module): + """ + Defines a module for computing the inverse (real-valued) SHT. + Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. + nlat, nlon: Output dimensions + lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions + + [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. + [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True): + + super().__init__() + + self.nlat = nlat + self.nlon = nlon + self.grid = grid + self.norm = norm + self.csphase = csphase + + # compute quadrature points + if self.grid == "legendre-gauss": + cost, _ = legendre_gauss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "lobatto": + cost, _ = lobatto_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat-1 + elif self.grid == "equiangular": + cost, _ = clenshaw_curtiss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "healpix": + raise(NotImplementedError("'healpix' grid not supported by RealVectorSHT")) + else: + raise(ValueError("Unknown quadrature mode")) + + # apply cosine transform and flip them + t = torch.flip(torch.arccos(cost), dims=(0,)) + + # determine the dimensions + self.mmax = mmax or self.nlon // 2 + 1 + + pct = torch.as_tensor(_precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)) + + # register buffer + self.pct = pct.float().to(get_device()) + + def extra_repr(self): + """ + Pretty print module + """ + return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' + + def forward(self, x: torch.Tensor, timer: Timer = NullTimer()): + + assert(x.shape[-2] == self.lmax) + assert(x.shape[-1] == self.mmax) + + with torch.autocast("cuda", enabled=False): + with timer.child("contraction"): + # irfft and view_as_complex don't support BF16, see https://github.com/pytorch/pytorch/issues/117844 + # Evaluate associated Legendre functions on the output nodes + x = torch.view_as_real(x).float() + + pct = self.pct.to(x.device).to(x.dtype) + rl = torch.einsum('...lm, mlk->...km', x[..., 0], pct ) + im = torch.einsum('...lm, mlk->...km', x[..., 1], pct ) + xs = torch.stack((rl, im), -1) + + # apply the inverse (real) FFT + x = torch.view_as_complex(xs) + with timer.child("irfft"): + x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") + + return x diff --git a/fme/core/models/conditional_sfno/v1/stochastic_sfno.py b/fme/core/models/conditional_sfno/v1/stochastic_sfno.py new file mode 100644 index 000000000..f82e3fc86 --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/stochastic_sfno.py @@ -0,0 +1,145 @@ +import math +from collections.abc import Callable +from typing import Literal + +import torch + +from fme.core.dataset_info import DatasetInfo + +from .sfnonet import Context, ContextConfig, get_lat_lon_sfnonet +from .sfnonet import SphericalFourierNeuralOperatorNet as ConditionalSFNO + + +def isotropic_noise( + leading_shape: tuple[int, ...], + lmax: int, # length of the ℓ axis expected by isht + mmax: int, # length of the m axis expected by isht + isht: Callable[[torch.Tensor], torch.Tensor], + device: torch.device, +) -> torch.Tensor: + # --- draw independent N(0,1) parts -------------------------------------- + coeff_shape = (*leading_shape, lmax, mmax) + real = torch.randn(coeff_shape, dtype=torch.float32, device=device) + imag = torch.randn(coeff_shape, dtype=torch.float32, device=device) + imag[..., :, 0] = 0.0 # m = 0 ⇒ purely real + + # m > 0: make Re and Im each N(0,½) → |a_{ℓ m}|² has variance 1 + sqrt2 = math.sqrt(2.0) + real[..., :, 1:] /= sqrt2 + imag[..., :, 1:] /= sqrt2 + + # --- global scale that makes Var[T(θ,φ)] = 1 --------------------------- + scale = math.sqrt(4.0 * math.pi) / lmax # (Unsöld theorem ⇒ L = lmax) + alm = (real + 1j * imag) * scale + + return isht(alm) + + +class NoiseConditionedSFNO(torch.nn.Module): + def __init__( + self, + conditional_model: ConditionalSFNO, + img_shape: tuple[int, int], + noise_type: Literal["isotropic", "gaussian"] = "gaussian", + embed_dim_noise: int = 256, + embed_dim_pos: int = 0, + embed_dim_labels: int = 0, + ): + super().__init__() + self.conditional_model = conditional_model + self.embed_dim = embed_dim_noise + self.noise_type = noise_type + self.label_pos_embed: torch.nn.Parameter | None = None + # register pos embed if pos_embed_dim != 0 + if embed_dim_pos != 0: + self.pos_embed = torch.nn.Parameter( + torch.zeros( + 1, embed_dim_pos, img_shape[0], img_shape[1], requires_grad=True + ) + ) + # initialize pos embed with std=0.02 + torch.nn.init.trunc_normal_(self.pos_embed, std=0.02) + if embed_dim_labels > 0: + self.label_pos_embed = torch.nn.Parameter( + torch.zeros( + embed_dim_labels, + embed_dim_pos, + img_shape[0], + img_shape[1], + requires_grad=True, + ) + ) + torch.nn.init.trunc_normal_(self.label_pos_embed, std=0.02) + else: + self.pos_embed = None + + def forward( + self, x: torch.Tensor, labels: torch.Tensor | None = None + ) -> torch.Tensor: + x = x.reshape(-1, *x.shape[-3:]) + if self.noise_type == "isotropic": + lmax = self.conditional_model.itrans_up.lmax + mmax = self.conditional_model.itrans_up.mmax + noise = isotropic_noise( + (x.shape[0], self.embed_dim), + lmax, + mmax, + self.conditional_model.itrans_up, + device=x.device, + ) + elif self.noise_type == "gaussian": + noise = torch.randn( + [x.shape[0], self.embed_dim, *x.shape[-2:]], + device=x.device, + dtype=x.dtype, + ) + else: + raise ValueError(f"Invalid noise type: {self.noise_type}") + + if self.pos_embed is not None: + embedding_pos = self.pos_embed.repeat(noise.shape[0], 1, 1, 1) + if self.label_pos_embed is not None and labels is not None: + label_embedding_pos = torch.einsum( + "bl, lpxy -> bpxy", labels, self.label_pos_embed + ) + embedding_pos = embedding_pos + label_embedding_pos + else: + embedding_pos = None + + return self.conditional_model( + x, + Context( + embedding_scalar=None, + embedding_pos=embedding_pos, + labels=labels, + noise=noise, + ), + ) + + +def build( + params, + n_in_channels: int, + n_out_channels: int, + dataset_info: DatasetInfo, +): + sfno_net = get_lat_lon_sfnonet( + params=params, + in_chans=n_in_channels, + out_chans=n_out_channels, + img_shape=dataset_info.img_shape, + context_config=ContextConfig( + embed_dim_scalar=0, + embed_dim_pos=params.context_pos_embed_dim, + embed_dim_noise=params.noise_embed_dim, + embed_dim_labels=len(dataset_info.all_labels), + ), + ) + return NoiseConditionedSFNO( + sfno_net, + noise_type=params.noise_type, + embed_dim_noise=params.noise_embed_dim, + embed_dim_pos=params.context_pos_embed_dim, + embed_dim_labels=len(dataset_info.all_labels), + img_shape=dataset_info.img_shape, + ) diff --git a/fme/core/models/conditional_sfno/v1/test_layers.py b/fme/core/models/conditional_sfno/v1/test_layers.py new file mode 100644 index 000000000..b3448a666 --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/test_layers.py @@ -0,0 +1,66 @@ +import pytest +import torch + +from fme.core.device import get_device + +from .layers import ConditionalLayerNorm, Context, ContextConfig + + +@pytest.mark.parametrize("global_layer_norm", [True, False]) +@pytest.mark.parametrize("n_channels", [32]) +@pytest.mark.parametrize("embed_dim_scalar", [9, 0]) +@pytest.mark.parametrize("embed_dim_noise", [10, 0]) +@pytest.mark.parametrize("embed_dim_labels", [11, 0]) +@pytest.mark.parametrize("embed_dim_pos", [18, 0]) +@pytest.mark.parametrize("img_shape", [(8, 16)]) +def test_conditional_layer_norm( + n_channels: int, + img_shape: tuple[int, int], + global_layer_norm: bool, + embed_dim_scalar: int, + embed_dim_labels: int, + embed_dim_noise: int, + embed_dim_pos: int, +): + epsilon = 1e-6 + device = get_device() + conditional_layer_norm = ConditionalLayerNorm( + n_channels, + img_shape, + context_config=ContextConfig( + embed_dim_scalar=embed_dim_scalar, + embed_dim_labels=embed_dim_labels, + embed_dim_noise=embed_dim_noise, + embed_dim_pos=embed_dim_pos, + ), + global_layer_norm=global_layer_norm, + epsilon=epsilon, + ).to(device) + x = torch.randn(1, n_channels, *img_shape, device=device) * 5 + 2 + context_embedding_scalar = torch.randn(1, embed_dim_scalar, device=device) + context_embedding_labels = torch.randn(1, embed_dim_labels, device=device) + context_embedding_noise = torch.randn(1, embed_dim_noise, *img_shape, device=device) + context_embedding_pos = torch.randn(1, embed_dim_pos, *img_shape, device=device) + context = Context( + embedding_scalar=context_embedding_scalar, + noise=context_embedding_noise, + labels=context_embedding_labels, + embedding_pos=context_embedding_pos, + ) + output = conditional_layer_norm(x, context) + assert output.shape == x.shape + torch.testing.assert_close( + output.mean(), torch.tensor(0.0, device=device), atol=1e-3, rtol=0 + ) + torch.testing.assert_close( + output.std(), torch.tensor(1.0, device=device), atol=1e-3, rtol=0 + ) + if not global_layer_norm: + zero = torch.zeros(1, *img_shape, device=device) + torch.testing.assert_close(output.mean(dim=1), zero, atol=1e-3, rtol=0) + torch.testing.assert_close( + (((n_channels - 1) / n_channels) ** 0.5 * output.std(dim=1) - 1), + zero, + atol=1e-3, + rtol=0, + ) diff --git a/fme/core/models/conditional_sfno/v1/test_lora.py b/fme/core/models/conditional_sfno/v1/test_lora.py new file mode 100644 index 000000000..e08a5a0d3 --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/test_lora.py @@ -0,0 +1,23 @@ +import torch +from torch import nn + +from .lora import LoRAConv2d + + +def test_lora_conv2d_load_conv2d_checkpoint(): + conv = nn.Conv2d(8, 16, 3, padding=1) + lora = LoRAConv2d(8, 16, 3, padding=1) # default should not use/require lora + + lora.load_state_dict(conv.state_dict(), strict=True) + + x = torch.randn(2, 8, 32, 32) + with torch.no_grad(): + y0 = conv(x) + y1 = lora(x) + torch.testing.assert_close( + y0, + y1, + atol=1e-6, + rtol=0, + msg="Outputs do not match after loading Conv2d checkpoint", + ) diff --git a/fme/core/models/conditional_sfno/v1/test_s2convolutions.py b/fme/core/models/conditional_sfno/v1/test_s2convolutions.py new file mode 100644 index 000000000..f1f992b49 --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/test_s2convolutions.py @@ -0,0 +1,133 @@ +import dataclasses + +import pytest +import torch + +from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations + +from .s2convolutions import SpectralConvS2, _contract_dhconv + + +@dataclasses.dataclass +class BenchmarkResult: + ms_total: float + ms_per: float + max_alloc: int + max_reserved: int + y_shape: tuple + y_dtype: torch.dtype + + +def benchmark(fn, iters=10, warmup=1) -> BenchmarkResult: + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + torch.cuda.reset_peak_memory_stats() + starter = torch.cuda.Event(enable_timing=True) + ender = torch.cuda.Event(enable_timing=True) + + starter.record() + for _ in range(iters): + y = fn() + ender.record() + torch.cuda.synchronize() + + ms = starter.elapsed_time(ender) + return BenchmarkResult( + ms_total=ms, + ms_per=ms / iters, + max_alloc=torch.cuda.max_memory_allocated(), + max_reserved=torch.cuda.max_memory_reserved(), + y_shape=tuple(y.shape), + y_dtype=y.dtype, + ) + + +@pytest.mark.skipif( + get_device().type != "cuda", + reason=( + "This test is only relevant for CUDA since " + "it's testing speed of DHConv groups on GPU." + ), +) # noqa: E501 +def test_contract_dhconv_groups_are_faster(): + B = 2 + C = 512 + H = 180 + L = 360 + G = 8 + x = torch.randn(B, 1, C, H, L, dtype=torch.complex64, device=get_device()) + w = torch.randn(1, C, C, H, 2, dtype=torch.float32, device=get_device()) + + def contract_ungrouped(): + return _contract_dhconv(x, w) + + ungrouped_result = benchmark(contract_ungrouped) + + x_grouped = x.reshape(B, G, C // G, H, L) + w_grouped = torch.randn( + G, C // G, C // G, H, 2, dtype=torch.float32, device=get_device() + ) + + def contract_grouped(): + return _contract_dhconv(x_grouped, w_grouped) + + grouped_result = benchmark(contract_grouped) + + assert grouped_result.ms_per < 2 / G * ungrouped_result.ms_per, ( + "Expected grouped DHConv to be faster than ungrouped, but got " + f"{grouped_result.ms_per:.6f} seconds for grouped and " + f"{ungrouped_result.ms_per:.6f} seconds for ungrouped." + ) + assert grouped_result.max_alloc < ungrouped_result.max_alloc, ( + "Expected grouped DHConv to use less memory than ungrouped, but got " + f"{grouped_result.max_alloc/1024/1024:.2f} MB for grouped and " + f"{ungrouped_result.max_alloc/1024/1024:.2f} MB for ungrouped." + ) + + +def test_spectral_conv_s2_lora(): + in_channels = 8 + out_channels = in_channels + n_lat = 12 + n_lon = 24 + operations = LatLonOperations( + area_weights=torch.ones(n_lat, n_lon), + grid="legendre-gauss", + ) + sht = operations.get_real_sht() + isht = operations.get_real_isht() + + conv1 = SpectralConvS2( + forward_transform=sht, + inverse_transform=isht, + in_channels=in_channels, + out_channels=out_channels, + operator_type="dhconv", + use_tensorly=False, + ) + assert conv1.lora_A is None + assert conv1.lora_B is None + conv2 = SpectralConvS2( + forward_transform=sht, + inverse_transform=isht, + in_channels=in_channels, + out_channels=out_channels, + operator_type="dhconv", + use_tensorly=False, + lora_rank=4, + lora_alpha=8, + ) + assert conv2.lora_A is not None + assert conv2.lora_B is not None + + conv2.load_state_dict(conv1.state_dict(), strict=False) + x = torch.randn(2, in_channels, n_lat, n_lon) + y1, residual1 = conv1(x) + y2, residual2 = conv2(x) + + # initial outputs should be identical since LoRA starts at 0 + assert torch.allclose(y1, y2, atol=1e-6) + assert torch.allclose(residual1, residual2, atol=1e-6) diff --git a/fme/core/models/conditional_sfno/v1/test_sfnonet.py b/fme/core/models/conditional_sfno/v1/test_sfnonet.py new file mode 100644 index 000000000..3230d7c87 --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/test_sfnonet.py @@ -0,0 +1,223 @@ +import os +from types import SimpleNamespace + +import pytest +import torch +from torch import nn + +from fme.core.device import get_device +from fme.core.testing.regression import validate_tensor + +from .layers import Context, ContextConfig +from .sfnonet import get_lat_lon_sfnonet + +DIR = os.path.abspath(os.path.dirname(__file__)) + + +@pytest.mark.parametrize( + "conditional_embed_dim_scalar, conditional_embed_dim_labels, " + "conditional_embed_dim_noise, " + "conditional_embed_dim_pos, residual_filter_factor", + [ + (0, 0, 0, 0, 1), + (16, 8, 0, 0, 1), + (16, 0, 16, 0, 1), + (16, 15, 14, 13, 1), + (0, 0, 0, 16, 1), + (0, 0, 16, 0, 1), + (16, 0, 0, 0, 4), + ], +) +def test_can_call_sfnonet( + conditional_embed_dim_scalar: int, + conditional_embed_dim_labels: int, + conditional_embed_dim_noise: int, + conditional_embed_dim_pos: int, + residual_filter_factor: int, +): + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 4 + device = get_device() + params = SimpleNamespace( + embed_dim=16, + num_layers=2, + residual_filter_factor=residual_filter_factor, + filter_type="makani-linear", + ) + model = get_lat_lon_sfnonet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + context_config=ContextConfig( + embed_dim_scalar=conditional_embed_dim_scalar, + embed_dim_labels=conditional_embed_dim_labels, + embed_dim_noise=conditional_embed_dim_noise, + embed_dim_pos=conditional_embed_dim_pos, + ), + ).to(device) + x = torch.randn(n_samples, input_channels, *img_shape, device=device) + context_embedding = torch.randn( + n_samples, conditional_embed_dim_scalar, device=device + ) + context_embedding_labels = torch.randn( + n_samples, conditional_embed_dim_labels, device=device + ) + context_embedding_noise = torch.randn( + n_samples, conditional_embed_dim_noise, *img_shape, device=device + ) + context_embedding_pos = torch.randn( + n_samples, conditional_embed_dim_pos, *img_shape, device=device + ) + context = Context( + embedding_scalar=context_embedding, + labels=context_embedding_labels, + noise=context_embedding_noise, + embedding_pos=context_embedding_pos, + ) + output = model(x, context) + assert output.shape == (n_samples, output_channels, *img_shape) + + +def test_scale_factor_not_implemented(): + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + device = get_device() + params = SimpleNamespace(embed_dim=16, num_layers=2, scale_factor=2) + with pytest.raises(NotImplementedError): + # if this ever gets implemented, we need to instead test that the scale factor + # is used to determine the nlat/nlon of the image in the network + get_lat_lon_sfnonet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + context_config=ContextConfig( + embed_dim_scalar=0, + embed_dim_noise=0, + embed_dim_labels=0, + embed_dim_pos=0, + ), + ).to(device) + + +def test_sfnonet_output_is_unchanged(): + torch.manual_seed(0) + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 4 + conditional_embed_dim_scalar = 8 + conditional_embed_dim_labels = 4 + conditional_embed_dim_noise = 16 + conditional_embed_dim_pos = 0 + device = get_device() + params = SimpleNamespace( + embed_dim=16, num_layers=2, filter_type="linear", operator_type="dhconv" + ) + model = get_lat_lon_sfnonet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + context_config=ContextConfig( + embed_dim_scalar=conditional_embed_dim_scalar, + embed_dim_labels=conditional_embed_dim_labels, + embed_dim_noise=conditional_embed_dim_noise, + embed_dim_pos=conditional_embed_dim_pos, + ), + ).to(device) + # must initialize on CPU to get the same results on GPU + x = torch.randn(n_samples, input_channels, *img_shape).to(device) + context_embedding = torch.randn(n_samples, conditional_embed_dim_scalar).to(device) + context_embedding_labels = torch.randn( + n_samples, conditional_embed_dim_labels, device=device + ) + context_embedding_noise = torch.randn( + n_samples, conditional_embed_dim_noise, *img_shape, device=device + ).to(device) + context_embedding_pos = None + context = Context( + embedding_scalar=context_embedding, + labels=context_embedding_labels, + noise=context_embedding_noise, + embedding_pos=context_embedding_pos, + ) + with torch.no_grad(): + output = model(x, context) + validate_tensor( + output, + os.path.join(DIR, "testdata/test_sfnonet_output_is_unchanged.pt"), + ) + + +@pytest.mark.parametrize("normalize_big_skip", [True, False]) +def test_all_inputs_get_layer_normed(normalize_big_skip: bool): + torch.manual_seed(0) + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 4 + conditional_embed_dim_scalar = 8 + conditional_embed_dim_noise = 16 + conditional_embed_dim_labels = 3 + conditional_embed_dim_pos = 12 + device = get_device() + + class SetToZero(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, x): + return torch.zeros_like(x) + + original_layer_norm = nn.LayerNorm + try: + nn.LayerNorm = SetToZero + params = SimpleNamespace( + embed_dim=16, + num_layers=2, + normalize_big_skip=normalize_big_skip, + global_layer_norm=True, # so it uses nn.LayerNorm + operator_type="dhconv", + ) + model = get_lat_lon_sfnonet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + context_config=ContextConfig( + embed_dim_scalar=conditional_embed_dim_scalar, + embed_dim_noise=conditional_embed_dim_noise, + embed_dim_labels=conditional_embed_dim_labels, + embed_dim_pos=conditional_embed_dim_pos, + ), + ).to(device) + finally: + nn.LayerNorm = original_layer_norm + x = torch.full((n_samples, input_channels, *img_shape), torch.nan).to(device) + context_embedding = torch.randn(n_samples, conditional_embed_dim_scalar).to(device) + context_embedding_noise = torch.randn( + n_samples, conditional_embed_dim_noise, *img_shape + ).to(device) + context_embedding_labels = torch.randn(n_samples, conditional_embed_dim_labels).to( + device + ) + context_embedding_pos = torch.randn( + n_samples, conditional_embed_dim_pos, *img_shape + ).to(device) + context = Context( + embedding_scalar=context_embedding, + embedding_pos=context_embedding_pos, + noise=context_embedding_noise, + labels=context_embedding_labels, + ) + with torch.no_grad(): + output = model(x, context) + if normalize_big_skip: + assert not torch.isnan(output).any() + else: + assert torch.isnan(output).any() diff --git a/fme/core/models/conditional_sfno/v1/test_stochastic_sfno.py b/fme/core/models/conditional_sfno/v1/test_stochastic_sfno.py new file mode 100644 index 000000000..dbabd76f5 --- /dev/null +++ b/fme/core/models/conditional_sfno/v1/test_stochastic_sfno.py @@ -0,0 +1,67 @@ +import unittest.mock + +import pytest +import torch +from torch_harmonics import InverseRealSHT + +from fme.core.device import get_device + +from .stochastic_sfno import Context, NoiseConditionedSFNO, isotropic_noise + + +@pytest.mark.parametrize("nlat, nlon", [(8, 16), (64, 128)]) +def test_isotropic_noise(nlat: int, nlon: int): + torch.manual_seed(0) + n_batch = 1000 + embed_dim = 4 + leading_shape = (n_batch, embed_dim) + isht = InverseRealSHT(nlat, nlon, grid="legendre-gauss") + lmax = isht.lmax + mmax = isht.mmax + noise = isotropic_noise(leading_shape, lmax, mmax, isht, device=get_device()) + assert noise.shape == (n_batch, embed_dim, nlat, nlon) + assert noise.dtype == torch.float32 + torch.testing.assert_close( + noise.mean(), torch.tensor(0.0, device=noise.device), atol=2e-3, rtol=0.0 + ) + torch.testing.assert_close( + noise.std(), torch.tensor(1.0, device=noise.device), atol=5e-3, rtol=0.0 + ) + + +def test_noise_conditioned_sfno_conditioning(): + mock_sfno = unittest.mock.MagicMock() + img_shape = (32, 64) + n_noise = 16 + n_pos = 8 + n_labels = 4 + model = NoiseConditionedSFNO( + conditional_model=mock_sfno, + img_shape=img_shape, + noise_type="gaussian", # needed so we don't need a SHT in this test + embed_dim_noise=n_noise, + embed_dim_pos=n_pos, + embed_dim_labels=n_labels, + ) + batch_size = 2 + x = torch.randn(batch_size, 3, img_shape[0], img_shape[1]) + labels = torch.randn(batch_size, 4) + _ = model(x, labels=labels) + mock_sfno.assert_called() + args, _ = mock_sfno.call_args + conditioned_x = args[0] + assert conditioned_x.shape == (batch_size, 3, img_shape[0], img_shape[1]) + context = args[1] + assert isinstance(context, Context) + assert context.embedding_scalar is None + assert context.embedding_pos is not None + assert context.labels is not None + assert context.noise is not None + assert context.embedding_pos.shape == ( + batch_size, + n_pos, + img_shape[0], + img_shape[1], + ) + assert context.labels.shape == (batch_size, n_labels) + assert context.noise.shape == (batch_size, n_noise, img_shape[0], img_shape[1]) diff --git a/fme/core/models/conditional_sfno/v1/testdata/test_sfnonet_output_is_unchanged.pt b/fme/core/models/conditional_sfno/v1/testdata/test_sfnonet_output_is_unchanged.pt new file mode 100644 index 000000000..95ea10fc6 Binary files /dev/null and b/fme/core/models/conditional_sfno/v1/testdata/test_sfnonet_output_is_unchanged.pt differ diff --git a/fme/core/step/test_step.py b/fme/core/step/test_step.py index d1d0b33c7..1de124d88 100644 --- a/fme/core/step/test_step.py +++ b/fme/core/step/test_step.py @@ -156,6 +156,7 @@ def get_single_module_noise_conditioned_selector( type="NoiseConditionedSFNO", config=dataclasses.asdict( NoiseConditionedSFNOBuilder( + version="v1", embed_dim=4, noise_embed_dim=4, noise_type="isotropic", diff --git a/fme/diffusion/registry/sfno.py b/fme/diffusion/registry/sfno.py index 0fea77310..208ef6077 100644 --- a/fme/diffusion/registry/sfno.py +++ b/fme/diffusion/registry/sfno.py @@ -1,8 +1,11 @@ import dataclasses from typing import Literal -from fme.core.models.conditional_sfno.sfnonet import ContextConfig, get_lat_lon_sfnonet -from fme.core.models.conditional_sfno.sfnonet import ( +from fme.core.models.conditional_sfno.v1.sfnonet import ( + ContextConfig, + get_lat_lon_sfnonet, +) +from fme.core.models.conditional_sfno.v1.sfnonet import ( SphericalFourierNeuralOperatorNet as ConditionalSFNO, ) diff --git a/fme/diffusion/stepper.py b/fme/diffusion/stepper.py index c3a541fb0..6563754d6 100644 --- a/fme/diffusion/stepper.py +++ b/fme/diffusion/stepper.py @@ -34,7 +34,7 @@ from fme.core.generics.optimization import OptimizationABC from fme.core.generics.train_stepper import TrainStepperABC from fme.core.gridded_ops import GriddedOperations, LatLonOperations -from fme.core.models.conditional_sfno.layers import Context +from fme.core.models.conditional_sfno.v1.layers import Context from fme.core.normalizer import NormalizationConfig, StandardNormalizer from fme.core.ocean import Ocean, OceanConfig from fme.core.optimization import NullOptimization