From f75325160f33fef106ff3405452dcb1c6fe818db Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 04:16:00 +0000 Subject: [PATCH 1/2] Add challenge 86: Paged KV-Cache Attention (Medium) Implements decode-phase attention over a non-contiguous paged KV cache, modeled on the vLLM paged attention architecture. Teaches block-table indirection, online softmax across scattered memory pages, and the memory access patterns central to LLM serving workloads. Co-Authored-By: Claude Sonnet 4.6 --- .../medium/86_paged_attention/challenge.html | 180 ++++++++++++++ .../medium/86_paged_attention/challenge.py | 231 ++++++++++++++++++ .../86_paged_attention/starter/starter.cu | 7 + .../starter/starter.cute.py | 20 ++ .../86_paged_attention/starter/starter.jax.py | 20 ++ .../86_paged_attention/starter/starter.mojo | 21 ++ .../starter/starter.pytorch.py | 18 ++ .../starter/starter.triton.py | 20 ++ 8 files changed, 517 insertions(+) create mode 100644 challenges/medium/86_paged_attention/challenge.html create mode 100644 challenges/medium/86_paged_attention/challenge.py create mode 100644 challenges/medium/86_paged_attention/starter/starter.cu create mode 100644 challenges/medium/86_paged_attention/starter/starter.cute.py create mode 100644 challenges/medium/86_paged_attention/starter/starter.jax.py create mode 100644 challenges/medium/86_paged_attention/starter/starter.mojo create mode 100644 challenges/medium/86_paged_attention/starter/starter.pytorch.py create mode 100644 challenges/medium/86_paged_attention/starter/starter.triton.py diff --git a/challenges/medium/86_paged_attention/challenge.html b/challenges/medium/86_paged_attention/challenge.html new file mode 100644 index 00000000..c1d75301 --- /dev/null +++ b/challenges/medium/86_paged_attention/challenge.html @@ -0,0 +1,180 @@ +

+ Implement decode-phase attention over a paged KV cache. In LLM serving systems (e.g., vLLM), + the key and value tensors for each sequence are stored in fixed-size memory blocks (pages) that + may be scattered non-contiguously across a shared GPU memory pool. A block_table maps each + sequence's logical block indices to physical block indices in the cache pool. Given a single query vector + per sequence (one new token being generated), compute the attention output by gathering the relevant + K/V blocks via the block table and computing scaled dot-product attention over the full context. +

+ + + + + + Paged KV-Cache: block_table lookup + + + + seq 0: [3, 7, —] + + + + seq 1: [1, 5, 9] + + block_table + + + phys block 3 + phys block 7 + phys block 1 + phys block 5 + phys block 9 + + + + + + + + + + + + + + + + + + + + K_cache / V_cache pool + + + + blk 0 + + + + blk 1 + + + blk 2 + + + + blk 3 + + + blk 4 + + + + blk 5 + + + blk 6 + + + + blk 7 + + + blk 8 + + + + + + + + + + + + Attention Computation + scores = Q · K_gatheredᵀ/ √d + weights = softmax(scores) + output = weights · V_gathered + K/V tokens gathered from + non-contiguous pool blocks + via block_table indirection + + +

Implementation Requirements

+

+ Implement the function solve(Q, K_cache, V_cache, block_table, context_lens, output, batch_size, num_heads, head_dim, block_size, max_blocks_per_seq) + that computes paged decode-phase attention: +

+ +

+ For each sequence s and each attention head h, compute: +

+
    +
  1. + Gather the context_lens[s] key and value vectors from the paged cache using block_table[s]. + Token at logical position t lives in physical block block_table[s, t / block_size] + at offset t % block_size within that block. +
  2. +
  3. + Compute scaled dot-product attention: + scores[t] = Q[s, h] · K[s, h, t] / √head_dim +
  4. +
  5. + Apply softmax over all context_lens[s] positions to get attention weights. +
  6. +
  7. + Compute output[s, h] = ∑t weights[t] × V[s, h, t]. +
  8. +
+

+ Do not use external libraries beyond the framework you select. Keep the function signature unchanged. + Write results directly into output. +

+ +

Example

+

+ With batch_size = 1, num_heads = 1, head_dim = 4, + block_size = 2, context_lens = [2], block_table = [[0]]: +

