Skip to content

Commit 64221d0

Browse files
committed
Fixes loading with weights_only for PersistenDataset by force converting to tensors before saving.
Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 93a5dd1 commit 64221d0

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

monai/data/dataset.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,9 @@ class PersistentDataset(Dataset):
207207
not guaranteed, so caution should be used when modifying transforms to avoid unexpected
208208
errors. If in doubt, it is advisable to clear the cache directory.
209209
210-
Loading is done using `torch.load` with `weights_only=False`, thus the user must ensure the data
211-
being loaded is safe. Typically this will be cached data the user created themselves, if data
212-
from external sources is used this should be validated for safely independently.
210+
Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will
211+
be converted to tensors, however any other object type returned by transforms will not be loadable since
212+
`torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects.
213213
214214
Lazy Resampling:
215215
If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to
@@ -376,7 +376,7 @@ def _cachecheck(self, item_transformed):
376376
if hashfile is not None and hashfile.is_file(): # cache hit
377377
try:
378378
# Loading with weights_only=False is expected to be safe as these should be the user's own cached data
379-
return torch.load(hashfile, weights_only=False)
379+
return torch.load(hashfile, weights_only=True)
380380
except PermissionError as e:
381381
if sys.platform != "win32":
382382
raise e
@@ -397,7 +397,7 @@ def _cachecheck(self, item_transformed):
397397
with tempfile.TemporaryDirectory() as tmpdirname:
398398
temp_hash_file = Path(tmpdirname) / hashfile.name
399399
torch.save(
400-
obj=_item_transformed,
400+
obj=convert_to_tensor(_item_transformed),
401401
f=temp_hash_file,
402402
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
403403
pickle_protocol=self.pickle_protocol,
@@ -1655,7 +1655,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
16551655
meta_hash_file = self.cache_dir / meta_hash_file_name
16561656
temp_hash_file = Path(tmpdirname) / meta_hash_file_name
16571657
torch.save(
1658-
obj=self._meta_cache[meta_hash_file_name],
1658+
obj=convert_to_tensor(self._meta_cache[meta_hash_file_name]),
16591659
f=temp_hash_file,
16601660
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
16611661
pickle_protocol=self.pickle_protocol,
@@ -1675,4 +1675,4 @@ def _load_meta_cache(self, meta_hash_file_name):
16751675
if meta_hash_file_name in self._meta_cache:
16761676
return self._meta_cache[meta_hash_file_name]
16771677
else:
1678-
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
1678+
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=True)

0 commit comments

Comments
 (0)