1515from flax import nnx
1616from jax .experimental import mesh_utils
1717from jax .sharding import NamedSharding , PartitionSpec
18- from torchax .ops .mappings import j2t_dtype
18+ from torchax .ops .mappings import j2t , j2t_dtype
1919from vllm .config import VllmConfig
2020from vllm .distributed .kv_transfer import (get_kv_transfer_group ,
2121 has_kv_transfer_group )
2828from vllm .v1 .kv_cache_interface import KVCacheConfig
2929from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , AsyncModelRunnerOutput ,
3030 DraftTokenIds , KVConnectorOutput , LogprobsLists ,
31- ModelRunnerOutput )
31+ LogprobsTensors , ModelRunnerOutput )
3232from vllm .v1 .request import Request
3333from vllm .v1 .spec_decode .ngram_proposer import NgramProposer
3434from vllm .v1 .worker .kv_connector_model_runner_mixin import \
@@ -122,9 +122,10 @@ def get_output(self) -> ModelRunnerOutput:
122122 next_tokens_cpu = next_tokens_cpu [self .logits_indices_selector ]
123123 selected_token_ids = np .expand_dims (next_tokens_cpu [:self ._num_reqs ],
124124 1 )
125- valid_sampled_token_ids = selected_token_ids .tolist ()
125+
126+ valid_sampled_token_ids = [token_id for token_id in selected_token_ids ]
126127 for i in self ._discard_sampled_tokens_req_indices :
127- valid_sampled_token_ids [i ]. clear ( )
128+ valid_sampled_token_ids [i ] = np . array ([] )
128129 self ._model_runner_output .sampled_token_ids = valid_sampled_token_ids
129130 return self ._model_runner_output
130131
@@ -190,7 +191,8 @@ def _substitute_placeholder_token(
190191 return input_ids .at [token_in_tpu_cur_input_indices ].set (update_values )
191192
192193
193- def _reorder_logits_indices (logprobs_lists , logits_indices_selector ):
194+ def _reorder_logits_indices (logprobs_lists : LogprobsLists ,
195+ logits_indices_selector : List [int ]):
194196 return LogprobsLists (
195197 logprob_token_ids = [
196198 logprobs_lists .logprob_token_ids [i ]
@@ -595,11 +597,11 @@ def _modify_prev_results(self):
595597 next_tokens_cpu = next_tokens_cpu [pre_logits_indices_selector ]
596598 selected_token_ids = np .expand_dims (next_tokens_cpu [:len (pre_req_ids )],
597599 1 )
598- valid_sampled_token_ids = selected_token_ids . tolist ()
600+ valid_sampled_token_ids = [ token_id for token_id in selected_token_ids ]
599601
600602 # Mask out the sampled tokens that should not be sampled.
601603 for i in pre_discard_sampled_tokens_req_indices :
602- valid_sampled_token_ids [i ]. clear ( )
604+ valid_sampled_token_ids [i ] = np . array ([] )
603605 # Append sampled tokens
604606 for pre_req_idx , req_state , _ in pre_request_seq_lens :
605607 sampled_ids = valid_sampled_token_ids [pre_req_idx ]
@@ -804,6 +806,8 @@ def _sample_from_logits(
804806 if tpu_sampling_metadata .logprobs :
805807 logprobs = self ._compute_and_gather_logprobs (
806808 logits , next_tokens , self .model_config .max_logprobs )
809+ logprobs_lists = jax .tree .map (lambda x : j2t (x .astype (jnp .float32 )),
810+ logprobs ).tolists ()
807811 else :
808812 logprobs = None
809813
@@ -856,7 +860,6 @@ def _sample_from_logits(
856860
857861 if logprobs is not None :
858862 # Map logprobs back to the pre-dp shuffling order
859- logprobs_lists = logprobs .tolists ()
860863 if logits_indices_selector is not None :
861864 logprobs_lists = _reorder_logits_indices (
862865 logprobs_lists , logits_indices_selector )
@@ -898,7 +901,9 @@ def _sample_from_logits(
898901 if logits_indices_selector is not None :
899902 next_tokens = next_tokens [logits_indices_selector ]
900903 selected_token_ids = np .expand_dims (next_tokens [:num_reqs ], 1 )
901- valid_sampled_token_ids = selected_token_ids .tolist ()
904+ valid_sampled_token_ids = [
905+ token_id for token_id in selected_token_ids
906+ ]
902907 else :
903908 valid_sampled_token_ids = self .rejection_sampler .parse_output (
904909 next_tokens , self .input_batch .vocab_size ,
@@ -907,7 +912,7 @@ def _sample_from_logits(
907912
908913 # Mask out the sampled tokens that should not be sampled.
909914 for i in discard_sampled_tokens_req_indices :
910- valid_sampled_token_ids [i ]. clear ( )
915+ valid_sampled_token_ids [i ] = np . array ([] )
911916 # Append sampled tokens
912917 for req_idx , req_state , _ in request_seq_lens :
913918 sampled_ids = valid_sampled_token_ids [req_idx ]
@@ -929,7 +934,6 @@ def _sample_from_logits(
929934
930935 if logprobs is not None :
931936 # Map logprobs back to the pre-dp shuffling order
932- logprobs_lists = logprobs .tolists ()
933937 if logits_indices_selector is not None :
934938 logprobs_lists = _reorder_logits_indices (
935939 logprobs_lists , logits_indices_selector )
@@ -976,7 +980,8 @@ def select_local_fn(local_array, local_indices):
976980
977981 @staticmethod
978982 @functools .partial (jax .jit , static_argnames = ("max_logprobs" , ))
979- def _compute_and_gather_logprobs (logits , next_tokens , max_logprobs ):
983+ def _compute_and_gather_logprobs (logits , next_tokens ,
984+ max_logprobs ) -> LogprobsTensors :
980985 logprobs = compute_logprobs (logits )
981986 return gather_logprobs (logprobs , next_tokens , max_logprobs )
982987
0 commit comments