Skip to content

Commit f6b4054

Browse files
committed
processmesh support convert group
1 parent d12bf33 commit f6b4054

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

python/paddle/distributed/auto_parallel/process_mesh.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
from __future__ import annotations
1616

1717
import copy
18+
import logging
1819
from typing import TYPE_CHECKING, Any, SupportsIndex, Union
1920

2021
import numpy as np
2122

2223
import paddle
24+
from paddle.distributed.communication.group import is_initialized
2325
from paddle.framework import core
2426

27+
logger = logging.getLogger(__name__)
28+
2529
if TYPE_CHECKING:
2630
from collections.abc import Iterable, Sequence
2731
from types import TracebackType
@@ -284,6 +288,110 @@ def get_mesh_with_dim(
284288
return ProcessMesh(new_mesh[index], new_dim_names[1:])
285289
return ProcessMesh(new_mesh, new_dim_names)
286290

291+
def get_submesh_with_dim(
292+
self,
293+
dim_name: str,
294+
) -> ProcessMesh:
295+
"""
296+
Slice the current ProcessMesh based on the dim_name given to create a submesh with single dimension remained.
297+
298+
Args:
299+
dim_name (str): the name of the mesh dimension of the ProcessMesh to create the submesh for.
300+
Returns:
301+
A :class:`ProcessMesh` object
302+
303+
The following program runs on each process/rank in an SPMD manner in a world size of 8.
304+
In the first example:
305+
Calling mesh_2d["dp"] on rank 0, 4 returns a 1D submesh of DeviceMesh:([0, 4]).
306+
Calling mesh_2d["dp"] on rank 1, 5 returns a 1D submesh of DeviceMesh:([1, 5]).
307+
Calling mesh_2d["dp"] on rank 2, 6 returns a 1D submesh of DeviceMesh:([2, 6]).
308+
Calling mesh_2d["dp"] on rank 3, 7 returns a 1D submesh of DeviceMesh:([3, 7]).
309+
Calling mesh_2d["tp"] on rank 0, 1, 2, 3 returns a 1D submesh of DeviceMesh:([0, 1, 2, 3]).
310+
Calling mesh_2d["tp"] on rank 4, 5, 6, 7 returns a 1D submesh of DeviceMesh:([4, 5, 6, 7]).
311+
312+
In the second example:
313+
Calling mesh_3d["pp"] on rank 0, 4 returns a 2D submesh of DeviceMesh:([0, 4]).
314+
Calling mesh_3d["pp"] on rank 1, 5 returns a 2D submesh of DeviceMesh:([1, 5]).
315+
Calling mesh_3d["pp"] on rank 2, 6 returns a 2D submesh of DeviceMesh:([2, 6]).
316+
Calling mesh_3d["pp"] on rank 3, 7 returns a 2D submesh of DeviceMesh:([3, 7]).
317+
Calling mesh_3d["dp"] on rank 0, 2 returns a 2D submesh of DeviceMesh:([0, 2]).
318+
Calling mesh_3d["dp"] on rank 1, 3 returns a 2D submesh of DeviceMesh:([1, 3]).
319+
Calling mesh_3d["dp"] on rank 4, 6 returns a 2D submesh of DeviceMesh:([4, 6]).
320+
Calling mesh_3d["dp"] on rank 5, 7 returns a 2D submesh of DeviceMesh:([5, 7]).
321+
Calling mesh_3d["tp"] on rank 0, 1 returns a 2D submesh of DeviceMesh:([0, 1]).
322+
Calling mesh_3d["tp"] on rank 2, 3 returns a 2D submesh of DeviceMesh:([2, 3]).
323+
Calling mesh_3d["tp"] on rank 4, 5 returns a 2D submesh of DeviceMesh:([4, 5]).
324+
Calling mesh_3d["tp"] on rank 6, 7 returns a 2D submesh of DeviceMesh:([6, 7]).
325+
Examples:
326+
.. code-block:: python
327+
328+
>>> import paddle
329+
>>> import paddle.distributed as dist
330+
331+
>>> dist.init_parallel_env()
332+
>>> mesh_2d = dist.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "tp"])
333+
>>> dp_mesh = mesh_2d["dp"]
334+
>>> tp_mesh = mesh_2d["tp"]
335+
>>> mesh_3d = dist.ProcessMesh([[[0, 1],[2, 3]], [[4, 5], [6, 7]]], dim_names=["pp","dp","tp"])
336+
>>> pp_mesh = mesh_3d["pp"]
337+
>>> dp_mesh = mesh_3d["dp"]
338+
>>> tp_mesh = mesh_3d["tp"]
339+
"""
340+
341+
reorder_mesh = self.get_mesh_with_dim(dim_name)._mesh.reshape(
342+
self.get_dim_size(dim_name), -1
343+
)
344+
curr_rank = paddle.distributed.get_rank()
345+
if curr_rank not in self._process_ids:
346+
logger.warning(
347+
f"Rank {curr_rank} is not in the process mesh, just return None"
348+
)
349+
return None
350+
# find curr_rank in reorder_mesh, get the column index
351+
col_idx = np.argmax(reorder_mesh == curr_rank) % reorder_mesh.shape[-1]
352+
sub_mesh = ProcessMesh(reorder_mesh[:, col_idx], [dim_name])
353+
return sub_mesh
354+
355+
def get_group(
356+
self,
357+
dim_name: str | None = None,
358+
) -> paddle.distributed.Group:
359+
"""
360+
Convert single dimension ProcessMesh to the corresponding Group.
361+
362+
Args:
363+
dim_name (str, optional): it can be the name of the mesh dimension. Default is None.
364+
365+
Returns:
366+
A :class:`Group` object.
367+
"""
368+
369+
# check parallel environment whether ready or not
370+
assert is_initialized(), (
371+
"When you want to get a group from the ProcessMesh."
372+
" Call paddle.distributed.init_parallel_env first "
373+
"to initialize the distributed environment."
374+
)
375+
if len(self._dim_names) > 1 and dim_name is None:
376+
raise ValueError(
377+
"You should specify the dim_name when the ProcessMesh has more than one dimensions."
378+
)
379+
if len(self._dim_names) == 1:
380+
if dim_name is not None and dim_name not in self._dim_names:
381+
raise ValueError(
382+
f"{dim_name} not in the dimension names {self._dim_names}"
383+
)
384+
else:
385+
pg = paddle.distributed.new_group(self._process_ids)
386+
return pg
387+
else:
388+
if dim_name not in self._dim_names:
389+
raise ValueError(
390+
f"{dim_name} not in the dimension names {self._dim_names}"
391+
)
392+
sub_mesh = self.get_submesh_with_dim(dim_name)
393+
return sub_mesh.get_group(dim_name)
394+
287395
def __enter__(self) -> None:
288396
set_current_process_mesh(self)
289397
default_prog = paddle.static.default_main_program()

0 commit comments

Comments
 (0)