Skip to content

Conversation

Copilot
Copy link

@Copilot Copilot AI commented Oct 4, 2025

What does this PR do?

Fixes a MisconfigurationException that occurs when using ReduceLROnPlateau scheduler with check_val_every_n_epoch > 1. The scheduler was attempting to access validation metrics on epochs where validation didn't run, causing an error.

Issue

When check_val_every_n_epoch is set to a value greater than 1, validation only runs on specific epochs (e.g., every 2nd epoch). However, the ReduceLROnPlateau scheduler was being updated at the end of every epoch, attempting to access the monitored metric (e.g., val/loss) even when it wasn't available.

Example error:

MisconfigurationException: ReduceLROnPlateau conditioned on metric val/loss which is not available. 
Available metrics are: ['lr-AdamW/pg1', 'lr-AdamW/pg2', 'train/a_pcc', 'train/loss']. 
Condition can be set using `monitor` key in lr scheduler dict

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal changes)

Solution

Modified the scheduler update logic in fit_loop.py to only update plateau schedulers when validation actually runs. This ensures the monitored metrics are available when the scheduler needs them.

Before:

self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting)

After:

if (
    not self.restarting
    and self.epoch_loop._num_ready_batches_reached()
    and self.epoch_loop._should_check_val_epoch()
):
    self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=True)

Testing

Added test case test_reducelronplateau_with_check_val_every_n_epoch that verifies the fix works correctly when validation runs every N epochs.

trainer = Trainer(
    max_epochs=3,
    check_val_every_n_epoch=2,  # Validation only on epochs 0, 2
)
trainer.fit(model)  # No longer raises MisconfigurationException

Impact

  • Behavior change: Plateau schedulers are now only updated on epochs when validation runs
  • Non-plateau schedulers: Unchanged behavior - still updated every epoch
  • Backward compatibility: Maintained - default behavior (validation every epoch) works as before
  • User workaround: Users no longer need to set "strict": False to avoid the error

Fixes #<issue_number>

Original prompt

This section details on the original issue you should resolve

<issue_title>ReduceLROnPlateu within configure_optimizers behave abnormally</issue_title>
<issue_description>### Bug description

Got error

  File "c:\Users\sean\miniconda3\envs\keras+torch+pl\Lib\site-packages\lightning\pytorch\loops\training_epoch_loop.py", line 459, in _update_learning_rates
    raise MisconfigurationException(
lightning.fabric.utilities.exceptions.MisconfigurationException: ReduceLROnPlateau conditioned on metric val/loss which is not available. Available metrics are: ['lr-AdamW/pg1', 'lr-AdamW/pg2', 'train/a_pcc', 'train/loss']. Condition can be set using `monitor` key in lr scheduler dict

Here is the configure_optimizers function:

    @final
    def configure_optimizers(self):

        decay, no_decay = [], []
        for name, param in self.named_parameters():
            if not param.requires_grad:
                continue
            if "bias" in name or "Norm" in name:
                no_decay.append(param)
            else:
                decay.append(param)

        grouped_params = [
            {"params": decay, "weight_decay": self.weight_decay, "lr": self.lr * 0.3},
            {
                "params": no_decay,
                "weight_decay": self.weight_decay,
                "lr": self.lr * 1.7,
            },
        ]

        optimizer = self.optmizer_class(
            grouped_params, lr=self.lr, weight_decay=self.weight_decay
        )

        scheduler = self.lr_scheduler_class(
            optimizer, **self.lr_scheduler_args if self.lr_scheduler_args else {}
        )
        scheduler = {
            "scheduler": self.lr_scheduler_class(
                optimizer, **self.lr_scheduler_args if self.lr_scheduler_args else {}
            ),
            "monitor": "val/loss",
            "interval": "epoch",
            "frequency": 1,
            # "strict": False,
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

The lr_scheduler_class is passed in as

  lr_scheduler_class: torch.optim.lr_scheduler.ReduceLROnPlateau
  lr_scheduler_args:
    mode: min
    factor: 0.5
    patience: 10
    threshold: 0.0001
    threshold_mode: rel
    cooldown: 5
    min_lr: 1.e-9
    eps: 1.e-08

(using yaml and CLI, which, I think, is not the case here)

It seems that I got the error at the end of the training epoch, as I just see the progress bar reports train/loss. The validation epoch is not finished, but the scheduler is called.

I am quite sure that val/loss is available after validation epoch is finished, because progress bar can correctly display it.

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

Error messages and logs

# Error messages and logs here please

Environment

StatusCode : 200
StatusDescription : OK
Content : # Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the...
RawContent : HTTP/1.1 200 OK
Connection: keep-alive
Content-Security-Policy: default-src 'none'; style-src 'unsafe-inline'; sandbox
Strict-Transport-Security: max-age=31536000
X-Content-Type-Options: nosniff
...
Forms : {}
Headers : {[Connection, keep-alive], [Content-Security-Policy, default-src 'none'; style-src 'unsafe-inline'; sandbox], [Strict-Transport-Security, max-age=31536000],
[X-Content-Type-Options, nosniff]...}
Images : {}
InputFields : {}
Links : {}
ParsedHtml : mshtml.HTMLDocumentClass
RawContentLength : 2775

More info

No response</issue_description>

Comments on the Issue (you are @copilot in this section)

Fixes #20829

✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.


📚 Documentation preview 📚: https://pytorch-lightning--21266.org.readthedocs.build/en/21266/

Copilot AI and others added 2 commits October 4, 2025 21:22
- Only update plateau schedulers on epochs when validation runs
- This prevents errors when monitored metrics are not available
- Added test case for this scenario

Co-authored-by: Borda <6035284+Borda@users.noreply.github.com>
Co-authored-by: Borda <6035284+Borda@users.noreply.github.com>
@Copilot Copilot AI changed the title [WIP] ReduceLROnPlateu within configure_optimizers behave abnormally Fix ReduceLROnPlateau scheduler error when validation doesn't run every epoch Oct 4, 2025
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Oct 4, 2025
@Copilot Copilot AI requested a review from Borda October 4, 2025 21:29
Copilot finished work on behalf of Borda October 4, 2025 21:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pl Generic label for PyTorch Lightning package
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ReduceLROnPlateu within configure_optimizers behave abnormally
2 participants