Skip to content

Commit 83249b2

Browse files
py4Pooya Moradi
andauthored
[Spec Decoding] Fix API error caused by upstream. (#1133)
Co-authored-by: Pooya Moradi <pooyam@google.com>
1 parent c0c8192 commit 83249b2

File tree

5 files changed

+37
-6
lines changed

5 files changed

+37
-6
lines changed

tests/layers/jax/sample/test_rejection_sampler.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,9 @@ def run_rejection_sampler_test(
436436
batch_size=len(num_draft_tokens),
437437
padded_tokens_length=int(sum(num_draft_tokens)))
438438

439+
# Convert numpy arrays to lists for comparison
440+
parsed_output = [x.tolist() for x in parsed_output]
441+
439442
assert parsed_output == test_case.expected, \
440443
f"Test '{test_case.name}': Expected {test_case.expected}, got {parsed_output}"
441444

@@ -512,6 +515,9 @@ def test_parse_output_basic(self, rejection_sampler):
512515
batch_size=len(num_draft_tokens),
513516
padded_tokens_length=int(sum(num_draft_tokens)))
514517

518+
# Convert numpy arrays to lists for comparison
519+
parsed_output = [x.tolist() for x in parsed_output]
520+
515521
expected = [[10, 20, 30, 40], [50, 60, 70]]
516522
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
517523

@@ -535,6 +541,9 @@ def test_parse_output_with_placeholders(self, rejection_sampler):
535541
batch_size=len(num_draft_tokens),
536542
padded_tokens_length=int(sum(num_draft_tokens)))
537543

544+
# Convert numpy arrays to lists for comparison
545+
parsed_output = [x.tolist() for x in parsed_output]
546+
538547
expected = [[10], [20, 30, 40]]
539548
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
540549

@@ -556,6 +565,9 @@ def test_parse_output_invalid_tokens(self, rejection_sampler):
556565
batch_size=len(num_draft_tokens),
557566
padded_tokens_length=int(sum(num_draft_tokens)))
558567

568+
# Convert numpy arrays to lists for comparison
569+
parsed_output = [x.tolist() for x in parsed_output]
570+
559571
expected = [[10, 20]] # Invalid tokens filtered out
560572
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
561573

@@ -577,6 +589,9 @@ def test_parse_output_empty_sequences(self, rejection_sampler):
577589
batch_size=len(num_draft_tokens),
578590
padded_tokens_length=int(sum(num_draft_tokens)))
579591

592+
# Convert numpy arrays to lists for comparison
593+
parsed_output = [x.tolist() for x in parsed_output]
594+
580595
expected = [[50], [60]] # Only bonus tokens
581596
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
582597

@@ -632,6 +647,9 @@ def test_extreme_padding(self, rejection_sampler, test_helper):
632647
batch_size=len(num_draft_tokens),
633648
padded_tokens_length=int(sum(num_draft_tokens)))
634649

650+
# Convert numpy arrays to lists for comparison
651+
parsed_output = [x.tolist() for x in parsed_output]
652+
635653
expected = [[1, 5]] # Should ignore all padding
636654
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
637655

@@ -777,6 +795,9 @@ def test_single_long_sequence(self, rejection_sampler, test_helper):
777795
batch_size=len(num_draft_tokens),
778796
padded_tokens_length=int(sum(num_draft_tokens)))
779797

798+
# Convert numpy arrays to lists for comparison
799+
parsed_output = [x.tolist() for x in parsed_output]
800+
780801
expected = [list(range(1, 28)) + [99]] # Tokens up to mismatch point
781802
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
782803

@@ -884,6 +905,9 @@ def test_non_greedy_deterministic_with_seed(self, rejection_sampler,
884905
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
885906
batch_size=1,
886907
padded_tokens_length=4)
908+
909+
# Convert numpy arrays to lists for comparison
910+
parsed_output = [x.tolist() for x in parsed_output]
887911
outputs.append(parsed_output)
888912

889913
# All outputs should be identical with same seed
@@ -1064,6 +1088,9 @@ def test_non_greedy_empty_sequence(self, rejection_sampler, test_helper):
10641088
batch_size=2,
10651089
padded_tokens_length=0)
10661090

1091+
# Convert numpy arrays to lists for comparison
1092+
parsed_output = [x.tolist() for x in parsed_output]
1093+
10671094
# Should get bonus tokens for empty sequences
10681095
expected = [[77], [88]]
10691096
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
@@ -1152,6 +1179,10 @@ def test_non_greedy_vs_greedy_same_perfect_case(self, rejection_sampler,
11521179
non_greedy_parsed = rejection_sampler.parse_output(
11531180
non_greedy_output, VOCAB_SIZE, np.asarray(num_draft_tokens), 1, 3)
11541181

1182+
# Convert numpy arrays to lists for comparison
1183+
greedy_parsed = [x.tolist() for x in greedy_parsed]
1184+
non_greedy_parsed = [x.tolist() for x in non_greedy_parsed]
1185+
11551186
# For perfect match, greedy should have all tokens + bonus
11561187
assert greedy_parsed == [[5, 15, 25, 99]]
11571188

tests/runner/test_speculative_decoding_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def test_propose_eagle3_draft_token_ids(self,
321321
)
322322

323323
# Inputs
324-
sampled_token_ids = [[1], [2]]
324+
sampled_token_ids = [np.array([1]), np.array([2])]
325325
aux_hidden_states = MagicMock()
326326
attn_metadata = MagicMock()
327327
attn_metadata.seq_lens.shape = [2]

tpu_inference/layers/jax/sample/rejection_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def parse_output(
128128
num_draft_tokens_cpu: np.ndarray,
129129
batch_size: int,
130130
padded_tokens_length: int,
131-
) -> list[list[int]]:
131+
) -> list[np.ndarray]:
132132
"""Parse the output of the rejection sampler.
133133
134134
Args:
@@ -177,7 +177,7 @@ def parse_output(
177177
else:
178178
seq_tokens = valid_main_tokens
179179

180-
outputs.append(seq_tokens.tolist())
180+
outputs.append(seq_tokens)
181181
start_idx = end_idx
182182

183183
return outputs

tpu_inference/runner/speculative_decoding_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def propose_draft_token_ids(
7878

7979
def propose_eagle3_draft_token_ids(
8080
self,
81-
sampled_token_ids: list[list[int]],
81+
sampled_token_ids: list[np.ndarray],
8282
aux_hidden_states: Optional[tuple[jnp.ndarray, ...]],
8383
attn_metadata: AttentionMetadata,
8484
spec_decode_metadata: Optional[SpecDecodeMetadata],
@@ -91,7 +91,7 @@ def propose_eagle3_draft_token_ids(
9191
req_ids = self.runner.input_batch.req_ids
9292
next_token_ids: list[int] = []
9393
for i, token_ids in enumerate(sampled_token_ids):
94-
if token_ids:
94+
if token_ids.size != 0:
9595
# Common case.
9696
next_token_id = token_ids[-1]
9797
else:

tpu_inference/runner/tpu_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ def _sample_from_logits(
935935
# Append sampled tokens
936936
for req_idx, req_state, _ in request_seq_lens:
937937
sampled_ids = valid_sampled_token_ids[req_idx]
938-
if not sampled_ids:
938+
if sampled_ids.size == 0:
939939
continue
940940

941941
start_idx = self.input_batch.num_tokens_no_spec[req_idx]

0 commit comments

Comments
 (0)