From 61943672e92451c3f51652152050bd6d33df2179 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Thu, 26 Mar 2026 04:47:51 +0000 Subject: [PATCH 1/2] Add challenge 87: Speculative Decoding Verification (Medium) Implements the token acceptance/rejection step from speculative decoding: given B draft sequences with T candidate tokens each, determine which tokens to accept (based on min(1, q/p) acceptance probability), resample a replacement from the adjusted distribution clamp(q-p, 0) on the first rejection, or sample a bonus token if all T draft tokens are accepted. Key GPU learning moments: - Sequential acceptance chain with inherent data dependency across positions - Parallel reduction to find first rejection across the batch dimension - O(V) inverse-CDF sampling via prefix sum over vocabulary Performance test: B=64, T=8, V=32,768 (Mistral/LLaMA-2 vocab size) Co-Authored-By: Claude Sonnet 4.6 --- .../challenge.html | 148 +++++++++ .../challenge.py | 300 ++++++++++++++++++ .../starter/starter.cu | 5 + .../starter/starter.cute.py | 17 + .../starter/starter.jax.py | 17 + .../starter/starter.mojo | 16 + .../starter/starter.pytorch.py | 15 + .../starter/starter.triton.py | 17 + 8 files changed, 535 insertions(+) create mode 100644 challenges/medium/87_speculative_decoding_verification/challenge.html create mode 100644 challenges/medium/87_speculative_decoding_verification/challenge.py create mode 100644 challenges/medium/87_speculative_decoding_verification/starter/starter.cu create mode 100644 challenges/medium/87_speculative_decoding_verification/starter/starter.cute.py create mode 100644 challenges/medium/87_speculative_decoding_verification/starter/starter.jax.py create mode 100644 challenges/medium/87_speculative_decoding_verification/starter/starter.mojo create mode 100644 challenges/medium/87_speculative_decoding_verification/starter/starter.pytorch.py create mode 100644 challenges/medium/87_speculative_decoding_verification/starter/starter.triton.py diff --git a/challenges/medium/87_speculative_decoding_verification/challenge.html b/challenges/medium/87_speculative_decoding_verification/challenge.html new file mode 100644 index 00000000..da02f983 --- /dev/null +++ b/challenges/medium/87_speculative_decoding_verification/challenge.html @@ -0,0 +1,148 @@ +

+ Implement the token verification step of speculative decoding. In speculative decoding, a small + draft model proposes T tokens; the target model then evaluates all of them in a + single forward pass and accepts or rejects each token. Given a batch of B draft + sequences, the per-position acceptance probabilities, and pre-generated uniform samples, produce + the final output token sequence for each sequence in the batch. All probability tensors use + float32; token index tensors use int32. +

+ + + + + + + + + + + Speculative Decoding Verification (T=4 draft tokens) + + + draft + + t₀ + + t₁ + + t₂ + + t₃ + + + + α₀=0.9 + + α₁=0.7 + + α₂=0.3 + + + test + + u=0.1✓ + + u=0.4✓ + + u=0.8✗ + + + + reject → resample + from clamp(q−p,0) + + + output + + t₀ ✓ + + t₁ ✓ + + t₂′ + + 0 + + + + + + + + (skipped) + + +

+ For each sequence b in the batch, process draft tokens left-to-right. For position + i, compute the acceptance probability + \(\alpha_i = \min\!\left(1,\, \frac{q_i(t_i)}{p_i(t_i)}\right)\) + where \(p_i(t_i)\) is the draft model probability and \(q_i(t_i)\) is the target model + probability for the draft token \(t_i\). Accept token \(t_i\) if + \(\texttt{uniform\_samples}[b, i] < \alpha_i\); otherwise reject and stop. On rejection at + position i, sample a replacement token from the adjusted distribution + \(\text{adj}(v) \propto \max(0,\, q_i(v) - p_i(v))\) + using \(\texttt{uniform\_samples}[b, T]\) via inverse CDF. If all T draft tokens + are accepted, sample one bonus token from \(q_{T-1}\) (the target distribution at the last + draft position) using \(\texttt{uniform\_samples}[b, T]\). Write the result into + output_tokens[b, :]: valid tokens fill positions \(0\) through + accepted_count (inclusive), remaining positions are zero. +

