Skip to content

Commit 44a761a

Browse files
[V1] implement tree sampler for draft token acceptance
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
1 parent 51c78b9 commit 44a761a

File tree

5 files changed

+144
-110
lines changed

5 files changed

+144
-110
lines changed

tests/v1/spec_decode/test_eagle.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_prepare_inputs():
125125
proposer = _create_proposer("eagle", 1)
126126

127127
updated_metadata, token_indices = proposer.prepare_inputs(
128-
common_attn_metadata, num_rejected_tokens.cpu())
128+
common_attn_metadata, num_rejected_tokens.cpu(), [], [])
129129

130130
assert torch.equal(updated_metadata.query_start_loc,
131131
expected_cu_num_tokens)
@@ -405,7 +405,9 @@ def create_deterministic_logits(token_ids):
405405
[(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0),
406406
(2, 1)], # Tree
407407
])
408-
def test_propose_tree(spec_token_tree):
408+
def test_propose_tree(spec_token_tree, monkeypatch):
409+
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TREE_ATTN")
410+
409411
# Get GPU device.
410412
device = torch.device(current_platform.device_type)
411413

tests/v1/spec_decode/test_tree_attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,12 @@ def forward_attention(
120120
)
121121

122122

123-
def test_tree_attn_correctness() -> None:
123+
def test_tree_attn_correctness(monkeypatch) -> None:
124124
torch.manual_seed(42)
125125
torch.cuda.manual_seed_all(42)
126126

