Skip to content
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ steps:
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
Expand Down Expand Up @@ -1080,6 +1081,7 @@ steps:
# https://github.com/NVIDIA/nccl/issues/1838
- export NCCL_CUMEM_HOST_ENABLE=0
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py
Expand Down
77 changes: 77 additions & 0 deletions tests/v1/distributed/test_eagle_dp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
from contextlib import AsyncExitStack
from dataclasses import replace

import pytest

from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM

DP_SIZE = int(os.getenv("DP_SIZE", 2))


@pytest.mark.asyncio
async def test_run_eagle_dp():
target_model = "meta-llama/Llama-3.1-8B-Instruct"
draft_model = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"

engine_args = AsyncEngineArgs(
model=target_model,
tokenizer_mode="auto",
enforce_eager=False,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp", # ray takes more time
trust_remote_code=True,
max_model_len=16384,
)

eagle_engine_args = replace(
engine_args,
speculative_config={
"model": draft_model,
"method": "eagle",
"num_speculative_tokens": 3,
},
)

prompt = "This is a test of data parallel with eagle"
num_expected_tokens = 100
sampling_params = SamplingParams(
min_tokens=num_expected_tokens,
max_tokens=num_expected_tokens,
ignore_eos=True,
output_kind=RequestOutputKind.FINAL_ONLY,
temperature=0,
)

async def generate_with_timeout(given_engine: AsyncLLM):
async for out in given_engine.generate(
request_id="test-eagle-dp", prompt=prompt, sampling_params=sampling_params
):
token_ids = out.outputs[0].token_ids
assert len(token_ids) == num_expected_tokens
return token_ids

async def engine_create_and_generate(engine_args: AsyncEngineArgs):
async with AsyncExitStack() as after:
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

token_ids = await asyncio.wait_for(
generate_with_timeout(engine), timeout=30
)

assert not engine.output_processor.has_unfinished_requests()
return token_ids

token_ids_with_eagle = await engine_create_and_generate(eagle_engine_args)
token_ids_no_eagle = await engine_create_and_generate(engine_args)

# Test for correctness
assert token_ids_with_eagle == token_ids_no_eagle
97 changes: 77 additions & 20 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch

logger = init_logger(__name__)
Expand Down Expand Up @@ -272,6 +273,13 @@ def propose(
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
num_input_tokens = num_tokens

num_input_tokens, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens,
num_tokens_padded=num_input_tokens,
allow_dp_padding=cudagraph_runtime_mode != CUDAGraphMode.NONE,
)

# copy inputs to buffer for cudagraph
self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states
Expand All @@ -295,6 +303,7 @@ def propose(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
Expand Down Expand Up @@ -364,6 +373,12 @@ def propose(
input_batch_size = batch_size
cudagraph_runtime_mode = CUDAGraphMode.NONE

input_batch_size, batch_size_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=batch_size,
num_tokens_padded=input_batch_size,
allow_dp_padding=cudagraph_runtime_mode != CUDAGraphMode.NONE,
)

common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
Expand Down Expand Up @@ -460,6 +475,7 @@ def propose(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size,
num_tokens_across_dp=batch_size_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
Expand Down Expand Up @@ -1055,33 +1071,51 @@ def dummy_run(
self,
num_tokens: int,
use_cudagraphs=True,
is_graph_capturing=False,
) -> None:
# Determine if CUDA graphs should be used for this run.
cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
if cudagraphs_enabled and num_tokens <= self.cudagraph_batch_sizes[-1]:
num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_tokens_padded = num_tokens

with set_forward_context(
None,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=(
CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE
),
# FIXME: when using tree-based specdec, adjust number of forward-passes
# according to the depth of the tree.
for fwd_idx in range(
self.num_speculative_tokens if not is_graph_capturing else 1
):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
if fwd_idx <= 1:
num_input_tokens, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens,
num_tokens_padded=num_tokens_padded,
# We don't use 'cudagraphs_enabled' because a dummy batch runs
# in eager mode although we want it to get padded to match the
# cudagraph-enabled non-dummy dp batch size
allow_dp_padding=self.use_cuda_graph,
)
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
if cudagraphs_enabled
else CUDAGraphMode.NONE,
):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None

self.model(
input_ids=input_ids,
positions=self._get_positions(num_tokens),
hidden_states=self.hidden_states[:num_tokens],
inputs_embeds=inputs_embeds,
)
self.model(
input_ids=input_ids,
positions=self._get_positions(num_input_tokens),
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)

def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
"""Find and return the attention metadata builders for EAGLE layers.
Expand Down Expand Up @@ -1131,6 +1165,29 @@ def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
== 1
), "All eagle layers should belong to the same kv cache group"

def _pad_batch_across_dp(
self,
num_tokens_unpadded: int,
num_tokens_padded: int,
allow_dp_padding: bool,
) -> tuple[int, torch.Tensor]:
# TODO(Flechman): support DBO ubatching
ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_unpadded,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
allow_dp_padding=allow_dp_padding,
num_tokens_padded=num_tokens_padded,
uniform_decode=None,
num_scheduled_tokens_per_request=None,
)
assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE"

dp_rank = self.vllm_config.parallel_config.data_parallel_rank
if num_toks_across_dp is not None:
num_tokens_padded = int(num_toks_across_dp[dp_rank].item())
return num_tokens_padded, num_toks_across_dp


# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3423,6 +3423,7 @@ def _dummy_run(
create_mixed_batch: bool = False,
remove_lora: bool = True,
activate_lora: bool = False,
is_graph_capturing: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
Expand Down Expand Up @@ -3656,7 +3657,7 @@ def _dummy_run(
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
use_cudagraphs = (
cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
cudagraph_runtime_mode.has_mode(CUDAGraphMode.PIECEWISE)
and not self.speculative_config.enforce_eager
)

Expand All @@ -3670,6 +3671,7 @@ def _dummy_run(
self.drafter.dummy_run(
num_tokens,
use_cudagraphs=use_cudagraphs,
is_graph_capturing=is_graph_capturing,
)

# This is necessary to avoid blocking DP.
Expand Down Expand Up @@ -4101,6 +4103,7 @@ def _capture_cudagraphs(
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
is_graph_capturing=True,
)
self.maybe_remove_all_loras(self.lora_config)

Expand Down