From 531b298726e0ed34215411f07400b3464f5b0845 Mon Sep 17 00:00:00 2001 From: Ran Lu Date: Wed, 15 Oct 2025 16:14:48 -0400 Subject: [PATCH 1/2] Add gradient accumulation --- deepem/train/option.py | 1 + deepem/train/run.py | 47 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/deepem/train/option.py b/deepem/train/option.py index 800a1a0..58badcb 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -62,6 +62,7 @@ def initialize(self): self.parser.add_argument('--chkpt_num', type=int, default=0) self.parser.add_argument('--no_eval', action='store_true') self.parser.add_argument('--pretrain', default=None) + self.parser.add_argument('--grad_accum_steps', type=int, default=1) # WandB logging self.parser.add_argument('--wandb_pad_output', action='store_true') diff --git a/deepem/train/run.py b/deepem/train/run.py index 7301e69..93f9b3d 100644 --- a/deepem/train/run.py +++ b/deepem/train/run.py @@ -1,6 +1,7 @@ import os import time +from collections import defaultdict import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP @@ -66,8 +67,11 @@ def train(opt): # Timer t0 = time.time() - for i in range(opt.chkpt_num, opt.max_iter): + grad_accum_steps = opt.grad_accum_steps + accum_losses = defaultdict(float) + accum_nmasks = defaultdict(float) + for i in range(opt.chkpt_num, opt.max_iter): # Load training samples. sample = train_loader() @@ -81,23 +85,48 @@ def train(opt): losses, nmasks, preds = forward(model, sample, opt) total_loss = sum([w*losses[k] for k, w in opt.loss_weight.items()]) # Backward passes under autocast are not recommended. - scaler.scale(total_loss).backward() + scaler.scale(total_loss / grad_accum_steps).backward() + else: + losses, nmasks, preds = forward(model, sample, opt) + total_loss = sum([w * losses[k] for k, w in opt.loss_weight.items()]) + (total_loss / grad_accum_steps).backward() + + # Accumulate metrics for logging + with torch.no_grad(): + for k, v in losses.items(): + accum_losses[k] += v + for k, v in nmasks.items(): + accum_nmasks[k] += v + + if ((i - opt.chkpt_num) + 1) % grad_accum_steps != 0: + continue + + # --- From here on, code only runs on optimizer step --- + if opt.mixed_precision: scaler.step(optimizer) scaler.update() - losses = {k: v.float() for k, v in losses.items()} - nmasks = {k: v.float() for k, v in nmasks.items()} - preds = {k: v.float() for k, v in preds.items()} else: - losses, nmasks, preds = forward(model, sample, opt) - total_loss = sum([w*losses[k] for k, w in opt.loss_weight.items()]) - total_loss.backward() optimizer.step() + # Average accumulated losses + avg_losses = {k: v / grad_accum_steps for k, v in accum_losses.items()} + avg_nmasks = {k: v / grad_accum_steps for k, v in accum_nmasks.items()} + + if opt.mixed_precision: + avg_losses = {k: v.float() for k, v in avg_losses.items()} + avg_nmasks = {k: v.float() for k, v in avg_nmasks.items()} + preds = {k: v.float() for k, v in preds.items()} + + # Elapsed time elapsed = time.time() - t0 # Record keeping - logger.record('train', losses, nmasks, elapsed=elapsed) + logger.record("train", avg_losses, avg_nmasks, elapsed=elapsed) + + # Reset accumulators + accum_losses = defaultdict(float) + accum_nmasks = defaultdict(float) # Log & display averaged stats. if (i+1) % opt.avgs_intv == 0 or i < opt.warm_up: From 17bff4a8c4a6cd29d32492b5817be311dace114d Mon Sep 17 00:00:00 2001 From: Ran Lu Date: Wed, 15 Oct 2025 16:15:00 -0400 Subject: [PATCH 2/2] The current implementation does not work with size_average (and batchnorm) --- deepem/train/run.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepem/train/run.py b/deepem/train/run.py index 93f9b3d..9c784c9 100644 --- a/deepem/train/run.py +++ b/deepem/train/run.py @@ -26,6 +26,8 @@ def cleanup_distributed(): def train(opt): + assert not (opt.size_average and opt.grad_accum_steps > 1), \ + "size_average and grad_accum_steps > 1 are not supported" # Model if opt.parallel == "DDP": # Make sure samewise finished syncing files