@@ -436,6 +436,9 @@ def run_rejection_sampler_test(
436436 batch_size = len (num_draft_tokens ),
437437 padded_tokens_length = int (sum (num_draft_tokens )))
438438
439+ # Convert numpy arrays to lists for comparison
440+ parsed_output = [x .tolist () for x in parsed_output ]
441+
439442 assert parsed_output == test_case .expected , \
440443 f"Test '{ test_case .name } ': Expected { test_case .expected } , got { parsed_output } "
441444
@@ -512,6 +515,9 @@ def test_parse_output_basic(self, rejection_sampler):
512515 batch_size = len (num_draft_tokens ),
513516 padded_tokens_length = int (sum (num_draft_tokens )))
514517
518+ # Convert numpy arrays to lists for comparison
519+ parsed_output = [x .tolist () for x in parsed_output ]
520+
515521 expected = [[10 , 20 , 30 , 40 ], [50 , 60 , 70 ]]
516522 assert parsed_output == expected , f"Expected { expected } , got { parsed_output } "
517523
@@ -535,6 +541,9 @@ def test_parse_output_with_placeholders(self, rejection_sampler):
535541 batch_size = len (num_draft_tokens ),
536542 padded_tokens_length = int (sum (num_draft_tokens )))
537543
544+ # Convert numpy arrays to lists for comparison
545+ parsed_output = [x .tolist () for x in parsed_output ]
546+
538547 expected = [[10 ], [20 , 30 , 40 ]]
539548 assert parsed_output == expected , f"Expected { expected } , got { parsed_output } "
540549
@@ -556,6 +565,9 @@ def test_parse_output_invalid_tokens(self, rejection_sampler):
556565 batch_size = len (num_draft_tokens ),
557566 padded_tokens_length = int (sum (num_draft_tokens )))
558567
568+ # Convert numpy arrays to lists for comparison
569+ parsed_output = [x .tolist () for x in parsed_output ]
570+
559571 expected = [[10 , 20 ]] # Invalid tokens filtered out
560572 assert parsed_output == expected , f"Expected { expected } , got { parsed_output } "
561573
@@ -577,6 +589,9 @@ def test_parse_output_empty_sequences(self, rejection_sampler):
577589 batch_size = len (num_draft_tokens ),
578590 padded_tokens_length = int (sum (num_draft_tokens )))
579591
592+ # Convert numpy arrays to lists for comparison
593+ parsed_output = [x .tolist () for x in parsed_output ]
594+
580595 expected = [[50 ], [60 ]] # Only bonus tokens
581596 assert parsed_output == expected , f"Expected { expected } , got { parsed_output } "
582597
@@ -632,6 +647,9 @@ def test_extreme_padding(self, rejection_sampler, test_helper):
632647 batch_size = len (num_draft_tokens ),
633648 padded_tokens_length = int (sum (num_draft_tokens )))
634649
650+ # Convert numpy arrays to lists for comparison
651+ parsed_output = [x .tolist () for x in parsed_output ]
652+
635653 expected = [[1 , 5 ]] # Should ignore all padding
636654 assert parsed_output == expected , f"Expected { expected } , got { parsed_output } "
637655
@@ -777,6 +795,9 @@ def test_single_long_sequence(self, rejection_sampler, test_helper):
777795 batch_size = len (num_draft_tokens ),
778796 padded_tokens_length = int (sum (num_draft_tokens )))
779797
798+ # Convert numpy arrays to lists for comparison
799+ parsed_output = [x .tolist () for x in parsed_output ]
800+
780801 expected = [list (range (1 , 28 )) + [99 ]] # Tokens up to mismatch point
781802 assert parsed_output == expected , f"Expected { expected } , got { parsed_output } "
782803
@@ -884,6 +905,9 @@ def test_non_greedy_deterministic_with_seed(self, rejection_sampler,
884905 num_draft_tokens_cpu = np .asarray (num_draft_tokens ),
885906 batch_size = 1 ,
886907 padded_tokens_length = 4 )
908+
909+ # Convert numpy arrays to lists for comparison
910+ parsed_output = [x .tolist () for x in parsed_output ]
887911 outputs .append (parsed_output )
888912
889913 # All outputs should be identical with same seed
@@ -1064,6 +1088,9 @@ def test_non_greedy_empty_sequence(self, rejection_sampler, test_helper):
10641088 batch_size = 2 ,
10651089 padded_tokens_length = 0 )
10661090
1091+ # Convert numpy arrays to lists for comparison
1092+ parsed_output = [x .tolist () for x in parsed_output ]
1093+
10671094 # Should get bonus tokens for empty sequences
10681095 expected = [[77 ], [88 ]]
10691096 assert parsed_output == expected , f"Expected { expected } , got { parsed_output } "
@@ -1152,6 +1179,10 @@ def test_non_greedy_vs_greedy_same_perfect_case(self, rejection_sampler,
11521179 non_greedy_parsed = rejection_sampler .parse_output (
11531180 non_greedy_output , VOCAB_SIZE , np .asarray (num_draft_tokens ), 1 , 3 )
11541181
1182+ # Convert numpy arrays to lists for comparison
1183+ greedy_parsed = [x .tolist () for x in greedy_parsed ]
1184+ non_greedy_parsed = [x .tolist () for x in non_greedy_parsed ]
1185+
11551186 # For perfect match, greedy should have all tokens + bonus
11561187 assert greedy_parsed == [[5 , 15 , 25 , 99 ]]
11571188
0 commit comments