diff --git a/fme/core/models/conditional_sfno/layers.py b/fme/core/models/conditional_sfno/layers.py index 47648d781..6873c3e9b 100644 --- a/fme/core/models/conditional_sfno/layers.py +++ b/fme/core/models/conditional_sfno/layers.py @@ -25,6 +25,7 @@ from torch.utils.checkpoint import checkpoint from fme.core.models.conditional_sfno.lora import LoRAConv2d +from fme.core.models.conditional_sfno.timer import CUDATimer, NullTimer from .activations import ComplexReLU from .contractions import compl_mul2d_fwd, compl_muladd2d_fwd @@ -163,35 +164,33 @@ def __init__( self.W_bias_labels = None if self.embed_dim_noise > 0: # no bias as it is already handled in the non-2d layers - self.W_scale_2d = nn.Conv2d( - self.embed_dim_noise, self.n_channels, kernel_size=1, bias=False + self.W_scale_2d = nn.Linear( + self.embed_dim_noise, self.n_channels, bias=False ) - self.W_bias_2d = nn.Conv2d( - self.embed_dim_noise, self.n_channels, kernel_size=1, bias=False + self.W_bias_2d = nn.Linear( + self.embed_dim_noise, self.n_channels, bias=False ) else: self.W_scale_2d = None self.W_bias_2d = None if self.embed_dim_pos > 0: # no bias as it is already handled in the non-2d layers - self.W_scale_pos = nn.Conv2d( - self.embed_dim_pos, self.n_channels, kernel_size=1, bias=False - ) - self.W_bias_pos = nn.Conv2d( - self.embed_dim_pos, self.n_channels, kernel_size=1, bias=False + self.W_scale_pos = nn.Linear( + self.embed_dim_pos, self.n_channels, bias=False ) + self.W_bias_pos = nn.Linear(self.embed_dim_pos, self.n_channels, bias=False) else: self.W_scale_pos = None self.W_bias_pos = None if global_layer_norm: self.norm = nn.LayerNorm( - (self.n_channels, img_shape[0], img_shape[1]), + (img_shape[1], img_shape[0], self.n_channels), eps=epsilon, elementwise_affine=elementwise_affine, ) else: - self.norm = ChannelLayerNorm( - self.n_channels, + self.norm = nn.LayerNorm( + (self.n_channels,), eps=epsilon, elementwise_affine=elementwise_affine, ) @@ -223,7 +222,12 @@ def reset_parameters(self): torch.nn.init.constant_(self.W_bias_pos.weight, 0.0) # no bias on 2d layers as it is already handled in the non-2d layers - def forward(self, x: torch.Tensor, context: Context) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + context: Context, + timer: CUDATimer | NullTimer | None = None, + ) -> torch.Tensor: """ Conditional Layer Normalization @@ -232,62 +236,74 @@ def forward(self, x: torch.Tensor, context: Context) -> torch.Tensor: Args: x: The input tensor to normalize, of shape - (batch_size, channels, height, width). + (batch_size, width, height, channels). context: The context to condition on. Returns: - The normalized tensor, of shape (batch_size, channels, height, width). + The normalized tensor, of shape (batch_size, width, height, channels). """ + if timer is None: + timer = NullTimer() if context.labels is None and ( self.W_scale_labels is not None or self.W_bias_labels is not None ): raise ValueError("labels must be provided") - if self.W_scale is not None: - if context.embedding_scalar is None: - raise ValueError("embedding_scalar must be provided") - scale: torch.Tensor = ( - self.W_scale(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) - ) - else: - scale = torch.ones( - list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype - ) + with timer.context("layer_norm_compute_scaling_and_bias"): + if self.W_scale is not None: + if context.embedding_scalar is None: + raise ValueError("embedding_scalar must be provided") + scale: torch.Tensor = ( + self.W_scale(context.embedding_scalar).unsqueeze(-2).unsqueeze(-2) + ) + else: + scale = torch.ones( + list(x.shape[:-3]) + [1, 1, x.shape[-1]], + device=x.device, + dtype=x.dtype, + ) - if self.W_scale_2d is not None: - if context.noise is None: - raise ValueError("embedding_2d must be provided") - scale = scale + self.W_scale_2d(context.noise) - if self.W_bias is not None: - if context.embedding_scalar is None: - raise ValueError("embedding_scalar must be provided") - bias: torch.Tensor = ( - self.W_bias(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1) - ) - else: - bias = torch.zeros( - list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype - ) + if self.W_scale_2d is not None: + if context.noise is None: + raise ValueError("embedding_2d must be provided") + scale = scale + self.W_scale_2d(context.noise) + if self.W_bias is not None: + if context.embedding_scalar is None: + raise ValueError("embedding_scalar must be provided") + bias: torch.Tensor = ( + self.W_bias(context.embedding_scalar).unsqueeze(-2).unsqueeze(-2) + ) + else: + bias = torch.zeros( + list(x.shape[:-3]) + [1, 1, x.shape[-1]], + device=x.device, + dtype=x.dtype, + ) - if self.W_scale_labels is not None: - scale = scale + self.W_scale_labels(context.labels).unsqueeze(-1).unsqueeze( - -1 - ) - if self.W_bias_labels is not None: - bias = bias + self.W_bias_labels(context.labels).unsqueeze(-1).unsqueeze(-1) - if self.W_bias_2d is not None: - if context.noise is None: - raise ValueError("embedding_2d must be provided") - bias = bias + self.W_bias_2d(context.noise) - if self.W_scale_pos is not None: - if context.embedding_pos is None: - raise ValueError("embedding_pos must be provided") - scale = scale + self.W_scale_pos(context.embedding_pos) - if self.W_bias_pos is not None: - if context.embedding_pos is None: - raise ValueError("embedding_pos must be provided") - bias = bias + self.W_bias_pos(context.embedding_pos) - x_norm: torch.Tensor = self.norm(x) - return x_norm * scale + bias + if self.W_scale_labels is not None: + scale = scale + self.W_scale_labels(context.labels).unsqueeze( + -2 + ).unsqueeze(-2) + if self.W_bias_labels is not None: + bias = bias + self.W_bias_labels(context.labels).unsqueeze( + -2 + ).unsqueeze(-2) + if self.W_bias_2d is not None: + if context.noise is None: + raise ValueError("embedding_2d must be provided") + bias = bias + self.W_bias_2d(context.noise) + if self.W_scale_pos is not None: + if context.embedding_pos is None: + raise ValueError("embedding_pos must be provided") + scale = scale + self.W_scale_pos(context.embedding_pos) + if self.W_bias_pos is not None: + if context.embedding_pos is None: + raise ValueError("embedding_pos must be provided") + bias = bias + self.W_bias_pos(context.embedding_pos) + with timer.context("layer_norm_normalize"): + x_norm: torch.Tensor = self.norm(x) + with timer.context("layer_norm_apply_scaling_and_bias"): + return_value = x_norm * scale + bias + return return_value @torch.jit.script diff --git a/fme/core/models/conditional_sfno/lora.py b/fme/core/models/conditional_sfno/lora.py index ea548ef32..952bab2a7 100644 --- a/fme/core/models/conditional_sfno/lora.py +++ b/fme/core/models/conditional_sfno/lora.py @@ -1,12 +1,10 @@ from __future__ import annotations -import math - import torch import torch.nn as nn -class LoRAConv2d(nn.Conv2d): +class LoRAConv2d(nn.Module): """ Drop-in Conv2d with optional LoRA. @@ -37,105 +35,149 @@ def __init__( lora_alpha: float | None = None, lora_dropout: float = 0.0, ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - self.lora_down: nn.Conv2d | None = None - self.lora_up: nn.Conv2d | None = None - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, + super().__init__() + self._linear = nn.Linear( + in_channels, + out_channels, bias=bias, - padding_mode=padding_mode, - **factory_kwargs, + device=device, + dtype=dtype, ) - if lora_rank < 0: - raise ValueError(f"lora_rank must be >= 0, got {lora_rank}") - if lora_dropout < 0.0: - raise ValueError(f"lora_dropout must be >= 0, got {lora_dropout}") - - self.lora_rank = int(lora_rank) - self.lora_alpha = ( - float(lora_alpha) if lora_alpha is not None else float(lora_rank) - ) - self.lora_dropout_p = float(lora_dropout) - - self._lora_merged = False - - if self.lora_rank > 0: - # Group-compatible LoRA via two convs: - # down: 1x1 grouped conv: in_channels -> (groups * r), groups=groups - # up: kxk grouped conv: (groups * r) -> out_channels, groups=groups - # This produces a delta with the same grouped structure as the base conv. - mid_channels = self.groups * self.lora_rank - - self.lora_down = nn.Conv2d( - in_channels=self.in_channels, - out_channels=mid_channels, - kernel_size=1, - stride=1, - padding=0, - dilation=1, - groups=self.groups, - bias=False, - **factory_kwargs, - ) - self.lora_up = nn.Conv2d( - in_channels=mid_channels, - out_channels=self.out_channels, - kernel_size=self.kernel_size, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - groups=self.groups, - bias=False, - padding_mode=self.padding_mode, - **factory_kwargs, - ) - - self.lora_dropout = ( - nn.Dropout(p=self.lora_dropout_p) - if self.lora_dropout_p > 0 - else nn.Identity() - ) - - # Scaling as in LoRA: alpha / r - self.lora_scaling = self.lora_alpha / float(self.lora_rank) - else: - self.lora_dropout = nn.Identity() - self.lora_scaling = 0.0 - self.reset_lora_parameters() # base parameters already reset in super init - - def reset_parameters(self) -> None: - super().reset_parameters() - self.reset_lora_parameters() - - def reset_lora_parameters(self): - # Init: down ~ Kaiming, up = 0 so the module starts - # identical to base Conv2d. - if self.lora_down is not None: - nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - if self.lora_up is not None: - nn.init.zeros_(self.lora_up.weight) - - def extra_repr(self) -> str: - base = super().extra_repr() - if self.lora_rank > 0: - return ( - f"{base}, lora_rank={self.lora_rank}, lora_alpha={self.lora_alpha}, " - f"lora_dropout={self.lora_dropout_p}, lora_merged={self._lora_merged}" - ) - return f"{base}, lora_rank=0" - def forward(self, x: torch.Tensor) -> torch.Tensor: - y = super().forward(x) - if self.lora_rank == 0 or self._lora_merged: - return y - assert self.lora_down is not None and self.lora_up is not None - return ( - y + self.lora_up(self.lora_down(self.lora_dropout(x))) * self.lora_scaling - ) + return self._linear(x) + + +# class LoRAConv2d(nn.Conv2d): +# """ +# Drop-in Conv2d with optional LoRA. + +# - API matches torch.nn.Conv2d, with extra args: +# lora_rank: int = 0 (0 disables LoRA) +# lora_alpha: float = None (defaults to lora_rank) +# lora_dropout: float = 0.0 + +# - Can load a checkpoint saved from nn.Conv2d even when lora_rank > 0 +# (i.e., state_dict only has "weight"/"bias"). +# """ + +# def __init__( +# self, +# in_channels: int, +# out_channels: int, +# kernel_size: int | tuple[int, int], +# stride: int | tuple[int, int] = 1, +# padding: int | tuple[int, int] = 0, +# dilation: int | tuple[int, int] = 1, +# groups: int = 1, +# bias: bool = True, +# padding_mode: str = "zeros", +# device=None, +# dtype=None, +# *, +# lora_rank: int = 0, +# lora_alpha: float | None = None, +# lora_dropout: float = 0.0, +# ) -> None: +# factory_kwargs = {"device": device, "dtype": dtype} +# self.lora_down: nn.Conv2d | None = None +# self.lora_up: nn.Conv2d | None = None +# super().__init__( +# in_channels=in_channels, +# out_channels=out_channels, +# kernel_size=kernel_size, +# stride=stride, +# padding=padding, +# dilation=dilation, +# groups=groups, +# bias=bias, +# padding_mode=padding_mode, +# **factory_kwargs, +# ) + +# if lora_rank < 0: +# raise ValueError(f"lora_rank must be >= 0, got {lora_rank}") +# if lora_dropout < 0.0: +# raise ValueError(f"lora_dropout must be >= 0, got {lora_dropout}") + +# self.lora_rank = int(lora_rank) +# self.lora_alpha = ( +# float(lora_alpha) if lora_alpha is not None else float(lora_rank) +# ) +# self.lora_dropout_p = float(lora_dropout) + +# self._lora_merged = False + +# if self.lora_rank > 0: +# # Group-compatible LoRA via two convs: +# # down: 1x1 grouped conv: in_channels -> (groups * r), groups=groups +# # up: kxk grouped conv: (groups * r) -> out_channels, groups=groups +# # This produces a delta with the same grouped structure as the base conv. +# mid_channels = self.groups * self.lora_rank + +# self.lora_down = nn.Conv2d( +# in_channels=self.in_channels, +# out_channels=mid_channels, +# kernel_size=1, +# stride=1, +# padding=0, +# dilation=1, +# groups=self.groups, +# bias=False, +# **factory_kwargs, +# ) +# self.lora_up = nn.Conv2d( +# in_channels=mid_channels, +# out_channels=self.out_channels, +# kernel_size=self.kernel_size, +# stride=self.stride, +# padding=self.padding, +# dilation=self.dilation, +# groups=self.groups, +# bias=False, +# padding_mode=self.padding_mode, +# **factory_kwargs, +# ) + +# self.lora_dropout = ( +# nn.Dropout(p=self.lora_dropout_p) +# if self.lora_dropout_p > 0 +# else nn.Identity() +# ) + +# # Scaling as in LoRA: alpha / r +# self.lora_scaling = self.lora_alpha / float(self.lora_rank) +# else: +# self.lora_dropout = nn.Identity() +# self.lora_scaling = 0.0 +# self.reset_lora_parameters() # base parameters already reset in super init + +# def reset_parameters(self) -> None: +# super().reset_parameters() +# self.reset_lora_parameters() + +# def reset_lora_parameters(self): +# # Init: down ~ Kaiming, up = 0 so the module starts +# # identical to base Conv2d. +# if self.lora_down is not None: +# nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) +# if self.lora_up is not None: +# nn.init.zeros_(self.lora_up.weight) + +# def extra_repr(self) -> str: +# base = super().extra_repr() +# if self.lora_rank > 0: +# return ( +# f"{base}, lora_rank={self.lora_rank}, lora_alpha={self.lora_alpha}, " +# f"lora_dropout={self.lora_dropout_p}, lora_merged={self._lora_merged}" +# ) +# return f"{base}, lora_rank=0" + +# def forward(self, x: torch.Tensor) -> torch.Tensor: +# y = super().forward(x) +# if self.lora_rank == 0 or self._lora_merged: +# return y +# assert self.lora_down is not None and self.lora_up is not None +# return ( +# y + self.lora_up(self.lora_down(self.lora_dropout(x))) * self.lora_scaling +# ) diff --git a/fme/core/models/conditional_sfno/s2convolutions.py b/fme/core/models/conditional_sfno/s2convolutions.py index 9d3960b56..fc9f786a1 100644 --- a/fme/core/models/conditional_sfno/s2convolutions.py +++ b/fme/core/models/conditional_sfno/s2convolutions.py @@ -18,11 +18,12 @@ import math import torch import torch.nn as nn -import torch.nn.functional as F import torch_harmonics as th import torch_harmonics.distributed as thd +from fme.core.models.conditional_sfno.timer import CUDATimer, NullTimer + # import convenience functions for factorized tensors from .activations import ComplexReLU @@ -48,6 +49,7 @@ def _contract_lora( Performs LoRA update contraction. Args: + lora_A: LoRA A matrix of shape (in_channels, rank, nlat, 2) lora_B: LoRA B matrix of shape (rank, out_channels, nlat, 2) x: Complex input tensor of shape @@ -55,10 +57,9 @@ def _contract_lora( """ lora_A = torch.view_as_complex(lora_A) lora_B = torch.view_as_complex(lora_B) - return torch.einsum("irx,rox,bixy->boxy", lora_A, lora_B, x) + return torch.einsum("girx,grox,bgixy->bgoxy", lora_A, lora_B, x) -@torch.jit.script def _contract_dhconv( xc: torch.Tensor, weight: torch.Tensor ) -> torch.Tensor: # pragma: no cover @@ -67,11 +68,11 @@ def _contract_dhconv( 'a' and 'b'. Args: - xc: Complex input tensor of shape (batch_size, in_channels, nlat, nlon) - weight: Weight tensor of shape (in_channels, out_channels, nlat, 2) + xc: Complex input tensor of shape (batch_size, group, in_channels, nlat, nlon) + weight: Weight tensor of shape (group, in_channels, out_channels, nlat, 2) """ wc = torch.view_as_complex(weight) - return torch.einsum("bixy,iox->boxy", xc, wc) + return torch.einsum("byxgi,gxoi->byxgo", xc, wc) class SpectralConvS2(nn.Module): @@ -88,6 +89,7 @@ def __init__( inverse_transform, in_channels, out_channels, + num_groups: int = 1, scale="auto", operator_type="diagonal", rank=0.2, @@ -123,6 +125,10 @@ def __init__( "Currently only in_channels == out_channels is supported." ) + assert in_channels % num_groups == 0 + assert out_channels % num_groups == 0 + self.num_groups = num_groups + if in_channels != out_channels: raise NotImplementedError( "Currently only in_channels == out_channels is supported." @@ -167,35 +173,42 @@ def __init__( self.mpad = 0 if scale == "auto": - scale = math.sqrt(1 / (in_channels)) * torch.ones(self.modes_lat_local, 2) + scale = math.sqrt(1 / (in_channels)) * torch.ones( + self.modes_lat_local, 1, 1, 2 + ) # seemingly the first weight is not really complex, so we need to account for that - scale[0, :] *= math.sqrt(2.0) + scale[0, :, :, :] *= math.sqrt(2.0) - weight_shape = [in_channels, out_channels, self.modes_lat_local] + weight_shape = [ + num_groups, + self.modes_lat_local, + out_channels // num_groups, + in_channels // num_groups, + ] assert factorization == "ComplexDense" self.weight = nn.Parameter(scale * torch.randn(*weight_shape, 2)) self.weight.is_shared_mp = ["matmul", "w"] if lora_rank > 0: - if self.weight.shape != ( - in_channels, - out_channels, - self.modes_lat_local, - 2, - ): - raise NotImplementedError( - "LoRA is only implemented for dhconv with unpadded weights." - ) - if use_tensorly: - raise NotImplementedError( - "LoRA is not implemented for tensorly factorized weights." - ) self.lora_A = nn.Parameter( - scale * torch.randn(in_channels, lora_rank, self.modes_lat_local, 2) + scale + * torch.randn( + num_groups, + in_channels // num_groups, + lora_rank, + self.modes_lat_local, + 2, + ) ) self.lora_B = nn.Parameter( - torch.zeros(lora_rank, out_channels, self.modes_lat_local, 2) + torch.zeros( + num_groups, + lora_rank, + out_channels // num_groups, + self.modes_lat_local, + 2, + ) ) self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank self.lora_scaling = self.lora_alpha / lora_rank @@ -205,44 +218,57 @@ def __init__( self.lora_scaling = 0.0 if bias: - self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) + self.bias = nn.Parameter(torch.zeros(1, 1, 1, out_channels)) + self.out_channels = out_channels - def forward(self, x): # pragma: no cover + def forward( + self, x, timer: CUDATimer | NullTimer | None = None + ): # pragma: no cover + if timer is None: + timer = NullTimer() dtype = x.dtype residual = x x = x.float() - B, C, H, W = x.shape with torch.amp.autocast("cuda", enabled=False): - x = self.forward_transform(x.float()) + with timer.context("forward_transform"): + x = self.forward_transform(x.float(), timer=timer).contiguous() if self._round_trip_residual: - x = x.contiguous() - residual = self.inverse_transform(x) - residual = residual.to(dtype) + with timer.context("round_trip_residual"): + residual = self.inverse_transform(x).contiguous() + residual = residual.to(dtype) + + B, W, H, C = x.shape + assert C % self.num_groups == 0 + x = x.reshape(B, W, H, self.num_groups, C // self.num_groups) if self.lora_A is not None and self.lora_B is not None: - lora_update = _contract_lora( - self.lora_A, - self.lora_B, - x[..., : self.modes_lat_local, : self.modes_lon_local], - ) + with timer.context("lora_update"): + lora_update = _contract_lora( + self.lora_A, + self.lora_B, + x, + ) else: lora_update = 0.0 - # approach with unpadded weights - xp = torch.zeros_like(x) - xp[..., : self.modes_lat_local, : self.modes_lon_local] = _contract_dhconv( - x[..., : self.modes_lat_local, : self.modes_lon_local], - self.weight, - ) - xp = xp + self.lora_scaling * lora_update - x = xp.contiguous() + with timer.context("dhconv"): + xp = torch.zeros_like(x) + xp[:] = _contract_dhconv( + x, + self.weight, + ) + xp = xp + self.lora_scaling * lora_update + xp = xp.reshape(B, W, H, self.out_channels) + x = xp.contiguous() with torch.amp.autocast("cuda", enabled=False): - x = self.inverse_transform(x) + with timer.context("inverse_transform"): + x = self.inverse_transform(x, timer=timer).contiguous() if hasattr(self, "bias"): - x = x + self.bias + with timer.context("add_bias"): + x = x + self.bias x = x.type(dtype) diff --git a/fme/core/models/conditional_sfno/sfnonet.py b/fme/core/models/conditional_sfno/sfnonet.py index d52e9ac57..16fac3b45 100644 --- a/fme/core/models/conditional_sfno/sfnonet.py +++ b/fme/core/models/conditional_sfno/sfnonet.py @@ -24,6 +24,8 @@ import torch_harmonics as th from torch.utils.checkpoint import checkpoint +from fme.core.models.conditional_sfno.timer import CUDATimer, NullTimer + from .initialization import trunc_normal_ # wrap fft, to unify interface to spectral transforms @@ -115,6 +117,7 @@ def __init__( filter_residual=filter_residual, lora_rank=lora_rank, lora_alpha=lora_alpha, + num_groups=num_groups, ) elif filter_type == "makani-linear": self.filter = SpectralConv( @@ -152,8 +155,8 @@ def __init__( else: raise (NotImplementedError) - def forward(self, x): - return self.filter(x) + def forward(self, x, timer: CUDATimer | None = None): + return self.filter(x, timer=timer) class FourierNeuralOperatorBlock(nn.Module): @@ -294,44 +297,52 @@ def __init__( lora_alpha=lora_alpha, ) - def forward(self, x, context_embedding): - x_norm = torch.zeros_like(x) - x_norm[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] = self.norm0( - x[..., : self.input_shape_loc[0], : self.input_shape_loc[1]], - context_embedding, - ) - x, residual = self.filter(x_norm) + def forward(self, x, context_embedding, timer: CUDATimer | NullTimer | None = None): + if timer is None: + timer = NullTimer() + with timer.context("norm0"): + x_norm = self.norm0( + x, + context_embedding, + timer=timer, + ) + with timer.context("filter"): + x, residual = self.filter(x_norm, timer=timer) if hasattr(self, "inner_skip"): - if self.concat_skip: - x = torch.cat((x, self.inner_skip(residual)), dim=1) - x = self.inner_skip_conv(x) - else: - x = x + self.inner_skip(residual) + with timer.context("inner_skip"): + if self.concat_skip: + x = torch.cat((x, self.inner_skip(residual)), dim=1) + x = self.inner_skip_conv(x) + else: + x = x + self.inner_skip(residual) if hasattr(self, "act_layer"): - x = self.act_layer(x) + with timer.context("activation"): + x = self.act_layer(x) - x_norm = torch.zeros_like(x) - x_norm[..., : self.output_shape_loc[0], : self.output_shape_loc[1]] = ( - self.norm1( - x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]], + with timer.context("norm1"): + x_norm = self.norm1( + x, context_embedding, + timer=timer, ) - ) - x = x_norm + x = x_norm if hasattr(self, "mlp"): - x = self.mlp(x) + with timer.context("mlp"): + x = self.mlp(x) - x = self.drop_path(x) + with timer.context("drop_path"): + x = self.drop_path(x) if hasattr(self, "outer_skip"): - if self.concat_skip: - x = torch.cat((x, self.outer_skip(residual)), dim=1) - x = self.outer_skip_conv(x) - else: - x = x + self.outer_skip(residual) + with timer.context("outer_skip"): + if self.concat_skip: + x = torch.cat((x, self.outer_skip(residual)), dim=1) + x = self.outer_skip_conv(x) + else: + x = x + self.outer_skip(residual) return x diff --git a/fme/core/models/conditional_sfno/sht.py b/fme/core/models/conditional_sfno/sht.py new file mode 100644 index 000000000..c121811da --- /dev/null +++ b/fme/core/models/conditional_sfno/sht.py @@ -0,0 +1,223 @@ +# flake8: noqa +# fmt: off +# isort: skip_file + +""" +This file contains a fix that we needed to get the SFNO to work on multiple +unroll steps in multiprocessing (e.g. multi-GPU mode.) We forked this code from +the torch harmonics sht.py file [*]. + +[*] https://github.com/NVIDIA/torch-harmonics/blob/17eefa53468d1a885d72087918eba905fa53e10a/torch_harmonics/sht.py +""" + + +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import torch +import torch.nn as nn +import torch.fft + +from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights +from torch_harmonics.legendre import _precompute_legpoly + +from fme.core.device import get_device +from fme.core.models.conditional_sfno.timer import CUDATimer, NullTimer + + +class RealSHT(nn.Module): + """ + Defines a module for computing the forward (real-valued) SHT. + Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. + The SHT is applied to the last two dimensions of the input + + [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. + [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True): + """ + Initializes the SHT Layer, precomputing the necessary quadrature weights + + Parameters: + nlat: input grid resolution in the latitudinal direction + nlon: input grid resolution in the longitudinal direction + grid: grid in the latitude direction (for now only tensor product grids are supported) + """ + + super().__init__() + + self.nlat = nlat + self.nlon = nlon + self.grid = grid + self.norm = norm + self.csphase = csphase + + # TODO: include assertions regarding the dimensions + + # compute quadrature points + if self.grid == "legendre-gauss": + cost, w = legendre_gauss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "lobatto": + cost, w = lobatto_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat-1 + elif self.grid == "equiangular": + cost, w = clenshaw_curtiss_weights(nlat, -1, 1) + # cost, w = fejer2_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "healpix": + raise(NotImplementedError("'healpix' grid not supported by InverseRealVectorSHT")) + else: + raise(ValueError("Unknown quadrature mode")) + + # apply cosine transform and flip them + tq = torch.flip(torch.arccos(cost), dims=(0,)) + + # determine the dimensions + self.mmax = mmax or self.nlon // 2 + 1 + + # combine quadrature weights with the legendre weights + pct = torch.as_tensor(_precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)) + weights = torch.einsum('mlk,k->mlk', pct, w).contiguous() + + # remember quadrature weights + self.weights = weights.float().to(get_device()) + + def extra_repr(self): + """ + Pretty print module + """ + return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' + + def forward(self, x: torch.Tensor, timer: CUDATimer | NullTimer = NullTimer()): + # last dims [w, h, c] + assert(x.shape[-2] == self.nlat) + assert(x.shape[-3] == self.nlon) + with torch.autocast("cuda", enabled=False): + # rfft and view_as_complex don't support BF16, see https://github.com/pytorch/pytorch/issues/117844 + if x.dtype != torch.float32: + with timer.context("forward_transform_cast_input"): + x = x.float() + + # apply real fft in the longitudinal direction + with timer.context("forward_transform_rfft"): + x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-3, norm="forward") + x = x.contiguous() + # do the Legendre-Gauss quadrature + x = torch.view_as_real(x) + + # contraction + weights = self.weights.to(x.device).to(x.dtype) + with timer.context("forward_transform_quadrature"): + rl = torch.einsum('...mkc, mlk->...mlc', x[..., 0], weights) + im = torch.einsum('...mkc, mlk->...mlc', x[..., 1], weights) + xout = torch.stack((rl, im), -1) + x = torch.view_as_complex(xout) + + return x + + +class InverseRealSHT(nn.Module): + """ + Defines a module for computing the inverse (real-valued) SHT. + Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. + nlat, nlon: Output dimensions + lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions + + [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. + [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True): + + super().__init__() + + self.nlat = nlat + self.nlon = nlon + self.grid = grid + self.norm = norm + self.csphase = csphase + + # compute quadrature points + if self.grid == "legendre-gauss": + cost, _ = legendre_gauss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "lobatto": + cost, _ = lobatto_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat-1 + elif self.grid == "equiangular": + cost, _ = clenshaw_curtiss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "healpix": + raise(NotImplementedError("'healpix' grid not supported by RealVectorSHT")) + else: + raise(ValueError("Unknown quadrature mode")) + + # apply cosine transform and flip them + t = torch.flip(torch.arccos(cost), dims=(0,)) + + # determine the dimensions + self.mmax = mmax or self.nlon // 2 + 1 + + pct = torch.as_tensor(_precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)) + + # register buffer + self.pct = pct.float().to(get_device()) + + def extra_repr(self): + """ + Pretty print module + """ + return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' + + def forward(self, x: torch.Tensor, timer: CUDATimer | NullTimer = NullTimer()): + + assert(x.shape[-2] == self.lmax) + assert(x.shape[-3] == self.mmax) + + with torch.autocast("cuda", enabled=False): + # irfft and view_as_complex don't support BF16, see https://github.com/pytorch/pytorch/issues/117844 + # Evaluate associated Legendre functions on the output nodes + x = torch.view_as_real(x).float() + + pct = self.pct.to(x.device).to(x.dtype) + with timer.context("inverse_transform_quadrature"): + rl = torch.einsum('...mlc, mlk->...mkc', x[..., 0], pct ) + im = torch.einsum('...mlc, mlk->...mkc', x[..., 1], pct ) + xs = torch.stack((rl, im), -1) + + with timer.context("inverse_transform_irfft"): + # apply the inverse (real) FFT + x = torch.view_as_complex(xs) + x = torch.fft.irfft(x, n=self.nlon, dim=-3, norm="forward") + + return x diff --git a/fme/core/models/conditional_sfno/test_s2convolutions.py b/fme/core/models/conditional_sfno/test_s2convolutions.py index be102d628..cc6fcd1d6 100644 --- a/fme/core/models/conditional_sfno/test_s2convolutions.py +++ b/fme/core/models/conditional_sfno/test_s2convolutions.py @@ -1,8 +1,93 @@ +import dataclasses + +import pytest import torch +from fme.core.device import get_device from fme.core.gridded_ops import LatLonOperations from fme.core.models.conditional_sfno.s2convolutions import SpectralConvS2 +from .s2convolutions import _contract_dhconv + + +@dataclasses.dataclass +class BenchmarkResult: + ms_total: float + ms_per: float + max_alloc: int + max_reserved: int + y_shape: tuple + y_dtype: torch.dtype + + +def benchmark(fn, iters=10, warmup=1) -> BenchmarkResult: + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + torch.cuda.reset_peak_memory_stats() + starter = torch.cuda.Event(enable_timing=True) + ender = torch.cuda.Event(enable_timing=True) + + starter.record() + for _ in range(iters): + y = fn() + ender.record() + torch.cuda.synchronize() + + ms = starter.elapsed_time(ender) + return BenchmarkResult( + ms_total=ms, + ms_per=ms / iters, + max_alloc=torch.cuda.max_memory_allocated(), + max_reserved=torch.cuda.max_memory_reserved(), + y_shape=tuple(y.shape), + y_dtype=y.dtype, + ) + + +@pytest.mark.skipif( + get_device().type != "cuda", + reason=( + "This test is only relevant for CUDA since " + "it's testing speed of DHConv groups on GPU." + ), +) # noqa: E501 +def test_contract_dhconv_groups_are_faster(): + B = 2 + C = 512 + H = 180 + L = 360 + G = 8 + x = torch.randn(B, 1, C, H, L, dtype=torch.complex64, device=get_device()) + w = torch.randn(1, C, C, H, 2, dtype=torch.float32, device=get_device()) + + def contract_ungrouped(): + return _contract_dhconv(x, w) + + ungrouped_result = benchmark(contract_ungrouped) + + x_grouped = x.reshape(B, G, C // G, H, L) + w_grouped = torch.randn( + G, C // G, C // G, H, 2, dtype=torch.float32, device=get_device() + ) + + def contract_grouped(): + return _contract_dhconv(x_grouped, w_grouped) + + grouped_result = benchmark(contract_grouped) + + assert grouped_result.ms_per < 2 / G * ungrouped_result.ms_per, ( + "Expected grouped DHConv to be faster than ungrouped, but got " + f"{grouped_result.ms_per:.6f} seconds for grouped and " + f"{ungrouped_result.ms_per:.6f} seconds for ungrouped." + ) + assert grouped_result.max_alloc < ungrouped_result.max_alloc, ( + "Expected grouped DHConv to use less memory than ungrouped, but got " + f"{grouped_result.max_alloc/1024/1024:.2f} MB for grouped and " + f"{ungrouped_result.max_alloc/1024/1024:.2f} MB for ungrouped." + ) + def test_spectral_conv_s2_lora(): in_channels = 8 diff --git a/fme/core/models/conditional_sfno/test_sfnonet.py b/fme/core/models/conditional_sfno/test_sfnonet.py index 3230d7c87..6a8881e37 100644 --- a/fme/core/models/conditional_sfno/test_sfnonet.py +++ b/fme/core/models/conditional_sfno/test_sfnonet.py @@ -1,3 +1,4 @@ +import dataclasses import os from types import SimpleNamespace @@ -9,7 +10,9 @@ from fme.core.testing.regression import validate_tensor from .layers import Context, ContextConfig -from .sfnonet import get_lat_lon_sfnonet +from .sfnonet import FourierNeuralOperatorBlock, get_lat_lon_sfnonet +from .sht import InverseRealSHT, RealSHT +from .timer import CUDATimer DIR = os.path.abspath(os.path.dirname(__file__)) @@ -221,3 +224,136 @@ def forward(self, x): assert not torch.isnan(output).any() else: assert torch.isnan(output).any() + + +@dataclasses.dataclass +class BenchmarkResult: + ms_total: float + ms_per: float + max_alloc: int + max_reserved: int + y_shape: tuple + y_dtype: torch.dtype + + +def benchmark(fn, iters=10, warmup=1) -> BenchmarkResult: + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + torch.cuda.reset_peak_memory_stats() + starter = torch.cuda.Event(enable_timing=True) + ender = torch.cuda.Event(enable_timing=True) + + starter.record() + for _ in range(iters): + y = fn() + ender.record() + torch.cuda.synchronize() + + ms = starter.elapsed_time(ender) + return BenchmarkResult( + ms_total=ms, + ms_per=ms / iters, + max_alloc=torch.cuda.max_memory_allocated(), + max_reserved=torch.cuda.max_memory_reserved(), + y_shape=tuple(y.shape), + y_dtype=y.dtype, + ) + + +@pytest.mark.skipif( + get_device().type != "cuda", + reason=( + "This test is only relevant for CUDA since " + "it's testing speed of SFNO blocks on GPU." + ), +) # noqa: E501 +def test_block_speed(): + B = 2 + C = 512 + H = 180 + L = 360 + G = 8 + device = get_device() + conditional_embed_dim_scalar = 0 + conditional_embed_dim_noise = 64 + conditional_embed_dim_labels = 3 + conditional_embed_dim_pos = 32 + embedding_scalar = None + context_embedding_noise = torch.randn(B, L, H, conditional_embed_dim_noise).to( + device + ) + context_embedding_labels = torch.randn(B, conditional_embed_dim_labels).to(device) + context_embedding_pos = torch.randn(B, L, H, conditional_embed_dim_pos).to(device) + context = Context( + embedding_scalar=embedding_scalar, + embedding_pos=context_embedding_pos, + noise=context_embedding_noise, + labels=context_embedding_labels, + ) + x = torch.randn(B, L, H, C, device=get_device()) + forward = RealSHT(nlat=H, nlon=L) + inverse = InverseRealSHT(nlat=H, nlon=L) + context_config = ContextConfig( + embed_dim_scalar=conditional_embed_dim_scalar, + embed_dim_noise=conditional_embed_dim_noise, + embed_dim_labels=conditional_embed_dim_labels, + embed_dim_pos=conditional_embed_dim_pos, + ) + block = FourierNeuralOperatorBlock( + forward_transform=forward, + inverse_transform=inverse, + embed_dim=C, + img_shape=(H, L), + filter_type="linear", + operator_type="dhconv", + use_mlp=True, + context_config=context_config, + ).to(device) + timer = CUDATimer() + grouped_block = FourierNeuralOperatorBlock( + forward_transform=forward, + inverse_transform=inverse, + embed_dim=C, + img_shape=(H, L), + filter_type="linear", + operator_type="dhconv", + use_mlp=True, + context_config=context_config, + filter_num_groups=G, + ).to(device) + grouped_timer = CUDATimer() + + def call_block(): + return block(x, context, timer=timer) + + def call_grouped_block(): + return grouped_block(x, context, timer=grouped_timer) + + for _ in range(10): + block(x, context) + ungrouped = benchmark(call_block, warmup=0, iters=10) + for _ in range(10): + grouped_block(x, context) + grouped = benchmark(call_grouped_block, warmup=0, iters=10) + + print( + "ungrouped timers: " + + " | ".join(f"{k}: {v:.2f} ms" for k, v in timer.report().items()) + ) + print( + "grouped timers: " + + " | ".join(f"{k}: {v:.2f} ms" for k, v in grouped_timer.report().items()) + ) + + assert grouped.ms_per < 2 / G * ungrouped.ms_per, ( + "Expected grouped DHConv to be faster than ungrouped, but got " + f"{grouped.ms_per:.6f} ms for grouped and " + f"{ungrouped.ms_per:.6f} ms for ungrouped." + ) + assert grouped.max_alloc < ungrouped.max_alloc, ( + "Expected grouped DHConv to use less memory than ungrouped, but got " + f"{grouped.max_alloc/1024/1024:.2f} MB for grouped and " + f"{ungrouped.max_alloc/1024/1024:.2f} MB for ungrouped." + ) diff --git a/fme/core/models/conditional_sfno/timer.py b/fme/core/models/conditional_sfno/timer.py new file mode 100644 index 000000000..43b01d5b2 --- /dev/null +++ b/fme/core/models/conditional_sfno/timer.py @@ -0,0 +1,41 @@ +import collections +import contextlib + +import torch + + +class CUDATimer: + def __init__(self): + self._starters = [] + self._enders = [] + self._names = [] + + @contextlib.contextmanager + def context(self, name: str): + starter = torch.cuda.Event(enable_timing=True) + ender = torch.cuda.Event(enable_timing=True) + self._starters.append(starter) + self._enders.append(ender) + self._names.append(name) + stream = torch.cuda.current_stream() + starter.record(stream) + try: + yield + finally: + ender.record(stream) + return + + def report(self): + torch.cuda.synchronize() + total_times: dict[str, float] = collections.defaultdict(float) + counts: dict[str, int] = collections.defaultdict(int) + for starter, ender, name in zip(self._starters, self._enders, self._names): + total_times[name] += starter.elapsed_time(ender) + counts[name] += 1 + avg_times = {name: total / counts[name] for name, total in total_times.items()} + return avg_times + + +class NullTimer: + def context(self, name: str) -> contextlib.nullcontext: + return contextlib.nullcontext()