Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fme/ace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from fme.ace.registry.land_net import LandNetBuilder
from fme.ace.registry.m2lines import FloeNetBuilder, SamudraBuilder
from fme.ace.registry.sfno import SFNO_V0_1_0, SphericalFourierNeuralOperatorBuilder
from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNO
from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNOBuilder
from fme.ace.stepper import DerivedForcingsConfig, StepperOverrideConfig
from fme.ace.stepper.insolation.config import InsolationConfig, NameConfig, ValueConfig
from fme.ace.stepper.parameter_init import (
Expand Down
163 changes: 24 additions & 139 deletions fme/ace/registry/stochastic_sfno.py
Original file line number Diff line number Diff line change
@@ -1,127 +1,10 @@
import dataclasses
import math
from collections.abc import Callable
from typing import Literal

import torch

from fme.ace.registry.registry import ModuleConfig, ModuleSelector
from fme.core.dataset_info import DatasetInfo
from fme.core.models.conditional_sfno.sfnonet import (
Context,
ContextConfig,
get_lat_lon_sfnonet,
)
from fme.core.models.conditional_sfno.sfnonet import (
SphericalFourierNeuralOperatorNet as ConditionalSFNO,
)


def isotropic_noise(
leading_shape: tuple[int, ...],
lmax: int, # length of the ℓ axis expected by isht
mmax: int, # length of the m axis expected by isht
isht: Callable[[torch.Tensor], torch.Tensor],
device: torch.device,
) -> torch.Tensor:
# --- draw independent N(0,1) parts --------------------------------------
coeff_shape = (*leading_shape, lmax, mmax)
real = torch.randn(coeff_shape, dtype=torch.float32, device=device)
imag = torch.randn(coeff_shape, dtype=torch.float32, device=device)
imag[..., :, 0] = 0.0 # m = 0 ⇒ purely real

# m > 0: make Re and Im each N(0,½) → |a_{ℓ m}|² has variance 1
sqrt2 = math.sqrt(2.0)
real[..., :, 1:] /= sqrt2
imag[..., :, 1:] /= sqrt2

# --- global scale that makes Var[T(θ,φ)] = 1 ---------------------------
scale = math.sqrt(4.0 * math.pi) / lmax # (Unsöld theorem ⇒ L = lmax)
alm = (real + 1j * imag) * scale

return isht(alm)


class NoiseConditionedSFNO(torch.nn.Module):
def __init__(
self,
conditional_model: ConditionalSFNO,
img_shape: tuple[int, int],
noise_type: Literal["isotropic", "gaussian"] = "gaussian",
embed_dim_noise: int = 256,
embed_dim_pos: int = 0,
embed_dim_labels: int = 0,
):
super().__init__()
self.conditional_model = conditional_model
self.embed_dim = embed_dim_noise
self.noise_type = noise_type
self.label_pos_embed: torch.nn.Parameter | None = None
# register pos embed if pos_embed_dim != 0
if embed_dim_pos != 0:
self.pos_embed = torch.nn.Parameter(
torch.zeros(
1, embed_dim_pos, img_shape[0], img_shape[1], requires_grad=True
)
)
# initialize pos embed with std=0.02
torch.nn.init.trunc_normal_(self.pos_embed, std=0.02)
if embed_dim_labels > 0:
self.label_pos_embed = torch.nn.Parameter(
torch.zeros(
embed_dim_labels,
embed_dim_pos,
img_shape[0],
img_shape[1],
requires_grad=True,
)
)
torch.nn.init.trunc_normal_(self.label_pos_embed, std=0.02)
else:
self.pos_embed = None

def forward(
self, x: torch.Tensor, labels: torch.Tensor | None = None
) -> torch.Tensor:
x = x.reshape(-1, *x.shape[-3:])
if self.noise_type == "isotropic":
lmax = self.conditional_model.itrans_up.lmax
mmax = self.conditional_model.itrans_up.mmax
noise = isotropic_noise(
(x.shape[0], self.embed_dim),
lmax,
mmax,
self.conditional_model.itrans_up,
device=x.device,
)
elif self.noise_type == "gaussian":
noise = torch.randn(
[x.shape[0], self.embed_dim, *x.shape[-2:]],
device=x.device,
dtype=x.dtype,
)
else:
raise ValueError(f"Invalid noise type: {self.noise_type}")

if self.pos_embed is not None:
embedding_pos = self.pos_embed.repeat(noise.shape[0], 1, 1, 1)
if self.label_pos_embed is not None and labels is not None:
label_embedding_pos = torch.einsum(
"bl, lpxy -> bpxy", labels, self.label_pos_embed
)
embedding_pos = embedding_pos + label_embedding_pos
else:
embedding_pos = None

return self.conditional_model(
x,
Context(
embedding_scalar=None,
embedding_pos=embedding_pos,
labels=labels,
noise=noise,
),
)
from fme.core.models.conditional_sfno.v0.stochastic_sfno import build as build_v0
from fme.core.models.conditional_sfno.v1.stochastic_sfno import build as build_v1


# this is based on the call signature of SphericalFourierNeuralOperatorNet at
Expand All @@ -135,6 +18,7 @@ class NoiseConditionedSFNOBuilder(ModuleConfig):
Noise is provided as conditioning input to conditional layer normalization.

Attributes:
version: Version of the model.
spectral_transform: Type of spherical transform to use.
Kept for backwards compatibility.
filter_type: Type of filter to use.
Expand Down Expand Up @@ -186,6 +70,7 @@ class NoiseConditionedSFNOBuilder(ModuleConfig):
Defaults to spectral_lora_rank.
"""

version: Literal["v0", "v1", "latest"] = "v0"
spectral_transform: Literal["sht"] = "sht"
filter_type: Literal["linear", "makani-linear"] = "linear"
operator_type: Literal["dhconv"] = "dhconv"
Expand Down Expand Up @@ -236,30 +121,30 @@ def __post_init__(self):
"Only 'dhconv' operator_type is supported for "
"NoiseConditionedSFNO models."
)
if self.version == "latest":
# must replace as eventual newer versions break backwards compatibility
# v1 is not stable yet, keep using v0 as default for now
self.version = "v0"

def build(
self,
n_in_channels: int,
n_out_channels: int,
dataset_info: DatasetInfo,
):
sfno_net = get_lat_lon_sfnonet(
params=self,
in_chans=n_in_channels,
out_chans=n_out_channels,
img_shape=dataset_info.img_shape,
context_config=ContextConfig(
embed_dim_scalar=0,
embed_dim_pos=self.context_pos_embed_dim,
embed_dim_noise=self.noise_embed_dim,
embed_dim_labels=len(dataset_info.all_labels),
),
)
return NoiseConditionedSFNO(
sfno_net,
noise_type=self.noise_type,
embed_dim_noise=self.noise_embed_dim,
embed_dim_pos=self.context_pos_embed_dim,
embed_dim_labels=len(dataset_info.all_labels),
img_shape=dataset_info.img_shape,
)
if self.version == "v0":
return build_v0(
self,
n_in_channels=n_in_channels,
n_out_channels=n_out_channels,
dataset_info=dataset_info,
)
elif self.version == "v1":
return build_v1(
self,
n_in_channels=n_in_channels,
n_out_channels=n_out_channels,
dataset_info=dataset_info,
)
else:
raise ValueError(f"Unsupported version: {self.version}")
1 change: 1 addition & 0 deletions fme/core/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import conditional_sfno, mlp
1 change: 1 addition & 0 deletions fme/core/models/conditional_sfno/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import v0
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from fme.core.benchmark.timer import Timer, NullTimer
from fme.core.models.conditional_sfno.lora import LoRAConv2d
from .lora import LoRAConv2d

from .activations import ComplexReLU
from .contractions import compl_mul2d_fwd, compl_muladd2d_fwd
Expand Down
145 changes: 145 additions & 0 deletions fme/core/models/conditional_sfno/v0/stochastic_sfno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import math
from collections.abc import Callable
from typing import Literal

import torch

from fme.core.dataset_info import DatasetInfo

from .sfnonet import Context, ContextConfig, get_lat_lon_sfnonet
from .sfnonet import SphericalFourierNeuralOperatorNet as ConditionalSFNO


def isotropic_noise(
leading_shape: tuple[int, ...],
lmax: int, # length of the ℓ axis expected by isht
mmax: int, # length of the m axis expected by isht
isht: Callable[[torch.Tensor], torch.Tensor],
device: torch.device,
) -> torch.Tensor:
# --- draw independent N(0,1) parts --------------------------------------
coeff_shape = (*leading_shape, lmax, mmax)
real = torch.randn(coeff_shape, dtype=torch.float32, device=device)
imag = torch.randn(coeff_shape, dtype=torch.float32, device=device)
imag[..., :, 0] = 0.0 # m = 0 ⇒ purely real

# m > 0: make Re and Im each N(0,½) → |a_{ℓ m}|² has variance 1
sqrt2 = math.sqrt(2.0)
real[..., :, 1:] /= sqrt2
imag[..., :, 1:] /= sqrt2

# --- global scale that makes Var[T(θ,φ)] = 1 ---------------------------
scale = math.sqrt(4.0 * math.pi) / lmax # (Unsöld theorem ⇒ L = lmax)
alm = (real + 1j * imag) * scale

return isht(alm)


class NoiseConditionedSFNO(torch.nn.Module):
def __init__(
self,
conditional_model: ConditionalSFNO,
img_shape: tuple[int, int],
noise_type: Literal["isotropic", "gaussian"] = "gaussian",
embed_dim_noise: int = 256,
embed_dim_pos: int = 0,
embed_dim_labels: int = 0,
):
super().__init__()
self.conditional_model = conditional_model
self.embed_dim = embed_dim_noise
self.noise_type = noise_type
self.label_pos_embed: torch.nn.Parameter | None = None
# register pos embed if pos_embed_dim != 0
if embed_dim_pos != 0:
self.pos_embed = torch.nn.Parameter(
torch.zeros(
1, embed_dim_pos, img_shape[0], img_shape[1], requires_grad=True
)
)
# initialize pos embed with std=0.02
torch.nn.init.trunc_normal_(self.pos_embed, std=0.02)
if embed_dim_labels > 0:
self.label_pos_embed = torch.nn.Parameter(
torch.zeros(
embed_dim_labels,
embed_dim_pos,
img_shape[0],
img_shape[1],
requires_grad=True,
)
)
torch.nn.init.trunc_normal_(self.label_pos_embed, std=0.02)
else:
self.pos_embed = None

def forward(
self, x: torch.Tensor, labels: torch.Tensor | None = None
) -> torch.Tensor:
x = x.reshape(-1, *x.shape[-3:])
if self.noise_type == "isotropic":
lmax = self.conditional_model.itrans_up.lmax
mmax = self.conditional_model.itrans_up.mmax
noise = isotropic_noise(
(x.shape[0], self.embed_dim),
lmax,
mmax,
self.conditional_model.itrans_up,
device=x.device,
)
elif self.noise_type == "gaussian":
noise = torch.randn(
[x.shape[0], self.embed_dim, *x.shape[-2:]],
device=x.device,
dtype=x.dtype,
)
else:
raise ValueError(f"Invalid noise type: {self.noise_type}")

if self.pos_embed is not None:
embedding_pos = self.pos_embed.repeat(noise.shape[0], 1, 1, 1)
if self.label_pos_embed is not None and labels is not None:
label_embedding_pos = torch.einsum(
"bl, lpxy -> bpxy", labels, self.label_pos_embed
)
embedding_pos = embedding_pos + label_embedding_pos
else:
embedding_pos = None

return self.conditional_model(
x,
Context(
embedding_scalar=None,
embedding_pos=embedding_pos,
labels=labels,
noise=noise,
),
)


def build(
params,
n_in_channels: int,
n_out_channels: int,
dataset_info: DatasetInfo,
):
sfno_net = get_lat_lon_sfnonet(
params=params,
in_chans=n_in_channels,
out_chans=n_out_channels,
img_shape=dataset_info.img_shape,
context_config=ContextConfig(
embed_dim_scalar=0,
embed_dim_pos=params.context_pos_embed_dim,
embed_dim_noise=params.noise_embed_dim,
embed_dim_labels=len(dataset_info.all_labels),
),
)
return NoiseConditionedSFNO(
sfno_net,
noise_type=params.noise_type,
embed_dim_noise=params.noise_embed_dim,
embed_dim_pos=params.context_pos_embed_dim,
embed_dim_labels=len(dataset_info.all_labels),
img_shape=dataset_info.img_shape,
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch import nn

from fme.core.models.conditional_sfno.lora import LoRAConv2d
from .lora import LoRAConv2d


def test_lora_conv2d_load_conv2d_checkpoint():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

from fme.core.device import get_device
from fme.core.gridded_ops import LatLonOperations
from fme.core.models.conditional_sfno.s2convolutions import SpectralConvS2

from .s2convolutions import _contract_dhconv
from .s2convolutions import SpectralConvS2, _contract_dhconv


@dataclasses.dataclass
Expand Down
Loading