Skip to content

Commit 51c78b9

Browse files
[V1] implement tree sampler for draft token acceptance
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
1 parent 3beadc2 commit 51c78b9

File tree

11 files changed

+1219
-178
lines changed

11 files changed

+1219
-178
lines changed

tests/v1/e2e/test_spec_decode.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,8 @@ def test_eagle_correctness(
159159
attn_backend: str,
160160
):
161161
if attn_backend == "TREE_ATTN":
162-
# TODO: Fix this flaky test
163162
pytest.skip(
164-
"TREE_ATTN is flaky in the test disable for now until it can be "
165-
"resolved (see https://github.com/vllm-project/vllm/issues/22922)")
163+
"TREE_ATTN is tested separately in test_tree_eagle_correctness.")
166164

167165
# Generate test prompts inside the function instead of using fixture
168166
test_prompts = get_test_prompts(mm_enabled)
@@ -223,3 +221,83 @@ def test_eagle_correctness(
223221
del spec_llm
224222
torch.cuda.empty_cache()
225223
cleanup_dist_env_and_memory()
224+
225+
226+
@pytest.mark.parametrize("model_setup", [
227+
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
228+
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
229+
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
230+
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
231+
],
232+
ids=[
233+
"llama3_eagle",
234+
"llama3_eagle3",
235+
])
236+
@pytest.mark.parametrize(
237+
"spec_token_tree",
238+
[
239+
[(0, )], # A single token
240+
[(0, ), (0, 0), (0, 0, 0)], # Chain
241+
[(0, ), (1, ), (2, )], # Parallel
242+
[(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0),
243+
(2, 1)], # Tree
244+
])
245+
def test_tree_eagle_correctness(
246+
monkeypatch: pytest.MonkeyPatch,
247+
sampling_config: SamplingParams,
248+
model_setup: tuple[str, str, str, int],
249+
spec_token_tree: list[tuple[int, ...]],
250+
):
251+
# Generate test prompts inside the function instead of using fixture
252+
test_prompts = get_test_prompts(False)
253+
'''
254+
Compare the outputs of a original LLM and a speculative LLM
255+
should be the same when using eagle speculative decoding.
256+
model_setup: (method, model_name, eagle_model_name, tp_size)
257+
'''
258+
with monkeypatch.context() as m:
259+
m.setenv("VLLM_USE_V1", "1")
260+
m.setenv("VLLM_ATTENTION_BACKEND", "TREE_ATTN")
261+
method, model_name, spec_model_name, tp_size = model_setup
262+
263+
ref_llm = LLM(model=model_name,
264+
max_model_len=2048,
265+
tensor_parallel_size=tp_size)
266+
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
267+
del ref_llm
268+
torch.cuda.empty_cache()
269+
cleanup_dist_env_and_memory()
270+
271+
spec_llm = LLM(
272+
model=model_name,
273+
trust_remote_code=True,
274+
tensor_parallel_size=tp_size,
275+
speculative_config={
276+
"method": method,
277+
"model": spec_model_name,
278+
"num_speculative_tokens": len(spec_token_tree),
279+
"spec_token_tree": str(spec_token_tree),
280+
"max_model_len": 2048,
281+
},
282+
max_model_len=2048,
283+
)
284+
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
285+
matches = 0
286+
misses = 0
287+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
288+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
289+
matches += 1
290+
else:
291+
misses += 1
292+
print(f"ref_output: {ref_output.outputs[0].text}")
293+
print(f"spec_output: {spec_output.outputs[0].text}")
294+
295+
# Heuristic: expect at least 50% of the prompts to match exactly
296+
# Upon failure, inspect the outputs to check for inaccuracy. This
297+
# threshold is lower than the other tests because the tree attention
298+
# backend uses triton kernels, which seem to introduce more floating
299+
# point non-determinism when compared to FA3.
300+
assert matches > int(0.50 * len(ref_outputs))
301+
del spec_llm
302+
torch.cuda.empty_cache()
303+
cleanup_dist_env_and_memory()

0 commit comments

Comments
 (0)