From 509bf82215e3f51b6096158282d20922d28cb0ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Delacourt?= Date: Thu, 2 Oct 2025 10:38:54 +0000 Subject: [PATCH 01/11] Add eagle dp>1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rémi Delacourt --- tests/v1/distributed/test_eagle_dp.py | 65 +++++++++++++++++++++++++++ vllm/v1/spec_decode/eagle.py | 30 +++++++------ 2 files changed, 81 insertions(+), 14 deletions(-) create mode 100644 tests/v1/distributed/test_eagle_dp.py diff --git a/tests/v1/distributed/test_eagle_dp.py b/tests/v1/distributed/test_eagle_dp.py new file mode 100644 index 000000000000..5350816bd2e8 --- /dev/null +++ b/tests/v1/distributed/test_eagle_dp.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import os +from contextlib import ExitStack + +import pytest + +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.platforms import current_platform +from vllm.sampling_params import RequestOutputKind +from vllm.v1.engine.async_llm import AsyncLLM + +DP_SIZE = int(os.getenv("DP_SIZE", 2)) + + +@pytest.fixture +def use_vllm_v1(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_V1", "1") + + +@pytest.mark.asyncio +async def test_run_eagle_dp(use_vllm_v1): + target_model = "meta-llama/Llama-3.1-8B-Instruct" + draft_model = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + + engine_args = AsyncEngineArgs( + model=target_model, + tokenizer_mode="auto", + enforce_eager=True, + tensor_parallel_size=int(os.getenv("TP_SIZE", 1)), + data_parallel_size=DP_SIZE, + data_parallel_backend="mp", # ray takes more time + trust_remote_code=True, + speculative_config={ + "model": draft_model, + "method": "eagle", + "num_speculative_tokens": 3, + }) + + if not current_platform.supports_v1(engine_args.create_model_config()): + pytest.skip(reason="Requires V1-supporting platform.", + allow_module_level=True) + + with ExitStack() as after: + engine = AsyncLLM.from_engine_args(engine_args) + after.callback(engine.shutdown) + + prompt = "This is a test of data parallel with eagle" + num_expected_tokens = 100 + sampling_params = SamplingParams( + min_tokens=num_expected_tokens, + max_tokens=num_expected_tokens, + ignore_eos=True, + output_kind=RequestOutputKind.FINAL_ONLY, + temperature=0) + async with asyncio.timeout(30): + async for out in engine.generate(request_id="eagle-dp", + prompt=prompt, + sampling_params=sampling_params): + num_tokens = len(out.outputs[0].token_ids) + assert num_tokens == num_expected_tokens + + assert not engine.output_processor.has_unfinished_requests() diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index dc6db0138806..eddacf42c5fd 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -980,21 +980,23 @@ def dummy_run( self, num_tokens: int, ) -> None: - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): - if self.supports_mm_inputs: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] - else: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = None + for _ in range(self.num_speculative_tokens): + with set_forward_context(None, + self.vllm_config, + num_tokens=num_tokens): + if self.supports_mm_inputs: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None - self.model( - input_ids=input_ids, - positions=self._get_positions(num_tokens), - hidden_states=self.hidden_states[:num_tokens], - inputs_embeds=inputs_embeds, - ) + self.model( + input_ids=input_ids, + positions=self._get_positions(num_tokens), + hidden_states=self.hidden_states[:num_tokens], + inputs_embeds=inputs_embeds, + ) def _get_attention_metadata_builder( self) -> list[AttentionMetadataBuilder]: From 80d5fafe5aec2aab7514dd4d847fbe234f311bb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Delacourt?= Date: Thu, 2 Oct 2025 13:32:02 +0000 Subject: [PATCH 02/11] python 3.9 doesn't have asyncio.timeout MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rémi Delacourt --- tests/v1/distributed/test_eagle_dp.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/v1/distributed/test_eagle_dp.py b/tests/v1/distributed/test_eagle_dp.py index 5350816bd2e8..e363361702f4 100644 --- a/tests/v1/distributed/test_eagle_dp.py +++ b/tests/v1/distributed/test_eagle_dp.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import os -from contextlib import ExitStack +from contextlib import AsyncExitStack import pytest @@ -43,7 +43,7 @@ async def test_run_eagle_dp(use_vllm_v1): pytest.skip(reason="Requires V1-supporting platform.", allow_module_level=True) - with ExitStack() as after: + async with AsyncExitStack() as after: engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) @@ -55,11 +55,14 @@ async def test_run_eagle_dp(use_vllm_v1): ignore_eos=True, output_kind=RequestOutputKind.FINAL_ONLY, temperature=0) - async with asyncio.timeout(30): + + async def generate_with_timeout(): async for out in engine.generate(request_id="eagle-dp", prompt=prompt, sampling_params=sampling_params): num_tokens = len(out.outputs[0].token_ids) assert num_tokens == num_expected_tokens + await asyncio.wait_for(generate_with_timeout(), timeout=30) + assert not engine.output_processor.has_unfinished_requests() From 9c74e82404c3c44a17b490d77347b57af7900d8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Delacourt?= Date: Fri, 3 Oct 2025 12:09:56 +0000 Subject: [PATCH 03/11] Add correctness test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rémi Delacourt --- .buildkite/test-pipeline.yaml | 2 + tests/v1/distributed/test_eagle_dp.py | 71 ++++++++++++++++----------- 2 files changed, 44 insertions(+), 29 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c131192c56fc..ae6946056f6c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -187,6 +187,7 @@ steps: # test with internal dp - python3 ../examples/offline_inference/data_parallel.py --enforce-eager - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py @@ -929,6 +930,7 @@ steps: - tests/v1/worker/test_worker_memory_snapshot.py commands: - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py diff --git a/tests/v1/distributed/test_eagle_dp.py b/tests/v1/distributed/test_eagle_dp.py index e363361702f4..fced198a8921 100644 --- a/tests/v1/distributed/test_eagle_dp.py +++ b/tests/v1/distributed/test_eagle_dp.py @@ -33,36 +33,49 @@ async def test_run_eagle_dp(use_vllm_v1): data_parallel_size=DP_SIZE, data_parallel_backend="mp", # ray takes more time trust_remote_code=True, - speculative_config={ - "model": draft_model, - "method": "eagle", - "num_speculative_tokens": 3, - }) + ) - if not current_platform.supports_v1(engine_args.create_model_config()): + eagle_engine_args = engine_args.replace(speculative_config={ + "model": draft_model, + "method": "eagle", + "num_speculative_tokens": 3, + }) + + if not current_platform.supports_v1( + eagle_engine_args.create_model_config()): pytest.skip(reason="Requires V1-supporting platform.", allow_module_level=True) - async with AsyncExitStack() as after: - engine = AsyncLLM.from_engine_args(engine_args) - after.callback(engine.shutdown) - - prompt = "This is a test of data parallel with eagle" - num_expected_tokens = 100 - sampling_params = SamplingParams( - min_tokens=num_expected_tokens, - max_tokens=num_expected_tokens, - ignore_eos=True, - output_kind=RequestOutputKind.FINAL_ONLY, - temperature=0) - - async def generate_with_timeout(): - async for out in engine.generate(request_id="eagle-dp", - prompt=prompt, - sampling_params=sampling_params): - num_tokens = len(out.outputs[0].token_ids) - assert num_tokens == num_expected_tokens - - await asyncio.wait_for(generate_with_timeout(), timeout=30) - - assert not engine.output_processor.has_unfinished_requests() + prompt = "This is a test of data parallel with eagle" + num_expected_tokens = 100 + sampling_params = SamplingParams(min_tokens=num_expected_tokens, + max_tokens=num_expected_tokens, + ignore_eos=True, + output_kind=RequestOutputKind.FINAL_ONLY, + temperature=0) + + async def generate_with_timeout(given_engine: AsyncLLM): + async for out in given_engine.generate( + request_id="test-eagle-dp", + prompt=prompt, + sampling_params=sampling_params): + token_ids = out.outputs[0].token_ids + assert len(token_ids) == num_expected_tokens + return token_ids + + async def engine_create_and_generate(engine_args: AsyncEngineArgs): + async with AsyncExitStack() as after: + engine = AsyncLLM.from_engine_args(engine_args) + after.callback(engine.shutdown) + + token_ids = await asyncio.wait_for(generate_with_timeout(engine), + timeout=30) + + assert not engine.output_processor.has_unfinished_requests() + return token_ids + + token_ids_with_eagle = await engine_create_and_generate(eagle_engine_args) + token_ids_no_eagle = await engine_create_and_generate(engine_args) + + # Test for correctness + assert token_ids_with_eagle == token_ids_no_eagle From 97b1280a518e3cc877cd5b31b69ebb3b453c2513 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Delacourt?= Date: Mon, 6 Oct 2025 08:51:33 +0000 Subject: [PATCH 04/11] pre-commit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rémi Delacourt --- tests/v1/distributed/test_eagle_dp.py | 40 ++++++++++++++------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/tests/v1/distributed/test_eagle_dp.py b/tests/v1/distributed/test_eagle_dp.py index fced198a8921..341fbcd06ed6 100644 --- a/tests/v1/distributed/test_eagle_dp.py +++ b/tests/v1/distributed/test_eagle_dp.py @@ -35,30 +35,31 @@ async def test_run_eagle_dp(use_vllm_v1): trust_remote_code=True, ) - eagle_engine_args = engine_args.replace(speculative_config={ - "model": draft_model, - "method": "eagle", - "num_speculative_tokens": 3, - }) + eagle_engine_args = engine_args.replace( + speculative_config={ + "model": draft_model, + "method": "eagle", + "num_speculative_tokens": 3, + } + ) - if not current_platform.supports_v1( - eagle_engine_args.create_model_config()): - pytest.skip(reason="Requires V1-supporting platform.", - allow_module_level=True) + if not current_platform.supports_v1(eagle_engine_args.create_model_config()): + pytest.skip(reason="Requires V1-supporting platform.", allow_module_level=True) prompt = "This is a test of data parallel with eagle" num_expected_tokens = 100 - sampling_params = SamplingParams(min_tokens=num_expected_tokens, - max_tokens=num_expected_tokens, - ignore_eos=True, - output_kind=RequestOutputKind.FINAL_ONLY, - temperature=0) + sampling_params = SamplingParams( + min_tokens=num_expected_tokens, + max_tokens=num_expected_tokens, + ignore_eos=True, + output_kind=RequestOutputKind.FINAL_ONLY, + temperature=0, + ) async def generate_with_timeout(given_engine: AsyncLLM): async for out in given_engine.generate( - request_id="test-eagle-dp", - prompt=prompt, - sampling_params=sampling_params): + request_id="test-eagle-dp", prompt=prompt, sampling_params=sampling_params + ): token_ids = out.outputs[0].token_ids assert len(token_ids) == num_expected_tokens return token_ids @@ -68,8 +69,9 @@ async def engine_create_and_generate(engine_args: AsyncEngineArgs): engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) - token_ids = await asyncio.wait_for(generate_with_timeout(engine), - timeout=30) + token_ids = await asyncio.wait_for( + generate_with_timeout(engine), timeout=30 + ) assert not engine.output_processor.has_unfinished_requests() return token_ids From 4fcf30415b4e4491018d4d47d3310cf05e62150b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Delacourt?= Date: Tue, 7 Oct 2025 08:25:49 +0000 Subject: [PATCH 05/11] Add FIXME for tree-based specdec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rémi Delacourt --- vllm/v1/spec_decode/eagle.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 7b336e8a0c12..3ed3dd5c4333 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1018,6 +1018,8 @@ def dummy_run( self, num_tokens: int, ) -> None: + # FIXME: when using tree-based specdec, adjust number of forward-passes + # according to the depth of the tree. for _ in range(self.num_speculative_tokens): with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): if self.supports_mm_inputs: From 2fbcb3aa63ed5a4e7ffc7daa5795fad6c521c550 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Delacourt?= Date: Mon, 13 Oct 2025 14:23:24 +0000 Subject: [PATCH 06/11] Fix indentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rémi Delacourt --- vllm/v1/spec_decode/eagle.py | 48 ++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index cecacf9cc795..4ddc3dc7861f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1051,30 +1051,30 @@ def dummy_run( ) -> None: if use_cudagraphs and num_tokens <= self.cudagraph_batch_sizes[-1]: num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) - # FIXME: when using tree-based specdec, adjust number of forward-passes - # according to the depth of the tree. - for _ in range(self.num_speculative_tokens): - with set_forward_context( - None, - self.vllm_config, - num_tokens=num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE - if use_cudagraphs - else CUDAGraphMode.NONE, - ): - if self.supports_mm_inputs: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] - else: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = None - - self.model( - input_ids=input_ids, - positions=self._get_positions(num_tokens), - hidden_states=self.hidden_states[:num_tokens], - inputs_embeds=inputs_embeds, - ) + # FIXME: when using tree-based specdec, adjust number of forward-passes + # according to the depth of the tree. + for _ in range(self.num_speculative_tokens): + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE + if use_cudagraphs + else CUDAGraphMode.NONE, + ): + if self.supports_mm_inputs: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + + self.model( + input_ids=input_ids, + positions=self._get_positions(num_tokens), + hidden_states=self.hidden_states[:num_tokens], + inputs_embeds=inputs_embeds, + ) def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder: """Find and return the attention metadata builders for EAGLE layers. From febe6b964a1dd8e09998860b6e6608a4ce8d9d7e Mon Sep 17 00:00:00 2001 From: remi Date: Sat, 1 Nov 2025 11:08:00 +0000 Subject: [PATCH 07/11] Fix Signed-off-by: remi --- tests/v1/distributed/test_eagle_dp.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/tests/v1/distributed/test_eagle_dp.py b/tests/v1/distributed/test_eagle_dp.py index 341fbcd06ed6..d8a416af7bc0 100644 --- a/tests/v1/distributed/test_eagle_dp.py +++ b/tests/v1/distributed/test_eagle_dp.py @@ -3,25 +3,20 @@ import asyncio import os from contextlib import AsyncExitStack +from dataclasses import replace import pytest from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.v1.engine.async_llm import AsyncLLM DP_SIZE = int(os.getenv("DP_SIZE", 2)) -@pytest.fixture -def use_vllm_v1(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_V1", "1") - - @pytest.mark.asyncio -async def test_run_eagle_dp(use_vllm_v1): +async def test_run_eagle_dp(): target_model = "meta-llama/Llama-3.1-8B-Instruct" draft_model = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" @@ -35,17 +30,15 @@ async def test_run_eagle_dp(use_vllm_v1): trust_remote_code=True, ) - eagle_engine_args = engine_args.replace( + eagle_engine_args = replace( + engine_args, speculative_config={ "model": draft_model, "method": "eagle", "num_speculative_tokens": 3, - } + }, ) - if not current_platform.supports_v1(eagle_engine_args.create_model_config()): - pytest.skip(reason="Requires V1-supporting platform.", allow_module_level=True) - prompt = "This is a test of data parallel with eagle" num_expected_tokens = 100 sampling_params = SamplingParams( From 9c53ba62ea6b02db9cf0d4f71b888601ffc28471 Mon Sep 17 00:00:00 2001 From: remi Date: Sun, 2 Nov 2025 23:19:29 +0000 Subject: [PATCH 08/11] Enforce eager False Signed-off-by: remi --- tests/v1/distributed/test_eagle_dp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/distributed/test_eagle_dp.py b/tests/v1/distributed/test_eagle_dp.py index d8a416af7bc0..417bcb409a00 100644 --- a/tests/v1/distributed/test_eagle_dp.py +++ b/tests/v1/distributed/test_eagle_dp.py @@ -23,7 +23,7 @@ async def test_run_eagle_dp(): engine_args = AsyncEngineArgs( model=target_model, tokenizer_mode="auto", - enforce_eager=True, + enforce_eager=False, tensor_parallel_size=int(os.getenv("TP_SIZE", 1)), data_parallel_size=DP_SIZE, data_parallel_backend="mp", # ray takes more time From 0862aa1e4a99b9ff59fe3c90391e9bbe047f6e88 Mon Sep 17 00:00:00 2001 From: remi Date: Fri, 14 Nov 2025 11:16:05 +0000 Subject: [PATCH 09/11] Update for cudagraphs compat Signed-off-by: remi --- vllm/v1/spec_decode/eagle.py | 105 ++++++++++++++++++++++------- vllm/v1/worker/gpu_model_runner.py | 5 +- 2 files changed, 84 insertions(+), 26 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 739ed9e0ff9a..8effc82687b6 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -40,6 +40,7 @@ from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch logger = init_logger(__name__) @@ -272,6 +273,13 @@ def propose( cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: num_input_tokens = num_tokens + + num_input_tokens, num_tokens_across_dp = self._pad_batch_across_dp( + num_tokens_unpadded=num_tokens, + num_tokens_padded=num_input_tokens, + allow_dp_padding=cudagraph_runtime_mode != CUDAGraphMode.NONE, + ) + # copy inputs to buffer for cudagraph self._set_positions(num_tokens, target_positions) self.hidden_states[:num_tokens] = target_hidden_states @@ -295,6 +303,7 @@ def propose( per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, ): ret_hidden_states = self.model( @@ -364,6 +373,12 @@ def propose( input_batch_size = batch_size cudagraph_runtime_mode = CUDAGraphMode.NONE + input_batch_size, batch_size_across_dp = self._pad_batch_across_dp( + num_tokens_unpadded=batch_size, + num_tokens_padded=input_batch_size, + allow_dp_padding=cudagraph_runtime_mode != CUDAGraphMode.NONE, + ) + common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.max_query_len = 1 common_attn_metadata.query_start_loc = self.arange[: batch_size + 1] @@ -460,6 +475,7 @@ def propose( per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size, + num_tokens_across_dp=batch_size_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, ): ret_hidden_states = self.model( @@ -1055,35 +1071,51 @@ def dummy_run( self, num_tokens: int, use_cudagraphs=True, + is_graph_capturing=False, ) -> None: # Determine if CUDA graphs should be used for this run. cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph if cudagraphs_enabled and num_tokens <= self.cudagraph_batch_sizes[-1]: - num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) - # FIXME: when using tree-based specdec, adjust number of forward-passes - # according to the depth of the tree. - for _ in range(self.num_speculative_tokens): - with set_forward_context( - None, - self.vllm_config, - num_tokens=num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE - if cudagraphs_enabled - else CUDAGraphMode.NONE, - ): - if self.supports_mm_inputs: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] - else: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = None - - self.model( - input_ids=input_ids, - positions=self._get_positions(num_tokens), - hidden_states=self.hidden_states[:num_tokens], - inputs_embeds=inputs_embeds, - ) + num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens) + else: + num_tokens_padded = num_tokens + + # FIXME: when using tree-based specdec, adjust number of forward-passes + # according to the depth of the tree. + for fwd_idx in range( + self.num_speculative_tokens if not is_graph_capturing else 1 + ): + if fwd_idx <= 1: + num_input_tokens, num_tokens_across_dp = self._pad_batch_across_dp( + num_tokens_unpadded=num_tokens, + num_tokens_padded=num_tokens_padded, + # We don't use 'cudagraphs_enabled' because a dummy batch runs + # in eager mode although we want it to get padded to match the + # cudagraph-enabled non-dummy dp batch size + allow_dp_padding=self.use_cuda_graph, + ) + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE + if cudagraphs_enabled + else CUDAGraphMode.NONE, + ): + if self.supports_mm_inputs: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] + else: + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + + self.model( + input_ids=input_ids, + positions=self._get_positions(num_input_tokens), + hidden_states=self.hidden_states[:num_input_tokens], + inputs_embeds=inputs_embeds, + ) def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder: """Find and return the attention metadata builders for EAGLE layers. @@ -1133,6 +1165,29 @@ def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: == 1 ), "All eagle layers should belong to the same kv cache group" + def _pad_batch_across_dp( + self, + num_tokens_unpadded: int, + num_tokens_padded: int, + allow_dp_padding: bool, + ) -> tuple[int, torch.Tensor]: + # TODO(Flechman): support DBO ubatching + ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp( + num_tokens_unpadded=num_tokens_unpadded, + parallel_config=self.vllm_config.parallel_config, + allow_microbatching=False, + allow_dp_padding=allow_dp_padding, + num_tokens_padded=num_tokens_padded, + uniform_decode=None, + num_scheduled_tokens_per_request=None, + ) + assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE" + + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + if num_toks_across_dp is not None: + num_tokens_padded = int(num_toks_across_dp[dp_rank].item()) + return num_tokens_padded, num_toks_across_dp + # NOTE(woosuk): Currently, the below code is not used and we always use argmax # to sample the draft tokens. We will use this after we find a way to manage diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d0f7f3a501f5..ce6663876a60 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3423,6 +3423,7 @@ def _dummy_run( create_mixed_batch: bool = False, remove_lora: bool = True, activate_lora: bool = False, + is_graph_capturing: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the @@ -3656,7 +3657,7 @@ def _dummy_run( if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) use_cudagraphs = ( - cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE + cudagraph_runtime_mode.has_mode(CUDAGraphMode.PIECEWISE) and not self.speculative_config.enforce_eager ) @@ -3670,6 +3671,7 @@ def _dummy_run( self.drafter.dummy_run( num_tokens, use_cudagraphs=use_cudagraphs, + is_graph_capturing=is_graph_capturing, ) # This is necessary to avoid blocking DP. @@ -4101,6 +4103,7 @@ def _capture_cudagraphs( skip_eplb=True, remove_lora=False, activate_lora=activate_lora, + is_graph_capturing=True, ) self.maybe_remove_all_loras(self.lora_config) From 3971e3a79516bf9891000e2866998d372e2c7059 Mon Sep 17 00:00:00 2001 From: remi Date: Fri, 14 Nov 2025 13:26:52 +0000 Subject: [PATCH 10/11] pre-commit Signed-off-by: remi --- vllm/v1/spec_decode/eagle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 8effc82687b6..25d18ba78b85 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1089,8 +1089,8 @@ def dummy_run( num_input_tokens, num_tokens_across_dp = self._pad_batch_across_dp( num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens_padded, - # We don't use 'cudagraphs_enabled' because a dummy batch runs - # in eager mode although we want it to get padded to match the + # We don't use 'cudagraphs_enabled' because a dummy batch runs + # in eager mode although we want it to get padded to match the # cudagraph-enabled non-dummy dp batch size allow_dp_padding=self.use_cuda_graph, ) From e28e234ac24e9185594986b9b5c3ad11e360169c Mon Sep 17 00:00:00 2001 From: remi Date: Fri, 14 Nov 2025 14:28:58 +0000 Subject: [PATCH 11/11] decrease max_model_len Signed-off-by: remi --- tests/v1/distributed/test_eagle_dp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v1/distributed/test_eagle_dp.py b/tests/v1/distributed/test_eagle_dp.py index 417bcb409a00..9f6a6614fc1f 100644 --- a/tests/v1/distributed/test_eagle_dp.py +++ b/tests/v1/distributed/test_eagle_dp.py @@ -28,6 +28,7 @@ async def test_run_eagle_dp(): data_parallel_size=DP_SIZE, data_parallel_backend="mp", # ray takes more time trust_remote_code=True, + max_model_len=16384, ) eagle_engine_args = replace(