From 89b6189b75224b754934ec499a549a1a9e9615b5 Mon Sep 17 00:00:00 2001 From: Grain Team Date: Tue, 27 Jan 2026 11:49:08 -0800 Subject: [PATCH] Add support for shared memory output in Batch transformations when used before mp_prefetch. PiperOrigin-RevId: 861822077 --- CHANGELOG.md | 2 + grain/_src/python/dataset/base.py | 9 ++ .../_src/python/dataset/transformations/BUILD | 2 + .../python/dataset/transformations/batch.py | 71 +++++++++-- .../dataset/transformations/batch_test.py | 110 +++++++++++++++++- .../dataset/transformations/prefetch.py | 2 + 6 files changed, 188 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f43006174..5fd5fba0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change ## Unreleased * New features: + * Adds support for shared memory output in Batch datasets when using + multiprocessing prefetch. * Adds support for filtering Grain-internal stack frames from user-thrown errors. * Adds experimental support for `get_next_index` and `set_next_index` to fetch diff --git a/grain/_src/python/dataset/base.py b/grain/_src/python/dataset/base.py index 751f37d97..c40682587 100644 --- a/grain/_src/python/dataset/base.py +++ b/grain/_src/python/dataset/base.py @@ -139,6 +139,15 @@ class ExecutionTrackingMode(enum.Flag): STAGE_TIMING = enum.auto() +@typing.runtime_checkable +class SupportsSharedMemoryOutput(Protocol): + """Protocol for datasets that support shared memory output.""" + + def enable_shared_memory_output(self) -> None: + """Enables shared memory output for the dataset.""" + ... + + @dataclasses.dataclass(slots=True, frozen=True) class _Default(Generic[T]): """Default options value holder.""" diff --git a/grain/_src/python/dataset/transformations/BUILD b/grain/_src/python/dataset/transformations/BUILD index ee15b9230..6d75cbca2 100644 --- a/grain/_src/python/dataset/transformations/BUILD +++ b/grain/_src/python/dataset/transformations/BUILD @@ -30,10 +30,12 @@ py_test( deps = [ "//grain/_src/core:transforms", "//grain/_src/core:tree_lib", + "//grain/_src/python:shared_memory_array", "//grain/_src/python/dataset", "//grain/_src/python/dataset:base", "@abseil-py//absl/testing:absltest", "@abseil-py//absl/testing:parameterized", + "@pypi//cloudpickle:pkg", "@pypi//jax:pkg", # buildcleaner: keep "@pypi//numpy:pkg", ], diff --git a/grain/_src/python/dataset/transformations/batch.py b/grain/_src/python/dataset/transformations/batch.py index 2ba2c6903..89f2eb89a 100644 --- a/grain/_src/python/dataset/transformations/batch.py +++ b/grain/_src/python/dataset/transformations/batch.py @@ -24,6 +24,7 @@ from typing import Any, Callable, TypeVar, cast from grain._src.core import tree_lib +from grain._src.python import shared_memory_array from grain._src.python.dataset import base from grain._src.python.dataset import dataset from grain._src.python.dataset import stats @@ -81,6 +82,23 @@ class _MakeBatchParallel: def __init__(self): self._parallel_batch_executor = concurrent.futures.ThreadPoolExecutor() + self._output_to_shared_memory = False + + def enable_shared_memory_output(self): + self._output_to_shared_memory = True + + def _stack(self, xs: Sequence[Any]) -> Any: + if not self._output_to_shared_memory: + return np.stack(xs) + + first_arg = np.asanyarray(xs[0]) + shape, dtype = (len(xs),) + first_arg.shape, first_arg.dtype + if dtype.hasobject: + return np.stack(xs) + + return np.stack( + xs, out=shared_memory_array.SharedMemoryArray(shape, dtype=dtype) + ) def __call__(self, values: Sequence[T]) -> T: def _batch_fn(*xs: Sequence[T]) -> T: @@ -89,14 +107,20 @@ def _batch_fn(*xs: Sequence[T]) -> T: if (self._parallel_batch_executor is None) or not isinstance( xs[0], np.ndarray ): - return np.stack(xs) + return self._stack(xs) xs = cast(Sequence[np.ndarray], xs) + first_arg = xs[0] + shape, dtype = (len(xs),) + first_arg.shape, first_arg.dtype # Fall back to the standard serial `np.stack` operation if the size of - # of the entire batchis smaller (measured in bytes) than the threshold. + # of the entire batch is smaller (measured in bytes) than the threshold. if sum(x.nbytes for x in xs) < _PARALLEL_BATCHING_MIN_TOTAL_BYTES: - return np.stack(xs) + return self._stack(xs) + + if not self._output_to_shared_memory or dtype.hasobject: + out = np.empty(shape, dtype=dtype) + else: + out = shared_memory_array.SharedMemoryArray(shape, dtype=dtype) - out = np.empty([len(xs), *xs[0].shape], dtype=xs[0].dtype) # For each input array, submit a parallel task to the thread pool to copy # the data into the corresponding slice of the output array. fs = [] @@ -123,7 +147,11 @@ def _batch_fn(*xs: Sequence[T]) -> T: ) from e def __reduce__(self): - return (self.__class__, ()) + state = self.__dict__.copy() + # ThreadPoolExecutor is not picklable. + # We rely on __init__ to recreate it during unpickling. + state.pop("_parallel_batch_executor", None) + return (self.__class__, (), state) def __del__(self): if self._parallel_batch_executor: @@ -131,7 +159,9 @@ def __del__(self): self._parallel_batch_executor = None -def make_batch(values: Sequence[T]) -> T: +def make_batch( + values: Sequence[T], *, output_to_shared_memory: bool = False +) -> T: """Returns a batch of values with a new batch dimension at the front.""" if not values: raise ValueError("Cannot batch 0 values. Please file a bug.") @@ -142,9 +172,22 @@ def make_batch(values: Sequence[T]) -> T: lambda x: np.expand_dims(x, axis=0), values[0], ) + stacking_function = lambda *xs: np.stack(xs) + if output_to_shared_memory: + + def shm_stacking_function(*args): + first_arg = np.asanyarray(args[0]) + shape, dtype = (len(args),) + first_arg.shape, first_arg.dtype + if dtype.hasobject: + return np.stack(args) + return np.stack( + args, out=shared_memory_array.SharedMemoryArray(shape, dtype=dtype) + ) + + stacking_function = shm_stacking_function try: - return tree_lib.map_structure(lambda *xs: np.stack(xs), *values) + return tree_lib.map_structure(stacking_function, *values) except ValueError as e: # NumPy error message doesn't include actual shapes and dtypes. Provide a @@ -385,6 +428,12 @@ def __str__(self) -> str: f" drop_remainder={self._drop_remainder})" ) + def enable_shared_memory_output(self) -> None: + if self._batch_fn is make_batch: + self._batch_fn = functools.partial( + make_batch, output_to_shared_memory=True + ) + class BatchIterDataset(dataset.IterDataset[T]): """Batch transformation for IterDatasets.""" @@ -443,3 +492,11 @@ def __str__(self) -> str: f"BatchIterDataset(batch_size={self._batch_size}," f" drop_remainder={self._drop_remainder})" ) + + def enable_shared_memory_output(self) -> None: + if isinstance(self._batch_fn, _MakeBatchParallel): + self._batch_fn.enable_shared_memory_output() + elif self._batch_fn is make_batch: + self._batch_fn = functools.partial( + make_batch, output_to_shared_memory=True + ) diff --git a/grain/_src/python/dataset/transformations/batch_test.py b/grain/_src/python/dataset/transformations/batch_test.py index 02f9d66b0..a27392dbc 100644 --- a/grain/_src/python/dataset/transformations/batch_test.py +++ b/grain/_src/python/dataset/transformations/batch_test.py @@ -19,9 +19,10 @@ import sys from typing import Any from unittest import mock - from absl.testing import absltest from absl.testing import parameterized +import cloudpickle +import multiprocessing from grain._src.core import transforms from grain._src.core import tree_lib from grain._src.python.dataset import base @@ -853,6 +854,113 @@ def test_set_next_index(self, batch_size, drop_remainder, expected): actual = next(ds_iter) np.testing.assert_allclose(actual, expected[i]) + @parameterized.named_parameters( + dict( + testcase_name="source_ds_jax", + use_jax=True, + initial_ds=source.SourceMapDataset([ + np.asarray([1, 2, 3]), + np.asarray([4, 5, 6]), + np.asarray([7, 8, 9]), + np.asarray([10, 11, 12]), + ]), + expected=[ + [np.asarray([1, 2, 3]), np.asarray([4, 5, 6])], + [np.asarray([7, 8, 9]), np.asarray([10, 11, 12])], + ], + ), + dict( + testcase_name="source_ds_no_jax", + use_jax=False, + initial_ds=source.SourceMapDataset([ + np.asarray([1, 2, 3]), + np.asarray([4, 5, 6]), + np.asarray([7, 8, 9]), + np.asarray([10, 11, 12]), + ]), + expected=[ + [np.asarray([1, 2, 3]), np.asarray([4, 5, 6])], + [np.asarray([7, 8, 9]), np.asarray([10, 11, 12])], + ], + ), + ) + def test_batch_shared_memory(self, use_jax: bool, initial_ds, expected): + def test_impl(): + ds = initial_ds.to_iter_dataset() + ds = batch.BatchIterDataset(ds, batch_size=2) + ds.enable_shared_memory_output() + self.assertIsInstance(ds._batch_fn, functools.partial) + self.assertIs(ds._batch_fn.func, batch.make_batch) + self.assertTrue(ds._batch_fn.keywords.get("output_to_shared_memory")) + + ds_iter = iter(ds) + for i in range(len(expected)): + actual = next(ds_iter) + self.assertIsInstance( + actual, batch.shared_memory_array.SharedMemoryArray + ) + shm_array = batch.shared_memory_array.SharedMemoryArray.from_metadata( + actual.metadata + ) + np.testing.assert_allclose(shm_array, expected[i]) + actual.metadata.close_and_unlink_shm() + + with mock.patch.dict( + sys.modules, {"jax": sys.modules["jax"] if use_jax else None} + ): + importlib.reload(batch.tree_lib) + test_impl() + + def test_parallel_batch_shared_memory(self): + # Enable the parallel batch experiment. + with experiment_mock_utils.mocked_experiment_context("EXP_parallel_batch"): + ds = dataset.MapDataset.range(0, 10).to_iter_dataset() + # Use a batch size larger than _PARALLEL_BATCHING_MIN_BATCH_SIZE (4) + ds = batch.BatchIterDataset(ds, batch_size=5, drop_remainder=False) + ds.enable_shared_memory_output() + + # Verify that we are using _MakeBatchParallel and SHM is enabled. + self.assertIsInstance(ds._batch_fn, batch._MakeBatchParallel) + self.assertTrue(ds._batch_fn._output_to_shared_memory) + + ds_iter = iter(ds) + batch_val = next(ds_iter) + + # Verify the output is SharedMemoryArray. + self.assertIsInstance( + batch_val, batch.shared_memory_array.SharedMemoryArray + ) + + # Verify data content. + shm_array = batch.shared_memory_array.SharedMemoryArray.from_metadata( + batch_val.metadata + ) + np.testing.assert_array_equal(shm_array, np.arange(5)) + + batch_val.metadata.close_and_unlink_shm() + + # Verify next batch + batch_val = next(ds_iter) + shm_array = batch.shared_memory_array.SharedMemoryArray.from_metadata( + batch_val.metadata + ) + np.testing.assert_array_equal(shm_array, np.arange(5, 10)) + batch_val.metadata.close_and_unlink_shm() + + +class MakeBatchParallelPickleTest(absltest.TestCase): + + def test_pickle_parallel_batch_preserves_flag(self): + make_batch_parallel = batch._MakeBatchParallel() + make_batch_parallel.enable_shared_memory_output() + self.assertTrue(make_batch_parallel._output_to_shared_memory) + + pickled = cloudpickle.dumps(make_batch_parallel) + unpickled = cloudpickle.loads(pickled) + + self.assertTrue(unpickled._output_to_shared_memory) + self.assertIsNotNone(unpickled._parallel_batch_executor) + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 9316419e3..69e3bad35 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -381,6 +381,8 @@ def __init__( self._sequential_slice = sequential_slice _validate_no_double_prefetch(self._parent) self._always_report_worker_state = always_report_worker_state + if isinstance(self._parent, base.SupportsSharedMemoryOutput): + self._parent.enable_shared_memory_output() def __str__(self) -> str: return (