6464from vllm .model_executor .models .utils import sequence_parallel_chunk
6565from vllm .sequence import IntermediateTensors
6666
67- from .interfaces import MixtureOfExperts , SupportsLoRA , SupportsPP
67+ from .interfaces import MixtureOfExperts , SupportsEagle3 , SupportsLoRA , SupportsPP
6868from .utils import (
6969 AutoWeightsLoader ,
7070 PPMissingLayer ,
@@ -422,6 +422,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
422422 self .make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory (
423423 ["hidden_states" , "residual" ], config .hidden_size
424424 )
425+ self .aux_hidden_state_layers = tuple [int , ...]()
425426
426427 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
427428 return self .embed_tokens (input_ids )
@@ -443,13 +444,25 @@ def forward(
443444 assert intermediate_tensors is not None
444445 hidden_states = intermediate_tensors ["hidden_states" ]
445446 residual = intermediate_tensors ["residual" ]
446- for layer in islice (self .layers , self .start_layer , self .end_layer ):
447+
448+ aux_hidden_states = []
449+ for idx , layer in enumerate (
450+ islice (self .layers , self .start_layer , self .end_layer )
451+ ):
452+ if idx in self .aux_hidden_state_layers :
453+ aux_hidden_states .append (hidden_states + residual )
447454 hidden_states , residual = layer (positions , hidden_states , residual )
455+
448456 if not get_pp_group ().is_last_rank :
449457 return IntermediateTensors (
450458 {"hidden_states" : hidden_states , "residual" : residual }
451459 )
460+
452461 hidden_states , _ = self .norm (hidden_states , residual )
462+
463+ if len (aux_hidden_states ) > 0 :
464+ return hidden_states , aux_hidden_states
465+
453466 return hidden_states
454467
455468 def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
@@ -606,7 +619,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
606619 return loaded_params
607620
608621
609- class Qwen3MoeForCausalLM (nn .Module , SupportsPP , SupportsLoRA , MixtureOfExperts ):
622+ class Qwen3MoeForCausalLM (
623+ nn .Module , SupportsPP , SupportsLoRA , SupportsEagle3 , MixtureOfExperts
624+ ):
610625 packed_modules_mapping = {
611626 "qkv_proj" : [
612627 "q_proj" ,
@@ -669,6 +684,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
669684 self .num_routed_experts = example_layer .n_routed_experts
670685 self .num_redundant_experts = example_layer .n_redundant_experts
671686
687+ def set_aux_hidden_state_layers (self , layers : tuple [int , ...]) -> None :
688+ self .model .aux_hidden_state_layers = layers
689+
690+ def get_eagle3_aux_hidden_state_layers (self ) -> tuple [int , ...]:
691+ num_layers = len (self .model .layers )
692+ return (2 , num_layers // 2 , num_layers - 3 )
693+
672694 def set_eplb_state (
673695 self ,
674696 expert_load_view : torch .Tensor ,
0 commit comments