Skip to content

bf16-mixed causing issues with fused AdamW #21435

@vsandwar-sumer

Description

@vsandwar-sumer

Bug description

bfloat16-mixed is incompatible with fused optimizers on the lightning framework. After some digging I found PR #15555 as a response to issue [#15501] which makes it seem like this is a safety net introduced by the lightning team.

Reproduction Code:

import torch import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset import lightning as L class SimpleModel(L.LightningModule): def __init__(self): super().__init__() self.layer = nn.Linear(10, 1) def forward(self, x): return self.layer(x) def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = nn.functional.mse_loss(y_hat, y) return loss def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=1e-3, fused=True) def main(): X, y = torch.randn(100, 10), torch.randn(100, 1) dataloader = DataLoader(TensorDataset(X, y), batch_size=16) trainer = L.Trainer( accelerator="gpu", devices=1, precision="16-mixed", # Using bf16-true passes gradient_clip_val=1.0, max_epochs=1, logger=False, enable_checkpointing=False, enable_progress_bar=False, ) print(f"Scaler: {trainer.precision_plugin.scaler}") trainer.fit(SimpleModel(), dataloader) if __name__ == "__main__": main() 

Error:

RuntimeError: The current optimizer, AdamW, does not allow for gradient clipping because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer? 

Expected Solution:

I believe the key error is in this function handling. I'm a newbie to the pytorch-lightning OSS community and internal frameworks, but this is in the native-amp code, and is what's in traceback. I think it should have an additional condition which is basically and trainer.precision_plugin.scaler is not None (or however it's accessed in the internal API), which would prevent the unnecessary handle in the bf16 case.

if clip_val > 0 and _optimizer_handles_unscaling(optimizer): raise RuntimeError( f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping" " because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?" ) 

The _step_supports_amp_scaling flag just indicates the optimizer can handle internal unscaling when a GradScaler is present. With bf16 (no scaler), there's nothing to unscale - gradient clipping works normally.

Lightning's check should be:
if clip_val > 0 and self.scaler is not None and _optimizer_handles_unscaling(optimizer):
raise RuntimeError(...)

Not just:
if clip_val > 0 and _optimizer_handles_unscaling(optimizer):
raise RuntimeError(...)

To verify that this is undesired behavior here's a PyTorch only code that works perfectly fine:

import torch import torch.nn as nn def main(): print("PyTorch:", torch.__version__) print() # Simple model model = nn.Linear(10, 1).cuda() # Fused AdamW optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, fused=True) print(f"_step_supports_amp_scaling: {optimizer._step_supports_amp_scaling}") for step in range(10): # Random data x = torch.randn(16, 10, device="cuda") y = torch.randn(16, 1, device="cuda") optimizer.zero_grad() # bf16 autocast (no GradScaler needed for bf16) with torch.autocast("cuda", dtype=torch.bfloat16): y_hat = model(x) loss = nn.functional.mse_loss(y_hat, y) # Backward loss.backward() # Gradient clipping - works fine! grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Step optimizer.step() if __name__ == "__main__": main() 

Personal Setup:
RTX A5000
pytorch-lightning version: 2.5.6

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

Error messages and logs

# Error messages and logs here please 

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0): #- PyTorch Version (e.g., 2.5): #- Python version (e.g., 3.12): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): 

More info

No response

cc @ethanwharris

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions