Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions challenges/medium/73_bilateral_filter/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
<p>
Implement a bilateral filter on a 2D image of 32-bit floating point values.
The bilateral filter is an edge-preserving smoothing filter: for each pixel it computes a
weighted average of its neighbors, where the weight combines a spatial Gaussian (favoring
nearby pixels) and a range Gaussian (favoring pixels with similar intensity), so that sharp
edges are preserved while flat regions are smoothed.
The image is stored in row-major order and boundary pixels are handled by clamping to the
nearest valid index (border replication).
</p>

<h2>Implementation Requirements</h2>
<ul>
<li>Use only native features (external libraries are not permitted)</li>
<li>The <code>solve</code> function signature must remain unchanged</li>
<li>The final result must be stored in <code>output</code></li>
<li>
For each output pixel at position \((i, j)\), compute:
\[
\text{output}[i,j] = \frac{\displaystyle\sum_{dy=-r}^{r}\sum_{dx=-r}^{r} w_s(dy,dx)\, w_r\bigl(\text{image}[i',j'] - \text{image}[i,j]\bigr)\, \text{image}[i',j']}
{\displaystyle\sum_{dy=-r}^{r}\sum_{dx=-r}^{r} w_s(dy,dx)\, w_r\bigl(\text{image}[i',j'] - \text{image}[i,j]\bigr)}
\]
where \(i' = \text{clamp}(i+dy, 0, H-1)\), \(j' = \text{clamp}(j+dx, 0, W-1)\),
\(w_s(dy,dx) = \exp\!\left(-\tfrac{dy^2+dx^2}{2\sigma_s^2}\right)\), and
\(w_r(\delta) = \exp\!\left(-\tfrac{\delta^2}{2\sigma_r^2}\right)\)
</li>
</ul>

<h2>Example</h2>
<p>
Input image (\(H=3, W=3\)), <code>spatial_sigma</code> = 1.0,
<code>range_sigma</code> = 0.5, <code>radius</code> = 1:
\[
\begin{bmatrix}
1.0 & 1.0 & 1.0 \\
1.0 & 0.0 & 1.0 \\
1.0 & 1.0 & 1.0
\end{bmatrix}
\]
Output:
\[
\begin{bmatrix}
0.9891 & 0.9812 & 0.9891 \\
0.9812 & 0.3453 & 0.9812 \\
0.9891 & 0.9812 & 0.9891
\end{bmatrix}
\]
The center pixel (value 0.0) is surrounded by neighbors all equal to 1.0. Because the
intensity difference of 1.0 is large relative to <code>range_sigma</code> = 0.5, the range
weights strongly suppress those neighbors, so the output at the center (0.3453) is far lower
than a plain Gaussian blur would produce (&approx; 0.75). The outer pixels remain close to
1.0 because their neighbors are mostly equal to themselves.
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>H</code>, <code>W</code> &le; 8192</li>
<li>1 &le; <code>radius</code> &le; 16</li>
<li>Image values are 32-bit floats (arbitrary range)</li>
<li>0.1 &le; <code>spatial_sigma</code>, <code>range_sigma</code> &le; 10.0</li>
<li>
Performance is measured with <code>H</code> = 2,048, <code>W</code> = 2,048,
<code>radius</code> = 5, <code>spatial_sigma</code> = 3.0, <code>range_sigma</code> = 0.1
</li>
</ul>
273 changes: 273 additions & 0 deletions challenges/medium/73_bilateral_filter/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
import ctypes
from typing import Any, Dict, List

import torch
import torch.nn.functional as F
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="Bilateral Filter",
atol=1e-04,
rtol=1e-04,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
image: torch.Tensor,
output: torch.Tensor,
H: int,
W: int,
spatial_sigma: float,
range_sigma: float,
radius: int,
):
assert image.shape == (H * W,)
assert output.shape == (H * W,)
assert image.dtype == torch.float32
assert output.dtype == torch.float32
assert image.device.type == "cuda"
assert output.device.type == "cuda"