+ +

Implementation Requirements

+ + +

Example

+

+ With B = 1, T = 3, V = 4: +

+
+draft_tokens     = [[1, 2, 0]]
+
+draft_probs  pos 0: [0.10, 0.60, 0.20, 0.10]   # p(t=1) = 0.60
+             pos 1: [0.10, 0.20, 0.50, 0.20]   # p(t=2) = 0.50
+             pos 2: [0.40, 0.20, 0.20, 0.20]   # (not reached)
+
+target_probs pos 0: [0.10, 0.50, 0.20, 0.20]   # q(t=1) = 0.50
+             pos 1: [0.30, 0.20, 0.20, 0.30]   # q(t=2) = 0.20
+             pos 2: [0.30, 0.20, 0.30, 0.20]   # (not reached)
+
+uniform_samples = [[0.50, 0.70, 0.30, 0.90]]
+                     ^     ^     ^     ^ uniform_samples[0, T=3] for resampling
+
+

+ Position 0: \(\alpha_0 = \min(1,\, 0.50/0.60) \approx 0.833\). Since \(0.50 < 0.833\), accept + token 1.
+ Position 1: \(\alpha_1 = \min(1,\, 0.20/0.50) = 0.40\). Since \(0.70 \ge 0.40\), reject. + Adjusted distribution: \(\max(0,\,[0.30\!-\!0.10,\,0.20\!-\!0.20,\,0.20\!-\!0.50,\,0.30\!-\!0.20]) + = [0.20,\,0,\,0,\,0.10]\), normalized \(= [2/3,\,0,\,0,\,1/3]\). + CDF \(= [0.667,\,0.667,\,0.667,\,1.0]\). With \(r = 0.90\), inverse CDF gives token + 3. +

+
+output_tokens = [[1, 3, 0, 0]]
+
+ +

Constraints

+ diff --git a/challenges/medium/87_speculative_decoding_verification/challenge.py b/challenges/medium/87_speculative_decoding_verification/challenge.py new file mode 100644 index 00000000..363c849a --- /dev/null +++ b/challenges/medium/87_speculative_decoding_verification/challenge.py @@ -0,0 +1,300 @@ +import ctypes +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="Speculative Decoding Verification", + atol=1e-05, + rtol=1e-05, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + draft_tokens: torch.Tensor, + draft_probs: torch.Tensor, + target_probs: torch.Tensor, + uniform_samples: torch.Tensor, + output_tokens: torch.Tensor, + B: int, + T: int, + V: int, + ): + assert draft_tokens.shape == (B, T) + assert draft_probs.shape == (B, T, V) + assert target_probs.shape == (B, T, V) + assert uniform_samples.shape == (B, T + 1) + assert output_tokens.shape == (B, T + 1) + assert draft_tokens.dtype == torch.int32 + assert draft_probs.dtype == torch.float32 + assert target_probs.dtype == torch.float32 + assert uniform_samples.dtype == torch.float32 + assert output_tokens.dtype == torch.int32 + assert draft_tokens.device.type == "cuda" + assert draft_probs.device.type == "cuda" + assert target_probs.device.type == "cuda" + assert uniform_samples.device.type == "cuda" + assert output_tokens.device.type == "cuda" + + output_tokens.fill_(0) + + for b in range(B): + for i in range(T): + tok = int(draft_tokens[b, i].item()) + p = draft_probs[b, i, tok].item() + q = target_probs[b, i, tok].item() + alpha = min(1.0, q / p) + + if uniform_samples[b, i].item() < alpha: + output_tokens[b, i] = tok + else: + adjusted = torch.clamp(target_probs[b, i] - draft_probs[b, i], min=0.0) + total = adjusted.sum().item() + if total > 0.0: + adjusted = adjusted / total + else: + adjusted = ( + torch.ones(V, dtype=torch.float32, device=draft_tokens.device) / V + ) + cdf = torch.cumsum(adjusted, dim=0) + r = float(uniform_samples[b, T].item()) + new_tok = int(torch.searchsorted(cdf.contiguous(), r).item()) + output_tokens[b, i] = min(new_tok, V - 1) + break + else: + cdf = torch.cumsum(target_probs[b, T - 1], dim=0) + r = float(uniform_samples[b, T].item()) + bonus_tok = int(torch.searchsorted(cdf.contiguous(), r).item()) + output_tokens[b, T] = min(bonus_tok, V - 1) + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "draft_tokens": (ctypes.POINTER(ctypes.c_int), "in"), + "draft_probs": (ctypes.POINTER(ctypes.c_float), "in"), + "target_probs": (ctypes.POINTER(ctypes.c_float), "in"), + "uniform_samples": (ctypes.POINTER(ctypes.c_float), "in"), + "output_tokens": (ctypes.POINTER(ctypes.c_int), "out"), + "B": (ctypes.c_int, "in"), + "T": (ctypes.c_int, "in"), + "V": (ctypes.c_int, "in"), + } + + def _make_sparse_probs(self, B, T, V, K, device): + """Generate sparse probability distributions: only K tokens have nonzero probability. + + Using sparse distributions ensures that the adjusted distribution clamp(q-p, 0) + has at most 2K nonzero entries, making CDF summation numerically exact regardless + of summation order. This prevents floating-point sensitivity for large V. + """ + K = min(K, V) + flat = B * T + # For each (b, i), sample K distinct token indices + idx = torch.stack([torch.randperm(V, device=device)[:K] for _ in range(flat)]) + idx = idx.view(B, T, K) + # Random weights summing to 1 + weights = torch.rand(B, T, K, device=device) + weights = weights / weights.sum(dim=-1, keepdim=True) + # Scatter into full V-dimensional probability vector + probs = torch.zeros(B, T, V, device=device) + probs.scatter_(2, idx, weights) + return probs, idx + + def _make_test_case(self, B, T, V, seed=42): + torch.manual_seed(seed) + device = "cuda" + + # K=64 active tokens per position: enough diversity while keeping the adjusted + # distribution sparse (at most 128 nonzero entries), ensuring CDF sums are + # independent of floating-point summation order. + K = min(64, V) + draft_probs, draft_idx = self._make_sparse_probs(B, T, V, K, device) + target_probs, _ = self._make_sparse_probs(B, T, V, K, device) + + # Sample draft tokens from the active K tokens + weights = draft_probs.gather(2, draft_idx) # (B, T, K) + flat_w = weights.view(B * T, K) + chosen = torch.multinomial(flat_w, 1).view(B, T) # index within the K tokens + draft_tokens = draft_idx.gather(2, chosen.unsqueeze(-1)).squeeze(-1).to(torch.int32) + + uniform_samples = torch.rand(B, T + 1, device=device) + output_tokens = torch.zeros(B, T + 1, device=device, dtype=torch.int32) + + return { + "draft_tokens": draft_tokens, + "draft_probs": draft_probs, + "target_probs": target_probs, + "uniform_samples": uniform_samples, + "output_tokens": output_tokens, + "B": B, + "T": T, + "V": V, + } + + def _make_accept_all_case(self, B, T, V, seed=42): + """All draft tokens accepted: target_probs == draft_probs so alpha == 1 everywhere.""" + torch.manual_seed(seed) + device = "cuda" + + K = min(64, V) + draft_probs, draft_idx = self._make_sparse_probs(B, T, V, K, device) + target_probs = draft_probs.clone() # alpha = min(1, q/p) = 1 → always accept + + weights = draft_probs.gather(2, draft_idx) + flat_w = weights.view(B * T, K) + chosen = torch.multinomial(flat_w, 1).view(B, T) + draft_tokens = draft_idx.gather(2, chosen.unsqueeze(-1)).squeeze(-1).to(torch.int32) + + # All acceptance samples set to 0 (< 1.0 = alpha) to guarantee acceptance + uniform_samples = torch.zeros(B, T + 1, device=device) + uniform_samples[:, T] = torch.rand(B, device=device) # bonus sampling sample + + output_tokens = torch.zeros(B, T + 1, device=device, dtype=torch.int32) + + return { + "draft_tokens": draft_tokens, + "draft_probs": draft_probs, + "target_probs": target_probs, + "uniform_samples": uniform_samples, + "output_tokens": output_tokens, + "B": B, + "T": T, + "V": V, + } + + def _make_reject_first_case(self, B, T, V, seed=42): + """First draft token always rejected: draft_probs high, target low for that token.""" + torch.manual_seed(seed) + device = "cuda" + + draft_probs = torch.softmax(torch.randn(B, T, V, device=device), dim=-1) + target_probs = torch.softmax(torch.randn(B, T, V, device=device), dim=-1) + + flat = draft_probs.view(B * T, V) + draft_tokens = torch.multinomial(flat, 1).view(B, T).to(torch.int32) + + # Force rejection at position 0 for every sequence: + # set alpha[b,0] very small and uniform_sample[b,0] high enough to reject + for b in range(B): + tok = int(draft_tokens[b, 0].item()) + # Make draft prob ~0.9 for the chosen token (high p) + draft_probs[b, 0] = torch.full((V,), 0.1 / max(V - 1, 1), device=device) + draft_probs[b, 0, tok] = 0.9 + draft_probs[b, 0] = draft_probs[b, 0] / draft_probs[b, 0].sum() + # Make target prob ~1/V for the same token (low q) + target_probs[b, 0] = torch.ones(V, device=device) / V + + uniform_samples = torch.rand(B, T + 1, device=device) + # Force uniform[b, 0] = 0.99 > alpha (which is ~1/V / 0.9 ≈ small) + uniform_samples[:, 0] = 0.99 + + output_tokens = torch.zeros(B, T + 1, device=device, dtype=torch.int32) + + return { + "draft_tokens": draft_tokens, + "draft_probs": draft_probs, + "target_probs": target_probs, + "uniform_samples": uniform_samples, + "output_tokens": output_tokens, + "B": B, + "T": T, + "V": V, + } + + def generate_example_test(self) -> Dict[str, Any]: + device = "cuda" + + # B=1, T=3, V=4: position 0 accepted, position 1 rejected, token resampled + draft_tokens = torch.tensor([[1, 2, 0]], device=device, dtype=torch.int32) + + draft_probs = torch.tensor( + [ + [ + [0.10, 0.60, 0.20, 0.10], # pos 0: draft_tokens[0,0]=1, p=0.60 + [0.10, 0.20, 0.50, 0.20], # pos 1: draft_tokens[0,1]=2, p=0.50 + [0.40, 0.20, 0.20, 0.20], # pos 2: draft_tokens[0,2]=0, p=0.40 + ] + ], + device=device, + dtype=torch.float32, + ) + + target_probs = torch.tensor( + [ + [ + [0.10, 0.50, 0.20, 0.20], # pos 0: q=0.50, alpha=min(1,0.50/0.60)=0.833 + [0.30, 0.20, 0.20, 0.30], # pos 1: q=0.20, alpha=min(1,0.20/0.50)=0.400 + [0.30, 0.20, 0.30, 0.20], # pos 2: not reached + ] + ], + device=device, + dtype=torch.float32, + ) + + # uniform_samples[0, 0]=0.50 < 0.833 → ACCEPT token 1 + # uniform_samples[0, 1]=0.70 > 0.400 → REJECT token 2 + # adjusted = clamp([0.20, 0, -0.30, 0.10], min=0) = [0.20, 0, 0, 0.10] + # normalized CDF = [0.667, 0.667, 0.667, 1.0] + # uniform_samples[0, T=3]=0.90 → searchsorted → token 3 + # output_tokens[0] = [1, 3, 0, 0] + uniform_samples = torch.tensor( + [[0.50, 0.70, 0.30, 0.90]], device=device, dtype=torch.float32 + ) + + output_tokens = torch.zeros(1, 4, device=device, dtype=torch.int32) + + return { + "draft_tokens": draft_tokens, + "draft_probs": draft_probs, + "target_probs": target_probs, + "uniform_samples": uniform_samples, + "output_tokens": output_tokens, + "B": 1, + "T": 3, + "V": 4, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + tests = [] + + # Edge: T=1, rejected immediately + tests.append(self._make_reject_first_case(1, 1, 4, seed=1)) + + # Edge: T=1, all accepted (bonus token sampled) + tests.append(self._make_accept_all_case(1, 1, 4, seed=2)) + + # Edge: T=2, first rejected + tests.append(self._make_reject_first_case(1, 2, 8, seed=3)) + + # Edge: T=4, all accepted + tests.append(self._make_accept_all_case(2, 4, 8, seed=4)) + + # Zero uniform_samples acceptance values → force rejection at pos 0 (unless alpha=1) + tests.append(self._make_reject_first_case(4, 4, 16, seed=5)) + + # Power-of-2 vocab, mixed acceptance + tests.append(self._make_test_case(4, 8, 64, seed=10)) + + # Larger vocab, mixed acceptance + tests.append(self._make_test_case(8, 8, 256, seed=20)) + + # Non-power-of-2 vocab + tests.append(self._make_test_case(4, 6, 30, seed=30)) + + # All sequences accept all tokens (bonus sampling) + tests.append(self._make_accept_all_case(8, 8, 128, seed=40)) + + # Realistic small batch + tests.append(self._make_test_case(16, 8, 1000, seed=50)) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + torch.manual_seed(0) + # B=64 sequences, T=8 draft tokens, V=32768 (Mistral/LLaMA-2 vocab size) + return self._make_test_case(64, 8, 32768, seed=0) diff --git a/challenges/medium/87_speculative_decoding_verification/starter/starter.cu b/challenges/medium/87_speculative_decoding_verification/starter/starter.cu new file mode 100644 index 00000000..6a9731a8 --- /dev/null +++ b/challenges/medium/87_speculative_decoding_verification/starter/starter.cu @@ -0,0 +1,5 @@ +#include + +// draft_tokens, draft_probs, target_probs, uniform_samples, output_tokens are device pointers +extern "C" void solve(const int* draft_tokens, const float* draft_probs, const float* target_probs, + const float* uniform_samples, int* output_tokens, int B, int T, int V) {} diff --git a/challenges/medium/87_speculative_decoding_verification/starter/starter.cute.py b/challenges/medium/87_speculative_decoding_verification/starter/starter.cute.py new file mode 100644 index 00000000..320555b2 --- /dev/null +++ b/challenges/medium/87_speculative_decoding_verification/starter/starter.cute.py @@ -0,0 +1,17 @@ +import cutlass +import cutlass.cute as cute + + +# draft_tokens, draft_probs, target_probs, uniform_samples, output_tokens are tensors on the GPU +@cute.jit +def solve( + draft_tokens: cute.Tensor, + draft_probs: cute.Tensor, + target_probs: cute.Tensor, + uniform_samples: cute.Tensor, + output_tokens: cute.Tensor, + B: cute.Int32, + T: cute.Int32, + V: cute.Int32, +): + pass diff --git a/challenges/medium/87_speculative_decoding_verification/starter/starter.jax.py b/challenges/medium/87_speculative_decoding_verification/starter/starter.jax.py new file mode 100644 index 00000000..bd7e8e18 --- /dev/null +++ b/challenges/medium/87_speculative_decoding_verification/starter/starter.jax.py @@ -0,0 +1,17 @@ +import jax +import jax.numpy as jnp + + +# draft_tokens, draft_probs, target_probs, uniform_samples are tensors on GPU +@jax.jit +def solve( + draft_tokens: jax.Array, + draft_probs: jax.Array, + target_probs: jax.Array, + uniform_samples: jax.Array, + B: int, + T: int, + V: int, +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/medium/87_speculative_decoding_verification/starter/starter.mojo b/challenges/medium/87_speculative_decoding_verification/starter/starter.mojo new file mode 100644 index 00000000..89c9fca2 --- /dev/null +++ b/challenges/medium/87_speculative_decoding_verification/starter/starter.mojo @@ -0,0 +1,16 @@ +from gpu.host import DeviceContext +from memory import UnsafePointer + +# draft_tokens, draft_probs, target_probs, uniform_samples, output_tokens are device pointers +@export +def solve( + draft_tokens: UnsafePointer[Int32], + draft_probs: UnsafePointer[Float32], + target_probs: UnsafePointer[Float32], + uniform_samples: UnsafePointer[Float32], + output_tokens: UnsafePointer[Int32], + B: Int32, + T: Int32, + V: Int32, +): + pass diff --git a/challenges/medium/87_speculative_decoding_verification/starter/starter.pytorch.py b/challenges/medium/87_speculative_decoding_verification/starter/starter.pytorch.py new file mode 100644 index 00000000..3cce7fae --- /dev/null +++ b/challenges/medium/87_speculative_decoding_verification/starter/starter.pytorch.py @@ -0,0 +1,15 @@ +import torch + + +# draft_tokens, draft_probs, target_probs, uniform_samples, output_tokens are tensors on the GPU +def solve( + draft_tokens: torch.Tensor, + draft_probs: torch.Tensor, + target_probs: torch.Tensor, + uniform_samples: torch.Tensor, + output_tokens: torch.Tensor, + B: int, + T: int, + V: int, +): + pass diff --git a/challenges/medium/87_speculative_decoding_verification/starter/starter.triton.py b/challenges/medium/87_speculative_decoding_verification/starter/starter.triton.py new file mode 100644 index 00000000..1347ec9c --- /dev/null +++ b/challenges/medium/87_speculative_decoding_verification/starter/starter.triton.py @@ -0,0 +1,17 @@ +import torch +import triton +import triton.language as tl + + +# draft_tokens, draft_probs, target_probs, uniform_samples, output_tokens are tensors on the GPU +def solve( + draft_tokens: torch.Tensor, + draft_probs: torch.Tensor, + target_probs: torch.Tensor, + uniform_samples: torch.Tensor, + output_tokens: torch.Tensor, + B: int, + T: int, + V: int, +): + pass From 5543c1824d90145febbdfeb0db1a71342c669e51 Mon Sep 17 00:00:00 2001 From: James Song Date: Thu, 26 Mar 2026 23:08:42 -0400 Subject: [PATCH 2/2] Improve speculative decoding HTML: SVG, LaTeX, concise description MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Redesign SVG with wider boxes to prevent text overflow, add probs row showing p(t) and q(t) values per position. Convert description to notation list + numbered algorithm steps with equations. Convert example from
 to LaTeX. Fix \texttt{uniform\_samples} → plain underscore.

Co-Authored-By: Claude Opus 4.6 (1M context) 
---
 .../challenge.html                            | 253 ++++++++++--------
 1 file changed, 146 insertions(+), 107 deletions(-)

diff --git a/challenges/medium/87_speculative_decoding_verification/challenge.html b/challenges/medium/87_speculative_decoding_verification/challenge.html
index da02f983..74e3ff19 100644
--- a/challenges/medium/87_speculative_decoding_verification/challenge.html
+++ b/challenges/medium/87_speculative_decoding_verification/challenge.html
@@ -1,90 +1,116 @@
 

- Implement the token verification step of speculative decoding. In speculative decoding, a small - draft model proposes T tokens; the target model then evaluates all of them in a - single forward pass and accepts or rejects each token. Given a batch of B draft - sequences, the per-position acceptance probabilities, and pre-generated uniform samples, produce - the final output token sequence for each sequence in the batch. All probability tensors use - float32; token index tensors use int32. + Implement the token verification step of speculative decoding. A draft model proposes \(T\) tokens; + the target model evaluates them in one forward pass and accepts or rejects each. Given \(B\) + sequences, produce the verified output tokens. Probability tensors are float32; + token tensors are int32.

- - - - - - - - - - Speculative Decoding Verification (T=4 draft tokens) - - - draft - - t₀ - - t₁ - - t₂ - - t₃ - - - - α₀=0.9 - - α₁=0.7 - - α₂=0.3 - - - test - - u=0.1✓ - - u=0.4✓ - - u=0.8✗ - - - - reject → resample - from clamp(q−p,0) - - - output - - t₀ ✓ - - t₁ ✓ - - t₂′ - - 0 - - - - - - - - (skipped) +

