-
Notifications
You must be signed in to change notification settings - Fork 13
Description
Hi, I am using s2fft in a script where I use healpix maps. Everything works perfectly when using a CPU. However, I now need to use the GPU and it does not seem to work.
I believe the problem is in the combination CUDA + Jax + Healpy, because if I use CUDA and Jax but without using 'healpix' as the method for the fft calculations (I tested for example sampling='mw', method='jax_cuda') it works. And using sampling='healpix' and method 'jax_healpy', without CUDA (in a CPU) also works.
I installed s2fft from source to have the CUDA extension support as detailed in here and I am using CUDA 12.3, jax 0.5.3, jaxlib 0.5.3, and Healpy 1.16.6. Also CMake 3.26.5. All these versions should be compatible with s2fft, I believe.
This is the error that I get, it looks like callbacks from s2fft :
ERROR: CUFFT call "cufftXtSetCallback(forwardPlan, getfftShiftFloat(equator), CUFFT_CB_ST_COMPLEX, (void **)¶ms_dev)" in line 284 of file /tmp/pip-req-build-trgn26kr/lib/src/s2fft_callbacks.cu failed with code (5).
ERROR: CUFFT call "cufftXtSetCallback(backwardPlan, getfftNormFloat(equator, false), CUFFT_CB_ST_COMPLEX, (void **)¶ms_dev)" in line 287 of file /tmp/pip-req-build-trgn26kr/lib/src/s2fft_callbacks.cu failed with code (5).
And then it keeps repeating the two error messages. Does anyone know what is happening ? Thank you in advance.