127+
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TREE_ATTN")
128+
127129
device = "cuda"
128130
tree_attn_masks = {
129131
# Chain.

vllm/v1/sample/tree_rejection_sampler.py

Lines changed: 134 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,8 @@ def __init__(
6565

6666
# Precompute indices for logits corresponding to tree-internal
6767
# tokens across batches.
68-
num_tree_internal_tokens = self.cu_tokens_per_level[-2]
69-
self.tree_internal_index_offsets = torch.arange(
70-
num_tree_internal_tokens, device=device)
68+
self.tree_internal_size = self.cu_tokens_per_level[-2]
69+
self.tree_index_offsets = torch.arange(self.tree_size, device=device)
7170

7271
def forward(
7372
self,
@@ -113,101 +112,23 @@ def forward(
113112
# 3 4 5 6 level 2
114113

115114
device = target_logits.device
115+
draft_tree_size = self.tree_size - 1
116+
tree_internal_index_offsets = (
117+
self.tree_index_offsets[:self.tree_internal_size])
118+
draft_index_offsets = self.tree_index_offsets[:draft_tree_size]
119+
116120
num_reqs = len(metadata.num_draft_tokens)
117-
# [8, 8, 0, 0, 0, 0, 0, 0]
121+
# [1, 8, 8, 0, 0, 0, 0, 0]
118122
num_draft_tokens = torch.tensor(metadata.num_draft_tokens,
119123
device=device)
120-
draft_tree_size = self.tree_size - 1
121-
# [1, 1, 0, 0, 0, 0, 0, 0]
122-
tree_decode_mask = num_draft_tokens == draft_tree_size
124+
# [0, 1, 1, 0, 0, 0, 0, 0]
125+
is_tree_decode = num_draft_tokens == draft_tree_size
123126
# [0, 1, 2, 3, 4, 5, 6, 7]
124127
start_indices = torch.arange(num_reqs, device=device)
125-
# [0, 9, 18, 19, 20, 21, 22, 23]
128+
# [0, 2, 11, 20, 21, 22, 23, 24]
126129
start_indices[1:] += metadata.cu_num_draft_tokens[:-1]
127130

128-
# Compute target probabilities for all logits corresponding to internal
129-
# nodes in the tree.
130-
vocab_size = target_logits.shape[-1]
131-
# [0, 9]
132-
tree_decode_start_indices = start_indices[tree_decode_mask]
133-
# [[0, 1, 2],
134-
# [9, 10, 11]]
135-
tree_internal_indices = tree_decode_start_indices.unsqueeze(
136-
1) + self.tree_internal_index_offsets
137-
num_tree_decodes, num_logits_per_batch = tree_internal_indices.shape
138-
tree_internal_logits = target_logits[tree_internal_indices.flatten()]
139-
target_probs = self.compute_probs(
140-
tree_internal_logits,
141-
num_logits_per_batch,
142-
sampling_metadata,
143-
).view(num_tree_decodes, -1, vocab_size)
144-
145-
# Sample tokens from the target probabilities.
146-
# TODO(TheEpicDolphin): Add support for probabilistic-style rejection
147-
# sampling, as used in EAGLE.
148-
target_token_ids = target_probs.argmax(dim=-1).cpu()
149-
150-
# Reshape the draft token ids to [num_tree_decodes, draft_tree_size].
151-
draft_token_ids = metadata.draft_token_ids.view(num_tree_decodes, -1)
152-
153-
# Move sampled target and draft token tensors to CPU.
154-
# [[311, 6435, 96618],
155-
# [279, 11, 15861]]
156-
target_token_ids_cpu = target_token_ids.cpu()
157-
# [[311, 8844, 2349, 387, 4732, 96618, 311, 334],
158-
# [3634, 279, 323, 11, 438, 15861, 3634, 7016]]
159-
draft_token_ids_cpu = draft_token_ids.cpu()
160-
161-
# For each batch, find longest path from the root node.
162-
path_lengths = torch.zeros(
163-
# +1 for the root token.
164-
(num_tree_decodes, draft_tree_size + 1),
165-
dtype=torch.int32)
166-
path_lengths[:, 0] = 1
167-
for level in range(1, self.tree_depth):
168-
# level 2:
169-
# (3, 9)
170-
start, end = self.draft_slices[level]
171-
# [1, 1, 1, 2, 2, 2]
172-
parent_indices = self.parent_indices[level]
173-
# [[0, 0, 0, 0, 0, 0],
174-
# [0, 1, 0, 1, 0, 0]]
175-
sample_match = draft_token_ids_cpu[:, start - 1:end -
176-
1] == target_token_ids_cpu[:,
177-
parent_indices]
178-
nonzero_length = path_lengths[:, parent_indices] > 0
179-
# [[1, 2, 0, 0, 0, 0, 0, 0, 0], -> [[1, 2, 0, 0, 0, 0, 0, 0, 0],
180-
# [1, 0, 2, 0, 0, 0, 0, 0, 0]] [1, 0, 2, 0, 0, 0, 3, 0, 0]]
181-
path_lengths[:,
182-
start:end].masked_fill_(sample_match & nonzero_length,
183-
level + 1)
184-
# [1, 6]
185-
accepted_token_index_offsets = path_lengths.argmax(dim=-1).to(device)
186-
187-
# Get boolean masks for the paths to the accepted tokens.
188-
# [0, 1]
189-
tree_batch_indices = self.batch_indices[:num_tree_decodes]
190-
# [[[1, 0, 0, 0, 0, 0, 0], <- batch 0
191-
# [1, 1, 0, 0, 0, 0, 0],
192-
# [1, 0, 1, 0, 0, 0, 0],
193-
# [1, 1, 0, 1, 0, 0, 0],
194-
# [1, 1, 0, 0, 1, 0, 0],
195-
# [1, 0, 1, 0, 0, 1, 0],
196-
# [1, 0, 1, 0, 0, 0, 1]],
197-
# [[1, 0, 0, 0, 0, 0, 0], <- batch 1
198-
# [1, 1, 0, 0, 0, 0, 0],
199-
# [1, 0, 1, 0, 0, 0, 0],
200-
# [1, 1, 0, 1, 0, 0, 0],
201-
# [1, 1, 0, 0, 1, 0, 0],
202-
# [1, 0, 1, 0, 0, 1, 0],
203-
# [1, 0, 1, 0, 0, 0, 1]]]
204-
tree_mask = self.expanded_tree_mask[:num_tree_decodes]
205-
# [1, 6] => [[1, 0, 0, 0, 0, 0], <- batch 0
206-
# [0, 1, 0, 0, 0, 1]] <- batch 1
207-
path_masks = tree_mask[tree_batch_indices,
208-
accepted_token_index_offsets]
209-
210-
# Create output buffer.
131+
# Create output token ids buffer.
211132
output_token_ids = torch.empty(
212133
# +1 for the bonus token.
213134
(num_reqs, draft_tree_size + 1),
@@ -217,15 +138,116 @@ def forward(
217138
)
218139
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
219140

220-
# Set accepted draft tokens.
221-
accepted_draft_tokens = draft_token_ids[path_masks]
222-
scatter_mask = torch.zeros_like(output_token_ids, dtype=torch.bool)
223-
scatter_mask[tree_decode_mask, :-1] = path_masks
224-
output_token_ids.masked_scatter_(scatter_mask, accepted_draft_tokens)
141+
# [0, 0, 0, 0, 0, 0, 0, 0]
142+
accepted_index_offsets = torch.zeros_like(is_tree_decode,
143+
dtype=torch.int32)
144+
145+
num_tree_decodes = is_tree_decode.sum()
146+
if num_tree_decodes > 0:
147+
# Compute target probabilities for all logits corresponding to
148+
# internal nodes in the tree.
149+
vocab_size = target_logits.shape[-1]
150+
# [0, 9]
151+
tree_decode_start_indices = start_indices[is_tree_decode]
152+
# [[0, 1, 2],
153+
# [9, 10, 11]]
154+
tree_internal_indices = (tree_decode_start_indices.unsqueeze(1) +
155+
tree_internal_index_offsets)
156+
tree_internal_logits = target_logits[
157+
tree_internal_indices.flatten()]
158+
target_probs = self.compute_tree_target_probs(
159+
tree_internal_logits,
160+
is_tree_decode,
161+
num_tree_decodes,
162+
sampling_metadata,
163+
).view(num_tree_decodes, -1, vocab_size)
164+
165+
# Sample tokens from the target probabilities.
166+
# TODO(TheEpicDolphin): Add support for probabilistic-style
167+
# rejection sampling, as used in EAGLE.
168+
target_token_ids = target_probs.argmax(dim=-1)
169+
170+
# Get the draft token ids for batches with full draft trees.
171+
# [0, 0]
172+
draft_start_indices = torch.zeros(num_tree_decodes,
173+
device=device,
174+
dtype=torch.int32)
175+
# [0, 8]
176+
draft_start_indices[1:] = (
177+
metadata.cu_num_draft_tokens[is_tree_decode][:-1])
178+
# [[0, 1, 2, ... , 7]
179+
# [8, 9, 10, ... , 15]]
180+
tree_draft_indices = (draft_start_indices.unsqueeze(1) +
181+
draft_index_offsets)
182+
draft_token_ids = metadata.draft_token_ids[tree_draft_indices]
183+
184+
# Move sampled target and draft token tensors to CPU.
185+
# [[311, 6435, 96618],
186+
# [279, 11, 15861]]
187+
target_token_ids_cpu = target_token_ids.cpu()
188+
# [[311, 8844, 2349, 387, 4732, 96618, 311, 334],
189+
# [3634, 279, 323, 11, 438, 15861, 3634, 7016]]
190+
draft_token_ids_cpu = draft_token_ids.cpu()
191+
192+
# For each tree decode batch, find longest path from the root node.
193+
path_lengths_cpu = torch.zeros(
194+
# +1 for the root token.
195+
(num_tree_decodes, draft_tree_size + 1),
196+
dtype=torch.int32,
197+
device="cpu")
198+
path_lengths_cpu[:, 0] = 1
199+
for level in range(1, self.tree_depth):
200+
# level 2:
201+
# (3, 9)
202+
start, end = self.draft_slices[level]
203+
# [1, 1, 1, 2, 2, 2]
204+
parent_indices = self.parent_indices[level]
205+
# [[0, 0, 0, 0, 0, 0],
206+
# [0, 1, 0, 1, 0, 0]]
207+
sample_match = (draft_token_ids_cpu[:, start - 1:end - 1] ==
208+
target_token_ids_cpu[:, parent_indices])
209+
nonzero_length = path_lengths_cpu[:, parent_indices] > 0
210+
# [[1, 2, 0, 0, 0, 0, 0, 0, 0],-> [[1, 2, 0, 0, 0, 0, 0, 0, 0],
211+
# [1, 0, 2, 0, 0, 0, 0, 0, 0]] [1, 0, 2, 0, 0, 0, 3, 0, 0]]
212+
path_lengths_cpu[:, start:end].masked_fill_(
213+
sample_match & nonzero_length, level + 1)
214+
# [1, 6, 0, 0, 0, 0, 0, 0]
215+
path_lengths = path_lengths_cpu.argmax(dim=-1).to(
216+
device, dtype=torch.int32)
217+
accepted_index_offsets[is_tree_decode] = path_lengths
218+
219+
# Get boolean masks for the paths to the accepted tokens.
220+
# [0, 1]
221+
tree_batch_indices = self.batch_indices[:num_tree_decodes]
222+
# [[[1, 0, 0, 0, 0, 0, 0], <- batch 0
223+
# [1, 1, 0, 0, 0, 0, 0],
224+
# [1, 0, 1, 0, 0, 0, 0],
225+
# [1, 1, 0, 1, 0, 0, 0],
226+
# [1, 1, 0, 0, 1, 0, 0],
227+
# [1, 0, 1, 0, 0, 1, 0],
228+
# [1, 0, 1, 0, 0, 0, 1]],
229+
# [[1, 0, 0, 0, 0, 0, 0], <- batch 1
230+
# [1, 1, 0, 0, 0, 0, 0],
231+
# [1, 0, 1, 0, 0, 0, 0],
232+
# [1, 1, 0, 1, 0, 0, 0],
233+
# [1, 1, 0, 0, 1, 0, 0],
234+
# [1, 0, 1, 0, 0, 1, 0],
235+
# [1, 0, 1, 0, 0, 0, 1]]]
236+
tree_mask = self.expanded_tree_mask[:num_tree_decodes]
237+
# [1, 6] => [[1, 0, 0, 0, 0, 0], <- batch 0
238+
# [0, 1, 0, 0, 0, 1]] <- batch 1
239+
path_masks = tree_mask[tree_batch_indices, path_lengths]
240+
241+
# Set accepted draft tokens.
242+
accepted_draft_tokens = draft_token_ids[path_masks]
243+
scatter_mask = torch.zeros_like(output_token_ids, dtype=torch.bool)
244+
scatter_mask[is_tree_decode, :-1] = path_masks
245+
output_token_ids.masked_scatter_(scatter_mask,
246+
accepted_draft_tokens)
225247

226248
# Sample and add a bonus token to the accepted paths.
227-
bonus_token_indices = start_indices
228-
bonus_token_indices[tree_decode_mask] += accepted_token_index_offsets
249+
# [0, 2 + 1, 11 + 6, 20, 21, 22, 23, 24]
250+
bonus_token_indices = start_indices + accepted_index_offsets
229251
bonus_sampler_output = self.main_sampler(
230252
logits=target_logits[bonus_token_indices],
231253
sampling_metadata=sampling_metadata,
@@ -234,22 +256,30 @@ def forward(
234256
-1] = bonus_sampler_output.sampled_token_ids.view(-1)
235257
return output_token_ids
236258

237-
def compute_probs(self, logits: torch.Tensor, logits_per_batch: int,
238-
sampling_metadata: SamplingMetadata):
259+
def compute_tree_target_probs(self, logits: torch.Tensor,
260+
is_tree_decode: torch.Tensor,
261+
num_tree_decodes: int,
262+
sampling_metadata: SamplingMetadata):
239263
if sampling_metadata.all_greedy:
240264
return logits
241265

266+
# How many times to repeat the temperature, top-k, and top-p
267+
# for each tree-decode batch.
268+
num_repeats = logits.shape[0] // num_tree_decodes
269+
242270
assert sampling_metadata.temperature is not None
243-
temperature = sampling_metadata.temperature.repeat_interleave(
244-
logits_per_batch)
271+
temperature = sampling_metadata.temperature[is_tree_decode]
272+
temperature = temperature.repeat_interleave(num_repeats)
245273
logits.div_(temperature.view(-1, 1))
246274

247275
top_k = None
248276
if sampling_metadata.top_k is not None:
249-
top_k = sampling_metadata.top_k.repeat_interleave(logits_per_batch)
277+
top_k = sampling_metadata.top_k[is_tree_decode]
278+
top_k = top_k.repeat_interleave(num_repeats)
250279
top_p = None
251280
if sampling_metadata.top_p is not None:
252-
top_p = sampling_metadata.top_p.repeat_interleave(logits_per_batch)
281+
top_p = sampling_metadata.top_p[is_tree_decode]
282+
top_p = top_p.repeat_interleave(num_repeats)
253283
logits = apply_top_k_top_p(logits, top_k, top_p)
254284
output_probs = logits.softmax(dim=-1, dtype=torch.float32)
255285
return output_probs

vllm/v1/tree_spec_decode/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ def apply_draft_offsets(
3939
req_idx = input_batch.req_id_to_index[req_id]
4040
start = query_start_loc_np[req_idx]
4141
end = query_start_loc_np[req_idx + 1]
42+
num_drafts = end - start - 1
4243
token_positions_np[start + 1:end] = (token_positions_np[start] +
43-
draft_token_offsets)
44+
draft_token_offsets[:num_drafts])
4445

4546

4647
def apply_accepted_draft_indices(

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2899,8 +2899,7 @@ def _dummy_sampler_run(
28992899
else:
29002900
raise e
29012901
if self.speculative_config:
2902-
num_spec_tokens = self.speculative_config.num_speculative_tokens
2903-
draft_token_ids = [[0] * num_spec_tokens for _ in range(num_reqs)]
2902+
draft_token_ids = [[0] for _ in range(num_reqs)]
29042903
dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
29052904
draft_token_ids, self.device)
29062905

0 commit comments

Comments
 (0)