diff --git a/grain/_src/python/BUILD b/grain/_src/python/BUILD index 3fe798268..bb731c3db 100644 --- a/grain/_src/python/BUILD +++ b/grain/_src/python/BUILD @@ -243,7 +243,10 @@ py_library( name = "options", srcs = ["options.py"], srcs_version = "PY3", - deps = ["@abseil-py//absl/logging"], + deps = [ + "//grain/_src/python/experimental/autotune/python:autotune_bindings", + "@abseil-py//absl/logging", + ], ) py_test( diff --git a/grain/_src/python/dataset/BUILD b/grain/_src/python/dataset/BUILD index 8e3e81ad4..ad1267912 100644 --- a/grain/_src/python/dataset/BUILD +++ b/grain/_src/python/dataset/BUILD @@ -57,6 +57,7 @@ py_library( "//grain/_src/python:options", "//grain/_src/python:shared_memory_array", "//grain/_src/python/checkpoint:base", + "//grain/_src/python/experimental/autotune/python:autotune_util", "//grain/proto:execution_summary_py_pb2", "@abseil-py//absl/flags", "@abseil-py//absl/logging", diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 7ceb2a8b4..2f45b6924 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -34,6 +34,7 @@ from grain._src.python.dataset.transformations import filter as filter_dataset from grain._src.python.dataset.transformations import interleave from grain._src.python.dataset.transformations import source +from grain._src.python.experimental.autotune.python import autotune_util T = TypeVar("T") @@ -143,10 +144,17 @@ def __init__( self._next_buffered_index = 0 self._buffer = collections.deque() self._lock = threading.Lock() - self._prefetch_buffer_size = ( - read_options.prefetch_buffer_size if read_options.num_threads > 0 else 0 + + self._num_threads, self.target_num_threads = ( + autotune_util.get_autotune_parameter(read_options.num_threads) + ) + self._prefetch_buffer_size, self.target_buffer_size = ( + autotune_util.get_autotune_parameter(read_options.prefetch_buffer_size) ) - self._num_threads = read_options.num_threads + + if self._num_threads == 0: + self._prefetch_buffer_size = 0 + self._allow_nones = allow_nones if self._prefetch_buffer_size > 0: self._executor = futures.ThreadPoolExecutor( @@ -185,6 +193,7 @@ def _threshold_checker(self): ) def __next__(self) -> T: self._assert_not_closed() + self._check_autotune_updates() # The time recorded here is the time spent in prefetch node to return an # element, including the time spent in parent node. timer = dataset_stats.Timer() @@ -254,7 +263,7 @@ def __str__(self) -> str: f" allow_nones={self._allow_nones})" ) - def set_prefetch_buffer_size(self, buffer_size: int): + def _set_prefetch_buffer_size(self, buffer_size: int): self._prefetch_buffer_size = buffer_size # The executor is created in the constructor only if the prefetch buffer # size is greater than 0. If the user changes the prefetch buffer size, we @@ -272,7 +281,7 @@ def set_prefetch_buffer_size(self, buffer_size: int): self._executor.shutdown() delattr(self, "_executor") - def set_num_threads(self, num_threads: int) -> None: + def _set_num_threads(self, num_threads: int) -> None: self._num_threads = num_threads old_executor = None # Accounts for the case where the executor does not exit. This can @@ -290,6 +299,17 @@ def set_num_threads(self, num_threads: int) -> None: # assigned asynchronously. old_executor.shutdown(wait=False) + def _check_autotune_updates(self): + if self.target_buffer_size is not None: + new_buffer_size = int(round(self.target_buffer_size.get_value())) + if new_buffer_size != self._prefetch_buffer_size: + self._set_prefetch_buffer_size(new_buffer_size) + + if self.target_num_threads is not None: + new_num_threads = int(round(self.target_num_threads.get_value())) + if new_num_threads != self._num_threads: + self._set_num_threads(new_num_threads) + def _fill_buffer(self): while ( len(self._buffer) < self._prefetch_buffer_size diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 0855b0b4b..f3de16815 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -162,7 +162,7 @@ def test_set_prefetch_buffer_size_0_to_positive(self): self.assertEqual(next(ds_iter), 0) # Setting prefetch_buffer_size to 2. - ds_iter.set_prefetch_buffer_size(2) + ds_iter._set_prefetch_buffer_size(2) self.assertEqual(ds_iter._prefetch_buffer_size, 2) self.assertEqual(next(ds_iter), 1) self.assertTrue(hasattr(ds_iter, '_executor')) @@ -183,7 +183,7 @@ def test_set_prefetch_buffer_size_positive_to_0(self): self.assertLen(ds_iter._buffer, 2) # Setting prefetch_buffer_size to 0. - ds_iter.set_prefetch_buffer_size(0) + ds_iter._set_prefetch_buffer_size(0) self.assertEqual(ds_iter._prefetch_buffer_size, 0) # Should consume buffer first. self.assertEqual(next(ds_iter), 1) @@ -207,7 +207,7 @@ def test_set_prefetch_buffer_size_increase(self): self.assertLen(ds_iter._buffer, 1) # Setting prefetch_buffer_size to 2. - ds_iter.set_prefetch_buffer_size(2) + ds_iter._set_prefetch_buffer_size(2) self.assertEqual(ds_iter._prefetch_buffer_size, 2) self.assertEqual(next(ds_iter), 1) self.assertLen(ds_iter._buffer, 2) @@ -227,7 +227,7 @@ def test_set_prefetch_buffer_size_decrease(self): self.assertLen(ds_iter._buffer, 2) # Setting prefetch_buffer_size to 1. - ds_iter.set_prefetch_buffer_size(1) + ds_iter._set_prefetch_buffer_size(1) self.assertEqual(ds_iter._prefetch_buffer_size, 1) self.assertEqual(next(ds_iter), 1) self.assertLen(ds_iter._buffer, 1) @@ -328,7 +328,7 @@ def test_set_num_threads_decrease_threads(self): self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5))) # Decrease threads - ds_iter.set_num_threads(5) + ds_iter._set_num_threads(5) self.assertEqual(ds_iter._num_threads, 5) self.assertEqual(ds_iter._executor._max_workers, 5) self.assertEqual([next(ds_iter) for _ in range(15)], list(range(5, 20))) @@ -345,7 +345,7 @@ def test_set_num_threads_increase_threads(self): self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5))) # Increase threads - ds_iter.set_num_threads(10) + ds_iter._set_num_threads(10) self.assertEqual(ds_iter._num_threads, 10) self.assertEqual(ds_iter._executor._max_workers, 10) self.assertEqual([next(ds_iter) for _ in range(15)], list(range(5, 20))) @@ -360,7 +360,7 @@ def test_set_num_threads_decrease_to_zero(self): ) self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5))) # Decrease threads to 0 - ds_iter.set_num_threads(0) + ds_iter._set_num_threads(0) self.assertEqual(ds_iter._num_threads, 0) self.assertFalse(hasattr(ds_iter, '_executor')) self.assertEqual([next(ds_iter) for _ in range(15)], list(range(5, 20))) @@ -370,13 +370,13 @@ def test_set_num_threads_increase_from_zero(self): self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator) ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter) self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5))) - ds_iter.set_num_threads(0) + ds_iter._set_num_threads(0) self.assertEqual(ds_iter._num_threads, 0) self.assertFalse(hasattr(ds_iter, '_executor')) self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5, 10))) # Increase threads from 0 - ds_iter.set_num_threads(5) + ds_iter._set_num_threads(5) self.assertEqual(ds_iter._num_threads, 5) self.assertEqual(ds_iter._executor._max_workers, 5) self.assertEqual([next(ds_iter) for _ in range(10)], list(range(10, 20))) diff --git a/grain/_src/python/options.py b/grain/_src/python/options.py index 293e3441a..f35557986 100644 --- a/grain/_src/python/options.py +++ b/grain/_src/python/options.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """Dataclasses for holdings options.""" +from __future__ import annotations + import dataclasses from absl import logging +from grain._src.python.experimental.autotune.python import autotune_bindings @dataclasses.dataclass(slots=True) @@ -41,25 +44,29 @@ class ReadOptions: # benchmarks reading from remote hard drives. # These values should work well for datasets with elements between 1 and # 10 KiB on disk. - num_threads: int = 16 - prefetch_buffer_size: int = 500 + num_threads: int | autotune_bindings.AutotuneParameter = 16 + prefetch_buffer_size: int | autotune_bindings.AutotuneParameter = 500 def __post_init__(self): - if self.num_threads < 0: + if isinstance(self.num_threads, int) and self.num_threads < 0: raise ValueError( f'num_threads must be non-negative, got {self.num_threads}' ) - if self.prefetch_buffer_size < 0: + + if ( + isinstance(self.prefetch_buffer_size, int) + and self.prefetch_buffer_size < 0 + ): raise ValueError( 'prefetch_buffer_size must be non-negative, got' f' {self.prefetch_buffer_size}' ) + # Avoid warning when setting prefetch_buffer_size=0, since this is commonly # used to disable prefetching. - if ( - self.prefetch_buffer_size < self.num_threads - and self.prefetch_buffer_size != 0 - ): + buffer_size = int(self.prefetch_buffer_size) + num_threads = int(self.num_threads) + if buffer_size < num_threads and buffer_size != 0: logging.warning( 'prefetch_buffer_size=%s is smaller than num_threads=%s. This will' ' limit the number of threads that can actually be used in parallel'