-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Labels
Description
I found that training can (pseudo) leak mem if there is no explicit gc. One solution was adding it to training_step to the regime (like with this):
def training_step(self, batch, batch_idx):
# clear cache once every while to reduce gpu mem usage
if batch_idx % 100 == 0:
gc.collect()
torch.cuda.empty_cache()
But this requires every single regime to have these lines duplicated. I propose that this be a default through a callback in ZettaDefaultTrainer. I have something like this in my branch that works:
class ZettaDefaultTrainer(pl.Trainer): # pragma: no cover
def __init__(
...
gc_interval: int | None = None,
...
if gc_interval is not None:
assert gc_interval > 0
kwargs["callbacks"].append(GcCallback(interval=gc_interval))
super().__init__(*args, **kwargs)
...
class GcCallback(pl.Callback):
def __init__(self, interval: int):
super().__init__()
self.interval = interval
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
if trainer.global_step > 0 and trainer.global_step % self.interval == 0:
gc.collect()
torch.cuda.empty_cache()