From e4ce92a828fd3d4752053ec999d79255b6b9e1b0 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Sat, 28 Mar 2026 04:18:12 +0000 Subject: [PATCH] Add challenge 88: Prefix-Cached Attention (Medium) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Chunked-prefill attention where new query tokens attend to a full KV cache prefix plus causally to each other — the core operation in LLM inference systems like vLLM and TensorRT-LLM. Co-Authored-By: Claude Sonnet 4.6 --- .../88_prefix_cached_attention/challenge.html | 194 +++++++++++++++++ .../88_prefix_cached_attention/challenge.py | 197 ++++++++++++++++++ .../starter/starter.cu | 5 + .../starter/starter.cute.py | 17 ++ .../starter/starter.jax.py | 17 ++ .../starter/starter.mojo | 16 ++ .../starter/starter.pytorch.py | 15 ++ .../starter/starter.triton.py | 17 ++ 8 files changed, 478 insertions(+) create mode 100644 challenges/medium/88_prefix_cached_attention/challenge.html create mode 100644 challenges/medium/88_prefix_cached_attention/challenge.py create mode 100644 challenges/medium/88_prefix_cached_attention/starter/starter.cu create mode 100644 challenges/medium/88_prefix_cached_attention/starter/starter.cute.py create mode 100644 challenges/medium/88_prefix_cached_attention/starter/starter.jax.py create mode 100644 challenges/medium/88_prefix_cached_attention/starter/starter.mojo create mode 100644 challenges/medium/88_prefix_cached_attention/starter/starter.pytorch.py create mode 100644 challenges/medium/88_prefix_cached_attention/starter/starter.triton.py diff --git a/challenges/medium/88_prefix_cached_attention/challenge.html b/challenges/medium/88_prefix_cached_attention/challenge.html new file mode 100644 index 0000000..1f3de33 --- /dev/null +++ b/challenges/medium/88_prefix_cached_attention/challenge.html @@ -0,0 +1,194 @@ +

+Implement prefix-cached attention, the attention pattern used during chunked prefill in +LLM inference systems such as vLLM and TensorRT-LLM. Given query tensors for a chunk of +new_len tokens and packed key/value tensors containing both a cached prefix of +cache_len tokens and the new tokens themselves, compute scaled dot-product attention +where each new query token attends to all cached tokens (full access) and causally to the new +tokens (lower-triangular access). All tensors use float32. +

+ + + + + + Attention mask (cache_len=4, new_len=4, total_len=8) + + + K cache (j=0..3) + K new (j=4..7) + + + Q[0] + Q[1] + Q[2] + Q[3] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + attend (cache) + + attend (causal) + + masked out + + mask: j <= cache_len + i + scale = 1 / sqrt(head_dim) + output = softmax(scores) @ V + + +

Implementation Requirements

+ + +

Example

+

+ With num_heads = 2, cache_len = 2, new_len = 2, + head_dim = 4 (total_len = 4): +

+

+ Input:
+ \(Q_0\) (2×4): + \[ + \begin{bmatrix} + 1 & 0 & 0 & 1 \\ + 0 & 1 & 1 & 0 + \end{bmatrix} + \] + \(Q_1\) (2×4): + \[ + \begin{bmatrix} + 0 & 1 & 0 & 1 \\ + 1 & 0 & 1 & 0 + \end{bmatrix} + \] + \(K_0\) (4×4, cache rows first): + \[ + \begin{bmatrix} + 1 & 0 & 1 & 0 \\ + 0 & 1 & 0 & 1 \\ + 1 & 1 & 0 & 0 \\ + 0 & 0 & 1 & 1 + \end{bmatrix} + \] + \(K_1\) (4×4): + \[ + \begin{bmatrix} + 0 & 1 & 0 & -1 \\ + -1 & 0 & 1 & 0 \\ + 1 & 0 & -1 & 0 \\ + 0 & 1 & 0 & 1 + \end{bmatrix} + \] + \(V_0\) (4×4): + \[ + \begin{bmatrix} + 1 & 2 & 3 & 4 \\ + 5 & 6 & 7 & 8 \\ + 9 & 10 & 11 & 12 \\ + 13 & 14 & 15 & 16 + \end{bmatrix} + \] + \(V_1\) (4×4): + \[ + \begin{bmatrix} + -1 & -2 & -3 & -4 \\ + 2 & 3 & 4 & 5 \\ + 6 & 7 & 8 & 9 \\ + -2 & -3 & -4 & -5 + \end{bmatrix} + \] + \(\text{cache\_len} = 2\), \(\text{new\_len} = 2\).
+ Query token 0 (absolute position 2) attends to \(K[\,:\,,\,0{:}3,\,:\,]\); token 1 attends to all four keys. +

+

