Skip to content

Commit 3c518eb

Browse files
committed
新增hcg判断方法
1 parent bb9bd86 commit 3c518eb

File tree

4 files changed

+23
-16
lines changed

4 files changed

+23
-16
lines changed

python/paddle/distributed/auto_parallel/process_mesh.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -392,21 +392,24 @@ def get_group(
392392
f"{dim_name} not in the dimension names {self._dim_names}"
393393
)
394394
else:
395-
hcg = fleet.get_hybrid_communicate_group()
396-
if hcg is not None:
397-
398-
parallel_group_map = {
399-
"pp": hcg.get_pipe_parallel_group,
400-
"dp": hcg.get_data_parallel_group,
401-
"mp": hcg.get_model_parallel_group,
402-
"sep": hcg.get_sep_parallel_group,
403-
"sharding": hcg.get_sharding_parallel_group,
404-
}
405-
406-
if dim_name not in parallel_group_map:
407-
raise ValueError(f"{dim_name} is not a valid dim name.")
408-
409-
return parallel_group_map[dim_name]()
395+
if not fleet._hybrid_communicate_group_is_None():
396+
hcg = fleet.get_hybrid_communicate_group()
397+
if hcg is not None:
398+
399+
parallel_group_map = {
400+
"pp": hcg.get_pipe_parallel_group,
401+
"dp": hcg.get_data_parallel_group,
402+
"mp": hcg.get_model_parallel_group,
403+
"sep": hcg.get_sep_parallel_group,
404+
"sharding": hcg.get_sharding_parallel_group,
405+
}
406+
407+
if dim_name not in parallel_group_map:
408+
raise ValueError(
409+
f"{dim_name} is not a valid dim name."
410+
)
411+
412+
return parallel_group_map[dim_name]()
410413
existing_group = None
411414
group_map = _get_group_map()
412415
for group in group_map.values():

python/paddle/distributed/fleet/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
]
5454

5555
fleet = Fleet()
56+
_hybrid_communicate_group_is_None = fleet._hybrid_communicate_group_is_None
5657
_final_strategy = fleet._final_strategy
5758
_get_applied_meta_list = fleet._get_applied_meta_list
5859
_get_applied_graph_list = fleet._get_applied_graph_list

python/paddle/distributed/fleet/fleet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,9 @@ def get_hybrid_communicate_group(self) -> HybridCommunicateGroup:
741741
assert self._hcg is not None
742742
return self._hcg
743743

744+
def _hybrid_communicate_group_is_None(self) -> bool:
745+
return self._hcg is None
746+
744747
def get_hybrid_parallel_topology(self) -> CommunicateTopology:
745748
assert self._topology is not None
746749
return self._topology

test/auto_parallel/hybrid_strategy/process_mesh_demo_unittest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_get_group(self):
9999
assert isinstance(
100100
group_1d_with_name, dist.communication.group.Group
101101
)
102-
102+
assert group_1d_with_name.id == group_1d.id
103103
# Test case 3: Single dimension mesh with wrong dim_name
104104
try:
105105
mesh_1d.get_group(dim_name="wrong_name")

0 commit comments

Comments
 (0)