Skip to content

Commit d518490

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Make torch.serialization.skip_data work with torch.load (pytorch#148018)
Pull Request resolved: pytorch#148018 Approved by: https://github.com/albanD ghstack dependencies: pytorch#147786, pytorch#147787, pytorch#147788
1 parent be0ceee commit d518490

File tree

2 files changed

+33
-18
lines changed

2 files changed

+33
-18
lines changed

test/test_serialization.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,27 @@ def test_debug_set_in_ci(self):
908908
# This test is to make sure that the serialization debug flag is set in CI
909909
self.assertTrue(os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1")
910910

911+
def test_skip_data_load(self):
912+
t_device = "cuda" if torch.cuda.is_available() else "cpu"
913+
t_v2 = torch.randn(2, 3, device=t_device)
914+
tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device))
915+
916+
sd = {'t_v2': t_v2, 'tt': tt}
917+
sd_zeroed = {
918+
't_v2': torch.zeros(2, 3, device=t_device),
919+
'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)),
920+
}
921+
922+
with BytesIOContext() as f:
923+
torch.save(sd, f)
924+
f.seek(0)
925+
with safe_globals([TwoTensor]), skip_data():
926+
sd_loaded = torch.load(f)
927+
self.assertNotEqual(sd_loaded, sd)
928+
for k in sd_loaded.keys():
929+
sd_loaded[k] = sd_loaded[k].zero_()
930+
self.assertEqual(sd_loaded, sd_zeroed)
931+
911932

912933
class serialization_method:
913934
def __init__(self, use_zip):
@@ -4463,12 +4484,6 @@ def _save_load(t):
44634484
with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"):
44644485
_save_load(t)
44654486

4466-
with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"):
4467-
with skip_data(), BytesIOContext() as f:
4468-
torch.save(torch.randn(2, 3), f)
4469-
f.seek(0)
4470-
torch.load(f, weights_only=True)
4471-
44724487
@parametrize("force_weights_only", (True, False))
44734488
def test_weights_only_env_variables(self, force_weights_only):
44744489
env_var = "TORCH_FORCE_WEIGHTS_ONLY_LOAD" if force_weights_only else "TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"

torch/serialization.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -383,16 +383,18 @@ def get_unsafe_globals_in_checkpoint(f: FileLike) -> list[str]:
383383

384384
class skip_data:
385385
"""
386-
Context-manager that skips writing storage bytes for ``torch.save`` calls.
386+
Context-manager that skips writing/reading storage bytes for ``torch.save`` / ``torch.load`` calls.
387387
388-
Storages will still be saved, but the space that their bytes would usually be written to
388+
For the save path, storages will still be saved, but the space that their bytes would usually be written to
389389
will be empty space. The storage bytes can then be populated in a separate pass.
390390
391+
For the load path, tensors will be loaded per the checkpoint but their storages will not be populated with data.
392+
391393
.. warning::
392394
The ``skip_data`` context manager is an early prototype and is subject to change.
393395
394396
Args:
395-
materialize_fake_tensors: Whether to materialize FakeTensors.
397+
materialize_fake_tensors: Whether to materialize FakeTensors during save. This is a no-op for the load path.
396398
397399
Example:
398400
>>> # xdoctest: +SKIP("NamedTemporaryFile on Windows")
@@ -1418,14 +1420,6 @@ def _get_wo_message(message: str) -> str:
14181420
updated_message += message
14191421
return updated_message + DOCS_MESSAGE
14201422

1421-
global _serialization_tls
1422-
skip_data = _serialization_tls.skip_data
1423-
if skip_data:
1424-
raise RuntimeError(
1425-
"`torch.load` called within a torch.serialization.skip_data context manager "
1426-
"is not supported yet. Please call torch.load outside the skip_data context manager."
1427-
)
1428-
14291423
weights_only_not_set = weights_only is None
14301424

14311425
if weights_only_not_set:
@@ -1735,6 +1729,9 @@ def persistent_load(saved_id):
17351729
if root_key not in deserialized_objects:
17361730
if torch._guards.active_fake_mode() is not None:
17371731
obj = cast(Storage, torch.UntypedStorage(nbytes, device="meta"))
1732+
elif _serialization_tls.skip_data:
1733+
obj = cast(Storage, torch.UntypedStorage(nbytes))
1734+
obj = restore_location(obj, location)
17381735
else:
17391736
obj = cast(Storage, torch.UntypedStorage(nbytes))
17401737
obj._torch_load_uninitialized = True
@@ -1807,7 +1804,7 @@ def persistent_load(saved_id):
18071804

18081805
deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
18091806

1810-
if torch._guards.active_fake_mode() is None:
1807+
if torch._guards.active_fake_mode() is None and not _serialization_tls.skip_data:
18111808
offset = f.tell() if f_should_read_directly else None
18121809
for key in deserialized_storage_keys:
18131810
assert key in deserialized_objects
@@ -1999,6 +1996,9 @@ def load_tensor(dtype, numel, key, location):
19991996
nbytes = numel * torch._utils._element_size(dtype)
20001997
storage = torch.UntypedStorage(nbytes, device="meta")
20011998
storage._checkpoint_offset = zip_file.get_record_offset(name)
1999+
elif _serialization_tls.skip_data:
2000+
nbytes = numel * torch._utils._element_size(dtype)
2001+
storage = torch.UntypedStorage(nbytes)
20022002
elif overall_storage is not None:
20032003
if can_calculate_storage_offsets and calculate_storage_offsets:
20042004
storage_offset = _get_offset(key, name, numel)

0 commit comments

Comments
 (0)