From fffc1c4e0559112a8de8b3407f19ad8c8133a98f Mon Sep 17 00:00:00 2001 From: seven-mile Date: Sat, 4 Oct 2025 06:52:33 +0000 Subject: [PATCH] [Spec Decode] Fix DP hang when some ranks do dummy runs Signed-off-by: seven-mile --- vllm/v1/spec_decode/eagle.py | 43 ++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index dc6db0138806..743265b2e50d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -585,6 +585,16 @@ def prepare_inputs_padded(self, return spec_common_attn_metadata, token_indices, token_indices_to_sample + def _is_tree_attention(self) -> bool: + if not hasattr(self.runner, + "attn_groups") or not self.runner.attn_groups: + return False + + tree_attn_metadata_builder = \ + self.runner.attn_groups[0][0].get_metadata_builder() + return isinstance(tree_attn_metadata_builder, + TreeAttentionMetadataBuilder) + def propose_tree( self, batch_size: int, @@ -980,21 +990,26 @@ def dummy_run( self, num_tokens: int, ) -> None: - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): - if self.supports_mm_inputs: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] - else: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = None + assert not self._is_tree_attention( + ), "Dummy run for tree attention not implemented" - self.model( - input_ids=input_ids, - positions=self._get_positions(num_tokens), - hidden_states=self.hidden_states[:num_tokens], - inputs_embeds=inputs_embeds, - ) + for _ in range(self.num_speculative_tokens): + with set_forward_context(None, + self.vllm_config, + num_tokens=num_tokens): + if self.supports_mm_inputs: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + + self.model( + input_ids=input_ids, + positions=self._get_positions(num_tokens), + hidden_states=self.hidden_states[:num_tokens], + inputs_embeds=inputs_embeds, + ) def _get_attention_metadata_builder( self) -> list[AttentionMetadataBuilder]: