diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index dc6db0138806..743265b2e50d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -585,6 +585,16 @@ def prepare_inputs_padded(self, return spec_common_attn_metadata, token_indices, token_indices_to_sample + def _is_tree_attention(self) -> bool: + if not hasattr(self.runner, + "attn_groups") or not self.runner.attn_groups: + return False + + tree_attn_metadata_builder = \ + self.runner.attn_groups[0][0].get_metadata_builder() + return isinstance(tree_attn_metadata_builder, + TreeAttentionMetadataBuilder) + def propose_tree( self, batch_size: int, @@ -980,21 +990,26 @@ def dummy_run( self, num_tokens: int, ) -> None: - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): - if self.supports_mm_inputs: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] - else: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = None + assert not self._is_tree_attention( + ), "Dummy run for tree attention not implemented" - self.model( - input_ids=input_ids, - positions=self._get_positions(num_tokens), - hidden_states=self.hidden_states[:num_tokens], - inputs_embeds=inputs_embeds, - ) + for _ in range(self.num_speculative_tokens): + with set_forward_context(None, + self.vllm_config, + num_tokens=num_tokens): + if self.supports_mm_inputs: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + + self.model( + input_ids=input_ids, + positions=self._get_positions(num_tokens), + hidden_states=self.hidden_states[:num_tokens], + inputs_embeds=inputs_embeds, + ) def _get_attention_metadata_builder( self) -> list[AttentionMetadataBuilder]: