Skip to content

Trainer.fit completes early due to an unexpected increment in StochasticWeightAveraging when training a model with _BatchNorm and max_epochs=-1 #21347

@3waffel

Description

@3waffel

Bug description

When training ResNet18, which contains a BatchNorm2d layer and setting max_epochs to -1 for infinite training, the trainer.fit ends early due to increment of max_epochs in StochasticWeightAveraging.

The trainer prints the following message after the sanity check and discontinues the training.

`Trainer.fit` stopped: `max_epochs=0` reached. 

The logged max_epochs is different from our configuration, which clearly shows an unexpected increment.


The following condition is met when the trained model contains _BatchNorm in stochastic_weight_avg.py

 self._max_epochs = trainer.max_epochs if self._model_contains_batch_norm: # virtually increase max_epochs to perform batch norm update on latest epoch. assert trainer.fit_loop.max_epochs is not None trainer.fit_loop.max_epochs += 1

After the max_epochs is set to 0, the fit loop stops when checking the following condition in fit_loop.py

 # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved. # we use it here because the checkpoint data won't have `completed` increased yet assert isinstance(self.max_epochs, int) stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs) if stop_epochs: # in case they are not equal, override so `trainer.current_epoch` has the expected value self.epoch_progress.current.completed = self.epoch_progress.current.processed rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.") return True

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

Error messages and logs

# Error messages and logs here please 

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0): 2.5.1 #- PyTorch Version (e.g., 2.5): 2.6.0 #- Python version (e.g., 3.12): 3.12.10 #- OS (e.g., Linux): Linux #- CUDA/cuDNN version: 11.8 #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): poetry 

More info

No response

cc @ethanwharris

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions