Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
56cc1cc
refactor MLP to LoRAMLP
mcgibbon Jan 13, 2026
5204ae1
add lora to Conv2d in conditional sfno
mcgibbon Jan 13, 2026
0a6ab69
add lora for spectral convolutions, add to csfno config
mcgibbon Jan 14, 2026
044168a
avoid crash on super init
mcgibbon Jan 15, 2026
3dd7aa9
fix change to model output
mcgibbon Jan 15, 2026
f6d62d4
Merge branch 'main' into feature/sfno_lora
mcgibbon Jan 15, 2026
d35f5f0
Merge branch 'main' into feature/sfno_lora
mcgibbon Jan 16, 2026
ca1c39a
Merge branch 'main' into feature/sfno_lora
mcgibbon Jan 22, 2026
e203eb9
Merge branch 'feature/sfno_lora' into feature/grouped_spectral_conv
mcgibbon Jan 22, 2026
7f6cc52
update regression target to use dhconv
mcgibbon Jan 22, 2026
71aac8b
use dhconv directly, disable other options
mcgibbon Jan 22, 2026
c09fd5a
delete unused code
mcgibbon Jan 22, 2026
2f7ca0f
enable grouped convolutions for linear filter type
mcgibbon Jan 22, 2026
bc6b789
restore MLP checkpointing
mcgibbon Jan 23, 2026
4ee5be4
Merge branch 'main' into feature/sfno_lora
mcgibbon Jan 26, 2026
5159ed8
Merge branch 'main' into feature/sfno_lora
mcgibbon Jan 29, 2026
3529e1e
Merge branch 'feature/sfno_lora' into feature/grouped_spectral_conv
mcgibbon Jan 29, 2026
9f21b52
Merge branch 'feature/grouped_spectral_conv' into feature/grouped_spe…
mcgibbon Jan 29, 2026
08323ba
enforce not implemented features at config level
mcgibbon Jan 29, 2026
c3f65fe
update sfno init to use updated makani scheme
mcgibbon Jan 29, 2026
29cd0ff
use correctly shaped scale
mcgibbon Jan 29, 2026
88c2e13
Merge branch 'main' into feature/grouped_spectral_conv
mcgibbon Feb 3, 2026
f460da7
Merge branch 'feature/grouped_spectral_conv' into feature/grouped_spe…
mcgibbon Feb 3, 2026
16813fc
default to linear filter type, disallow non-linear
mcgibbon Feb 3, 2026
30f3afb
Merge branch 'main' into feature/grouped_spectral_conv
mcgibbon Feb 3, 2026
15abf2c
allow makani-linear filter
mcgibbon Feb 3, 2026
5bef726
Merge branch 'feature/grouped_spectral_conv' of github.com:ai2cm/ace …
mcgibbon Feb 3, 2026
39dec0d
update sfnonet regression target to match primary code path
mcgibbon Feb 3, 2026
30d9a8a
Merge branch 'feature/update_reference' into feature/grouped_spectral…
mcgibbon Feb 3, 2026
2787426
Merge branch 'main' into feature/grouped_spectral_conv
mcgibbon Feb 3, 2026
f615a26
update diffusion regression test to latest settings
mcgibbon Feb 3, 2026
1147499
Merge branch 'feature/grouped_spectral_conv' of github.com:ai2cm/ace …
mcgibbon Feb 3, 2026
5acc370
Merge branch 'feature/grouped_spectral_conv' into feature/grouped_spe…
mcgibbon Feb 3, 2026
54fecc6
incorporate review comments
mcgibbon Feb 3, 2026
577dd2d
Merge branch 'main' into feature/grouped_spectral_conv
mcgibbon Feb 3, 2026
b93ab14
Merge branch 'main' into feature/grouped_spectral_conv
mcgibbon Feb 4, 2026
788aac0
Merge branch 'feature/grouped_spectral_conv' into feature/grouped_spe…
mcgibbon Feb 4, 2026
7acb412
Merge branch 'main' into feature/makani_sfno_init
mcgibbon Feb 4, 2026
f43e8af
remove overwrite of conv2d weights
mcgibbon Feb 4, 2026
cbbc0b6
use varname makani is using
mcgibbon Feb 4, 2026
965765e
update regression target
mcgibbon Feb 4, 2026
8c18465
update diffusion regression targets
mcgibbon Feb 4, 2026
aef7ce3
Merge branch 'main' into feature/grouped_spectral_conv_2
mcgibbon Feb 4, 2026
e0fc1b4
remove second copy of _contract_dhconv
mcgibbon Feb 4, 2026
4db57d3
Merge branch 'feature/makani_sfno_init' into feature/grouped_spectral…
mcgibbon Feb 6, 2026
aa03c6d
add unit test that dhconv is faster when using groups
mcgibbon Feb 6, 2026
fbe3f9d
Merge branch 'main' into feature/grouped_spectral_conv_2
mcgibbon Feb 6, 2026
460198f
move test to correct file
mcgibbon Feb 6, 2026
0bc701a
add test with profiling for sfno
mcgibbon Feb 6, 2026
e5c92ad
switch to [w, h, c] ordering for 10pct speedup
mcgibbon Feb 6, 2026
7cd13ec
clean up commented code, give avg times
mcgibbon Feb 6, 2026
1d3f385
make timings make sense with contiguous casts
mcgibbon Feb 6, 2026
6d8f983
save a bit more with contiguous, clearer times
mcgibbon Feb 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 75 additions & 59 deletions fme/core/models/conditional_sfno/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch.utils.checkpoint import checkpoint

from fme.core.models.conditional_sfno.lora import LoRAConv2d
from fme.core.models.conditional_sfno.timer import CUDATimer, NullTimer

from .activations import ComplexReLU
from .contractions import compl_mul2d_fwd, compl_muladd2d_fwd
Expand Down Expand Up @@ -163,35 +164,33 @@ def __init__(
self.W_bias_labels = None
if self.embed_dim_noise > 0:
# no bias as it is already handled in the non-2d layers
self.W_scale_2d = nn.Conv2d(
self.embed_dim_noise, self.n_channels, kernel_size=1, bias=False
self.W_scale_2d = nn.Linear(
self.embed_dim_noise, self.n_channels, bias=False
)
self.W_bias_2d = nn.Conv2d(
self.embed_dim_noise, self.n_channels, kernel_size=1, bias=False
self.W_bias_2d = nn.Linear(
self.embed_dim_noise, self.n_channels, bias=False
)
else:
self.W_scale_2d = None
self.W_bias_2d = None
if self.embed_dim_pos > 0:
# no bias as it is already handled in the non-2d layers
self.W_scale_pos = nn.Conv2d(
self.embed_dim_pos, self.n_channels, kernel_size=1, bias=False
)
self.W_bias_pos = nn.Conv2d(
self.embed_dim_pos, self.n_channels, kernel_size=1, bias=False
self.W_scale_pos = nn.Linear(
self.embed_dim_pos, self.n_channels, bias=False
)
self.W_bias_pos = nn.Linear(self.embed_dim_pos, self.n_channels, bias=False)
else:
self.W_scale_pos = None
self.W_bias_pos = None
if global_layer_norm:
self.norm = nn.LayerNorm(
(self.n_channels, img_shape[0], img_shape[1]),
(img_shape[1], img_shape[0], self.n_channels),
eps=epsilon,
elementwise_affine=elementwise_affine,
)
else:
self.norm = ChannelLayerNorm(
self.n_channels,
self.norm = nn.LayerNorm(
(self.n_channels,),
eps=epsilon,
elementwise_affine=elementwise_affine,
)
Expand Down Expand Up @@ -223,7 +222,12 @@ def reset_parameters(self):
torch.nn.init.constant_(self.W_bias_pos.weight, 0.0)
# no bias on 2d layers as it is already handled in the non-2d layers

def forward(self, x: torch.Tensor, context: Context) -> torch.Tensor:
def forward(
self,
x: torch.Tensor,
context: Context,
timer: CUDATimer | NullTimer | None = None,
) -> torch.Tensor:
"""
Conditional Layer Normalization

Expand All @@ -232,62 +236,74 @@ def forward(self, x: torch.Tensor, context: Context) -> torch.Tensor:

Args:
x: The input tensor to normalize, of shape
(batch_size, channels, height, width).
(batch_size, width, height, channels).
context: The context to condition on.

Returns:
The normalized tensor, of shape (batch_size, channels, height, width).
The normalized tensor, of shape (batch_size, width, height, channels).
"""
if timer is None:
timer = NullTimer()
if context.labels is None and (
self.W_scale_labels is not None or self.W_bias_labels is not None
):
raise ValueError("labels must be provided")
if self.W_scale is not None:
if context.embedding_scalar is None:
raise ValueError("embedding_scalar must be provided")
scale: torch.Tensor = (
self.W_scale(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1)
)
else:
scale = torch.ones(
list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype
)
with timer.context("layer_norm_compute_scaling_and_bias"):
if self.W_scale is not None:
if context.embedding_scalar is None:
raise ValueError("embedding_scalar must be provided")
scale: torch.Tensor = (
self.W_scale(context.embedding_scalar).unsqueeze(-2).unsqueeze(-2)
)
else:
scale = torch.ones(
list(x.shape[:-3]) + [1, 1, x.shape[-1]],
device=x.device,
dtype=x.dtype,
)

if self.W_scale_2d is not None:
if context.noise is None:
raise ValueError("embedding_2d must be provided")
scale = scale + self.W_scale_2d(context.noise)
if self.W_bias is not None:
if context.embedding_scalar is None:
raise ValueError("embedding_scalar must be provided")
bias: torch.Tensor = (
self.W_bias(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1)
)
else:
bias = torch.zeros(
list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype
)
if self.W_scale_2d is not None:
if context.noise is None:
raise ValueError("embedding_2d must be provided")
scale = scale + self.W_scale_2d(context.noise)
if self.W_bias is not None:
if context.embedding_scalar is None:
raise ValueError("embedding_scalar must be provided")
bias: torch.Tensor = (
self.W_bias(context.embedding_scalar).unsqueeze(-2).unsqueeze(-2)
)
else:
bias = torch.zeros(
list(x.shape[:-3]) + [1, 1, x.shape[-1]],
device=x.device,
dtype=x.dtype,
)

if self.W_scale_labels is not None:
scale = scale + self.W_scale_labels(context.labels).unsqueeze(-1).unsqueeze(
-1
)
if self.W_bias_labels is not None:
bias = bias + self.W_bias_labels(context.labels).unsqueeze(-1).unsqueeze(-1)
if self.W_bias_2d is not None:
if context.noise is None:
raise ValueError("embedding_2d must be provided")
bias = bias + self.W_bias_2d(context.noise)
if self.W_scale_pos is not None:
if context.embedding_pos is None:
raise ValueError("embedding_pos must be provided")
scale = scale + self.W_scale_pos(context.embedding_pos)
if self.W_bias_pos is not None:
if context.embedding_pos is None:
raise ValueError("embedding_pos must be provided")
bias = bias + self.W_bias_pos(context.embedding_pos)
x_norm: torch.Tensor = self.norm(x)
return x_norm * scale + bias
if self.W_scale_labels is not None:
scale = scale + self.W_scale_labels(context.labels).unsqueeze(
-2
).unsqueeze(-2)
if self.W_bias_labels is not None:
bias = bias + self.W_bias_labels(context.labels).unsqueeze(
-2
).unsqueeze(-2)
if self.W_bias_2d is not None:
if context.noise is None:
raise ValueError("embedding_2d must be provided")
bias = bias + self.W_bias_2d(context.noise)
if self.W_scale_pos is not None:
if context.embedding_pos is None:
raise ValueError("embedding_pos must be provided")
scale = scale + self.W_scale_pos(context.embedding_pos)
if self.W_bias_pos is not None:
if context.embedding_pos is None:
raise ValueError("embedding_pos must be provided")
bias = bias + self.W_bias_pos(context.embedding_pos)
with timer.context("layer_norm_normalize"):
x_norm: torch.Tensor = self.norm(x)
with timer.context("layer_norm_apply_scaling_and_bias"):
return_value = x_norm * scale + bias
return return_value


@torch.jit.script
Expand Down
Loading