r = int(radius)
img = image.view(H, W)

yy = torch.arange(-r, r + 1, device=image.device, dtype=torch.float32)
xx = torch.arange(-r, r + 1, device=image.device, dtype=torch.float32)
grid_y, grid_x = torch.meshgrid(yy, xx, indexing="ij")
spatial_weights = torch.exp(-(grid_y**2 + grid_x**2) / (2.0 * float(spatial_sigma) ** 2))

padded = (
F.pad(img.unsqueeze(0).unsqueeze(0), (r, r, r, r), mode="replicate")
.squeeze(0)
.squeeze(0)
)

out = torch.zeros(H, W, device=image.device, dtype=torch.float32)
norm = torch.zeros(H, W, device=image.device, dtype=torch.float32)
inv_2rs2 = 1.0 / (2.0 * float(range_sigma) ** 2)

for dy in range(2 * r + 1):
for dx in range(2 * r + 1):
neighbor = padded[dy : dy + H, dx : dx + W]
range_weight = torch.exp(-((neighbor - img) ** 2) * inv_2rs2)
weight = spatial_weights[dy, dx] * range_weight
out += weight * neighbor
norm += weight

output.copy_(out.view(-1) / norm.view(-1))

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"image": (ctypes.POINTER(ctypes.c_float), "in"),
"output": (ctypes.POINTER(ctypes.c_float), "out"),
"H": (ctypes.c_int, "in"),
"W": (ctypes.c_int, "in"),
"spatial_sigma": (ctypes.c_float, "in"),
"range_sigma": (ctypes.c_float, "in"),
"radius": (ctypes.c_int, "in"),
}

