Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions challenges/medium/90_causal_depthwise_conv1d/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
<p>
Implement a <strong>causal depthwise 1D convolution</strong> over a batched sequence tensor
<code>x</code> of shape <code>(B, L, D)</code>, producing an output of the same shape.
In a depthwise convolution, each channel <code>d</code> is convolved independently using its
own kernel <code>weight[d, :]</code> — there is no mixing across channels.
The convolution is <strong>causal</strong>: output position <code>l</code> may only depend on
input positions <code>0, 1, &hellip;, l</code> (past and present), never future positions.
This operation is a key component of state-space models such as Mamba, where it is applied
before the selective scan to mix local context within each feature channel.
</p>

<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 480 260" width="480" height="260" style="display:block; margin:20px auto;">
<defs>
<marker id="ah" viewBox="0 0 10 10" refX="9" refY="5" markerWidth="6" markerHeight="6" orient="auto-start-reverse">
<path d="M0 0L10 5L0 10z" fill="#999"/>
</marker>
</defs>

<!-- Background -->
<rect width="480" height="260" fill="#222" rx="8"/>

<!-- Title -->
<text x="240" y="22" text-anchor="middle" fill="#ccc" font-size="13" font-family="sans-serif" font-weight="bold">Causal Depthwise Conv1d (K=3, one channel shown)</text>

<!-- Input row label -->
<text x="14" y="68" fill="#aaa" font-size="11" font-family="monospace">x[d]</text>

<!-- Input cells: positions 0..5 -->
<rect x="52" y="52" width="40" height="28" fill="#2a3a55" stroke="#4477bb" stroke-width="1.2" rx="3"/>
<text x="72" y="71" text-anchor="middle" fill="#aaccee" font-size="12" font-family="monospace">x₀</text>

<rect x="96" y="52" width="40" height="28" fill="#2a3a55" stroke="#4477bb" stroke-width="1.2" rx="3"/>
<text x="116" y="71" text-anchor="middle" fill="#aaccee" font-size="12" font-family="monospace">x₁</text>

<rect x="140" y="52" width="40" height="28" fill="#2a3a55" stroke="#4477bb" stroke-width="1.2" rx="3"/>
<text x="160" y="71" text-anchor="middle" fill="#aaccee" font-size="12" font-family="monospace">x₂</text>

<rect x="184" y="52" width="40" height="28" fill="#2a3a55" stroke="#4477bb" stroke-width="1.2" rx="3"/>
<text x="204" y="71" text-anchor="middle" fill="#aaccee" font-size="12" font-family="monospace">x₃</text>

<rect x="228" y="52" width="40" height="28" fill="#2a3a55" stroke="#4477bb" stroke-width="1.2" rx="3"/>
<text x="248" y="71" text-anchor="middle" fill="#aaccee" font-size="12" font-family="monospace">x₄</text>

<rect x="272" y="52" width="40" height="28" fill="#2a3a55" stroke="#4477bb" stroke-width="1.2" rx="3"/>
<text x="292" y="71" text-anchor="middle" fill="#aaccee" font-size="12" font-family="monospace">x₅</text>

<!-- Kernel box -->
<text x="14" y="138" fill="#aaa" font-size="11" font-family="monospace">w[d]</text>
<rect x="140" y="118" width="40" height="28" fill="#1e3d2d" stroke="#44aa66" stroke-width="1.5" rx="3"/>
<text x="160" y="137" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="monospace">w₀</text>
<rect x="184" y="118" width="40" height="28" fill="#1e3d2d" stroke="#44aa66" stroke-width="1.5" rx="3"/>
<text x="204" y="137" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="monospace">w₁</text>
<rect x="228" y="118" width="40" height="28" fill="#1e3d2d" stroke="#44aa66" stroke-width="1.5" rx="3"/>
<text x="248" y="137" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="monospace">w₂</text>

<!-- Annotation: kernel aligned at l=4 -->
<text x="190" y="155" text-anchor="middle" fill="#888" font-size="10" font-family="sans-serif">kernel at l=4: reads x₂,x₃,x₄</text>

