Skip to content

Commit 29abe6e

Browse files
matsumotosanSkafteNickiBordapre-commit-ci[bot]tchaton
authored
Expose weights_only for loading checkpoints with Trainer, LightningModule, LightningDataModule (#21072)
* change weights_only default to True * add docs on weights_only arg * add weights_only arg to checkpoint save. weights_only during test set based on ckpt version * add weights_only arg to checkpoint_io * woops, reverting changes * permissions too * fix link * fix another link * datamodule weights_only args * wip: try safe_globals context manager for tests * add weights_only arg to _run_standard_hparams_test * weights_only=False when adding extra_args * switch to lightning_utilities.cli requirements set-oldest (#21077) * bump: try `deepspeed >=0.14.1,<=0.15.0` (#21076) * try `deepspeed >=0.14.1,<=0.15.0` * drop from oldest * pip uninstall -y deepspeed * error::DeprecationWarning * weights_only=True default for torch>=2.6 * changelog * ignore torch.load futurewarning * add .* * will this woork * weights_only according pl version * set env var TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 for pl < 1.5.0 * weights_only=False for omegaconf hparams test * default to weights_only=true for loading from state_dict from url * weights_only=False for hydra * Update src/lightning/fabric/utilities/cloud_io.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> * defaults for weights_only in torch.hub.load_state_dict_from_url * default to weights_only=False for torch.hub.load_state_dict_from_url * add weights_only to trainer.fit, validate, test, predict * fix tests * add weights_only arg * specify weights_only kwarg * weights_only for fsdp load * Apply suggestions from code review * Apply suggestions from code review * default is none * add weights_only args to strategies * trainer default to weights_only=None * wip: fix typing dump_checkpoint * Apply suggestions from code review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * weights_only as last arg * asset called with none * weights_only=False for torch>=2.6 in tests * fix changelog description * Empty-Commit * Empty-Commit * trigger ci * skip ddp_fork on macos --------- Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas chaton <thomas@grid.ai> Co-authored-by: Jirka B <j.borovec+github@gmail.com> Co-authored-by: jirka <jirka.borovec@seznam.cz> Co-authored-by: Deependu <deependujha21@gmail.com>
1 parent b82db78 commit 29abe6e

File tree

29 files changed

+228
-78
lines changed

29 files changed

+228
-78
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ filterwarnings = [
172172
# "error::DeprecationWarning",
173173
"error::FutureWarning",
174174
"ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated
175+
"ignore:You are using `torch.load` with `weights_only=False`.*:FutureWarning",
175176
]
176177
xfail_strict = true
177178
junit_duration_report = "call"

src/lightning/fabric/CHANGELOG.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919

2020
### Changed
2121

22-
-
23-
24-
22+
- Expose `weights_only` argument for `Trainer.{fit,validate,test,predict}` and let `torch` handle default value ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072))
2523
- Set `_DeviceDtypeModuleMixin._device` from torch's default device function ([#21164](https://github.com/Lightning-AI/pytorch-lightning/pull/21164))
2624

2725

src/lightning/fabric/plugins/io/checkpoint_io.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,20 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio
4747
"""
4848

4949
@abstractmethod
50-
def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> dict[str, Any]:
50+
def load_checkpoint(
51+
self, path: _PATH, map_location: Optional[Any] = None, weights_only: Optional[bool] = None
52+
) -> dict[str, Any]:
5153
"""Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages.
5254
5355
Args:
5456
path: Path to checkpoint
5557
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
5658
locations.
59+
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
60+
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
61+
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
62+
recommend using ``weights_only=True``. For more information, please refer to the
63+
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
5764
5865
Returns: The loaded checkpoint.
5966

src/lightning/fabric/plugins/io/torch_io.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,22 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio
5959

6060
@override
6161
def load_checkpoint(
62-
self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage
62+
self,
63+
path: _PATH,
64+
map_location: Optional[Callable] = lambda storage, loc: storage,
65+
weights_only: Optional[bool] = None,
6366
) -> dict[str, Any]:
6467
"""Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files.
6568
6669
Args:
6770
path: Path to checkpoint
6871
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
6972
locations.
73+
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
74+
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
75+
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
76+
recommend using ``weights_only=True``. For more information, please refer to the
77+
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
7078
7179
Returns: The loaded checkpoint.
7280
@@ -80,7 +88,7 @@ def load_checkpoint(
8088
if not fs.exists(path):
8189
raise FileNotFoundError(f"Checkpoint file not found: {path}")
8290

83-
return pl_load(path, map_location=map_location)
91+
return pl_load(path, map_location=map_location, weights_only=weights_only)
8492

8593
@override
8694
def remove_checkpoint(self, path: _PATH) -> None:

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ def load_checkpoint(
473473
path: _PATH,
474474
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
475475
strict: bool = True,
476+
weights_only: Optional[bool] = None,
476477
) -> dict[str, Any]:
477478
"""Load the contents from a checkpoint and restore the state of the given objects.
478479
@@ -498,7 +499,7 @@ def load_checkpoint(
498499
# This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from
499500
# a consolidated checkpoint
500501
path = self.broadcast(path)
501-
return super().load_checkpoint(path=path, state=state, strict=strict)
502+
return super().load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only)
502503

