-
Notifications
You must be signed in to change notification settings - Fork 9
Open
Description
When utilizing grain and orbax, the __repr__ of the data source is used for validation when restoring checkpoints. Currently, this causes errors when using bagz.Reader for the data source, as it utilizes the default __repr__ implementation that indicates the pointer location.
This results in the following error when resuming checkpoints that restore the data iterator:
ValueError: DataSource in checkpoint does not match datasource in dataloader.
data source in checkpoint: <bagz.lib.bagz.Reader object at 0x7f603b11ac70>
data source in dataloader: <bagz.lib.bagz.Reader object at 0x7f174c3f4f30>
Grain uses `repr(data_source)` to validate the source, so you may need to implement a custom `__repr__`.
This can be patched currently with a trivial subclass that indicates the location of the underlying file:
class ReprReader(bagz.Reader):
def __init__(self, fp, *args):
self.fp = fp
super().__init__(fp, *args)
def __repr__(self):
return f"Reader(fp={self.fp})"Minimal example:
import grain
import orbax.checkpoint as ocp
import bagz
import tempfile
import os
with tempfile.TemporaryDirectory() as temp_dir:
bagz_file = os.path.join(temp_dir, "demo.bagz")
with bagz.Writer(bagz_file) as writer:
writer.write(b"123")
def make_iter():
reader = bagz.Reader(bagz_file)
dataset = grain.DataLoader(data_source=reader, sampler=grain.samplers.SequentialSampler(1))
return iter(dataset)
mgr = ocp.CheckpointManager(os.path.join(temp_dir, "ckp"), item_names=("data_iter",))
# Write a checkpoint containing a data iterator sourced from bagz
data_iter = make_iter()
mgr.save(0, args=ocp.args.Composite(
data_iter=grain.checkpoint.CheckpointSave(data_iter)
))
mgr.wait_until_finished()
# Reinstantiate, to emulate restoring to an instance in a new process
data_iter = make_iter()
mgr.restore(0, args=ocp.args.Composite(
data_iter=grain.checkpoint.CheckpointRestore(data_iter)
))
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels