Skip to content

Commit cb27cef

Browse files
committed
[SpecDecode] Support EAGLE for Qwen3 MoE
Signed-off-by: seven-mile <i@7li.moe>
1 parent c50901f commit cb27cef

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

vllm/model_executor/models/qwen3_moe.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
from vllm.model_executor.models.utils import sequence_parallel_chunk
6565
from vllm.sequence import IntermediateTensors
6666

67-
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
67+
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
6868
from .utils import (
6969
AutoWeightsLoader,
7070
PPMissingLayer,
@@ -422,6 +422,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
422422
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
423423
["hidden_states", "residual"], config.hidden_size
424424
)
425+
self.aux_hidden_state_layers = tuple[int, ...]()
425426

426427
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
427428
return self.embed_tokens(input_ids)
@@ -443,13 +444,25 @@ def forward(
443444
assert intermediate_tensors is not None
444445
hidden_states = intermediate_tensors["hidden_states"]
445446
residual = intermediate_tensors["residual"]
446-
for layer in islice(self.layers, self.start_layer, self.end_layer):
447+
448+
aux_hidden_states = []
449+
for idx, layer in enumerate(
450+
islice(self.layers, self.start_layer, self.end_layer)
451+
):
452+
if idx in self.aux_hidden_state_layers:
453+
aux_hidden_states.append(hidden_states + residual)
447454
hidden_states, residual = layer(positions, hidden_states, residual)
455+
448456
if not get_pp_group().is_last_rank:
449457
return IntermediateTensors(
450458
{"hidden_states": hidden_states, "residual": residual}
451459
)
460+
452461
hidden_states, _ = self.norm(hidden_states, residual)
462+
463+
if len(aux_hidden_states) > 0:
464+
return hidden_states, aux_hidden_states
465+
453466
return hidden_states
454467

455468
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
@@ -606,7 +619,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
606619
return loaded_params
607620

608621

609-
class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExperts):
622+
class Qwen3MoeForCausalLM(
623+
nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts
624+
):
610625
packed_modules_mapping = {
611626
"qkv_proj": [
612627
"q_proj",
@@ -669,6 +684,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
669684
self.num_routed_experts = example_layer.n_routed_experts
670685
self.num_redundant_experts = example_layer.n_redundant_experts
671686

687+
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
688+
self.model.aux_hidden_state_layers = layers
689+
690+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
691+
num_layers = len(self.model.layers)
692+
return (2, num_layers // 2, num_layers - 3)
693+
672694
def set_eplb_state(
673695
self,
674696
expert_load_view: torch.Tensor,

0 commit comments

Comments
 (0)