Skip to content

DDP Validation Metric Logging got Misplaced Silently When Processes Different Metric Keys on Different Devices #21409

@worldlife123

Description

@worldlife123

Bug description

Problem

In PyTorch Lightning's DDP (Distributed Data Parallel) mode, when different processes log different metric keys during validation (based on their local data), the synchronization mechanism fails, causing:

  1. Metrics to be assigned to wrong keys
  2. Incorrect aggregated values
  3. No clear error or warning about the mismatch

Possible Root Cause

Lightning's sync_dist=True mechanism assumes homogeneous metric keys across all processes at each step. When this assumption is violated (e.g., Process 0 logs ["loss", "metric_a"] while Process 1 logs ["loss", "metric_b"]), the synchronization logic becomes confused, leading to data misalignment.

Minimal Example

class BuggyModel(pl.LightningModule): def validation_step(self, batch, batch_idx): x, y = batch # Common metric - all processes log this self.log("val_loss", loss, sync_dist=True) # Different processes log different metrics if y[0] % 2 == 0: # Even samples self.log("val_metric_even", torch.tensor(0.5), sync_dist=True) else: # Odd samples self.log("val_metric_odd", torch.tensor(0.7), sync_dist=True) return {"val_loss": loss, "y": y}

Expected Behavior

One of the following:

  1. Option A (Preferred): Lightning properly handles heterogeneous metric keys by:

    • Synchronizing each key independently across processes
    • Only aggregating metrics where all processes contributed values
    • Providing clear warnings about partial metric coverage
  2. Option B: Clear error/warning when metric key mismatch is detected, guiding users to:

    • Use sync_dist=False and handle synchronization manually
    • Ensure consistent logging across processes
    • Use the all_gather API for heterogeneous metrics
  3. Option C: Add a flag like allow_heterogeneous_keys=True that enables proper synchronization of different keys.

Actual Behavior

  • Silent misalignment of metric values
  • Incorrect metrics reported to logger (TensorBoard, WandB, etc.)
  • No error or warning, making debugging extremely difficult

Impact

This affects many real-world scenarios:

  1. Multi-task learning: Different tasks may have different evaluation metrics
  2. Imbalanced datasets: Rare classes might trigger special metrics only in some batches
  3. Conditional evaluation: Some metrics only make sense for certain data subsets

Current User Workarounds (All Unsatisfactory)

  1. Log dummy values: Forces all processes to log all keys, wasting computation
  2. Disable sync_dist: Lose automatic synchronization benefits
  3. Manual all_gather: Requires significant boilerplate code
  4. Log only at epoch end: Lose per-step logging granularity

Suggested Fix

Maybe implement proper key-aware synchronization like this?

# Pseudo-code for proper synchronization def sync_metrics(metrics_dict, group, rank): # Gather all keys from all processes all_keys = [set(metrics_dict.keys()) for _ in range(world_size)] # Synchronize each key independently for key in union_of_all_keys: if key in metrics_dict: sync_tensor(metrics_dict[key], group) else: # Handle missing key (skip or use NaN) pass

** Below I provide a complete reproduction script of this bug, where a model logs 'val_metric_a', 'val_metric_b', 'val_metric_c' depends on the data class, which should be around 1.0, 2.0, 3.0, respectively. However, when viewing in Tensorboard, the values are not that case, as shown in this image (version_0 use sync_dist, and version 1 not): **

Image

What version are you seeing the problem on?

v2.4

Reproduced in studio

No response

How to reproduce the bug

