diff --git a/grain/_src/python/dataset/transformations/BUILD b/grain/_src/python/dataset/transformations/BUILD index 21ddb6a47..c6ea98fcc 100644 --- a/grain/_src/python/dataset/transformations/BUILD +++ b/grain/_src/python/dataset/transformations/BUILD @@ -223,6 +223,7 @@ py_test( shard_count = 5, srcs_version = "PY3", deps = [ + "//grain/_src/python:options", "//grain/_src/python/dataset", "//grain/_src/python/testing:experimental", "@abseil-py//absl/testing:absltest", diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index 3539899bb..9a387ae07 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -82,6 +82,10 @@ def __init__( ] * self._cycle_length self._started = False self._parent_stats: list[stats.Stats | None] = [None] * self._cycle_length + self._keep_iterators_after_stop_iteration = False + self._exhausted_iterators: list[ + tuple[int, dataset.DatasetIterator[T]] | None + ] = [None] * self._cycle_length @stats.record_next_duration_if_output @stats.trace_input_pipeline_next(stage_category=stats.IPL_CAT_PREPROCESSING) @@ -119,6 +123,11 @@ def __next__(self) -> T: self._exhausted_iterator_state[self._next_index_in_cycle] = ( iterator_to_use.get_state() ) + if self._keep_iterators_after_stop_iteration: + self._exhausted_iterators[self._next_index_in_cycle] = ( + self._iterators_in_use_indices[self._next_index_in_cycle], + iterator_to_use, + ) self._iterators_in_use[self._next_index_in_cycle] = None self._next_index_in_cycle = ( self._next_index_in_cycle + 1 @@ -195,12 +204,22 @@ def set_state(self, state): or iterator is None ): # The iterator currently in use is either exhausted or corresponds to - # a different dataset. We need to create a new iterator. - iterator = _add_prefetch_and_make_iterator( - self._datasets[index_in_datasets], - interleave_iterator=weakref.ref(self), - start_prefetch=False, - ) + # a different dataset. We need to create a new iterator or check the + # exhausted iterators list. + if ( + self._keep_iterators_after_stop_iteration + and self._exhausted_iterators[index_in_cycle] is not None + and self._exhausted_iterators[index_in_cycle][0] + == index_in_datasets + ): + _, iterator = self._exhausted_iterators[index_in_cycle] + self._exhausted_iterators[index_in_cycle] = None + else: + iterator = _add_prefetch_and_make_iterator( + self._datasets[index_in_datasets], + interleave_iterator=weakref.ref(self), + start_prefetch=False, + ) iterator.set_state(it_state) self._iterators_in_use[index_in_cycle] = iterator else: @@ -238,6 +257,15 @@ def _set_next_index(self, index: int) -> None: " more than one dataset." ) + def set_keep_iterators_after_stop_iteration( + self, keep_iterators: bool + ) -> None: + # Determines whether the iterators should be kept alive after + # StopIteration is raised by `__next__`. This is used by + # `RepeatDatasetIterator` to allow for resetting the iterator state and + # continuing iteration without recreating the iterators. + self._keep_iterators_after_stop_iteration = keep_iterators + def close(self) -> None: """Closes the iterator and shuts down the iterator prefetching.""" if self._closed: diff --git a/grain/_src/python/dataset/transformations/process_prefetch.py b/grain/_src/python/dataset/transformations/process_prefetch.py index 156eba673..7023c2db8 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch.py +++ b/grain/_src/python/dataset/transformations/process_prefetch.py @@ -300,6 +300,7 @@ def __init__( self._iter_parent = parent self._buffer_size = buffer_size self._worker_init_fn = worker_init_fn + self._keep_workers_after_stop_iteration = False # Since the parent iterator is going to be created in each subprocess, and # the options are propagated during iterator creation, we need to manually # propagate them. @@ -451,7 +452,11 @@ def __next__(self): # Unlink shared memory for the discarded element. shared_memory_array.unlink_shm(element) if err is not None: - self._stop_prefetch() + if ( + not isinstance(err, StopIteration) + or not self._keep_workers_after_stop_iteration + ): + self._stop_prefetch() self._exhausted = True raise err self._state = state @@ -549,6 +554,13 @@ def _set_next_index(self, next_index: int): self._exhausted = False self._state = None + def set_keep_workers_after_stop_iteration(self, keep_workers: bool): + # Determines whether the worker processes should be kept alive after + # StopIteration is raised by `__next__`. This is used by + # `RepeatDatasetIterator` to allow for resetting the iterator state and + # continuing iteration without recreating the worker processes. + self._keep_workers_after_stop_iteration = keep_workers + def __str__(self) -> str: return f"ProcessPrefetchDatasetIterator(buffer_size={self._buffer_size})" diff --git a/grain/_src/python/dataset/transformations/repeat.py b/grain/_src/python/dataset/transformations/repeat.py index 02abbe156..4921cab46 100644 --- a/grain/_src/python/dataset/transformations/repeat.py +++ b/grain/_src/python/dataset/transformations/repeat.py @@ -17,6 +17,8 @@ from grain._src.python.dataset import dataset from grain._src.python.dataset import stats +from grain._src.python.dataset.transformations import interleave +from grain._src.python.dataset.transformations import process_prefetch T = TypeVar("T") @@ -94,6 +96,18 @@ def __init__( self._num_epochs = num_epochs self._epoch = 0 self._parent_starting_state = self._parent.get_state() + # Check for ProcessPrefetchDatasetIterator and InterleaveDatasetIterator and + # ensure processes/iterators are not reset on StopIteration. This is needed + # to avoid recreating the worker processes on each epoch. + to_visit = [self] + while to_visit: + node = to_visit.pop(0) + if isinstance(node, process_prefetch._ProcessPrefetchDatasetIterator): # pylint: disable=protected-access + node.set_keep_workers_after_stop_iteration(True) + if isinstance(node, interleave.InterleaveDatasetIterator): + node.set_keep_iterators_after_stop_iteration(True) + to_visit.extend(n for n in node._iterators_in_use if n is not None) # pylint: disable=protected-access + to_visit.extend(n for n in node._parents) @stats.record_next_duration_if_output def __next__(self): diff --git a/grain/_src/python/dataset/transformations/repeat_test.py b/grain/_src/python/dataset/transformations/repeat_test.py index d94fc4418..c91b5b56b 100644 --- a/grain/_src/python/dataset/transformations/repeat_test.py +++ b/grain/_src/python/dataset/transformations/repeat_test.py @@ -13,11 +13,15 @@ # limitations under the License. """Tests for repeat transformation.""" +import os import sys from absl.testing import absltest from absl.testing import parameterized +import multiprocessing as mp +from grain._src.python import options as grain_options from grain._src.python.dataset import dataset +from grain._src.python.dataset.transformations import process_prefetch from grain._src.python.dataset.transformations import repeat from grain._src.python.testing import experimental as testing import numpy as np @@ -195,6 +199,29 @@ def test_element_spec(self): self.assertEqual(spec.dtype, np.int64) self.assertEqual(spec.shape, ()) + def test_repeat_after_mp_prefetch(self): + ds = dataset.MapDataset.range(20).to_iter_dataset() + ds = ds.mp_prefetch( + grain_options.MultiprocessingOptions( + num_workers=3, + per_worker_buffer_size=2, + ) + ) + ds = repeat.RepeatIterDataset(ds, num_epochs=3) + self.assertEqual(list(ds), list(range(20)) * 3) + + def test_repeat_after_mp_prefetch_does_not_restart_workers(self): + ds = dataset.MapDataset.range(20).to_iter_dataset() + ds = ds.map(lambda x: os.getpid()) + ds = process_prefetch.multiprocess_prefetch( + ds, + num_workers=3, + ) + ds = repeat.RepeatIterDataset(ds, num_epochs=3) + results = list(ds) + self.assertLen(results, 60) + self.assertLen(set(results), 3) + if __name__ == "__main__": absltest.main()