Skip to content

Commit 85453d2

Browse files
authored
[SPMD] named partition spec support (#5415)
[SPMD] named partition spec
1 parent 1835771 commit 85453d2

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

test/spmd/test_xla_sharding.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,18 @@ def test_sharded_tensor_to_cpu_int_type(self):
718718
partition_spec)
719719
self.assertTrue(torch.allclose(t1, xst1.cpu()))
720720

721+
def test_named_partition_spec(self):
722+
xt1 = torch.arange(64).reshape(8, 8).to(xm.xla_device())
723+
mesh = xs.Mesh(
724+
list(range(self.n_devices)), (1, self.n_devices), ('data', 'model'))
725+
partition_spec = ('model', 'data')
726+
xs.mark_sharding(xt1, mesh, partition_spec)
727+
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt1)
728+
if self.n_devices > 1:
729+
self.assertTrue(f"devices=[{self.n_devices},1]" in sharding_spec)
730+
else:
731+
self.assertTrue("replicated" in sharding_spec)
732+
721733

722734
if __name__ == '__main__':
723735
test = unittest.main()

torch_xla/experimental/xla_sharding.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self,
5353
if not isinstance(device_ids, np.ndarray):
5454
device_ids = np.array(device_ids)
5555
assert (axis_names is None) or (len(mesh_shape) == len(axis_names))
56+
assert axis_names is None or (len(set(axis_names)) == len(axis_names))
5657
assert (len(device_ids) == np.prod(mesh_shape))
5758
assert len(device_ids) == len(np.unique(device_ids))
5859
self.device_ids = device_ids
@@ -64,12 +65,20 @@ def size(self):
6465
return np.prod(self.mesh_shape)
6566

6667
def shape(self):
68+
if self.axis_names is None:
69+
return OrderedDict(
70+
(dim, size) for dim, size in enumerate(self.mesh_shape))
6771
return OrderedDict(
6872
(name, size) for name, size in zip(self.axis_names, self.mesh_shape))
6973

7074
def get_logical_mesh(self):
7175
return self.device_ids.reshape(self.mesh_shape)
7276

77+
def get_axis_name_idx(self, name: str) -> int:
78+
if name not in self.axis_names:
79+
return None
80+
return self.axis_names.index(name)
81+
7382

7483
# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4
7584

@@ -359,9 +368,26 @@ def _get_group_assignment(
359368
return group_assignment, replication_groups
360369

361370

371+
def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple):
372+
_partition_spec = list()
373+
for p in partition_spec:
374+
if (p is None) or (type(p) is int):
375+
_partition_spec.append(p)
376+
elif type(p) is str:
377+
idx = mesh.get_axis_name_idx(p)
378+
if idx is None:
379+
raise ValueError(f"Axis name {p} is not defined in the given mesh")
380+
_partition_spec.append(idx)
381+
else:
382+
raise ValueError(
383+
f"Spec type {type(p)} is not supported in partition spec")
384+
return _partition_spec
385+
386+
362387
@xr.requires_pjrt
363-
def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
364-
partition_spec: Tuple[Union[int, None]]) -> XLAShardedTensor:
388+
def mark_sharding(
389+
t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
390+
partition_spec: Tuple[Union[int, str, None]]) -> XLAShardedTensor:
365391
"""
366392
Annotates the tensor provided with XLA partition spec. Internally,
367393
it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass.
@@ -370,7 +396,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
370396
371397
mesh (Mesh): describes the logical XLA device topology and the underlying device IDs.
372398
373-
partition_spec (Tuple[int, None]): A tuple of device_mesh dimension index or `None`.
399+
partition_spec (Tuple[int, str, None]): A tuple of device_mesh dimension index or `None`. Each index is an int or str if the mesh axis is named.
374400
This specifies how each input rank is sharded (index to mesh_shape) or replicated (None).
375401
For example, we can shard an 8x10 tensor 4-way row-wise, and replicate column-wise.
376402
>> input = torch.randn(8, 10)
@@ -396,6 +422,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
396422
assert num_devices > 0, "This requires XLA supported device(s)."
397423
assert mesh.size() == num_devices, \
398424
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
425+
partition_spec = _translate_named_partition_spec(mesh, partition_spec)
399426
assert all((d >= 0 and d < len(mesh.mesh_shape)) for d in partition_spec if d), \
400427
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
401428
# We only allow fully specified `partition_spec` to be applicable, as opposed

0 commit comments

Comments
 (0)