From ee6e7e36d83a89b7365329cedb01380e4c97291a Mon Sep 17 00:00:00 2001 From: Nithin Tatikonda Date: Sat, 7 Feb 2026 11:37:11 -0800 Subject: [PATCH] Internal PiperOrigin-RevId: 866945487 --- .../dataset/transformations/prefetch.py | 1 + .../transformations/process_prefetch.py | 32 +++++---- .../transformations/process_prefetch_test.py | 69 +++++++++++++++++++ 3 files changed, 87 insertions(+), 15 deletions(-) diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index f31b26f3b..cc46870c3 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -619,6 +619,7 @@ def _get_next_index(self) -> int: # after setting the state to the point before all current buffer elements # were produced from the parent iterator. state = self.get_state() + self._stop_prefetch() self._maybe_nonnative_parent.set_state(state) self._next_index = dataset.get_next_index(self._maybe_nonnative_parent) return self._next_index diff --git a/grain/_src/python/dataset/transformations/process_prefetch.py b/grain/_src/python/dataset/transformations/process_prefetch.py index 156eba673..2d9808932 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch.py +++ b/grain/_src/python/dataset/transformations/process_prefetch.py @@ -209,23 +209,25 @@ def _put_dataset_elements_in_buffer( next_index: int | None = 0 while not should_stop.is_set(): if set_state_request_count.value > 0: + new_state_or_index = None with set_state_request_count.get_lock(): if set_state_request_count.value > 0: set_state_request_count.value -= 1 - parent_exhausted = False - if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types - (_SetStateIsDone(), None, None, None), - buffer, - should_stop.is_set, - ): - continue new_state_or_index = set_state_queue.get() - if isinstance(new_state_or_index, int): - dataset.set_next_index(it, new_state_or_index) - next_index = new_state_or_index - else: - it.set_state(new_state_or_index) - next_index = None + parent_exhausted = False + if new_state_or_index is not None: + if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types + (_SetStateIsDone(), None, None, None), + buffer, + should_stop.is_set, + ): + continue + if isinstance(new_state_or_index, int): + dataset.set_next_index(it, new_state_or_index) + next_index = new_state_or_index + else: + it.set_state(new_state_or_index) + next_index = None if parent_exhausted: # Avoid busy-waiting when parent iterator is exhausted due to an # error. Wait until set_state_event or should_stop is set. @@ -244,6 +246,8 @@ def _put_dataset_elements_in_buffer( # __next__ method. if not it._stats._config.is_prefetch: # pylint: disable=protected-access it._stats.record_bytes_produced(element) # pylint: disable=protected-access + if next_index is not None: + next_index += 1 if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types (element, it.get_state(), next_index, None), buffer, @@ -253,8 +257,6 @@ def _put_dataset_elements_in_buffer( # should_stop event was set. The element may contain a shared memory # block reference that has to be cleaned up. shared_memory_array.unlink_shm(element) - if next_index is not None: - next_index += 1 except Exception as e: # pylint: disable=broad-except _clear_queue_and_maybe_unlink_shm(buffer) _clear_queue_and_maybe_unlink_shm(set_state_queue) diff --git a/grain/_src/python/dataset/transformations/process_prefetch_test.py b/grain/_src/python/dataset/transformations/process_prefetch_test.py index 6d4fe45d8..935d10007 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch_test.py +++ b/grain/_src/python/dataset/transformations/process_prefetch_test.py @@ -330,6 +330,75 @@ def map_fn(x): self.assertEqual(next(ds_iter), 2) ds_iter.close() + def test_get_next_index(self): + ds = process_prefetch.ProcessPrefetchIterDataset( + dataset.MapDataset.range(10).to_iter_dataset(), + buffer_size=1, + ) + ds_iter = ds.__iter__() + for i in range(10): + self.assertEqual(dataset.get_next_index(ds_iter), i) + next(ds_iter) + + def test_set_next_index(self): + ds = process_prefetch.ProcessPrefetchIterDataset( + dataset.MapDataset.range(10).to_iter_dataset(), + buffer_size=1, + ) + ds_iter = ds.__iter__() + for i in reversed(range(10)): + dataset.set_next_index(ds_iter, i) + self.assertEqual(next(ds_iter), i) + + def test_alternate_set_state_and_get_next_index(self): + ds = process_prefetch.ProcessPrefetchIterDataset( + dataset.MapDataset.range(20).to_iter_dataset(), + buffer_size=1, + ) + ds_iter = ds.__iter__() + next(ds_iter) + state1 = ds_iter.get_state() + index1 = dataset.get_next_index(ds_iter) + next(ds_iter) + next(ds_iter) + index2 = dataset.get_next_index(ds_iter) + next(ds_iter) + + self.assertEqual(index1, 1) + self.assertEqual(index2, 3) + + ds_iter.set_state(state1) + self.assertEqual(dataset.get_next_index(ds_iter), 1) + next(ds_iter) + ds_iter.set_state(state1) + next(ds_iter) + next(ds_iter) + self.assertEqual(dataset.get_next_index(ds_iter), 3) + + def test_alternate_set_next_index_and_get_state(self): + ds = process_prefetch.ProcessPrefetchIterDataset( + dataset.MapDataset.range(20).to_iter_dataset(), + buffer_size=1, + ) + ds_iter = ds.__iter__() + next(ds_iter) + state1 = ds_iter.get_state() + index1 = dataset.get_next_index(ds_iter) + next(ds_iter) + next(ds_iter) + state2 = ds_iter.get_state() + next(ds_iter) + + self.assertEqual(index1, 1) + + dataset.set_next_index(ds_iter, index1) + self.assertEqual(ds_iter.get_state(), state1) + next(ds_iter) + dataset.set_next_index(ds_iter, index1) + next(ds_iter) + next(ds_iter) + self.assertEqual(ds_iter.get_state(), state2) + @dataclasses.dataclass(frozen=True) class FilterAllElements(transforms.Filter):