<!-- Arrow from kernel region to output -->
<line x1="204" y1="146" x2="204" y2="180" stroke="#999" stroke-width="1.2" marker-end="url(#ah)"/>

<!-- Output row label -->
<text x="14" y="208" fill="#aaa" font-size="11" font-family="monospace">y[d]</text>

<!-- Output cells -->
<rect x="52" y="192" width="40" height="28" fill="#3a2a2a" stroke="#884444" stroke-width="1.2" rx="3"/>
<text x="72" y="211" text-anchor="middle" fill="#eeccaa" font-size="11" font-family="monospace">y₀</text>

<rect x="96" y="192" width="40" height="28" fill="#3a2a2a" stroke="#884444" stroke-width="1.2" rx="3"/>
<text x="116" y="211" text-anchor="middle" fill="#eeccaa" font-size="11" font-family="monospace">y₁</text>

<rect x="140" y="192" width="40" height="28" fill="#3a2a2a" stroke="#884444" stroke-width="1.2" rx="3"/>
<text x="160" y="211" text-anchor="middle" fill="#eeccaa" font-size="11" font-family="monospace">y₂</text>

<rect x="184" y="192" width="40" height="28" fill="#3a2a2a" stroke="#cc6644" stroke-width="2" rx="3"/>
<text x="204" y="211" text-anchor="middle" fill="#ffddaa" font-size="11" font-family="monospace" font-weight="bold">y₃</text>

<rect x="228" y="192" width="40" height="28" fill="#3a2a2a" stroke="#cc6644" stroke-width="2" rx="3"/>
<text x="248" y="211" text-anchor="middle" fill="#ffddaa" font-size="11" font-family="monospace" font-weight="bold">y₄</text>

<rect x="272" y="192" width="40" height="28" fill="#3a2a2a" stroke="#884444" stroke-width="1.2" rx="3"/>
<text x="292" y="211" text-anchor="middle" fill="#eeccaa" font-size="11" font-family="monospace">y₅</text>

<!-- Equation at bottom -->
<text x="240" y="246" text-anchor="middle" fill="#888" font-size="11" font-family="sans-serif">
y[d,l] = bias[d] + Σ w[d,k] · x[d, l−k] (x[d,l−k] = 0 if l−k &lt; 0)
</text>
</svg>

<p>
Formally, for each batch element <code>b</code>, sequence position <code>l</code>, and channel <code>d</code>:
</p>

\[
\text{output}[b,\, l,\, d]
= \text{bias}[d]
+ \sum_{k=0}^{K-1} \text{weight}[d,\, k] \cdot x[b,\, l - k,\, d]
\]

<p>
where positions <code>l &minus; k &lt; 0</code> are treated as zero (zero-pad the left boundary).
The tensor layout is <strong>channels-last</strong>: <code>x[b, l, d]</code> is stored at offset
<code>b &times; L &times; D + l &times; D + d</code>.
</p>

<h2>Implementation Requirements</h2>
<ul>
<li>The <code>solve</code> function signature must remain unchanged</li>
<li>The result must be written into the <code>output</code> tensor</li>
<li>Use only native features (external libraries are not permitted)</li>
<li>Input positions before the start of the sequence (i.e. indices <code>l &minus; k &lt; 0</code>) must be treated as zero</li>
</ul>

<h2>Example</h2>

<p>With <code>B</code> = 1, <code>L</code> = 4, <code>D</code> = 2, <code>K</code> = 3:</p>

<pre>
x = [[[1.0, 2.0], # l=0
[3.0, 4.0], # l=1
[5.0, 6.0], # l=2
[7.0, 8.0]]] # l=3 shape (1, 4, 2)

weight = [[ 1.0, 0.0, -1.0], # channel d=0
[ 1.0, 1.0, 1.0]] # channel d=1 shape (2, 3)

bias = [0.0, 0.0]

