Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions grain/_src/python/dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ py_test(
shard_count = 5,
srcs_version = "PY3",
deps = [
"//grain/_src/python:options",
"//grain/_src/python/dataset",
"//grain/_src/python/testing:experimental",
"@abseil-py//absl/testing:absltest",
Expand Down
40 changes: 34 additions & 6 deletions grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def __init__(
] * self._cycle_length
self._started = False
self._parent_stats: list[stats.Stats | None] = [None] * self._cycle_length
self._keep_iterators_after_stop_iteration = False
self._exhausted_iterators: list[
tuple[int, dataset.DatasetIterator[T]] | None
] = [None] * self._cycle_length

@stats.record_next_duration_if_output
@stats.trace_input_pipeline_next(stage_category=stats.IPL_CAT_PREPROCESSING)
Expand Down Expand Up @@ -119,6 +123,11 @@ def __next__(self) -> T:
self._exhausted_iterator_state[self._next_index_in_cycle] = (
iterator_to_use.get_state()
)
if self._keep_iterators_after_stop_iteration:
self._exhausted_iterators[self._next_index_in_cycle] = (
self._iterators_in_use_indices[self._next_index_in_cycle],
iterator_to_use,
)
self._iterators_in_use[self._next_index_in_cycle] = None
self._next_index_in_cycle = (
self._next_index_in_cycle + 1
Expand Down Expand Up @@ -195,12 +204,22 @@ def set_state(self, state):
or iterator is None
):
# The iterator currently in use is either exhausted or corresponds to
# a different dataset. We need to create a new iterator.
iterator = _add_prefetch_and_make_iterator(
self._datasets[index_in_datasets],
interleave_iterator=weakref.ref(self),
start_prefetch=False,
)
# a different dataset. We need to create a new iterator or check the
# exhausted iterators list.
if (
self._keep_iterators_after_stop_iteration
and self._exhausted_iterators[index_in_cycle] is not None
and self._exhausted_iterators[index_in_cycle][0]
== index_in_datasets
):
_, iterator = self._exhausted_iterators[index_in_cycle]
self._exhausted_iterators[index_in_cycle] = None
else:
iterator = _add_prefetch_and_make_iterator(
self._datasets[index_in_datasets],
interleave_iterator=weakref.ref(self),
start_prefetch=False,
)
iterator.set_state(it_state)
self._iterators_in_use[index_in_cycle] = iterator
else:
Expand Down Expand Up @@ -238,6 +257,15 @@ def _set_next_index(self, index: int) -> None:
" more than one dataset."
)

def set_keep_iterators_after_stop_iteration(
self, keep_iterators: bool
) -> None:
# Determines whether the iterators should be kept alive after
# StopIteration is raised by `__next__`. This is used by
# `RepeatDatasetIterator` to allow for resetting the iterator state and
# continuing iteration without recreating the iterators.
self._keep_iterators_after_stop_iteration = keep_iterators

def close(self) -> None:
"""Closes the iterator and shuts down the iterator prefetching."""
if self._closed:
Expand Down
14 changes: 13 additions & 1 deletion grain/_src/python/dataset/transformations/process_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def __init__(
self._iter_parent = parent
self._buffer_size = buffer_size
self._worker_init_fn = worker_init_fn
self._keep_workers_after_stop_iteration = False
# Since the parent iterator is going to be created in each subprocess, and
# the options are propagated during iterator creation, we need to manually
# propagate them.
Expand Down Expand Up @@ -451,7 +452,11 @@ def __next__(self):
# Unlink shared memory for the discarded element.
shared_memory_array.unlink_shm(element)
if err is not None:
self._stop_prefetch()
if (
not isinstance(err, StopIteration)
or not self._keep_workers_after_stop_iteration
):
self._stop_prefetch()
self._exhausted = True
raise err
self._state = state
Expand Down Expand Up @@ -549,6 +554,13 @@ def _set_next_index(self, next_index: int):
self._exhausted = False
self._state = None

def set_keep_workers_after_stop_iteration(self, keep_workers: bool):
# Determines whether the worker processes should be kept alive after
# StopIteration is raised by `__next__`. This is used by
# `RepeatDatasetIterator` to allow for resetting the iterator state and
# continuing iteration without recreating the worker processes.
self._keep_workers_after_stop_iteration = keep_workers

def __str__(self) -> str:
return f"ProcessPrefetchDatasetIterator(buffer_size={self._buffer_size})"

Expand Down
14 changes: 14 additions & 0 deletions grain/_src/python/dataset/transformations/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats
from grain._src.python.dataset.transformations import interleave
from grain._src.python.dataset.transformations import process_prefetch

T = TypeVar("T")

Expand Down Expand Up @@ -94,6 +96,18 @@ def __init__(
self._num_epochs = num_epochs
self._epoch = 0
self._parent_starting_state = self._parent.get_state()
# Check for ProcessPrefetchDatasetIterator and InterleaveDatasetIterator and
# ensure processes/iterators are not reset on StopIteration. This is needed
# to avoid recreating the worker processes on each epoch.
to_visit = [self]
while to_visit:
node = to_visit.pop(0)
if isinstance(node, process_prefetch._ProcessPrefetchDatasetIterator): # pylint: disable=protected-access
node.set_keep_workers_after_stop_iteration(True)
if isinstance(node, interleave.InterleaveDatasetIterator):
node.set_keep_iterators_after_stop_iteration(True)
to_visit.extend(n for n in node._iterators_in_use if n is not None) # pylint: disable=protected-access
to_visit.extend(n for n in node._parents)

@stats.record_next_duration_if_output
def __next__(self):
Expand Down
27 changes: 27 additions & 0 deletions grain/_src/python/dataset/transformations/repeat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
# limitations under the License.
"""Tests for repeat transformation."""

import os
import sys

from absl.testing import absltest
from absl.testing import parameterized
import multiprocessing as mp
from grain._src.python import options as grain_options
from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import process_prefetch
from grain._src.python.dataset.transformations import repeat
from grain._src.python.testing import experimental as testing
import numpy as np
Expand Down Expand Up @@ -195,6 +199,29 @@ def test_element_spec(self):
self.assertEqual(spec.dtype, np.int64)
self.assertEqual(spec.shape, ())

def test_repeat_after_mp_prefetch(self):
ds = dataset.MapDataset.range(20).to_iter_dataset()
ds = ds.mp_prefetch(
grain_options.MultiprocessingOptions(
num_workers=3,
per_worker_buffer_size=2,
)
)
ds = repeat.RepeatIterDataset(ds, num_epochs=3)
self.assertEqual(list(ds), list(range(20)) * 3)

def test_repeat_after_mp_prefetch_does_not_restart_workers(self):
ds = dataset.MapDataset.range(20).to_iter_dataset()
ds = ds.map(lambda x: os.getpid())
ds = process_prefetch.multiprocess_prefetch(
ds,
num_workers=3,
)
ds = repeat.RepeatIterDataset(ds, num_epochs=3)
results = list(ds)
self.assertLen(results, 60)
self.assertLen(set(results), 3)


if __name__ == "__main__":
absltest.main()