Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ spmd
.. autofunction:: clear_sharding
.. autofunction:: set_global_mesh
.. autofunction:: get_global_mesh
.. autofunction:: get_1d_mesh
.. autoclass:: Mesh
.. autoclass:: HybridMesh
.. autoclass:: ShardingSpec

experimental
----------------------------------
Expand Down
14 changes: 14 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,20 @@ def test_spmd_reduce_scatter_canonical_index(self):
expected_x = torch.ones(8, 8 // self.n_devices) * 4
self.assertTrue(torch.allclose(x.cpu(), expected_x))

def test_get_1d_mesh(self):
device = torch_xla.device()
mesh = xs.get_1d_mesh("data")
t1 = torch.randn(8, 8).to(device)
xt = xs.mark_sharding(t1, mesh, ("data", None))
shards = xt.local_shards
self.assertEqual(len(shards), self.n_devices)
self.assertEqual(mesh.mesh_shape, (xr.global_runtime_device_count(),))
self.assertEqual(mesh.axis_names, ("data",))

mesh_without_name = xs.get_1d_mesh()
self.assertEqual(mesh_without_name.mesh_shape,
(xr.global_runtime_device_count(),))


if __name__ == '__main__':
test = unittest.main()
Expand Down
9 changes: 5 additions & 4 deletions torch_xla/distributed/spmd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .xla_sharded_tensor import XLAShard, XLAShardedTensor
from .xla_sharding import (Mesh, HybridMesh, ShardingType, ShardingSpec,
XLAPatchedLinear, mark_sharding, clear_sharding,
wrap_if_sharded, xla_patched_nn_linear_forward,
set_global_mesh, get_global_mesh,
_mark_manual_sharding, enable_manual_sharding,
disable_manual_sharding,
get_1d_mesh, wrap_if_sharded,
xla_patched_nn_linear_forward, set_global_mesh,
get_global_mesh, _mark_manual_sharding,
enable_manual_sharding, disable_manual_sharding,
apply_backward_optimization_barrier)
from .api import xla_distribute_tensor, xla_distribute_module, auto_policy

Expand All @@ -19,6 +19,7 @@
"XLAPatchedLinear",
"mark_sharding",
"clear_sharding",
"get_1d_mesh",
"wrap_if_sharded",
"xla_distribute_tensor",
"xla_distribute_module",
Expand Down
74 changes: 72 additions & 2 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,64 @@ def get_op_sharding(self,


def set_global_mesh(mesh: Mesh):
"""
Set the global mesh that can be used for the current process.

Args:
mesh: (Mesh) Mesh object that will be the global mesh.

Example:
import torch_xla.distributed.spmd as xs
mesh = xs.get_1d_mesh("data")
xs.set_global_mesh(mesh)
"""
global _GLOBAL_MESH
_GLOBAL_MESH = mesh


def get_global_mesh():
def get_global_mesh() -> Optional[Mesh]:
"""
Get the global mesh for the current process.

Returns:
mesh: (Optional[Mesh]) Mesh object if global mesh is set, otherwise return None.

Example:
import torch_xla.distributed.spmd as xs
xs.get_global_mesh()
"""
global _GLOBAL_MESH
return _GLOBAL_MESH


def get_1d_mesh(axis_name: Optional[str] = None) -> Mesh:
"""
Helper function to return the mesh with all devices in one dimension.

Args:
axis_name: (Optional[str]) optional string to represent the axis name of the mesh

Returns:
Mesh: Mesh object

Example:
# This example is assuming 1 TPU v4-8
import torch_xla.distributed.spmd as xs
mesh = xs.get_1d_mesh("data")
print(mesh.mesh_shape)
>> (4,)
print(mesh.axis_names)
>> ('data',)
"""
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices,)
device_ids = np.array(range(num_devices))
if axis_name == None:
return Mesh(device_ids, mesh_shape)
else:
return Mesh(device_ids, mesh_shape, (axis_name,))


# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4ƒ
class HybridMesh(Mesh):
"""Creates a hybrid device mesh of devices connected with ICI and DCN networks.
Expand Down Expand Up @@ -548,6 +597,9 @@ def mark_sharding(

Examples
—------------------------------
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs

mesh_shape = (4, 2)
num_devices = xr.global_runtime_device_count()
device_ids = np.array(range(num_devices))
Expand Down Expand Up @@ -579,7 +631,25 @@ def mark_sharding(


def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:
"""Clear sharding annotation from the input tensor and return a `cpu` casted tensor."""
"""
Clear sharding annotation from the input tensor and return a `cpu` casted tensor. This
is a in place operation but will also return the same torch.Tensor back.

Args:
t (Union[torch.Tensor, XLAShardedTensor]): Tensor that we want to clear the sharding

Return:
t (torch.Tensor): tensor that without sharding.

Examples:
import torch_xla.distributed.spmd as xs
torch_xla.runtime.use_spmd()

t1 = torch.randn(8,8).to(torch_xla.device())
mesh = xs.get_1d_mesh()
xs.mark_sharding(t1, mesh, (0, None))
xs.clear_sharding(t1)
"""
torch_xla._XLAC._xla_clear_sharding(unwrap_sharded_tensor(t))
if isinstance(t, XLAShardedTensor):
return t.global_tensor
Expand Down