diff --git a/CHANGELOG.md b/CHANGELOG.md index 1321a77dd..7437d241a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change and advance a `grain.DatasetIterator` to the given produced element index. * Switches to multithreading instead of multiprocessing in `IterDataset.mp_prefetch` when free-threaded Python is detected. + * `grain.DataLoaderIterator` can now asynchronously start processing elements + in background with `start_prefetch` call. * Breaking changes: * Custom implementations of `RandomAccessDataSource` should accept `int` diff --git a/grain/_src/python/data_loader.py b/grain/_src/python/data_loader.py index bee44d937..f6caafaee 100644 --- a/grain/_src/python/data_loader.py +++ b/grain/_src/python/data_loader.py @@ -561,6 +561,20 @@ def set_state(self, state: bytes): self._data_loader._validate_state(state) # pylint: disable=protected-access self._iterator.set_state(state) + def start_prefetch(self): + """Starts processing elements asynchronously in the background. + + Useful when the iterator can be created in advance but the elements are not + needed immediately. For instance, when recovering iterator and model from a + checkpoint, recover the iterator first, call ``start_prefech`` and then + recover the model. This way the time to get the first batch from the + iterator will be partially or fully hidden behind the time it takes to + recover the model. + + This method is idempotent and safe to call multiple times. + """ + self._iterator.start_prefetch() + ### BEGIN Orbax checkpointing API. # See orbax.checkpoint.v1.handlers.StatefulCheckpointable for more details. # See https://orbax.readthedocs.io/en/latest/ for usage examples. diff --git a/grain/_src/python/data_loader_test.py b/grain/_src/python/data_loader_test.py index e945c6e06..f53fef171 100644 --- a/grain/_src/python/data_loader_test.py +++ b/grain/_src/python/data_loader_test.py @@ -772,6 +772,27 @@ def test_data_loader_with_flat_map_checkpointing(self): ) assert_equal_output_after_checkpoint(data_loader) + @absl_parameterized.product( + worker_count=[0, 4], num_start_prefetch_calls=[1, 5] + ) + def test_start_prefetch( + self, worker_count: int, num_start_prefetch_calls: int + ): + range_data_source = RangeDataSource(start=0, stop=16, step=1) + sampler = samplers.SequentialSampler( + num_records=len(range_data_source), shard_options=sharding.NoSharding() + ) + data_loader = data_loader_lib.DataLoader( + data_source=range_data_source, + sampler=sampler, + read_options=self.read_options, + worker_count=worker_count, + ) + data_loader_iterator = data_loader.__iter__() + for _ in range(num_start_prefetch_calls): + data_loader_iterator.start_prefetch() + self.assertEqual(list(data_loader_iterator), list(range(16))) + class PyGrainDatasetIteratorTest(absltest.TestCase):