From 531c1e9da56f5384482207a0118dd06bfb57a969 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 4 Oct 2025 21:04:59 +0000 Subject: [PATCH 1/3] Initial plan From 817243398495a67ed90b460c1dd288be00c87af3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 4 Oct 2025 21:22:04 +0000 Subject: [PATCH 2/3] Fix ReduceLROnPlateau scheduler with check_val_every_n_epoch - Only update plateau schedulers on epochs when validation runs - This prevents errors when monitored metrics are not available - Added test case for this scenario Co-authored-by: Borda <6035284+Borda@users.noreply.github.com> --- src/lightning/pytorch/loops/fit_loop.py | 6 ++- .../trainer/optimization/test_optimizers.py | 39 +++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 8bb123939dc20..2de4066d91921 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -483,8 +483,10 @@ def on_advance_end(self) -> None: if not self.restarting and self.epoch_loop._num_ready_batches_reached(): # since metric-based schedulers require access to metrics and those are not currently saved in the - # checkpoint, the plateau schedulers shouldn't be updated - self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting) + # checkpoint, the plateau schedulers shouldn't be updated when restarting + # only update plateau schedulers if validation ran this epoch to ensure monitored metrics are available + if self.epoch_loop._should_check_val_epoch(): + self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=True) # we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics # even when the batch loop has finished diff --git a/tests/tests_pytorch/trainer/optimization/test_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_optimizers.py index 6b88534f3430d..0b991cba3abd4 100644 --- a/tests/tests_pytorch/trainer/optimization/test_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_optimizers.py @@ -165,6 +165,45 @@ def configure_optimizers(self): ) +def test_reducelronplateau_with_check_val_every_n_epoch(tmp_path): + """Test that ReduceLROnPlateau works correctly when validation runs every N epochs.""" + + class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + loss = super().training_step(batch, batch_idx)["loss"] + self.log("train/loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + loss = super().validation_step(batch, batch_idx)["x"] + self.log("val/loss", loss) + return loss + + def configure_optimizers(self): + optimizer = optim.Adam(self.parameters()) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": optim.lr_scheduler.ReduceLROnPlateau(optimizer), + "monitor": "val/loss", + }, + } + + model = TestModel() + # Validation runs every 2 epochs, but scheduler should only update on validation epochs + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=3, + limit_train_batches=2, + limit_val_batches=2, + check_val_every_n_epoch=2, + enable_progress_bar=False, + enable_model_summary=False, + ) + # This should not raise an error about missing val/loss metric + trainer.fit(model) + + def test_optimizer_return_options(tmp_path): trainer = Trainer(default_root_dir=tmp_path) model = BoringModel() From 191fa8dccad3589cfb95f95b38726e02cadaf00c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 4 Oct 2025 21:27:20 +0000 Subject: [PATCH 3/3] Apply ruff lint suggestion - combine nested if statements Co-authored-by: Borda <6035284+Borda@users.noreply.github.com> --- src/lightning/pytorch/loops/fit_loop.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 2de4066d91921..749779d9dfa02 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -481,12 +481,15 @@ def on_advance_end(self) -> None: trainer._logger_connector.on_epoch_end() - if not self.restarting and self.epoch_loop._num_ready_batches_reached(): - # since metric-based schedulers require access to metrics and those are not currently saved in the - # checkpoint, the plateau schedulers shouldn't be updated when restarting - # only update plateau schedulers if validation ran this epoch to ensure monitored metrics are available - if self.epoch_loop._should_check_val_epoch(): - self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=True) + # since metric-based schedulers require access to metrics and those are not currently saved in the + # checkpoint, the plateau schedulers shouldn't be updated when restarting + # only update plateau schedulers if validation ran this epoch to ensure monitored metrics are available + if ( + not self.restarting + and self.epoch_loop._num_ready_batches_reached() + and self.epoch_loop._should_check_val_epoch() + ): + self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=True) # we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics # even when the batch loop has finished