Skip to content

Training: add a default garbage collection callback #1029

@trivoldus28

Description

@trivoldus28

I found that training can (pseudo) leak mem if there is no explicit gc. One solution was adding it to training_step to the regime (like with this):

    def training_step(self, batch, batch_idx):
        # clear cache once every while to reduce gpu mem usage
        if batch_idx % 100 == 0:
            gc.collect()
            torch.cuda.empty_cache()

But this requires every single regime to have these lines duplicated. I propose that this be a default through a callback in ZettaDefaultTrainer. I have something like this in my branch that works:

class ZettaDefaultTrainer(pl.Trainer):  # pragma: no cover
    def __init__(
...
        gc_interval: int | None = None,
...
        if gc_interval is not None:
            assert gc_interval > 0
            kwargs["callbacks"].append(GcCallback(interval=gc_interval))

        super().__init__(*args, **kwargs)
...
class GcCallback(pl.Callback):
    def __init__(self, interval: int):
        super().__init__()
        self.interval = interval

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        if trainer.global_step > 0 and trainer.global_step % self.interval == 0:
            gc.collect()
            torch.cuda.empty_cache()

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions