Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions challenges/medium/89_flash_attention/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
<p>
Implement the Flash Attention forward pass: given query, key, and value tensors, compute scaled
dot-product attention using the online softmax algorithm so that the full
<code>seq_len &times; seq_len</code> attention matrix is never materialized in global memory. Each
head attends to all other positions without a causal mask. All tensors use <code>float32</code>.
</p>

<svg width="700" height="300" viewBox="0 0 700 300" xmlns="http://www.w3.org/2000/svg" style="display:block; margin:20px auto;">
<rect width="700" height="300" fill="#222" rx="10"/>
<text x="350" y="26" fill="#ccc" font-family="monospace" font-size="13" text-anchor="middle">Flash Attention: tiled online-softmax avoids materializing the S×S matrix</text>

<!-- HBM label -->
<text x="30" y="58" fill="#aaa" font-family="monospace" font-size="11">HBM</text>
<!-- Q block -->
<rect x="20" y="65" width="80" height="30" fill="#2563eb" rx="3"/>
<text x="60" y="85" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">Q [S×D]</text>
<!-- K block -->
<rect x="120" y="65" width="80" height="30" fill="#7c3aed" rx="3"/>
<text x="160" y="85" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">K [S×D]</text>
<!-- V block -->
<rect x="220" y="65" width="80" height="30" fill="#065f46" rx="3"/>
<text x="260" y="85" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">V [S×D]</text>
<!-- O block -->
<rect x="340" y="65" width="80" height="30" fill="#92400e" rx="3"/>
<text x="380" y="85" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">O [S×D]</text>

<!-- SRAM label -->
<text x="30" y="155" fill="#aaa" font-family="monospace" font-size="11">SRAM (tile)</text>
<!-- Q tile -->
<rect x="20" y="163" width="55" height="28" fill="#3b82f6" rx="3"/>
<text x="47" y="181" fill="#fff" font-family="monospace" font-size="11" text-anchor="middle">Q_tile</text>
<!-- K tile -->
<rect x="90" y="163" width="55" height="28" fill="#8b5cf6" rx="3"/>
<text x="117" y="181" fill="#fff" font-family="monospace" font-size="11" text-anchor="middle">K_tile</text>
<!-- V tile -->
<rect x="160" y="163" width="55" height="28" fill="#10b981" rx="3"/>
<text x="187" y="181" fill="#fff" font-family="monospace" font-size="11" text-anchor="middle">V_tile</text>

<!-- Arrows HBM → SRAM -->
<line x1="60" y1="95" x2="47" y2="163" stroke="#60a5fa" stroke-width="1.5" stroke-dasharray="4,2"/>
<line x1="160" y1="95" x2="117" y2="163" stroke="#a78bfa" stroke-width="1.5" stroke-dasharray="4,2"/>
<line x1="260" y1="95" x2="187" y2="163" stroke="#6ee7b7" stroke-width="1.5" stroke-dasharray="4,2"/>

<!-- Compute box -->
<rect x="250" y="155" width="200" height="46" fill="#1e293b" rx="4" stroke="#475569" stroke-width="1"/>
<text x="350" y="175" fill="#4ade80" font-family="monospace" font-size="11" text-anchor="middle">S_ij = Q_tile @ K_tile^T</text>
<text x="350" y="192" fill="#4ade80" font-family="monospace" font-size="11" text-anchor="middle">update m, ℓ, O (online softmax)</text>

<!-- Arrow back to O -->
<line x1="380" y1="201" x2="380" y2="95" stroke="#fb923c" stroke-width="1.5" stroke-dasharray="4,2"/>

<!-- Algorithm steps -->
<text x="30" y="245" fill="#ccc" font-family="monospace" font-size="11">For each tile j of K, V:</text>
<text x="30" y="262" fill="#94a3b8" font-family="monospace" font-size="11"> m_new = max(m_old, rowmax(S_ij)) ← running max</text>
<text x="30" y="279" fill="#94a3b8" font-family="monospace" font-size="11"> ℓ_new = exp(m_old-m_new)·ℓ_old + rowsum(exp(S_ij-m_new)) ← running sum</text>
<text x="30" y="296" fill="#94a3b8" font-family="monospace" font-size="11"> O_new = diag(exp(m_old-m_new))·O_old + exp(S_ij-m_new)·V_tile</text>
</svg>

