@@ -383,16 +383,18 @@ def get_unsafe_globals_in_checkpoint(f: FileLike) -> list[str]:
383383
384384class skip_data :
385385 """
386- Context-manager that skips writing storage bytes for ``torch.save`` calls.
386+ Context-manager that skips writing/reading storage bytes for ``torch.save`` / ``torch.load `` calls.
387387
388- Storages will still be saved, but the space that their bytes would usually be written to
388+ For the save path, storages will still be saved, but the space that their bytes would usually be written to
389389 will be empty space. The storage bytes can then be populated in a separate pass.
390390
391+ For the load path, tensors will be loaded per the checkpoint but their storages will not be populated with data.
392+
391393 .. warning::
392394 The ``skip_data`` context manager is an early prototype and is subject to change.
393395
394396 Args:
395- materialize_fake_tensors: Whether to materialize FakeTensors.
397+ materialize_fake_tensors: Whether to materialize FakeTensors during save. This is a no-op for the load path .
396398
397399 Example:
398400 >>> # xdoctest: +SKIP("NamedTemporaryFile on Windows")
@@ -1418,14 +1420,6 @@ def _get_wo_message(message: str) -> str:
14181420 updated_message += message
14191421 return updated_message + DOCS_MESSAGE
14201422
1421- global _serialization_tls
1422- skip_data = _serialization_tls .skip_data
1423- if skip_data :
1424- raise RuntimeError (
1425- "`torch.load` called within a torch.serialization.skip_data context manager "
1426- "is not supported yet. Please call torch.load outside the skip_data context manager."
1427- )
1428-
14291423 weights_only_not_set = weights_only is None
14301424
14311425 if weights_only_not_set :
@@ -1735,6 +1729,9 @@ def persistent_load(saved_id):
17351729 if root_key not in deserialized_objects :
17361730 if torch ._guards .active_fake_mode () is not None :
17371731 obj = cast (Storage , torch .UntypedStorage (nbytes , device = "meta" ))
1732+ elif _serialization_tls .skip_data :
1733+ obj = cast (Storage , torch .UntypedStorage (nbytes ))
1734+ obj = restore_location (obj , location )
17381735 else :
17391736 obj = cast (Storage , torch .UntypedStorage (nbytes ))
17401737 obj ._torch_load_uninitialized = True
@@ -1807,7 +1804,7 @@ def persistent_load(saved_id):
18071804
18081805 deserialized_storage_keys = pickle_module .load (f , ** pickle_load_args )
18091806
1810- if torch ._guards .active_fake_mode () is None :
1807+ if torch ._guards .active_fake_mode () is None and not _serialization_tls . skip_data :
18111808 offset = f .tell () if f_should_read_directly else None
18121809 for key in deserialized_storage_keys :
18131810 assert key in deserialized_objects
@@ -1999,6 +1996,9 @@ def load_tensor(dtype, numel, key, location):
19991996 nbytes = numel * torch ._utils ._element_size (dtype )
20001997 storage = torch .UntypedStorage (nbytes , device = "meta" )
20011998 storage ._checkpoint_offset = zip_file .get_record_offset (name )
1999+ elif _serialization_tls .skip_data :
2000+ nbytes = numel * torch ._utils ._element_size (dtype )
2001+ storage = torch .UntypedStorage (nbytes )
20022002 elif overall_storage is not None :
20032003 if can_calculate_storage_offsets and calculate_storage_offsets :
20042004 storage_offset = _get_offset (key , name , numel )
0 commit comments