503504
if not state:
504505
raise ValueError(

src/lightning/fabric/strategies/fsdp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ def load_checkpoint(
516516
path: _PATH,
517517
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
518518
strict: bool = True,
519+
weights_only: Optional[bool] = None,
519520
) -> dict[str, Any]:
520521
"""Load the contents from a checkpoint and restore the state of the given objects."""
521522
if not state:
@@ -586,7 +587,7 @@ def load_checkpoint(
586587
optim.load_state_dict(flattened_osd)
587588

588589
# Load metadata (anything not a module or optimizer)
589-
metadata = torch.load(path / _METADATA_FILENAME)
590+
metadata = torch.load(path / _METADATA_FILENAME, weights_only=weights_only)
590591
requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
591592
_validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict)
592593
for key in requested_metadata_keys:

src/lightning/fabric/strategies/model_parallel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def load_checkpoint(
275275
path: _PATH,
276276
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
277277
strict: bool = True,
278+
weights_only: Optional[bool] = None,
278279
) -> dict[str, Any]:
279280
"""Load the contents from a checkpoint and restore the state of the given objects."""
280281
if not state:
@@ -295,7 +296,7 @@ def load_checkpoint(
295296
f"Loading a single optimizer object from a checkpoint is not supported yet with {type(self).__name__}."
296297
)
297298

298-
return _load_checkpoint(path=path, state=state, strict=strict)
299+
return _load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only)
299300

300301
def _setup_distributed(self) -> None:
301302
reset_seed()
@@ -411,6 +412,7 @@ def _load_checkpoint(
411412
state: dict[str, Union[Module, Optimizer, Any]],
412413
strict: bool = True,
413414
optimizer_states_from_list: bool = False,
415+
weights_only: Optional[bool] = None,
414416
) -> dict[str, Any]:
415417
from torch.distributed.checkpoint.state_dict import (
416418
StateDictOptions,
@@ -449,7 +451,7 @@ def _load_checkpoint(
449451
set_optimizer_state_dict(module, optim, optim_state_dict=optim_state[optim_key], options=state_dict_options)
450452

451453
# Load metadata (anything not a module or optimizer)
452-
metadata = torch.load(path / _METADATA_FILENAME)
454+
metadata = torch.load(path / _METADATA_FILENAME, weights_only=weights_only)
453455
requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
454456
_validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict)
455457
for key in requested_metadata_keys:
@@ -461,7 +463,7 @@ def _load_checkpoint(
461463
return metadata
462464

463465
if _is_full_checkpoint(path):
464-
checkpoint = torch.load(path, mmap=True, map_location="cpu", weights_only=False)
466+
checkpoint = torch.load(path, mmap=True, map_location="cpu", weights_only=weights_only)
465467
_load_raw_module_state(checkpoint.pop(module_key), module, strict=strict)
466468

467469
state_dict_options = StateDictOptions(

src/lightning/fabric/strategies/strategy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def load_checkpoint(
310310
path: _PATH,
311311
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
312312
strict: bool = True,
313+
weights_only: Optional[bool] = None,
313314
) -> dict[str, Any]:
314315
"""Load the contents from a checkpoint and restore the state of the given objects.
315316
@@ -330,7 +331,7 @@ def load_checkpoint(
330331
331332
"""
332333
torch.cuda.empty_cache()
333-
checkpoint = self.checkpoint_io.load_checkpoint(path)
334+
checkpoint = self.checkpoint_io.load_checkpoint(path, weights_only=weights_only)
334335
if not state:
335336
return checkpoint
336337

src/lightning/fabric/strategies/xla_fsdp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ def load_checkpoint(
516516
path: _PATH,
517517
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
518518
strict: bool = True,
519+
weights_only: Optional[bool] = None,
519520
) -> dict[str, Any]:
520521
"""Given a folder, load the contents from a checkpoint and restore the state of the given objects.
521522
@@ -608,7 +609,7 @@ def load_checkpoint(
608609
)
609610
if "model" not in state or not isinstance(model := state["model"], torch.nn.Module):
610611
raise NotImplementedError("XLAFSDP only supports a single model instance with 'model' as the key.")
611-
full_ckpt = torch.load(path)
612+
full_ckpt = torch.load(path, weights_only=weights_only)
612613
model.load_state_dict(full_ckpt.pop("model"), strict=strict)
613614
return full_ckpt
614615

src/lightning/fabric/utilities/cloud_io.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import io
1818
import logging
1919
from pathlib import Path
20-
from typing import IO, Any, Union
20+
from typing import IO, Any, Optional, Union
2121

2222
import fsspec
2323
import fsspec.utils
@@ -34,13 +34,18 @@
3434
def _load(
3535
path_or_url: Union[IO, _PATH],
3636
map_location: _MAP_LOCATION_TYPE = None,
37-
weights_only: bool = False,
37+
weights_only: Optional[bool] = None,
3838
) -> Any:
3939
"""Loads a checkpoint.
4040
4141
Args:
4242
path_or_url: Path or URL of the checkpoint.
4343
map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations.
44+
weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other primitive
45+
types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use
46+
``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using
47+
``weights_only=True``. For more information, please refer to the
48+
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
4449
4550
"""
4651
if not isinstance(path_or_url, (str, Path)):
@@ -51,6 +56,13 @@ def _load(
5156
weights_only=weights_only,
5257
)
5358
if str(path_or_url).startswith("http"):
59+
if weights_only is None:
60+
weights_only = False
61+
log.debug(
62+
f"Defaulting to `weights_only=False` for remote checkpoint: {path_or_url}."
63+
f" If loading a checkpoint from an untrustted source, we recommend using `weights_only=True`."
64+
)
65+
5466
return torch.hub.load_state_dict_from_url(
5567
str(path_or_url),
5668
map_location=map_location, # type: ignore[arg-type]
@@ -70,7 +82,7 @@ def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem:
7082
return fs
7183

7284

73-
def _atomic_save(checkpoint: dict[str, Any], filepath: Union[str, Path]) -> None:
85+
def _atomic_save(checkpoint: dict[str, Any], filepath: _PATH) -> None:
7486
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
7587
7688
Args:

0 commit comments

Comments
 (0)