output = [[[1.0, 2.0], # l=0: d0: 1*1=1 d1: 1*2=2
[3.0, 6.0], # l=1: d0: 3*1+1*0=3 d1: 4*1+2*1=6
[4.0, 12.0], # l=2: d0: 5*1+3*0+1*(-1)=4 d1: 6+4+2=12
[4.0, 18.0]]] # l=3: d0: 7*1+5*0+3*(-1)=4 d1: 8+6+4=18
</pre>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>B</code> &le; 16 (batch size)</li>
<li>1 &le; <code>L</code> &le; 8,192 (sequence length)</li>
<li>1 &le; <code>D</code> &le; 8,192 (number of channels)</li>
<li>1 &le; <code>K</code> &le; 8 (kernel size; typically 3 or 4 in practice)</li>
<li>All tensors use 32-bit floating point</li>
<li>Tensor <code>x</code> and <code>output</code> use channels-last layout: shape <code>(B, L, D)</code></li>
<li>Performance is measured with <code>B</code> = 8, <code>L</code> = 2,048, <code>D</code> = 4,096, <code>K</code> = 4</li>
</ul>
187 changes: 187 additions & 0 deletions challenges/medium/90_causal_depthwise_conv1d/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import ctypes
from typing import Any, Dict, List

import torch
import torch.nn.functional as F
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="Causal Depthwise Conv1d",
atol=1e-04,
rtol=1e-04,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
output: torch.Tensor,
B: int,
L: int,
D: int,
K: int,
):
assert x.shape == (B, L, D)
assert weight.shape == (D, K)
assert bias.shape == (D,)
assert output.shape == (B, L, D)
assert x.dtype == weight.dtype == bias.dtype == output.dtype == torch.float32
assert x.device.type == "cuda"
assert weight.device.type == "cuda"
assert bias.device.type == "cuda"
assert output.device.type == "cuda"

# Reshape to (B, D, L) for conv1d
x_t = x.permute(0, 2, 1).contiguous() # (B, D, L)

# Causal padding: pad K-1 zeros on the left so each output position
# only sees current and past input positions
x_padded = F.pad(x_t, (K - 1, 0)) # (B, D, L + K - 1)

# Depthwise conv: weight (D, K) -> (D, 1, K), groups=D
# Flip the kernel so weight[d, 0] applies to the current position (l-0)
# and weight[d, K-1] applies to the oldest position (l-(K-1)).
# F.conv1d uses cross-correlation (no implicit flip), so we flip explicitly.
w = weight.flip(1).unsqueeze(1) # (D, 1, K)
result = F.conv1d(x_padded, w, bias=bias, groups=D) # (B, D, L)

output.copy_(result.permute(0, 2, 1)) # (B, L, D)

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"x": (ctypes.POINTER(ctypes.c_float), "in"),
"weight": (ctypes.POINTER(ctypes.c_float), "in"),
"bias": (ctypes.POINTER(ctypes.c_float), "in"),
"output": (ctypes.POINTER(ctypes.c_float), "out"),
"B": (ctypes.c_int, "in"),
"L": (ctypes.c_int, "in"),
"D": (ctypes.c_int, "in"),
"K": (ctypes.c_int, "in"),
}

