Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion grain/_src/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,10 @@ py_library(
name = "options",
srcs = ["options.py"],
srcs_version = "PY3",
deps = ["@abseil-py//absl/logging"],
deps = [
"//grain/_src/python/experimental/autotune/python:autotune_bindings",
"@abseil-py//absl/logging",
],
)

py_test(
Expand Down
1 change: 1 addition & 0 deletions grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ py_library(
"//grain/_src/python:options",
"//grain/_src/python:shared_memory_array",
"//grain/_src/python/checkpoint:base",
"//grain/_src/python/experimental/autotune/python:autotune_util",
"//grain/proto:execution_summary_py_pb2",
"@abseil-py//absl/flags",
"@abseil-py//absl/logging",
Expand Down
30 changes: 25 additions & 5 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from grain._src.python.dataset.transformations import filter as filter_dataset
from grain._src.python.dataset.transformations import interleave
from grain._src.python.dataset.transformations import source
from grain._src.python.experimental.autotune.python import autotune_util

T = TypeVar("T")

Expand Down Expand Up @@ -143,10 +144,17 @@ def __init__(
self._next_buffered_index = 0
self._buffer = collections.deque()
self._lock = threading.Lock()
self._prefetch_buffer_size = (
read_options.prefetch_buffer_size if read_options.num_threads > 0 else 0

self._num_threads, self.target_num_threads = (
autotune_util.get_autotune_parameter(read_options.num_threads)
)
self._prefetch_buffer_size, self.target_buffer_size = (
autotune_util.get_autotune_parameter(read_options.prefetch_buffer_size)
)
self._num_threads = read_options.num_threads

if self._num_threads == 0:
self._prefetch_buffer_size = 0

self._allow_nones = allow_nones
if self._prefetch_buffer_size > 0:
self._executor = futures.ThreadPoolExecutor(
Expand Down Expand Up @@ -185,6 +193,7 @@ def _threshold_checker(self):
)
def __next__(self) -> T:
self._assert_not_closed()
self._check_autotune_updates()
# The time recorded here is the time spent in prefetch node to return an
# element, including the time spent in parent node.
timer = dataset_stats.Timer()
Expand Down Expand Up @@ -254,7 +263,7 @@ def __str__(self) -> str:
f" allow_nones={self._allow_nones})"
)

def set_prefetch_buffer_size(self, buffer_size: int):
def _set_prefetch_buffer_size(self, buffer_size: int):
self._prefetch_buffer_size = buffer_size
# The executor is created in the constructor only if the prefetch buffer
# size is greater than 0. If the user changes the prefetch buffer size, we
Expand All @@ -272,7 +281,7 @@ def set_prefetch_buffer_size(self, buffer_size: int):
self._executor.shutdown()
delattr(self, "_executor")

def set_num_threads(self, num_threads: int) -> None:
def _set_num_threads(self, num_threads: int) -> None:
self._num_threads = num_threads
old_executor = None
# Accounts for the case where the executor does not exit. This can
Expand All @@ -290,6 +299,17 @@ def set_num_threads(self, num_threads: int) -> None:
# assigned asynchronously.
old_executor.shutdown(wait=False)

def _check_autotune_updates(self):
if self.target_buffer_size is not None:
new_buffer_size = int(round(self.target_buffer_size.get_value()))
if new_buffer_size != self._prefetch_buffer_size:
self._set_prefetch_buffer_size(new_buffer_size)

if self.target_num_threads is not None:
new_num_threads = int(round(self.target_num_threads.get_value()))
if new_num_threads != self._num_threads:
self._set_num_threads(new_num_threads)

def _fill_buffer(self):
while (
len(self._buffer) < self._prefetch_buffer_size
Expand Down
18 changes: 9 additions & 9 deletions grain/_src/python/dataset/transformations/prefetch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_set_prefetch_buffer_size_0_to_positive(self):
self.assertEqual(next(ds_iter), 0)

# Setting prefetch_buffer_size to 2.
ds_iter.set_prefetch_buffer_size(2)
ds_iter._set_prefetch_buffer_size(2)
self.assertEqual(ds_iter._prefetch_buffer_size, 2)
self.assertEqual(next(ds_iter), 1)
self.assertTrue(hasattr(ds_iter, '_executor'))
Expand All @@ -183,7 +183,7 @@ def test_set_prefetch_buffer_size_positive_to_0(self):
self.assertLen(ds_iter._buffer, 2)

# Setting prefetch_buffer_size to 0.
ds_iter.set_prefetch_buffer_size(0)
ds_iter._set_prefetch_buffer_size(0)
self.assertEqual(ds_iter._prefetch_buffer_size, 0)
# Should consume buffer first.
self.assertEqual(next(ds_iter), 1)
Expand All @@ -207,7 +207,7 @@ def test_set_prefetch_buffer_size_increase(self):
self.assertLen(ds_iter._buffer, 1)

# Setting prefetch_buffer_size to 2.
ds_iter.set_prefetch_buffer_size(2)
ds_iter._set_prefetch_buffer_size(2)
self.assertEqual(ds_iter._prefetch_buffer_size, 2)
self.assertEqual(next(ds_iter), 1)
self.assertLen(ds_iter._buffer, 2)
Expand All @@ -227,7 +227,7 @@ def test_set_prefetch_buffer_size_decrease(self):
self.assertLen(ds_iter._buffer, 2)

# Setting prefetch_buffer_size to 1.
ds_iter.set_prefetch_buffer_size(1)
ds_iter._set_prefetch_buffer_size(1)
self.assertEqual(ds_iter._prefetch_buffer_size, 1)
self.assertEqual(next(ds_iter), 1)
self.assertLen(ds_iter._buffer, 1)
Expand Down Expand Up @@ -328,7 +328,7 @@ def test_set_num_threads_decrease_threads(self):
self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5)))

# Decrease threads
ds_iter.set_num_threads(5)
ds_iter._set_num_threads(5)
self.assertEqual(ds_iter._num_threads, 5)
self.assertEqual(ds_iter._executor._max_workers, 5)
self.assertEqual([next(ds_iter) for _ in range(15)], list(range(5, 20)))
Expand All @@ -345,7 +345,7 @@ def test_set_num_threads_increase_threads(self):
self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5)))

