|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | 17 | import copy |
| 18 | +import logging |
18 | 19 | from typing import TYPE_CHECKING, Any, SupportsIndex, Union |
19 | 20 |
|
20 | 21 | import numpy as np |
21 | 22 |
|
22 | 23 | import paddle |
| 24 | +from paddle.distributed.communication.group import is_initialized |
23 | 25 | from paddle.framework import core |
24 | 26 |
|
| 27 | +logger = logging.getLogger(__name__) |
| 28 | + |
25 | 29 | if TYPE_CHECKING: |
26 | 30 | from collections.abc import Iterable, Sequence |
27 | 31 | from types import TracebackType |
@@ -284,6 +288,110 @@ def get_mesh_with_dim( |
284 | 288 | return ProcessMesh(new_mesh[index], new_dim_names[1:]) |
285 | 289 | return ProcessMesh(new_mesh, new_dim_names) |
286 | 290 |
|
| 291 | + def get_submesh_with_dim( |
| 292 | + self, |
| 293 | + dim_name: str, |
| 294 | + ) -> ProcessMesh: |
| 295 | + """ |
| 296 | + Slice the current ProcessMesh based on the dim_name given to create a submesh with single dimension remained. |
| 297 | +
|
| 298 | + Args: |
| 299 | + dim_name (str): the name of the mesh dimension of the ProcessMesh to create the submesh for. |
| 300 | + Returns: |
| 301 | + A :class:`ProcessMesh` object |
| 302 | +
|
| 303 | + The following program runs on each process/rank in an SPMD manner in a world size of 8. |
| 304 | + In the first example: |
| 305 | + Calling mesh_2d["dp"] on rank 0, 4 returns a 1D submesh of DeviceMesh:([0, 4]). |
| 306 | + Calling mesh_2d["dp"] on rank 1, 5 returns a 1D submesh of DeviceMesh:([1, 5]). |
| 307 | + Calling mesh_2d["dp"] on rank 2, 6 returns a 1D submesh of DeviceMesh:([2, 6]). |
| 308 | + Calling mesh_2d["dp"] on rank 3, 7 returns a 1D submesh of DeviceMesh:([3, 7]). |
| 309 | + Calling mesh_2d["tp"] on rank 0, 1, 2, 3 returns a 1D submesh of DeviceMesh:([0, 1, 2, 3]). |
| 310 | + Calling mesh_2d["tp"] on rank 4, 5, 6, 7 returns a 1D submesh of DeviceMesh:([4, 5, 6, 7]). |
| 311 | +
|
| 312 | + In the second example: |
| 313 | + Calling mesh_3d["pp"] on rank 0, 4 returns a 2D submesh of DeviceMesh:([0, 4]). |
| 314 | + Calling mesh_3d["pp"] on rank 1, 5 returns a 2D submesh of DeviceMesh:([1, 5]). |
| 315 | + Calling mesh_3d["pp"] on rank 2, 6 returns a 2D submesh of DeviceMesh:([2, 6]). |
| 316 | + Calling mesh_3d["pp"] on rank 3, 7 returns a 2D submesh of DeviceMesh:([3, 7]). |
| 317 | + Calling mesh_3d["dp"] on rank 0, 2 returns a 2D submesh of DeviceMesh:([0, 2]). |
| 318 | + Calling mesh_3d["dp"] on rank 1, 3 returns a 2D submesh of DeviceMesh:([1, 3]). |
| 319 | + Calling mesh_3d["dp"] on rank 4, 6 returns a 2D submesh of DeviceMesh:([4, 6]). |
| 320 | + Calling mesh_3d["dp"] on rank 5, 7 returns a 2D submesh of DeviceMesh:([5, 7]). |
| 321 | + Calling mesh_3d["tp"] on rank 0, 1 returns a 2D submesh of DeviceMesh:([0, 1]). |
| 322 | + Calling mesh_3d["tp"] on rank 2, 3 returns a 2D submesh of DeviceMesh:([2, 3]). |
| 323 | + Calling mesh_3d["tp"] on rank 4, 5 returns a 2D submesh of DeviceMesh:([4, 5]). |
| 324 | + Calling mesh_3d["tp"] on rank 6, 7 returns a 2D submesh of DeviceMesh:([6, 7]). |
| 325 | + Examples: |
| 326 | + .. code-block:: python |
| 327 | +
|
| 328 | + >>> import paddle |
| 329 | + >>> import paddle.distributed as dist |
| 330 | +
|
| 331 | + >>> dist.init_parallel_env() |
| 332 | + >>> mesh_2d = dist.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "tp"]) |
| 333 | + >>> dp_mesh = mesh_2d["dp"] |
| 334 | + >>> tp_mesh = mesh_2d["tp"] |
| 335 | + >>> mesh_3d = dist.ProcessMesh([[[0, 1],[2, 3]], [[4, 5], [6, 7]]], dim_names=["pp","dp","tp"]) |
| 336 | + >>> pp_mesh = mesh_3d["pp"] |
| 337 | + >>> dp_mesh = mesh_3d["dp"] |
| 338 | + >>> tp_mesh = mesh_3d["tp"] |
| 339 | + """ |
| 340 | + |
| 341 | + reorder_mesh = self.get_mesh_with_dim(dim_name)._mesh.reshape( |
| 342 | + self.get_dim_size(dim_name), -1 |
| 343 | + ) |
| 344 | + curr_rank = paddle.distributed.get_rank() |
| 345 | + if curr_rank not in self._process_ids: |
| 346 | + logger.warning( |
| 347 | + f"Rank {curr_rank} is not in the process mesh, just return None" |
| 348 | + ) |
| 349 | + return None |
| 350 | + # find curr_rank in reorder_mesh, get the column index |
| 351 | + col_idx = np.argmax(reorder_mesh == curr_rank) % reorder_mesh.shape[-1] |
| 352 | + sub_mesh = ProcessMesh(reorder_mesh[:, col_idx], [dim_name]) |
| 353 | + return sub_mesh |
| 354 | + |
| 355 | + def get_group( |
| 356 | + self, |
| 357 | + dim_name: str | None = None, |
| 358 | + ) -> paddle.distributed.Group: |
| 359 | + """ |
| 360 | + Convert single dimension ProcessMesh to the corresponding Group. |
| 361 | +
|
| 362 | + Args: |
| 363 | + dim_name (str, optional): it can be the name of the mesh dimension. Default is None. |
| 364 | +
|
| 365 | + Returns: |
| 366 | + A :class:`Group` object. |
| 367 | + """ |
| 368 | + |
| 369 | + # check parallel environment whether ready or not |
| 370 | + assert is_initialized(), ( |
| 371 | + "When you want to get a group from the ProcessMesh." |
| 372 | + " Call paddle.distributed.init_parallel_env first " |
| 373 | + "to initialize the distributed environment." |
| 374 | + ) |
| 375 | + if len(self._dim_names) > 1 and dim_name is None: |
| 376 | + raise ValueError( |
| 377 | + "You should specify the dim_name when the ProcessMesh has more than one dimensions." |
| 378 | + ) |
| 379 | + if len(self._dim_names) == 1: |
| 380 | + if dim_name is not None and dim_name not in self._dim_names: |
| 381 | + raise ValueError( |
| 382 | + f"{dim_name} not in the dimension names {self._dim_names}" |
| 383 | + ) |
| 384 | + else: |
| 385 | + pg = paddle.distributed.new_group(self._process_ids) |
| 386 | + return pg |
| 387 | + else: |
| 388 | + if dim_name not in self._dim_names: |
| 389 | + raise ValueError( |
| 390 | + f"{dim_name} not in the dimension names {self._dim_names}" |
| 391 | + ) |
| 392 | + sub_mesh = self.get_submesh_with_dim(dim_name) |
| 393 | + return sub_mesh.get_group(dim_name) |
| 394 | + |
287 | 395 | def __enter__(self) -> None: |
288 | 396 | set_current_process_mesh(self) |
289 | 397 | default_prog = paddle.static.default_main_program() |
|
0 commit comments