Skip to content

Commit f3f305e

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Fix condition for weights_only unpickler for DTensor (pytorch#140740)
Same as pytorch#140739 but for DTensor (move safe globals for DTensor to `torch.distributed.tensor.__init__` and update error message to let user know `torch.distributed.tensor` must be imported to load DTensor) Differential Revision: [D65961690](https://our.internmc.facebook.com/intern/diff/D65961690) Pull Request resolved: pytorch#140740 Approved by: https://github.com/malfet ghstack dependencies: pytorch#140739
1 parent b63a848 commit f3f305e

File tree

3 files changed

+66
-17
lines changed

3 files changed

+66
-17
lines changed

test/distributed/_tensor/test_dtensor.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
# Owner(s): ["oncall: distributed"]
33

44
import os
5+
import pathlib
6+
import tempfile
7+
import unittest
58

69
from numpy.testing import assert_array_equal
710

@@ -28,7 +31,7 @@
2831
parallelize_module,
2932
RowwiseParallel,
3033
)
31-
from torch.testing._internal.common_utils import run_tests
34+
from torch.testing._internal.common_utils import IS_FBCODE, run_tests
3235
from torch.testing._internal.distributed._tensor.common_dtensor import (
3336
DTensorTestBase,
3437
with_comms,
@@ -542,6 +545,33 @@ def test_dtensor_save_load(self):
542545
reloaded_st = torch.load(buffer, weights_only=True)
543546
self.assertEqual(sharded_tensor, reloaded_st)
544547

548+
@with_comms
549+
@unittest.skipIf(
550+
IS_FBCODE,
551+
"subprocess import torch fails with ModuleNotFoundError: No module named 'torch' in fbcode",
552+
)
553+
def test_dtensor_save_load_import(self):
554+
for should_import in [True, False]:
555+
device_mesh = self.build_device_mesh()
556+
placements = [Shard(0)]
557+
local_tensor = torch.randn(3, 3)
558+
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
559+
with tempfile.NamedTemporaryFile() as f:
560+
torch.save(sharded_tensor, f)
561+
import_string = (
562+
"import torch.distributed.tensor;" if should_import else ""
563+
)
564+
filename = pathlib.Path(f.name)
565+
err_msg = (
566+
(
567+
"_pickle.UnpicklingError: Weights only load failed. "
568+
"``torch.distributed.tensor`` must be imported to load DTensors"
569+
)
570+
if not should_import
571+
else None
572+
)
573+
self._attempt_load_from_subprocess(filename, import_string, err_msg)
574+
545575

546576
class DTensorMeshTest(DTensorTestBase):
547577
@property
@@ -943,9 +973,11 @@ def test_split_tensor_1D(self) -> None:
943973
from torch.distributed.tensor._collective_utils import unpad_tensor
944974

945975
unpadded_list = [
946-
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
947-
if pad_sizes[i] > 0
948-
else tensor
976+
(
977+
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
978+
if pad_sizes[i] > 0
979+
else tensor
980+
)
949981
for i, tensor in enumerate(splitted_tensor_list)
950982
]
951983
expected_is_tensor_empty = [

torch/_weights_only_unpickler.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -169,19 +169,6 @@ def _get_allowed_globals():
169169
"builtins.bytearray": bytearray, # for bytearray
170170
"builtins.set": set, # for set
171171
}
172-
# Only add the dtensor related classes if the dtensor module is available
173-
if hasattr(torch.distributed, "tensor"):
174-
dtensor_rc: Dict[str, Any] = {
175-
# DTensor related
176-
"torch.distributed.device_mesh.DeviceMesh": torch.distributed.device_mesh.DeviceMesh,
177-
"torch.distributed.tensor._dtensor_spec.DTensorSpec": torch.distributed.tensor._dtensor_spec.DTensorSpec,
178-
"torch.distributed.tensor._dtensor_spec.TensorMeta": torch.distributed.tensor._dtensor_spec.TensorMeta,
179-
"torch.distributed.tensor.DTensor": torch.distributed.tensor.DTensor,
180-
"torch.distributed.tensor.placement_types.Partial": torch.distributed.tensor.placement_types.Partial,
181-
"torch.distributed.tensor.placement_types.Replicate": torch.distributed.tensor.placement_types.Replicate,
182-
"torch.distributed.tensor.placement_types.Shard": torch.distributed.tensor.placement_types.Shard,
183-
}
184-
rc.update(dtensor_rc)
185172

186173
# dtype
187174
for t in torch.storage._dtype_to_storage_type_map().keys():
@@ -341,6 +328,20 @@ def load(self):
341328
raise UnpicklingError(
342329
"``torch.nested`` and ``torch._dynamo`` must be imported to load nested jagged tensors (NJTs)"
343330
)
331+
elif full_path in (
332+
[
333+
"torch.distributed.device_mesh.DeviceMesh",
334+
"torch.distributed.tensor._dtensor_spec.DTensorSpec",
335+
"torch.distributed.tensor._dtensor_spec.TensorMeta",
336+
"torch.distributed.tensor.DTensor",
337+
"torch.distributed.tensor.placement_types.Partial",
338+
"torch.distributed.tensor.placement_types.Replicate",
339+
"torch.distributed.tensor.placement_types.Shard",
340+
]
341+
):
342+
raise UnpicklingError(
343+
"``torch.distributed.tensor`` must be imported to load DTensors"
344+
)
344345
else:
345346
raise UnpicklingError(
346347
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "

torch/distributed/tensor/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@
4545
"zeros",
4646
]
4747

48+
# For weights_only torch.load
49+
from ._dtensor_spec import DTensorSpec as _DTensorSpec, TensorMeta as _TensorMeta
50+
51+
52+
torch.serialization.add_safe_globals(
53+
[
54+
DeviceMesh,
55+
_DTensorSpec,
56+
_TensorMeta,
57+
DTensor,
58+
Partial,
59+
Replicate,
60+
Shard,
61+
]
62+
)
63+
4864

4965
# Append DTensor to the list of supported types for foreach implementation for optimizer
5066
# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA.

0 commit comments

Comments
 (0)