From 17497f58efd1ba458724814367cc5d75d1cb8fa7 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Tue, 31 Mar 2026 04:35:20 +0000 Subject: [PATCH] Add challenge 89: Flash Attention Forward (Medium) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces a Flash Attention forward-pass challenge teaching the online-softmax tiling algorithm that avoids materializing the full seq_len × seq_len attention matrix in global memory. Co-Authored-By: Claude Sonnet 4.6 --- .../medium/89_flash_attention/challenge.html | 118 ++++++++++++++++ .../medium/89_flash_attention/challenge.py | 133 ++++++++++++++++++ .../89_flash_attention/starter/starter.cu | 5 + .../starter/starter.cute.py | 16 +++ .../89_flash_attention/starter/starter.jax.py | 16 +++ .../89_flash_attention/starter/starter.mojo | 15 ++ .../starter/starter.pytorch.py | 14 ++ .../starter/starter.triton.py | 16 +++ 8 files changed, 333 insertions(+) create mode 100644 challenges/medium/89_flash_attention/challenge.html create mode 100644 challenges/medium/89_flash_attention/challenge.py create mode 100644 challenges/medium/89_flash_attention/starter/starter.cu create mode 100644 challenges/medium/89_flash_attention/starter/starter.cute.py create mode 100644 challenges/medium/89_flash_attention/starter/starter.jax.py create mode 100644 challenges/medium/89_flash_attention/starter/starter.mojo create mode 100644 challenges/medium/89_flash_attention/starter/starter.pytorch.py create mode 100644 challenges/medium/89_flash_attention/starter/starter.triton.py diff --git a/challenges/medium/89_flash_attention/challenge.html b/challenges/medium/89_flash_attention/challenge.html new file mode 100644 index 00000000..d9f7c73b --- /dev/null +++ b/challenges/medium/89_flash_attention/challenge.html @@ -0,0 +1,118 @@ +

+Implement the Flash Attention forward pass: given query, key, and value tensors, compute scaled +dot-product attention using the online softmax algorithm so that the full +seq_len × seq_len attention matrix is never materialized in global memory. Each +head attends to all other positions without a causal mask. All tensors use float32. +

+ + + + Flash Attention: tiled online-softmax avoids materializing the S×S matrix + + + HBM + + + Q [S×D] + + + K [S×D] + + + V [S×D] + + + O [S×D] + + + SRAM (tile) + + + Q_tile + + + K_tile + + + V_tile + + + + + + + + + S_ij = Q_tile @ K_tile^T + update m, ℓ, O (online softmax) + + + + + + For each tile j of K, V: + m_new = max(m_old, rowmax(S_ij)) ← running max + ℓ_new = exp(m_old-m_new)·ℓ_old + rowsum(exp(S_ij-m_new)) ← running sum + O_new = diag(exp(m_old-m_new))·O_old + exp(S_ij-m_new)·V_tile + + +

Implementation Requirements

+ + +

Example

+

+ With num_heads = 1, seq_len = 3, head_dim = 4: +

+

+ Input:
+ \(Q\) (3×4): + \[ + \begin{bmatrix} + 1 & 0 & 0 & 0 \\ + 0 & 1 & 0 & 0 \\ + 0 & 0 & 1 & 0 + \end{bmatrix} + \] + \(K\) (3×4): + \[ + \begin{bmatrix} + 1 & 0 & 0 & 0 \\ + 0 & 1 & 0 & 0 \\ + 0 & 0 & 1 & 0 + \end{bmatrix} + \] + \(V\) (3×4): + \[ + \begin{bmatrix} + 1 & 2 & 3 & 4 \\ + 5 & 6 & 7 & 8 \\ + 9 & 10 & 11 & 12 + \end{bmatrix} + \] +

+

+ Output (values rounded to 2 decimal places):
+ \(\text{output}\) (3×4): + \[ + \begin{bmatrix} + 4.29 & 5.29 & 6.29 & 7.29 \\ + 5.00 & 6.00 & 7.00 & 8.00 \\ + 5.71 & 6.71 & 7.71 & 8.71 + \end{bmatrix} + \] +

+ +

Constraints

