Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
## Unreleased

* New features:
* Adds support for shared memory output in Batch datasets when using
multiprocessing prefetch.
* Adds support for filtering Grain-internal stack frames from user-thrown
errors.
* Adds experimental support for `get_next_index` and `set_next_index` to fetch
Expand Down
9 changes: 9 additions & 0 deletions grain/_src/python/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ class ExecutionTrackingMode(enum.Flag):
STAGE_TIMING = enum.auto()


@typing.runtime_checkable
class SupportsSharedMemoryOutput(Protocol):
"""Protocol for datasets that support shared memory output."""

def enable_shared_memory_output(self) -> None:
"""Enables shared memory output for the dataset."""
...


@dataclasses.dataclass(slots=True, frozen=True)
class _Default(Generic[T]):
"""Default options value holder."""
Expand Down
2 changes: 2 additions & 0 deletions grain/_src/python/dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ py_test(
deps = [
"//grain/_src/core:transforms",
"//grain/_src/core:tree_lib",
"//grain/_src/python:shared_memory_array",
"//grain/_src/python/dataset",
"//grain/_src/python/dataset:base",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
"@pypi//cloudpickle:pkg",
"@pypi//jax:pkg", # buildcleaner: keep
"@pypi//numpy:pkg",
],
Expand Down
71 changes: 64 additions & 7 deletions grain/_src/python/dataset/transformations/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Any, Callable, TypeVar, cast

from grain._src.core import tree_lib
from grain._src.python import shared_memory_array
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats
Expand Down Expand Up @@ -81,6 +82,23 @@ class _MakeBatchParallel:

def __init__(self):
self._parallel_batch_executor = concurrent.futures.ThreadPoolExecutor()
self._output_to_shared_memory = False

def enable_shared_memory_output(self):
self._output_to_shared_memory = True

def _stack(self, xs: Sequence[Any]) -> Any:
if not self._output_to_shared_memory:
return np.stack(xs)

first_arg = np.asanyarray(xs[0])
shape, dtype = (len(xs),) + first_arg.shape, first_arg.dtype
if dtype.hasobject:
return np.stack(xs)

return np.stack(
xs, out=shared_memory_array.SharedMemoryArray(shape, dtype=dtype)
)

def __call__(self, values: Sequence[T]) -> T:
def _batch_fn(*xs: Sequence[T]) -> T:
Expand All @@ -89,14 +107,20 @@ def _batch_fn(*xs: Sequence[T]) -> T:
if (self._parallel_batch_executor is None) or not isinstance(
xs[0], np.ndarray
):
return np.stack(xs)
return self._stack(xs)
xs = cast(Sequence[np.ndarray], xs)
first_arg = xs[0]
shape, dtype = (len(xs),) + first_arg.shape, first_arg.dtype
# Fall back to the standard serial `np.stack` operation if the size of
# of the entire batchis smaller (measured in bytes) than the threshold.
# of the entire batch is smaller (measured in bytes) than the threshold.
if sum(x.nbytes for x in xs) < _PARALLEL_BATCHING_MIN_TOTAL_BYTES:
return np.stack(xs)
return self._stack(xs)

if not self._output_to_shared_memory or dtype.hasobject:
out = np.empty(shape, dtype=dtype)
else:
out = shared_memory_array.SharedMemoryArray(shape, dtype=dtype)

out = np.empty([len(xs), *xs[0].shape], dtype=xs[0].dtype)
# For each input array, submit a parallel task to the thread pool to copy
# the data into the corresponding slice of the output array.
fs = []
Expand All @@ -123,15 +147,21 @@ def _batch_fn(*xs: Sequence[T]) -> T:
) from e

def __reduce__(self):
return (self.__class__, ())
state = self.__dict__.copy()
# ThreadPoolExecutor is not picklable.
# We rely on __init__ to recreate it during unpickling.
state.pop("_parallel_batch_executor", None)
return (self.__class__, (), state)

def __del__(self):
if self._parallel_batch_executor:
self._parallel_batch_executor.shutdown(wait=False, cancel_futures=True)
self._parallel_batch_executor = None


def make_batch(values: Sequence[T]) -> T:
def make_batch(
values: Sequence[T], *, output_to_shared_memory: bool = False
) -> T:
"""Returns a batch of values with a new batch dimension at the front."""
if not values:
raise ValueError("Cannot batch 0 values. Please file a bug.")
Expand All @@ -142,9 +172,22 @@ def make_batch(values: Sequence[T]) -> T:
lambda x: np.expand_dims(x, axis=0),
values[0],
)
stacking_function = lambda *xs: np.stack(xs)
if output_to_shared_memory:

def shm_stacking_function(*args):
first_arg = np.asanyarray(args[0])
shape, dtype = (len(args),) + first_arg.shape, first_arg.dtype
if dtype.hasobject:
return np.stack(args)
return np.stack(
args, out=shared_memory_array.SharedMemoryArray(shape, dtype=dtype)
)

stacking_function = shm_stacking_function

try:
return tree_lib.map_structure(lambda *xs: np.stack(xs), *values)
return tree_lib.map_structure(stacking_function, *values)

except ValueError as e:
# NumPy error message doesn't include actual shapes and dtypes. Provide a
Expand Down Expand Up @@ -385,6 +428,12 @@ def __str__(self) -> str:
f" drop_remainder={self._drop_remainder})"
)

def enable_shared_memory_output(self) -> None:
if self._batch_fn is make_batch:
self._batch_fn = functools.partial(
make_batch, output_to_shared_memory=True
)


