Skip to content
27 changes: 25 additions & 2 deletions python/paddle/distributed/auto_parallel/process_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import numpy as np

import paddle
from paddle.distributed import fleet
from paddle.distributed.collective import _get_group_map
from paddle.distributed.communication.group import is_initialized
from paddle.framework import core

Expand Down Expand Up @@ -442,8 +444,29 @@ def get_group(
f"{dim_name} not in the dimension names {self._dim_names}"
)
else:
pg = paddle.distributed.new_group(self._process_ids)
return pg
if hasattr(fleet.fleet, "_hcg"):
hcg = fleet.get_hybrid_communicate_group()
if hcg is not None:

parallel_group_map = {
"pp": hcg.get_pipe_parallel_group,
"dp": hcg.get_data_parallel_group,
"mp": hcg.get_model_parallel_group,
"sep": hcg.get_sep_parallel_group,
"sharding": hcg.get_sharding_parallel_group,
}

if dim_name not in parallel_group_map:
raise ValueError(
f"{dim_name} is not a valid dim name."
)

return parallel_group_map[dim_name]()
group_map = _get_group_map()
for group in group_map.values():
if set(group.ranks) == set(self._process_ids):
return group
return paddle.distributed.new_group(self._process_ids)
else:
if dim_name not in self._dim_names:
raise ValueError(
Expand Down
10 changes: 9 additions & 1 deletion test/auto_parallel/hybrid_strategy/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,14 @@ if((WITH_GPU) AND (LINUX))
py_test_modules(
test_process_mesh MODULES test_process_mesh ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_process_mesh PROPERTIES TIMEOUT "60" LABELS
set_tests_properties(test_process_mesh PROPERTIES TIMEOUT "150" LABELS
"RUN_TYPE=HYBRID")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
test_get_group_in_different_hybrid_configs MODULES
test_get_group_in_different_hybrid_configs ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_get_group_in_different_hybrid_configs
PROPERTIES TIMEOUT "150" LABELS "RUN_TYPE=HYBRID")
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_get_group(self):
assert isinstance(
group_1d_with_name, dist.communication.group.Group
)

assert group_1d_with_name.id == group_1d.id
# Test case 3: Single dimension mesh with wrong dim_name
try:
mesh_1d.get_group(dim_name="wrong_name")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import collective.test_communication_api_base as test_base


class TestProcessMeshDPGroupConsistency(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=200, nnode=1)

def test_dp_parallel(self):
"""Test data parallel group creation and consistency"""
_default_envs = {
"dp": "2",
"mp": "1",
"pp": "1",
"parallel_type": "dp",
"FLAGS_embedding_deterministic": "1",
"FLAGS_cudnn_deterministic": "1",
}
_changeable_envs = {
"backend": ["gpu"],
}
envs_list = test_base.gen_product_envs_list(
_default_envs, _changeable_envs
)
for envs in envs_list:
self.run_test_case(
"test_process_mesh_group_consistency.py",
user_defined_envs=envs,
)


class TestProcessMeshMPGroupConsistency(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=200, nnode=1)

def test_mp_parallel(self):
"""Test model parallel group creation and consistency"""
_default_envs = {
"dp": "1",
"mp": "2",
"pp": "1",
"parallel_type": "mp",
"FLAGS_embedding_deterministic": "1",
"FLAGS_cudnn_deterministic": "1",
}
_changeable_envs = {
"backend": ["gpu"],
}
envs_list = test_base.gen_product_envs_list(
_default_envs, _changeable_envs
)
for envs in envs_list:
self.run_test_case(
"test_process_mesh_group_consistency.py",
user_defined_envs=envs,
)


class TestProcessMeshPPGroupConsistency(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=200, nnode=1)

def test_pp_parallel(self):
"""Test pipeline parallel group creation and consistency"""
_default_envs = {
"dp": "1",
"mp": "1",
"pp": "2",
"parallel_type": "pp",
"FLAGS_embedding_deterministic": "1",
"FLAGS_cudnn_deterministic": "1",
}
_changeable_envs = {
"backend": ["gpu"],
}
envs_list = test_base.gen_product_envs_list(
_default_envs, _changeable_envs
)
for envs in envs_list:
self.run_test_case(
"test_process_mesh_group_consistency.py",
user_defined_envs=envs,
)


class TestProcessMeshSEPGroupConsistency(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=200, nnode=1)

def test_sep_parallel(self):
"""Test sequence parallel group creation and consistency"""
_default_envs = {
"dp": "1",
"mp": "1",
"pp": "1",
"sep": "2",
"sharding": "1",
"parallel_type": "sep",
"FLAGS_embedding_deterministic": "1",
"FLAGS_cudnn_deterministic": "1",
}
_changeable_envs = {
"backend": ["gpu"],
}
envs_list = test_base.gen_product_envs_list(
_default_envs, _changeable_envs
)
for envs in envs_list:
self.run_test_case(
"test_process_mesh_group_consistency.py",
user_defined_envs=envs,
)


class TestProcessMeshShardingGroupConsistency(
test_base.CommunicationTestDistBase
):
def setUp(self):
super().setUp(num_of_devices=2, timeout=200, nnode=1)

def test_sharding_parallel(self):
"""Test sharding parallel group creation and consistency"""
_default_envs = {
"dp": "1",
"mp": "1",
"pp": "1",
"sep": "1",
"sharding": "2",
"parallel_type": "sharding",
"FLAGS_embedding_deterministic": "1",
"FLAGS_cudnn_deterministic": "1",
}
_changeable_envs = {
"backend": ["gpu"],
}
envs_list = test_base.gen_product_envs_list(
_default_envs, _changeable_envs
)
for envs in envs_list:
self.run_test_case(
"test_process_mesh_group_consistency.py",
user_defined_envs=envs,
)


if __name__ == "__main__":
unittest.main() # python run
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import paddle.distributed as dist
from paddle.distributed import fleet


class TestProcessMeshGroupConsistency:
def __init__(self):
# Get configuration from environment variables
self.dp = int(os.getenv("dp", "1"))
self.mp = int(os.getenv("mp", "1"))
self.pp = int(os.getenv("pp", "1"))
self.sep = int(os.getenv("sep", "1"))
self.sharding = int(os.getenv("sharding", "1"))

# Determine which parallel type to test
self.parallel_type = os.getenv("parallel_type", "dp")

def init_dist_env(self):
"""Initialize distributed environment"""
# Configure distributed strategy
dist_strategy = fleet.DistributedStrategy()
dist_strategy.hybrid_configs = {
"dp_degree": self.dp,
"mp_degree": self.mp,
"pp_degree": self.pp,
"sep_degree": self.sep,
"sharding_degree": self.sharding,
}

# Add corresponding configuration based on parallel type
if self.sep > 1:
dist_strategy.hybrid_configs["sep_degree"] = self.sep
if self.sharding > 1:
dist_strategy.hybrid_configs["sharding_degree"] = self.sharding

fleet.init(is_collective=True, strategy=dist_strategy)

def test_process_mesh_group_consistency(self):
"""Test consistency between ProcessMesh created groups and HCG created groups"""

# Create corresponding ProcessMesh and get corresponding HCG group based on parallel type
if self.parallel_type == "dp":
mesh = dist.ProcessMesh([0, 1], dim_names=["dp"])
hcg = fleet.get_hybrid_communicate_group()
group = mesh.get_group(dim_name="dp")
hcg_group = hcg.get_data_parallel_group()

elif self.parallel_type == "mp":
mesh = dist.ProcessMesh([0, 1], dim_names=["mp"])
hcg = fleet.get_hybrid_communicate_group()
group = mesh.get_group(dim_name="mp")
hcg_group = hcg.get_model_parallel_group()

elif self.parallel_type == "pp":
mesh = dist.ProcessMesh([0, 1], dim_names=["pp"])
hcg = fleet.get_hybrid_communicate_group()
group = mesh.get_group(dim_name="pp")
hcg_group = hcg.get_pipe_parallel_group()

elif self.parallel_type == "sep":
mesh = dist.ProcessMesh([0, 1], dim_names=["sep"])
hcg = fleet.get_hybrid_communicate_group()
group = mesh.get_group(dim_name="sep")
hcg_group = hcg.get_sep_parallel_group()

elif self.parallel_type == "sharding":
mesh = dist.ProcessMesh([0, 1], dim_names=["sharding"])
hcg = fleet.get_hybrid_communicate_group()
group = mesh.get_group(dim_name="sharding")
hcg_group = hcg.get_sharding_parallel_group()

else:
raise ValueError(f"Unsupported parallel type: {self.parallel_type}")

# Verify that group ranks are consistent
group_ranks = group.ranks
hcg_group_ranks = hcg_group.ranks
assert set(group_ranks) == set(hcg_group_ranks)

# Verify that group IDs are consistent
group_id = group.id
hcg_group_id = hcg_group.id
assert group_id == hcg_group_id

def run_test_cases(self):
"""Run test cases"""
self.init_dist_env()
self.test_process_mesh_group_consistency()


if __name__ == "__main__":
TestProcessMeshGroupConsistency().run_test_cases()