Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions torch_harmonics/sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def forward(self, x: torch.Tensor):

# apply real fft in the longitudinal direction
x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
# [..., H, W] -> [..., W, H]
x = x.transpose(-1, -2).contiguous()

# do the Legendre-Gauss quadrature
x = torch.view_as_real(x)
Expand All @@ -132,8 +134,8 @@ def forward(self, x: torch.Tensor):
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)

# contraction
xout[..., 0] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 0], self.weights.to(x.dtype))
xout[..., 1] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 1], self.weights.to(x.dtype))
xout[..., 0] = torch.einsum("...mk,mlk->...lm", x[..., : self.mmax, :, 0], self.weights.to(x.dtype))
xout[..., 1] = torch.einsum("...mk,mlk->...lm", x[..., : self.mmax, :, 1], self.weights.to(x.dtype))
x = torch.view_as_complex(xout)

return x
Expand Down Expand Up @@ -219,18 +221,20 @@ def forward(self, x: torch.Tensor):
assert x.shape[-2] == self.lmax
assert x.shape[-1] == self.mmax

# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)

# prepare output
out_shape = list(x.size())
out_shape[-3] = self.nlat
out_shape[-2] = self.mmax
xs = torch.zeros(out_shape, dtype=x.dtype, device=x.device)

# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)
x = x.transpose(-1, -2).contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if this is correct. If we first to view as real, the last dim should be re/im. This means that this transpose does the following:

Input: B, C, H, M, 2 -> B, C, H, 2, M

And then in the contraction below, we contract the 2-index with L. I think the view as real should go after the transpose.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're correct. This snuck through since I wasn't able to install/test the code.

I'll work on a version perhaps today that I can define in a standalone script with timing for demonstration.


# legendre transformation
xs[..., 0] = torch.einsum("...lm,mlk->...km", x[..., 0], self.pct.to(x.dtype))
xs[..., 1] = torch.einsum("...lm,mlk->...km", x[..., 1], self.pct.to(x.dtype))
xs[..., 0] = torch.einsum("...ml,mlk->...km", x[..., 0], self.pct.to(x.dtype))
xs[..., 1] = torch.einsum("...ml,mlk->...km", x[..., 1], self.pct.to(x.dtype))

# apply the inverse (real) FFT
x = torch.view_as_complex(xs)
Expand Down