diff --git a/CHANGELOG.md b/CHANGELOG.md index 1321a77dd..87f96422d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change and advance a `grain.DatasetIterator` to the given produced element index. * Switches to multithreading instead of multiprocessing in `IterDataset.mp_prefetch` when free-threaded Python is detected. + * Add `ElasticIterDatasetIterator` for scaling up and down the number of shards between checkpoints. * Breaking changes: * Custom implementations of `RandomAccessDataSource` should accept `int` diff --git a/grain/_src/python/checkpoint/BUILD b/grain/_src/python/checkpoint/BUILD index a160c9c46..efd871ab7 100644 --- a/grain/_src/python/checkpoint/BUILD +++ b/grain/_src/python/checkpoint/BUILD @@ -17,9 +17,34 @@ py_library( srcs = ["handler.py"], srcs_version = "PY3", deps = [ + ":elastic_checkpoint", "//grain/_src/core:sharding", "//grain/_src/python:data_loader", "//grain/_src/python/dataset", + "//grain/_src/python/dataset:elastic_iterator", + "@pypi//etils:pkg", + ], +) + +py_library( + name = "elastic_checkpoint", + srcs = ["elastic_checkpoint.py"], + srcs_version = "PY3", + deps = [ + "//grain/_src/python/dataset:elastic_iterator", + "@pypi//etils:pkg", + ], +) + +py_test( + name = "elastic_checkpoint_test", + srcs = ["elastic_checkpoint_test.py"], + srcs_version = "PY3", + deps = [ + ":elastic_checkpoint", + "//grain/_src/core:sharding", + "//grain/_src/python/dataset:elastic_iterator", + "@abseil-py//absl/testing:absltest", "@pypi//etils:pkg", ], ) diff --git a/grain/_src/python/checkpoint/elastic_checkpoint.py b/grain/_src/python/checkpoint/elastic_checkpoint.py new file mode 100644 index 000000000..598e9f1ed --- /dev/null +++ b/grain/_src/python/checkpoint/elastic_checkpoint.py @@ -0,0 +1,137 @@ +"""This module provides checkpointing logic for ElasticIterDatasetIterator.""" + +import dataclasses +import json +from typing import Any, Optional, Sequence + +from etils import epath +from grain._src.python.dataset import elastic_iterator + + +def _find_shard_file( + directory: epath.Path, + shard_index: int, + total_num_shards: int, +) -> epath.Path: + """Finds all files matching 'shard_state_*.json' in the directory.""" + all_files = list(directory.iterdir()) + pattern = f"shard_state_{shard_index}-of-{total_num_shards}.json" + found_files = [f for f in all_files if f.name.endswith(pattern)] + if not found_files: + raise ValueError( + f"No shard state files found in {directory} for shard {shard_index}" + ) + if len(found_files) > 1: + raise ValueError( + f"Multiple shard state files found in {directory} for shard" + f" {shard_index}" + ) + return found_files[0] + + +def save_elastic_iterator( + directory: epath.Path, + item: elastic_iterator.ElasticIterDatasetIterator, +): + """Saves the given iterator to the checkpoint in `directory`.""" + state = item.get_state() + ds_iterator_states = state["ds_iterator_states"] + total_num_shards = state["total_num_shards"] + for idx, host_iterator_state in ds_iterator_states.items(): + host_iterator_state["total_num_shards"] = total_num_shards + shard_state = json.dumps(host_iterator_state, indent=4) + filename = directory / f"shard_state_{idx}-of-{total_num_shards}.json" + filename.write_text(shard_state) + + +def restore_elastic_iterator( + directory: epath.Path, + item: elastic_iterator.ElasticIterDatasetIterator, +): + """Restores the given iterator from the checkpoint in `directory`.""" + total_num_shards = item.total_num_shards + shard_index = item.shard_options.shard_index + shard_count = item.shard_options.shard_count + while shard_index < total_num_shards: + filename = _find_shard_file(directory, shard_index, total_num_shards) + state = filename.read_text() + state = json.loads(state) + item.update_shard_iterator_state(shard_index, state) + shard_index += shard_count + + +class ElasticCheckpointHandler: + """Orbax CheckpointHandler for PyGrain iterators.""" + + def save( + self, + directory: epath.Path, + item: Optional[ + elastic_iterator.ElasticIterDatasetIterator + | Sequence[elastic_iterator.ElasticIterDatasetIterator] + ] = None, + args: Any = None, + ): + """Saves the given iterator to the checkpoint in `directory`.""" + item = item or args.item + if isinstance(item, elastic_iterator.ElasticIterDatasetIterator): + item = [item] + for iterator in item: + save_elastic_iterator(directory, iterator) + + def restore( + self, + directory: epath.Path, + item: Optional[ + elastic_iterator.ElasticIterDatasetIterator + | Sequence[elastic_iterator.ElasticIterDatasetIterator] + ] = None, + args: Any = None, + ) -> Any: + """Restores the given iterator from the checkpoint in `directory`.""" + item = item or args.item + if isinstance(item, elastic_iterator.ElasticIterDatasetIterator): + item = [item] + for iterator in item: + restore_elastic_iterator(directory, iterator) + return item + + # Required by interface but not supported by PyGrain checkpoints. + def structure(self, directory: epath.Path) -> Any: + del directory + return None + + # Required by interface. + + def metadata(self, directory: epath.Path) -> Optional[Any]: + del directory + return None + + def finalize(self, directory: epath.Path): + pass + + def close(self): + pass + + @classmethod + def typestr(cls): + return f"{cls.__module__}.{cls.__qualname__}" + + +try: + # Register the handler to be used with the new checkpointing API if Orbax is + # present. + import orbax.checkpoint as ocp # pylint:disable=g-import-not-at-top # pytype:disable=import-error + + @ocp.args.register_with_handler(ElasticCheckpointHandler, for_save=True) # pytype:disable=wrong-arg-types + @dataclasses.dataclass + class ElasticCheckpointSave(ocp.args.CheckpointArgs): + item: Any + + @ocp.args.register_with_handler(ElasticCheckpointHandler, for_restore=True) # pytype:disable=wrong-arg-types + @dataclasses.dataclass + class ElasticCheckpointRestore(ocp.args.CheckpointArgs): + item: Any + +except (ImportError, TypeError, AttributeError): + pass diff --git a/grain/_src/python/checkpoint/elastic_checkpoint_test.py b/grain/_src/python/checkpoint/elastic_checkpoint_test.py new file mode 100644 index 000000000..88ba37f5e --- /dev/null +++ b/grain/_src/python/checkpoint/elastic_checkpoint_test.py @@ -0,0 +1,122 @@ +"""Tests for elastic checkpoint.""" + +import json + +from etils import epath +from grain._src.core import sharding +from grain._src.python.checkpoint import elastic_checkpoint +from grain._src.python.dataset import elastic_iterator + +from absl.testing import absltest + + +class MockElasticIterDatasetIterator( + elastic_iterator.ElasticIterDatasetIterator +): + + def __init__(self, shard_options, total_num_shards, states=None): + self._shard_options = shard_options + self._total_num_shards = total_num_shards + self._states = states if states is not None else {} + self.updated_states = {} + + def get_state(self): + return { + "ds_iterator_states": self._states, + "total_num_shards": self._total_num_shards, + } + + def update_shard_iterator_state(self, shard_index, state): + self.updated_states[shard_index] = state + + +class ElasticCheckpointTest(absltest.TestCase): + + def test_save_and_restore_elastic_iterator(self): + temp_dir = epath.Path(self.create_tempdir().full_path) + shard_options = sharding.ShardOptions(shard_index=0, shard_count=1) + states = { + 0: {"val": 0}, + 1: {"val": 1}, + } + iterator = MockElasticIterDatasetIterator( + shard_options=shard_options, total_num_shards=2, states=states + ) + elastic_checkpoint.save_elastic_iterator(temp_dir, iterator) + + file0 = temp_dir / "shard_state_0-of-2.json" + self.assertTrue(file0.exists()) + self.assertEqual( + file0.read_text(), + json.dumps({"val": 0, "total_num_shards": 2}, indent=4), + ) + file1 = temp_dir / "shard_state_1-of-2.json" + self.assertTrue(file1.exists()) + self.assertEqual( + file1.read_text(), + json.dumps({"val": 1, "total_num_shards": 2}, indent=4), + ) + + iterator_to_restore = MockElasticIterDatasetIterator( + shard_options=shard_options, total_num_shards=2 + ) + elastic_checkpoint.restore_elastic_iterator(temp_dir, iterator_to_restore) + self.assertEqual( + iterator_to_restore.updated_states, + { + 0: {"val": 0, "total_num_shards": 2}, + 1: {"val": 1, "total_num_shards": 2}, + }, + ) + + def test_restore_elastic_iterator_with_multiple_processes(self): + temp_dir = epath.Path(self.create_tempdir().full_path) + # Process 0 + shard_options_0 = sharding.ShardOptions(shard_index=0, shard_count=2) + states = { + 0: {"val": 0}, + 1: {"val": 1}, + 2: {"val": 2}, + } + iterator_0 = MockElasticIterDatasetIterator( + shard_options=shard_options_0, total_num_shards=3, states=states + ) + # In reality save_elastic_iterator will be called in each process, but + # get_state() should return all states, so we only need to call it once + # to create checkpoint files. + elastic_checkpoint.save_elastic_iterator(temp_dir, iterator_0) + + # Check files are written + self.assertTrue((temp_dir / "shard_state_0-of-3.json").exists()) + self.assertTrue((temp_dir / "shard_state_1-of-3.json").exists()) + self.assertTrue((temp_dir / "shard_state_2-of-3.json").exists()) + + # Restore for process 0, responsible for shards 0 and 2. + iterator_to_restore_0 = MockElasticIterDatasetIterator( + shard_options=shard_options_0, total_num_shards=3 + ) + elastic_checkpoint.restore_elastic_iterator(temp_dir, iterator_to_restore_0) + self.assertEqual( + iterator_to_restore_0.updated_states, + { + 0: {"val": 0, "total_num_shards": 3}, + 2: {"val": 2, "total_num_shards": 3}, + }, + ) + + # Restore for process 1, responsible for shard 1. + shard_options_1 = sharding.ShardOptions(shard_index=1, shard_count=2) + iterator_to_restore_1 = MockElasticIterDatasetIterator( + shard_options=shard_options_1, total_num_shards=3 + ) + elastic_checkpoint.restore_elastic_iterator(temp_dir, iterator_to_restore_1) + self.assertEqual( + iterator_to_restore_1.updated_states, + { + 1: {"val": 1, "total_num_shards": 3}, + }, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/grain/_src/python/checkpoint/handler.py b/grain/_src/python/checkpoint/handler.py index d6bebb16b..e7da8ad5c 100644 --- a/grain/_src/python/checkpoint/handler.py +++ b/grain/_src/python/checkpoint/handler.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """This module provides a PyGrain CheckpointHandler for integration with Orbax.""" + import dataclasses import json from typing import Any, Optional, TypeVar @@ -19,7 +20,9 @@ from etils import epath from grain._src.core import sharding from grain._src.python import data_loader +from grain._src.python.checkpoint import elastic_checkpoint from grain._src.python.dataset import dataset +from grain._src.python.dataset import elastic_iterator IteratorType = TypeVar( "IteratorType", data_loader.DataLoaderIterator, dataset.DatasetIterator @@ -41,6 +44,9 @@ def save( """Saves the given iterator to the checkpoint in `directory`.""" item = item or args.item # pytype:disable=attribute-error if isinstance(item, dataset.DatasetIterator): + if isinstance(item, elastic_iterator.ElasticIterDatasetIterator): + elastic_checkpoint.save_elastic_iterator(directory, item) + return state = json.dumps(item.get_state(), indent=4) else: state = item.get_state().decode() @@ -56,6 +62,9 @@ def restore( ) -> IteratorType: """Restores the given iterator from the checkpoint in `directory`.""" item = item or args.item # pytype:disable=attribute-error + if isinstance(item, elastic_iterator.ElasticIterDatasetIterator): + elastic_checkpoint.restore_elastic_iterator(directory, item) + return item process_index, process_count = sharding.get_process_index_and_count() filename = directory / f"process_{process_index}-of-{process_count}.json" if not filename.exists(): @@ -105,6 +114,5 @@ class CheckpointSave(ocp.args.CheckpointArgs): class CheckpointRestore(ocp.args.CheckpointArgs): item: Any - except (ImportError, TypeError, AttributeError): pass diff --git a/grain/_src/python/dataset/BUILD b/grain/_src/python/dataset/BUILD index 8e3e81ad4..98c4d86bb 100644 --- a/grain/_src/python/dataset/BUILD +++ b/grain/_src/python/dataset/BUILD @@ -181,6 +181,7 @@ py_test( ":elastic_iterator", "//grain/_src/core:sharding", "//grain/_src/python:options", + "//grain/_src/python/checkpoint:elastic_checkpoint", "//grain/_src/python/testing:experimental", "@abseil-py//absl/testing:absltest", "@abseil-py//absl/testing:parameterized", diff --git a/grain/_src/python/dataset/elastic_iterator.py b/grain/_src/python/dataset/elastic_iterator.py index a7140c172..a660e7f93 100644 --- a/grain/_src/python/dataset/elastic_iterator.py +++ b/grain/_src/python/dataset/elastic_iterator.py @@ -14,7 +14,7 @@ """Iterator supporting changes in the number of hosts (dataset shards).""" import functools -from typing import Any +from typing import Any, Sequence, TypeVar from grain._src.core import sharding from grain._src.python import options @@ -22,6 +22,9 @@ from grain._src.python.dataset.transformations import ( filter as filter_dataset, ) +from grain._src.python.dataset.transformations import interleave + +T = TypeVar("T") _GLOBAL_NEXT_INDEX_STATE_KEY = "global_next_index" @@ -104,3 +107,150 @@ def set_state(self, state: dict[str, Any]): self._global_next_index = state[_GLOBAL_NEXT_INDEX_STATE_KEY] # Reset the iterator if it was already created. self.__dict__.pop("_iterator", None) + + +class ElasticIterDatasetIterator(dataset.DatasetIterator): + """Iterator for ElasticIterDataset. + + This class acts as a wrapper around InterleaveDatasetIterator, applying + sharding and batching dynamically to the datasets. Typically, sharded datasets + can not be resharded and distributed to iterators. This class + provides a way to do this by taking in the maximum number of dataset shards + and interleaving those shards into a variable number of iterators. + + Caveats: + - Order of elements is not guaranteed. + + Usage: + parquet_files = ep.glob("/path/to/some/files/*.parquet") + ds = [ + ParquetIterDataset(f) for f in parquet_files + ] + it = ElasticIterDatasetIterator( + ds, + shard_options=sharding.ShardOptions(shard_index=jax.process_id(), + shard_count=10), + global_batch_size=3, + ) + iterator = iter(it) + x = next(iterator) + + # Continue to use the iterator as usual and save it to a checkpoint with the + # dedicated elastic checkpoint API. + elastic_checkpoint.save_elastic_iterator(temp_dir, it) + + # When restoring, the number of processes can be changed and elastic + # iterator will be restored accordingly. + it = ElasticIterDatasetIterator( + ds, + shard_options=sharding.ShardOptions(shard_index=jax.process_id(), + shard_count=20), + global_batch_size=3, + ) + elastic_checkpoint.restore_elastic_iterator(temp_dir, it) + """ + + def __init__( + self, + ds: Sequence[dataset.IterDataset], + shard_options: sharding.ShardOptions, + global_batch_size: int, + *, + read_options: options.ReadOptions = options.ReadOptions(), + multiprocessing_options: options.MultiprocessingOptions | None = None, + cycle_length: int | None = None, + num_make_iter_threads: int = 1, + make_iter_buffer_size: int = 1, + iter_buffer_size: int = 1, + ): + super().__init__() + self._ds = ds + self._global_batch_size = global_batch_size + self._shard_options = shard_options + self._read_options = read_options + self._multiprocessing_options = multiprocessing_options + + # InterleaveDatasetIterator options. + self._cycle_length = cycle_length or global_batch_size + self._num_make_iter_threads = num_make_iter_threads + self._make_iter_buffer_size = make_iter_buffer_size + self._iter_buffer_size = iter_buffer_size + + self._total_num_shards = len(ds) + # The shard indices that are assigned to this iterator. + self._shard_indices = list( + range( + self._shard_options.shard_index, + self._total_num_shards, + self._shard_options.shard_count, + ) + ) + # The corresponding iterators for each shard index. + if self._global_batch_size == 1: + self._ds_iterators = [ + ds.__iter__() + for ds in self._ds[ + self._shard_options.shard_index :: self._shard_options.shard_count + ] + ] + else: + self._ds_iterators = [ + ds.batch(self._global_batch_size, drop_remainder=True).__iter__() + for ds in self._ds[ + self._shard_options.shard_index :: self._shard_options.shard_count + ] + ] + + @property + def shard_options(self) -> sharding.ShardOptions: + return self._shard_options + + @property + def total_num_shards(self) -> int: + return self._total_num_shards + + @functools.cached_property + def _iterator(self) -> dataset.DatasetIterator: + return interleave.InterleaveDatasetIterator( + self._ds_iterators, + cycle_length=self._cycle_length, + num_make_iter_threads=self._num_make_iter_threads, + make_iter_buffer_size=self._make_iter_buffer_size, + iter_buffer_size=self._iter_buffer_size, + ) + + def __iter__(self) -> dataset.DatasetIterator: + return self + + def __next__(self) -> Any: + return next(self._iterator) + + def get_state(self) -> dict[str, Any]: + host_iterator_states = { + indx: it.get_state() + for indx, it in zip(self._shard_indices, self._ds_iterators) + } + state = { + "total_num_shards": self._total_num_shards, + } + state["ds_iterator_states"] = host_iterator_states + return state + + def set_state(self, state: dict[str, Any]): + saved_iterator_states = state["ds_iterator_states"] + for k, v in saved_iterator_states.items(): + indx = k // self.shard_options.shard_count + self._ds_iterators[indx].set_state(v) + self.__dict__.pop("_iterator", None) + + def update_shard_iterator_state( + self, shard_index: int, state: dict[str, Any] + ): + if shard_index not in self._shard_indices: + # This should never happen. + raise ValueError( + f"Shard index {shard_index} is not in the shard indices" + f" {self._shard_indices}." + ) + indx = shard_index // self.shard_options.shard_count + self._ds_iterators[indx].set_state(state) diff --git a/grain/_src/python/dataset/elastic_iterator_test.py b/grain/_src/python/dataset/elastic_iterator_test.py index 1c4261f09..f7184e26f 100644 --- a/grain/_src/python/dataset/elastic_iterator_test.py +++ b/grain/_src/python/dataset/elastic_iterator_test.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import platform from absl.testing import absltest from absl.testing import parameterized from grain._src.core import sharding import multiprocessing as mp from grain._src.python import options +from grain._src.python.checkpoint import elastic_checkpoint from grain._src.python.dataset import dataset from grain._src.python.dataset import elastic_iterator import grain._src.python.testing.experimental as test_util @@ -25,7 +25,7 @@ @absltest.skipIf(platform.system() == "Windows", "Skipped under bazel.") -class ElasticIteratorTest(parameterized.TestCase): +class ElasticMapDataset(parameterized.TestCase): @parameterized.parameters( dict( @@ -254,5 +254,90 @@ def test_filter_raises_error(self): elastic_iterator.ElasticIterator(ds, 5, sharding.NoSharding()) +class ElasticIterDatasetIteratorTest(parameterized.TestCase): + + @parameterized.parameters( + dict( + shard_options=sharding.NoSharding(), + global_batch_size=1, + expected=list(range(15)), + ), + dict( + shard_options=sharding.ShardOptions(shard_index=0, shard_count=1), + global_batch_size=1, + expected=list(range(15)), + ), + dict( + shard_options=sharding.NoSharding(), + global_batch_size=3, + # Data is interleaved with cycle length 3. + expected=[[0, 1, 2], [5, 6, 7], [10, 11, 12]], + ), + ) + def test_no_sharding_produces_correct_elements( + self, shard_options, global_batch_size, expected + ): + ds = [ + # 3 shards, each with 5 elements. + dataset.MapDataset.range(i * 5, (i + 1) * 5).to_iter_dataset() + for i in range(3) + ] + it = elastic_iterator.ElasticIterDatasetIterator( + ds, shard_options, global_batch_size=global_batch_size + ) + actual = list(it) + self.assertLen(actual, len(expected)) + for actual_batch, expected_batch in zip(actual, expected): + np.testing.assert_equal(actual_batch, expected_batch) + + @parameterized.parameters( + dict( + shard_options=sharding.ShardOptions(shard_index=0, shard_count=2), + global_batch_size=1, + expected=[0, 2, 4, 6, 8], + ), + dict( + shard_options=sharding.ShardOptions(shard_index=1, shard_count=2), + global_batch_size=1, + expected=[1, 3, 5, 7, 9], + ), + dict( + shard_options=sharding.ShardOptions(shard_index=0, shard_count=2), + global_batch_size=2, + expected=[[0, 4], [2, 6]], + ), + ) + def test_sharding_produces_correct_elements( + self, shard_options, global_batch_size, expected + ): + ds = [ + # 4 shards, 0: [0, 4, 8], 1: [1, 5, 9], 2: [2, 6], 3: [3, 7] + dataset.MapDataset.range(i, 10, 4).to_iter_dataset() + for i in range(4) + ] + it = elastic_iterator.ElasticIterDatasetIterator( + ds, + shard_options, + global_batch_size=global_batch_size, + cycle_length=2, + ) + actual = list(it) + self.assertLen(actual, len(expected)) + for actual_batch, expected_batch in zip(actual, expected): + np.testing.assert_equal(actual_batch, expected_batch) + + def test_checkpointing_no_change(self): + ds = [ + dataset.MapDataset.range(i, 100, 25).to_iter_dataset() + for i in range(25) + ] + it = elastic_iterator.ElasticIterDatasetIterator( + ds, + shard_options=sharding.ShardOptions(shard_index=2, shard_count=4), + global_batch_size=3, + ) + test_util.assert_equal_output_after_checkpoint(it) + + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index 3539899bb..2ee9e6ede 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -33,7 +33,10 @@ class InterleaveDatasetIterator(dataset.DatasetIterator[T]): def __init__( self, - datasets: Sequence[dataset.IterDataset[T] | dataset.MapDataset[T]], + datasets: ( + Sequence[dataset.IterDataset[T] | dataset.MapDataset[T]] + | Sequence[dataset.DatasetIterator[T]] + ), cycle_length: int, num_make_iter_threads: int = 1, make_iter_buffer_size: int = 1, @@ -46,30 +49,44 @@ def __init__( self._num_make_iter_threads = num_make_iter_threads self._make_iter_buffer_size = make_iter_buffer_size self._iter_buffer_size = iter_buffer_size - self._prefetch_ds_iter = ( - dataset.MapDataset.source(datasets) - .map( - functools.partial( - _add_prefetch_and_make_iterator, - # We use weakref to avoid a circular reference. The - # _InterleaveDatasetIterator holds a reference to the - # prefetch iterator in `self._prefetch_ds_iter`. - # The call to `_add_prefetch_and_make_iterator` (and the - # partial object) would hold a reference to the - # _InterleaveDatasetIterator. This would prolong its lifetime - # leading to increased resource usage. - interleave_iterator=weakref.ref(self), - start_prefetch=True, - ) - ) - .to_iter_dataset( - grain_options.ReadOptions( - num_threads=self._num_make_iter_threads, - prefetch_buffer_size=self._make_iter_buffer_size, - ) - ) - .__iter__() - ) + if datasets and isinstance(datasets[0], dataset.DatasetIterator): + # If the input is a sequence of iterators, create a prefetch iterator + # directly from the iterators. + self._prefetch_ds_iter = ( + dataset.MapDataset.source(datasets) + .to_iter_dataset( + grain_options.ReadOptions( + num_threads=self._num_make_iter_threads, + prefetch_buffer_size=self._make_iter_buffer_size, + ) + ) + .__iter__() + ) + else: + self._prefetch_ds_iter = ( + dataset.MapDataset.source(datasets) + .map( + functools.partial( + _add_prefetch_and_make_iterator, + # We use weakref to avoid a circular reference. The + # InterleaveDatasetIterator holds a reference to the + # prefetch iterator in `self._prefetch_ds_iter`. + # The call to `_add_prefetch_and_make_iterator` (and the + # partial object) would hold a reference to the + # InterleaveDatasetIterator. This would prolong its lifetime + # leading to increased resource usage. + interleave_iterator=weakref.ref(self), + start_prefetch=True, + ) + ) + .to_iter_dataset( + grain_options.ReadOptions( + num_threads=self._num_make_iter_threads, + prefetch_buffer_size=self._make_iter_buffer_size, + ) + ) + .__iter__() + ) self._cycle_length: int = min(cycle_length, len(datasets)) self._next_index_in_cycle: int = 0 self._next_index_in_datasets: int = 0