<h2>Implementation Requirements</h2>
<ul>
<li>Implement the function <code>solve(Q, K, V, output, num_heads, seq_len, head_dim)</code>.</li>
<li>Do not change the function signature or use external libraries beyond the standard GPU frameworks.</li>
<li>Write the result into the provided <code>output</code> buffer.</li>
<li>Use scale factor <code>1 / sqrt(head_dim)</code> and a softmax over the key (last) dimension.</li>
<li>No causal mask — every query position attends to all key positions.</li>
</ul>

<h2>Example</h2>
<p>
With <code>num_heads</code> = 1, <code>seq_len</code> = 3, <code>head_dim</code> = 4:
</p>
<p>
<strong>Input:</strong><br>
\(Q\) (3&times;4):
\[
\begin{bmatrix}
1 & 0 & 0 & 0 \\
0 & 1 & 0 & 0 \\
0 & 0 & 1 & 0
\end{bmatrix}
\]
\(K\) (3&times;4):
\[
\begin{bmatrix}
1 & 0 & 0 & 0 \\
0 & 1 & 0 & 0 \\
0 & 0 & 1 & 0
\end{bmatrix}
\]
\(V\) (3&times;4):
\[
\begin{bmatrix}
1 & 2 & 3 & 4 \\
5 & 6 & 7 & 8 \\
9 & 10 & 11 & 12
\end{bmatrix}
\]
</p>
<p>
<strong>Output</strong> (values rounded to 2 decimal places):<br>
\(\text{output}\) (3&times;4):
\[
\begin{bmatrix}
4.29 & 5.29 & 6.29 & 7.29 \\
5.00 & 6.00 & 7.00 & 8.00 \\
5.71 & 6.71 & 7.71 & 8.71
\end{bmatrix}
\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>num_heads</code> &le; 64</li>
<li>1 &le; <code>seq_len</code> &le; 8,192</li>
<li>8 &le; <code>head_dim</code> &le; 128; <code>head_dim</code> is a multiple of 8</li>
<li>All tensor values are <code>float32</code></li>
<li>Performance is measured with <code>num_heads</code> = 16, <code>seq_len</code> = 4,096, <code>head_dim</code> = 64</li>
</ul>
133 changes: 133 additions & 0 deletions challenges/medium/89_flash_attention/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import ctypes
import math
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="Flash Attention Forward",
atol=1e-03,
rtol=1e-03,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
output: torch.Tensor,
num_heads: int,
seq_len: int,
head_dim: int,
):
assert Q.shape == (num_heads, seq_len, head_dim)
assert K.shape == (num_heads, seq_len, head_dim)
assert V.shape == (num_heads, seq_len, head_dim)
assert output.shape == (num_heads, seq_len, head_dim)
assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32
assert Q.device.type == "cuda"
assert K.device.type == "cuda"
assert V.device.type == "cuda"
assert output.device.type == "cuda"

scale = 1.0 / math.sqrt(head_dim)
# scores: (num_heads, seq_len, seq_len)
scores = torch.bmm(Q, K.transpose(1, 2)) * scale
attn_weights = torch.softmax(scores, dim=-1)
output.copy_(torch.bmm(attn_weights, V))

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"Q": (ctypes.POINTER(ctypes.c_float), "in"),
"K": (ctypes.POINTER(ctypes.c_float), "in"),
"V": (ctypes.POINTER(ctypes.c_float), "in"),
"output": (ctypes.POINTER(ctypes.c_float), "out"),
"num_heads": (ctypes.c_int, "in"),
"seq_len": (ctypes.c_int, "in"),
"head_dim": (ctypes.c_int, "in"),
}

