File tree Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Original file line number Diff line number Diff line change @@ -4422,6 +4422,10 @@ def test_get_unsafe_globals_in_checkpoint(self):
44224422 unsafe_globals = torch .serialization .get_unsafe_globals_in_checkpoint (f )
44234423 self .assertEqual (set (unsafe_globals ), expected_unsafe_global_strs )
44244424 f .seek (0 )
4425+ with torch .serialization .safe_globals ([TwoTensor ]):
4426+ unsafe_globals = torch .serialization .get_unsafe_globals_in_checkpoint (f )
4427+ self .assertEqual (set (unsafe_globals ), set ())
4428+ f .seek (0 )
44254429 try :
44264430 old_get_allowed_globals = torch ._weights_only_unpickler ._get_allowed_globals
44274431 torch ._weights_only_unpickler ._get_allowed_globals = lambda : dict () # noqa: PIE807
Original file line number Diff line number Diff line change @@ -342,7 +342,13 @@ def get_unsafe_globals_in_checkpoint(f: FILE_LIKE) -> List[str]:
342342 Returns:
343343 A list of strings of pickle GLOBALs in the checkpoint that are not allowlisted for ``weights_only``.
344344 """
345- safe_global_strings = set (_weights_only_unpickler ._get_allowed_globals ().keys ())
345+ default_safe_globals_strings = set (
346+ _weights_only_unpickler ._get_allowed_globals ().keys ()
347+ )
348+ user_safe_global_strings = set (
349+ _weights_only_unpickler ._get_user_allowed_globals ().keys ()
350+ )
351+ safe_global_strings = default_safe_globals_strings .union (user_safe_global_strings )
346352
347353 with _open_file_like (f , "rb" ) as opened_file :
348354 if not _is_zipfile (opened_file ):
You can’t perform that action at this time.
0 commit comments