Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
and advance a `grain.DatasetIterator` to the given produced element index.
* Switches to multithreading instead of multiprocessing in
`IterDataset.mp_prefetch` when free-threaded Python is detected.
* Add `ElasticIterDatasetIterator` for scaling up and down the number of shards between checkpoints.

* Breaking changes:
* Custom implementations of `RandomAccessDataSource` should accept `int`
Expand Down
25 changes: 25 additions & 0 deletions grain/_src/python/checkpoint/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,34 @@ py_library(
srcs = ["handler.py"],
srcs_version = "PY3",
deps = [
":elastic_checkpoint",
"//grain/_src/core:sharding",
"//grain/_src/python:data_loader",
"//grain/_src/python/dataset",
"//grain/_src/python/dataset:elastic_iterator",
"@pypi//etils:pkg",
],
)

py_library(
name = "elastic_checkpoint",
srcs = ["elastic_checkpoint.py"],
srcs_version = "PY3",
deps = [
"//grain/_src/python/dataset:elastic_iterator",
"@pypi//etils:pkg",
],
)

py_test(
name = "elastic_checkpoint_test",
srcs = ["elastic_checkpoint_test.py"],
srcs_version = "PY3",
deps = [
":elastic_checkpoint",
"//grain/_src/core:sharding",
"//grain/_src/python/dataset:elastic_iterator",
"@abseil-py//absl/testing:absltest",
"@pypi//etils:pkg",
],
)
137 changes: 137 additions & 0 deletions grain/_src/python/checkpoint/elastic_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""This module provides checkpointing logic for ElasticIterDatasetIterator."""

import dataclasses
import json
from typing import Any, Optional, Sequence

from etils import epath
from grain._src.python.dataset import elastic_iterator


def _find_shard_file(
directory: epath.Path,
shard_index: int,
total_num_shards: int,
) -> epath.Path:
"""Finds all files matching 'shard_state_*.json' in the directory."""
all_files = list(directory.iterdir())
pattern = f"shard_state_{shard_index}-of-{total_num_shards}.json"
found_files = [f for f in all_files if f.name.endswith(pattern)]
if not found_files:
raise ValueError(
f"No shard state files found in {directory} for shard {shard_index}"
)
if len(found_files) > 1:
raise ValueError(
f"Multiple shard state files found in {directory} for shard"
f" {shard_index}"
)
return found_files[0]


def save_elastic_iterator(
directory: epath.Path,
item: elastic_iterator.ElasticIterDatasetIterator,
):
"""Saves the given iterator to the checkpoint in `directory`."""
state = item.get_state()
ds_iterator_states = state["ds_iterator_states"]
total_num_shards = state["total_num_shards"]
for idx, host_iterator_state in ds_iterator_states.items():
host_iterator_state["total_num_shards"] = total_num_shards
shard_state = json.dumps(host_iterator_state, indent=4)
filename = directory / f"shard_state_{idx}-of-{total_num_shards}.json"
filename.write_text(shard_state)


def restore_elastic_iterator(
directory: epath.Path,
item: elastic_iterator.ElasticIterDatasetIterator,
):
"""Restores the given iterator from the checkpoint in `directory`."""
total_num_shards = item.total_num_shards
shard_index = item.shard_options.shard_index
shard_count = item.shard_options.shard_count
while shard_index < total_num_shards:
filename = _find_shard_file(directory, shard_index, total_num_shards)
state = filename.read_text()
state = json.loads(state)
item.update_shard_iterator_state(shard_index, state)
shard_index += shard_count


class ElasticCheckpointHandler:
"""Orbax CheckpointHandler for PyGrain iterators."""

def save(
self,
directory: epath.Path,
item: Optional[
elastic_iterator.ElasticIterDatasetIterator
| Sequence[elastic_iterator.ElasticIterDatasetIterator]
] = None,
args: Any = None,
):
"""Saves the given iterator to the checkpoint in `directory`."""
item = item or args.item
if isinstance(item, elastic_iterator.ElasticIterDatasetIterator):
item = [item]
for iterator in item:
save_elastic_iterator(directory, iterator)

def restore(
self,
directory: epath.Path,
item: Optional[
elastic_iterator.ElasticIterDatasetIterator
| Sequence[elastic_iterator.ElasticIterDatasetIterator]
] = None,
args: Any = None,
) -> Any:
"""Restores the given iterator from the checkpoint in `directory`."""
item = item or args.item
if isinstance(item, elastic_iterator.ElasticIterDatasetIterator):
item = [item]
for iterator in item:
restore_elastic_iterator(directory, iterator)
return item

# Required by interface but not supported by PyGrain checkpoints.
def structure(self, directory: epath.Path) -> Any:
del directory
return None

# Required by interface.

def metadata(self, directory: epath.Path) -> Optional[Any]:
del directory
return None

def finalize(self, directory: epath.Path):
pass

def close(self):
pass

@classmethod
def typestr(cls):
return f"{cls.__module__}.{cls.__qualname__}"


try:
# Register the handler to be used with the new checkpointing API if Orbax is
# present.
import orbax.checkpoint as ocp # pylint:disable=g-import-not-at-top # pytype:disable=import-error

@ocp.args.register_with_handler(ElasticCheckpointHandler, for_save=True) # pytype:disable=wrong-arg-types
@dataclasses.dataclass
class ElasticCheckpointSave(ocp.args.CheckpointArgs):
item: Any

@ocp.args.register_with_handler(ElasticCheckpointHandler, for_restore=True) # pytype:disable=wrong-arg-types
@dataclasses.dataclass
class ElasticCheckpointRestore(ocp.args.CheckpointArgs):
item: Any

except (ImportError, TypeError, AttributeError):
pass
122 changes: 122 additions & 0 deletions grain/_src/python/checkpoint/elastic_checkpoint_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Tests for elastic checkpoint."""