class BatchIterDataset(dataset.IterDataset[T]):
"""Batch transformation for IterDatasets."""
Expand Down Expand Up @@ -443,3 +492,11 @@ def __str__(self) -> str:
f"BatchIterDataset(batch_size={self._batch_size},"
f" drop_remainder={self._drop_remainder})"
)

def enable_shared_memory_output(self) -> None:
if isinstance(self._batch_fn, _MakeBatchParallel):
self._batch_fn.enable_shared_memory_output()
elif self._batch_fn is make_batch:
self._batch_fn = functools.partial(
make_batch, output_to_shared_memory=True
)
110 changes: 109 additions & 1 deletion grain/_src/python/dataset/transformations/batch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
import sys
from typing import Any
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
import cloudpickle
import multiprocessing
from grain._src.core import transforms
from grain._src.core import tree_lib
from grain._src.python.dataset import base
Expand Down Expand Up @@ -853,6 +854,113 @@ def test_set_next_index(self, batch_size, drop_remainder, expected):
actual = next(ds_iter)
np.testing.assert_allclose(actual, expected[i])

@parameterized.named_parameters(
dict(
testcase_name="source_ds_jax",
use_jax=True,
initial_ds=source.SourceMapDataset([
np.asarray([1, 2, 3]),
np.asarray([4, 5, 6]),
np.asarray([7, 8, 9]),
np.asarray([10, 11, 12]),
]),
expected=[
[np.asarray([1, 2, 3]), np.asarray([4, 5, 6])],
[np.asarray([7, 8, 9]), np.asarray([10, 11, 12])],
],
),
dict(
testcase_name="source_ds_no_jax",
use_jax=False,
initial_ds=source.SourceMapDataset([
np.asarray([1, 2, 3]),
np.asarray([4, 5, 6]),
np.asarray([7, 8, 9]),
np.asarray([10, 11, 12]),
]),
expected=[
[np.asarray([1, 2, 3]), np.asarray([4, 5, 6])],
[np.asarray([7, 8, 9]), np.asarray([10, 11, 12])],
],
),
)
def test_batch_shared_memory(self, use_jax: bool, initial_ds, expected):
def test_impl():
ds = initial_ds.to_iter_dataset()
ds = batch.BatchIterDataset(ds, batch_size=2)
ds.enable_shared_memory_output()
self.assertIsInstance(ds._batch_fn, functools.partial)
self.assertIs(ds._batch_fn.func, batch.make_batch)
self.assertTrue(ds._batch_fn.keywords.get("output_to_shared_memory"))

ds_iter = iter(ds)
for i in range(len(expected)):
actual = next(ds_iter)
self.assertIsInstance(
actual, batch.shared_memory_array.SharedMemoryArray
)
shm_array = batch.shared_memory_array.SharedMemoryArray.from_metadata(
actual.metadata
)
np.testing.assert_allclose(shm_array, expected[i])
actual.metadata.close_and_unlink_shm()

with mock.patch.dict(
sys.modules, {"jax": sys.modules["jax"] if use_jax else None}
):
importlib.reload(batch.tree_lib)
test_impl()

def test_parallel_batch_shared_memory(self):
# Enable the parallel batch experiment.
with experiment_mock_utils.mocked_experiment_context("EXP_parallel_batch"):
ds = dataset.MapDataset.range(0, 10).to_iter_dataset()
# Use a batch size larger than _PARALLEL_BATCHING_MIN_BATCH_SIZE (4)
ds = batch.BatchIterDataset(ds, batch_size=5, drop_remainder=False)
ds.enable_shared_memory_output()

# Verify that we are using _MakeBatchParallel and SHM is enabled.
self.assertIsInstance(ds._batch_fn, batch._MakeBatchParallel)
self.assertTrue(ds._batch_fn._output_to_shared_memory)

ds_iter = iter(ds)
batch_val = next(ds_iter)

# Verify the output is SharedMemoryArray.
self.assertIsInstance(
batch_val, batch.shared_memory_array.SharedMemoryArray
)

# Verify data content.
shm_array = batch.shared_memory_array.SharedMemoryArray.from_metadata(
batch_val.metadata
)
np.testing.assert_array_equal(shm_array, np.arange(5))

batch_val.metadata.close_and_unlink_shm()

# Verify next batch
batch_val = next(ds_iter)
shm_array = batch.shared_memory_array.SharedMemoryArray.from_metadata(
batch_val.metadata
)
np.testing.assert_array_equal(shm_array, np.arange(5, 10))
batch_val.metadata.close_and_unlink_shm()


class MakeBatchParallelPickleTest(absltest.TestCase):

def test_pickle_parallel_batch_preserves_flag(self):
make_batch_parallel = batch._MakeBatchParallel()
make_batch_parallel.enable_shared_memory_output()
self.assertTrue(make_batch_parallel._output_to_shared_memory)

pickled = cloudpickle.dumps(make_batch_parallel)
unpickled = cloudpickle.loads(pickled)

self.assertTrue(unpickled._output_to_shared_memory)
self.assertIsNotNone(unpickled._parallel_batch_executor)


if __name__ == "__main__":
absltest.main()
2 changes: 2 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ def __init__(
self._sequential_slice = sequential_slice
_validate_no_double_prefetch(self._parent)
self._always_report_worker_state = always_report_worker_state
if isinstance(self._parent, base.SupportsSharedMemoryOutput):
self._parent.enable_shared_memory_output()

def __str__(self) -> str:
return (
Expand Down
Loading