diff --git a/torch_harmonics/sht.py b/torch_harmonics/sht.py index 2af8dba3..2d7badf0 100644 --- a/torch_harmonics/sht.py +++ b/torch_harmonics/sht.py @@ -121,6 +121,8 @@ def forward(self, x: torch.Tensor): # apply real fft in the longitudinal direction x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward") + # [..., H, W] -> [..., W, H] + x = x.transpose(-1, -2).contiguous() # do the Legendre-Gauss quadrature x = torch.view_as_real(x) @@ -132,8 +134,8 @@ def forward(self, x: torch.Tensor): xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) # contraction - xout[..., 0] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 0], self.weights.to(x.dtype)) - xout[..., 1] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 1], self.weights.to(x.dtype)) + xout[..., 0] = torch.einsum("...mk,mlk->...lm", x[..., : self.mmax, :, 0], self.weights.to(x.dtype)) + xout[..., 1] = torch.einsum("...mk,mlk->...lm", x[..., : self.mmax, :, 1], self.weights.to(x.dtype)) x = torch.view_as_complex(xout) return x @@ -219,8 +221,6 @@ def forward(self, x: torch.Tensor): assert x.shape[-2] == self.lmax assert x.shape[-1] == self.mmax - # Evaluate associated Legendre functions on the output nodes - x = torch.view_as_real(x) # prepare output out_shape = list(x.size()) @@ -228,9 +228,13 @@ def forward(self, x: torch.Tensor): out_shape[-2] = self.mmax xs = torch.zeros(out_shape, dtype=x.dtype, device=x.device) + # Evaluate associated Legendre functions on the output nodes + x = torch.view_as_real(x) + x = x.transpose(-1, -2).contiguous() + # legendre transformation - xs[..., 0] = torch.einsum("...lm,mlk->...km", x[..., 0], self.pct.to(x.dtype)) - xs[..., 1] = torch.einsum("...lm,mlk->...km", x[..., 1], self.pct.to(x.dtype)) + xs[..., 0] = torch.einsum("...ml,mlk->...km", x[..., 0], self.pct.to(x.dtype)) + xs[..., 1] = torch.einsum("...ml,mlk->...km", x[..., 1], self.pct.to(x.dtype)) # apply the inverse (real) FFT x = torch.view_as_complex(xs)