Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def _get_next_index(self) -> int:
# after setting the state to the point before all current buffer elements
# were produced from the parent iterator.
state = self.get_state()
self._stop_prefetch()
self._maybe_nonnative_parent.set_state(state)
self._next_index = dataset.get_next_index(self._maybe_nonnative_parent)
return self._next_index
Expand Down
32 changes: 17 additions & 15 deletions grain/_src/python/dataset/transformations/process_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,23 +209,25 @@ def _put_dataset_elements_in_buffer(
next_index: int | None = 0
while not should_stop.is_set():
if set_state_request_count.value > 0:
new_state_or_index = None
with set_state_request_count.get_lock():
if set_state_request_count.value > 0:
set_state_request_count.value -= 1
parent_exhausted = False
if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types
(_SetStateIsDone(), None, None, None),
buffer,
should_stop.is_set,
):
continue
new_state_or_index = set_state_queue.get()
if isinstance(new_state_or_index, int):
dataset.set_next_index(it, new_state_or_index)
next_index = new_state_or_index
else:
it.set_state(new_state_or_index)
next_index = None
parent_exhausted = False
if new_state_or_index is not None:
if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types
(_SetStateIsDone(), None, None, None),
buffer,
should_stop.is_set,
):
continue
if isinstance(new_state_or_index, int):
dataset.set_next_index(it, new_state_or_index)
next_index = new_state_or_index
else:
it.set_state(new_state_or_index)
next_index = None
if parent_exhausted:
# Avoid busy-waiting when parent iterator is exhausted due to an
# error. Wait until set_state_event or should_stop is set.
Expand All @@ -244,6 +246,8 @@ def _put_dataset_elements_in_buffer(
# __next__ method.
if not it._stats._config.is_prefetch: # pylint: disable=protected-access
it._stats.record_bytes_produced(element) # pylint: disable=protected-access
if next_index is not None:
next_index += 1
if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types
(element, it.get_state(), next_index, None),
buffer,
Expand All @@ -253,8 +257,6 @@ def _put_dataset_elements_in_buffer(
# should_stop event was set. The element may contain a shared memory
# block reference that has to be cleaned up.
shared_memory_array.unlink_shm(element)
if next_index is not None:
next_index += 1
except Exception as e: # pylint: disable=broad-except
_clear_queue_and_maybe_unlink_shm(buffer)
_clear_queue_and_maybe_unlink_shm(set_state_queue)
Expand Down
69 changes: 69 additions & 0 deletions grain/_src/python/dataset/transformations/process_prefetch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,75 @@ def map_fn(x):
self.assertEqual(next(ds_iter), 2)
ds_iter.close()

def test_get_next_index(self):
ds = process_prefetch.ProcessPrefetchIterDataset(
dataset.MapDataset.range(10).to_iter_dataset(),
buffer_size=1,
)
ds_iter = ds.__iter__()
for i in range(10):
self.assertEqual(dataset.get_next_index(ds_iter), i)
next(ds_iter)

def test_set_next_index(self):
ds = process_prefetch.ProcessPrefetchIterDataset(
dataset.MapDataset.range(10).to_iter_dataset(),
buffer_size=1,
)
ds_iter = ds.__iter__()
for i in reversed(range(10)):
dataset.set_next_index(ds_iter, i)
self.assertEqual(next(ds_iter), i)

def test_alternate_set_state_and_get_next_index(self):
ds = process_prefetch.ProcessPrefetchIterDataset(
dataset.MapDataset.range(20).to_iter_dataset(),
buffer_size=1,
)
ds_iter = ds.__iter__()
next(ds_iter)
state1 = ds_iter.get_state()
index1 = dataset.get_next_index(ds_iter)
next(ds_iter)
next(ds_iter)
index2 = dataset.get_next_index(ds_iter)
next(ds_iter)

self.assertEqual(index1, 1)
self.assertEqual(index2, 3)

ds_iter.set_state(state1)
self.assertEqual(dataset.get_next_index(ds_iter), 1)
next(ds_iter)
ds_iter.set_state(state1)
next(ds_iter)
next(ds_iter)
self.assertEqual(dataset.get_next_index(ds_iter), 3)

def test_alternate_set_next_index_and_get_state(self):
ds = process_prefetch.ProcessPrefetchIterDataset(
dataset.MapDataset.range(20).to_iter_dataset(),
buffer_size=1,
)
ds_iter = ds.__iter__()
next(ds_iter)
state1 = ds_iter.get_state()
index1 = dataset.get_next_index(ds_iter)
next(ds_iter)
next(ds_iter)
state2 = ds_iter.get_state()
next(ds_iter)

self.assertEqual(index1, 1)

dataset.set_next_index(ds_iter, index1)
self.assertEqual(ds_iter.get_state(), state1)
next(ds_iter)
dataset.set_next_index(ds_iter, index1)
next(ds_iter)
next(ds_iter)
self.assertEqual(ds_iter.get_state(), state2)


@dataclasses.dataclass(frozen=True)
class FilterAllElements(transforms.Filter):
Expand Down
Loading