""" Reproduction script for PyTorch Lightning DDP validation logging bug when processes log different metric keys.  Issue: When different processes log different metric keys during validation, the synchronized logs become misplaced/corrupted.  Expected: Each metric should be properly synchronized across processes. Actual: Metrics get misaligned, causing wrong values or missing metrics. """ import os import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import pytorch_lightning as pl from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint import numpy as np # Set seed for reproducibility torch.manual_seed(42) np.random.seed(42) class SyntheticDataset(Dataset): """Synthetic dataset that generates different data for different ranks""" def __init__(self, size=1000): self.size = size def __len__(self): return self.size def __getitem__(self, idx): # Create input with 2 channels x = torch.randn(3, 32, 32) # Create different labels for different data points # This will cause different processes to see different types of samples if idx % 3 == 0: y = torch.tensor(0) # Type A elif idx % 3 == 1: y = torch.tensor(1) # Type B else: y = torch.tensor(2) # Type C return x, y class BuggyModel(pl.LightningModule): """Model that reproduces the DDP logging bug""" def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, 3, padding=1) self.conv2 = nn.Conv2d(16, 32, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(32 * 8 * 8, 128) self.fc2 = nn.Linear(128, 3) # Track what we're logging for debugging self.logged_keys = [] def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) self.log('train_loss', loss, prog_bar=True) return loss def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) # CRITICAL BUG REPRODUCTION: # Different processes log different metrics based on the sample type # This simulates real-world scenarios where different data requires # different evaluation metrics # Get unique labels in this batch unique_labels = torch.unique(y) # Store what we log for debugging step_logged_keys = [] # Always log the common loss self.log('val_loss', loss, sync_dist=True, prog_bar=True) step_logged_keys.append('val_loss') # Log type-specific metrics (THIS IS WHERE THE BUG HAPPENS) if 0 in unique_labels: # Type A samples # Process with Type A samples computes metric A metric_a = torch.rand(1).item() * 0.5 + 1 # Simulated metric self.log('val_metric_a', metric_a, sync_dist=True) step_logged_keys.append('val_metric_a') if 1 in unique_labels: # Type B samples  # Process with Type B samples computes metric B metric_b = torch.rand(1).item() * 0.5 + 2 # Simulated metric self.log('val_metric_b', metric_b, sync_dist=True) step_logged_keys.append('val_metric_b') if 2 in unique_labels: # Type C samples # Process with Type C samples computes metric C metric_c = torch.rand(1).item() * 0.5 + 3 # Simulated metric self.log('val_metric_c', metric_c, sync_dist=True) step_logged_keys.append('val_metric_c') # Record what was logged in this step self.logged_keys.append(step_logged_keys) # Print what each rank is logging (for debugging) rank = self.trainer.global_rank if self.trainer else 0 print(f"Rank {rank}, Batch {batch_idx}: Logged keys = {step_logged_keys}") return { 'val_loss': loss, 'labels': y, 'logged_keys': step_logged_keys } def on_validation_epoch_end(self): """Analyze the logging issue at the end of validation""" if hasattr(self, 'trainer') and self.trainer: rank = self.trainer.global_rank print(f"\n=== Rank {rank} Summary ===") print(f"Total validation steps: {len(self.logged_keys)}") # Check for inconsistent logging patterns all_keys = set() for step_keys in self.logged_keys: all_keys.update(step_keys) print(f"All keys logged by this rank: {sorted(all_keys)}") # Count occurrences of each key key_counts = {} for step_keys in self.logged_keys: for key in step_keys: key_counts[key] = key_counts.get(key, 0) + 1 print(f"Key frequencies: {key_counts}") # Reset for next epoch self.logged_keys = [] def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.001) def reproduce_bug(): """Main function to reproduce the bug""" print("=" * 80) print("PyTorch Lightning DDP Validation Logging Bug Reproduction") print("=" * 80) # Create dataset and dataloader dataset = SyntheticDataset(size=32) # Use a sampler that ensures different ranks get different data distributions # This maximizes the chance of different metrics being logged train_loader = DataLoader( dataset, batch_size=1, shuffle=True, num_workers=0 # Set to 0 for easier debugging ) val_loader = DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0 ) # Create model model = BuggyModel() # Setup trainer with DDP trainer = Trainer( max_epochs=2, accelerator='gpu' if torch.cuda.is_available() else 'cpu', devices=2 if torch.cuda.is_available() else 2, # Use 2 processes strategy='ddp' if torch.cuda.is_available() else 'ddp_spawn', num_nodes=1, enable_progress_bar=True, log_every_n_steps=1, enable_checkpointing=False, enable_model_summary=False, ) print("\nTraining with DDP (2 processes)...") print("Expected: All metrics (val_metric_a, val_metric_b, val_metric_c) should be logged correctly.") print("Bug: Metrics will be misplaced because processes log different keys.\n") # Train and validate trainer.fit(model, train_loader, val_loader) print("\n" + "=" * 80) print("Bug Reproduction Complete!") print("=" * 80) print("\nAnalysis:") print("1. Each process logs different metrics based on the data it receives.") print("2. Lightning tries to synchronize these logs across processes.") print("3. Because the metric keys differ, the synchronization gets confused.") print("4. Result: Some metrics show wrong values or appear in wrong steps.") print("\nCheck the logs above to see the issue:") print("- Look for 'val_metric_a', 'val_metric_b', 'val_metric_c' in TensorBoard/logger") print("- Notice they might show incorrect values or appear/disappear unexpectedly") def simple_repro(): """Even simpler reproduction - can run without GPU""" print("\n" + "=" * 80) print("Simplified Bug Reproduction (Single Process Simulation)") print("=" * 80) # Simulate what happens in DDP print("\nSimulating DDP with 2 processes:") # Process 0 logs in step 0 print("\nStep 0:") print(" Process 0 logs: ['val_loss', 'val_metric_a', 'val_metric_b']") print(" Process 1 logs: ['val_loss', 'val_metric_c']") print(" Lightning tries to sync: Expects same keys from all processes!") # Process 0 logs in step 1 print("\nStep 1:") print(" Process 0 logs: ['val_loss', 'val_metric_a']") print(" Process 1 logs: ['val_loss', 'val_metric_b', 'val_metric_c']") print("\nProblem: Key mismatch causes:") print(" 1. Metric values get assigned to wrong keys") print(" 2. Some metrics appear/disappear") print(" 3. Aggregated values are incorrect") print("\nExpected behavior:") print(" Lightning should handle heterogeneous metric keys across processes") print(" or at least provide a clear error/warning") if __name__ == "__main__": # Run simple explanation first # simple_repro() # Uncomment to run the full DDP reproduction (requires 2+ GPUs or CPU with DDP support) reproduce_bug()

