diff --git a/python/array_record_data_source.py b/python/array_record_data_source.py index 4a6aed3..f56e998 100644 --- a/python/array_record_data_source.py +++ b/python/array_record_data_source.py @@ -173,11 +173,12 @@ def get_read_instruction(path: PathLikeOrFileInstruction) -> _ReadInstruction: ) -def _create_reader(filename: epath.PathLike): +def _create_reader(filename: epath.PathLike, additional_reader_options: str): """Returns an ArrayRecordReader for the given filename.""" + reader_options = f"readahead_buffer_size:0,{additional_reader_options}" return array_record_module.ArrayRecordReader( filename, - options="readahead_buffer_size:0", + options=reader_options, file_reader_buffer_size=32768, ) @@ -219,6 +220,7 @@ def __init__( paths: Union[ PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction] ], + reader_options: dict[str, str] | None = None, ): """Creates a new ArrayRecordDataSource object. @@ -238,6 +240,8 @@ def __init__( paths/FileInstructions. When you want to read subsets or have a large number of files prefer to pass FileInstructions. This makes the initialization faster. + reader_options: string of comma-separated options to be passed when + creating a reader. """ if isinstance(paths, (str, pathlib.Path, FileInstruction)): paths = [paths] @@ -258,6 +262,12 @@ def __init__( "Unsupported path format was used. Path format must be " "a Sequence, String, pathlib.Path or FileInstruction." ) + if reader_options is None: + self._reader_options_string = "" + else: + self._reader_options_string = ",".join( + [f"{k}:{v}" for k, v in reader_options.items()] + ) self._read_instructions = _get_read_instructions(paths) self._paths = [ri.filename for ri in self._read_instructions] # We open readers lazily when we need to read from them. @@ -324,7 +334,7 @@ def _ensure_reader_exists(self, reader_idx: int) -> None: if self._readers[reader_idx] is not None: return filename = self._read_instructions[reader_idx].filename - reader = _create_reader(filename) + reader = _create_reader(filename, self._reader_options_string) _check_group_size(filename, reader) self._readers[reader_idx] = reader diff --git a/python/array_record_data_source_test.py b/python/array_record_data_source_test.py index 0c1fbfd..f0e4a3a 100644 --- a/python/array_record_data_source_test.py +++ b/python/array_record_data_source_test.py @@ -247,6 +247,22 @@ def test_repr(self): ]) self.assertRegex(repr(ar), r"ArrayRecordDataSource\(hash_of_paths=[\w]+\)") + @flagsaver.flagsaver(grain_use_fast_array_record_reader=False) + def test_additional_reader_options(self): + indices_to_read = [3, 0, 5, 9, 2, 1, 4, 7, 8, 6] + ar = array_record_data_source.ArrayRecordDataSource( + [ + self.testdata_dir / "digits.array_record-00000-of-00002", + self.testdata_dir / "digits.array_record-00001-of-00002", + ], + {"index_storage_option": "in_memory"}, + ) + # We need to read the records to trigger the creation of the readers. + _ = [ar[x] for x in indices_to_read] + self.assertLen(ar._readers, 2) + self.assertIsInstance(ar._readers[0], array_record_module.ArrayRecordReader) + self.assertIsInstance(ar._readers[1], array_record_module.ArrayRecordReader) + class RunInParallelTest(parameterized.TestCase):