diff --git a/challenges/medium/88_prefix_cached_attention/challenge.html b/challenges/medium/88_prefix_cached_attention/challenge.html
new file mode 100644
index 0000000..1f3de33
--- /dev/null
+++ b/challenges/medium/88_prefix_cached_attention/challenge.html
@@ -0,0 +1,194 @@
+
+Implement prefix-cached attention, the attention pattern used during chunked prefill in
+LLM inference systems such as vLLM and TensorRT-LLM. Given query tensors for a chunk of
+new_len tokens and packed key/value tensors containing both a cached prefix of
+cache_len tokens and the new tokens themselves, compute scaled dot-product attention
+where each new query token attends to all cached tokens (full access) and causally to the new
+tokens (lower-triangular access). All tensors use float32.
+
+
+
+
+
+
+ Attention mask (cache_len=4, new_len=4, total_len=8)
+
+
+ K cache (j=0..3)
+ K new (j=4..7)
+
+
+ Q[0]
+ Q[1]
+ Q[2]
+ Q[3]
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ attend (cache)
+
+ attend (causal)
+
+ masked out
+
+ mask: j <= cache_len + i
+ scale = 1 / sqrt(head_dim)
+ output = softmax(scores) @ V
+
+
+Implementation Requirements
+
+ Implement the function solve(Q, K, V, output, num_heads, cache_len, new_len, head_dim).
+ Do not change the function signature or use external libraries beyond the standard GPU frameworks.
+ Write the result into the provided output buffer.
+ Use scaled dot-product attention with scale factor 1 / sqrt(head_dim).
+
+ Apply the causal mask: query token i (at absolute sequence position
+ cache_len + i) attends to key token j if and only if
+ j ≤ cache_len + i. Masked positions receive -inf before
+ softmax.
+
+
+ K and V are packed buffers of shape
+ (num_heads, cache_len + new_len, head_dim); the first cache_len
+ positions along the sequence dimension are the cached prefix.
+
+
+
+Example
+
+ With num_heads = 2, cache_len = 2, new_len = 2,
+ head_dim = 4 (total_len = 4):
+
+
+ Input:
+ \(Q_0\) (2×4):
+ \[
+ \begin{bmatrix}
+ 1 & 0 & 0 & 1 \\
+ 0 & 1 & 1 & 0
+ \end{bmatrix}
+ \]
+ \(Q_1\) (2×4):
+ \[
+ \begin{bmatrix}
+ 0 & 1 & 0 & 1 \\
+ 1 & 0 & 1 & 0
+ \end{bmatrix}
+ \]
+ \(K_0\) (4×4, cache rows first):
+ \[
+ \begin{bmatrix}
+ 1 & 0 & 1 & 0 \\
+ 0 & 1 & 0 & 1 \\
+ 1 & 1 & 0 & 0 \\
+ 0 & 0 & 1 & 1
+ \end{bmatrix}
+ \]
+ \(K_1\) (4×4):
+ \[
+ \begin{bmatrix}
+ 0 & 1 & 0 & -1 \\
+ -1 & 0 & 1 & 0 \\
+ 1 & 0 & -1 & 0 \\
+ 0 & 1 & 0 & 1
+ \end{bmatrix}
+ \]
+ \(V_0\) (4×4):
+ \[
+ \begin{bmatrix}
+ 1 & 2 & 3 & 4 \\
+ 5 & 6 & 7 & 8 \\
+ 9 & 10 & 11 & 12 \\
+ 13 & 14 & 15 & 16
+ \end{bmatrix}
+ \]
+ \(V_1\) (4×4):
+ \[
+ \begin{bmatrix}
+ -1 & -2 & -3 & -4 \\
+ 2 & 3 & 4 & 5 \\
+ 6 & 7 & 8 & 9 \\
+ -2 & -3 & -4 & -5
+ \end{bmatrix}
+ \]
+ \(\text{cache\_len} = 2\), \(\text{new\_len} = 2\).
+ Query token 0 (absolute position 2) attends to \(K[\,:\,,\,0{:}3,\,:\,]\); token 1 attends to all four keys.
+
+
+ Output (values rounded to 2 decimal places):
+ \(\text{output}_0\) (2×4):
+ \[
+ \begin{bmatrix}
+ 5.00 & 6.00 & 7.00 & 8.00 \\
+ 7.00 & 8.00 & 9.00 & 10.00
+ \end{bmatrix}
+ \]
+ \(\text{output}_1\) (2×4):
+ \[
+ \begin{bmatrix}
+ 2.33 & 2.67 & 3.00 & 3.33 \\
+ 1.25 & 1.25 & 1.25 & 1.25
+ \end{bmatrix}
+ \]
+
+
+Constraints
+
+ 1 ≤ num_heads ≤ 64
+ 0 ≤ cache_len ≤ 4,096
+ 1 ≤ new_len ≤ 1,024
+ 8 ≤ head_dim ≤ 256; head_dim is a multiple of 8
+ All tensor values are float32
+
+ Performance is measured with num_heads = 32, cache_len = 1,024,
+ new_len = 512, head_dim = 128
+
+
diff --git a/challenges/medium/88_prefix_cached_attention/challenge.py b/challenges/medium/88_prefix_cached_attention/challenge.py
new file mode 100644
index 0000000..12930f3
--- /dev/null
+++ b/challenges/medium/88_prefix_cached_attention/challenge.py
@@ -0,0 +1,197 @@
+import ctypes
+import math
+from typing import Any, Dict, List
+
+import torch
+from core.challenge_base import ChallengeBase
+
+
+class Challenge(ChallengeBase):
+ def __init__(self):
+ super().__init__(
+ name="Prefix-Cached Attention",
+ atol=1e-04,
+ rtol=1e-04,
+ num_gpus=1,
+ access_tier="free",
+ )
+
+ def reference_impl(
+ self,
+ Q: torch.Tensor,
+ K: torch.Tensor,
+ V: torch.Tensor,
+ output: torch.Tensor,
+ num_heads: int,
+ cache_len: int,
+ new_len: int,
+ head_dim: int,
+ ):
+ total_len = cache_len + new_len
+ assert Q.shape == (num_heads, new_len, head_dim)
+ assert K.shape == (num_heads, total_len, head_dim)
+ assert V.shape == (num_heads, total_len, head_dim)
+ assert output.shape == (num_heads, new_len, head_dim)
+ assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32
+ assert Q.device.type == "cuda"
+ assert K.device.type == "cuda"
+ assert V.device.type == "cuda"
+ assert output.device.type == "cuda"
+
+ scale = 1.0 / math.sqrt(head_dim)
+
+ # scores: (num_heads, new_len, total_len)
+ scores = torch.bmm(Q, K.transpose(1, 2)) * scale
+
+ # Causal mask: query token i (at absolute position cache_len+i) attends to
+ # key token j iff j <= cache_len + i.
+ # This gives full access to the KV cache and causal access within new tokens.
+ i_idx = torch.arange(new_len, device=Q.device).unsqueeze(1) # (new_len, 1)
+ j_idx = torch.arange(total_len, device=Q.device).unsqueeze(0) # (1, total_len)
+ mask = j_idx <= cache_len + i_idx # (new_len, total_len)
+
+ scores = scores.masked_fill(~mask.unsqueeze(0), float("-inf"))
+ attn_weights = torch.softmax(scores, dim=-1)
+ output.copy_(torch.bmm(attn_weights, V))
+
+ def get_solve_signature(self) -> Dict[str, tuple]:
+ return {
+ "Q": (ctypes.POINTER(ctypes.c_float), "in"),
+ "K": (ctypes.POINTER(ctypes.c_float), "in"),
+ "V": (ctypes.POINTER(ctypes.c_float), "in"),
+ "output": (ctypes.POINTER(ctypes.c_float), "out"),
+ "num_heads": (ctypes.c_int, "in"),
+ "cache_len": (ctypes.c_int, "in"),
+ "new_len": (ctypes.c_int, "in"),
+ "head_dim": (ctypes.c_int, "in"),
+ }
+
+ def _make_test_case(self, num_heads, cache_len, new_len, head_dim, zero_inputs=False):
+ total_len = cache_len + new_len
+ dtype = torch.float32
+ device = "cuda"
+ if zero_inputs:
+ Q = torch.zeros(num_heads, new_len, head_dim, device=device, dtype=dtype)
+ K = torch.zeros(num_heads, total_len, head_dim, device=device, dtype=dtype)
+ V = torch.zeros(num_heads, total_len, head_dim, device=device, dtype=dtype)
+ else:
+ Q = torch.randn(num_heads, new_len, head_dim, device=device, dtype=dtype)
+ K = torch.randn(num_heads, total_len, head_dim, device=device, dtype=dtype)
+ V = torch.randn(num_heads, total_len, head_dim, device=device, dtype=dtype)
+ output = torch.zeros(num_heads, new_len, head_dim, device=device, dtype=dtype)
+ return {
+ "Q": Q,
+ "K": K,
+ "V": V,
+ "output": output,
+ "num_heads": num_heads,
+ "cache_len": cache_len,
+ "new_len": new_len,
+ "head_dim": head_dim,
+ }
+
+ def generate_example_test(self) -> Dict[str, Any]:
+ num_heads = 2
+ cache_len = 2
+ new_len = 2
+ head_dim = 4
+ device = "cuda"
+ dtype = torch.float32
+
+ Q = torch.tensor(
+ [
+ [[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0]],
+ [[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0]],
+ ],
+ device=device,
+ dtype=dtype,
+ )
+ K = torch.tensor(
+ [
+ [
+ [1.0, 0.0, 1.0, 0.0],
+ [0.0, 1.0, 0.0, 1.0],
+ [1.0, 1.0, 0.0, 0.0],
+ [0.0, 0.0, 1.0, 1.0],
+ ],
+ [
+ [0.0, 1.0, 0.0, -1.0],
+ [-1.0, 0.0, 1.0, 0.0],
+ [1.0, 0.0, -1.0, 0.0],
+ [0.0, 1.0, 0.0, 1.0],
+ ],
+ ],
+ device=device,
+ dtype=dtype,
+ )
+ V = torch.tensor(
+ [
+ [
+ [1.0, 2.0, 3.0, 4.0],
+ [5.0, 6.0, 7.0, 8.0],
+ [9.0, 10.0, 11.0, 12.0],
+ [13.0, 14.0, 15.0, 16.0],
+ ],
+ [
+ [-1.0, -2.0, -3.0, -4.0],
+ [2.0, 3.0, 4.0, 5.0],
+ [6.0, 7.0, 8.0, 9.0],
+ [-2.0, -3.0, -4.0, -5.0],
+ ],
+ ],
+ device=device,
+ dtype=dtype,
+ )
+ output = torch.zeros(num_heads, new_len, head_dim, device=device, dtype=dtype)
+ return {
+ "Q": Q,
+ "K": K,
+ "V": V,
+ "output": output,
+ "num_heads": num_heads,
+ "cache_len": cache_len,
+ "new_len": new_len,
+ "head_dim": head_dim,
+ }
+
+ def generate_functional_test(self) -> List[Dict[str, Any]]:
+ torch.manual_seed(42)
+ tests = []
+
+ # Edge case: single decode step against a single cached token
+ tests.append(self._make_test_case(1, 1, 1, 4))
+
+ # Edge case: zero inputs
+ tests.append(self._make_test_case(2, 2, 2, 4, zero_inputs=True))
+
+ # cache_len=0: pure causal self-attention over new tokens
+ tests.append(self._make_test_case(2, 0, 4, 8))
+
+ # Single decode step (new_len=1) — typical autoregressive generation
+ tests.append(self._make_test_case(4, 16, 1, 32))
+
+ # Power-of-2 sizes
+ tests.append(self._make_test_case(4, 32, 16, 32))
+
+ # Larger power-of-2
+ tests.append(self._make_test_case(8, 64, 32, 64))
+
+ # Non-power-of-2 sizes
+ tests.append(self._make_test_case(4, 30, 15, 32))
+
+ # Non-power-of-2 with more heads
+ tests.append(self._make_test_case(6, 100, 50, 32))
+
+ # Long cache, short new chunk
+ tests.append(self._make_test_case(8, 255, 3, 64))
+
+ # Realistic dimensions (LLaMA-style), short chunk
+ tests.append(self._make_test_case(16, 128, 64, 64))
+
+ return tests
+
+ def generate_performance_test(self) -> Dict[str, Any]:
+ torch.manual_seed(0)
+ # LLaMA-3 8B style: 32 heads, head_dim=128
+ # cache_len=1024 (prior context), new_len=512 (chunk being prefilled)
+ return self._make_test_case(32, 1024, 512, 128)
diff --git a/challenges/medium/88_prefix_cached_attention/starter/starter.cu b/challenges/medium/88_prefix_cached_attention/starter/starter.cu
new file mode 100644
index 0000000..0d61d0a
--- /dev/null
+++ b/challenges/medium/88_prefix_cached_attention/starter/starter.cu
@@ -0,0 +1,5 @@
+#include
+
+// Q, K, V, output are device pointers
+extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int num_heads,
+ int cache_len, int new_len, int head_dim) {}
diff --git a/challenges/medium/88_prefix_cached_attention/starter/starter.cute.py b/challenges/medium/88_prefix_cached_attention/starter/starter.cute.py
new file mode 100644
index 0000000..2c30634
--- /dev/null
+++ b/challenges/medium/88_prefix_cached_attention/starter/starter.cute.py
@@ -0,0 +1,17 @@
+import cutlass
+import cutlass.cute as cute
+
+
+# Q, K, V, output are tensors on the GPU
+@cute.jit
+def solve(
+ Q: cute.Tensor,
+ K: cute.Tensor,
+ V: cute.Tensor,
+ output: cute.Tensor,
+ num_heads: cute.Int32,
+ cache_len: cute.Int32,
+ new_len: cute.Int32,
+ head_dim: cute.Int32,
+):
+ pass
diff --git a/challenges/medium/88_prefix_cached_attention/starter/starter.jax.py b/challenges/medium/88_prefix_cached_attention/starter/starter.jax.py
new file mode 100644
index 0000000..89fb6e5
--- /dev/null
+++ b/challenges/medium/88_prefix_cached_attention/starter/starter.jax.py
@@ -0,0 +1,17 @@
+import jax
+import jax.numpy as jnp
+
+
+# Q, K, V are tensors on GPU
+@jax.jit
+def solve(
+ Q: jax.Array,
+ K: jax.Array,
+ V: jax.Array,
+ num_heads: int,
+ cache_len: int,
+ new_len: int,
+ head_dim: int,
+) -> jax.Array:
+ # return output tensor directly
+ pass
diff --git a/challenges/medium/88_prefix_cached_attention/starter/starter.mojo b/challenges/medium/88_prefix_cached_attention/starter/starter.mojo
new file mode 100644
index 0000000..bca0a7d
--- /dev/null
+++ b/challenges/medium/88_prefix_cached_attention/starter/starter.mojo
@@ -0,0 +1,16 @@
+from gpu.host import DeviceContext
+from memory import UnsafePointer
+
+# Q, K, V, output are device pointers
+@export
+def solve(
+ Q: UnsafePointer[Float32],
+ K: UnsafePointer[Float32],
+ V: UnsafePointer[Float32],
+ output: UnsafePointer[Float32],
+ num_heads: Int32,
+ cache_len: Int32,
+ new_len: Int32,
+ head_dim: Int32,
+):
+ pass
diff --git a/challenges/medium/88_prefix_cached_attention/starter/starter.pytorch.py b/challenges/medium/88_prefix_cached_attention/starter/starter.pytorch.py
new file mode 100644
index 0000000..9b5fdd3
--- /dev/null
+++ b/challenges/medium/88_prefix_cached_attention/starter/starter.pytorch.py
@@ -0,0 +1,15 @@
+import torch
+
+
+# Q, K, V, output are tensors on the GPU
+def solve(
+ Q: torch.Tensor,
+ K: torch.Tensor,
+ V: torch.Tensor,
+ output: torch.Tensor,
+ num_heads: int,
+ cache_len: int,
+ new_len: int,
+ head_dim: int,
+):
+ pass
diff --git a/challenges/medium/88_prefix_cached_attention/starter/starter.triton.py b/challenges/medium/88_prefix_cached_attention/starter/starter.triton.py
new file mode 100644
index 0000000..aea43c7
--- /dev/null
+++ b/challenges/medium/88_prefix_cached_attention/starter/starter.triton.py
@@ -0,0 +1,17 @@
+import torch
+import triton
+import triton.language as tl
+
+
+# Q, K, V, output are tensors on the GPU
+def solve(
+ Q: torch.Tensor,
+ K: torch.Tensor,
+ V: torch.Tensor,
+ output: torch.Tensor,
+ num_heads: int,
+ cache_len: int,
+ new_len: int,
+ head_dim: int,
+):
+ pass