Skip to content

Update kde.py#96

Open
GuniDH wants to merge 1 commit intoParskatt:mainfrom
GuniDH:kde-optimization-by-GuniDH
Open

Update kde.py#96
GuniDH wants to merge 1 commit intoParskatt:mainfrom
GuniDH:kde-optimization-by-GuniDH

Conversation

@GuniDH
Copy link

@GuniDH GuniDH commented Mar 26, 2025

implemented a faster version of KDE which is up to 2 times faster than your original vesion

implemented a faster
@VladsTheImplier
Copy link

Have you actually timed your code?
For me its slightly slower than the original, and with a significant floating point error when run with half=True.

I use this to time cuda code:

def time_cuda_function(function: Callable, *args, _iter=50, **kwargs) -> Any:
    a = torch.rand((3000, 3000), device='cuda')
    s = torch.cuda.Event(enable_timing=True)
    e = torch.cuda.Event(enable_timing=True)
    res = None
    times = []

    _ = [torch.linalg.inv(a * _ @ a) for _ in range(1, 11)]
    for _ in range(_iter):
        s.record()
        res = function(*args, **kwargs)
        e.record()
        torch.cuda.synchronize()
        times.append(s.elapsed_time(e))

    print(f"Average time for {_iter} iterations: {np.mean(times):.2f}ms +- {np.std(times):.2f}ms  "
          f"[min: {np.min(times):.2f}ms, max: {np.max(times):.2f}ms, median: {np.median(times):.2f}ms]")

    return res

And then you can do:

kde_input = torch.rand((20_000, 4), device='cuda' dtype=torch.float16)

r1 = time_cuda_function(kde, kde_input)
r2 = time_cuda_function(kde_new, kde_input)
print(torch.abs(r1 - r2).mean().item())

Which on my machine gives:

Average time for 50 iterations: 49.80ms +- 0.35ms [min: 48.87ms, max: 51.01ms, median: 49.86ms]
Average time for 50 iterations: 56.51ms +- 0.33ms [min: 55.45ms, max: 57.10ms, median: 56.55ms]
0.1441650390625

@Parskatt
Copy link
Owner

I'll wait with merging this as it seems unclear if it actually reduces compute. Imo current implementation is dumb. It should be made on the grid or with fewer points, but typically one doesn't need that many correspondences.

@GuniDH
Copy link
Author

GuniDH commented May 11, 2025

Have you actually timed your code? For me its slightly slower than the original, and with a significant floating point error when run with half=True.

I use this to time cuda code:

def time_cuda_function(function: Callable, *args, _iter=50, **kwargs) -> Any:
    a = torch.rand((3000, 3000), device='cuda')
    s = torch.cuda.Event(enable_timing=True)
    e = torch.cuda.Event(enable_timing=True)
    res = None
    times = []

    _ = [torch.linalg.inv(a * _ @ a) for _ in range(1, 11)]
    for _ in range(_iter):
        s.record()
        res = function(*args, **kwargs)
        e.record()
        torch.cuda.synchronize()
        times.append(s.elapsed_time(e))

    print(f"Average time for {_iter} iterations: {np.mean(times):.2f}ms +- {np.std(times):.2f}ms  "
          f"[min: {np.min(times):.2f}ms, max: {np.max(times):.2f}ms, median: {np.median(times):.2f}ms]")

    return res

And then you can do:

kde_input = torch.rand((20_000, 4), device='cuda' dtype=torch.float16)

r1 = time_cuda_function(kde, kde_input)
r2 = time_cuda_function(kde_new, kde_input)
print(torch.abs(r1 - r2).mean().item())

Which on my machine gives:

Average time for 50 iterations: 49.80ms +- 0.35ms [min: 48.87ms, max: 51.01ms, median: 49.86ms]
Average time for 50 iterations: 56.51ms +- 0.33ms [min: 55.45ms, max: 57.10ms, median: 56.55ms]
0.1441650390625

I have timed it, for my data my version was usually twice as fast than the original approach, and in terms of accuracy it's an infinitesimal difference for my data - unfortunately I cannot share my data as it's a part of a test dataset of a kaggle competition and I can't extract it. Though switching to float32 would probably improve that, the memory cost isn't justifiable given the negligible gain in accuracy. It's worth noting that my approach that leverages matrix multiplication, shines when the feature dimension D is reasonably large. In your case, with D=4, the matrix multiply ([20000, 4] @ [4, 20000]) doesn't provide enough parallel work to fully utilize the tensor cores. Meanwhile, the default cdist implementation uses simple broadcasted pointwise ops that are cheap when D is small, and may run faster due to lower overhead in that regime. This is likely why you're seeing a performance gap in favor of the original version.

Also, I have just found that torch.cdist's API does actually support my approach.
All you need is to add the key argument: compute_mode='use_mm_for_euclid_dist' and so the matrix multipication version will be ran.

Using this approach is better for data with a D large enough because the normal usage of cdists calculated ||xi-xj|| explicitly launches many kernels due to the pointwise ops (sub, square, sum, sqrt), also broadcasting x1 and x2 creates lots of memory bandwidth and cache pressure. My approach which calculates ||xi||^2 + ||xj||^2 – 2(xi*xj) instead utilizes the fast calculation of matrix multipication - cuBLAS GEMM or optimized matmul kernels are usually used behind the scenes. If we will find out this version isn't fast enough on your data, I'll implement it manually because after all we can't be 100% sure this is how pytorch implemented the matrix multipication. Another must is avoiding the multiple kernel launches by leveraging fusion which will use registers/shared memory instead of vram for faster memory access, lower vram usage, and combine these multiple operations into one kernel launch. For this to happen we will use torch.compile.

Because I mentioned using my approach wouldn't be suitable for all data (like it was suitable for mine), I encourage you to test these two approaches on your data:
first one is fusion + always choose the matrix multipication version
second one is fusion + let torch choose between the approaches to calculate the norm

first:

import torch

# Inner function – only tensors and primitives to suit torch.compile
def _kde(x, std, down):
    if down is not None:
        scores = (-torch.cdist(x, x[::down], compute_mode='use_mm_for_euclid_dist')**2 / (2 * std**2)).exp()
    else:
        scores = (-torch.cdist(x, x, compute_mode='use_mm_for_euclid_dist')**2 / (2 * std**2)).exp()
    return scores.sum(dim=-1)

compiled_kde = torch.compile(_kde)

# Public function — pre-process and call compiled version
def kde(x, std=0.1, half=True, down=None):
    # use a gaussian kernel to estimate density
    if half:
        x = x.half()  # Do it in half precision TODO: remove hardcoding
    return compiled_kde(x, std, down)

second:

import torch

# Inner function – only tensors and primitives to suit torch.compile
def _kde(x, std, down):
    if down is not None:
        scores = (-torch.cdist(x, x[::down])**2 / (2 * std**2)).exp()
    else:
        scores = (-torch.cdist(x, x)**2 / (2 * std**2)).exp()
    return scores.sum(dim=-1)

compiled_kde = torch.compile(_kde)

# Public function — pre-process and call compiled version
def kde(x, std=0.1, half=True, down=None):
    # use a gaussian kernel to estimate density
    if half:
        x = x.half()  # Do it in half precision TODO: remove hardcoding
    return compiled_kde(x, std, down)

@VladsTheImplier
Copy link

@GuniDH I don't understand how your D is anything but 4. KDE is only used for the sampling, where the dimensions are (n_points, 4)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants