@@ -436,9 +436,6 @@ def run_rejection_sampler_test(
436436 batch_size = len (num_draft_tokens ),
437437 padded_tokens_length = int (sum (num_draft_tokens )))
438438
439- # Convert numpy arrays to lists for comparison
440- parsed_output = [x .tolist () for x in parsed_output ]
441-
442439 assert parsed_output == test_case .expected , \
443440 f"Test '{ test_case .name } ': Expected { test_case .expected } , got { parsed_output } "
444441
@@ -515,9 +512,6 @@ def test_parse_output_basic(self, rejection_sampler):
515512 batch_size = len (num_draft_tokens ),
516513 padded_tokens_length = int (sum (num_draft_tokens )))
517514
518- # Convert numpy arrays to lists for comparison
519- parsed_output = [x .tolist () for x in parsed_output ]
520-
521515 expected = [[10 , 20 , 30 , 40 ], [50 , 60 , 70 ]]
522516 assert parsed_output == expected , f"Expected { expected } , got { parsed_output } "
523517
@@ -541,9 +535,6 @@ def test_parse_output_with_placeholders(self, rejection_sampler):
541535 batch_size = len (num_draft_tokens ),
542536 padded_tokens_length = int (sum (num_draft_tokens )))
543537
544- # Convert numpy arrays to lists for comparison
545- parsed_output = [x .tolist () for x in parsed_output ]
546-
547538 expected = [[10 ], [20 , 30 , 40 ]]
548539 assert parsed_output == expected , f"Expected { expected } , got { parsed_output } "
549540
@@ -565,9 +556,6 @@ def test_parse_output_invalid_tokens(self, rejection_sampler):
565556 batch_size = len (num_draft_tokens ),
566557 padded_tokens_length = int (sum (num_draft_tokens )))
567558
568- # Convert numpy arrays to lists for comparison
569- parsed_output = [x .tolist () for x in parsed_output ]
570-
571559 expected = [[10 , 20 ]] # Invalid tokens filtered out
572560 assert parsed_output == expected , f"Expected { expected } , got { parsed_output } "
573561
@@ -589,9 +577,6 @@ def test_parse_output_empty_sequences(self, rejection_sampler):
589577 batch_size = len (num_draft_tokens ),
590578 padded_tokens_length = int (sum (num_draft_tokens )))
591579
592- # Convert numpy arrays to lists for comparison
593- parsed_output = [x .tolist () for x in parsed_output ]
594-
595580 expected = [[50 ], [60 ]] # Only bonus tokens
596581 assert parsed_output == expected , f"Expected { expected } , got { parsed_output } "
597582
@@ -647,9 +632,6 @@ def test_extreme_padding(self, rejection_sampler, test_helper):
647632 batch_size = len (num_draft_tokens ),
648633 padded_tokens_length = int (sum (num_draft_tokens )))
649634
650- # Convert numpy arrays to lists for comparison
651- parsed_output = [x .tolist () for x in parsed_output ]
652-
653635 expected = [[1 , 5 ]] # Should ignore all padding
654636 assert parsed_output == expected , f"Expected { expected } , got { parsed_output } "
655637
@@ -795,9 +777,6 @@ def test_single_long_sequence(self, rejection_sampler, test_helper):
795777 batch_size = len (num_draft_tokens ),
796778 padded_tokens_length = int (sum (num_draft_tokens )))
797779
798- # Convert numpy arrays to lists for comparison
799- parsed_output = [x .tolist () for x in parsed_output ]
800-
801780 expected = [list (range (1 , 28 )) + [99 ]] # Tokens up to mismatch point
802781 assert parsed_output == expected , f"Expected { expected } , got { parsed_output } "
803782
@@ -905,9 +884,6 @@ def test_non_greedy_deterministic_with_seed(self, rejection_sampler,
905884 num_draft_tokens_cpu = np .asarray (num_draft_tokens ),
906885 batch_size = 1 ,
907886 padded_tokens_length = 4 )
908-
909- # Convert numpy arrays to lists for comparison
910- parsed_output = [x .tolist () for x in parsed_output ]
911887 outputs .append (parsed_output )
912888
913889 # All outputs should be identical with same seed
@@ -1088,9 +1064,6 @@ def test_non_greedy_empty_sequence(self, rejection_sampler, test_helper):
10881064 batch_size = 2 ,
10891065 padded_tokens_length = 0 )
10901066
1091- # Convert numpy arrays to lists for comparison
1092- parsed_output = [x .tolist () for x in parsed_output ]
1093-
10941067 # Should get bonus tokens for empty sequences
10951068 expected = [[77 ], [88 ]]
10961069 assert parsed_output == expected , f"Expected { expected } , got { parsed_output } "
@@ -1179,10 +1152,6 @@ def test_non_greedy_vs_greedy_same_perfect_case(self, rejection_sampler,
11791152 non_greedy_parsed = rejection_sampler .parse_output (
11801153 non_greedy_output , VOCAB_SIZE , np .asarray (num_draft_tokens ), 1 , 3 )
11811154
1182- # Convert numpy arrays to lists for comparison
1183- greedy_parsed = [x .tolist () for x in greedy_parsed ]
1184- non_greedy_parsed = [x .tolist () for x in non_greedy_parsed ]
1185-
11861155 # For perfect match, greedy should have all tokens + bonus
11871156 assert greedy_parsed == [[5 , 15 , 25 , 99 ]]
11881157
0 commit comments