diff --git a/challenges/medium/87_speculative_decoding_verification/challenge.html b/challenges/medium/87_speculative_decoding_verification/challenge.html
new file mode 100644
index 00000000..74e3ff19
--- /dev/null
+++ b/challenges/medium/87_speculative_decoding_verification/challenge.html
@@ -0,0 +1,187 @@
+
+ Implement the token verification step of speculative decoding. A draft model proposes \(T\) tokens;
+ the target model evaluates them in one forward pass and accepts or rejects each. Given \(B\)
+ sequences, produce the verified output tokens. Probability tensors are float32;
+ token tensors are int32.
+
+
+
+ Notation for each sequence \(b\), at each draft position \(i = 0, \ldots, T{-}1\):
+
+
+ \(t_i = \texttt{draft_tokens}[b, i]\) — the token proposed by the draft model
+ \(p_i(v) = \texttt{draft_probs}[b, i, v]\) — draft model's probability for token \(v\)
+ \(q_i(v) = \texttt{target_probs}[b, i, v]\) — target model's probability for token \(v\)
+ \(u_i = \texttt{uniform_samples}[b, i]\) — pre-generated \(U[0,1)\) sample for position \(i\)
+
+
+
+
+
+
+ pos 0
+ pos 1
+ pos 2
+ pos 3
+
+
+ draft
+
+ t₀
+
+ t₁
+
+ t₂
+
+ t₃
+
+
+ probs
+
+ p(t₀) = 0.60
+ q(t₀) = 0.50
+
+
+ p(t₁) = 0.50
+ q(t₁) = 0.20
+
+
+ not reached
+
+
+ not reached
+
+
+ α, test
+
+ α = .50/.60 = .83
+ u=0.1 < .83 ✓
+
+
+ α = .20/.50 = .40
+ u=0.7 ≥ .40 ✗
+
+
+ skipped
+
+
+ skipped
+
+
+
+ reject at pos 1 → stop, resample from adj(v) = max(0, q(v) − p(v))
+ normalize adj, inverse-CDF sample using u[b, T] → replacement token t₁′
+
+
+ output
+
+ t₀
+
+ t₁′
+
+ 0
+
+ 0
+
+
+ p = draft prob
+ q = target prob
+ α = min(1, q/p)
+ ■ accepted
+ ■ resampled
+ ■ pad
+
+
+ If all T tokens accepted: sample bonus token from q at last position using u[b, T]
+
+
+
+ For each sequence \(b\), process positions \(i = 0, 1, \ldots, T{-}1\) left-to-right:
+
+
+ Compute acceptance probability: \(\displaystyle \alpha_i = \min\!\left(1,\; \frac{q_i(t_i)}{p_i(t_i)}\right)\)
+ If \(u_i < \alpha_i\): accept \(t_i\), continue to position \(i{+}1\).
+ If \(u_i \ge \alpha_i\): reject , stop. Sample replacement from:
+ \[\text{adj}(v) = \frac{\max(0,\; q_i(v) - p_i(v))}{\sum_{v'} \max(0,\; q_i(v') - p_i(v'))}\]
+ using inverse CDF with \(r = \texttt{uniform_samples}[b, T]\). If \(\text{adj}\) is all zeros, use uniform \(1/V\).
+
+ If all \(T\) tokens accepted: sample a bonus token from \(q_{T-1}\) using \(\texttt{uniform_samples}[b, T]\).
+
+
+ Write results into output_tokens[b, :] (shape \([B, T{+}1]\)): accepted/resampled tokens
+ fill positions \(0\) through the accepted count (inclusive), remaining positions are zero.
+
+
+Implementation Requirements
+
+ Implement solve(draft_tokens, draft_probs, target_probs, uniform_samples, output_tokens, B, T, V).
+ Do not change the function signature or use external libraries beyond the standard GPU frameworks.
+ Write results into the provided output_tokens buffer (shape [B, T+1], int32).
+ Memory layout is row-major: draft_probs[b, i, v] is at offset b*T*V + i*V + v.
+
+ Inverse CDF sampling: given distribution \(\text{adj}\) (already normalized), find the
+ smallest index \(k\) where \(\sum_{v=0}^{k} \text{adj}(v) \ge r\), where
+ \(r = \texttt{uniform_samples}[b, T]\). Clamp the result to \([0, V-1]\).
+
+
+ If the adjusted distribution is all zeros (i.e., \(q_i \le p_i\) everywhere), fall back to
+ the uniform distribution over \(V\) tokens.
+
+
+
+Example
+
+ Input: \(B = 1,\; T = 3,\; V = 4\)
+
+
+ \(\text{draft_tokens} = [1, 2, 0]\)
+
+
+ Draft probabilities \(p_i\) and target probabilities \(q_i\) per position:
+ \[
+ p_0 = \begin{bmatrix} 0.10 & 0.60 & 0.20 & 0.10 \end{bmatrix}, \quad
+ q_0 = \begin{bmatrix} 0.10 & 0.50 & 0.20 & 0.20 \end{bmatrix}
+ \]
+ \[
+ p_1 = \begin{bmatrix} 0.10 & 0.20 & 0.50 & 0.20 \end{bmatrix}, \quad
+ q_1 = \begin{bmatrix} 0.30 & 0.20 & 0.20 & 0.30 \end{bmatrix}
+ \]
+ \[
+ \text{uniform_samples} = \begin{bmatrix} 0.50 & 0.70 & 0.30 & 0.90 \end{bmatrix}
+ \]
+
+
+ Position 0 (draft token = 1):
+ \(\alpha_0 = \min\!\left(1,\, \frac{q_0(1)}{p_0(1)}\right) = \min\!\left(1,\, \frac{0.50}{0.60}\right) \approx 0.833\).
+ Since \(u_0 = 0.50 < 0.833\), accept token 1.
+
+
+ Position 1 (draft token = 2):
+ \(\alpha_1 = \min\!\left(1,\, \frac{q_1(2)}{p_1(2)}\right) = \min\!\left(1,\, \frac{0.20}{0.50}\right) = 0.40\).
+ Since \(u_1 = 0.70 \ge 0.40\), reject . Resample from adjusted distribution:
+ \[
+ \text{adj}(v) = \max(0,\, q_1(v) - p_1(v)) = [0.20,\, 0,\, 0,\, 0.10]
+ \]
+ \[
+ \text{normalized} = \left[\tfrac{2}{3},\, 0,\, 0,\, \tfrac{1}{3}\right], \quad
+ \text{CDF} = [0.667,\, 0.667,\, 0.667,\, 1.0]
+ \]
+ With \(r = \text{uniform_samples}[0, T] = 0.90\), inverse CDF gives token 3 .
+
+
+ Output:
+ \[\text{output_tokens} = \begin{bmatrix} 1 & 3 & 0 & 0 \end{bmatrix}\]
+
+
+Constraints
+
+ 1 ≤ B ≤ 256
+ 1 ≤ T ≤ 16
+ 2 ≤ V ≤ 131,072
+ draft_probs[b, i, :] and target_probs[b, i, :] are valid probability distributions (non-negative, sum to 1)
+ draft_probs[b, i, draft_tokens[b, i]] > 0 for all b, i
+ uniform_samples values are in \([0, 1)\)
+ All floating-point tensors use float32; token tensors use int32
+ Performance is measured with B = 64, T = 8, V = 32,768
+
diff --git a/challenges/medium/87_speculative_decoding_verification/challenge.py b/challenges/medium/87_speculative_decoding_verification/challenge.py
new file mode 100644
index 00000000..363c849a
--- /dev/null
+++ b/challenges/medium/87_speculative_decoding_verification/challenge.py
@@ -0,0 +1,300 @@
+import ctypes
+from typing import Any, Dict, List
+
+import torch
+from core.challenge_base import ChallengeBase
+
+
+class Challenge(ChallengeBase):
+ def __init__(self):
+ super().__init__(
+ name="Speculative Decoding Verification",
+ atol=1e-05,
+ rtol=1e-05,
+ num_gpus=1,
+ access_tier="free",
+ )
+
+ def reference_impl(
+ self,
+ draft_tokens: torch.Tensor,
+ draft_probs: torch.Tensor,
+ target_probs: torch.Tensor,
+ uniform_samples: torch.Tensor,
+ output_tokens: torch.Tensor,
+ B: int,
+ T: int,
+ V: int,
+ ):
+ assert draft_tokens.shape == (B, T)
+ assert draft_probs.shape == (B, T, V)
+ assert target_probs.shape == (B, T, V)
+ assert uniform_samples.shape == (B, T + 1)
+ assert output_tokens.shape == (B, T + 1)
+ assert draft_tokens.dtype == torch.int32
+ assert draft_probs.dtype == torch.float32
+ assert target_probs.dtype == torch.float32
+ assert uniform_samples.dtype == torch.float32
+ assert output_tokens.dtype == torch.int32
+ assert draft_tokens.device.type == "cuda"
+ assert draft_probs.device.type == "cuda"
+ assert target_probs.device.type == "cuda"
+ assert uniform_samples.device.type == "cuda"
+ assert output_tokens.device.type == "cuda"
+
+ output_tokens.fill_(0)
+
+ for b in range(B):
+ for i in range(T):
+ tok = int(draft_tokens[b, i].item())
+ p = draft_probs[b, i, tok].item()
+ q = target_probs[b, i, tok].item()
+ alpha = min(1.0, q / p)
+
+ if uniform_samples[b, i].item() < alpha:
+ output_tokens[b, i] = tok
+ else:
+ adjusted = torch.clamp(target_probs[b, i] - draft_probs[b, i], min=0.0)
+ total = adjusted.sum().item()
+ if total > 0.0:
+ adjusted = adjusted / total
+ else:
+ adjusted = (
+ torch.ones(V, dtype=torch.float32, device=draft_tokens.device) / V
+ )
+ cdf = torch.cumsum(adjusted, dim=0)
+ r = float(uniform_samples[b, T].item())
+ new_tok = int(torch.searchsorted(cdf.contiguous(), r).item())
+ output_tokens[b, i] = min(new_tok, V - 1)
+ break
+ else:
+ cdf = torch.cumsum(target_probs[b, T - 1], dim=0)
+ r = float(uniform_samples[b, T].item())
+ bonus_tok = int(torch.searchsorted(cdf.contiguous(), r).item())
+ output_tokens[b, T] = min(bonus_tok, V - 1)
+
+ def get_solve_signature(self) -> Dict[str, tuple]:
+ return {
+ "draft_tokens": (ctypes.POINTER(ctypes.c_int), "in"),
+ "draft_probs": (ctypes.POINTER(ctypes.c_float), "in"),
+ "target_probs": (ctypes.POINTER(ctypes.c_float), "in"),
+ "uniform_samples": (ctypes.POINTER(ctypes.c_float), "in"),
+ "output_tokens": (ctypes.POINTER(ctypes.c_int), "out"),
+ "B": (ctypes.c_int, "in"),
+ "T": (ctypes.c_int, "in"),
+ "V": (ctypes.c_int, "in"),
+ }
+
+ def _make_sparse_probs(self, B, T, V, K, device):
+ """Generate sparse probability distributions: only K tokens have nonzero probability.
+
+ Using sparse distributions ensures that the adjusted distribution clamp(q-p, 0)
+ has at most 2K nonzero entries, making CDF summation numerically exact regardless
+ of summation order. This prevents floating-point sensitivity for large V.
+ """
+ K = min(K, V)
+ flat = B * T
+ # For each (b, i), sample K distinct token indices
+ idx = torch.stack([torch.randperm(V, device=device)[:K] for _ in range(flat)])
+ idx = idx.view(B, T, K)
+ # Random weights summing to 1
+ weights = torch.rand(B, T, K, device=device)
+ weights = weights / weights.sum(dim=-1, keepdim=True)
+ # Scatter into full V-dimensional probability vector
+ probs = torch.zeros(B, T, V, device=device)
+ probs.scatter_(2, idx, weights)
+ return probs, idx
+
+ def _make_test_case(self, B, T, V, seed=42):
+ torch.manual_seed(seed)
+ device = "cuda"
+
+ # K=64 active tokens per position: enough diversity while keeping the adjusted
+ # distribution sparse (at most 128 nonzero entries), ensuring CDF sums are
+ # independent of floating-point summation order.
+ K = min(64, V)
+ draft_probs, draft_idx = self._make_sparse_probs(B, T, V, K, device)
+ target_probs, _ = self._make_sparse_probs(B, T, V, K, device)
+
+ # Sample draft tokens from the active K tokens
+ weights = draft_probs.gather(2, draft_idx) # (B, T, K)
+ flat_w = weights.view(B * T, K)
+ chosen = torch.multinomial(flat_w, 1).view(B, T) # index within the K tokens
+ draft_tokens = draft_idx.gather(2, chosen.unsqueeze(-1)).squeeze(-1).to(torch.int32)
+
+ uniform_samples = torch.rand(B, T + 1, device=device)
+ output_tokens = torch.zeros(B, T + 1, device=device, dtype=torch.int32)
+
+ return {
+ "draft_tokens": draft_tokens,
+ "draft_probs": draft_probs,
+ "target_probs": target_probs,
+ "uniform_samples": uniform_samples,
+ "output_tokens": output_tokens,
+ "B": B,
+ "T": T,
+ "V": V,
+ }
+
+ def _make_accept_all_case(self, B, T, V, seed=42):
+ """All draft tokens accepted: target_probs == draft_probs so alpha == 1 everywhere."""
+ torch.manual_seed(seed)
+ device = "cuda"
+
+ K = min(64, V)
+ draft_probs, draft_idx = self._make_sparse_probs(B, T, V, K, device)
+ target_probs = draft_probs.clone() # alpha = min(1, q/p) = 1 → always accept
+
+ weights = draft_probs.gather(2, draft_idx)
+ flat_w = weights.view(B * T, K)
+ chosen = torch.multinomial(flat_w, 1).view(B, T)
+ draft_tokens = draft_idx.gather(2, chosen.unsqueeze(-1)).squeeze(-1).to(torch.int32)
+
+ # All acceptance samples set to 0 (< 1.0 = alpha) to guarantee acceptance
+ uniform_samples = torch.zeros(B, T + 1, device=device)
+ uniform_samples[:, T] = torch.rand(B, device=device) # bonus sampling sample
+
+ output_tokens = torch.zeros(B, T + 1, device=device, dtype=torch.int32)
+
+ return {
+ "draft_tokens": draft_tokens,
+ "draft_probs": draft_probs,
+ "target_probs": target_probs,
+ "uniform_samples": uniform_samples,
+ "output_tokens": output_tokens,
+ "B": B,
+ "T": T,
+ "V": V,
+ }
+
+ def _make_reject_first_case(self, B, T, V, seed=42):
+ """First draft token always rejected: draft_probs high, target low for that token."""
+ torch.manual_seed(seed)
+ device = "cuda"
+
+ draft_probs = torch.softmax(torch.randn(B, T, V, device=device), dim=-1)
+ target_probs = torch.softmax(torch.randn(B, T, V, device=device), dim=-1)
+
+ flat = draft_probs.view(B * T, V)
+ draft_tokens = torch.multinomial(flat, 1).view(B, T).to(torch.int32)
+
+ # Force rejection at position 0 for every sequence:
+ # set alpha[b,0] very small and uniform_sample[b,0] high enough to reject
+ for b in range(B):
+ tok = int(draft_tokens[b, 0].item())
+ # Make draft prob ~0.9 for the chosen token (high p)
+ draft_probs[b, 0] = torch.full((V,), 0.1 / max(V - 1, 1), device=device)
+ draft_probs[b, 0, tok] = 0.9
+ draft_probs[b, 0] = draft_probs[b, 0] / draft_probs[b, 0].sum()
+ # Make target prob ~1/V for the same token (low q)
+ target_probs[b, 0] = torch.ones(V, device=device) / V
+
+ uniform_samples = torch.rand(B, T + 1, device=device)
+ # Force uniform[b, 0] = 0.99 > alpha (which is ~1/V / 0.9 ≈ small)
+ uniform_samples[:, 0] = 0.99
+
+ output_tokens = torch.zeros(B, T + 1, device=device, dtype=torch.int32)
+
+ return {
+ "draft_tokens": draft_tokens,
+ "draft_probs": draft_probs,
+ "target_probs": target_probs,
+ "uniform_samples": uniform_samples,
+ "output_tokens": output_tokens,
+ "B": B,
+ "T": T,
+ "V": V,
+ }
+
+ def generate_example_test(self) -> Dict[str, Any]:
+ device = "cuda"
+
+ # B=1, T=3, V=4: position 0 accepted, position 1 rejected, token resampled
+ draft_tokens = torch.tensor([[1, 2, 0]], device=device, dtype=torch.int32)
+
+ draft_probs = torch.tensor(
+ [
+ [
+ [0.10, 0.60, 0.20, 0.10], # pos 0: draft_tokens[0,0]=1, p=0.60
+ [0.10, 0.20, 0.50, 0.20], # pos 1: draft_tokens[0,1]=2, p=0.50
+ [0.40, 0.20, 0.20, 0.20], # pos 2: draft_tokens[0,2]=0, p=0.40
+ ]
+ ],
+ device=device,
+ dtype=torch.float32,
+ )
+
+ target_probs = torch.tensor(
+ [
+ [
+ [0.10, 0.50, 0.20, 0.20], # pos 0: q=0.50, alpha=min(1,0.50/0.60)=0.833
+ [0.30, 0.20, 0.20, 0.30], # pos 1: q=0.20, alpha=min(1,0.20/0.50)=0.400
+ [0.30, 0.20, 0.30, 0.20], # pos 2: not reached
+ ]
+ ],
+ device=device,
+ dtype=torch.float32,
+ )
+
+ # uniform_samples[0, 0]=0.50 < 0.833 → ACCEPT token 1
+ # uniform_samples[0, 1]=0.70 > 0.400 → REJECT token 2
+ # adjusted = clamp([0.20, 0, -0.30, 0.10], min=0) = [0.20, 0, 0, 0.10]
+ # normalized CDF = [0.667, 0.667, 0.667, 1.0]
+ # uniform_samples[0, T=3]=0.90 → searchsorted → token 3
+ # output_tokens[0] = [1, 3, 0, 0]
+ uniform_samples = torch.tensor(
+ [[0.50, 0.70, 0.30, 0.90]], device=device, dtype=torch.float32
+ )
+
+ output_tokens = torch.zeros(1, 4, device=device, dtype=torch.int32)
+
+ return {
+ "draft_tokens": draft_tokens,
+ "draft_probs": draft_probs,
+ "target_probs": target_probs,
+ "uniform_samples": uniform_samples,
+ "output_tokens": output_tokens,
+ "B": 1,
+ "T": 3,
+ "V": 4,
+ }
+
+ def generate_functional_test(self) -> List[Dict[str, Any]]:
+ tests = []
+
+ # Edge: T=1, rejected immediately
+ tests.append(self._make_reject_first_case(1, 1, 4, seed=1))
+
+ # Edge: T=1, all accepted (bonus token sampled)
+ tests.append(self._make_accept_all_case(1, 1, 4, seed=2))
+
+ # Edge: T=2, first rejected
+ tests.append(self._make_reject_first_case(1, 2, 8, seed=3))
+
+ # Edge: T=4, all accepted
+ tests.append(self._make_accept_all_case(2, 4, 8, seed=4))
+
+ # Zero uniform_samples acceptance values → force rejection at pos 0 (unless alpha=1)
+ tests.append(self._make_reject_first_case(4, 4, 16, seed=5))
+
+ # Power-of-2 vocab, mixed acceptance
+ tests.append(self._make_test_case(4, 8, 64, seed=10))
+
+ # Larger vocab, mixed acceptance
+ tests.append(self._make_test_case(8, 8, 256, seed=20))
+
+ # Non-power-of-2 vocab
+ tests.append(self._make_test_case(4, 6, 30, seed=30))
+
+ # All sequences accept all tokens (bonus sampling)
+ tests.append(self._make_accept_all_case(8, 8, 128, seed=40))
+
+ # Realistic small batch
+ tests.append(self._make_test_case(16, 8, 1000, seed=50))
+
+ return tests
+
+ def generate_performance_test(self) -> Dict[str, Any]:
+ torch.manual_seed(0)
+ # B=64 sequences, T=8 draft tokens, V=32768 (Mistral/LLaMA-2 vocab size)
+ return self._make_test_case(64, 8, 32768, seed=0)
diff --git a/challenges/medium/87_speculative_decoding_verification/starter/starter.cu b/challenges/medium/87_speculative_decoding_verification/starter/starter.cu
new file mode 100644
index 00000000..6a9731a8
--- /dev/null
+++ b/challenges/medium/87_speculative_decoding_verification/starter/starter.cu
@@ -0,0 +1,5 @@
+#include
+
+// draft_tokens, draft_probs, target_probs, uniform_samples, output_tokens are device pointers
+extern "C" void solve(const int* draft_tokens, const float* draft_probs, const float* target_probs,
+ const float* uniform_samples, int* output_tokens, int B, int T, int V) {}
diff --git a/challenges/medium/87_speculative_decoding_verification/starter/starter.cute.py b/challenges/medium/87_speculative_decoding_verification/starter/starter.cute.py
new file mode 100644
index 00000000..320555b2
--- /dev/null
+++ b/challenges/medium/87_speculative_decoding_verification/starter/starter.cute.py
@@ -0,0 +1,17 @@
+import cutlass
+import cutlass.cute as cute
+
+
+# draft_tokens, draft_probs, target_probs, uniform_samples, output_tokens are tensors on the GPU
+@cute.jit
+def solve(
+ draft_tokens: cute.Tensor,
+ draft_probs: cute.Tensor,
+ target_probs: cute.Tensor,
+ uniform_samples: cute.Tensor,
+ output_tokens: cute.Tensor,
+ B: cute.Int32,
+ T: cute.Int32,
+ V: cute.Int32,
+):
+ pass
diff --git a/challenges/medium/87_speculative_decoding_verification/starter/starter.jax.py b/challenges/medium/87_speculative_decoding_verification/starter/starter.jax.py
new file mode 100644
index 00000000..bd7e8e18
--- /dev/null
+++ b/challenges/medium/87_speculative_decoding_verification/starter/starter.jax.py
@@ -0,0 +1,17 @@
+import jax
+import jax.numpy as jnp
+
+
+# draft_tokens, draft_probs, target_probs, uniform_samples are tensors on GPU
+@jax.jit
+def solve(
+ draft_tokens: jax.Array,
+ draft_probs: jax.Array,
+ target_probs: jax.Array,
+ uniform_samples: jax.Array,
+ B: int,
+ T: int,
+ V: int,
+) -> jax.Array:
+ # return output tensor directly
+ pass
diff --git a/challenges/medium/87_speculative_decoding_verification/starter/starter.mojo b/challenges/medium/87_speculative_decoding_verification/starter/starter.mojo
new file mode 100644
index 00000000..89c9fca2
--- /dev/null
+++ b/challenges/medium/87_speculative_decoding_verification/starter/starter.mojo
@@ -0,0 +1,16 @@
+from gpu.host import DeviceContext
+from memory import UnsafePointer
+
+# draft_tokens, draft_probs, target_probs, uniform_samples, output_tokens are device pointers
+@export
+def solve(
+ draft_tokens: UnsafePointer[Int32],
+ draft_probs: UnsafePointer[Float32],
+ target_probs: UnsafePointer[Float32],
+ uniform_samples: UnsafePointer[Float32],
+ output_tokens: UnsafePointer[Int32],
+ B: Int32,
+ T: Int32,
+ V: Int32,
+):
+ pass
diff --git a/challenges/medium/87_speculative_decoding_verification/starter/starter.pytorch.py b/challenges/medium/87_speculative_decoding_verification/starter/starter.pytorch.py
new file mode 100644
index 00000000..3cce7fae
--- /dev/null
+++ b/challenges/medium/87_speculative_decoding_verification/starter/starter.pytorch.py
@@ -0,0 +1,15 @@
+import torch
+
+
+# draft_tokens, draft_probs, target_probs, uniform_samples, output_tokens are tensors on the GPU
+def solve(
+ draft_tokens: torch.Tensor,
+ draft_probs: torch.Tensor,
+ target_probs: torch.Tensor,
+ uniform_samples: torch.Tensor,
+ output_tokens: torch.Tensor,
+ B: int,
+ T: int,
+ V: int,
+):
+ pass
diff --git a/challenges/medium/87_speculative_decoding_verification/starter/starter.triton.py b/challenges/medium/87_speculative_decoding_verification/starter/starter.triton.py
new file mode 100644
index 00000000..1347ec9c
--- /dev/null
+++ b/challenges/medium/87_speculative_decoding_verification/starter/starter.triton.py
@@ -0,0 +1,17 @@
+import torch
+import triton
+import triton.language as tl
+
+
+# draft_tokens, draft_probs, target_probs, uniform_samples, output_tokens are tensors on the GPU
+def solve(
+ draft_tokens: torch.Tensor,
+ draft_probs: torch.Tensor,
+ target_probs: torch.Tensor,
+ uniform_samples: torch.Tensor,
+ output_tokens: torch.Tensor,
+ B: int,
+ T: int,
+ V: int,
+):
+ pass