Skip to content

Commit 41bb153

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Fix get_unsafe_globals_in_checkpoint to account for user allowed globals per docstring (pytorch#140738)
bugfix: this function did not account for the user allowed globals :( Differential Revision: [D65960696](https://our.internmc.facebook.com/intern/diff/D65960696) Pull Request resolved: pytorch#140738 Approved by: https://github.com/malfet
1 parent fc813df commit 41bb153

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

test/test_serialization.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4422,6 +4422,10 @@ def test_get_unsafe_globals_in_checkpoint(self):
44224422
unsafe_globals = torch.serialization.get_unsafe_globals_in_checkpoint(f)
44234423
self.assertEqual(set(unsafe_globals), expected_unsafe_global_strs)
44244424
f.seek(0)
4425+
with torch.serialization.safe_globals([TwoTensor]):
4426+
unsafe_globals = torch.serialization.get_unsafe_globals_in_checkpoint(f)
4427+
self.assertEqual(set(unsafe_globals), set())
4428+
f.seek(0)
44254429
try:
44264430
old_get_allowed_globals = torch._weights_only_unpickler._get_allowed_globals
44274431
torch._weights_only_unpickler._get_allowed_globals = lambda: dict() # noqa: PIE807

torch/serialization.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,13 @@ def get_unsafe_globals_in_checkpoint(f: FILE_LIKE) -> List[str]:
342342
Returns:
343343
A list of strings of pickle GLOBALs in the checkpoint that are not allowlisted for ``weights_only``.
344344
"""
345-
safe_global_strings = set(_weights_only_unpickler._get_allowed_globals().keys())
345+
default_safe_globals_strings = set(
346+
_weights_only_unpickler._get_allowed_globals().keys()
347+
)
348+
user_safe_global_strings = set(
349+
_weights_only_unpickler._get_user_allowed_globals().keys()
350+
)
351+
safe_global_strings = default_safe_globals_strings.union(user_safe_global_strings)
346352

347353
with _open_file_like(f, "rb") as opened_file:
348354
if not _is_zipfile(opened_file):

0 commit comments

Comments
 (0)