Conversation
There was a problem hiding this comment.
Hello Jeremy, thanks for your effort. I have a comment and a general question. On what architecture and for which sizes does this outperform the old data ordering. While we have not really optimized the einsum we have chosen a data layout which was OK for the average case. However, I am happy to accept any changes which speed this up but would be good to know how much of this is architecture dependent, i.e. whether the ensue is executed on the TC or not (either with bf16, fp16 or tf32).
However, I am happy to check it out.For the distributed version, we could bake this transpose into the all to all.
|
|
||
| # Evaluate associated Legendre functions on the output nodes | ||
| x = torch.view_as_real(x) | ||
| x = x.transpose(-1, -2).contiguous() |
There was a problem hiding this comment.
I am wondering if this is correct. If we first to view as real, the last dim should be re/im. This means that this transpose does the following:
Input: B, C, H, M, 2 -> B, C, H, 2, M
And then in the contraction below, we contract the 2-index with L. I think the view as real should go after the transpose.
There was a problem hiding this comment.
Yes, you're correct. This snuck through since I wasn't able to install/test the code.
I'll work on a version perhaps today that I can define in a standalone script with timing for demonstration.
I was doing this with an effective leading dimension of 1024 (combining the batch and channel dimensions), and resolution of 1-degree (so an array of size [1024, 90, 180]). The ~50% speed-up is isolated to the SHT/iSHT calls, on a Tesla T4. I'm not sure what you mean by TC, but I thought this code does not run with bf16 or fp16? When I tried earlier this month I ran into issues with torch not supporting these types in I'll make a timing script you can use to run at your chosen resolutions on your hardware, I'm curious if you also get the speed-up. The distributed case may require different optimization patterns, I would benchmark it before applying the same changes. The data layout may matter a lot there, e.g. whether latitude bands are on single ranks and whether you choose the same layout in spectral space (which in the SFNO case would also affect the contraction). |
|
It looks like my simpler example will have to wait at least another day, there are other tasks piling up for today and I'm still having issues installing the package. For now I can point you to tests I'm running on our slightly different version of the SHT in the ace repo. The WIP results (which I haven't cleaned up yet, but have per-commit performance charts) are in this branch. In theory, the changes I've made should apply regardless of weight shape so long as the GPU is occupied. The FFTs are not made any slower, as they stay in the fast dimension, and the performance of everything else scales linearly with size. The motivation also makes sense - you want the dimension you're applying a linear transformation over to be the fast dimension. But it would be nice to see you reproduce these findings on your hardware, and I'd like to confirm on my end with CPU timers and a benchmark isolated to just the SHT. |
I am working on performance improvements for the SFNO and came across this small modification which seems to reduce the runtime of the SHT and inverse SHT by ~50% (according to cuda.Event timers) by drastically reducing the time taken for the einsums, since the matrix multiplication is now being done over the fast dimension. I am still investigating and haven't done a production runtime benchmark, but I'm wondering if you would be interested in benchmarking this purely on the SHT @bonevbs ? Since it makes no API changes, it would be good to include in torch-harmonics if this is real.
I was not able to pip install -e torch-harmonics because my environment isn't set up with the right dev headers, but I am 90% sure this code is right based on the tested changes in our fork.