+
+Q[0, 0]            = [1.0, 1.0, 0.0, 0.0]
+
+K_cache[0, 0, 0]   = [1.0, 0.0, 0.0, 0.0]   # block 0, token 0
+K_cache[0, 1, 0]   = [0.0, 1.0, 0.0, 0.0]   # block 0, token 1
+
+V_cache[0, 0, 0]   = [2.0, 0.0, 0.0, 0.0]
+V_cache[0, 1, 0]   = [0.0, 4.0, 0.0, 0.0]
+
+

+ Scores (before softmax): +

+
+score[0] = (1·1 + 1·0 + 0·0 + 0·0) / √4 = 0.5
+score[1] = (1·0 + 1·1 + 0·0 + 0·0) / √4 = 0.5
+
+

+ Attention weights = softmax([0.5, 0.5]) = [0.5, 0.5] +

+
+output[0, 0] = 0.5 × [2, 0, 0, 0] + 0.5 × [0, 4, 0, 0] = [1.0, 2.0, 0.0, 0.0]
+
+ +

Constraints

+ diff --git a/challenges/medium/86_paged_attention/challenge.py b/challenges/medium/86_paged_attention/challenge.py new file mode 100644 index 00000000..824f869b --- /dev/null +++ b/challenges/medium/86_paged_attention/challenge.py @@ -0,0 +1,231 @@ +import ctypes +import math +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="Paged KV-Cache Attention", + atol=1e-04, + rtol=1e-04, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + Q: torch.Tensor, + K_cache: torch.Tensor, + V_cache: torch.Tensor, + block_table: torch.Tensor, + context_lens: torch.Tensor, + output: torch.Tensor, + batch_size: int, + num_heads: int, + head_dim: int, + block_size: int, + max_blocks_per_seq: int, + ): + assert Q.shape == (batch_size, num_heads, head_dim) + assert K_cache.shape[1] == block_size + assert K_cache.shape[2] == num_heads + assert K_cache.shape[3] == head_dim + assert V_cache.shape == K_cache.shape + assert block_table.shape == (batch_size, max_blocks_per_seq) + assert context_lens.shape == (batch_size,) + assert output.shape == (batch_size, num_heads, head_dim) + assert Q.dtype == K_cache.dtype == V_cache.dtype == output.dtype == torch.float32 + assert block_table.dtype == context_lens.dtype == torch.int32 + assert Q.device.type == "cuda" + assert K_cache.device.type == "cuda" + assert V_cache.device.type == "cuda" + assert block_table.device.type == "cuda" + assert context_lens.device.type == "cuda" + assert output.device.type == "cuda" + + scale = 1.0 / math.sqrt(head_dim) + + for s in range(batch_size): + ctx_len = context_lens[s].item() + n_blocks = (ctx_len + block_size - 1) // block_size + + # Gather the physical blocks assigned to this sequence + phys_blocks = block_table[s, :n_blocks].long() # (n_blocks,) + + # Gather K and V: (n_blocks, block_size, num_heads, head_dim) + K_blocks = K_cache[phys_blocks] + V_blocks = V_cache[phys_blocks] + + # Flatten to (n_blocks * block_size, num_heads, head_dim) and trim + K_seq = K_blocks.reshape(-1, num_heads, head_dim)[ + :ctx_len + ] # (ctx_len, num_heads, head_dim) + V_seq = V_blocks.reshape(-1, num_heads, head_dim)[:ctx_len] + + # Transpose to (num_heads, ctx_len, head_dim) + K_seq = K_seq.transpose(0, 1).contiguous() + V_seq = V_seq.transpose(0, 1).contiguous() + + # Q[s]: (num_heads, head_dim) -> (num_heads, 1, head_dim) + q = Q[s].unsqueeze(1) + + # Scaled dot-product: (num_heads, 1, ctx_len) + scores = torch.bmm(q, K_seq.transpose(1, 2)) * scale + attn_weights = torch.softmax(scores, dim=-1) + + # Weighted sum: (num_heads, 1, head_dim) -> (num_heads, head_dim) + out = torch.bmm(attn_weights, V_seq).squeeze(1) + output[s].copy_(out) + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "Q": (ctypes.POINTER(ctypes.c_float), "in"), + "K_cache": (ctypes.POINTER(ctypes.c_float), "in"), + "V_cache": (ctypes.POINTER(ctypes.c_float), "in"), + "block_table": (ctypes.POINTER(ctypes.c_int), "in"), + "context_lens": (ctypes.POINTER(ctypes.c_int), "in"), + "output": (ctypes.POINTER(ctypes.c_float), "out"), + "batch_size": (ctypes.c_int, "in"), + "num_heads": (ctypes.c_int, "in"), + "head_dim": (ctypes.c_int, "in"), + "block_size": (ctypes.c_int, "in"), + "max_blocks_per_seq": (ctypes.c_int, "in"), + } + + def _make_test_case( + self, batch_size, num_heads, head_dim, block_size, context_lens, zero_q=False + ): + if isinstance(context_lens, int): + context_lens = [context_lens] * batch_size + + max_ctx = max(context_lens) + max_blocks_per_seq = (max_ctx + block_size - 1) // block_size + + # Allocate exactly the blocks needed, assigned sequentially + total_blocks = sum((cl + block_size - 1) // block_size for cl in context_lens) + + device = "cuda" + dtype = torch.float32 + + if zero_q: + Q = torch.zeros(batch_size, num_heads, head_dim, device=device, dtype=dtype) + else: + Q = torch.randn(batch_size, num_heads, head_dim, device=device, dtype=dtype) + + K_cache = torch.randn( + total_blocks, block_size, num_heads, head_dim, device=device, dtype=dtype + ) + V_cache = torch.randn( + total_blocks, block_size, num_heads, head_dim, device=device, dtype=dtype + ) + + block_table = torch.zeros(batch_size, max_blocks_per_seq, device=device, dtype=torch.int32) + ctx_lens_tensor = torch.tensor(context_lens, device=device, dtype=torch.int32) + + # Assign physical blocks sequentially per sequence + block_idx = 0 + for s in range(batch_size): + n_blocks = (context_lens[s] + block_size - 1) // block_size + for b in range(n_blocks): + block_table[s, b] = block_idx + block_idx += 1 + + output = torch.zeros(batch_size, num_heads, head_dim, device=device, dtype=dtype) + + return { + "Q": Q, + "K_cache": K_cache, + "V_cache": V_cache, + "block_table": block_table, + "context_lens": ctx_lens_tensor, + "output": output, + "batch_size": batch_size, + "num_heads": num_heads, + "head_dim": head_dim, + "block_size": block_size, + "max_blocks_per_seq": max_blocks_per_seq, + } + + def generate_example_test(self) -> Dict[str, Any]: + device = "cuda" + dtype = torch.float32 + + # batch=1, heads=1, head_dim=4, block_size=2, ctx_len=2 + # Q · K / sqrt(4): [1,1,0,0]·[1,0,0,0]/2 = 0.5, [1,1,0,0]·[0,1,0,0]/2 = 0.5 + # attn = softmax([0.5, 0.5]) = [0.5, 0.5] + # output = 0.5*[2,0,0,0] + 0.5*[0,4,0,0] = [1, 2, 0, 0] + Q = torch.tensor([[[1.0, 1.0, 0.0, 0.0]]], device=device, dtype=dtype) # (1, 1, 4) + K_cache = torch.tensor( + [[[[1.0, 0.0, 0.0, 0.0]], [[0.0, 1.0, 0.0, 0.0]]]], + device=device, + dtype=dtype, + ) # (1 block, block_size=2, 1 head, head_dim=4) + V_cache = torch.tensor( + [[[[2.0, 0.0, 0.0, 0.0]], [[0.0, 4.0, 0.0, 0.0]]]], + device=device, + dtype=dtype, + ) + block_table = torch.tensor( + [[0]], device=device, dtype=torch.int32 + ) # seq 0 -> physical block 0 + context_lens = torch.tensor([2], device=device, dtype=torch.int32) + output = torch.zeros(1, 1, 4, device=device, dtype=dtype) + + return { + "Q": Q, + "K_cache": K_cache, + "V_cache": V_cache, + "block_table": block_table, + "context_lens": context_lens, + "output": output, + "batch_size": 1, + "num_heads": 1, + "head_dim": 4, + "block_size": 2, + "max_blocks_per_seq": 1, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + torch.manual_seed(42) + tests = [] + + # Edge case: single KV token + tests.append(self._make_test_case(1, 1, 4, 2, 1)) + + # Edge case: ctx_len equals block_size exactly + tests.append(self._make_test_case(1, 2, 8, 4, 4)) + + # Zero query: softmax is uniform, output is mean of V + tests.append(self._make_test_case(2, 2, 8, 4, 8, zero_q=True)) + + # Variable context lengths within a batch + tests.append(self._make_test_case(4, 4, 32, 16, [16, 32, 48, 64])) + + # Power-of-2 context lengths + tests.append(self._make_test_case(4, 4, 32, 16, 32)) + + # Power-of-2, larger + tests.append(self._make_test_case(4, 8, 64, 16, 128)) + + # Non-power-of-2 context length + tests.append(self._make_test_case(2, 4, 32, 16, 30)) + + # Non-power-of-2, straddles multiple blocks + tests.append(self._make_test_case(4, 4, 64, 16, 100)) + + # Mixed variable lengths with non-power-of-2 + tests.append(self._make_test_case(4, 8, 64, 16, [50, 100, 150, 200])) + + # Realistic: LLaMA-3 8B style (8 Q heads), shorter context + tests.append(self._make_test_case(4, 8, 128, 16, 256)) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + torch.manual_seed(0) + # Realistic LLM decode: batch=8, 32 heads, head_dim=128, block_size=16, ctx_len=2048 + return self._make_test_case(8, 32, 128, 16, 2048) diff --git a/challenges/medium/86_paged_attention/starter/starter.cu b/challenges/medium/86_paged_attention/starter/starter.cu new file mode 100644 index 00000000..b72d1b9c --- /dev/null +++ b/challenges/medium/86_paged_attention/starter/starter.cu @@ -0,0 +1,7 @@ +#include + +// Q, K_cache, V_cache, block_table, context_lens, output are device pointers +extern "C" void solve(const float* Q, const float* K_cache, const float* V_cache, + const int* block_table, const int* context_lens, float* output, + int batch_size, int num_heads, int head_dim, int block_size, + int max_blocks_per_seq) {} diff --git a/challenges/medium/86_paged_attention/starter/starter.cute.py b/challenges/medium/86_paged_attention/starter/starter.cute.py new file mode 100644 index 00000000..d703ed65 --- /dev/null +++ b/challenges/medium/86_paged_attention/starter/starter.cute.py @@ -0,0 +1,20 @@ +import cutlass +import cutlass.cute as cute + + +# Q, K_cache, V_cache, block_table, context_lens, output are tensors on the GPU +@cute.jit +def solve( + Q: cute.Tensor, + K_cache: cute.Tensor, + V_cache: cute.Tensor, + block_table: cute.Tensor, + context_lens: cute.Tensor, + output: cute.Tensor, + batch_size: cute.Int32, + num_heads: cute.Int32, + head_dim: cute.Int32, + block_size: cute.Int32, + max_blocks_per_seq: cute.Int32, +): + pass diff --git a/challenges/medium/86_paged_attention/starter/starter.jax.py b/challenges/medium/86_paged_attention/starter/starter.jax.py new file mode 100644 index 00000000..cd82ce9b --- /dev/null +++ b/challenges/medium/86_paged_attention/starter/starter.jax.py @@ -0,0 +1,20 @@ +import jax +import jax.numpy as jnp + + +# Q, K_cache, V_cache, block_table, context_lens are tensors on GPU +@jax.jit +def solve( + Q: jax.Array, + K_cache: jax.Array, + V_cache: jax.Array, + block_table: jax.Array, + context_lens: jax.Array, + batch_size: int, + num_heads: int, + head_dim: int, + block_size: int, + max_blocks_per_seq: int, +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/medium/86_paged_attention/starter/starter.mojo b/challenges/medium/86_paged_attention/starter/starter.mojo new file mode 100644 index 00000000..ce8b7e21 --- /dev/null +++ b/challenges/medium/86_paged_attention/starter/starter.mojo @@ -0,0 +1,21 @@ +from gpu.host import DeviceContext +from gpu.id import block_dim, block_idx, thread_idx +from memory import UnsafePointer +from math import ceildiv + +# Q, K_cache, V_cache, block_table, context_lens, output are device pointers +@export +def solve( + Q: UnsafePointer[Float32], + K_cache: UnsafePointer[Float32], + V_cache: UnsafePointer[Float32], + block_table: UnsafePointer[Int32], + context_lens: UnsafePointer[Int32], + output: UnsafePointer[Float32], + batch_size: Int32, + num_heads: Int32, + head_dim: Int32, + block_size: Int32, + max_blocks_per_seq: Int32, +): + pass diff --git a/challenges/medium/86_paged_attention/starter/starter.pytorch.py b/challenges/medium/86_paged_attention/starter/starter.pytorch.py new file mode 100644 index 00000000..aeb42ce3 --- /dev/null +++ b/challenges/medium/86_paged_attention/starter/starter.pytorch.py @@ -0,0 +1,18 @@ +import torch + + +# Q, K_cache, V_cache, block_table, context_lens, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K_cache: torch.Tensor, + V_cache: torch.Tensor, + block_table: torch.Tensor, + context_lens: torch.Tensor, + output: torch.Tensor, + batch_size: int, + num_heads: int, + head_dim: int, + block_size: int, + max_blocks_per_seq: int, +): + pass diff --git a/challenges/medium/86_paged_attention/starter/starter.triton.py b/challenges/medium/86_paged_attention/starter/starter.triton.py new file mode 100644 index 00000000..7c392628 --- /dev/null +++ b/challenges/medium/86_paged_attention/starter/starter.triton.py @@ -0,0 +1,20 @@ +import torch +import triton +import triton.language as tl + + +# Q, K_cache, V_cache, block_table, context_lens, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K_cache: torch.Tensor, + V_cache: torch.Tensor, + block_table: torch.Tensor, + context_lens: torch.Tensor, + output: torch.Tensor, + batch_size: int, + num_heads: int, + head_dim: int, + block_size: int, + max_blocks_per_seq: int, +): + pass From bb42a8eb74777b25acd4c0f2f910d96726a250db Mon Sep 17 00:00:00 2001 From: James Song Date: Thu, 26 Mar 2026 22:10:18 -0400 Subject: [PATCH 2/2] Improve paged attention HTML: SVG diagram, LaTeX example and formulas Redesign SVG: block_table as a proper table with column headers, cache pool as horizontal memory strip with color-coded blocks and sequence labels. Convert example and computation steps from HTML entities to LaTeX math notation. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../medium/86_paged_attention/challenge.html | 249 ++++++++++-------- 1 file changed, 135 insertions(+), 114 deletions(-) diff --git a/challenges/medium/86_paged_attention/challenge.html b/challenges/medium/86_paged_attention/challenge.html index c1d75301..9eb99eb7 100644 --- a/challenges/medium/86_paged_attention/challenge.html +++ b/challenges/medium/86_paged_attention/challenge.html @@ -7,98 +7,113 @@ K/V blocks via the block table and computing scaled dot-product attention over the full context.

- - - - - Paged KV-Cache: block_table lookup - - - - seq 0: [3, 7, —] - - - - seq 1: [1, 5, 9] - - block_table - - - phys block 3 - phys block 7 - phys block 1 - phys block 5 - phys block 9 - - - - - - - - - + + - - + + - - + + - - K_cache / V_cache pool - - - - blk 0 - - - - blk 1 - - - blk 2 - - - - blk 3 - - - blk 4 - - - - blk 5 - - - blk 6 - - - - blk 7 - - - blk 8 - - - - - - - - - - - - Attention Computation - scores = Q · K_gatheredᵀ/ √d - weights = softmax(scores) - output = weights · V_gathered - K/V tokens gathered from - non-contiguous pool blocks - via block_table indirection + + + + block_table + + + blk 0 + blk 1 + blk 2 + + + seq 0 + + 3 + + 7 + + + + + seq 1 + + 1 + + 5 + + 9 + + values = physical block indices in pool ↓ + + + + + K_cache / V_cache pool (GPU memory) + + + + + blk 0 + + + + blk 1 + seq1.0 + + + + blk 2 + + + + blk 3 + seq0.0 + + + + blk 4 + + + + blk 5 + seq1.1 + + + + blk 6 + + + + blk 7 + seq0.1 + + + + blk 8 + + + + blk 9 + seq1.2 + + + + + + Decode Attention (per sequence s, per head h) + + 1. + Gather K, V: token t is at pool[ block_table[s, t/B] ], offset t%B + + 2. + scores[t] = Q[s,h] · K[s,h,t] / √head_dim for t = 0 .. context_lens[s]-1 + + 3. + output[s,h] = ∑_t softmax(scores)[t] · V[s,h,t]

