Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
<p>
Implement the token verification step of speculative decoding. A draft model proposes \(T\) tokens;
the target model evaluates them in one forward pass and accepts or rejects each. Given \(B\)
sequences, produce the verified output tokens. Probability tensors are <code>float32</code>;
token tensors are <code>int32</code>.
</p>

<p>
Notation for each sequence \(b\), at each draft position \(i = 0, \ldots, T{-}1\):
</p>
<ul>
<li>\(t_i = \texttt{draft_tokens}[b, i]\) &mdash; the token proposed by the draft model</li>
<li>\(p_i(v) = \texttt{draft_probs}[b, i, v]\) &mdash; draft model's probability for token \(v\)</li>
<li>\(q_i(v) = \texttt{target_probs}[b, i, v]\) &mdash; target model's probability for token \(v\)</li>
<li>\(u_i = \texttt{uniform_samples}[b, i]\) &mdash; pre-generated \(U[0,1)\) sample for position \(i\)</li>
</ul>

<svg width="660" height="310" viewBox="0 0 660 310" xmlns="http://www.w3.org/2000/svg"
style="display:block; margin:20px auto; font-family:monospace;">
<rect width="660" height="310" fill="#222" rx="8"/>

<!-- Column headers -->
<text x="108" y="18" fill="#666" font-size="9" text-anchor="middle">pos 0</text>
<text x="248" y="18" fill="#666" font-size="9" text-anchor="middle">pos 1</text>
<text x="388" y="18" fill="#666" font-size="9" text-anchor="middle">pos 2</text>
<text x="528" y="18" fill="#666" font-size="9" text-anchor="middle">pos 3</text>

<!-- Row 1: Draft tokens -->
<text x="16" y="42" fill="#888" font-size="10">draft</text>
<rect x="56" y="28" width="104" height="24" rx="4" fill="#1e3a5f" stroke="#4477bb" stroke-width="1.5"/>
<text x="108" y="45" text-anchor="middle" fill="#8ec4f0" font-size="11">t&#x2080;</text>
<rect x="196" y="28" width="104" height="24" rx="4" fill="#1e3a5f" stroke="#4477bb" stroke-width="1.5"/>
<text x="248" y="45" text-anchor="middle" fill="#8ec4f0" font-size="11">t&#x2081;</text>
<rect x="336" y="28" width="104" height="24" rx="4" fill="#1e3a5f" stroke="#4477bb" stroke-width="1.5"/>
<text x="388" y="45" text-anchor="middle" fill="#8ec4f0" font-size="11">t&#x2082;</text>
<rect x="476" y="28" width="104" height="24" rx="4" fill="#1e3a5f" stroke="#4477bb" stroke-width="1.5"/>
<text x="528" y="45" text-anchor="middle" fill="#8ec4f0" font-size="11">t&#x2083;</text>

