@@ -30,7 +30,7 @@ class Dummy:
3030
3131 class MyPickle (pickle .Unpickler ):
3232 def find_class (self , module , name ):
33- #making the following available will expose a vulnerability from 2011:
33+ #making all of the following available will expose a vulnerability from 2011, unclear if patched
3434 #globals, getattr, dict, apply
3535
3636 #print(module, name)
@@ -42,16 +42,26 @@ def find_class(self, module, name):
4242 return np .int64
4343 if name == 'HalfStorage' :
4444 return np .float16
45+ if module == 'numpy.core.multiarray' and name == 'scalar' :
46+ return np .core .multiarray .scalar
47+ if module == 'numpy' and name == 'dtype' :
48+ return np .dtype
4549 if module == "torch._utils" :
4650 if name == "_rebuild_tensor_v2" :
4751 return HackTensor
4852 elif name == "_rebuild_parameter" :
4953 return HackParameter
5054 if module == "collections" and name == "OrderedDict" :
5155 return OrderedDict
56+ if module == '_codecs' and name == 'encode' :
57+ from _codecs import encode
58+ return encode
59+ if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint' :
60+ return Dummy
61+ if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint' :
62+ return Dummy
5263 else :
53- #return Dummy
54- raise pickle .UnpicklingError ("'%s.%s' is forbidden" % (module , name ))
64+ raise pickle .UnpicklingError ("'%s.%s' is forbidden" % (module , name ))
5565
5666 def persistent_load (self , pid ):
5767 return pid
0 commit comments