Conversation
implemented a faster
|
Have you actually timed your code? I use this to time cuda code: And then you can do: Which on my machine gives:
|
|
I'll wait with merging this as it seems unclear if it actually reduces compute. Imo current implementation is dumb. It should be made on the grid or with fewer points, but typically one doesn't need that many correspondences. |
I have timed it, for my data my version was usually twice as fast than the original approach, and in terms of accuracy it's an infinitesimal difference for my data - unfortunately I cannot share my data as it's a part of a test dataset of a kaggle competition and I can't extract it. Though switching to float32 would probably improve that, the memory cost isn't justifiable given the negligible gain in accuracy. It's worth noting that my approach that leverages matrix multiplication, shines when the feature dimension D is reasonably large. In your case, with D=4, the matrix multiply ([20000, 4] @ [4, 20000]) doesn't provide enough parallel work to fully utilize the tensor cores. Meanwhile, the default cdist implementation uses simple broadcasted pointwise ops that are cheap when D is small, and may run faster due to lower overhead in that regime. This is likely why you're seeing a performance gap in favor of the original version. Also, I have just found that torch.cdist's API does actually support my approach. Using this approach is better for data with a D large enough because the normal usage of cdists calculated ||xi-xj|| explicitly launches many kernels due to the pointwise ops (sub, square, sum, sqrt), also broadcasting x1 and x2 creates lots of memory bandwidth and cache pressure. My approach which calculates ||xi||^2 + ||xj||^2 – 2(xi*xj) instead utilizes the fast calculation of matrix multipication - cuBLAS GEMM or optimized matmul kernels are usually used behind the scenes. If we will find out this version isn't fast enough on your data, I'll implement it manually because after all we can't be 100% sure this is how pytorch implemented the matrix multipication. Another must is avoiding the multiple kernel launches by leveraging fusion which will use registers/shared memory instead of vram for faster memory access, lower vram usage, and combine these multiple operations into one kernel launch. For this to happen we will use torch.compile. Because I mentioned using my approach wouldn't be suitable for all data (like it was suitable for mine), I encourage you to test these two approaches on your data: first: second: |
|
@GuniDH I don't understand how your |
implemented a faster version of KDE which is up to 2 times faster than your original vesion