From e2a02cdf5b0755d52fa74ced4aece5c804c2d7a7 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Sun, 28 Sep 2025 13:50:41 +0500 Subject: [PATCH 1/3] fix: synchronize gradients in manual optimization with DDPStrategy(static_graph=True). Ensure gradients are reduced correctly when using manual optimization and DDP with static_graph enabled. --- src/lightning/pytorch/strategies/ddp.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 92206e1accc31..64a0724605bc8 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -319,6 +319,27 @@ def pre_backward(self, closure_loss: Tensor) -> None: if not self.lightning_module.automatic_optimization: prepare_for_backward(self.model, closure_loss) + @override + def post_backward(self, closure_loss: Tensor) -> None: + # Only for first static-graph iteration with manual optimization + model = self.model + lm = self.lightning_module + if not isinstance(model, DistributedDataParallel): + return + if lm is None or lm.automatic_optimization: + return + if not getattr(model, "static_graph", False): + return + if self._pl_static_graph_delay_done: + return + + # Call DDP's own first-iter static-graph flush. + # This is what actually launches the bucket all-reduces. + reducer = model.reducer + reducer._delay_all_reduce() + + self._pl_static_graph_delay_done = True + @override def model_to_device(self) -> None: log.debug(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...") From 38cc2426a1c192ec327fb2150e0ffcb73e926fcd Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Sun, 28 Sep 2025 13:52:18 +0500 Subject: [PATCH 2/3] Adds regression test to cover all combinations of optimization/static_graph. --- .../strategies/test_ddp_integration.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py index fc3a8cfebbac0..6373985687ad3 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration.py @@ -448,3 +448,50 @@ def creates_processes_externally(self): RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`." ): trainer.fit(model) + + +@RunIf(min_cuda_gpus=2, standalone=True) +@pytest.mark.parametrize("automatic_optimization", [True, False]) +@pytest.mark.parametrize("static_graph", [True, False]) +def test_ddp_gradients_synced(tmp_path, automatic_optimization, static_graph): + """Ensure gradients are synchronized across ranks for both optimization modes and static_graph settings.""" + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.automatic_optimization = automatic_optimization + + def training_step(self, batch, batch_idx): + if self.automatic_optimization: + return super().training_step(batch, batch_idx) + + # manual optimization path + opt = self.optimizers() + opt.zero_grad() + out = super().training_step(batch, batch_idx) + loss = out["loss"] + self.manual_backward(loss) + opt.step() + return out + + def on_train_batch_end(self, *args, **kwargs): + # record grad sum for sync check + grad_sum = self.layer.bias.grad.detach().sum() + self.log("grad_sum_min", grad_sum, sync_dist=True, reduce_fx="min") + self.log("grad_sum_max", grad_sum, sync_dist=True, reduce_fx="max") + + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="gpu", + devices=2, + strategy=DDPStrategy(static_graph=static_graph), + max_steps=1, + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(TestModel(), datamodule=BoringDataModule()) + + # assert all ranks saw identical grads + gmin = trainer.callback_metrics["grad_sum_min"] + gmax = trainer.callback_metrics["grad_sum_max"] + assert torch.allclose(gmin, gmax) From c7e6cc47f0d263b0c2ec2e032e9c1e4c9eadc59a Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Sun, 28 Sep 2025 14:06:01 +0500 Subject: [PATCH 3/3] Initialize _pl_static_graph_delay_done attribute properly --- src/lightning/pytorch/strategies/ddp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 64a0724605bc8..4eca6159ddced 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -103,6 +103,7 @@ def __init__( self._process_group_backend: Optional[str] = process_group_backend self._timeout: Optional[timedelta] = timeout self._start_method = start_method + self._pl_static_graph_delay_done = False @property def is_distributed(self) -> bool: # pragma: no-cover