From 64155998019ae5cf2bbb37b68a8e45c9897c20c9 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Thu, 2 Oct 2025 15:26:06 -0700 Subject: [PATCH 1/4] Add `EMAWeightAveraging` callback to `weight_averaging.py` --- .../pytorch/callbacks/weight_averaging.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index f9b8d64eae6a5..c97f7c5b41f7b 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -361,3 +361,59 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None: current_params = itertools.chain(pl_module.parameters(), pl_module.buffers()) for average_param, current_param in zip(average_params, current_params): current_param.data.copy_(average_param.data) + + +class EMAWeightAveraging(WeightAveraging): + """Exponential Moving Average (EMA) Weight Averaging callback.""" + + def __init__( + self, + device: Optional[Union[torch.device, str, int]] = None, + use_buffers: bool = True, + decay: float = 0.999, + update_every_n_steps: int = 1, + update_starting_at_step: Optional[int] = None, + update_starting_at_epoch: Optional[int] = None, + **kwargs: Any, + ): + super().__init__( + device=device, + use_buffers=use_buffers, + **kwargs, + avg_fn=get_ema_avg_fn(decay=decay), + ) + + self.update_every_n_steps = update_every_n_steps + self.update_starting_at_step = update_starting_at_step + self.update_starting_at_epoch = update_starting_at_epoch + + def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None): + """Decide when to update the model weights. + + Args: + step_idx: The current step index. + epoch_idx: The current epoch index. + Returns: + bool: True if the model weights should be updated, False otherwise. + """ + if step_idx is not None: + # Check step-based conditions only if we have a valid step_idx + meets_step_requirement = ( + self.update_starting_at_step is None or step_idx >= self.update_starting_at_step + ) + meets_step_frequency = ( + self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0 + ) + if meets_step_requirement and meets_step_frequency: + return True + + if epoch_idx is not None: + # Check epoch-based condition only if we specify one + meets_epoch_requirement = ( + self.update_starting_at_epoch is not None + and epoch_idx >= self.update_starting_at_epoch + ) + if meets_epoch_requirement: + return True + + return False From bcf746f05ab03de0a11bcfcadf890bdb185c28c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Oct 2025 22:27:37 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/callbacks/weight_averaging.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index c97f7c5b41f7b..673ed8fae2f0d 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -395,23 +395,19 @@ def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] epoch_idx: The current epoch index. Returns: bool: True if the model weights should be updated, False otherwise. + """ if step_idx is not None: # Check step-based conditions only if we have a valid step_idx - meets_step_requirement = ( - self.update_starting_at_step is None or step_idx >= self.update_starting_at_step - ) - meets_step_frequency = ( - self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0 - ) + meets_step_requirement = self.update_starting_at_step is None or step_idx >= self.update_starting_at_step + meets_step_frequency = self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0 if meets_step_requirement and meets_step_frequency: return True if epoch_idx is not None: # Check epoch-based condition only if we specify one meets_epoch_requirement = ( - self.update_starting_at_epoch is not None - and epoch_idx >= self.update_starting_at_epoch + self.update_starting_at_epoch is not None and epoch_idx >= self.update_starting_at_epoch ) if meets_epoch_requirement: return True From b9501980d67dc5a184381104921f0173ed5d51a3 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Thu, 2 Oct 2025 15:30:33 -0700 Subject: [PATCH 3/4] Update CHANGELOG.md --- src/lightning/pytorch/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1bba5e4ca0da7..30072517a2e9a 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236)) +- Added `EMAWeightAveraging` callback that wraps Lightning's `WeightAveraging` class ([#21260](https://github.com/Lightning-AI/pytorch-lightning/pull/21260)) + + ### Changed - Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896)) From 7c940ab358cd8c1e365261c656e0ab961df02641 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Fri, 3 Oct 2025 16:16:50 -0700 Subject: [PATCH 4/4] Update weight_averaging.py --- src/lightning/pytorch/callbacks/weight_averaging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index 673ed8fae2f0d..c6f95adaedc1a 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -21,7 +21,7 @@ from typing import Any, Optional, Union import torch -from torch.optim.swa_utils import AveragedModel +from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn from typing_extensions import override import lightning.pytorch as pl