|
2 | 2 | import numpy as np |
3 | 3 | import math |
4 | 4 | import zipfile |
5 | | -from collections import OrderedDict |
6 | 5 |
|
7 | 6 | def prod(x): |
8 | 7 | return math.prod(x) |
@@ -30,38 +29,23 @@ class Dummy: |
30 | 29 |
|
31 | 30 | class MyPickle(pickle.Unpickler): |
32 | 31 | def find_class(self, module, name): |
33 | | - #making all of the following available will expose a vulnerability from 2011, unclear if patched |
34 | | - #globals, getattr, dict, apply |
35 | | - |
36 | 32 | #print(module, name) |
37 | 33 | if name == 'FloatStorage': |
38 | 34 | return np.float32 |
39 | | - if name == 'IntStorage': |
40 | | - return np.int32 |
41 | 35 | if name == 'LongStorage': |
42 | 36 | return np.int64 |
43 | 37 | if name == 'HalfStorage': |
44 | 38 | return np.float16 |
45 | | - if module == 'numpy.core.multiarray' and name == 'scalar': |
46 | | - return np.core.multiarray.scalar |
47 | | - if module == 'numpy' and name == 'dtype': |
48 | | - return np.dtype |
49 | 39 | if module == "torch._utils": |
50 | 40 | if name == "_rebuild_tensor_v2": |
51 | 41 | return HackTensor |
52 | 42 | elif name == "_rebuild_parameter": |
53 | 43 | return HackParameter |
54 | | - if module == "collections" and name == "OrderedDict": |
55 | | - return OrderedDict |
56 | | - if module == '_codecs' and name == 'encode': |
57 | | - from _codecs import encode |
58 | | - return encode |
59 | | - if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': |
60 | | - return Dummy |
61 | | - if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': |
62 | | - return Dummy |
63 | 44 | else: |
64 | | - raise pickle.UnpicklingError("'%s.%s' is forbidden" % (module, name)) |
| 45 | + try: |
| 46 | + return pickle.Unpickler.find_class(self, module, name) |
| 47 | + except Exception: |
| 48 | + return Dummy |
65 | 49 |
|
66 | 50 | def persistent_load(self, pid): |
67 | 51 | return pid |
|
0 commit comments