Skip to content

Speed up SHT#150

Open
mcgibbon wants to merge 2 commits intoNVIDIA:mainfrom
mcgibbon:feature/fast_sht
Open

Speed up SHT#150
mcgibbon wants to merge 2 commits intoNVIDIA:mainfrom
mcgibbon:feature/fast_sht

Conversation

@mcgibbon
Copy link

I am working on performance improvements for the SFNO and came across this small modification which seems to reduce the runtime of the SHT and inverse SHT by ~50% (according to cuda.Event timers) by drastically reducing the time taken for the einsums, since the matrix multiplication is now being done over the fast dimension. I am still investigating and haven't done a production runtime benchmark, but I'm wondering if you would be interested in benchmarking this purely on the SHT @bonevbs ? Since it makes no API changes, it would be good to include in torch-harmonics if this is real.

I was not able to pip install -e torch-harmonics because my environment isn't set up with the right dev headers, but I am 90% sure this code is right based on the tested changes in our fork.

Copy link
Collaborator

@azrael417 azrael417 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello Jeremy, thanks for your effort. I have a comment and a general question. On what architecture and for which sizes does this outperform the old data ordering. While we have not really optimized the einsum we have chosen a data layout which was OK for the average case. However, I am happy to accept any changes which speed this up but would be good to know how much of this is architecture dependent, i.e. whether the ensue is executed on the TC or not (either with bf16, fp16 or tf32).

However, I am happy to check it out.For the distributed version, we could bake this transpose into the all to all.


# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)
x = x.transpose(-1, -2).contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if this is correct. If we first to view as real, the last dim should be re/im. This means that this transpose does the following:

Input: B, C, H, M, 2 -> B, C, H, 2, M

And then in the contraction below, we contract the 2-index with L. I think the view as real should go after the transpose.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're correct. This snuck through since I wasn't able to install/test the code.

I'll work on a version perhaps today that I can define in a standalone script with timing for demonstration.

@mcgibbon
Copy link
Author

mcgibbon commented Feb 12, 2026

Hello Jeremy, thanks for your effort. I have a comment and a general question. On what architecture and for which sizes does this outperform the old data ordering. While we have not really optimized the einsum we have chosen a data layout which was OK for the average case. However, I am happy to accept any changes which speed this up but would be good to know how much of this is architecture dependent, i.e. whether the ensue is executed on the TC or not (either with bf16, fp16 or tf32).

However, I am happy to check it out.For the distributed version, we could bake this transpose into the all to all.

I was doing this with an effective leading dimension of 1024 (combining the batch and channel dimensions), and resolution of 1-degree (so an array of size [1024, 90, 180]). The ~50% speed-up is isolated to the SHT/iSHT calls, on a Tesla T4.

I'm not sure what you mean by TC, but I thought this code does not run with bf16 or fp16? When I tried earlier this month I ran into issues with torch not supporting these types in rfft/irfft.

I'll make a timing script you can use to run at your chosen resolutions on your hardware, I'm curious if you also get the speed-up. The distributed case may require different optimization patterns, I would benchmark it before applying the same changes. The data layout may matter a lot there, e.g. whether latitude bands are on single ranks and whether you choose the same layout in spectral space (which in the SFNO case would also affect the contraction).

@mcgibbon
Copy link
Author

mcgibbon commented Feb 12, 2026

It looks like my simpler example will have to wait at least another day, there are other tasks piling up for today and I'm still having issues installing the package. For now I can point you to tests I'm running on our slightly different version of the SHT in the ace repo. The WIP results (which I haven't cleaned up yet, but have per-commit performance charts) are in this branch.

In theory, the changes I've made should apply regardless of weight shape so long as the GPU is occupied. The FFTs are not made any slower, as they stay in the fast dimension, and the performance of everything else scales linearly with size. The motivation also makes sense - you want the dimension you're applying a linear transformation over to be the fast dimension. But it would be nice to see you reproduce these findings on your hardware, and I'd like to confirm on my end with CPU timers and a benchmark isolated to just the SHT.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants