Skip to content

Useful __repr__ for Reader #3

@lukepmccombs

Description

@lukepmccombs

When utilizing grain and orbax, the __repr__ of the data source is used for validation when restoring checkpoints. Currently, this causes errors when using bagz.Reader for the data source, as it utilizes the default __repr__ implementation that indicates the pointer location.

This results in the following error when resuming checkpoints that restore the data iterator:

ValueError: DataSource in checkpoint does not match datasource in dataloader.
data source in checkpoint: <bagz.lib.bagz.Reader object at 0x7f603b11ac70>
data source in dataloader: <bagz.lib.bagz.Reader object at 0x7f174c3f4f30>
Grain uses `repr(data_source)` to validate the source, so you may need to implement a custom `__repr__`.

This can be patched currently with a trivial subclass that indicates the location of the underlying file:

class ReprReader(bagz.Reader):

    def __init__(self, fp, *args):
        self.fp = fp
        super().__init__(fp, *args)

    def __repr__(self):
        return f"Reader(fp={self.fp})"

Minimal example:

import grain
import orbax.checkpoint as ocp
import bagz
import tempfile
import os


with tempfile.TemporaryDirectory() as temp_dir:
    bagz_file = os.path.join(temp_dir, "demo.bagz")
    with bagz.Writer(bagz_file) as writer:
        writer.write(b"123")

    def make_iter():
        reader = bagz.Reader(bagz_file)
        dataset = grain.DataLoader(data_source=reader, sampler=grain.samplers.SequentialSampler(1))
        return iter(dataset)

    mgr = ocp.CheckpointManager(os.path.join(temp_dir, "ckp"), item_names=("data_iter",))

    # Write a checkpoint containing a data iterator sourced from bagz
    data_iter = make_iter()
    mgr.save(0, args=ocp.args.Composite(
            data_iter=grain.checkpoint.CheckpointSave(data_iter)
        ))
    mgr.wait_until_finished()

    # Reinstantiate, to emulate restoring to an instance in a new process
    data_iter = make_iter()
    mgr.restore(0, args=ocp.args.Composite(
            data_iter=grain.checkpoint.CheckpointRestore(data_iter)
        ))
        

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions