From 4d11b8820f63913d1ca7e4b092bd4c56589a8b18 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 22 Jan 2026 07:20:36 -0800 Subject: [PATCH 01/13] fwd and bwd 16bit support --- tests/test_convolution.py | 60 +++++--- torch_harmonics/disco/_disco_utils.py | 66 +++++---- torch_harmonics/disco/csrc/disco_cuda_bwd.cu | 139 +++++++++++-------- torch_harmonics/disco/csrc/disco_cuda_fwd.cu | 126 +++++++++-------- 4 files changed, 228 insertions(+), 163 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 09bc7d74..ad3fdbbd 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -336,24 +336,39 @@ def test_sparse_against_dense( @parameterized.expand( [ + # fp32 tests # regular convolution - [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (21, 40), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (21, 40), (2, 2), "morlet", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (21, 40), (2, 1), "morlet", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (21, 40), (3), "zernike", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (21, 40), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (21, 40), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (21, 40), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (21, 40), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], # transpose convolution - [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (21, 40), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (21, 40), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (21, 40), (41, 80), (2, 1), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (21, 40), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (21, 40), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (21, 40), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (21, 40), (41, 80), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (21, 40), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + # fp16 tests + # regular convolution + [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float16, False, 1e-2, 1e-2], + [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float16, False, 1e-2, 1e-2], + # transpose convolution + [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float16, True, 1e-2, 1e-2], + [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float16, True, 1e-2, 1e-2], + # bf16 tests + # regular convolution + [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.bfloat16, False, 1e-2, 1e-2], + [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.bfloat16, False, 1e-2, 1e-2], + # transpose convolution + [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.bfloat16, True, 1e-2, 1e-2], + [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.bfloat16, True, 1e-2, 1e-2], ], skip_on_empty=True, ) @@ -370,16 +385,17 @@ def test_optimized_against_torch( basis_norm_mode, grid_in, grid_out, + amp_dtype, transpose, atol, rtol, - verbose=False, + verbose=True, ): if (self.device.type == "cuda") and (not cuda_kernels_is_available()): raise unittest.SkipTest("skipping test because CUDA kernels are not available") - if verbose: - print(f"Testing DISCO convolution on {in_shape[0]}x{in_shape[1]} {grid_in} grid to {out_shape[0]}x{out_shape[1]} {grid_out} grid on {self.device.type} device") + if (self.device.type == "cpu") and (amp_dtype != torch.float32): + raise unittest.SkipTest("skipping test because CPU does not support non-float32 autocast") set_seed(333) @@ -434,14 +450,16 @@ def test_optimized_against_torch( # FWD and BWD pass inp.requires_grad = True - out_naive = conv_naive(inp) + with torch.amp.autocast(device_type=self.device.type, dtype=amp_dtype): + out_naive = conv_naive(inp) grad_input = torch.randn_like(out_naive) out_naive.backward(grad_input) inp_grad_naive = inp.grad.clone() # perform the reference computation inp.grad = None - out_opt = conv_opt(inp) + with torch.amp.autocast(device_type=self.device.type, dtype=amp_dtype): + out_opt = conv_opt(inp) out_opt.backward(grad_input) inp_grad_opt = inp.grad.clone() diff --git a/torch_harmonics/disco/_disco_utils.py b/torch_harmonics/disco/_disco_utils.py index 4ad1effb..c1aef364 100644 --- a/torch_harmonics/disco/_disco_utils.py +++ b/torch_harmonics/disco/_disco_utils.py @@ -60,10 +60,11 @@ def _disco_s2_contraction_optimized( inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - itype = inp.dtype - inp = inp.to(torch.float32).contiguous() + #itype = inp.dtype + #inp = inp.to(torch.float32).contiguous() + inp = inp.contiguous() out = disco_kernels.forward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) - out = out.to(itype) + #out = out.to(itype) return out # transpose @@ -72,10 +73,11 @@ def _disco_s2_transpose_contraction_optimized( inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - itype = inp.dtype - inp = inp.to(torch.float32).contiguous() + #itype = inp.dtype + #inp = inp.to(torch.float32).contiguous() + inp = inp.contiguous() out = disco_kernels.backward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) - out = out.to(itype) + #out = out.to(itype) return out # forward fake @@ -107,11 +109,12 @@ def _disco_s2_contraction_bwd_optimized(ctx, grad_output): roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors if ctx.needs_input_grad[0]: - gtype = grad_output.dtype - grad_output = grad_output.to(torch.float32).contiguous() + #gtype = grad_output.dtype + #grad_output = grad_output.to(torch.float32).contiguous() + grad_output = grad_output.contiguous() grad_input = disco_kernels.backward.default(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) - grad_input = grad_input.to(gtype) + #grad_input = grad_input.to(gtype) else: grad_input = None @@ -126,11 +129,12 @@ def _disco_s2_transpose_contraction_bwd_optimized(ctx, grad_output): roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors if ctx.needs_input_grad[0]: - gtype = grad_output.dtype - grad_output = grad_output.to(torch.float32).contiguous() + #gtype = grad_output.dtype + #grad_output = grad_output.to(torch.float32).contiguous() + grad_output = grad_output.contiguous() grad_input = disco_kernels.forward.default(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) - grad_input = grad_input.to(gtype) + #grad_input = grad_input.to(gtype) else: grad_input = None @@ -183,17 +187,22 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in # add a dummy dimension for nkernel and move the batch and channel dims to the end x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1) x = x.expand(kernel_size, -1, -1, -1) + xtype = x.dtype + x = x.to(torch.float32).contiguous() y = torch.zeros(nlon_out, kernel_size, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype) - for pout in range(nlon_out): - # sparse contraction with psi - y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1)) - # we need to repeatedly roll the input tensor to faciliate the shifted multiplication - x = torch.roll(x, -pscale, dims=2) + with torch.amp.autocast(device_type=x.device.type, enabled=False): + for pout in range(nlon_out): + # sparse contraction with psi + y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1)) + # we need to repeatedly roll the input tensor to faciliate the shifted multiplication + x = torch.roll(x, -pscale, dims=2) + + y = y.to(xtype) # reshape y back to expose the correct dimensions - y = y.permute(3, 1, 2, 0).reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out) + y = y.permute(3, 1, 2, 0).reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out).contiguous() return y @@ -218,18 +227,25 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl # we only need to apoply the nlon stride here, since nlat stride is taken care of by the kernel x_ext[:, :, ::pscale, :] = x[...] + xtype = x_ext.dtype + x_ext = x_ext.to(torch.float32).contiguous() + # create output tensor y = torch.zeros(kernel_size, nlon_out, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype) - for pout in range(nlon_out): - # we need to repeatedly roll the input tensor to faciliate the shifted multiplication - # TODO: double-check why this has to happen first - x_ext = torch.roll(x_ext, -1, dims=2) - # sparse contraction with the modified psi - y[:, pout, :, :] = torch.bmm(psi, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1)) + with torch.amp.autocast(device_type=x.device.type, enabled=False): + for pout in range(nlon_out): + # we need to repeatedly roll the input tensor to faciliate the shifted multiplication + # TODO: double-check why this has to happen first + x_ext = torch.roll(x_ext, -1, dims=2) + # sparse contraction with the modified psi + y[:, pout, :, :] = torch.bmm(psi, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1)) # sum over the kernel dimension and reshape to the correct output size - y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out).contiguous() + y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out) + + # convert datatype back to input type + y = y.to(xtype).contiguous() return y diff --git a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu index 00f5e514..a23872cf 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu @@ -31,13 +31,16 @@ #include "disco.h" #include "disco_cuda.cuh" +#include +#include + namespace disco_kernels { -template +template __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int Ho, const int Wo, const int pscale, const int64_t *__restrict__ roff, const int64_t *__restrict__ kers, const int64_t *__restrict__ rows, const int64_t *__restrict__ cols, - const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out) + const COMPUTE_T *__restrict__ vals, const STORAGE_T *__restrict__ inp, COMPUTE_T *__restrict__ out) { const int tid = threadIdx.x; @@ -55,15 +58,15 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H out += bidy * Ho * Wo; // align to larger supported fp type - extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*(BDIM_X*ELXTH)*pscale] + extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // COMPUTE_T __sh[2*(BDIM_X*ELXTH)*pscale] - REAL_T(*__sh)[BDIM_X * ELXTH * 2] = reinterpret_cast(__sh_ptr); + COMPUTE_T(*__sh)[BDIM_X * ELXTH * 2] = reinterpret_cast(__sh_ptr); // copy current inp row in regs - REAL_T __reg[ELXTH]; + COMPUTE_T __reg[ELXTH]; #pragma unroll - for (int i = 0; i < ELXTH; i++) { __reg[i] = (i * BDIM_X + tid < Wi) ? inp[i * BDIM_X + tid] : REAL_T(0); } + for (int i = 0; i < ELXTH; i++) { __reg[i] = (i * BDIM_X + tid < Wi) ? static_cast(inp[i * BDIM_X + tid]) : static_cast(0); } // reset shared row up to Wo+2, remaining // ppscale*(BDIM_X*ELXTH - Wo) locations @@ -71,7 +74,7 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H // global mem for (int i = 0; i < pscale; i++) { #pragma unroll - for (int j = 0; j < 2 * BDIM_X * ELXTH; j += BDIM_X) { __sh[i][j + tid] = 0; } + for (int j = 0; j < 2 * BDIM_X * ELXTH; j += BDIM_X) { __sh[i][j + tid] = static_cast(0); } } __syncthreads(); @@ -84,7 +87,7 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H for (int64_t nz = soff; nz < eoff; nz++) { const int col = cols[nz]; - const REAL_T val = vals[nz]; + const COMPUTE_T val = vals[nz]; // if we are processing a nz with a col value // leading to a new row of inp then copy it @@ -96,12 +99,12 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H for (int i = 0; i < pscale; i++) { for (int j = tid; j < Wi; j += BDIM_X) { - const REAL_T v = __sh[i][j] + __sh[i][Wi + j]; + const COMPUTE_T v = __sh[i][j] + __sh[i][Wi + j]; atomicAdd(&out[h_prev * Wo + j * pscale + i], v); - __sh[i][j] = 0; - __sh[i][Wi + j] = 0; + __sh[i][j] = static_cast(0); + __sh[i][Wi + j] = static_cast(0); } } __syncthreads(); @@ -133,38 +136,38 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H for (int j = tid; j < Wi; j += BDIM_X) { - const REAL_T v = __sh[i][j] + __sh[i][Wi + j]; + const COMPUTE_T v = __sh[i][j] + __sh[i][Wi + j]; atomicAdd(&out[h_prev * Wo + j * pscale + i], v); } } return; } -template +template __global__ __launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo, const int pscale, const int64_t *__restrict__ roff, const int64_t *__restrict__ kers, const int64_t *__restrict__ rows, - const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals, - const REAL_T *__restrict__ inp, REAL_T *__restrict__ out) + const int64_t *__restrict__ cols, const COMPUTE_T *__restrict__ vals, + const STORAGE_T *__restrict__ inp, COMPUTE_T *__restrict__ out) { if constexpr (PSCALE != 0) { - disco_bwd_d(Hi, Wi, K, Ho, Wo, PSCALE, roff, kers, rows, cols, vals, inp, out); + disco_bwd_d(Hi, Wi, K, Ho, Wo, PSCALE, roff, kers, rows, cols, vals, inp, out); } else { - disco_bwd_d(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out); + disco_bwd_d(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out); } return; } -template +template static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t nrows, int64_t *roff_d, int64_t *ker_d, - int64_t *row_d, int64_t *col_d, REAL_T *val_d, REAL_T *inp_d, REAL_T *out_d, + int64_t *row_d, int64_t *col_d, COMPUTE_T *val_d, STORAGE_T *inp_d, COMPUTE_T *out_d, cudaStream_t stream) { - static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8); + static_assert(sizeof(STORAGE_T) == 2 || sizeof(STORAGE_T) == 4 || sizeof(STORAGE_T) == 8); if constexpr (ELXTH <= ELXTH_MAX) { if (NTH * ELXTH >= Wi) { @@ -175,23 +178,23 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t switch (pscale) { case 1: - disco_bwd_blk_k<<>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, + disco_bwd_blk_k<<>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); break; case 2: - disco_bwd_blk_k<<>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, + disco_bwd_blk_k<<>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); break; case 3: - disco_bwd_blk_k<<>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, + disco_bwd_blk_k<<>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); break; default: - disco_bwd_blk_k<<>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, + disco_bwd_blk_k<<>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); } } else { - launch_kernel(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, + launch_kernel(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d, stream); } } @@ -220,7 +223,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t // allocate output int64_t out_dims[] = {B, C, Ho, Wo}; - auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); + auto options = torch::TensorOptions().device(inp.device()).dtype(val.dtype()); torch::Tensor out = torch::zeros(out_dims, options); // get stream @@ -230,50 +233,64 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t static_assert(0 == (ELXTH_MAX % 2)); if (Wo <= 64 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { - launch_kernel<64, 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_backward_cuda", ([&] { + using storage_t = scalar_t; + using compute_t = typename at::opmath_type; + launch_kernel<64, 1, storage_t, compute_t>( + BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), + ker_idx.data_ptr(), row_idx.data_ptr(), + col_idx.data_ptr(), val.data_ptr(), + inp.data_ptr(), out.data_ptr(), stream); + })); } else if (Wo <= 128 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { - launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_backward_cuda", ([&] { + using storage_t = scalar_t; + using compute_t = typename at::opmath_type; + launch_kernel<128, (ELXTH_MAX / 2) + 1, storage_t, compute_t>( + BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), + ker_idx.data_ptr(), row_idx.data_ptr(), + col_idx.data_ptr(), val.data_ptr(), + inp.data_ptr(), out.data_ptr(), stream); + })); } else if (Wo <= 256 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { - launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_backward_cuda", ([&] { + using storage_t = scalar_t; + using compute_t = typename at::opmath_type; + launch_kernel<256, (ELXTH_MAX / 2) + 1, storage_t, compute_t>( + BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), + ker_idx.data_ptr(), row_idx.data_ptr(), + col_idx.data_ptr(), val.data_ptr(), + inp.data_ptr(), out.data_ptr(), stream); + })); } else if (Wo <= 512 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { - launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_backward_cuda", ([&] { + using storage_t = scalar_t; + using compute_t = typename at::opmath_type; + launch_kernel<512, (ELXTH_MAX / 2) + 1, storage_t, compute_t>( + BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), + ker_idx.data_ptr(), row_idx.data_ptr(), + col_idx.data_ptr(), val.data_ptr(), + inp.data_ptr(), out.data_ptr(), stream); + })); } else if (Wo <= 1024 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { - launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_backward_cuda", ([&] { + using storage_t = scalar_t; + using compute_t = typename at::opmath_type; + launch_kernel<1024, (ELXTH_MAX / 2) + 1, storage_t, compute_t>( + BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), + ker_idx.data_ptr(), row_idx.data_ptr(), + col_idx.data_ptr(), val.data_ptr(), + inp.data_ptr(), out.data_ptr(), stream); + })); } else { fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo, 1024 * ELXTH_MAX); exit(EXIT_FAILURE); } + + // convert type if requested + out = out.to(inp.dtype()); + return out; } diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index 8482f76d..0f6d43cf 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -31,13 +31,16 @@ #include "disco.h" #include "disco_cuda.cuh" +#include +#include + namespace disco_kernels { -template +template __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int Ho, const int Wo, const int pscale, const int64_t *__restrict__ roff, const int64_t *__restrict__ kers, const int64_t *__restrict__ rows, const int64_t *__restrict__ cols, - const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out) + const COMPUTE_T *__restrict__ vals, const STORAGE_T *__restrict__ inp, STORAGE_T *__restrict__ out) { const int tid = threadIdx.x; @@ -54,12 +57,12 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H inp += bidy * Hi * Wi; out += bidy * K * Ho * Wo + ker * Ho * Wo + row * Wo; - REAL_T __reg[ELXTH] = {0}; + COMPUTE_T __reg[ELXTH] = {0}; // align to larger supported fp type extern __shared__ __align__( - sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)] - REAL_T *__sh = reinterpret_cast(__sh_ptr); + sizeof(double)) unsigned char __sh_ptr[]; // STORAGE_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)] + STORAGE_T *__sh = reinterpret_cast(__sh_ptr); int col_prev = cols[soff]; @@ -68,7 +71,7 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H // copy current inp row in shmem for (int i = tid; i < Wi; i += BDIM_X) { - const REAL_T v = inp[h_prev * Wi + i]; + const STORAGE_T v = inp[h_prev * Wi + i]; __sh[i] = v; __sh[Wi + i] = v; } @@ -79,7 +82,7 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H for (int64_t nz = soff; nz < eoff; nz++) { const int col = cols[nz]; - const REAL_T val = vals[nz]; + const COMPUTE_T val = vals[nz]; // if we are processing a nz with a col value // leading to a new row of inp then copy it @@ -94,7 +97,7 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H __syncthreads(); for (int i = tid; i < Wi; i += BDIM_X) { - const REAL_T v = inp[h_prev * Wi + i]; + const STORAGE_T v = inp[h_prev * Wi + i]; __sh[i] = v; __sh[Wi + i] = v; } @@ -120,7 +123,7 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H // also, to avoid the conditional, sh can be extended to // cover the maximum location accessed during this loop // - // REAL_T __sh[2*Wi + ppscale*NUM_REM] + // STORAGE_T __sh[2*Wi + ppscale*NUM_REM] // // Wi + (Wi/Wo)*BDIM_X*ELXTH = (since BDIM_X*ELXTH >= Wo) = // = Wi + (Wi/Wo)*(Wo + (BDIM_X*ELXTH - Wo)) = @@ -130,7 +133,7 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H const int wpp = w + pscale * pp; - __reg[i] += val * __sh[wpp]; + __reg[i] += val * static_cast(__sh[wpp]); } } @@ -140,33 +143,33 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H const int pp = i * BDIM_X + tid; if (pp >= Wo) break; - out[pp] = __reg[i]; + out[pp] = static_cast(__reg[i]); } return; } -template +template __global__ __launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo, const int pscale, const int64_t *__restrict__ roff, const int64_t *__restrict__ kers, const int64_t *__restrict__ rows, - const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals, - const REAL_T *__restrict__ inp, REAL_T *__restrict__ out) + const int64_t *__restrict__ cols, const COMPUTE_T *__restrict__ vals, + const STORAGE_T *__restrict__ inp, STORAGE_T *__restrict__ out) { - disco_fwd_d(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out); + disco_fwd_d(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out); return; } -template +template static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t nrows, int64_t *roff_d, int64_t *ker_d, - int64_t *row_d, int64_t *col_d, REAL_T *val_d, REAL_T *inp_d, REAL_T *out_d, + int64_t *row_d, int64_t *col_d, COMPUTE_T *val_d, STORAGE_T *inp_d, STORAGE_T *out_d, cudaStream_t stream) { - static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8); + static_assert(sizeof(STORAGE_T) == 2 || sizeof(STORAGE_T) == 4 || sizeof(STORAGE_T) == 8); if constexpr (ELXTH <= ELXTH_MAX) { if (NTH * ELXTH >= Wo) { @@ -175,11 +178,12 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t const int pscale = Wi / Wo; size_t shmem = sizeof(*out_d) * (Wi * 2 + pscale * (NTH * ELXTH - Wo)); - disco_fwd_blk_k<<>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, + disco_fwd_blk_k<<>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); } else { - launch_kernel(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, - out_d, stream); + launch_kernel( + BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d, stream + ); } } return; @@ -218,45 +222,55 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t // pick the correct launch config if (Wo <= 64 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { - launch_kernel<64, 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_forward_cuda", ([&] { + using storage_t = scalar_t; + using compute_t = typename at::opmath_type; + launch_kernel<64, 1, storage_t, compute_t>( + BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), + ker_idx.data_ptr(), row_idx.data_ptr(), + col_idx.data_ptr(), val.data_ptr(), + inp.data_ptr(), out.data_ptr(), stream); + })); } else if (Wo <= 128 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { - launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_forward_cuda", ([&] { + using storage_t = scalar_t; + using compute_t = typename at::opmath_type; + launch_kernel<128, (ELXTH_MAX / 2) + 1, storage_t, compute_t>( + BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), + ker_idx.data_ptr(), row_idx.data_ptr(), + col_idx.data_ptr(), val.data_ptr(), + inp.data_ptr(), out.data_ptr(), stream); + })); } else if (Wo <= 256 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { - launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_forward_cuda", ([&] { + using storage_t = scalar_t; + using compute_t = typename at::opmath_type; + launch_kernel<256, (ELXTH_MAX / 2) + 1, storage_t, compute_t>( + BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), + ker_idx.data_ptr(), row_idx.data_ptr(), + col_idx.data_ptr(), val.data_ptr(), + inp.data_ptr(), out.data_ptr(), stream); + })); } else if (Wo <= 512 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { - launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_forward_cuda", ([&] { + using storage_t = scalar_t; + using compute_t = typename at::opmath_type; + launch_kernel<512, (ELXTH_MAX / 2) + 1, storage_t, compute_t>( + BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), + ker_idx.data_ptr(), row_idx.data_ptr(), + col_idx.data_ptr(), val.data_ptr(), + inp.data_ptr(), out.data_ptr(), stream); + })); } else if (Wo <= 1024 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { - launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_forward_cuda", ([&] { + using storage_t = scalar_t; + using compute_t = typename at::opmath_type; + launch_kernel<1024, (ELXTH_MAX / 2) + 1, storage_t, compute_t>( + BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), + ker_idx.data_ptr(), row_idx.data_ptr(), + col_idx.data_ptr(), val.data_ptr(), + inp.data_ptr(), out.data_ptr(), stream); + })); } else { fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo, 1024 * ELXTH_MAX); From 586cd4944469d4a3d6fd66b78ad64203629a8f4d Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Fri, 23 Jan 2026 00:19:50 -0800 Subject: [PATCH 02/13] relaxing allclose on tests --- Dockerfile.examples | 6 +-- tests/test_convolution.py | 103 ++++++++++++++++++++++---------------- 2 files changed, 63 insertions(+), 46 deletions(-) diff --git a/Dockerfile.examples b/Dockerfile.examples index a8951e8d..67a2ee1a 100644 --- a/Dockerfile.examples +++ b/Dockerfile.examples @@ -30,7 +30,7 @@ # build after cloning in directoy torch_harmonics via # docker build . -t torch_harmonics -FROM nvcr.io/nvidia/pytorch:24.12-py3 +FROM nvcr.io/nvidia/pytorch:25.12-py3 # we need this for tests RUN pip install parameterized @@ -48,7 +48,7 @@ RUN pip install h5py RUN cd /opt && git clone https://github.com/SHI-Labs/NATTEN natten && \ cd natten && \ make WITH_CUDA=1 \ - CUDA_ARCH="7.0;7.2;7.5;8.0;8.6;8.7;9.0" \ + CUDA_ARCH="8.0;8.6;8.7;9.0;10.0" \ WORKERS=4 # install torch harmonics @@ -57,5 +57,5 @@ COPY . /workspace/torch_harmonics # The custom CUDA extension does not suppport architerctures < 7.0 ENV FORCE_CUDA_EXTENSION=1 ENV TORCH_HARMONICS_ENABLE_OPENMP=1 -ENV TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX" +ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.7 9.0 10.0+PTX" RUN cd /workspace/torch_harmonics && pip install --no-build-isolation . diff --git a/tests/test_convolution.py b/tests/test_convolution.py index ad3fdbbd..3bbf69ec 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -182,32 +182,47 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): @parameterized.expand( [ + # fp32 tests # regular convolution - [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (24, 48), (12, 24), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (24, 48), (12, 24), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (24, 48), (12, 24), (2, 1), "morlet", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (24, 48), (12, 24), (3), "zernike", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (16, 24), (8, 8), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (18, 36), (6, 12), (7), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4, 1e-4], - [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4, 1e-4], - [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, 1e-4], + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (24, 48), (12, 24), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (24, 48), (12, 24), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (24, 48), (12, 24), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (24, 48), (12, 24), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (16, 24), (8, 8), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (18, 36), (6, 12), (7), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", torch.float32, False, 5e-4, 1e-4], # transpose convolution - [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (12, 24), (24, 48), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (12, 24), (24, 48), (2, 1), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (12, 24), (24, 48), (3), "zernike", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (8, 8), (16, 24), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (6, 12), (18, 36), (7), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4, 1e-4], - [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4, 1e-4], - [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, 1e-4], + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (12, 24), (24, 48), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (12, 24), (24, 48), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (12, 24), (24, 48), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (8, 8), (16, 24), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (6, 12), (18, 36), (7), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", torch.float32, True, 5e-4, 1e-4], + # fp16 tests + # regular convolution + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float16, False, 1e-2, 1e-2], + [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float16, False, 1e-2, 1e-2], + # transpose convolution + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float16, True, 1e-2, 1e-2], + [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float16, True, 1e-2, 1e-2], + # bf16 tests + # regular convolution + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.bfloat16, False, 1e-2, 1e-2], + [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.bfloat16, False, 1e-2, 1e-2], + # transpose convolution + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.bfloat16, True, 1e-2, 1e-2], + [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.bfloat16, True, 1e-2, 1e-2], ], skip_on_empty=True, ) @@ -223,14 +238,15 @@ def test_sparse_against_dense( basis_norm_mode, grid_in, grid_out, + amp_dtype, transpose, atol, rtol, verbose=False, ): - if verbose: - print(f"Testing DISCO convolution on {in_shape[0]}x{in_shape[1]} {grid_in} grid to {out_shape[0]}x{out_shape[1]} {grid_out} grid on {self.device.type} device") + if (self.device.type == "cpu") and (amp_dtype != torch.float32): + raise unittest.SkipTest("skipping test because CPU does not support non-float32 autocast") set_seed(333) @@ -309,7 +325,8 @@ def test_sparse_against_dense( # FWD and BWD pass x.requires_grad = True - y = conv(x) + with torch.autocast(device_type=self.device.type, dtype=amp_dtype): + y = conv(x) grad_input = torch.randn_like(y) y.backward(grad_input) x_grad = x.grad.clone() @@ -338,23 +355,23 @@ def test_sparse_against_dense( [ # fp32 tests # regular convolution - [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (21, 40), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (21, 40), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (21, 40), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (21, 40), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], + [8, 4, 2, (41, 80), (21, 40), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], + [8, 4, 2, (41, 80), (21, 40), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], + [8, 4, 2, (41, 80), (21, 40), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], + [8, 4, 2, (41, 80), (21, 40), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], # transpose convolution - [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], - [8, 4, 2, (21, 40), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], - [8, 4, 2, (21, 40), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], - [8, 4, 2, (21, 40), (41, 80), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], - [8, 4, 2, (21, 40), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], + [8, 4, 2, (21, 40), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], + [8, 4, 2, (21, 40), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], + [8, 4, 2, (21, 40), (41, 80), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], + [8, 4, 2, (21, 40), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], # fp16 tests # regular convolution [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float16, False, 1e-2, 1e-2], From 06db0ade2c8be23c6ada9d825e42f93ea04b74f3 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Fri, 23 Jan 2026 00:26:16 -0800 Subject: [PATCH 03/13] fixing tests --- tests/test_convolution.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 3bbf69ec..bcc7db29 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -344,10 +344,10 @@ def test_sparse_against_dense( x_ref_grad = x_ref.grad.clone() # compare results - self.assertTrue(compare_tensors(f"output", y, y_ref, atol=atol, rtol=rtol, verbose=verbose)) + self.assertTrue(compare_tensors(f"output", y.to(y_ref.dtype), y_ref, atol=atol, rtol=rtol, verbose=verbose)) # compare - self.assertTrue(compare_tensors(f"input grad", x_grad, x_ref_grad, atol=atol, rtol=rtol, verbose=verbose)) + self.assertTrue(compare_tensors(f"input grad", x_grad.to(x_ref_grad.dtype), x_ref_grad, atol=atol, rtol=rtol, verbose=verbose)) self.assertTrue(compare_tensors(f"weight grad", conv.weight.grad, w_ref.grad, atol=atol, rtol=rtol, verbose=verbose)) @@ -355,23 +355,23 @@ def test_sparse_against_dense( [ # fp32 tests # regular convolution - [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], - [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], - [8, 4, 2, (41, 80), (21, 40), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], - [8, 4, 2, (41, 80), (21, 40), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], - [8, 4, 2, (41, 80), (21, 40), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], - [8, 4, 2, (41, 80), (21, 40), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 5e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (41, 80), (21, 40), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (41, 80), (21, 40), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (41, 80), (21, 40), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (41, 80), (21, 40), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], # transpose convolution - [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], - [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], - [8, 4, 2, (21, 40), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], - [8, 4, 2, (21, 40), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], - [8, 4, 2, (21, 40), (41, 80), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], - [8, 4, 2, (21, 40), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 5e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (21, 40), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (21, 40), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (21, 40), (41, 80), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (21, 40), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], # fp16 tests # regular convolution [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float16, False, 1e-2, 1e-2], From 4c8ff7769e376efbf85b71446cca335578aca9f9 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Fri, 23 Jan 2026 00:31:45 -0800 Subject: [PATCH 04/13] debugging tests --- tests/test_convolution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index bcc7db29..6467fa5e 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -242,7 +242,7 @@ def test_sparse_against_dense( transpose, atol, rtol, - verbose=False, + verbose=True, ): if (self.device.type == "cpu") and (amp_dtype != torch.float32): @@ -366,7 +366,7 @@ def test_sparse_against_dense( # transpose convolution [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 5e-4], [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], [8, 4, 2, (21, 40), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], [8, 4, 2, (21, 40), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], From dfcac8c9d662a90a86e6ba73970910753ce370db Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Fri, 23 Jan 2026 00:36:08 -0800 Subject: [PATCH 05/13] adjusting test precisions --- tests/test_convolution.py | 48 +++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 6467fa5e..bea58b04 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -184,31 +184,31 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): [ # fp32 tests # regular convolution - [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (24, 48), (12, 24), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (24, 48), (12, 24), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (24, 48), (12, 24), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (24, 48), (12, 24), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (16, 24), (8, 8), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (18, 36), (6, 12), (7), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], + [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], + [8, 4, 2, (24, 48), (12, 24), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], + [8, 4, 2, (24, 48), (12, 24), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], + [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], + [8, 4, 2, (24, 48), (12, 24), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], + [8, 4, 2, (24, 48), (12, 24), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], + [8, 4, 2, (16, 24), (8, 8), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], + [8, 4, 2, (18, 36), (6, 12), (7), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], + [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", torch.float32, False, 1e-3, 5e-4], + [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", torch.float32, False, 1e-3, 5e-4], + [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", torch.float32, False, 1e-3, 5e-4], # transpose convolution - [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (12, 24), (24, 48), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (12, 24), (24, 48), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (12, 24), (24, 48), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (8, 8), (16, 24), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (6, 12), (18, 36), (7), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], + [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], + [8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], + [8, 4, 2, (12, 24), (24, 48), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], + [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], + [8, 4, 2, (12, 24), (24, 48), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], + [8, 4, 2, (12, 24), (24, 48), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], + [8, 4, 2, (8, 8), (16, 24), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], + [8, 4, 2, (6, 12), (18, 36), (7), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], + [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", torch.float32, True, 1e-3, 5e-4], + [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", torch.float32, True, 1e-3, 5e-4], + [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", torch.float32, True, 1e-3, 5e-4], # fp16 tests # regular convolution [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float16, False, 1e-2, 1e-2], From 126412c91a5d2b9a50ab26e9d747911d2c16a0c1 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Fri, 23 Jan 2026 01:31:54 -0800 Subject: [PATCH 06/13] fixed tests --- tests/test_convolution.py | 107 ++++++++++++++++++++++---------------- 1 file changed, 61 insertions(+), 46 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index bea58b04..a3208333 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -184,31 +184,31 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): [ # fp32 tests # regular convolution - [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], - [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], - [8, 4, 2, (24, 48), (12, 24), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], - [8, 4, 2, (24, 48), (12, 24), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], - [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], - [8, 4, 2, (24, 48), (12, 24), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], - [8, 4, 2, (24, 48), (12, 24), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], - [8, 4, 2, (16, 24), (8, 8), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], - [8, 4, 2, (18, 36), (6, 12), (7), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-3, 5e-4], - [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", torch.float32, False, 1e-3, 5e-4], - [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", torch.float32, False, 1e-3, 5e-4], - [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", torch.float32, False, 1e-3, 5e-4], + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (24, 48), (12, 24), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (24, 48), (12, 24), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (24, 48), (12, 24), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (24, 48), (12, 24), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (16, 24), (8, 8), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (18, 36), (6, 12), (7), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", torch.float32, False, 1e-4, 1e-4], # transpose convolution - [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], - [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], - [8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], - [8, 4, 2, (12, 24), (24, 48), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], - [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], - [8, 4, 2, (12, 24), (24, 48), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], - [8, 4, 2, (12, 24), (24, 48), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], - [8, 4, 2, (8, 8), (16, 24), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], - [8, 4, 2, (6, 12), (18, 36), (7), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-3, 5e-4], - [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", torch.float32, True, 1e-3, 5e-4], - [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", torch.float32, True, 1e-3, 5e-4], - [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", torch.float32, True, 1e-3, 5e-4], + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (12, 24), (24, 48), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (12, 24), (24, 48), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (12, 24), (24, 48), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (8, 8), (16, 24), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (6, 12), (18, 36), (7), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", torch.float32, True, 1e-4, 1e-4], # fp16 tests # regular convolution [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float16, False, 1e-2, 1e-2], @@ -218,11 +218,11 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float16, True, 1e-2, 1e-2], # bf16 tests # regular convolution - [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.bfloat16, False, 1e-2, 1e-2], - [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.bfloat16, False, 1e-2, 1e-2], + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.bfloat16, False, 5e-2, 1e-2], + [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.bfloat16, False, 5e-2, 1e-2], # transpose convolution - [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.bfloat16, True, 1e-2, 1e-2], - [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.bfloat16, True, 1e-2, 1e-2], + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.bfloat16, True, 5e-2, 5e-2], + [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.bfloat16, True, 5e-2, 1e-2], ], skip_on_empty=True, ) @@ -242,14 +242,22 @@ def test_sparse_against_dense( transpose, atol, rtol, - verbose=True, + verbose=False, ): if (self.device.type == "cpu") and (amp_dtype != torch.float32): raise unittest.SkipTest("skipping test because CPU does not support non-float32 autocast") + # disable tf32 + torch.backends.cuda.matmul.fp32_precision = "ieee" + torch.backends.cudnn.fp32_precision = "ieee" + torch.backends.cudnn.conv.fp32_precision = "ieee" + torch.backends.cudnn.rnn.fp32_precision = "ieee" + + # set seed set_seed(333) + # use optimized kernels use_optimized_kernels = optimized_kernels_is_available() if (self.device.type == "cuda") and (not cuda_kernels_is_available()): use_optimized_kernels = False @@ -355,23 +363,23 @@ def test_sparse_against_dense( [ # fp32 tests # regular convolution - [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (41, 80), (21, 40), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (41, 80), (21, 40), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (41, 80), (21, 40), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], - [8, 4, 2, (41, 80), (21, 40), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 5e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (21, 40), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (21, 40), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (21, 40), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (21, 40), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, False, 1e-4, 1e-4], # transpose convolution - [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 5e-4], - [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (21, 40), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (21, 40), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (21, 40), (41, 80), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], - [8, 4, 2, (21, 40), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 5e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (2, 3), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (41, 80), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (21, 40), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (21, 40), (41, 80), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (21, 40), (41, 80), (2, 1), "morlet", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], + [8, 4, 2, (21, 40), (41, 80), (3), "zernike", "mean", "equiangular", "equiangular", torch.float32, True, 1e-4, 1e-4], # fp16 tests # regular convolution [8, 4, 2, (41, 80), (41, 80), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float16, False, 1e-2, 1e-2], @@ -406,7 +414,7 @@ def test_optimized_against_torch( transpose, atol, rtol, - verbose=True, + verbose=False, ): if (self.device.type == "cuda") and (not cuda_kernels_is_available()): raise unittest.SkipTest("skipping test because CUDA kernels are not available") @@ -414,6 +422,13 @@ def test_optimized_against_torch( if (self.device.type == "cpu") and (amp_dtype != torch.float32): raise unittest.SkipTest("skipping test because CPU does not support non-float32 autocast") + # disable tf32 + torch.backends.cuda.matmul.fp32_precision = "ieee" + torch.backends.cudnn.fp32_precision = "ieee" + torch.backends.cudnn.conv.fp32_precision = "ieee" + torch.backends.cudnn.rnn.fp32_precision = "ieee" + + # set seed set_seed(333) nlat_in, nlon_in = in_shape From 5b79463f101163ab22d92e2e6ec303b596796603 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 26 Jan 2026 02:30:22 -0800 Subject: [PATCH 07/13] adding fp32 cast to disco because current fp16 and bf16 kernels are slower --- Dockerfile | 1 + torch_harmonics/disco/_disco_utils.py | 29 +++++++++++---------------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/Dockerfile b/Dockerfile index 2cb926a1..840f3e03 100644 --- a/Dockerfile +++ b/Dockerfile @@ -38,6 +38,7 @@ RUN pip install parameterized # The custom CUDA extension does not suppport architerctures < 7.0 ENV FORCE_CUDA_EXTENSION=1 ENV TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX" +ENV TORCH_HARMONICS_DEBUG=1 COPY . /workspace/torch_harmonics RUN cd /workspace/torch_harmonics && pip install --no-build-isolation . diff --git a/torch_harmonics/disco/_disco_utils.py b/torch_harmonics/disco/_disco_utils.py index c1aef364..adea7731 100644 --- a/torch_harmonics/disco/_disco_utils.py +++ b/torch_harmonics/disco/_disco_utils.py @@ -30,7 +30,6 @@ # from typing import Optional -import math import torch from disco_helpers import optimized_kernels_is_available @@ -60,11 +59,10 @@ def _disco_s2_contraction_optimized( inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - #itype = inp.dtype - #inp = inp.to(torch.float32).contiguous() - inp = inp.contiguous() + itype = inp.dtype + inp = inp.to(torch.float32).contiguous() out = disco_kernels.forward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) - #out = out.to(itype) + out = out.to(itype) return out # transpose @@ -73,11 +71,10 @@ def _disco_s2_transpose_contraction_optimized( inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - #itype = inp.dtype - #inp = inp.to(torch.float32).contiguous() - inp = inp.contiguous() + itype = inp.dtype + inp = inp.to(torch.float32).contiguous() out = disco_kernels.backward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) - #out = out.to(itype) + out = out.to(itype) return out # forward fake @@ -109,12 +106,11 @@ def _disco_s2_contraction_bwd_optimized(ctx, grad_output): roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors if ctx.needs_input_grad[0]: - #gtype = grad_output.dtype - #grad_output = grad_output.to(torch.float32).contiguous() - grad_output = grad_output.contiguous() + gtype = grad_output.dtype + grad_output = grad_output.to(torch.float32).contiguous() grad_input = disco_kernels.backward.default(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) - #grad_input = grad_input.to(gtype) + grad_input = grad_input.to(gtype) else: grad_input = None @@ -129,12 +125,11 @@ def _disco_s2_transpose_contraction_bwd_optimized(ctx, grad_output): roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors if ctx.needs_input_grad[0]: - #gtype = grad_output.dtype - #grad_output = grad_output.to(torch.float32).contiguous() - grad_output = grad_output.contiguous() + gtype = grad_output.dtype + grad_output = grad_output.to(torch.float32).contiguous() grad_input = disco_kernels.forward.default(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) - #grad_input = grad_input.to(gtype) + grad_input = grad_input.to(gtype) else: grad_input = None From 1f96e346138406179de5392d5bb5de603efe6b1c Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 26 Jan 2026 04:11:34 -0800 Subject: [PATCH 08/13] made disabling of tf32 conditional on cuda availability --- tests/test_convolution.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index a3208333..6751387d 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -249,10 +249,11 @@ def test_sparse_against_dense( raise unittest.SkipTest("skipping test because CPU does not support non-float32 autocast") # disable tf32 - torch.backends.cuda.matmul.fp32_precision = "ieee" - torch.backends.cudnn.fp32_precision = "ieee" - torch.backends.cudnn.conv.fp32_precision = "ieee" - torch.backends.cudnn.rnn.fp32_precision = "ieee" + if torch.cuda.is_available(): + torch.backends.cuda.matmul.fp32_precision = "ieee" + torch.backends.cudnn.fp32_precision = "ieee" + torch.backends.cudnn.conv.fp32_precision = "ieee" + torch.backends.cudnn.rnn.fp32_precision = "ieee" # set seed set_seed(333) @@ -423,10 +424,11 @@ def test_optimized_against_torch( raise unittest.SkipTest("skipping test because CPU does not support non-float32 autocast") # disable tf32 - torch.backends.cuda.matmul.fp32_precision = "ieee" - torch.backends.cudnn.fp32_precision = "ieee" - torch.backends.cudnn.conv.fp32_precision = "ieee" - torch.backends.cudnn.rnn.fp32_precision = "ieee" + if torch.cuda.is_available(): + torch.backends.cuda.matmul.fp32_precision = "ieee" + torch.backends.cudnn.fp32_precision = "ieee" + torch.backends.cudnn.conv.fp32_precision = "ieee" + torch.backends.cudnn.rnn.fp32_precision = "ieee" # set seed set_seed(333) From 551913c2f4c0d2a4ede027a4bf04123ce0605c1a Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 27 Jan 2026 00:05:10 -0800 Subject: [PATCH 09/13] decorating dynamo unsupported primitives with compiler disabling --- torch_harmonics/distributed/primitives.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch_harmonics/distributed/primitives.py b/torch_harmonics/distributed/primitives.py index 96d17908..51231fe0 100644 --- a/torch_harmonics/distributed/primitives.py +++ b/torch_harmonics/distributed/primitives.py @@ -442,9 +442,8 @@ def backward(ctx, grad_output): return _reduce_scatter(grad_output, ctx.dim, use_fp32=True, group=polar_group()), None, None else: return grad_output, None, None - - + def copy_to_polar_region(input_): return _CopyToPolarRegion.apply(input_) @@ -457,14 +456,18 @@ def reduce_from_polar_region(input_): def reduce_from_azimuth_region(input_): return _ReduceFromAzimuthRegion.apply(input_) +@torch.compiler.disable() def scatter_to_polar_region(input_, dim_): return _ScatterToPolarRegion.apply(input_, dim_) +@torch.compiler.disable() def gather_from_polar_region(input_, dim_, shapes_): return _GatherFromPolarRegion.apply(input_, dim_, shapes_) +@torch.compiler.disable() def reduce_from_scatter_to_polar_region(input_, dim_): return _ReduceFromScatterToPolarRegion.apply(input_, dim_) +@torch.compiler.disable() def gather_from_copy_to_polar_region(input_, dim_, shapes_): return _GatherFromCopyToPolarRegion.apply(input_, dim_, shapes_) From 1f557309bdae446a82351a94f4f3ed7377f5c968 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 27 Jan 2026 02:58:09 -0800 Subject: [PATCH 10/13] some small refactoring --- tests/test_convolution.py | 30 +++++++---------- tests/testutils.py | 16 ++++++++++ .../distributed/distributed_convolution.py | 10 +++--- .../distributed/distributed_resample.py | 8 ++--- .../distributed/distributed_sht.py | 32 +++++++++---------- torch_harmonics/distributed/primitives.py | 25 +++++++++------ 6 files changed, 68 insertions(+), 53 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 6751387d..fcbf4a6d 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -42,7 +42,7 @@ from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes from torch_harmonics.disco import cuda_kernels_is_available, optimized_kernels_is_available -from testutils import set_seed, compare_tensors +from testutils import disable_tf32, set_seed, compare_tensors if not optimized_kernels_is_available(): print(f"Warning: Couldn't import optimized disco convolution kernels") @@ -211,18 +211,18 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", torch.float32, True, 1e-4, 1e-4], # fp16 tests # regular convolution - [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float16, False, 1e-2, 1e-2], - [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float16, False, 1e-2, 1e-2], + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float16, False, 2e-2, 1e-2], + [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float16, False, 2e-2, 1e-2], # transpose convolution - [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float16, True, 1e-2, 1e-2], - [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float16, True, 1e-2, 1e-2], + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.float16, True, 2e-2, 1e-2], + [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.float16, True, 2e-2, 1e-2], # bf16 tests # regular convolution - [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.bfloat16, False, 5e-2, 1e-2], - [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.bfloat16, False, 5e-2, 1e-2], + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.bfloat16, False, 5e-2, 5e-2], + [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.bfloat16, False, 5e-2, 5e-2], # transpose convolution [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", torch.bfloat16, True, 5e-2, 5e-2], - [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.bfloat16, True, 5e-2, 1e-2], + [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", torch.bfloat16, True, 5e-2, 5e-2], ], skip_on_empty=True, ) @@ -242,18 +242,14 @@ def test_sparse_against_dense( transpose, atol, rtol, - verbose=False, + verbose=True, ): if (self.device.type == "cpu") and (amp_dtype != torch.float32): raise unittest.SkipTest("skipping test because CPU does not support non-float32 autocast") # disable tf32 - if torch.cuda.is_available(): - torch.backends.cuda.matmul.fp32_precision = "ieee" - torch.backends.cudnn.fp32_precision = "ieee" - torch.backends.cudnn.conv.fp32_precision = "ieee" - torch.backends.cudnn.rnn.fp32_precision = "ieee" + disable_tf32() # set seed set_seed(333) @@ -424,11 +420,7 @@ def test_optimized_against_torch( raise unittest.SkipTest("skipping test because CPU does not support non-float32 autocast") # disable tf32 - if torch.cuda.is_available(): - torch.backends.cuda.matmul.fp32_precision = "ieee" - torch.backends.cudnn.fp32_precision = "ieee" - torch.backends.cudnn.conv.fp32_precision = "ieee" - torch.backends.cudnn.rnn.fp32_precision = "ieee" + disable_tf32() # set seed set_seed(333) diff --git a/tests/testutils.py b/tests/testutils.py index b74f9cf9..ed522f1e 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -30,6 +30,8 @@ # import os +from packaging import version + import torch import torch.distributed as dist import torch_harmonics.distributed as thd @@ -41,6 +43,20 @@ def set_seed(seed=333): return +def disable_tf32(): + # the api for this was changed lately in pytorch + if torch.cuda.is_available(): + if version.parse(torch.__version__) >= version.parse("2.9.0"): + torch.backends.cuda.matmul.fp32_precision = "ieee" + torch.backends.cudnn.fp32_precision = "ieee" + torch.backends.cudnn.conv.fp32_precision = "ieee" + torch.backends.cudnn.rnn.fp32_precision = "ieee" + else: + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + return + + def setup_distributed_context(ctx): ctx.world_rank = int(os.getenv("WORLD_RANK", 0)) ctx.grid_size_h = int(os.getenv("GRID_H", 1)) diff --git a/torch_harmonics/distributed/distributed_convolution.py b/torch_harmonics/distributed/distributed_convolution.py index 88fcb102..7d2bd568 100644 --- a/torch_harmonics/distributed/distributed_convolution.py +++ b/torch_harmonics/distributed/distributed_convolution.py @@ -49,7 +49,7 @@ # distirbuted stuff from torch_harmonics.distributed import polar_group_size, azimuth_group_size -from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar +from torch_harmonics.distributed import distributed_transpose_azimuth from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_polar_region, gather_from_polar_region, copy_to_polar_region from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim @@ -255,7 +255,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # h and w is split. First we make w local by transposing into channel dim if self.comm_size_azimuth > 1: - x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) + x = distributed_transpose_azimuth(x, (1, -1), self.lon_in_shapes) if self.optimized_kernel: x = _disco_s2_contraction_optimized( @@ -271,7 +271,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # now we can transpose back the result, so that lon is split and channels are local if self.comm_size_azimuth > 1: chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth) - x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes) + x = distributed_transpose_azimuth(x, (-1, 1), chan_shapes) # extract shape B, C, K, H, W = x.shape @@ -440,7 +440,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # transpose such that lon is local, channels are split if self.comm_size_azimuth > 1: - x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) + x = distributed_transpose_azimuth(x, (1, -1), self.lon_in_shapes) # gather input tensor and set up backward reduction hooks x = gather_from_polar_region(x, -2, self.lat_in_shapes) @@ -456,7 +456,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # now we can transpose back the result, so that lon is split and channels are local if self.comm_size_azimuth > 1: chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth) - out = distributed_transpose_azimuth.apply(out, (-1, 1), chan_shapes) + out = distributed_transpose_azimuth(out, (-1, 1), chan_shapes) if self.bias is not None: out = out + self.bias.reshape(1, -1, 1, 1) diff --git a/torch_harmonics/distributed/distributed_resample.py b/torch_harmonics/distributed/distributed_resample.py index 764bea85..114f57a0 100644 --- a/torch_harmonics/distributed/distributed_resample.py +++ b/torch_harmonics/distributed/distributed_resample.py @@ -219,7 +219,7 @@ def forward(self, x: torch.Tensor): # h and w is split. First we make w local by transposing into channel dim if self.comm_size_polar > 1: channels_shapes = compute_split_shapes(num_chans, self.comm_size_polar) - x = distributed_transpose_polar.apply(x, (-3, -2), self.lat_in_shapes) + x = distributed_transpose_polar(x, (-3, -2), self.lat_in_shapes) # expand poles if requested if self.expand_poles: @@ -230,18 +230,18 @@ def forward(self, x: torch.Tensor): # now, transpose back if self.comm_size_polar > 1: - x = distributed_transpose_polar.apply(x, (-2, -3), channels_shapes) + x = distributed_transpose_polar(x, (-2, -3), channels_shapes) # now, transpose in w: if self.comm_size_azimuth > 1: channels_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth) - x = distributed_transpose_azimuth.apply(x, (-3, -1), self.lon_in_shapes) + x = distributed_transpose_azimuth(x, (-3, -1), self.lon_in_shapes) # upscale x = self._upscale_longitudes(x) # transpose back if self.comm_size_azimuth > 1: - x = distributed_transpose_azimuth.apply(x, (-1, -3), channels_shapes) + x = distributed_transpose_azimuth(x, (-1, -3), channels_shapes) return x diff --git a/torch_harmonics/distributed/distributed_sht.py b/torch_harmonics/distributed/distributed_sht.py index a0ac8acc..cf48aa91 100644 --- a/torch_harmonics/distributed/distributed_sht.py +++ b/torch_harmonics/distributed/distributed_sht.py @@ -141,7 +141,7 @@ def forward(self, x: torch.Tensor): # h and w is split. First we make w local by transposing into channel dim if self.comm_size_azimuth > 1: - x = distributed_transpose_azimuth.apply(x, (-3, -1), self.lon_shapes) + x = distributed_transpose_azimuth(x, (-3, -1), self.lon_shapes) # apply real fft in the longitudinal direction: make sure to truncate to nlon x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward") @@ -152,11 +152,11 @@ def forward(self, x: torch.Tensor): # transpose: after this, m is split and c is local if self.comm_size_azimuth > 1: chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth) - x = distributed_transpose_azimuth.apply(x, (-1, -3), chan_shapes) + x = distributed_transpose_azimuth(x, (-1, -3), chan_shapes) # transpose: after this, c is split and h is local if self.comm_size_polar > 1: - x = distributed_transpose_polar.apply(x, (-3, -2), self.lat_shapes) + x = distributed_transpose_polar(x, (-3, -2), self.lat_shapes) # do the Legendre-Gauss quadrature x = torch.view_as_real(x) @@ -175,7 +175,7 @@ def forward(self, x: torch.Tensor): # transpose: after this, l is split and c is local if self.comm_size_polar > 1: chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar) - x = distributed_transpose_polar.apply(x, (-2, -3), chan_shapes) + x = distributed_transpose_polar(x, (-2, -3), chan_shapes) return x @@ -274,7 +274,7 @@ def forward(self, x: torch.Tensor): # transpose: after that, channels are split, l is local: if self.comm_size_polar > 1: - x = distributed_transpose_polar.apply(x, (-3, -2), self.l_shapes) + x = distributed_transpose_polar(x, (-3, -2), self.l_shapes) # Evaluate associated Legendre functions on the output nodes x = torch.view_as_real(x) @@ -294,11 +294,11 @@ def forward(self, x: torch.Tensor): if self.comm_size_polar > 1: chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar) - x = distributed_transpose_polar.apply(x, (-2, -3), chan_shapes) + x = distributed_transpose_polar(x, (-2, -3), chan_shapes) # transpose: after this, channels are split and m is local if self.comm_size_azimuth > 1: - x = distributed_transpose_azimuth.apply(x, (-3, -1), self.m_shapes) + x = distributed_transpose_azimuth(x, (-3, -1), self.m_shapes) # set DCT and nyquist frequencies to 0: x[..., 0].imag = 0.0 @@ -311,7 +311,7 @@ def forward(self, x: torch.Tensor): # transpose: after this, m is split and channels are local if self.comm_size_azimuth > 1: chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth) - x = distributed_transpose_azimuth.apply(x, (-1, -3), chan_shapes) + x = distributed_transpose_azimuth(x, (-1, -3), chan_shapes) return x @@ -420,7 +420,7 @@ def forward(self, x: torch.Tensor): # h and w is split. First we make w local by transposing into channel dim if self.comm_size_azimuth > 1: - x = distributed_transpose_azimuth.apply(x, (-4, -1), self.lon_shapes) + x = distributed_transpose_azimuth(x, (-4, -1), self.lon_shapes) # apply real fft in the longitudinal direction: make sure to truncate to nlon x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward") @@ -431,11 +431,11 @@ def forward(self, x: torch.Tensor): # transpose: after this, m is split and c is local if self.comm_size_azimuth > 1: chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth) - x = distributed_transpose_azimuth.apply(x, (-1, -4), chan_shapes) + x = distributed_transpose_azimuth(x, (-1, -4), chan_shapes) # transpose: after this, c is split and h is local if self.comm_size_polar > 1: - x = distributed_transpose_polar.apply(x, (-4, -2), self.lat_shapes) + x = distributed_transpose_polar(x, (-4, -2), self.lat_shapes) # do the Legendre-Gauss quadrature x = torch.view_as_real(x) @@ -468,7 +468,7 @@ def forward(self, x: torch.Tensor): # transpose: after this, l is split and c is local if self.comm_size_polar > 1: chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar) - x = distributed_transpose_polar.apply(x, (-2, -4), chan_shapes) + x = distributed_transpose_polar(x, (-2, -4), chan_shapes) return x @@ -564,7 +564,7 @@ def forward(self, x: torch.Tensor): # transpose: after that, channels are split, l is local: if self.comm_size_polar > 1: - x = distributed_transpose_polar.apply(x, (-4, -2), self.l_shapes) + x = distributed_transpose_polar(x, (-4, -2), self.l_shapes) # Evaluate associated Legendre functions on the output nodes x = torch.view_as_real(x) @@ -595,11 +595,11 @@ def forward(self, x: torch.Tensor): if self.comm_size_polar > 1: chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar) - x = distributed_transpose_polar.apply(x, (-2, -4), chan_shapes) + x = distributed_transpose_polar(x, (-2, -4), chan_shapes) # transpose: after this, channels are split and m is local if self.comm_size_azimuth > 1: - x = distributed_transpose_azimuth.apply(x, (-4, -1), self.m_shapes) + x = distributed_transpose_azimuth(x, (-4, -1), self.m_shapes) # set DCT and nyquist frequencies to zero x[..., 0].imag = 0.0 @@ -612,6 +612,6 @@ def forward(self, x: torch.Tensor): # transpose: after this, m is split and channels are local if self.comm_size_azimuth > 1: chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth) - x = distributed_transpose_azimuth.apply(x, (-1, -4), chan_shapes) + x = distributed_transpose_azimuth(x, (-1, -4), chan_shapes) return x diff --git a/torch_harmonics/distributed/primitives.py b/torch_harmonics/distributed/primitives.py index 51231fe0..50b76ace 100644 --- a/torch_harmonics/distributed/primitives.py +++ b/torch_harmonics/distributed/primitives.py @@ -98,7 +98,7 @@ def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False) return x_recv, dim0_split_sizes, req -class distributed_transpose_azimuth(torch.autograd.Function): +class _DistributeTransposeAzimuth(torch.autograd.Function): @staticmethod @custom_fwd(device_type="cuda") @@ -124,16 +124,16 @@ def backward(ctx, go): return gi, None, None -class distributed_transpose_polar(torch.autograd.Function): +class _DistributeTransposePolar(torch.autograd.Function): @staticmethod @custom_fwd(device_type="cuda") - def forward(ctx, x, dim, dim1_split_sizes): + def forward(ctx, x, dims, dim1_split_sizes): # WAR for a potential contig check torch bug for channels last contig tensors - xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group()) - x = torch.cat(xlist, dim=dim[1]) - ctx.dim = dim + xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=polar_group()) + x = torch.cat(xlist, dim=dims[1]) + ctx.dims = dims ctx.dim0_split_sizes = dim0_split_sizes return x @@ -141,11 +141,11 @@ def forward(ctx, x, dim, dim1_split_sizes): @custom_bwd(device_type="cuda") def backward(ctx, go): - dim = ctx.dim + dims = ctx.dims dim0_split_sizes = ctx.dim0_split_sizes # WAR for a potential contig check torch bug for channels last contig tensors - gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group()) - gi = torch.cat(gilist, dim=dim[0]) + gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=polar_group()) + gi = torch.cat(gilist, dim=dims[0]) return gi, None, None @@ -443,6 +443,13 @@ def backward(ctx, grad_output): else: return grad_output, None, None +@torch.compiler.disable() +def distributed_transpose_azimuth(input_, dims_, shapes_): + return _DistributeTransposeAzimuth.apply(input_, dims_, shapes_) + +@torch.compiler.disable() +def distributed_transpose_polar(input_, dims_, shapes_): + return _DistributeTransposePolar.apply(input_, dims_, shapes_) def copy_to_polar_region(input_): return _CopyToPolarRegion.apply(input_) From a5eb0370d8b14f0d6746a882d8ecc70dc88b8c67 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 28 Jan 2026 00:09:29 -0800 Subject: [PATCH 11/13] slight reformat --- Dockerfile.examples | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile.examples b/Dockerfile.examples index 67a2ee1a..c63726a9 100644 --- a/Dockerfile.examples +++ b/Dockerfile.examples @@ -48,8 +48,8 @@ RUN pip install h5py RUN cd /opt && git clone https://github.com/SHI-Labs/NATTEN natten && \ cd natten && \ make WITH_CUDA=1 \ - CUDA_ARCH="8.0;8.6;8.7;9.0;10.0" \ - WORKERS=4 + CUDA_ARCH="8.0;8.6;8.7;9.0;10.0" \ + WORKERS=4 # install torch harmonics COPY . /workspace/torch_harmonics From ad930c23099052d1afbcebdb11908c57b87a53b2 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 28 Jan 2026 02:11:04 -0800 Subject: [PATCH 12/13] updating main dockerfile --- Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 840f3e03..c4bf9dbb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -30,15 +30,15 @@ # build after cloning in directoy torch_harmonics via # docker build . -t torch_harmonics -FROM nvcr.io/nvidia/pytorch:24.12-py3 +FROM nvcr.io/nvidia/pytorch:25.12-py3 # we need this for tests RUN pip install parameterized # The custom CUDA extension does not suppport architerctures < 7.0 ENV FORCE_CUDA_EXTENSION=1 -ENV TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX" -ENV TORCH_HARMONICS_DEBUG=1 +ENV TORCH_HARMONICS_ENABLE_OPENMP=1 +ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.7 9.0 10.0+PTX" COPY . /workspace/torch_harmonics RUN cd /workspace/torch_harmonics && pip install --no-build-isolation . From 16f05855525547000c1f101ffb9758febfc97499 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 2 Feb 2026 00:08:56 -0800 Subject: [PATCH 13/13] removing compilation of 16 bit kernels since they are not used --- torch_harmonics/disco/csrc/disco_cuda_bwd.cu | 10 +++++----- torch_harmonics/disco/csrc/disco_cuda_fwd.cu | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu index a23872cf..4aaaf16c 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu @@ -233,7 +233,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t static_assert(0 == (ELXTH_MAX % 2)); if (Wo <= 64 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { using storage_t = scalar_t; using compute_t = typename at::opmath_type; launch_kernel<64, 1, storage_t, compute_t>( @@ -243,7 +243,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t inp.data_ptr(), out.data_ptr(), stream); })); } else if (Wo <= 128 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { using storage_t = scalar_t; using compute_t = typename at::opmath_type; launch_kernel<128, (ELXTH_MAX / 2) + 1, storage_t, compute_t>( @@ -253,7 +253,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t inp.data_ptr(), out.data_ptr(), stream); })); } else if (Wo <= 256 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { using storage_t = scalar_t; using compute_t = typename at::opmath_type; launch_kernel<256, (ELXTH_MAX / 2) + 1, storage_t, compute_t>( @@ -263,7 +263,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t inp.data_ptr(), out.data_ptr(), stream); })); } else if (Wo <= 512 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { using storage_t = scalar_t; using compute_t = typename at::opmath_type; launch_kernel<512, (ELXTH_MAX / 2) + 1, storage_t, compute_t>( @@ -273,7 +273,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t inp.data_ptr(), out.data_ptr(), stream); })); } else if (Wo <= 1024 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { using storage_t = scalar_t; using compute_t = typename at::opmath_type; launch_kernel<1024, (ELXTH_MAX / 2) + 1, storage_t, compute_t>( diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index 0f6d43cf..4b94a791 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -222,7 +222,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t // pick the correct launch config if (Wo <= 64 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_forward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { using storage_t = scalar_t; using compute_t = typename at::opmath_type; launch_kernel<64, 1, storage_t, compute_t>( @@ -232,7 +232,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t inp.data_ptr(), out.data_ptr(), stream); })); } else if (Wo <= 128 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_forward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { using storage_t = scalar_t; using compute_t = typename at::opmath_type; launch_kernel<128, (ELXTH_MAX / 2) + 1, storage_t, compute_t>( @@ -242,7 +242,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t inp.data_ptr(), out.data_ptr(), stream); })); } else if (Wo <= 256 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_forward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { using storage_t = scalar_t; using compute_t = typename at::opmath_type; launch_kernel<256, (ELXTH_MAX / 2) + 1, storage_t, compute_t>( @@ -252,7 +252,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t inp.data_ptr(), out.data_ptr(), stream); })); } else if (Wo <= 512 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_forward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { using storage_t = scalar_t; using compute_t = typename at::opmath_type; launch_kernel<512, (ELXTH_MAX / 2) + 1, storage_t, compute_t>( @@ -262,7 +262,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t inp.data_ptr(), out.data_ptr(), stream); })); } else if (Wo <= 1024 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, inp.scalar_type(), "disco_forward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { using storage_t = scalar_t; using compute_t = typename at::opmath_type; launch_kernel<1024, (ELXTH_MAX / 2) + 1, storage_t, compute_t>(