From 05a4096ac39b59810e9724b43558761a3549acd7 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Sun, 16 Nov 2025 17:35:19 -0800 Subject: [PATCH 1/2] Modernize the _batched method for BatchedSerializer --- python/pyspark/serializers.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 821e142304aa..9f5604a95993 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -55,7 +55,7 @@ import sys import os -from itertools import chain, product +from itertools import batched, chain, product import marshal import struct import types @@ -203,22 +203,8 @@ def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): def _batched(self, iterator): if self.batchSize == self.UNLIMITED_BATCH_SIZE: yield list(iterator) - elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"): - n = len(iterator) - for i in range(0, n, self.batchSize): - yield iterator[i : i + self.batchSize] else: - items = [] - count = 0 - for item in iterator: - items.append(item) - count += 1 - if count == self.batchSize: - yield items - items = [] - count = 0 - if items: - yield items + yield from map(list(batched(iterator, self.batchSize))) def dump_stream(self, iterator, stream): self.serializer.dump_stream(self._batched(iterator), stream) From 4cd1edfed3531211c2a03bf015a4f505700d7441 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Sun, 16 Nov 2025 17:44:03 -0800 Subject: [PATCH 2/2] Use the older APIs --- python/pyspark/serializers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 9f5604a95993..9fd01ebf5d68 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -55,7 +55,7 @@ import sys import os -from itertools import batched, chain, product +from itertools import chain, product import marshal import struct import types @@ -204,7 +204,9 @@ def _batched(self, iterator): if self.batchSize == self.UNLIMITED_BATCH_SIZE: yield list(iterator) else: - yield from map(list(batched(iterator, self.batchSize))) + it = iter(iterator) + while batch := list(itertools.islice(it, self.batchSize)): + yield batch def dump_stream(self, iterator, stream): self.serializer.dump_stream(self._batched(iterator), stream)