Skip to content

Commit 9dbeb64

Browse files
authored
Revert previous API changes due to upstream change. (#1155)
Signed-off-by: Lihao Ran <imlihao.ran@gmail.com>
1 parent 5047569 commit 9dbeb64

File tree

5 files changed

+14
-49
lines changed

5 files changed

+14
-49
lines changed

tests/layers/jax/sample/test_rejection_sampler.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -436,9 +436,6 @@ def run_rejection_sampler_test(
436436
batch_size=len(num_draft_tokens),
437437
padded_tokens_length=int(sum(num_draft_tokens)))
438438

439-
# Convert numpy arrays to lists for comparison
440-
parsed_output = [x.tolist() for x in parsed_output]
441-
442439
assert parsed_output == test_case.expected, \
443440
f"Test '{test_case.name}': Expected {test_case.expected}, got {parsed_output}"
444441

@@ -515,9 +512,6 @@ def test_parse_output_basic(self, rejection_sampler):
515512
batch_size=len(num_draft_tokens),
516513
padded_tokens_length=int(sum(num_draft_tokens)))
517514

518-
# Convert numpy arrays to lists for comparison
519-
parsed_output = [x.tolist() for x in parsed_output]
520-
521515
expected = [[10, 20, 30, 40], [50, 60, 70]]
522516
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
523517

@@ -541,9 +535,6 @@ def test_parse_output_with_placeholders(self, rejection_sampler):
541535
batch_size=len(num_draft_tokens),
542536
padded_tokens_length=int(sum(num_draft_tokens)))
543537

544-
# Convert numpy arrays to lists for comparison
545-
parsed_output = [x.tolist() for x in parsed_output]
546-
547538
expected = [[10], [20, 30, 40]]
548539
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
549540

@@ -565,9 +556,6 @@ def test_parse_output_invalid_tokens(self, rejection_sampler):
565556
batch_size=len(num_draft_tokens),
566557
padded_tokens_length=int(sum(num_draft_tokens)))
567558

568-
# Convert numpy arrays to lists for comparison
569-
parsed_output = [x.tolist() for x in parsed_output]
570-
571559
expected = [[10, 20]] # Invalid tokens filtered out
572560
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
573561

@@ -589,9 +577,6 @@ def test_parse_output_empty_sequences(self, rejection_sampler):
589577
batch_size=len(num_draft_tokens),
590578
padded_tokens_length=int(sum(num_draft_tokens)))
591579

592-
# Convert numpy arrays to lists for comparison
593-
parsed_output = [x.tolist() for x in parsed_output]
594-
595580
expected = [[50], [60]] # Only bonus tokens
596581
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
597582

@@ -647,9 +632,6 @@ def test_extreme_padding(self, rejection_sampler, test_helper):
647632
batch_size=len(num_draft_tokens),
648633
padded_tokens_length=int(sum(num_draft_tokens)))
649634

650-
# Convert numpy arrays to lists for comparison
651-
parsed_output = [x.tolist() for x in parsed_output]
652-
653635
expected = [[1, 5]] # Should ignore all padding
654636
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
655637

@@ -795,9 +777,6 @@ def test_single_long_sequence(self, rejection_sampler, test_helper):
795777
batch_size=len(num_draft_tokens),
796778
padded_tokens_length=int(sum(num_draft_tokens)))
797779

798-
# Convert numpy arrays to lists for comparison
799-
parsed_output = [x.tolist() for x in parsed_output]
800-
801780
expected = [list(range(1, 28)) + [99]] # Tokens up to mismatch point
802781
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
803782

@@ -905,9 +884,6 @@ def test_non_greedy_deterministic_with_seed(self, rejection_sampler,
905884
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
906885
batch_size=1,
907886
padded_tokens_length=4)
908-
909-
# Convert numpy arrays to lists for comparison
910-
parsed_output = [x.tolist() for x in parsed_output]
911887
outputs.append(parsed_output)
912888

913889
# All outputs should be identical with same seed
@@ -1088,9 +1064,6 @@ def test_non_greedy_empty_sequence(self, rejection_sampler, test_helper):
10881064
batch_size=2,
10891065
padded_tokens_length=0)
10901066

