File tree Expand file tree Collapse file tree 1 file changed +14
-6
lines changed Expand file tree Collapse file tree 1 file changed +14
-6
lines changed Original file line number Diff line number Diff line change @@ -38,16 +38,24 @@ def propose(
3838 self ,
3939 target_hidden_states : torch .Tensor ,
4040 sampling_metadata : SamplingMetadata ,
41- ) -> list [ list [ int ]] :
41+ ) -> torch . Tensor :
4242 # Generate blocks and compute logits
4343 blocks = self .model (target_hidden_states )
4444 logits = self .model .compute_logits (blocks )
4545
46- # Get draft tokens and transpose the result
47- # TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU
48- # synchronization.
49- draft_tokens = [logit .argmax (dim = - 1 ).tolist () for logit in logits ]
50- return [list (row ) for row in zip (* draft_tokens )]
46+ # Compute argmax for each Medusa head and stack into a single tensor
47+ # Shape: [batch_size, num_heads]
48+ draft_tokens = torch .stack ([logit .argmax (dim = - 1 ) for logit in logits ], dim = 1 )
49+
50+ # Sanity check to catch any unexpected shape mismatch early
51+ batch_size = target_hidden_states .shape [0 ]
52+ num_heads = len (logits )
53+ assert draft_tokens .shape == (
54+ batch_size ,
55+ num_heads ,
56+ ), f"Expected shape ({ batch_size } , { num_heads } ), got { draft_tokens .shape } "
57+
58+ return draft_tokens
5159
5260 def load_model (self , target_model : nn .Module ) -> None :
5361 from vllm .compilation .backends import set_model_tag
You can’t perform that action at this time.
0 commit comments