Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
and advance a `grain.DatasetIterator` to the given produced element index.
* Switches to multithreading instead of multiprocessing in
`IterDataset.mp_prefetch` when free-threaded Python is detected.
* `grain.DataLoaderIterator` can now asynchronously start processing elements
in background with `start_prefetch` call.

* Breaking changes:
* Custom implementations of `RandomAccessDataSource` should accept `int`
Expand Down
14 changes: 14 additions & 0 deletions grain/_src/python/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,20 @@ def set_state(self, state: bytes):
self._data_loader._validate_state(state) # pylint: disable=protected-access
self._iterator.set_state(state)

def start_prefetch(self):
"""Starts processing elements asynchronously in the background.

Useful when the iterator can be created in advance but the elements are not
needed immediately. For instance, when recovering iterator and model from a
checkpoint, recover the iterator first, call ``start_prefech`` and then
recover the model. This way the time to get the first batch from the
iterator will be partially or fully hidden behind the time it takes to
recover the model.

This method is idempotent and safe to call multiple times.
"""
self._iterator.start_prefetch()

### BEGIN Orbax checkpointing API.
# See orbax.checkpoint.v1.handlers.StatefulCheckpointable for more details.
# See https://orbax.readthedocs.io/en/latest/ for usage examples.
Expand Down
21 changes: 21 additions & 0 deletions grain/_src/python/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,27 @@ def test_data_loader_with_flat_map_checkpointing(self):
)
assert_equal_output_after_checkpoint(data_loader)

@absl_parameterized.product(
worker_count=[0, 4], num_start_prefetch_calls=[1, 5]
)
def test_start_prefetch(
self, worker_count: int, num_start_prefetch_calls: int
):
range_data_source = RangeDataSource(start=0, stop=16, step=1)
sampler = samplers.SequentialSampler(
num_records=len(range_data_source), shard_options=sharding.NoSharding()
)
data_loader = data_loader_lib.DataLoader(
data_source=range_data_source,
sampler=sampler,
read_options=self.read_options,
worker_count=worker_count,
)
data_loader_iterator = data_loader.__iter__()
for _ in range(num_start_prefetch_calls):
data_loader_iterator.start_prefetch()
self.assertEqual(list(data_loader_iterator), list(range(16)))


class PyGrainDatasetIteratorTest(absltest.TestCase):

Expand Down