Error messages and logs

================================================================================ PyTorch Lightning DDP Validation Logging Bug Reproduction ================================================================================ GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores Training with DDP (2 processes)... Expected: All metrics (val_metric_a, val_metric_b, val_metric_c) should be logged correctly. Bug: Metrics will be misplaced because processes log different keys. Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2 ================================================================================ PyTorch Lightning DDP Validation Logging Bug Reproduction ================================================================================ Training with DDP (2 processes)... Expected: All metrics (val_metric_a, val_metric_b, val_metric_c) should be logged correctly. Bug: Metrics will be misplaced because processes log different keys. Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2 ---------------------------------------------------------------------------------------------------- distributed_backend=nccl All distributed processes registered. Starting with 2 processes ---------------------------------------------------------------------------------------------------- LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1] LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1] /home/xzy/miniconda3/envs/cbench_vid/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance. Rank 0, Batch 0: Logged keys = ['val_loss', 'val_metric_a'] Rank 0, Batch 1: Logged keys = ['val_loss', 'val_metric_c'] === Rank 0 Summary === Total validation steps: 2 All keys logged by this rank: ['val_loss', 'val_metric_a', 'val_metric_c'] Key frequencies: {'val_loss': 2, 'val_metric_a': 1, 'val_metric_c': 1} Rank 1, Batch 0: Logged keys = ['val_loss', 'val_metric_b'] Rank 1, Batch 1: Logged keys = ['val_loss', 'val_metric_a'] === Rank 1 Summary === Total validation steps: 2 All keys logged by this rank: ['val_loss', 'val_metric_a', 'val_metric_b'] Key frequencies: {'val_loss': 2, 'val_metric_b': 1, 'val_metric_a': 1} /home/xzy/miniconda3/envs/cbench_vid/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance. Rank 0, Batch 0: Logged keys = ['val_loss', 'val_metric_a'] Rank 0, Batch 1: Logged keys = ['val_loss', 'val_metric_c'] Epoch 0/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340 Validation ━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2/16 0:00:00 • 0:00:01 192.15it/s Rank 1, Batch 0: Logged keys = ['val Rank 0, Batch 2: Logged keys = ['val_loss', 'val_metric_b'] Rank 0, Batch 3: Logged keys = ['val_loss', 'val_metric_a'] Epoch 0/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340 Validation ━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4/16 0:00:00 • 0:00:01 197.95it/s Rank 1, Batch 1: Logged keys = ['val Rank 0, Batch 4: Logged keys = ['val_loss', 'val_metric_c'] Rank 0, Batch 5: Logged keys = ['val_loss', 'val_metric_b'] Rank 0, Batch 6: Logged keys = ['val_loss', 'val_metric_a'] Epoch 0/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340 Validation ━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━ 7/16 0:00:00 • 0:00:01 204.50it/s Rank 1, Batch 2: Logged keys = ['val Rank 0, Batch 7: Logged keys = ['val_loss', 'val_metric_c'] Rank 0, Batch 8: Logged keys = ['val_loss', 'val_metric_b'] Epoch 0/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340 Validation ━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━ 9/16 0:00:00 • 0:00:01 204.93it/s Rank 1, Batch 3: Logged keys = ['val Rank 0, Batch 9: Logged keys = ['val_loss', 'val_metric_a'] Rank 0, Batch 10: Logged keys = ['val_loss', 'val_metric_c'] Rank 0, Batch 11: Logged keys = ['val_loss', 'val_metric_b'] Epoch 0/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340 Validation ━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━ 11/16 0:00:00 • 0:00:01 206.23it/s Rank 1, Batch 4: Logged keys = ['val Rank 0, Batch 12: Logged keys = ['val_loss', 'val_metric_a'] Rank 0, Batch 13: Logged keys = ['val_loss', 'val_metric_c'] Epoch 0/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340 Validation ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━ 14/16 0:00:00 • 0:00:01 207.14it/s Rank 1, Batch 5: Logged keys = ['val Rank 0, Batch 14: Logged keys = ['val_loss', 'val_metric_b'] Epoch 0/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340 Validation ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━ 15/16 0:00:00 • 0:00:01 193.32it/s Rank 1, Batch 6: Logged keys = ['val Rank 0, Batch 15: Logged keys = ['val_loss', 'val_metric_a'] Epoch 0/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 104.26it/s v_num: 0.000 train_loss: 1.340 Validation ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 178.06it/s Rank 1, Batch 7: Logged keys = ['val_loss', 'val_metric_a'] Rank 1, Batch 8: Logged keys = ['val_loss', 'val_metric_c'] Rank 1, Batch 9: Logged keys = ['val_loss', 'val_metric_b'] Rank 1, Batch 10: Logged keys = ['val_loss', 'val_metric_a'] Rank 1, Batch 11: Logged keys = ['val_loss', 'val_metric_c'] Rank 1, Batch 12: Logged keys = ['val_loss', 'val_metric_b'] Rank 1, Batch 13: Logged keys = ['val_loss', 'val_metric_a'] Rank 1, Batch 14: Logged keys = ['val_loss', 'val_metric_c'] Rank 1, Batch 15: Logged keys = ['val_loss', 'val_metric_b'] === Rank 1 Summary === Total validation steps: 16 All keys logged by this rank: ['val_loss', 'val_metric_a', 'val_metric_b', 'val_metric_c'] === Rank 0 Summary === Total validation steps: 16 All keys logged by this rank: ['val_loss', 'val_metric_a', 'val_metric_b', 'val_metric_c'] Key frequencies: {'val_loss': 16, 'val_metric_a': 6, 'val_metric_c': 5, 'val_metric_b': 5} Rank 0, Batch 0: Logged keys = ['val_loss', 'val_metric_a'] Rank 0, Batch 1: Logged keys = ['val_loss', 'val_metric_c'] Epoch 1/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116 Validation ━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/16 0:00:00 • -:--:-- 0.00it/s Rank 1, Batch 0: Log Epoch 1/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116 Validation ━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2/16 0:00:00 • 0:00:01 219.10it/s Rank 1, Batch 1: Log Rank 0, Batch 2: Logged keys = ['val_loss', 'val_metric_b'] Rank 0, Batch 3: Logged keys = ['val_loss', 'val_metric_a'] Rank 0, Batch 4: Logged keys = ['val_loss', 'val_metric_c'] Epoch 1/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116 Validation ━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5/16 0:00:00 • 0:00:01 209.41it/s Rank 1, Batch 2: Log Rank 0, Batch 5: Logged keys = ['val_loss', 'val_metric_b'] Rank 0, Batch 6: Logged keys = ['val_loss', 'val_metric_a'] Epoch 1/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116 Validation ━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━ 7/16 0:00:00 • 0:00:01 208.61it/s Rank 1, Batch 3: Log Rank 0, Batch 7: Logged keys = ['val_loss', 'val_metric_c'] Rank 0, Batch 8: Logged keys = ['val_loss', 'val_metric_b'] Rank 0, Batch 9: Logged keys = ['val_loss', 'val_metric_a'] Epoch 1/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116 Validation ━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━ 9/16 0:00:00 • 0:00:01 209.89it/s Rank 1, Batch 4: Log Rank 0, Batch 10: Logged keys = ['val_loss', 'val_metric_c'] Rank 0, Batch 11: Logged keys = ['val_loss', 'val_metric_b'] Epoch 1/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116 Validation ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━ 12/16 0:00:00 • 0:00:01 210.92it/s Rank 1, Batch 5: Log Rank 0, Batch 12: Logged keys = ['val_loss', 'val_metric_a'] Rank 0, Batch 13: Logged keys = ['val_loss', 'val_metric_c'] Rank 0, Batch 14: Logged keys = ['val_loss', 'val_metric_b'] Epoch 1/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116 Validation ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━ 14/16 0:00:00 • 0:00:01 211.86it/s Rank 1, Batch 6: Log Rank 0, Batch 15: Logged keys = ['val_loss', 'val_metric_a'] Epoch 1/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116 Validation ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 211.71it/s Rank 1, Batch 7: Logged keys = ['val_loss', 'val_metric_a'] Rank 1, Batch 8: Logged keys = ['val_loss', 'val_metric_c'] Rank 1, Batch 9: Logged keys = ['val_loss', 'val_metric_b'] Rank 1, Batch 10: Logged keys = ['val_loss', 'val_metric_a'] Rank 1, Batch 11: Logged keys = ['val_loss', 'val_metric_c'] Rank 1, Batch 12: Logged keys = ['val_loss', 'val_metric_b'] Rank 1, Batch 13: Logged keys = ['val_loss', 'val_metric_a'] Rank 1, Batch 14: Logged keys = ['val_loss', 'val_metric_c'] Rank 1, Batch 15: Logged keys = ['val_loss', 'val_metric_b'] === Rank 1 Summary === Total validation steps: 16 All keys logged by this rank: ['val_loss', 'val_metric_a', 'val_metric_b', 'val_metric_c'] === Rank 0 Summary === Total validation steps: 16 All keys logged by this rank: ['val_loss', 'val_metric_a', 'val_metric_b', 'val_metric_c'] Key frequencies: {'val_loss': 16, 'val_metric_a': 6, 'val_metric_c': 5, 'val_metric_b': 5} Epoch 1/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.116`Trainer.fit` stopped: `max_epochs=2` reached. ================================================================================ Bug Reproduction Complete! ================================================================================ Analysis: 1. Each process logs different metrics based on the data it receives. 2. Lightning tries to synchronize these logs across processes. 3. Because the metric keys differ, the synchronization gets confused. 4. Result: Some metrics show wrong values or appear in wrong steps. Check the logs above to see the issue: - Look for 'val_metric_a', 'val_metric_b', 'val_metric_c' in TensorBoard/logger - Notice they might show incorrect values or appear/disappear unexpectedly Epoch 1/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 0:00:00 • 0:00:00 126.26it/s v_num: 0.000 train_loss: 1.167 val_loss: 1.099 ================================================================================ Bug Reproduction Complete! ================================================================================ Analysis: 1. Each process logs different metrics based on the data it receives. 2. Lightning tries to synchronize these logs across processes. 3. Because the metric keys differ, the synchronization gets confused. 4. Result: Some metrics show wrong values or appear in wrong steps. Check the logs above to see the issue: - Look for 'val_metric_a', 'val_metric_b', 'val_metric_c' in TensorBoard/logger - Notice they might show incorrect values or appear/disappear unexpectedly 

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA TITAN Xp
    - NVIDIA TITAN X (Pascal)
    - available: True
    - version: 11.8
  • Lightning:
    - adabelief-pytorch: 0.2.0
    - lightning: 2.4.0
    - lightning-utilities: 0.15.2
    - pytorch-lightning: 2.6.0
    - pytorch-msssim: 1.0.0
    - torch: 2.4.1+cu118
    - torchaudio: 2.4.1+cu118
    - torchmetrics: 1.8.2
    - torchvision: 0.19.1+cu118
  • Packages:
    - absl-py: 2.3.1
    - adabelief-pytorch: 0.2.0
    - addict: 2.4.0
    - aiohappyeyeballs: 2.6.1
    - aiohttp: 3.13.2
    - aiosignal: 1.4.0
    - aliyun-python-sdk-core: 2.16.0
    - aliyun-python-sdk-kms: 2.16.5
    - attrs: 25.4.0
    - autograd: 1.8.0
    - av: 16.0.1
    - boto3: 1.42.2
    - botocore: 1.42.2
    - brotlipy: 0.7.0
    - cbench: 0.2
    - certifi: 2025.11.12
    - cffi: 2.0.0
    - charset-normalizer: 3.4.4
    - click: 8.3.1
    - colorama: 0.4.6
    - compressai: 1.2.3
    - contourpy: 1.3.3
    - crcmod: 1.7
    - cryptography: 46.0.3
    - cycler: 0.12.1
    - cython: 3.2.2
    - einops: 0.8.1
    - entmax: 1.1
    - filelock: 3.14.0
    - fonttools: 4.61.0
    - frozenlist: 1.8.0
    - fsspec: 2025.9.0
    - grpcio: 1.76.0
    - idna: 3.11
    - imageio: 2.37.2
    - jinja2: 3.1.6
    - jmespath: 0.10.0
    - kiwisolver: 1.4.9
    - lightning: 2.4.0
    - lightning-utilities: 0.15.2
    - litdata: 0.2.58
    - markdown: 3.10
    - markdown-it-py: 4.0.0
    - markupsafe: 2.1.5
    - matplotlib: 3.10.7
    - mdurl: 0.1.2
    - mmcv: 2.2.0
    - mmengine: 0.10.7
    - model-index: 0.1.11
    - mpmath: 1.3.0
    - multidict: 6.7.0
    - networkx: 2.6.3
    - numpy: 2.2.6
    - nvidia-cublas-cu11: 11.11.3.6
    - nvidia-cuda-cupti-cu11: 11.8.87
    - nvidia-cuda-nvrtc-cu11: 11.8.89
    - nvidia-cuda-runtime-cu11: 11.8.89
    - nvidia-cudnn-cu11: 9.1.0.70
    - nvidia-cufft-cu11: 10.9.0.58
    - nvidia-curand-cu11: 10.3.0.86
    - nvidia-cusolver-cu11: 11.4.1.48
    - nvidia-cusparse-cu11: 11.7.5.86
    - nvidia-nccl-cu11: 2.20.5
    - nvidia-nvtx-cu11: 11.8.86
    - obstore: 0.8.2
    - opencv-python: 4.12.0.88
    - opendatalab: 0.0.10
    - openmim: 0.3.9
    - openxlab: 0.1.3
    - ordered-set: 4.1.0
    - oss2: 2.17.0
    - packaging: 24.2
    - pandas: 2.3.3
    - pillow: 11.3.0
    - pip: 25.3
    - platformdirs: 4.5.0
    - propcache: 0.4.1
    - protobuf: 6.33.1
    - ptflops: 0.7.5
    - pybind11: 3.0.1
    - pybind11-stubgen: 0.16.2
    - pycocotools: 2.0.10
    - pycparser: 2.23
    - pycryptodome: 3.23.0
    - pygments: 2.19.2
    - pyparsing: 3.2.5
    - pyrclone-wrapper: 0.0.3
    - python-dateutil: 2.9.0.post0
    - pytorch-lightning: 2.6.0
    - pytorch-msssim: 1.0.0
    - pytz: 2023.4
    - pyyaml: 6.0.3
    - requests: 2.28.2
    - rich: 13.4.2
    - s3transfer: 0.16.0
    - scipy: 1.16.3
    - setuptools: 60.2.0
    - six: 1.17.0
    - sympy: 1.14.0
    - tabulate: 0.9.0
    - tensorboard: 2.20.0
    - tensorboard-data-server: 0.7.2
    - termcolor: 3.2.0
    - thop: 0.1.1.post2209072238
    - tifffile: 2025.10.16
    - torch: 2.4.1+cu118
    - torchaudio: 2.4.1+cu118
    - torchmetrics: 1.8.2
    - torchvision: 0.19.1+cu118
    - tqdm: 4.65.2
    - triton: 3.0.0
    - typing-extensions: 4.15.0
    - tzdata: 2025.2
    - urllib3: 1.26.20
    - werkzeug: 3.1.4
    - wheel: 0.45.1
    - yapf: 0.43.0
    - yarl: 1.22.0
    - zstandard: 0.25.0
    - zstd: 1.5.7.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.11.14
    - release: 5.15.0-139-generic
    - version: Gradient Accumulation Scheduler #149~20.04.1-Ubuntu SMP Wed Apr 16 08:29:56 UTC 2025

More info

No response

cc @ethanwharris @justusschock @lantiga

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions