Skip to content
Merged
102 changes: 57 additions & 45 deletions fme/core/models/conditional_sfno/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from fme.core.benchmark.timer import Timer, NullTimer
from fme.core.models.conditional_sfno.lora import LoRAConv2d

from .activations import ComplexReLU
Expand Down Expand Up @@ -223,7 +224,12 @@ def reset_parameters(self):
torch.nn.init.constant_(self.W_bias_pos.weight, 0.0)
# no bias on 2d layers as it is already handled in the non-2d layers

def forward(self, x: torch.Tensor, context: Context) -> torch.Tensor:
def forward(
self,
x: torch.Tensor,
context: Context,
timer: Timer = NullTimer(),
) -> torch.Tensor:
"""
Conditional Layer Normalization

Expand All @@ -242,52 +248,58 @@ def forward(self, x: torch.Tensor, context: Context) -> torch.Tensor:
self.W_scale_labels is not None or self.W_bias_labels is not None
):
raise ValueError("labels must be provided")
if self.W_scale is not None:
if context.embedding_scalar is None:
raise ValueError("embedding_scalar must be provided")
scale: torch.Tensor = (
self.W_scale(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1)
)
else:
scale = torch.ones(
list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype
)
with timer.child("compute_scaling_and_bias"):
if self.W_scale is not None:
if context.embedding_scalar is None:
raise ValueError("embedding_scalar must be provided")
scale: torch.Tensor = (
self.W_scale(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1)
)
else:
scale = torch.ones(
list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype
)

if self.W_scale_2d is not None:
if context.noise is None:
raise ValueError("embedding_2d must be provided")
scale = scale + self.W_scale_2d(context.noise)
if self.W_bias is not None:
if context.embedding_scalar is None:
raise ValueError("embedding_scalar must be provided")
bias: torch.Tensor = (
self.W_bias(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1)
)
else:
bias = torch.zeros(
list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype
)
if self.W_scale_2d is not None:
if context.noise is None:
raise ValueError("embedding_2d must be provided")
scale = scale + self.W_scale_2d(context.noise)
if self.W_bias is not None:
if context.embedding_scalar is None:
raise ValueError("embedding_scalar must be provided")
bias: torch.Tensor = (
self.W_bias(context.embedding_scalar).unsqueeze(-1).unsqueeze(-1)
)
else:
bias = torch.zeros(
list(x.shape[:-2]) + [1, 1], device=x.device, dtype=x.dtype
)

if self.W_scale_labels is not None:
scale = scale + self.W_scale_labels(context.labels).unsqueeze(-1).unsqueeze(
-1
)
if self.W_bias_labels is not None:
bias = bias + self.W_bias_labels(context.labels).unsqueeze(-1).unsqueeze(-1)
if self.W_bias_2d is not None:
if context.noise is None:
raise ValueError("embedding_2d must be provided")
bias = bias + self.W_bias_2d(context.noise)
if self.W_scale_pos is not None:
if context.embedding_pos is None:
raise ValueError("embedding_pos must be provided")
scale = scale + self.W_scale_pos(context.embedding_pos)
if self.W_bias_pos is not None:
if context.embedding_pos is None:
raise ValueError("embedding_pos must be provided")
bias = bias + self.W_bias_pos(context.embedding_pos)
x_norm: torch.Tensor = self.norm(x)
return x_norm * scale + bias
if self.W_scale_labels is not None:
scale = scale + self.W_scale_labels(context.labels).unsqueeze(
-1
).unsqueeze(-1)
if self.W_bias_labels is not None:
bias = bias + self.W_bias_labels(context.labels).unsqueeze(
-1
).unsqueeze(-1)
if self.W_bias_2d is not None:
if context.noise is None:
raise ValueError("embedding_2d must be provided")
bias = bias + self.W_bias_2d(context.noise)
if self.W_scale_pos is not None:
if context.embedding_pos is None:
raise ValueError("embedding_pos must be provided")
scale = scale + self.W_scale_pos(context.embedding_pos)
if self.W_bias_pos is not None:
if context.embedding_pos is None:
raise ValueError("embedding_pos must be provided")
bias = bias + self.W_bias_pos(context.embedding_pos)
with timer.child("normalize"):
x_norm: torch.Tensor = self.norm(x)
with timer.child("apply_scaling_and_bias"):
return_value = x_norm * scale + bias
return return_value


@torch.jit.script
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import torch.nn as nn
from torch import amp

from fme.core.benchmark.timer import NullTimer, Timer

# import convenience functions for factorized tensors
from .factorizations import get_contract_fun

Expand Down Expand Up @@ -124,7 +126,7 @@ def __init__(
if bias:
self.bias = nn.Parameter(torch.zeros(1, self.out_channels, 1, 1))

def forward(self, x):
def forward(self, x, timer: Timer = NullTimer()):
dtype = x.dtype
residual = x
x = x.float()
Expand All @@ -138,7 +140,10 @@ def forward(self, x):
B, C, H, W = x.shape
x = x.reshape(B, self.num_groups, C // self.num_groups, H, W)
xp = self._contract(
x, self.weight, separable=self.separable, operator_type=self.operator_type
x,
self.weight,
separable=self.separable,
operator_type=self.operator_type,
)
x = xp.reshape(B, self.out_channels, H, W).contiguous()

Expand Down
54 changes: 31 additions & 23 deletions fme/core/models/conditional_sfno/s2convolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import torch_harmonics as th
import torch_harmonics.distributed as thd

from fme.core.benchmark.timer import NullTimer, Timer

# import convenience functions for factorized tensors
from .activations import ComplexReLU

Expand Down Expand Up @@ -223,45 +225,51 @@ def __init__(
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
self.out_channels = out_channels

def forward(self, x): # pragma: no cover
def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover
dtype = x.dtype
residual = x
x = x.float()

with torch.amp.autocast("cuda", enabled=False):
x = self.forward_transform(x.float())
with timer.child("forward_transform"):
x = self.forward_transform(x.float())
if self._round_trip_residual:
x = x.contiguous()
residual = self.inverse_transform(x)
residual = residual.to(dtype)
with timer.child("round_trip_residual"):
x = x.contiguous()
residual = self.inverse_transform(x)
residual = residual.to(dtype)

B, C, H, W = x.shape
assert C % self.num_groups == 0
x = x.reshape(B, self.num_groups, C // self.num_groups, H, W)

if self.lora_A is not None and self.lora_B is not None:
lora_update = _contract_lora(
self.lora_A,
self.lora_B,
x[..., : self.modes_lat_local, : self.modes_lon_local],
)
with timer.child("lora_update"):
lora_update = _contract_lora(
self.lora_A,
self.lora_B,
x[..., : self.modes_lat_local, : self.modes_lon_local],
)
else:
lora_update = 0.0

xp = torch.zeros_like(x)
xp[..., : self.modes_lat_local, : self.modes_lon_local] = _contract_dhconv(
x[..., : self.modes_lat_local, : self.modes_lon_local],
self.weight,
)
xp = xp + self.lora_scaling * lora_update
xp = xp.reshape(B, self.out_channels, H, W)
x = xp.contiguous()
with timer.child("dhconv"):
xp = torch.zeros_like(x)
xp[..., : self.modes_lat_local, : self.modes_lon_local] = _contract_dhconv(
x[..., : self.modes_lat_local, : self.modes_lon_local],
self.weight,
)
xp = xp + self.lora_scaling * lora_update
xp = xp.reshape(B, self.out_channels, H, W)
x = xp.contiguous()

with torch.amp.autocast("cuda", enabled=False):
x = self.inverse_transform(x)
with timer.child("inverse_transform"):
x = self.inverse_transform(x)

if hasattr(self, "bias"):
x = x + self.bias
with timer.child("add_bias"):
x = x + self.bias

x = x.type(dtype)

Expand Down Expand Up @@ -320,7 +328,7 @@ def __init__(
scale * torch.randn(1, out_channels, *self.output_dims)
)

def forward(self, x): # pragma: no cover
def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover
dtype = x.dtype
x = x.float()
B, C, H, W = x.shape
Expand Down Expand Up @@ -503,7 +511,7 @@ def forward_mlp(self, x): # pragma: no cover

return x

def forward(self, x): # pragma: no cover
def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover
dtype = x.dtype
residual = x
x = x.to(torch.float32)
Expand Down Expand Up @@ -626,7 +634,7 @@ def forward_mlp(self, x): # pragma: no cover

return x

def forward(self, x): # pragma: no cover
def forward(self, x, timer: Timer = NullTimer()): # pragma: no cover
dtype = x.dtype
x = x.to(torch.float32)

Expand Down
74 changes: 43 additions & 31 deletions fme/core/models/conditional_sfno/sfnonet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import torch_harmonics as th
from torch.utils.checkpoint import checkpoint

from fme.core.benchmark.timer import Timer, NullTimer

from .initialization import trunc_normal_

# wrap fft, to unify interface to spectral transforms
Expand Down Expand Up @@ -62,7 +64,7 @@ def __init__(self, *args, **kwargs):
super().__init__()
self.conv = th.DiscreteContinuousConvS2(*args, **kwargs)

def forward(self, x):
def forward(self, x, timer: Timer = NullTimer()):
return self.conv(x), x


Expand Down Expand Up @@ -153,8 +155,8 @@ def __init__(
else:
raise (NotImplementedError)

def forward(self, x):
return self.filter(x)
def forward(self, x, timer: Timer = NullTimer()):
return self.filter(x, timer=timer)


class FourierNeuralOperatorBlock(nn.Module):
Expand Down Expand Up @@ -295,44 +297,54 @@ def __init__(
lora_alpha=lora_alpha,
)

def forward(self, x, context_embedding):
x_norm = torch.zeros_like(x)
x_norm[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] = self.norm0(
x[..., : self.input_shape_loc[0], : self.input_shape_loc[1]],
context_embedding,
)
x, residual = self.filter(x_norm)

def forward(self, x, context_embedding, timer: Timer = NullTimer()):
with timer.child("norm0") as norm0_timer:
x_norm = torch.zeros_like(x)
x_norm[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] = (
self.norm0(
x[..., : self.input_shape_loc[0], : self.input_shape_loc[1]],
context_embedding,
timer=norm0_timer,
)
)
with timer.child("filter") as filter_timer:
x, residual = self.filter(x_norm, timer=filter_timer)
if hasattr(self, "inner_skip"):
if self.concat_skip:
x = torch.cat((x, self.inner_skip(residual)), dim=1)
x = self.inner_skip_conv(x)
else:
x = x + self.inner_skip(residual)
with timer.child("inner_skip"):
if self.concat_skip:
x = torch.cat((x, self.inner_skip(residual)), dim=1)
x = self.inner_skip_conv(x)
else:
x = x + self.inner_skip(residual)

if hasattr(self, "act_layer"):
x = self.act_layer(x)

x_norm = torch.zeros_like(x)
x_norm[..., : self.output_shape_loc[0], : self.output_shape_loc[1]] = (
self.norm1(
x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]],
context_embedding,
with timer.child("activation"):
x = self.act_layer(x)

with timer.child("norm1") as norm1_timer:
x_norm = torch.zeros_like(x)
x_norm[..., : self.output_shape_loc[0], : self.output_shape_loc[1]] = (
self.norm1(
x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]],
context_embedding,
timer=norm1_timer,
)
)
)
x = x_norm
x = x_norm

if hasattr(self, "mlp"):
x = self.mlp(x)
with timer.child("mlp"):
x = self.mlp(x)

x = self.drop_path(x)

if hasattr(self, "outer_skip"):
if self.concat_skip:
x = torch.cat((x, self.outer_skip(residual)), dim=1)
x = self.outer_skip_conv(x)
else:
x = x + self.outer_skip(residual)
with timer.child("outer_skip"):
if self.concat_skip:
x = torch.cat((x, self.outer_skip(residual)), dim=1)
x = self.outer_skip_conv(x)
else:
x = x + self.outer_skip(residual)

return x

Expand Down
Loading