diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 52539728215b..c56261994f71 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -192,6 +192,7 @@ steps: # test with internal dp - python3 ../examples/offline_inference/data_parallel.py --enforce-eager - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py @@ -1080,6 +1081,7 @@ steps: # https://github.com/NVIDIA/nccl/issues/1838 - export NCCL_CUMEM_HOST_ENABLE=0 - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py diff --git a/tests/v1/distributed/test_eagle_dp.py b/tests/v1/distributed/test_eagle_dp.py new file mode 100644 index 000000000000..9f6a6614fc1f --- /dev/null +++ b/tests/v1/distributed/test_eagle_dp.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import os +from contextlib import AsyncExitStack +from dataclasses import replace + +import pytest + +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.sampling_params import RequestOutputKind +from vllm.v1.engine.async_llm import AsyncLLM + +DP_SIZE = int(os.getenv("DP_SIZE", 2)) + + +@pytest.mark.asyncio +async def test_run_eagle_dp(): + target_model = "meta-llama/Llama-3.1-8B-Instruct" + draft_model = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + + engine_args = AsyncEngineArgs( + model=target_model, + tokenizer_mode="auto", + enforce_eager=False, + tensor_parallel_size=int(os.getenv("TP_SIZE", 1)), + data_parallel_size=DP_SIZE, + data_parallel_backend="mp", # ray takes more time + trust_remote_code=True, + max_model_len=16384, + ) + + eagle_engine_args = replace( + engine_args, + speculative_config={ + "model": draft_model, + "method": "eagle", + "num_speculative_tokens": 3, + }, + ) + + prompt = "This is a test of data parallel with eagle" + num_expected_tokens = 100 + sampling_params = SamplingParams( + min_tokens=num_expected_tokens, + max_tokens=num_expected_tokens, + ignore_eos=True, + output_kind=RequestOutputKind.FINAL_ONLY, + temperature=0, + ) + + async def generate_with_timeout(given_engine: AsyncLLM): + async for out in given_engine.generate( + request_id="test-eagle-dp", prompt=prompt, sampling_params=sampling_params + ): + token_ids = out.outputs[0].token_ids + assert len(token_ids) == num_expected_tokens + return token_ids + + async def engine_create_and_generate(engine_args: AsyncEngineArgs): + async with AsyncExitStack() as after: + engine = AsyncLLM.from_engine_args(engine_args) + after.callback(engine.shutdown) + + token_ids = await asyncio.wait_for( + generate_with_timeout(engine), timeout=30 + ) + + assert not engine.output_processor.has_unfinished_requests() + return token_ids + + token_ids_with_eagle = await engine_create_and_generate(eagle_engine_args) + token_ids_no_eagle = await engine_create_and_generate(engine_args) + + # Test for correctness + assert token_ids_with_eagle == token_ids_no_eagle diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index beef5203e039..25d18ba78b85 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -40,6 +40,7 @@ from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch logger = init_logger(__name__) @@ -272,6 +273,13 @@ def propose( cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: num_input_tokens = num_tokens + + num_input_tokens, num_tokens_across_dp = self._pad_batch_across_dp( + num_tokens_unpadded=num_tokens, + num_tokens_padded=num_input_tokens, + allow_dp_padding=cudagraph_runtime_mode != CUDAGraphMode.NONE, + ) + # copy inputs to buffer for cudagraph self._set_positions(num_tokens, target_positions) self.hidden_states[:num_tokens] = target_hidden_states @@ -295,6 +303,7 @@ def propose( per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, ): ret_hidden_states = self.model( @@ -364,6 +373,12 @@ def propose( input_batch_size = batch_size cudagraph_runtime_mode = CUDAGraphMode.NONE + input_batch_size, batch_size_across_dp = self._pad_batch_across_dp( + num_tokens_unpadded=batch_size, + num_tokens_padded=input_batch_size, + allow_dp_padding=cudagraph_runtime_mode != CUDAGraphMode.NONE, + ) + common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.max_query_len = 1 common_attn_metadata.query_start_loc = self.arange[: batch_size + 1] @@ -460,6 +475,7 @@ def propose( per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size, + num_tokens_across_dp=batch_size_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, ): ret_hidden_states = self.model( @@ -1055,33 +1071,51 @@ def dummy_run( self, num_tokens: int, use_cudagraphs=True, + is_graph_capturing=False, ) -> None: # Determine if CUDA graphs should be used for this run. cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph if cudagraphs_enabled and num_tokens <= self.cudagraph_batch_sizes[-1]: - num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens) + else: + num_tokens_padded = num_tokens - with set_forward_context( - None, - self.vllm_config, - num_tokens=num_tokens, - cudagraph_runtime_mode=( - CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE - ), + # FIXME: when using tree-based specdec, adjust number of forward-passes + # according to the depth of the tree. + for fwd_idx in range( + self.num_speculative_tokens if not is_graph_capturing else 1 ): - if self.supports_mm_inputs: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] - else: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = None + if fwd_idx <= 1: + num_input_tokens, num_tokens_across_dp = self._pad_batch_across_dp( + num_tokens_unpadded=num_tokens, + num_tokens_padded=num_tokens_padded, + # We don't use 'cudagraphs_enabled' because a dummy batch runs + # in eager mode although we want it to get padded to match the + # cudagraph-enabled non-dummy dp batch size + allow_dp_padding=self.use_cuda_graph, + ) + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE + if cudagraphs_enabled + else CUDAGraphMode.NONE, + ): + if self.supports_mm_inputs: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] + else: + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None - self.model( - input_ids=input_ids, - positions=self._get_positions(num_tokens), - hidden_states=self.hidden_states[:num_tokens], - inputs_embeds=inputs_embeds, - ) + self.model( + input_ids=input_ids, + positions=self._get_positions(num_input_tokens), + hidden_states=self.hidden_states[:num_input_tokens], + inputs_embeds=inputs_embeds, + ) def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder: """Find and return the attention metadata builders for EAGLE layers. @@ -1131,6 +1165,29 @@ def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: == 1 ), "All eagle layers should belong to the same kv cache group" + def _pad_batch_across_dp( + self, + num_tokens_unpadded: int, + num_tokens_padded: int, + allow_dp_padding: bool, + ) -> tuple[int, torch.Tensor]: + # TODO(Flechman): support DBO ubatching + ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp( + num_tokens_unpadded=num_tokens_unpadded, + parallel_config=self.vllm_config.parallel_config, + allow_microbatching=False, + allow_dp_padding=allow_dp_padding, + num_tokens_padded=num_tokens_padded, + uniform_decode=None, + num_scheduled_tokens_per_request=None, + ) + assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE" + + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + if num_toks_across_dp is not None: + num_tokens_padded = int(num_toks_across_dp[dp_rank].item()) + return num_tokens_padded, num_toks_across_dp + # NOTE(woosuk): Currently, the below code is not used and we always use argmax # to sample the draft tokens. We will use this after we find a way to manage diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d0f7f3a501f5..ce6663876a60 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3423,6 +3423,7 @@ def _dummy_run( create_mixed_batch: bool = False, remove_lora: bool = True, activate_lora: bool = False, + is_graph_capturing: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the @@ -3656,7 +3657,7 @@ def _dummy_run( if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) use_cudagraphs = ( - cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE + cudagraph_runtime_mode.has_mode(CUDAGraphMode.PIECEWISE) and not self.speculative_config.enforce_eager ) @@ -3670,6 +3671,7 @@ def _dummy_run( self.drafter.dummy_run( num_tokens, use_cudagraphs=use_cudagraphs, + is_graph_capturing=is_graph_capturing, ) # This is necessary to avoid blocking DP. @@ -4101,6 +4103,7 @@ def _capture_cudagraphs( skip_eplb=True, remove_lora=False, activate_lora=activate_lora, + is_graph_capturing=True, ) self.maybe_remove_all_loras(self.lora_config)