Implementation Requirements

@@ -115,23 +130,24 @@

Implementation Requirements

  • output: result of shape (batch_size, num_heads, head_dim), dtype float32
  • - For each sequence s and each attention head h, compute: + For each sequence \(s\) and each attention head \(h\), compute:

    1. - Gather the context_lens[s] key and value vectors from the paged cache using block_table[s]. - Token at logical position t lives in physical block block_table[s, t / block_size] - at offset t % block_size within that block. + Gather the \(\text{context_lens}[s]\) key and value vectors from the paged cache using + \(\text{block_table}[s]\). Token at logical position \(t\) lives in physical block + \(\text{block_table}[s,\;\lfloor t / B \rfloor]\) at offset \(t \bmod B\) within that block, + where \(B = \text{block_size}\).
    2. Compute scaled dot-product attention: - scores[t] = Q[s, h] · K[s, h, t] / √head_dim + \[\text{scores}[t] = \frac{Q[s, h] \cdot K[s, h, t]}{\sqrt{\text{head_dim}}}\]
    3. - Apply softmax over all context_lens[s] positions to get attention weights. + Apply softmax over all \(\text{context_lens}[s]\) positions to get attention weights.
    4. - Compute output[s, h] = ∑t weights[t] × V[s, h, t]. + Compute: \(\displaystyle \text{output}[s, h] = \sum_{t} \text{softmax}(\text{scores})[t] \cdot V[s, h, t]\)

    @@ -141,31 +157,36 @@

    Implementation Requirements

    Example

    - With batch_size = 1, num_heads = 1, head_dim = 4, - block_size = 2, context_lens = [2], block_table = [[0]]: + Input: batch_size = 1, num_heads = 1, head_dim = 4, + block_size = 2, context_lens = [2], block_table = [[0]]

    -
    -Q[0, 0]            = [1.0, 1.0, 0.0, 0.0]
    -
    -K_cache[0, 0, 0]   = [1.0, 0.0, 0.0, 0.0]   # block 0, token 0
    -K_cache[0, 1, 0]   = [0.0, 1.0, 0.0, 0.0]   # block 0, token 1
    -
    -V_cache[0, 0, 0]   = [2.0, 0.0, 0.0, 0.0]
    -V_cache[0, 1, 0]   = [0.0, 4.0, 0.0, 0.0]
    -

    - Scores (before softmax): + \(Q[0, 0] = \begin{bmatrix} 1.0 & 1.0 & 0.0 & 0.0 \end{bmatrix}\)

    -
    -score[0] = (1·1 + 1·0 + 0·0 + 0·0) / √4 = 0.5
    -score[1] = (1·0 + 1·1 + 0·0 + 0·0) / √4 = 0.5
    -

    - Attention weights = softmax([0.5, 0.5]) = [0.5, 0.5] + Keys gathered from block 0 (2 tokens): + \[ + K_0 = \begin{bmatrix} 1.0 & 0.0 & 0.0 & 0.0 \end{bmatrix}, \quad + K_1 = \begin{bmatrix} 0.0 & 1.0 & 0.0 & 0.0 \end{bmatrix} + \] + Values gathered from block 0: + \[ + V_0 = \begin{bmatrix} 2.0 & 0.0 & 0.0 & 0.0 \end{bmatrix}, \quad + V_1 = \begin{bmatrix} 0.0 & 4.0 & 0.0 & 0.0 \end{bmatrix} + \] +

    +

    + Scores (before softmax): + \[ + s_0 = \frac{Q \cdot K_0}{\sqrt{4}} = \frac{1}{2} = 0.5, \quad + s_1 = \frac{Q \cdot K_1}{\sqrt{4}} = \frac{1}{2} = 0.5 + \] + Attention weights: \(\text{softmax}([0.5, 0.5]) = [0.5, 0.5]\) + \[ + \text{output}[0, 0] = 0.5 \cdot V_0 + 0.5 \cdot V_1 = + \begin{bmatrix} 1.0 & 2.0 & 0.0 & 0.0 \end{bmatrix} + \]

    -
    -output[0, 0] = 0.5 × [2, 0, 0, 0] + 0.5 × [0, 4, 0, 0] = [1.0, 2.0, 0.0, 0.0]
    -

    Constraints