Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,16 @@ def prepare_inputs_padded(self,

return spec_common_attn_metadata, token_indices, token_indices_to_sample

def _is_tree_attention(self) -> bool:
if not hasattr(self.runner,
"attn_groups") or not self.runner.attn_groups:
return False
Comment on lines +589 to +591
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current check for self.runner.attn_groups is not sufficient to prevent a potential IndexError. If self.runner.attn_groups is [[]], the condition not self.runner.attn_groups will evaluate to False, but the subsequent access self.runner.attn_groups[0][0] on line 594 will raise an IndexError because the inner list is empty. This could lead to a server crash. To prevent this, you should also check if the inner list is empty.

Suggested change
if not hasattr(self.runner,
"attn_groups") or not self.runner.attn_groups:
return False
if (not hasattr(self.runner, "attn_groups") or
not self.runner.attn_groups or not self.runner.attn_groups[0]):
return False


tree_attn_metadata_builder = \
self.runner.attn_groups[0][0].get_metadata_builder()
return isinstance(tree_attn_metadata_builder,
TreeAttentionMetadataBuilder)

def propose_tree(
self,
batch_size: int,
Expand Down Expand Up @@ -980,21 +990,26 @@ def dummy_run(
self,
num_tokens: int,
) -> None:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
assert not self._is_tree_attention(
), "Dummy run for tree attention not implemented"

self.model(
input_ids=input_ids,
positions=self._get_positions(num_tokens),
hidden_states=self.hidden_states[:num_tokens],
inputs_embeds=inputs_embeds,
)
for _ in range(self.num_speculative_tokens):
with set_forward_context(None,
self.vllm_config,
num_tokens=num_tokens):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None

self.model(
input_ids=input_ids,
positions=self._get_positions(num_tokens),
hidden_states=self.hidden_states[:num_tokens],
inputs_embeds=inputs_embeds,
)

def _get_attention_metadata_builder(
self) -> list[AttentionMetadataBuilder]:
Expand Down
Loading