<!-- Row 2: Probabilities -->
<text x="16" y="76" fill="#888" font-size="10">probs</text>
<rect x="56" y="62" width="104" height="34" rx="4" fill="#1a1a1a" stroke="#666" stroke-width="1"/>
<text x="108" y="76" text-anchor="middle" fill="#c060e0" font-size="9">p(t&#x2080;) = 0.60</text>
<text x="108" y="89" text-anchor="middle" fill="#e0a040" font-size="9">q(t&#x2080;) = 0.50</text>

<rect x="196" y="62" width="104" height="34" rx="4" fill="#1a1a1a" stroke="#666" stroke-width="1"/>
<text x="248" y="76" text-anchor="middle" fill="#c060e0" font-size="9">p(t&#x2081;) = 0.50</text>
<text x="248" y="89" text-anchor="middle" fill="#e0a040" font-size="9">q(t&#x2081;) = 0.20</text>

<rect x="336" y="62" width="104" height="34" rx="4" fill="#2a2a2a" stroke="#555" stroke-width="1"/>
<text x="388" y="80" text-anchor="middle" fill="#555" font-size="9">not reached</text>

<rect x="476" y="62" width="104" height="34" rx="4" fill="#2a2a2a" stroke="#555" stroke-width="1"/>
<text x="528" y="80" text-anchor="middle" fill="#555" font-size="9">not reached</text>

<!-- Row 3: Alpha + accept/reject -->
<text x="16" y="124" fill="#888" font-size="10">&#x3b1;, test</text>
<rect x="56" y="108" width="104" height="40" rx="4" fill="#1a3a1a" stroke="#44aa66" stroke-width="1.5"/>
<text x="108" y="124" text-anchor="middle" fill="#aaa" font-size="9">&#x3b1; = .50/.60 = .83</text>
<text x="108" y="140" text-anchor="middle" fill="#44aa66" font-size="9">u=0.1 &lt; .83 &#x2713;</text>

<rect x="196" y="108" width="104" height="40" rx="4" fill="#4a1a1a" stroke="#e06060" stroke-width="1.5"/>
<text x="248" y="124" text-anchor="middle" fill="#aaa" font-size="9">&#x3b1; = .20/.50 = .40</text>
<text x="248" y="140" text-anchor="middle" fill="#e06060" font-size="9">u=0.7 &#x2265; .40 &#x2717;</text>

<rect x="336" y="108" width="104" height="40" rx="4" fill="#2a2a2a" stroke="#555" stroke-width="1"/>
<text x="388" y="132" text-anchor="middle" fill="#555" font-size="9">skipped</text>

<rect x="476" y="108" width="104" height="40" rx="4" fill="#2a2a2a" stroke="#555" stroke-width="1"/>
<text x="528" y="132" text-anchor="middle" fill="#555" font-size="9">skipped</text>

<!-- Row 4: Resample box -->
<rect x="56" y="164" width="524" height="38" rx="5" fill="#1a1a1a" stroke="#e06060" stroke-width="1"/>
<text x="318" y="180" text-anchor="middle" fill="#e06060" font-size="10">reject at pos 1 &#x2192; stop, resample from adj(v) = max(0, q(v) &#x2212; p(v))</text>
<text x="318" y="194" text-anchor="middle" fill="#aaa" font-size="9">normalize adj, inverse-CDF sample using u[b, T] &#x2192; replacement token t&#x2081;&#x2032;</text>

<!-- Row 5: Output tokens -->
<text x="16" y="224" fill="#888" font-size="10">output</text>
<rect x="56" y="212" width="104" height="24" rx="4" fill="#1e3a5f" stroke="#4477bb" stroke-width="1.5"/>
<text x="108" y="229" text-anchor="middle" fill="#8ec4f0" font-size="11">t&#x2080;</text>
<rect x="196" y="212" width="104" height="24" rx="4" fill="#3a2010" stroke="#e0a040" stroke-width="1.5"/>
<text x="248" y="229" text-anchor="middle" fill="#f0b060" font-size="11">t&#x2081;&#x2032;</text>
<rect x="336" y="212" width="104" height="24" rx="4" fill="#2a2a2a" stroke="#555" stroke-width="1"/>
<text x="388" y="229" text-anchor="middle" fill="#555" font-size="11">0</text>
<rect x="476" y="212" width="104" height="24" rx="4" fill="#2a2a2a" stroke="#555" stroke-width="1"/>
<text x="528" y="229" text-anchor="middle" fill="#555" font-size="11">0</text>

<!-- Legend -->
<text x="16" y="260" fill="#c060e0" font-size="9">p = draft prob</text>
<text x="130" y="260" fill="#e0a040" font-size="9">q = target prob</text>
<text x="260" y="260" fill="#aaa" font-size="9">&#x3b1; = min(1, q/p)</text>
<text x="400" y="260" fill="#44aa66" font-size="9">&#x25a0; accepted</text>
<text x="490" y="260" fill="#e0a040" font-size="9">&#x25a0; resampled</text>
<text x="590" y="260" fill="#555" font-size="9">&#x25a0; pad</text>

<!-- All-accept note -->
<text x="330" y="290" text-anchor="middle" fill="#888" font-size="9">If all T tokens accepted: sample bonus token from q at last position using u[b, T]</text>
</svg>

<p>
For each sequence \(b\), process positions \(i = 0, 1, \ldots, T{-}1\) left-to-right:
</p>
<ol>
<li>Compute acceptance probability: \(\displaystyle \alpha_i = \min\!\left(1,\; \frac{q_i(t_i)}{p_i(t_i)}\right)\)</li>
<li>If \(u_i < \alpha_i\): <strong>accept</strong> \(t_i\), continue to position \(i{+}1\).</li>
<li>If \(u_i \ge \alpha_i\): <strong>reject</strong>, stop. Sample replacement from:
\[\text{adj}(v) = \frac{\max(0,\; q_i(v) - p_i(v))}{\sum_{v'} \max(0,\; q_i(v') - p_i(v'))}\]
using inverse CDF with \(r = \texttt{uniform_samples}[b, T]\). If \(\text{adj}\) is all zeros, use uniform \(1/V\).
</li>
<li>If all \(T\) tokens accepted: sample a <strong>bonus token</strong> from \(q_{T-1}\) using \(\texttt{uniform_samples}[b, T]\).</li>
</ol>
<p>
Write results into <code>output_tokens[b, :]</code> (shape \([B, T{+}1]\)): accepted/resampled tokens
fill positions \(0\) through the accepted count (inclusive), remaining positions are zero.
</p>

<h2>Implementation Requirements</h2>
<ul>
<li>Implement <code>solve(draft_tokens, draft_probs, target_probs, uniform_samples, output_tokens, B, T, V)</code>.</li>
<li>Do not change the function signature or use external libraries beyond the standard GPU frameworks.</li>
<li>Write results into the provided <code>output_tokens</code> buffer (shape <code>[B, T+1]</code>, <code>int32</code>).</li>
<li>Memory layout is row-major: <code>draft_probs[b, i, v]</code> is at offset <code>b*T*V + i*V + v</code>.</li>
<li>
Inverse CDF sampling: given distribution \(\text{adj}\) (already normalized), find the
smallest index \(k\) where \(\sum_{v=0}^{k} \text{adj}(v) \ge r\), where
\(r = \texttt{uniform_samples}[b, T]\). Clamp the result to \([0, V-1]\).
</li>
<li>
If the adjusted distribution is all zeros (i.e., \(q_i \le p_i\) everywhere), fall back to
the uniform distribution over \(V\) tokens.
</li>
</ul>

<h2>Example</h2>
<p>
Input: \(B = 1,\; T = 3,\; V = 4\)
</p>
<p>
\(\text{draft_tokens} = [1, 2, 0]\)
</p>
<p>
Draft probabilities \(p_i\) and target probabilities \(q_i\) per position:
\[
p_0 = \begin{bmatrix} 0.10 & 0.60 & 0.20 & 0.10 \end{bmatrix}, \quad
q_0 = \begin{bmatrix} 0.10 & 0.50 & 0.20 & 0.20 \end{bmatrix}
\]
\[
p_1 = \begin{bmatrix} 0.10 & 0.20 & 0.50 & 0.20 \end{bmatrix}, \quad
q_1 = \begin{bmatrix} 0.30 & 0.20 & 0.20 & 0.30 \end{bmatrix}
\]
\[
\text{uniform_samples} = \begin{bmatrix} 0.50 & 0.70 & 0.30 & 0.90 \end{bmatrix}
\]
</p>
<p>
<strong>Position 0</strong> (draft token = 1):
\(\alpha_0 = \min\!\left(1,\, \frac{q_0(1)}{p_0(1)}\right) = \min\!\left(1,\, \frac{0.50}{0.60}\right) \approx 0.833\).
Since \(u_0 = 0.50 < 0.833\), <strong>accept</strong> token 1.
</p>
<p>
<strong>Position 1</strong> (draft token = 2):
\(\alpha_1 = \min\!\left(1,\, \frac{q_1(2)}{p_1(2)}\right) = \min\!\left(1,\, \frac{0.20}{0.50}\right) = 0.40\).
Since \(u_1 = 0.70 \ge 0.40\), <strong>reject</strong>. Resample from adjusted distribution:
\[
\text{adj}(v) = \max(0,\, q_1(v) - p_1(v)) = [0.20,\, 0,\, 0,\, 0.10]
\]
\[
\text{normalized} = \left[\tfrac{2}{3},\, 0,\, 0,\, \tfrac{1}{3}\right], \quad
\text{CDF} = [0.667,\, 0.667,\, 0.667,\, 1.0]
\]
With \(r = \text{uniform_samples}[0, T] = 0.90\), inverse CDF gives token <strong>3</strong>.
</p>
<p>
Output:
\[\text{output_tokens} = \begin{bmatrix} 1 & 3 & 0 & 0 \end{bmatrix}\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>B</code> &le; 256</li>
<li>1 &le; <code>T</code> &le; 16</li>
<li>2 &le; <code>V</code> &le; 131,072</li>
<li><code>draft_probs[b, i, :]</code> and <code>target_probs[b, i, :]</code> are valid probability distributions (non-negative, sum to 1)</li>
<li><code>draft_probs[b, i, draft_tokens[b, i]]</code> &gt; 0 for all <code>b</code>, <code>i</code></li>
<li><code>uniform_samples</code> values are in \([0, 1)\)</li>
<li>All floating-point tensors use <code>float32</code>; token tensors use <code>int32</code></li>
<li>Performance is measured with <code>B</code> = 64, <code>T</code> = 8, <code>V</code> = 32,768</li>
</ul>
Loading
Loading