def generate_example_test(self) -> Dict[str, Any]:
dtype = torch.float32
H, W = 3, 3
image = torch.tensor(
[1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], device="cuda", dtype=dtype
)
output = torch.zeros(H * W, device="cuda", dtype=dtype)
return {
"image": image,
"output": output,
"H": H,
"W": W,
"spatial_sigma": 1.0,
"range_sigma": 0.5,
"radius": 1,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
dtype = torch.float32
device = "cuda"
tests = []

# single_pixel
H, W = 1, 1
tests.append(
{
"image": torch.tensor([0.5], device=device, dtype=dtype),
"output": torch.zeros(H * W, device=device, dtype=dtype),
"H": H,
"W": W,
"spatial_sigma": 1.0,
"range_sigma": 0.5,
"radius": 1,
}
)

# two_by_two_zeros
H, W = 2, 2
tests.append(
{
"image": torch.zeros(H * W, device=device, dtype=dtype),
"output": torch.zeros(H * W, device=device, dtype=dtype),
"H": H,
"W": W,
"spatial_sigma": 1.0,
"range_sigma": 0.5,
"radius": 1,
}
)

# three_by_three_ring (matches example)
H, W = 3, 3
tests.append(
{
"image": torch.tensor(
[1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], device=device, dtype=dtype
),
"output": torch.zeros(H * W, device=device, dtype=dtype),
"H": H,
"W": W,
"spatial_sigma": 1.0,
"range_sigma": 0.5,
"radius": 1,
}
)

# four_by_four_negatives
H, W = 4, 4
tests.append(
{
"image": torch.tensor(
[
-1.0,
-1.0,
1.0,
1.0,
-1.0,
-1.0,
1.0,
1.0,
-1.0,
-1.0,
1.0,
1.0,
-1.0,
-1.0,
1.0,
1.0,
],
device=device,
dtype=dtype,
),
"output": torch.zeros(H * W, device=device, dtype=dtype),
"H": H,
"W": W,
"spatial_sigma": 1.5,
"range_sigma": 0.8,
"radius": 1,
}
)

# power_of_two_16x16
H, W = 16, 16
tests.append(
{
"image": torch.empty(H * W, device=device, dtype=dtype).uniform_(0.0, 1.0),
"output": torch.zeros(H * W, device=device, dtype=dtype),
"H": H,
"W": W,
"spatial_sigma": 1.5,
"range_sigma": 0.3,
"radius": 2,
}
)

# power_of_two_64x64_mixed
H, W = 64, 64
tests.append(
{
"image": torch.empty(H * W, device=device, dtype=dtype).uniform_(-1.0, 1.0),
"output": torch.zeros(H * W, device=device, dtype=dtype),
"H": H,
"W": W,
"spatial_sigma": 2.0,
"range_sigma": 0.5,
"radius": 2,
}
)

# non_power_of_two_100x100
H, W = 100, 100
tests.append(
{
"image": torch.empty(H * W, device=device, dtype=dtype).uniform_(0.0, 1.0),
"output": torch.zeros(H * W, device=device, dtype=dtype),
"H": H,
"W": W,
"spatial_sigma": 2.0,
"range_sigma": 0.3,
"radius": 3,
}
)

# non_power_of_two_255x255_mixed
H, W = 255, 255
tests.append(
{
"image": torch.empty(H * W, device=device, dtype=dtype).uniform_(-1.0, 1.0),
"output": torch.zeros(H * W, device=device, dtype=dtype),
"H": H,
"W": W,
"spatial_sigma": 1.5,
"range_sigma": 0.4,
"radius": 2,
}
)

# realistic_512x512
H, W = 512, 512
tests.append(
{
"image": torch.empty(H * W, device=device, dtype=dtype).uniform_(0.0, 1.0),
"output": torch.zeros(H * W, device=device, dtype=dtype),
"H": H,
"W": W,
"spatial_sigma": 2.0,
"range_sigma": 0.2,
"radius": 3,
}
)

# realistic_1000x1000
H, W = 1000, 1000
tests.append(
{
"image": torch.empty(H * W, device=device, dtype=dtype).uniform_(0.0, 1.0),
"output": torch.zeros(H * W, device=device, dtype=dtype),
"H": H,
"W": W,
"spatial_sigma": 3.0,
"range_sigma": 0.1,
"radius": 5,
}
)

return tests

def generate_performance_test(self) -> Dict[str, Any]:
dtype = torch.float32
H, W = 2048, 2048
return {
"image": torch.empty(H * W, device="cuda", dtype=dtype).uniform_(0.0, 1.0),
"output": torch.zeros(H * W, device="cuda", dtype=dtype),
"H": H,
"W": W,
"spatial_sigma": 3.0,
"range_sigma": 0.1,
"radius": 5,
}
5 changes: 5 additions & 0 deletions challenges/medium/73_bilateral_filter/starter/starter.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <cuda_runtime.h>

// image, output are device pointers
extern "C" void solve(const float* image, float* output, int H, int W, float spatial_sigma,
float range_sigma, int radius) {}
16 changes: 16 additions & 0 deletions challenges/medium/73_bilateral_filter/starter/starter.cute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import cutlass
import cutlass.cute as cute


# image, output are tensors on the GPU
@cute.jit
def solve(
image: cute.Tensor,
output: cute.Tensor,
H: cute.Int32,
W: cute.Int32,
spatial_sigma: cute.Float32,
range_sigma: cute.Float32,
radius: cute.Int32,
):
pass
11 changes: 11 additions & 0 deletions challenges/medium/73_bilateral_filter/starter/starter.jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import jax
import jax.numpy as jnp


# image is a tensor on GPU
@jax.jit
def solve(
image: jax.Array, H: int, W: int, spatial_sigma: float, range_sigma: float, radius: int
) -> jax.Array:
# return output tensor directly
pass
18 changes: 18 additions & 0 deletions challenges/medium/73_bilateral_filter/starter/starter.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from gpu.host import DeviceContext
from gpu.id import block_dim, block_idx, thread_idx
from memory import UnsafePointer
from math import ceildiv


# image, output are device pointers
@export
def solve(
image: UnsafePointer[Float32],
output: UnsafePointer[Float32],
H: Int32,
W: Int32,
spatial_sigma: Float32,
range_sigma: Float32,
radius: Int32,
):
pass
Loading