+ diff --git a/challenges/medium/89_flash_attention/challenge.py b/challenges/medium/89_flash_attention/challenge.py new file mode 100644 index 00000000..b2cd3200 --- /dev/null +++ b/challenges/medium/89_flash_attention/challenge.py @@ -0,0 +1,133 @@ +import ctypes +import math +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="Flash Attention Forward", + atol=1e-03, + rtol=1e-03, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + output: torch.Tensor, + num_heads: int, + seq_len: int, + head_dim: int, + ): + assert Q.shape == (num_heads, seq_len, head_dim) + assert K.shape == (num_heads, seq_len, head_dim) + assert V.shape == (num_heads, seq_len, head_dim) + assert output.shape == (num_heads, seq_len, head_dim) + assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32 + assert Q.device.type == "cuda" + assert K.device.type == "cuda" + assert V.device.type == "cuda" + assert output.device.type == "cuda" + + scale = 1.0 / math.sqrt(head_dim) + # scores: (num_heads, seq_len, seq_len) + scores = torch.bmm(Q, K.transpose(1, 2)) * scale + attn_weights = torch.softmax(scores, dim=-1) + output.copy_(torch.bmm(attn_weights, V)) + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "Q": (ctypes.POINTER(ctypes.c_float), "in"), + "K": (ctypes.POINTER(ctypes.c_float), "in"), + "V": (ctypes.POINTER(ctypes.c_float), "in"), + "output": (ctypes.POINTER(ctypes.c_float), "out"), + "num_heads": (ctypes.c_int, "in"), + "seq_len": (ctypes.c_int, "in"), + "head_dim": (ctypes.c_int, "in"), + } + + def _make_test_case(self, num_heads, seq_len, head_dim, zero_inputs=False): + device = "cuda" + dtype = torch.float32 + if zero_inputs: + Q = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype) + K = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype) + V = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype) + else: + Q = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype) + K = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype) + V = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype) + output = torch.empty(num_heads, seq_len, head_dim, device=device, dtype=dtype) + return { + "Q": Q, + "K": K, + "V": V, + "output": output, + "num_heads": num_heads, + "seq_len": seq_len, + "head_dim": head_dim, + } + + def generate_example_test(self) -> Dict[str, Any]: + device = "cuda" + dtype = torch.float32 + Q = torch.tensor( + [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]]], + device=device, + dtype=dtype, + ) + K = torch.tensor( + [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]]], + device=device, + dtype=dtype, + ) + V = torch.tensor( + [[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]], + device=device, + dtype=dtype, + ) + output = torch.empty(1, 3, 4, device=device, dtype=dtype) + return { + "Q": Q, + "K": K, + "V": V, + "output": output, + "num_heads": 1, + "seq_len": 3, + "head_dim": 4, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + tests = [] + + # Edge cases: tiny sequences + tests.append(self._make_test_case(1, 1, 8)) + tests.append(self._make_test_case(2, 2, 8, zero_inputs=True)) + + # Edge cases: small sequences, multiple heads + tests.append(self._make_test_case(4, 3, 16)) + + # Power-of-2 sizes + tests.append(self._make_test_case(1, 16, 32)) + tests.append(self._make_test_case(4, 64, 32)) + tests.append(self._make_test_case(8, 128, 64)) + + # Non-power-of-2 sequences + tests.append(self._make_test_case(2, 30, 32)) + tests.append(self._make_test_case(4, 100, 64)) + tests.append(self._make_test_case(2, 255, 32)) + + # Realistic size + tests.append(self._make_test_case(8, 512, 64)) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + return self._make_test_case(16, 4096, 64) diff --git a/challenges/medium/89_flash_attention/starter/starter.cu b/challenges/medium/89_flash_attention/starter/starter.cu new file mode 100644 index 00000000..707af5d8 --- /dev/null +++ b/challenges/medium/89_flash_attention/starter/starter.cu @@ -0,0 +1,5 @@ +#include + +// Q, K, V, output are device pointers +extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int num_heads, + int seq_len, int head_dim) {} diff --git a/challenges/medium/89_flash_attention/starter/starter.cute.py b/challenges/medium/89_flash_attention/starter/starter.cute.py new file mode 100644 index 00000000..06da950a --- /dev/null +++ b/challenges/medium/89_flash_attention/starter/starter.cute.py @@ -0,0 +1,16 @@ +import cutlass +import cutlass.cute as cute + + +# Q, K, V, output are tensors on the GPU +@cute.jit +def solve( + Q: cute.Tensor, + K: cute.Tensor, + V: cute.Tensor, + output: cute.Tensor, + num_heads: cute.Int32, + seq_len: cute.Int32, + head_dim: cute.Int32, +): + pass diff --git a/challenges/medium/89_flash_attention/starter/starter.jax.py b/challenges/medium/89_flash_attention/starter/starter.jax.py new file mode 100644 index 00000000..e3fe5114 --- /dev/null +++ b/challenges/medium/89_flash_attention/starter/starter.jax.py @@ -0,0 +1,16 @@ +import jax +import jax.numpy as jnp + + +# Q, K, V are tensors on GPU +@jax.jit +def solve( + Q: jax.Array, + K: jax.Array, + V: jax.Array, + num_heads: int, + seq_len: int, + head_dim: int, +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/medium/89_flash_attention/starter/starter.mojo b/challenges/medium/89_flash_attention/starter/starter.mojo new file mode 100644 index 00000000..a2db4bd8 --- /dev/null +++ b/challenges/medium/89_flash_attention/starter/starter.mojo @@ -0,0 +1,15 @@ +from gpu.host import DeviceContext +from memory import UnsafePointer + +# Q, K, V, output are device pointers +@export +def solve( + Q: UnsafePointer[Float32], + K: UnsafePointer[Float32], + V: UnsafePointer[Float32], + output: UnsafePointer[Float32], + num_heads: Int32, + seq_len: Int32, + head_dim: Int32, +): + pass diff --git a/challenges/medium/89_flash_attention/starter/starter.pytorch.py b/challenges/medium/89_flash_attention/starter/starter.pytorch.py new file mode 100644 index 00000000..7ae6982d --- /dev/null +++ b/challenges/medium/89_flash_attention/starter/starter.pytorch.py @@ -0,0 +1,14 @@ +import torch + + +# Q, K, V, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + output: torch.Tensor, + num_heads: int, + seq_len: int, + head_dim: int, +): + pass diff --git a/challenges/medium/89_flash_attention/starter/starter.triton.py b/challenges/medium/89_flash_attention/starter/starter.triton.py new file mode 100644 index 00000000..b0e09f23 --- /dev/null +++ b/challenges/medium/89_flash_attention/starter/starter.triton.py @@ -0,0 +1,16 @@ +import torch +import triton +import triton.language as tl + + +# Q, K, V, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + output: torch.Tensor, + num_heads: int, + seq_len: int, + head_dim: int, +): + pass