diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 821e142304aa..9fd01ebf5d68 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -203,22 +203,10 @@ def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): def _batched(self, iterator): if self.batchSize == self.UNLIMITED_BATCH_SIZE: yield list(iterator) - elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"): - n = len(iterator) - for i in range(0, n, self.batchSize): - yield iterator[i : i + self.batchSize] else: - items = [] - count = 0 - for item in iterator: - items.append(item) - count += 1 - if count == self.batchSize: - yield items - items = [] - count = 0 - if items: - yield items + it = iter(iterator) + while batch := list(itertools.islice(it, self.batchSize)): + yield batch def dump_stream(self, iterator, stream): self.serializer.dump_stream(self._batched(iterator), stream)