- Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
An exception may occur when loading a model checkpoint (with LightningModule.load_from_checkpoint()) that was corrupted during saving with PyTorch Lightning (Trainer.save_checkpoint()).
Trying to load the checkpoint file directly e.g. with torch.load() shows the same exception. Comparing the file with a valid checkpoint (e.g. with diff) confirms that both differ. This indicates that the checkpoint file is corrupted during saving.
This exception appears for CPU and CUDA versions of PyTorch libraries. Tested with PyTorch 2.9.0+cu130 / Torchvision 0.24.0+cu130 (for CUDA 13.0) and the CPU equivalent, and lightning-2.6.0 (bug initially spotted with lightning-2.5.5).
A minimal self-contained script to reproduce the issue is attached below. Continuously saving and loading a checkpoint after trainer.fit() will sometimes raise an exception and other times succeed. Script output might look like this:
Starting loop... 10 ok, 0 fail 20 ok, 0 fail 30 ok, 0 fail 40 ok, 0 fail 48 ok, 1 fail: PytorchStreamReader failed reading file data/1017: invalid header or archive is corrupted 48 ok, 2 fail: PytorchStreamReader failed reading file data/569: invalid header or archive is corrupted 50 ok, 2 fail 60 ok, 2 fail Failure also occurs with trainer.save_checkpoint(weights_only=True).
On an older version of PyTorch (torch-2.3.1+cu121, torchvision-0.18.1+cu121) the issue is even more prevalent with an approx. 1/6 failure rate.
On the older PyTorch version, sleeping between saving and loading increases the likelihood of failure; with sleep(10) every save is corrupted. On the newer PyTorch version, no such effect can be observed.
Test script is running in a docker container on Windows/WSL. Failures occur regardless of storing the checkpoint on a bind mount or volume mount.
Issue cannot be reproduced on another machine with the same docker image, running on Linux.
Issue cannot be reproduced with Lightning Studio: https://lightning.ai/stf9790/vision-model/studios/prime-bronze-7x9h/code?source=copylink
What version are you seeing the problem on?
v2.6
Reproduced in studio
Not reproducible in studio.
How to reproduce the bug
from time import sleep import traceback import torch import torch.nn as nn import torchvision from torchvision.transforms import v2 import lightning as L class MyImageClassifier(L.LightningModule): def __init__(self, num_classes: int): super().__init__() self.loss_function = nn.CrossEntropyLoss() self.model = self.instantiate_model(num_classes) self.save_hyperparameters() def instantiate_model(self, num_classes: int) -> nn.Module: model = torchvision.models.maxvit_t(weights='DEFAULT') block_channels=[64, 128, 256, 512] model.classifier = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.LayerNorm(block_channels[-1]), nn.Linear(block_channels[-1], block_channels[-1]), nn.Tanh(), nn.Linear(block_channels[-1], num_classes, bias=False) ) return model def configure_optimizers(self): optimizer = torch.optim.SGD(self.model.parameters(), lr = 0.001, momentum=0.9) scheduler = torch.optim.lr_scheduler.StepLR(optimizer = optimizer, step_size = 2, gamma = 0.75) return { 'optimizer': optimizer, 'lr_scheduler': scheduler, } def forward(self, inputs): return self.model(inputs) def training_step(self, batch, batch_idx, dataloader_idx=0): inputs, targets = batch logits = self(inputs) loss = self.loss_function(logits, targets) return loss def validation_step(self, batch, batch_idx, dataloader_idx=0): inputs, targets = batch logits = self(inputs) loss = self.loss_function(logits, targets) return loss def worker(): transform = v2.Compose([ v2.ToImage(), v2.ToDtype(torch.float32, scale=True), ]) dataset = torchvision.datasets.fakedata.FakeData(size=100, num_classes=3, transform=transform) train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=16) val_dataloader = torch.utils.data.DataLoader(dataset, batch_size=16) model = MyImageClassifier(num_classes=3) trainer = L.Trainer( max_epochs=1, # accelerator='cpu', ) trainer.fit( model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ) ok = 0 fail = 0 print("Starting loop...") while True: trainer.save_checkpoint('tmp_test.ckpt', weights_only=False) # sleep(10) try: newnet = MyImageClassifier.load_from_checkpoint(checkpoint_path='tmp_test.ckpt') ok += 1 except Exception as e: fail += 1 print(f"{ok} ok, {fail} fail: {str(e)}") #print(traceback.format_exc()) if ok % 10 == 0: print(f"{ok} ok, {fail} fail") if __name__ == "__main__": worker()Error messages and logs
Output might look similar to this. Exceptions occur randomly, and the data/<number> may vary for each exception.
GPU available: False, used: False TPU available: False, using: 0 TPU cores /root/.cache/pypoetry/virtualenvs/test-s5OTQf_9-py3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default | Name | Type | Params | Mode | FLOPs ------------------------------------------------------------------- 0 | loss_function | CrossEntropyLoss | 0 | train | 0 1 | model | MaxVit | 30.4 M | train | 0 ------------------------------------------------------------------- 30.4 M Trainable params 0 Non-trainable params 30.4 M Total params 121.637 Total estimated model params size (MB) 671 Modules in train mode 0 Modules in eval mode 0 Total Flops Sanity Checking: | | 0/? [00:00<?, ?it/s] /root/.cache/pypoetry/virtualenvs/test-s5OTQf_9-py3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance. /root/.cache/pypoetry/virtualenvs/test-s5OTQf_9-py3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance. /root/.cache/pypoetry/virtualenvs/test-s5OTQf_9-py3.12/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:317: The number of training batches (7) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch. Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:18<00:00, 0.39it/s, v_num=2]`Trainer.fit` stopped: `max_epochs=1` reached. Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:19<00:00, 0.37it/s, v_num=2] Starting loop... 10 ok, 0 fail 12 ok, 1 fail: PytorchStreamReader failed reading file data/993: invalid header or archive is corrupted 12 ok, 2 fail: PytorchStreamReader failed reading file data/568: invalid header or archive is corrupted 19 ok, 3 fail: PytorchStreamReader failed reading file data/974: invalid header or archive is corrupted 19 ok, 4 fail: PytorchStreamReader failed reading file data/568: invalid header or archive is corrupted 20 ok, 4 fail 27 ok, 5 fail: PytorchStreamReader failed reading file data/568: invalid header or archive is corrupted 30 ok, 5 fail 40 ok, 5 fail 50 ok, 5 fail 60 ok, 5 fail 68 ok, 6 fail: PytorchStreamReader failed reading zip archive: failed finding central directory 70 ok, 6 fail 76 ok, 7 fail: PytorchStreamReader failed reading file data/517: invalid header or archive is corrupted 80 ok, 7 fail 84 ok, 8 fail: PytorchStreamReader failed reading file data/291: invalid header or archive is corrupted 90 ok, 8 fail 100 ok, 8 fail 110 ok, 8 fail Traceback looks like this:
Traceback (most recent call last): File "/data/test-cpu/test.py", line 92, in worker newnet = MyImageClassifier.load_from_checkpoint(checkpoint_path='tmp_test.ckpt') ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/pypoetry/virtualenvs/test-s5OTQf_9-py3.12/lib/python3.12/site-packages/lightning/pytorch/utilities/model_helpers.py", line 130, in wrapper return self.method(cls_type, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/pypoetry/virtualenvs/test-s5OTQf_9-py3.12/lib/python3.12/site-packages/lightning/pytorch/core/module.py", line 1781, in load_from_checkpoint loaded = _load_from_checkpoint( ^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/pypoetry/virtualenvs/test-s5OTQf_9-py3.12/lib/python3.12/site-packages/lightning/pytorch/core/saving.py", line 65, in _load_from_checkpoint checkpoint = pl_load(checkpoint_path, map_location=map_location, weights_only=weights_only) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/pypoetry/virtualenvs/test-s5OTQf_9-py3.12/lib/python3.12/site-packages/lightning/fabric/utilities/cloud_io.py", line 73, in _load return torch.load( ^^^^^^^^^^^ File "/root/.cache/pypoetry/virtualenvs/test-s5OTQf_9-py3.12/lib/python3.12/site-packages/torch/serialization.py", line 1521, in load return _load( ^^^^^^ File "/root/.cache/pypoetry/virtualenvs/test-s5OTQf_9-py3.12/lib/python3.12/site-packages/torch/serialization.py", line 2122, in _load result = unpickler.load() ^^^^^^^^^^^^^^^^ File "/root/.cache/pypoetry/virtualenvs/test-s5OTQf_9-py3.12/lib/python3.12/site-packages/torch/_weights_only_unpickler.py", line 535, in load self.append(self.persistent_load(pid)) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/pypoetry/virtualenvs/test-s5OTQf_9-py3.12/lib/python3.12/site-packages/torch/serialization.py", line 2086, in persistent_load typed_storage = load_tensor( ^^^^^^^^^^^^ File "/root/.cache/pypoetry/virtualenvs/test-s5OTQf_9-py3.12/lib/python3.12/site-packages/torch/serialization.py", line 2039, in load_tensor zip_file.get_storage_from_record(name, numel, torch.UntypedStorage) RuntimeError: PytorchStreamReader failed reading file data/472: invalid header or archive is corrupted Environment
Current environment
- CUDA:
- GPU: None
- available: False
- version: None - Lightning:
- lightning: 2.6.0
- lightning-utilities: 0.15.2
- pytorch-lightning: 2.6.0
- torch: 2.9.0+cpu
- torchmetrics: 1.8.2
- torchvision: 0.24.0+cpu - Packages:
- aiohappyeyeballs: 2.6.1
- aiohttp: 3.13.2
- aiosignal: 1.4.0
- attrs: 25.4.0
- autocommand: 2.2.2
- backports.tarfile: 1.2.0
- filelock: 3.20.1
- frozenlist: 1.8.0
- fsspec: 2025.12.0
- idna: 3.11
- importlib-metadata: 8.0.0
- inflect: 7.3.1
- jaraco.collections: 5.1.0
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.1
- jaraco.text: 3.12.1
- jinja2: 3.1.6
- lightning: 2.6.0
- lightning-utilities: 0.15.2
- markupsafe: 3.0.3
- more-itertools: 10.3.0
- mpmath: 1.3.0
- multidict: 6.7.0
- networkx: 3.6.1
- numpy: 2.3.5
- packaging: 25.0
- pillow: 12.0.0
- pip: 25.3
- platformdirs: 4.2.2
- propcache: 0.4.1
- pytorch-lightning: 2.6.0
- pyyaml: 6.0.3
- setuptools: 80.9.0
- sympy: 1.14.0
- tomli: 2.0.1
- torch: 2.9.0+cpu
- torchmetrics: 1.8.2
- torchvision: 0.24.0+cpu
- tqdm: 4.67.1
- typeguard: 4.3.0
- typing-extensions: 4.15.0
- wheel: 0.45.1
- yarl: 1.22.0
- zipp: 3.19.2 - System:
- OS: Linux
- architecture:
- 64bit
-
- processor: x86_64
- python: 3.12.3
- release: 6.6.87.2-microsoft-standard-WSL2
- version: Proposal for help #1 SMP PREEMPT_DYNAMIC Thu Jun 5 18:30:46 UTC 2025
More info
No response