Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,24 @@ def test_global_mesh(self):

self.assertEqual(id(mesh), id(expected_mesh))

def test_mark_manual_sharding(self):
x = torch.zeros(3, 2).to(xm.xla_device())
with self.assertRaises(RuntimeError):
xt = xs._mark_manual_sharding(x)

xx = x + 1
xt = xs._mark_manual_sharding(xx)

hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt.global_tensor])
self.assertIn(', sharding={manual}', hlo)
self.assertEqual(xt.sharding_type, xs.ShardingType.MANUAL)
self.assertEqual(xt.sharding_spec, "{manual}")

# It looks like XLA does't like only having manual sharding in the HLO.
# It needs to be paired with SPMDFullToShardShape/SPMDShardToFullShape.
# The following exception cannot be caught somehow.
# xt.global_tensor.cpu()


if __name__ == '__main__':
test = unittest.main()
Expand Down
11 changes: 11 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,17 @@ void InitXlaModuleBindings(py::module m) {
[](const at::Tensor& input, xla::OpSharding sharding) {
ShardingUtil::XlaMarkSharding(input, sharding);
});
m.def("_mark_manual_sharding", [](const at::Tensor& input,
xla::OpSharding sharding) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
bool is_ir = xtensor->CurrentIrValue();
if (is_ir) {
is_ir = !DeviceData::Cast(xtensor->CurrentIrValue().node.get());
}
XLA_CHECK(is_ir) << "Marking any data tensors as manual is not supported";

ShardingUtil::XlaMarkSharding(input, sharding);
});
m.def("_xla_mark_sharding_dynamo_custom_op",
[](const at::Tensor& input, const py::list& tile_assignment,
const py::list& group_assignment, const py::list& replication_groups,
Expand Down
12 changes: 5 additions & 7 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ xla::OpSharding ShardingUtil::CreateOpSharding(
xla::OpSharding sharding;
switch (sharding_type) {
case ShardingType::MANUAL: {
TF_LOG(ERROR) << "Invalid arguments: sharding_type (MANUAL) is "
<< "currently not supported";
sharding = xla::HloSharding::Manual().ToProto();
break;
}
case ShardingType::TUPLE: {
Expand Down Expand Up @@ -323,7 +322,7 @@ std::vector<int64_t> ShardingUtil::GetShardShape(

return shard_shape;
} else {
TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type();
XLA_CHECK(false) << "Unsupported OpSharding type " << sharding.type();
}
}

Expand Down Expand Up @@ -429,7 +428,7 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices(
shard_indices[device_index[core]] = std::make_pair(replica_id, indices);
}
} else {
TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type();
XLA_CHECK(false) << "Unsupported OpSharding type " << sharding.type();
}
return shard_indices;
}
Expand Down Expand Up @@ -488,9 +487,8 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor(
shards[i], c10::IntArrayRef(pads.data(), pads.size()), 0);
}
}
} else if ((sharding.type() == xla::OpSharding::MANUAL) ||
(sharding.type() == xla::OpSharding::TUPLE)) {
TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type();
} else {
XLA_CHECK(false) << "Unsupported OpSharding type " << sharding.type();
}
return shards;
}
Expand Down
11 changes: 6 additions & 5 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,12 @@ class ShardingUtil {
// based on the `sharding` spec. REPLICATED sharding should result in shards
// identical to the input; OTHERS (tiled) sharding result in shards where
// each data dimension is sharded across devices along the same dimension in
// the `tile_assignment`; the returned tensor shards vector is indexed by the
// device IDs. There is no data duplication. Shards are not padded in case the
// input tensor is not evenly partitionable, unless `padded` is set.
// The the returned tensors will be in 1:1 correspondence with the `devices`
// vector, so the `i`th result will belong on the `i`th device.
// the `tile_assignment`; the returned tensor shards vector is
// indexed by the device IDs. There is no data duplication. Shards are not
// padded in case the input tensor is not evenly partitionable, unless
// `padded` is set. The the returned tensors will be in 1:1 correspondence
// with the `devices` vector, so the `i`th result will belong on the `i`th
// device.
static std::vector<at::Tensor> ShardTensor(
const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings,
const std::vector<std::string>& devices, bool padded = true);
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/distributed/spmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from .xla_sharding import (Mesh, HybridMesh, ShardingType, ShardingSpec,
XLAPatchedLinear, mark_sharding, clear_sharding,
wrap_if_sharded, xla_patched_nn_linear_forward,
set_global_mesh, get_global_mesh)
set_global_mesh, get_global_mesh,
_mark_manual_sharding)
from .api import xla_distribute_tensor, xla_distribute_module, auto_policy

__all__ = [
Expand All @@ -22,4 +23,5 @@
"xla_patched_nn_linear_forward",
"set_global_mesh",
"get_global_mesh",
"_mark_manual_sharding",
]
12 changes: 12 additions & 0 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,18 @@ def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple):
return tuple(_partition_spec)


def _mark_manual_sharding(
t: Union[torch.Tensor, XLAShardedTensor]) -> XLAShardedTensor:
"""
This API is meant to be paired with the upcoming pause_spmd&resume_spmd APIs.
Don't use it alone.
"""
manual_sharding = torch_xla._XLAC.OpSharding([], [], [], ShardingType.MANUAL)
torch_xla._XLAC._mark_manual_sharding(
unwrap_sharded_tensor(t), manual_sharding)
return wrap_as_sharded_tensor(t)


@xr.requires_pjrt
def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor],
mesh: Mesh,
Expand Down