import json

from etils import epath
from grain._src.core import sharding
from grain._src.python.checkpoint import elastic_checkpoint
from grain._src.python.dataset import elastic_iterator

from absl.testing import absltest


class MockElasticIterDatasetIterator(
elastic_iterator.ElasticIterDatasetIterator
):

def __init__(self, shard_options, total_num_shards, states=None):
self._shard_options = shard_options
self._total_num_shards = total_num_shards
self._states = states if states is not None else {}
self.updated_states = {}

def get_state(self):
return {
"ds_iterator_states": self._states,
"total_num_shards": self._total_num_shards,
}

def update_shard_iterator_state(self, shard_index, state):
self.updated_states[shard_index] = state


class ElasticCheckpointTest(absltest.TestCase):

def test_save_and_restore_elastic_iterator(self):
temp_dir = epath.Path(self.create_tempdir().full_path)
shard_options = sharding.ShardOptions(shard_index=0, shard_count=1)
states = {
0: {"val": 0},
1: {"val": 1},
}
iterator = MockElasticIterDatasetIterator(
shard_options=shard_options, total_num_shards=2, states=states
)
elastic_checkpoint.save_elastic_iterator(temp_dir, iterator)

file0 = temp_dir / "shard_state_0-of-2.json"
self.assertTrue(file0.exists())
self.assertEqual(
file0.read_text(),
json.dumps({"val": 0, "total_num_shards": 2}, indent=4),
)
file1 = temp_dir / "shard_state_1-of-2.json"
self.assertTrue(file1.exists())
self.assertEqual(
file1.read_text(),
json.dumps({"val": 1, "total_num_shards": 2}, indent=4),
)

iterator_to_restore = MockElasticIterDatasetIterator(
shard_options=shard_options, total_num_shards=2
)
elastic_checkpoint.restore_elastic_iterator(temp_dir, iterator_to_restore)
self.assertEqual(
iterator_to_restore.updated_states,
{
0: {"val": 0, "total_num_shards": 2},
1: {"val": 1, "total_num_shards": 2},
},
)

def test_restore_elastic_iterator_with_multiple_processes(self):
temp_dir = epath.Path(self.create_tempdir().full_path)
# Process 0
shard_options_0 = sharding.ShardOptions(shard_index=0, shard_count=2)
states = {
0: {"val": 0},
1: {"val": 1},
2: {"val": 2},
}
iterator_0 = MockElasticIterDatasetIterator(
shard_options=shard_options_0, total_num_shards=3, states=states
)
# In reality save_elastic_iterator will be called in each process, but
# get_state() should return all states, so we only need to call it once
# to create checkpoint files.
elastic_checkpoint.save_elastic_iterator(temp_dir, iterator_0)

# Check files are written
self.assertTrue((temp_dir / "shard_state_0-of-3.json").exists())
self.assertTrue((temp_dir / "shard_state_1-of-3.json").exists())
self.assertTrue((temp_dir / "shard_state_2-of-3.json").exists())

# Restore for process 0, responsible for shards 0 and 2.
iterator_to_restore_0 = MockElasticIterDatasetIterator(
shard_options=shard_options_0, total_num_shards=3
)
elastic_checkpoint.restore_elastic_iterator(temp_dir, iterator_to_restore_0)
self.assertEqual(
iterator_to_restore_0.updated_states,
{
0: {"val": 0, "total_num_shards": 3},
2: {"val": 2, "total_num_shards": 3},
},
)

# Restore for process 1, responsible for shard 1.
shard_options_1 = sharding.ShardOptions(shard_index=1, shard_count=2)
iterator_to_restore_1 = MockElasticIterDatasetIterator(
shard_options=shard_options_1, total_num_shards=3
)
elastic_checkpoint.restore_elastic_iterator(temp_dir, iterator_to_restore_1)
self.assertEqual(
iterator_to_restore_1.updated_states,
{
1: {"val": 1, "total_num_shards": 3},
},
)


if __name__ == "__main__":
absltest.main()
10 changes: 9 additions & 1 deletion grain/_src/python/checkpoint/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module provides a PyGrain CheckpointHandler for integration with Orbax."""

import dataclasses
import json
from typing import Any, Optional, TypeVar

from etils import epath
from grain._src.core import sharding
from grain._src.python import data_loader
from grain._src.python.checkpoint import elastic_checkpoint
from grain._src.python.dataset import dataset
from grain._src.python.dataset import elastic_iterator

IteratorType = TypeVar(
"IteratorType", data_loader.DataLoaderIterator, dataset.DatasetIterator
Expand All @@ -41,6 +44,9 @@ def save(
"""Saves the given iterator to the checkpoint in `directory`."""
item = item or args.item # pytype:disable=attribute-error
if isinstance(item, dataset.DatasetIterator):
if isinstance(item, elastic_iterator.ElasticIterDatasetIterator):
elastic_checkpoint.save_elastic_iterator(directory, item)
return
state = json.dumps(item.get_state(), indent=4)
else:
state = item.get_state().decode()
Expand All @@ -56,6 +62,9 @@ def restore(
) -> IteratorType:
"""Restores the given iterator from the checkpoint in `directory`."""
item = item or args.item # pytype:disable=attribute-error
if isinstance(item, elastic_iterator.ElasticIterDatasetIterator):
elastic_checkpoint.restore_elastic_iterator(directory, item)
return item
process_index, process_count = sharding.get_process_index_and_count()
filename = directory / f"process_{process_index}-of-{process_count}.json"
if not filename.exists():
Expand Down Expand Up @@ -105,6 +114,5 @@ class CheckpointSave(ocp.args.CheckpointArgs):
class CheckpointRestore(ocp.args.CheckpointArgs):
item: Any


except (ImportError, TypeError, AttributeError):
pass
1 change: 1 addition & 0 deletions grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ py_test(
":elastic_iterator",
"//grain/_src/core:sharding",
"//grain/_src/python:options",
"//grain/_src/python/checkpoint:elastic_checkpoint",
"//grain/_src/python/testing:experimental",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
Expand Down
Loading