# Increase threads
ds_iter.set_num_threads(10)
ds_iter._set_num_threads(10)
self.assertEqual(ds_iter._num_threads, 10)
self.assertEqual(ds_iter._executor._max_workers, 10)
self.assertEqual([next(ds_iter) for _ in range(15)], list(range(5, 20)))
Expand All @@ -360,7 +360,7 @@ def test_set_num_threads_decrease_to_zero(self):
)
self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5)))
# Decrease threads to 0
ds_iter.set_num_threads(0)
ds_iter._set_num_threads(0)
self.assertEqual(ds_iter._num_threads, 0)
self.assertFalse(hasattr(ds_iter, '_executor'))
self.assertEqual([next(ds_iter) for _ in range(15)], list(range(5, 20)))
Expand All @@ -370,13 +370,13 @@ def test_set_num_threads_increase_from_zero(self):
self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator)
ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter)
self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5)))
ds_iter.set_num_threads(0)
ds_iter._set_num_threads(0)
self.assertEqual(ds_iter._num_threads, 0)
self.assertFalse(hasattr(ds_iter, '_executor'))
self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5, 10)))

# Increase threads from 0
ds_iter.set_num_threads(5)
ds_iter._set_num_threads(5)
self.assertEqual(ds_iter._num_threads, 5)
self.assertEqual(ds_iter._executor._max_workers, 5)
self.assertEqual([next(ds_iter) for _ in range(10)], list(range(10, 20)))
Expand Down
23 changes: 15 additions & 8 deletions grain/_src/python/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataclasses for holdings options."""
from __future__ import annotations

import dataclasses

from absl import logging
from grain._src.python.experimental.autotune.python import autotune_bindings


@dataclasses.dataclass(slots=True)
Expand All @@ -41,25 +44,29 @@ class ReadOptions:
# benchmarks reading from remote hard drives.
# These values should work well for datasets with elements between 1 and
# 10 KiB on disk.
num_threads: int = 16
prefetch_buffer_size: int = 500
num_threads: int | autotune_bindings.AutotuneParameter = 16
prefetch_buffer_size: int | autotune_bindings.AutotuneParameter = 500

def __post_init__(self):
if self.num_threads < 0:
if isinstance(self.num_threads, int) and self.num_threads < 0:
raise ValueError(
f'num_threads must be non-negative, got {self.num_threads}'
)
if self.prefetch_buffer_size < 0:

if (
isinstance(self.prefetch_buffer_size, int)
and self.prefetch_buffer_size < 0
):
raise ValueError(
'prefetch_buffer_size must be non-negative, got'
f' {self.prefetch_buffer_size}'
)

# Avoid warning when setting prefetch_buffer_size=0, since this is commonly
# used to disable prefetching.
if (
self.prefetch_buffer_size < self.num_threads
and self.prefetch_buffer_size != 0
):
buffer_size = int(self.prefetch_buffer_size)
num_threads = int(self.num_threads)
if buffer_size < num_threads and buffer_size != 0:
logging.warning(
'prefetch_buffer_size=%s is smaller than num_threads=%s. This will'
' limit the number of threads that can actually be used in parallel'
Expand Down
Loading