+ Output (values rounded to 2 decimal places):
+ \(\text{output}_0\) (2×4): + \[ + \begin{bmatrix} + 5.00 & 6.00 & 7.00 & 8.00 \\ + 7.00 & 8.00 & 9.00 & 10.00 + \end{bmatrix} + \] + \(\text{output}_1\) (2×4): + \[ + \begin{bmatrix} + 2.33 & 2.67 & 3.00 & 3.33 \\ + 1.25 & 1.25 & 1.25 & 1.25 + \end{bmatrix} + \] +

+ +

Constraints

+ diff --git a/challenges/medium/88_prefix_cached_attention/challenge.py b/challenges/medium/88_prefix_cached_attention/challenge.py new file mode 100644 index 0000000..12930f3 --- /dev/null +++ b/challenges/medium/88_prefix_cached_attention/challenge.py @@ -0,0 +1,197 @@ +import ctypes +import math +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="Prefix-Cached Attention", + atol=1e-04, + rtol=1e-04, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + output: torch.Tensor, + num_heads: int, + cache_len: int, + new_len: int, + head_dim: int, + ): + total_len = cache_len + new_len + assert Q.shape == (num_heads, new_len, head_dim) + assert K.shape == (num_heads, total_len, head_dim) + assert V.shape == (num_heads, total_len, head_dim) + assert output.shape == (num_heads, new_len, head_dim) + assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32 + assert Q.device.type == "cuda" + assert K.device.type == "cuda" + assert V.device.type == "cuda" + assert output.device.type == "cuda" + + scale = 1.0 / math.sqrt(head_dim) + + # scores: (num_heads, new_len, total_len) + scores = torch.bmm(Q, K.transpose(1, 2)) * scale + + # Causal mask: query token i (at absolute position cache_len+i) attends to + # key token j iff j <= cache_len + i. + # This gives full access to the KV cache and causal access within new tokens. + i_idx = torch.arange(new_len, device=Q.device).unsqueeze(1) # (new_len, 1) + j_idx = torch.arange(total_len, device=Q.device).unsqueeze(0) # (1, total_len) + mask = j_idx <= cache_len + i_idx # (new_len, total_len) + + scores = scores.masked_fill(~mask.unsqueeze(0), float("-inf")) + attn_weights = torch.softmax(scores, dim=-1) + output.copy_(torch.bmm(attn_weights, V)) + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "Q": (ctypes.POINTER(ctypes.c_float), "in"), + "K": (ctypes.POINTER(ctypes.c_float), "in"), + "V": (ctypes.POINTER(ctypes.c_float), "in"), + "output": (ctypes.POINTER(ctypes.c_float), "out"), + "num_heads": (ctypes.c_int, "in"), + "cache_len": (ctypes.c_int, "in"), + "new_len": (ctypes.c_int, "in"), + "head_dim": (ctypes.c_int, "in"), + } + + def _make_test_case(self, num_heads, cache_len, new_len, head_dim, zero_inputs=False): + total_len = cache_len + new_len + dtype = torch.float32 + device = "cuda" + if zero_inputs: + Q = torch.zeros(num_heads, new_len, head_dim, device=device, dtype=dtype) + K = torch.zeros(num_heads, total_len, head_dim, device=device, dtype=dtype) + V = torch.zeros(num_heads, total_len, head_dim, device=device, dtype=dtype) + else: + Q = torch.randn(num_heads, new_len, head_dim, device=device, dtype=dtype) + K = torch.randn(num_heads, total_len, head_dim, device=device, dtype=dtype) + V = torch.randn(num_heads, total_len, head_dim, device=device, dtype=dtype) + output = torch.zeros(num_heads, new_len, head_dim, device=device, dtype=dtype) + return { + "Q": Q, + "K": K, + "V": V, + "output": output, + "num_heads": num_heads, + "cache_len": cache_len, + "new_len": new_len, + "head_dim": head_dim, + } + + def generate_example_test(self) -> Dict[str, Any]: + num_heads = 2 + cache_len = 2 + new_len = 2 + head_dim = 4 + device = "cuda" + dtype = torch.float32 + + Q = torch.tensor( + [ + [[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0]], + [[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0]], + ], + device=device, + dtype=dtype, + ) + K = torch.tensor( + [ + [ + [1.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0], + [1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0], + ], + [ + [0.0, 1.0, 0.0, -1.0], + [-1.0, 0.0, 1.0, 0.0], + [1.0, 0.0, -1.0, 0.0], + [0.0, 1.0, 0.0, 1.0], + ], + ], + device=device, + dtype=dtype, + ) + V = torch.tensor( + [ + [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0], + ], + [ + [-1.0, -2.0, -3.0, -4.0], + [2.0, 3.0, 4.0, 5.0], + [6.0, 7.0, 8.0, 9.0], + [-2.0, -3.0, -4.0, -5.0], + ], + ], + device=device, + dtype=dtype, + ) + output = torch.zeros(num_heads, new_len, head_dim, device=device, dtype=dtype) + return { + "Q": Q, + "K": K, + "V": V, + "output": output, + "num_heads": num_heads, + "cache_len": cache_len, + "new_len": new_len, + "head_dim": head_dim, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + torch.manual_seed(42) + tests = [] + + # Edge case: single decode step against a single cached token + tests.append(self._make_test_case(1, 1, 1, 4)) + + # Edge case: zero inputs + tests.append(self._make_test_case(2, 2, 2, 4, zero_inputs=True)) + + # cache_len=0: pure causal self-attention over new tokens + tests.append(self._make_test_case(2, 0, 4, 8)) + + # Single decode step (new_len=1) — typical autoregressive generation + tests.append(self._make_test_case(4, 16, 1, 32)) + + # Power-of-2 sizes + tests.append(self._make_test_case(4, 32, 16, 32)) + + # Larger power-of-2 + tests.append(self._make_test_case(8, 64, 32, 64)) + + # Non-power-of-2 sizes + tests.append(self._make_test_case(4, 30, 15, 32)) + + # Non-power-of-2 with more heads + tests.append(self._make_test_case(6, 100, 50, 32)) + + # Long cache, short new chunk + tests.append(self._make_test_case(8, 255, 3, 64)) + + # Realistic dimensions (LLaMA-style), short chunk + tests.append(self._make_test_case(16, 128, 64, 64)) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + torch.manual_seed(0) + # LLaMA-3 8B style: 32 heads, head_dim=128 + # cache_len=1024 (prior context), new_len=512 (chunk being prefilled) + return self._make_test_case(32, 1024, 512, 128) diff --git a/challenges/medium/88_prefix_cached_attention/starter/starter.cu b/challenges/medium/88_prefix_cached_attention/starter/starter.cu new file mode 100644 index 0000000..0d61d0a --- /dev/null +++ b/challenges/medium/88_prefix_cached_attention/starter/starter.cu @@ -0,0 +1,5 @@ +#include + +// Q, K, V, output are device pointers +extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int num_heads, + int cache_len, int new_len, int head_dim) {} diff --git a/challenges/medium/88_prefix_cached_attention/starter/starter.cute.py b/challenges/medium/88_prefix_cached_attention/starter/starter.cute.py new file mode 100644 index 0000000..2c30634 --- /dev/null +++ b/challenges/medium/88_prefix_cached_attention/starter/starter.cute.py @@ -0,0 +1,17 @@ +import cutlass +import cutlass.cute as cute + + +# Q, K, V, output are tensors on the GPU +@cute.jit +def solve( + Q: cute.Tensor, + K: cute.Tensor, + V: cute.Tensor, + output: cute.Tensor, + num_heads: cute.Int32, + cache_len: cute.Int32, + new_len: cute.Int32, + head_dim: cute.Int32, +): + pass diff --git a/challenges/medium/88_prefix_cached_attention/starter/starter.jax.py b/challenges/medium/88_prefix_cached_attention/starter/starter.jax.py new file mode 100644 index 0000000..89fb6e5 --- /dev/null +++ b/challenges/medium/88_prefix_cached_attention/starter/starter.jax.py @@ -0,0 +1,17 @@ +import jax +import jax.numpy as jnp + + +# Q, K, V are tensors on GPU +@jax.jit +def solve( + Q: jax.Array, + K: jax.Array, + V: jax.Array, + num_heads: int, + cache_len: int, + new_len: int, + head_dim: int, +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/medium/88_prefix_cached_attention/starter/starter.mojo b/challenges/medium/88_prefix_cached_attention/starter/starter.mojo new file mode 100644 index 0000000..bca0a7d --- /dev/null +++ b/challenges/medium/88_prefix_cached_attention/starter/starter.mojo @@ -0,0 +1,16 @@ +from gpu.host import DeviceContext +from memory import UnsafePointer + +# Q, K, V, output are device pointers +@export +def solve( + Q: UnsafePointer[Float32], + K: UnsafePointer[Float32], + V: UnsafePointer[Float32], + output: UnsafePointer[Float32], + num_heads: Int32, + cache_len: Int32, + new_len: Int32, + head_dim: Int32, +): + pass diff --git a/challenges/medium/88_prefix_cached_attention/starter/starter.pytorch.py b/challenges/medium/88_prefix_cached_attention/starter/starter.pytorch.py new file mode 100644 index 0000000..9b5fdd3 --- /dev/null +++ b/challenges/medium/88_prefix_cached_attention/starter/starter.pytorch.py @@ -0,0 +1,15 @@ +import torch + + +# Q, K, V, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + output: torch.Tensor, + num_heads: int, + cache_len: int, + new_len: int, + head_dim: int, +): + pass diff --git a/challenges/medium/88_prefix_cached_attention/starter/starter.triton.py b/challenges/medium/88_prefix_cached_attention/starter/starter.triton.py new file mode 100644 index 0000000..aea43c7 --- /dev/null +++ b/challenges/medium/88_prefix_cached_attention/starter/starter.triton.py @@ -0,0 +1,17 @@ +import torch +import triton +import triton.language as tl + + +# Q, K, V, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + output: torch.Tensor, + num_heads: int, + cache_len: int, + new_len: int, + head_dim: int, +): + pass