1091-
# Convert numpy arrays to lists for comparison
1092-
parsed_output = [x.tolist() for x in parsed_output]
1093-
10941067
# Should get bonus tokens for empty sequences
10951068
expected = [[77], [88]]
10961069
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
@@ -1179,10 +1152,6 @@ def test_non_greedy_vs_greedy_same_perfect_case(self, rejection_sampler,
11791152
non_greedy_parsed = rejection_sampler.parse_output(
11801153
non_greedy_output, VOCAB_SIZE, np.asarray(num_draft_tokens), 1, 3)
11811154

1182-
# Convert numpy arrays to lists for comparison
1183-
greedy_parsed = [x.tolist() for x in greedy_parsed]
1184-
non_greedy_parsed = [x.tolist() for x in non_greedy_parsed]
1185-
11861155
# For perfect match, greedy should have all tokens + bonus
11871156
assert greedy_parsed == [[5, 15, 25, 99]]
11881157

tests/runner/test_speculative_decoding_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def test_propose_eagle3_draft_token_ids(self,
321321
)
322322

323323
# Inputs
324-
sampled_token_ids = [np.array([1]), np.array([2])]
324+
sampled_token_ids = [[1], [2]]
325325
aux_hidden_states = MagicMock()
326326
attn_metadata = MagicMock()
327327
attn_metadata.seq_lens.shape = [2]

tpu_inference/layers/jax/sample/rejection_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def parse_output(
128128
num_draft_tokens_cpu: np.ndarray,
129129
batch_size: int,
130130
padded_tokens_length: int,
131-
) -> list[np.ndarray]:
131+
) -> list[list[int]]:
132132
"""Parse the output of the rejection sampler.
133133
134134
Args:
@@ -177,7 +177,7 @@ def parse_output(
177177
else:
178178
seq_tokens = valid_main_tokens
179179

180-
outputs.append(seq_tokens)
180+
outputs.append(seq_tokens.tolist())
181181
start_idx = end_idx
182182

183183
return outputs

tpu_inference/runner/speculative_decoding_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def propose_draft_token_ids(
7878

7979
def propose_eagle3_draft_token_ids(
8080
self,
81-
sampled_token_ids: list[np.ndarray],
81+
sampled_token_ids: list[list[int]],
8282
aux_hidden_states: Optional[tuple[jnp.ndarray, ...]],
8383
attn_metadata: AttentionMetadata,
8484
spec_decode_metadata: Optional[SpecDecodeMetadata],
@@ -91,7 +91,7 @@ def propose_eagle3_draft_token_ids(
9191
req_ids = self.runner.input_batch.req_ids
9292
next_token_ids: list[int] = []
9393
for i, token_ids in enumerate(sampled_token_ids):
94-
if token_ids.size != 0:
94+
if token_ids:
9595
# Common case.
9696
next_token_id = token_ids[-1]
9797
else:

tpu_inference/runner/tpu_runner.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from vllm.v1.kv_cache_interface import KVCacheConfig
2929
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
3030
DraftTokenIds, KVConnectorOutput, LogprobsLists,
31-
LogprobsTensors, ModelRunnerOutput)
31+
ModelRunnerOutput)
3232
from vllm.v1.request import Request
3333
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
3434
from vllm.v1.worker.kv_connector_model_runner_mixin import \
@@ -122,10 +122,9 @@ def get_output(self) -> ModelRunnerOutput:
122122
next_tokens_cpu = next_tokens_cpu[self.logits_indices_selector]
123123
selected_token_ids = np.expand_dims(next_tokens_cpu[:self._num_reqs],
124124
1)
125-
126-
valid_sampled_token_ids = [token_id for token_id in selected_token_ids]
125+
valid_sampled_token_ids = selected_token_ids.tolist()
127126
for i in self._discard_sampled_tokens_req_indices:
128-
valid_sampled_token_ids[i] = np.array([])
127+
valid_sampled_token_ids[i].clear()
129128
self._model_runner_output.sampled_token_ids = valid_sampled_token_ids
130129
return self._model_runner_output
131130

@@ -614,11 +613,11 @@ def _modify_prev_results(self):
614613
next_tokens_cpu = next_tokens_cpu[pre_logits_indices_selector]
615614
selected_token_ids = np.expand_dims(next_tokens_cpu[:len(pre_req_ids)],
616615
1)
617-
valid_sampled_token_ids = [token_id for token_id in selected_token_ids]
616+
valid_sampled_token_ids = selected_token_ids.tolist()
618617

619618
# Mask out the sampled tokens that should not be sampled.
620619
for i in pre_discard_sampled_tokens_req_indices:
621-
valid_sampled_token_ids[i] = np.array([])
620+
valid_sampled_token_ids[i].clear()
622621
# Append sampled tokens
623622
for pre_req_idx, req_state, _ in pre_request_seq_lens:
624623
sampled_ids = valid_sampled_token_ids[pre_req_idx]
@@ -940,9 +939,7 @@ def _sample_from_logits(
940939
if logits_indices_selector is not None:
941940
next_tokens = next_tokens[logits_indices_selector]
942941
selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1)
943-
valid_sampled_token_ids = [
944-
token_id for token_id in selected_token_ids
945-
]
942+
valid_sampled_token_ids = selected_token_ids.tolist()
946943
else:
947944
valid_sampled_token_ids = self.rejection_sampler.parse_output(
948945
next_tokens, self.input_batch.vocab_size,
@@ -951,11 +948,11 @@ def _sample_from_logits(
951948

952949
# Mask out the sampled tokens that should not be sampled.
953950
for i in discard_sampled_tokens_req_indices:
954-
valid_sampled_token_ids[i] = np.array([])
951+
valid_sampled_token_ids[i].clear()
955952
# Append sampled tokens
956953
for req_idx, req_state, _ in request_seq_lens:
957954
sampled_ids = valid_sampled_token_ids[req_idx]
958-
if sampled_ids.size == 0:
955+
if not sampled_ids:
959956
continue
960957

961958
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
@@ -1018,8 +1015,7 @@ def select_local_fn(local_array, local_indices):
10181015

10191016
@staticmethod
10201017
@functools.partial(jax.jit, static_argnames=("max_logprobs", ))
1021-
def _compute_and_gather_logprobs(logits, next_tokens,
1022-
max_logprobs) -> LogprobsTensors:
1018+
def _compute_and_gather_logprobs(logits, next_tokens, max_logprobs):
10231019
logprobs = compute_logprobs(logits)
10241020
return gather_logprobs(logprobs, next_tokens, max_logprobs)
10251021

0 commit comments

Comments
 (0)