@@ -53,6 +53,7 @@ def __init__(self,
5353 if not isinstance (device_ids , np .ndarray ):
5454 device_ids = np .array (device_ids )
5555 assert (axis_names is None ) or (len (mesh_shape ) == len (axis_names ))
56+ assert axis_names is None or (len (set (axis_names )) == len (axis_names ))
5657 assert (len (device_ids ) == np .prod (mesh_shape ))
5758 assert len (device_ids ) == len (np .unique (device_ids ))
5859 self .device_ids = device_ids
@@ -64,12 +65,20 @@ def size(self):
6465 return np .prod (self .mesh_shape )
6566
6667 def shape (self ):
68+ if self .axis_names is None :
69+ return OrderedDict (
70+ (dim , size ) for dim , size in enumerate (self .mesh_shape ))
6771 return OrderedDict (
6872 (name , size ) for name , size in zip (self .axis_names , self .mesh_shape ))
6973
7074 def get_logical_mesh (self ):
7175 return self .device_ids .reshape (self .mesh_shape )
7276
77+ def get_axis_name_idx (self , name : str ) -> int :
78+ if name not in self .axis_names :
79+ return None
80+ return self .axis_names .index (name )
81+
7382
7483# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4
7584
@@ -359,9 +368,26 @@ def _get_group_assignment(
359368 return group_assignment , replication_groups
360369
361370
371+ def _translate_named_partition_spec (mesh : Mesh , partition_spec : Tuple ):
372+ _partition_spec = list ()
373+ for p in partition_spec :
374+ if (p is None ) or (type (p ) is int ):
375+ _partition_spec .append (p )
376+ elif type (p ) is str :
377+ idx = mesh .get_axis_name_idx (p )
378+ if idx is None :
379+ raise ValueError (f"Axis name { p } is not defined in the given mesh" )
380+ _partition_spec .append (idx )
381+ else :
382+ raise ValueError (
383+ f"Spec type { type (p )} is not supported in partition spec" )
384+ return _partition_spec
385+
386+
362387@xr .requires_pjrt
363- def mark_sharding (t : Union [torch .Tensor , XLAShardedTensor ], mesh : Mesh ,
364- partition_spec : Tuple [Union [int , None ]]) -> XLAShardedTensor :
388+ def mark_sharding (
389+ t : Union [torch .Tensor , XLAShardedTensor ], mesh : Mesh ,
390+ partition_spec : Tuple [Union [int , str , None ]]) -> XLAShardedTensor :
365391 """
366392 Annotates the tensor provided with XLA partition spec. Internally,
367393 it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass.
@@ -370,7 +396,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
370396
371397 mesh (Mesh): describes the logical XLA device topology and the underlying device IDs.
372398
373- partition_spec (Tuple[int, None]): A tuple of device_mesh dimension index or `None`.
399+ partition_spec (Tuple[int, str, None]): A tuple of device_mesh dimension index or `None`. Each index is an int or str if the mesh axis is named .
374400 This specifies how each input rank is sharded (index to mesh_shape) or replicated (None).
375401 For example, we can shard an 8x10 tensor 4-way row-wise, and replicate column-wise.
376402 >> input = torch.randn(8, 10)
@@ -396,6 +422,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
396422 assert num_devices > 0 , "This requires XLA supported device(s)."
397423 assert mesh .size () == num_devices , \
398424 f"{ mesh .mesh_shape } is not mappable over { num_devices } devices."
425+ partition_spec = _translate_named_partition_spec (mesh , partition_spec )
399426 assert all ((d >= 0 and d < len (mesh .mesh_shape )) for d in partition_spec if d ), \
400427 f"partition_spec ({ partition_spec } ) contains out of bound index into mesh_shape."
401428 # We only allow fully specified `partition_spec` to be applicable, as opposed
0 commit comments