diff --git a/challenges/medium/89_flash_attention/challenge.html b/challenges/medium/89_flash_attention/challenge.html
new file mode 100644
index 0000000..d9f7c73
--- /dev/null
+++ b/challenges/medium/89_flash_attention/challenge.html
@@ -0,0 +1,118 @@
+
+Implement the Flash Attention forward pass: given query, key, and value tensors, compute scaled
+dot-product attention using the online softmax algorithm so that the full
+seq_len × seq_len attention matrix is never materialized in global memory. Each
+head attends to all other positions without a causal mask. All tensors use float32.
+
+
+
+
+ Flash Attention: tiled online-softmax avoids materializing the S×S matrix
+
+
+ HBM
+
+
+ Q [S×D]
+
+
+ K [S×D]
+
+
+ V [S×D]
+
+
+ O [S×D]
+
+
+ SRAM (tile)
+
+
+ Q_tile
+
+
+ K_tile
+
+
+ V_tile
+
+
+
+
+
+
+
+
+ S_ij = Q_tile @ K_tile^T
+ update m, ℓ, O (online softmax)
+
+
+
+
+
+ For each tile j of K, V:
+ m_new = max(m_old, rowmax(S_ij)) ← running max
+ ℓ_new = exp(m_old-m_new)·ℓ_old + rowsum(exp(S_ij-m_new)) ← running sum
+ O_new = diag(exp(m_old-m_new))·O_old + exp(S_ij-m_new)·V_tile
+
+
+Implementation Requirements
+
+ Implement the function solve(Q, K, V, output, num_heads, seq_len, head_dim).
+ Do not change the function signature or use external libraries beyond the standard GPU frameworks.
+ Write the result into the provided output buffer.
+ Use scale factor 1 / sqrt(head_dim) and a softmax over the key (last) dimension.
+ No causal mask — every query position attends to all key positions.
+
+
+Example
+
+ With num_heads = 1, seq_len = 3, head_dim = 4:
+
+
+ Input:
+ \(Q\) (3×4):
+ \[
+ \begin{bmatrix}
+ 1 & 0 & 0 & 0 \\
+ 0 & 1 & 0 & 0 \\
+ 0 & 0 & 1 & 0
+ \end{bmatrix}
+ \]
+ \(K\) (3×4):
+ \[
+ \begin{bmatrix}
+ 1 & 0 & 0 & 0 \\
+ 0 & 1 & 0 & 0 \\
+ 0 & 0 & 1 & 0
+ \end{bmatrix}
+ \]
+ \(V\) (3×4):
+ \[
+ \begin{bmatrix}
+ 1 & 2 & 3 & 4 \\
+ 5 & 6 & 7 & 8 \\
+ 9 & 10 & 11 & 12
+ \end{bmatrix}
+ \]
+
+
+ Output (values rounded to 2 decimal places):
+ \(\text{output}\) (3×4):
+ \[
+ \begin{bmatrix}
+ 4.29 & 5.29 & 6.29 & 7.29 \\
+ 5.00 & 6.00 & 7.00 & 8.00 \\
+ 5.71 & 6.71 & 7.71 & 8.71
+ \end{bmatrix}
+ \]
+
+
+Constraints
+
+ 1 ≤ num_heads ≤ 64
+ 1 ≤ seq_len ≤ 8,192
+ 8 ≤ head_dim ≤ 128; head_dim is a multiple of 8
+ All tensor values are float32
+ Performance is measured with num_heads = 16, seq_len = 4,096, head_dim = 64
+
diff --git a/challenges/medium/89_flash_attention/challenge.py b/challenges/medium/89_flash_attention/challenge.py
new file mode 100644
index 0000000..b2cd320
--- /dev/null
+++ b/challenges/medium/89_flash_attention/challenge.py
@@ -0,0 +1,133 @@
+import ctypes
+import math
+from typing import Any, Dict, List
+
+import torch
+from core.challenge_base import ChallengeBase
+
+
+class Challenge(ChallengeBase):
+ def __init__(self):
+ super().__init__(
+ name="Flash Attention Forward",
+ atol=1e-03,
+ rtol=1e-03,
+ num_gpus=1,
+ access_tier="free",
+ )
+
+ def reference_impl(
+ self,
+ Q: torch.Tensor,
+ K: torch.Tensor,
+ V: torch.Tensor,
+ output: torch.Tensor,
+ num_heads: int,
+ seq_len: int,
+ head_dim: int,
+ ):
+ assert Q.shape == (num_heads, seq_len, head_dim)
+ assert K.shape == (num_heads, seq_len, head_dim)
+ assert V.shape == (num_heads, seq_len, head_dim)
+ assert output.shape == (num_heads, seq_len, head_dim)
+ assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32
+ assert Q.device.type == "cuda"
+ assert K.device.type == "cuda"
+ assert V.device.type == "cuda"
+ assert output.device.type == "cuda"
+
+ scale = 1.0 / math.sqrt(head_dim)
+ # scores: (num_heads, seq_len, seq_len)
+ scores = torch.bmm(Q, K.transpose(1, 2)) * scale
+ attn_weights = torch.softmax(scores, dim=-1)
+ output.copy_(torch.bmm(attn_weights, V))
+
+ def get_solve_signature(self) -> Dict[str, tuple]:
+ return {
+ "Q": (ctypes.POINTER(ctypes.c_float), "in"),
+ "K": (ctypes.POINTER(ctypes.c_float), "in"),
+ "V": (ctypes.POINTER(ctypes.c_float), "in"),
+ "output": (ctypes.POINTER(ctypes.c_float), "out"),
+ "num_heads": (ctypes.c_int, "in"),
+ "seq_len": (ctypes.c_int, "in"),
+ "head_dim": (ctypes.c_int, "in"),
+ }
+
+ def _make_test_case(self, num_heads, seq_len, head_dim, zero_inputs=False):
+ device = "cuda"
+ dtype = torch.float32
+ if zero_inputs:
+ Q = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
+ K = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
+ V = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
+ else:
+ Q = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype)
+ K = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype)
+ V = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype)
+ output = torch.empty(num_heads, seq_len, head_dim, device=device, dtype=dtype)
+ return {
+ "Q": Q,
+ "K": K,
+ "V": V,
+ "output": output,
+ "num_heads": num_heads,
+ "seq_len": seq_len,
+ "head_dim": head_dim,
+ }
+
+ def generate_example_test(self) -> Dict[str, Any]:
+ device = "cuda"
+ dtype = torch.float32
+ Q = torch.tensor(
+ [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]]],
+ device=device,
+ dtype=dtype,
+ )
+ K = torch.tensor(
+ [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]]],
+ device=device,
+ dtype=dtype,
+ )
+ V = torch.tensor(
+ [[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]],
+ device=device,
+ dtype=dtype,
+ )
+ output = torch.empty(1, 3, 4, device=device, dtype=dtype)
+ return {
+ "Q": Q,
+ "K": K,
+ "V": V,
+ "output": output,
+ "num_heads": 1,
+ "seq_len": 3,
+ "head_dim": 4,
+ }
+
+ def generate_functional_test(self) -> List[Dict[str, Any]]:
+ tests = []
+
+ # Edge cases: tiny sequences
+ tests.append(self._make_test_case(1, 1, 8))
+ tests.append(self._make_test_case(2, 2, 8, zero_inputs=True))
+
+ # Edge cases: small sequences, multiple heads
+ tests.append(self._make_test_case(4, 3, 16))
+
+ # Power-of-2 sizes
+ tests.append(self._make_test_case(1, 16, 32))
+ tests.append(self._make_test_case(4, 64, 32))
+ tests.append(self._make_test_case(8, 128, 64))
+
+ # Non-power-of-2 sequences
+ tests.append(self._make_test_case(2, 30, 32))
+ tests.append(self._make_test_case(4, 100, 64))
+ tests.append(self._make_test_case(2, 255, 32))
+
+ # Realistic size
+ tests.append(self._make_test_case(8, 512, 64))
+
+ return tests
+
+ def generate_performance_test(self) -> Dict[str, Any]:
+ return self._make_test_case(16, 4096, 64)
diff --git a/challenges/medium/89_flash_attention/starter/starter.cu b/challenges/medium/89_flash_attention/starter/starter.cu
new file mode 100644
index 0000000..707af5d
--- /dev/null
+++ b/challenges/medium/89_flash_attention/starter/starter.cu
@@ -0,0 +1,5 @@
+#include
+
+// Q, K, V, output are device pointers
+extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int num_heads,
+ int seq_len, int head_dim) {}
diff --git a/challenges/medium/89_flash_attention/starter/starter.cute.py b/challenges/medium/89_flash_attention/starter/starter.cute.py
new file mode 100644
index 0000000..06da950
--- /dev/null
+++ b/challenges/medium/89_flash_attention/starter/starter.cute.py
@@ -0,0 +1,16 @@
+import cutlass
+import cutlass.cute as cute
+
+
+# Q, K, V, output are tensors on the GPU
+@cute.jit
+def solve(
+ Q: cute.Tensor,
+ K: cute.Tensor,
+ V: cute.Tensor,
+ output: cute.Tensor,
+ num_heads: cute.Int32,
+ seq_len: cute.Int32,
+ head_dim: cute.Int32,
+):
+ pass
diff --git a/challenges/medium/89_flash_attention/starter/starter.jax.py b/challenges/medium/89_flash_attention/starter/starter.jax.py
new file mode 100644
index 0000000..e3fe511
--- /dev/null
+++ b/challenges/medium/89_flash_attention/starter/starter.jax.py
@@ -0,0 +1,16 @@
+import jax
+import jax.numpy as jnp
+
+
+# Q, K, V are tensors on GPU
+@jax.jit
+def solve(
+ Q: jax.Array,
+ K: jax.Array,
+ V: jax.Array,
+ num_heads: int,
+ seq_len: int,
+ head_dim: int,
+) -> jax.Array:
+ # return output tensor directly
+ pass
diff --git a/challenges/medium/89_flash_attention/starter/starter.mojo b/challenges/medium/89_flash_attention/starter/starter.mojo
new file mode 100644
index 0000000..a2db4bd
--- /dev/null
+++ b/challenges/medium/89_flash_attention/starter/starter.mojo
@@ -0,0 +1,15 @@
+from gpu.host import DeviceContext
+from memory import UnsafePointer
+
+# Q, K, V, output are device pointers
+@export
+def solve(
+ Q: UnsafePointer[Float32],
+ K: UnsafePointer[Float32],
+ V: UnsafePointer[Float32],
+ output: UnsafePointer[Float32],
+ num_heads: Int32,
+ seq_len: Int32,
+ head_dim: Int32,
+):
+ pass
diff --git a/challenges/medium/89_flash_attention/starter/starter.pytorch.py b/challenges/medium/89_flash_attention/starter/starter.pytorch.py
new file mode 100644
index 0000000..7ae6982
--- /dev/null
+++ b/challenges/medium/89_flash_attention/starter/starter.pytorch.py
@@ -0,0 +1,14 @@
+import torch
+
+
+# Q, K, V, output are tensors on the GPU
+def solve(
+ Q: torch.Tensor,
+ K: torch.Tensor,
+ V: torch.Tensor,
+ output: torch.Tensor,
+ num_heads: int,
+ seq_len: int,
+ head_dim: int,
+):
+ pass
diff --git a/challenges/medium/89_flash_attention/starter/starter.triton.py b/challenges/medium/89_flash_attention/starter/starter.triton.py
new file mode 100644
index 0000000..b0e09f2
--- /dev/null
+++ b/challenges/medium/89_flash_attention/starter/starter.triton.py
@@ -0,0 +1,16 @@
+import torch
+import triton
+import triton.language as tl
+
+
+# Q, K, V, output are tensors on the GPU
+def solve(
+ Q: torch.Tensor,
+ K: torch.Tensor,
+ V: torch.Tensor,
+ output: torch.Tensor,
+ num_heads: int,
+ seq_len: int,
+ head_dim: int,
+):
+ pass