def generate_example_test(self) -> Dict[str, Any]:
B, L, D, K = 1, 4, 2, 3
x = torch.tensor(
[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]],
device="cuda",
dtype=torch.float32,
)
weight = torch.tensor(
[[1.0, 0.0, -1.0], [1.0, 1.0, 1.0]], device="cuda", dtype=torch.float32
)
bias = torch.zeros(D, device="cuda", dtype=torch.float32)
output = torch.empty(B, L, D, device="cuda", dtype=torch.float32)
return {
"x": x,
"weight": weight,
"bias": bias,
"output": output,
"B": B,
"L": L,
"D": D,
"K": K,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
dtype = torch.float32
test_cases = []

def make_case(B, L, D, K, x_vals=None, w_vals=None, b_vals=None):
if x_vals is not None:
x = torch.tensor(x_vals, device="cuda", dtype=dtype)
else:
x = torch.randn(B, L, D, device="cuda", dtype=dtype)
if w_vals is not None:
weight = torch.tensor(w_vals, device="cuda", dtype=dtype)
else:
weight = torch.randn(D, K, device="cuda", dtype=dtype)
if b_vals is not None:
bias = torch.tensor(b_vals, device="cuda", dtype=dtype)
else:
bias = torch.randn(D, device="cuda", dtype=dtype)
output = torch.empty(B, L, D, device="cuda", dtype=dtype)
return {
"x": x,
"weight": weight,
"bias": bias,
"output": output,
"B": B,
"L": L,
"D": D,
"K": K,
}

# Example test (matches generate_example_test)
test_cases.append(
make_case(
1,
4,
2,
3,
x_vals=[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]],
w_vals=[[1.0, 0.0, -1.0], [1.0, 1.0, 1.0]],
b_vals=[0.0, 0.0],
)
)

# Edge cases: minimal sizes
test_cases.append(make_case(1, 1, 1, 1)) # single element, kernel=1
test_cases.append(make_case(1, 2, 1, 2)) # L < K, so first output is partial
test_cases.append(make_case(2, 3, 4, 3)) # small batch, B=2

# Zero inputs
x_zero = torch.zeros(1, 8, 4, device="cuda", dtype=dtype)
w_zero = torch.randn(4, 3, device="cuda", dtype=dtype)
b_zero = torch.randn(4, device="cuda", dtype=dtype)
test_cases.append(
{
"x": x_zero,
"weight": w_zero,
"bias": b_zero,
"output": torch.empty(1, 8, 4, device="cuda", dtype=dtype),
"B": 1,
"L": 8,
"D": 4,
"K": 3,
}
)

# Negative values
test_cases.append(make_case(1, 16, 8, 4))

# Power-of-2 sizes
test_cases.append(make_case(2, 32, 16, 4))
test_cases.append(make_case(4, 64, 32, 4))

# Non-power-of-2 sizes
test_cases.append(make_case(3, 30, 12, 3))
test_cases.append(make_case(2, 100, 24, 4))

# Realistic inference size (Mamba-like small)
test_cases.append(make_case(2, 256, 128, 4))

return test_cases

def generate_performance_test(self) -> Dict[str, Any]:
B, L, D, K = 8, 2048, 4096, 4
dtype = torch.float32
x = torch.randn(B, L, D, device="cuda", dtype=dtype)
weight = torch.randn(D, K, device="cuda", dtype=dtype)
bias = torch.randn(D, device="cuda", dtype=dtype)
output = torch.empty(B, L, D, device="cuda", dtype=dtype)
return {
"x": x,
"weight": weight,
"bias": bias,
"output": output,
"B": B,
"L": L,
"D": D,
"K": K,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <cuda_runtime.h>

// x, weight, bias, output are device pointers
extern "C" void solve(const float* x, const float* weight, const float* bias, float* output, int B,
int L, int D, int K) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import cutlass
import cutlass.cute as cute


# x, weight, bias, output are tensors on the GPU
@cute.jit
def solve(
x: cute.Tensor,
weight: cute.Tensor,
bias: cute.Tensor,
output: cute.Tensor,
B: cute.Int32,
L: cute.Int32,
D: cute.Int32,
K: cute.Int32,
):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import jax
import jax.numpy as jnp


# x, weight, bias are tensors on GPU
@jax.jit
def solve(
x: jax.Array, weight: jax.Array, bias: jax.Array, B: int, L: int, D: int, K: int
) -> jax.Array:
# return output tensor directly
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from gpu.host import DeviceContext
from memory import UnsafePointer

# x, weight, bias, output are device pointers
@export
def solve(
x: UnsafePointer[Float32],
weight: UnsafePointer[Float32],
bias: UnsafePointer[Float32],
output: UnsafePointer[Float32],
B: Int32,
L: Int32,
D: Int32,
K: Int32,
):
pass
Loading
Loading