Skip to content

Commit 45edde6

Browse files
authored
[Bugfix] Fix error where vLLM expects numpy sampled token ids (#1119)
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 01aeb6f commit 45edde6

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

tpu_inference/runner/tpu_runner.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from flax import nnx
1616
from jax.experimental import mesh_utils
1717
from jax.sharding import NamedSharding, PartitionSpec
18-
from torchax.ops.mappings import j2t_dtype
18+
from torchax.ops.mappings import j2t, j2t_dtype
1919
from vllm.config import VllmConfig
2020
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
2121
has_kv_transfer_group)
@@ -28,7 +28,7 @@
2828
from vllm.v1.kv_cache_interface import KVCacheConfig
2929
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
3030
DraftTokenIds, KVConnectorOutput, LogprobsLists,
31-
ModelRunnerOutput)
31+
LogprobsTensors, ModelRunnerOutput)
3232
from vllm.v1.request import Request
3333
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
3434
from vllm.v1.worker.kv_connector_model_runner_mixin import \
@@ -122,9 +122,10 @@ def get_output(self) -> ModelRunnerOutput:
122122
next_tokens_cpu = next_tokens_cpu[self.logits_indices_selector]
123123
selected_token_ids = np.expand_dims(next_tokens_cpu[:self._num_reqs],
124124
1)
125-
valid_sampled_token_ids = selected_token_ids.tolist()
125+
126+
valid_sampled_token_ids = [token_id for token_id in selected_token_ids]
126127
for i in self._discard_sampled_tokens_req_indices:
127-
valid_sampled_token_ids[i].clear()
128+
valid_sampled_token_ids[i] = np.array([])
128129
self._model_runner_output.sampled_token_ids = valid_sampled_token_ids
129130
return self._model_runner_output
130131

@@ -190,7 +191,8 @@ def _substitute_placeholder_token(
190191
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
191192

192193

193-
def _reorder_logits_indices(logprobs_lists, logits_indices_selector):
194+
def _reorder_logits_indices(logprobs_lists: LogprobsLists,
195+
logits_indices_selector: List[int]):
194196
return LogprobsLists(
195197
logprob_token_ids=[
196198
logprobs_lists.logprob_token_ids[i]
@@ -595,11 +597,11 @@ def _modify_prev_results(self):
595597
next_tokens_cpu = next_tokens_cpu[pre_logits_indices_selector]
596598
selected_token_ids = np.expand_dims(next_tokens_cpu[:len(pre_req_ids)],
597599
1)
598-
valid_sampled_token_ids = selected_token_ids.tolist()
600+
valid_sampled_token_ids = [token_id for token_id in selected_token_ids]
599601

600602
# Mask out the sampled tokens that should not be sampled.
601603
for i in pre_discard_sampled_tokens_req_indices:
602-
valid_sampled_token_ids[i].clear()
604+
valid_sampled_token_ids[i] = np.array([])
603605
# Append sampled tokens
604606
for pre_req_idx, req_state, _ in pre_request_seq_lens:
605607
sampled_ids = valid_sampled_token_ids[pre_req_idx]
@@ -804,6 +806,8 @@ def _sample_from_logits(
804806
if tpu_sampling_metadata.logprobs:
805807
logprobs = self._compute_and_gather_logprobs(
806808
logits, next_tokens, self.model_config.max_logprobs)
809+
logprobs_lists = jax.tree.map(lambda x: j2t(x.astype(jnp.float32)),
810+
logprobs).tolists()
807811
else:
808812
logprobs = None
809813

@@ -856,7 +860,6 @@ def _sample_from_logits(
856860

857861
if logprobs is not None:
858862
# Map logprobs back to the pre-dp shuffling order
859-
logprobs_lists = logprobs.tolists()
860863
if logits_indices_selector is not None:
861864
logprobs_lists = _reorder_logits_indices(
862865
logprobs_lists, logits_indices_selector)
@@ -898,7 +901,9 @@ def _sample_from_logits(
898901
if logits_indices_selector is not None:
899902
next_tokens = next_tokens[logits_indices_selector]
900903
selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1)
901-
valid_sampled_token_ids = selected_token_ids.tolist()
904+
valid_sampled_token_ids = [
905+
token_id for token_id in selected_token_ids
906+
]
902907
else:
903908
valid_sampled_token_ids = self.rejection_sampler.parse_output(
904909
next_tokens, self.input_batch.vocab_size,
@@ -907,7 +912,7 @@ def _sample_from_logits(
907912

908913
# Mask out the sampled tokens that should not be sampled.
909914
for i in discard_sampled_tokens_req_indices:
910-
valid_sampled_token_ids[i].clear()
915+
valid_sampled_token_ids[i] = np.array([])
911916
# Append sampled tokens
912917
for req_idx, req_state, _ in request_seq_lens:
913918
sampled_ids = valid_sampled_token_ids[req_idx]
@@ -929,7 +934,6 @@ def _sample_from_logits(
929934

930935
if logprobs is not None:
931936
# Map logprobs back to the pre-dp shuffling order
932-
logprobs_lists = logprobs.tolists()
933937
if logits_indices_selector is not None:
934938
logprobs_lists = _reorder_logits_indices(
935939
logprobs_lists, logits_indices_selector)
@@ -976,7 +980,8 @@ def select_local_fn(local_array, local_indices):
976980

977981
@staticmethod
978982
@functools.partial(jax.jit, static_argnames=("max_logprobs", ))
979-
def _compute_and_gather_logprobs(logits, next_tokens, max_logprobs):
983+
def _compute_and_gather_logprobs(logits, next_tokens,
984+
max_logprobs) -> LogprobsTensors:
980985
logprobs = compute_logprobs(logits)
981986
return gather_logprobs(logprobs, next_tokens, max_logprobs)
982987

0 commit comments

Comments
 (0)