Skip to content
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236))


- Added `EMAWeightAveraging` callback that wraps Lightning's `WeightAveraging` class ([#21260](https://github.com/Lightning-AI/pytorch-lightning/pull/21260))


### Changed

- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896))
Expand Down
54 changes: 53 additions & 1 deletion src/lightning/pytorch/callbacks/weight_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Optional, Union

import torch
from torch.optim.swa_utils import AveragedModel
from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn
from typing_extensions import override

import lightning.pytorch as pl
Expand Down Expand Up @@ -361,3 +361,55 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
for average_param, current_param in zip(average_params, current_params):
current_param.data.copy_(average_param.data)


class EMAWeightAveraging(WeightAveraging):
"""Exponential Moving Average (EMA) Weight Averaging callback."""

def __init__(
self,
device: Optional[Union[torch.device, str, int]] = None,
use_buffers: bool = True,
decay: float = 0.999,
update_every_n_steps: int = 1,
update_starting_at_step: Optional[int] = None,
update_starting_at_epoch: Optional[int] = None,
**kwargs: Any,
):
super().__init__(
device=device,
use_buffers=use_buffers,
**kwargs,
avg_fn=get_ema_avg_fn(decay=decay),
)

self.update_every_n_steps = update_every_n_steps
self.update_starting_at_step = update_starting_at_step
self.update_starting_at_epoch = update_starting_at_epoch

def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None):
"""Decide when to update the model weights.

Args:
step_idx: The current step index.
epoch_idx: The current epoch index.
Returns:
bool: True if the model weights should be updated, False otherwise.

"""
if step_idx is not None:
# Check step-based conditions only if we have a valid step_idx
meets_step_requirement = self.update_starting_at_step is None or step_idx >= self.update_starting_at_step
meets_step_frequency = self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0
if meets_step_requirement and meets_step_frequency:
return True

if epoch_idx is not None:
# Check epoch-based condition only if we specify one
meets_epoch_requirement = (
self.update_starting_at_epoch is not None and epoch_idx >= self.update_starting_at_epoch
)
if meets_epoch_requirement:
return True

return False
Loading