diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index 3539899bb..c2af32110 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -199,7 +199,7 @@ def set_state(self, state): iterator = _add_prefetch_and_make_iterator( self._datasets[index_in_datasets], interleave_iterator=weakref.ref(self), - start_prefetch=False, + start_prefetch=self._started, ) iterator.set_state(it_state) self._iterators_in_use[index_in_cycle] = iterator @@ -238,6 +238,13 @@ def _set_next_index(self, index: int) -> None: " more than one dataset." ) + def start_prefetch(self) -> None: + self._prefetch_ds_iter.start_prefetch() + for iterator in self._iterators_in_use: + if iterator is not None: + iterator.start_prefetch() + self._started = True + def close(self) -> None: """Closes the iterator and shuts down the iterator prefetching.""" if self._closed: diff --git a/grain/_src/python/dataset/transformations/interleave_test.py b/grain/_src/python/dataset/transformations/interleave_test.py index c5d647307..af2b0eff0 100644 --- a/grain/_src/python/dataset/transformations/interleave_test.py +++ b/grain/_src/python/dataset/transformations/interleave_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time + from absl.testing import absltest from absl.testing import flagsaver from absl.testing import parameterized @@ -291,6 +293,22 @@ def test_set_next_index_with_multiple_datasets(self): ): dataset.set_next_index(ds_iter, 0) + def test_start_prefetch(self): + count = 0 + + def map_fn(x): + nonlocal count + count += 1 + return x + + ds = dataset.MapDataset.range(10).to_iter_dataset() + ds = ds.map(map_fn) + ds = interleave.InterleaveIterDataset([ds], cycle_length=1) + ds_iter = ds.__iter__() + ds_iter.start_prefetch() + while count == 0: + time.sleep(0.1) + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/transformations/process_prefetch_test.py b/grain/_src/python/dataset/transformations/process_prefetch_test.py index 6d4fe45d8..fa033302c 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch_test.py +++ b/grain/_src/python/dataset/transformations/process_prefetch_test.py @@ -48,6 +48,16 @@ def filter(self, element: int) -> bool: return bool(element % 2) +@dataclasses.dataclass(frozen=True) +class WriteMarker(transforms.Map): + path: str + + def map(self, element: int) -> int: + with open(self.path, 'w') as f: + f.write(str(element)) + return element + + class ProcessPrefetchIterDatasetTest(parameterized.TestCase): def setUp(self): @@ -782,10 +792,22 @@ def map(self, features): if not start_prefetch_calls: self.assertGreater(time_to_fetch, 1) + def test_start_prefetch_prefetches_without_next_call(self): + marker_file = os.path.join(self.create_tempdir().full_path, 'marker') + ds = dataset.MapDataset.range(10) + ds = ds.map(WriteMarker(marker_file)) + ds = ds.to_iter_dataset() + ds = process_prefetch.multiprocess_prefetch(ds, num_workers=1) + it = ds.__iter__() + it.start_prefetch() + + # Wait for prefetch to happen. + while not os.path.exists(marker_file): + time.sleep(0.5) + @parameterized.parameters(0, 0.5, 30) def test_prefetch_but_no_read(self, sleep_s): ds = dataset.MapDataset.source([1, 2, 3]).repeat() - ds = ds.filter(lambda x: x > 3) ds = ds.to_iter_dataset() ds = process_prefetch.multiprocess_prefetch(ds, num_workers=1) it = ds.__iter__()