+ Notation for each sequence \(b\), at each draft position \(i = 0, \ldots, T{-}1\): +

+
    +
  • \(t_i = \texttt{draft_tokens}[b, i]\) — the token proposed by the draft model
  • +
  • \(p_i(v) = \texttt{draft_probs}[b, i, v]\) — draft model's probability for token \(v\)
  • +
  • \(q_i(v) = \texttt{target_probs}[b, i, v]\) — target model's probability for token \(v\)
  • +
  • \(u_i = \texttt{uniform_samples}[b, i]\) — pre-generated \(U[0,1)\) sample for position \(i\)
  • +
+ + + + + + pos 0 + pos 1 + pos 2 + pos 3 + + + draft + + t₀ + + t₁ + + t₂ + + t₃ + + + probs + + p(t₀) = 0.60 + q(t₀) = 0.50 + + + p(t₁) = 0.50 + q(t₁) = 0.20 + + + not reached + + + not reached + + + α, test + + α = .50/.60 = .83 + u=0.1 < .83 ✓ + + + α = .20/.50 = .40 + u=0.7 ≥ .40 ✗ + + + skipped + + + skipped + + + + reject at pos 1 → stop, resample from adj(v) = max(0, q(v) − p(v)) + normalize adj, inverse-CDF sample using u[b, T] → replacement token t₁′ + + + output + + t₀ + + t₁′ + + 0 + + 0 + + + p = draft prob + q = target prob + α = min(1, q/p) + ■ accepted + ■ resampled + ■ pad + + + If all T tokens accepted: sample bonus token from q at last position using u[b, T]