def _make_test_case(self, num_heads, seq_len, head_dim, zero_inputs=False):
device = "cuda"
dtype = torch.float32
if zero_inputs:
Q = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
K = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
V = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
else:
Q = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype)
K = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype)
V = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype)
output = torch.empty(num_heads, seq_len, head_dim, device=device, dtype=dtype)
return {
"Q": Q,
"K": K,
"V": V,
"output": output,
"num_heads": num_heads,
"seq_len": seq_len,
"head_dim": head_dim,
}

def generate_example_test(self) -> Dict[str, Any]:
device = "cuda"
dtype = torch.float32
Q = torch.tensor(
[[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]]],
device=device,
dtype=dtype,
)
K = torch.tensor(
[[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]]],
device=device,
dtype=dtype,
)
V = torch.tensor(
[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]],
device=device,
dtype=dtype,
)
output = torch.empty(1, 3, 4, device=device, dtype=dtype)
return {
"Q": Q,
"K": K,
"V": V,
"output": output,
"num_heads": 1,
"seq_len": 3,
"head_dim": 4,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
tests = []

# Edge cases: tiny sequences
tests.append(self._make_test_case(1, 1, 8))
tests.append(self._make_test_case(2, 2, 8, zero_inputs=True))

# Edge cases: small sequences, multiple heads
tests.append(self._make_test_case(4, 3, 16))

# Power-of-2 sizes
tests.append(self._make_test_case(1, 16, 32))
tests.append(self._make_test_case(4, 64, 32))
tests.append(self._make_test_case(8, 128, 64))

# Non-power-of-2 sequences
tests.append(self._make_test_case(2, 30, 32))
tests.append(self._make_test_case(4, 100, 64))
tests.append(self._make_test_case(2, 255, 32))

# Realistic size
tests.append(self._make_test_case(8, 512, 64))

return tests

def generate_performance_test(self) -> Dict[str, Any]:
return self._make_test_case(16, 4096, 64)
5 changes: 5 additions & 0 deletions challenges/medium/89_flash_attention/starter/starter.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <cuda_runtime.h>

// Q, K, V, output are device pointers
extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int num_heads,
int seq_len, int head_dim) {}
16 changes: 16 additions & 0 deletions challenges/medium/89_flash_attention/starter/starter.cute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import cutlass
import cutlass.cute as cute


# Q, K, V, output are tensors on the GPU
@cute.jit
def solve(
Q: cute.Tensor,
K: cute.Tensor,
V: cute.Tensor,
output: cute.Tensor,
num_heads: cute.Int32,
seq_len: cute.Int32,
head_dim: cute.Int32,
):
pass
16 changes: 16 additions & 0 deletions challenges/medium/89_flash_attention/starter/starter.jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import jax
import jax.numpy as jnp


# Q, K, V are tensors on GPU
@jax.jit
def solve(
Q: jax.Array,
K: jax.Array,
V: jax.Array,
num_heads: int,
seq_len: int,
head_dim: int,
) -> jax.Array:
# return output tensor directly
pass
15 changes: 15 additions & 0 deletions challenges/medium/89_flash_attention/starter/starter.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from gpu.host import DeviceContext
from memory import UnsafePointer

# Q, K, V, output are device pointers
@export
def solve(
Q: UnsafePointer[Float32],
K: UnsafePointer[Float32],
V: UnsafePointer[Float32],
output: UnsafePointer[Float32],
num_heads: Int32,
seq_len: Int32,
head_dim: Int32,
):
pass
14 changes: 14 additions & 0 deletions challenges/medium/89_flash_attention/starter/starter.pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch


# Q, K, V, output are tensors on the GPU
def solve(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
output: torch.Tensor,
num_heads: int,
seq_len: int,
head_dim: int,
):
pass
16 changes: 16 additions & 0 deletions challenges/medium/89_flash_attention/starter/starter.triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
import triton
import triton.language as tl


# Q, K, V, output are tensors on the GPU
def solve(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
output: torch.Tensor,
num_heads: int,
seq_len: int,
head_dim: int,
):
pass
Loading