From 9c829943b18f58e05fa8cb4073f3d3dfa8e3d070 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 11 Feb 2026 19:26:50 +0000 Subject: [PATCH 1/2] speed up SHT --- torch_harmonics/sht.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/torch_harmonics/sht.py b/torch_harmonics/sht.py index 2af8dba3..51d19c2a 100644 --- a/torch_harmonics/sht.py +++ b/torch_harmonics/sht.py @@ -121,6 +121,8 @@ def forward(self, x: torch.Tensor): # apply real fft in the longitudinal direction x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward") + # [..., H, W] -> [..., W, H] + x = x.transpose(-1, -2).contiguous() # do the Legendre-Gauss quadrature x = torch.view_as_real(x) @@ -132,8 +134,8 @@ def forward(self, x: torch.Tensor): xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) # contraction - xout[..., 0] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 0], self.weights.to(x.dtype)) - xout[..., 1] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 1], self.weights.to(x.dtype)) + xout[..., 0] = torch.einsum("...mk,mlk->...lm", x[..., : self.mmax, 0], self.weights.to(x.dtype)) + xout[..., 1] = torch.einsum("...mk,mlk->...lm", x[..., : self.mmax, 1], self.weights.to(x.dtype)) x = torch.view_as_complex(xout) return x @@ -219,8 +221,6 @@ def forward(self, x: torch.Tensor): assert x.shape[-2] == self.lmax assert x.shape[-1] == self.mmax - # Evaluate associated Legendre functions on the output nodes - x = torch.view_as_real(x) # prepare output out_shape = list(x.size()) @@ -228,9 +228,13 @@ def forward(self, x: torch.Tensor): out_shape[-2] = self.mmax xs = torch.zeros(out_shape, dtype=x.dtype, device=x.device) + # Evaluate associated Legendre functions on the output nodes + x = torch.view_as_real(x) + x = x.transpose(-1, -2).contiguous() + # legendre transformation - xs[..., 0] = torch.einsum("...lm,mlk->...km", x[..., 0], self.pct.to(x.dtype)) - xs[..., 1] = torch.einsum("...lm,mlk->...km", x[..., 1], self.pct.to(x.dtype)) + xs[..., 0] = torch.einsum("...ml,mlk->...km", x[..., 0], self.pct.to(x.dtype)) + xs[..., 1] = torch.einsum("...ml,mlk->...km", x[..., 1], self.pct.to(x.dtype)) # apply the inverse (real) FFT x = torch.view_as_complex(xs) From e060933465e3b51c3c807ec339598e6c291f36be Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 11 Feb 2026 19:29:52 +0000 Subject: [PATCH 2/2] take mmax slice on correct dimension --- torch_harmonics/sht.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_harmonics/sht.py b/torch_harmonics/sht.py index 51d19c2a..2d7badf0 100644 --- a/torch_harmonics/sht.py +++ b/torch_harmonics/sht.py @@ -134,8 +134,8 @@ def forward(self, x: torch.Tensor): xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) # contraction - xout[..., 0] = torch.einsum("...mk,mlk->...lm", x[..., : self.mmax, 0], self.weights.to(x.dtype)) - xout[..., 1] = torch.einsum("...mk,mlk->...lm", x[..., : self.mmax, 1], self.weights.to(x.dtype)) + xout[..., 0] = torch.einsum("...mk,mlk->...lm", x[..., : self.mmax, :, 0], self.weights.to(x.dtype)) + xout[..., 1] = torch.einsum("...mk,mlk->...lm", x[..., : self.mmax, :, 1], self.weights.to(x.dtype)) x = torch.view_as_complex(xout) return x