- For each sequence b in the batch, process draft tokens left-to-right. For position - i, compute the acceptance probability - \(\alpha_i = \min\!\left(1,\, \frac{q_i(t_i)}{p_i(t_i)}\right)\) - where \(p_i(t_i)\) is the draft model probability and \(q_i(t_i)\) is the target model - probability for the draft token \(t_i\). Accept token \(t_i\) if - \(\texttt{uniform\_samples}[b, i] < \alpha_i\); otherwise reject and stop. On rejection at - position i, sample a replacement token from the adjusted distribution - \(\text{adj}(v) \propto \max(0,\, q_i(v) - p_i(v))\) - using \(\texttt{uniform\_samples}[b, T]\) via inverse CDF. If all T draft tokens - are accepted, sample one bonus token from \(q_{T-1}\) (the target distribution at the last - draft position) using \(\texttt{uniform\_samples}[b, T]\). Write the result into - output_tokens[b, :]: valid tokens fill positions \(0\) through - accepted_count (inclusive), remaining positions are zero. + For each sequence \(b\), process positions \(i = 0, 1, \ldots, T{-}1\) left-to-right: +

+
    +
  1. Compute acceptance probability: \(\displaystyle \alpha_i = \min\!\left(1,\; \frac{q_i(t_i)}{p_i(t_i)}\right)\)
  2. +
  3. If \(u_i < \alpha_i\): accept \(t_i\), continue to position \(i{+}1\).
  4. +
  5. If \(u_i \ge \alpha_i\): reject, stop. Sample replacement from: + \[\text{adj}(v) = \frac{\max(0,\; q_i(v) - p_i(v))}{\sum_{v'} \max(0,\; q_i(v') - p_i(v'))}\] + using inverse CDF with \(r = \texttt{uniform_samples}[b, T]\). If \(\text{adj}\) is all zeros, use uniform \(1/V\). +
  6. +
  7. If all \(T\) tokens accepted: sample a bonus token from \(q_{T-1}\) using \(\texttt{uniform_samples}[b, T]\).
  8. +
+

+ Write results into output_tokens[b, :] (shape \([B, T{+}1]\)): accepted/resampled tokens + fill positions \(0\) through the accepted count (inclusive), remaining positions are zero.

Implementation Requirements

@@ -96,7 +122,7 @@

Implementation Requirements

  • Inverse CDF sampling: given distribution \(\text{adj}\) (already normalized), find the smallest index \(k\) where \(\sum_{v=0}^{k} \text{adj}(v) \ge r\), where - \(r = \texttt{uniform\_samples}[b, T]\). Clamp the result to \([0, V-1]\). + \(r = \texttt{uniform_samples}[b, T]\). Clamp the result to \([0, V-1]\).
  • If the adjusted distribution is all zeros (i.e., \(q_i \le p_i\) everywhere), fall back to @@ -106,34 +132,47 @@

    Implementation Requirements

    Example

    - With B = 1, T = 3, V = 4: + Input: \(B = 1,\; T = 3,\; V = 4\) +

    +

    + \(\text{draft_tokens} = [1, 2, 0]\) +

    +

    + Draft probabilities \(p_i\) and target probabilities \(q_i\) per position: + \[ + p_0 = \begin{bmatrix} 0.10 & 0.60 & 0.20 & 0.10 \end{bmatrix}, \quad + q_0 = \begin{bmatrix} 0.10 & 0.50 & 0.20 & 0.20 \end{bmatrix} + \] + \[ + p_1 = \begin{bmatrix} 0.10 & 0.20 & 0.50 & 0.20 \end{bmatrix}, \quad + q_1 = \begin{bmatrix} 0.30 & 0.20 & 0.20 & 0.30 \end{bmatrix} + \] + \[ + \text{uniform_samples} = \begin{bmatrix} 0.50 & 0.70 & 0.30 & 0.90 \end{bmatrix} + \] +

    +

    + Position 0 (draft token = 1): + \(\alpha_0 = \min\!\left(1,\, \frac{q_0(1)}{p_0(1)}\right) = \min\!\left(1,\, \frac{0.50}{0.60}\right) \approx 0.833\). + Since \(u_0 = 0.50 < 0.833\), accept token 1. +

    +

    + Position 1 (draft token = 2): + \(\alpha_1 = \min\!\left(1,\, \frac{q_1(2)}{p_1(2)}\right) = \min\!\left(1,\, \frac{0.20}{0.50}\right) = 0.40\). + Since \(u_1 = 0.70 \ge 0.40\), reject. Resample from adjusted distribution: + \[ + \text{adj}(v) = \max(0,\, q_1(v) - p_1(v)) = [0.20,\, 0,\, 0,\, 0.10] + \] + \[ + \text{normalized} = \left[\tfrac{2}{3},\, 0,\, 0,\, \tfrac{1}{3}\right], \quad + \text{CDF} = [0.667,\, 0.667,\, 0.667,\, 1.0] + \] + With \(r = \text{uniform_samples}[0, T] = 0.90\), inverse CDF gives token 3.

    -
    -draft_tokens     = [[1, 2, 0]]
    -
    -draft_probs  pos 0: [0.10, 0.60, 0.20, 0.10]   # p(t=1) = 0.60
    -             pos 1: [0.10, 0.20, 0.50, 0.20]   # p(t=2) = 0.50
    -             pos 2: [0.40, 0.20, 0.20, 0.20]   # (not reached)
    -
    -target_probs pos 0: [0.10, 0.50, 0.20, 0.20]   # q(t=1) = 0.50
    -             pos 1: [0.30, 0.20, 0.20, 0.30]   # q(t=2) = 0.20
    -             pos 2: [0.30, 0.20, 0.30, 0.20]   # (not reached)
    -
    -uniform_samples = [[0.50, 0.70, 0.30, 0.90]]
    -                     ^     ^     ^     ^ uniform_samples[0, T=3] for resampling
    -

    - Position 0: \(\alpha_0 = \min(1,\, 0.50/0.60) \approx 0.833\). Since \(0.50 < 0.833\), accept - token 1.
    - Position 1: \(\alpha_1 = \min(1,\, 0.20/0.50) = 0.40\). Since \(0.70 \ge 0.40\), reject. - Adjusted distribution: \(\max(0,\,[0.30\!-\!0.10,\,0.20\!-\!0.20,\,0.20\!-\!0.50,\,0.30\!-\!0.20]) - = [0.20,\,0,\,0,\,0.10]\), normalized \(= [2/3,\,0,\,0,\,1/3]\). - CDF \(= [0.667,\,0.667,\,0.667,\,1.0]\). With \(r = 0.90\), inverse CDF gives token - 3. + Output: + \[\text{output_tokens} = \begin{bmatrix} 1 & 3 & 0 & 0 \end{bmatrix}\]

    -
    -output_tokens = [[1, 3, 0, 0]]
    -

    Constraints