Skip to content

Commit b0dff05

Browse files
authored
[Auto Parallel] Do the physical mapping between the process graph and the cluster graph (#37094)
* [Auto Parallel] Add the unified cluster representation * [Auto Parallel] Add the graph class for physical mapping * [Auto Parallel] Add the simple physical mapper * Set the timeout of the mapper * Merge the upstream develop unittests cmake files * Fix a bug of the process group * Remove mapper unittest from platforms which is not GPU * Move the instantiation of process group after resharding * Add the local id for devices * Update the rank mapping format * Add some comments * Remove the related files about mapping * Update the unittest for auto mapping * Remove unused rank_mapping unittest * Improve the unittest coverage * Improve the unittest coverage
1 parent 87e65a9 commit b0dff05

File tree

3 files changed

+904
-0
lines changed

3 files changed

+904
-0
lines changed
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License
14+
15+
import operator
16+
import functools
17+
import json
18+
import paddle
19+
from collections import deque
20+
from .graph import Node
21+
from .graph import Edge
22+
from .graph import Graph
23+
from .cluster import DeviceType
24+
from .process_group import get_process_group
25+
26+
27+
def is_collective_comm_op(op):
28+
comm_list = [
29+
"c_allreduce_sum", "c_allreduce_min", "c_allreduce_max",
30+
"c_allreduce_prod", "c_reduce_sum", "c_reduce_min", "c_reduce_max",
31+
"c_reduce_prod", "c_broadcast", "c_allgather"
32+
]
33+
if op.type in comm_list:
34+
return True
35+
else:
36+
return False
37+
38+
39+
def is_p2p_comm_op(op):
40+
comm_list = ["send_v2", "recv_v2"]
41+
if op.type in comm_list:
42+
return True
43+
else:
44+
return False
45+
46+
47+
def get_dtype_bytes(dtype):
48+
num_bytes = 0
49+
if dtype == paddle.float64:
50+
num_bytes = 8
51+
elif dtype == paddle.float32:
52+
num_bytes = 4
53+
elif dtype == paddle.float16:
54+
num_bytes = 2
55+
elif dtype == paddle.bfloat16:
56+
num_bytes = 2
57+
elif dtype == paddle.int64:
58+
num_bytes = 8
59+
elif dtype == paddle.int32:
60+
num_bytes = 4
61+
elif dtype == paddle.int16:
62+
num_bytes = 2
63+
elif dtype == paddle.int8:
64+
num_bytes = 1
65+
elif dtype == paddle.uint8:
66+
num_bytes = 1
67+
else:
68+
raise ValueError("Unrecognized dtype {}.".format(dtype))
69+
return num_bytes
70+
71+
72+
def get_comm_volume(comm_op, src_rank, tgt_rank):
73+
comm_volume = None
74+
if src_rank == tgt_rank:
75+
return comm_volume
76+
comm_op_type = comm_op.type
77+
if comm_op_type != "recv_v2":
78+
tensor_name = comm_op.input_arg_names[0]
79+
else:
80+
tensor_name = comm_op.output_arg_names[0]
81+
tensor = comm_op.block._find_var_recursive(tensor_name)
82+
assert tensor is not None
83+
tensor_shape = tensor.shape
84+
# Skip the batch dim
85+
new_tensor_shape = []
86+
for val in tensor_shape:
87+
if val == -1:
88+
print("Warning: -1 in the tensor shape.")
89+
new_tensor_shape.append(1)
90+
else:
91+
new_tensor_shape.append(val)
92+
tensor_size = functools.reduce(operator.mul, new_tensor_shape, 1)
93+
tensor_bytes = tensor_size * get_dtype_bytes(tensor.dtype)
94+
if "c_allreduce" in comm_op_type:
95+
comm_volume = 2 * tensor_bytes
96+
elif "c_allgather" in comm_op_type:
97+
comm_volume = tensor_bytes
98+
elif "c_broadcast" in comm_op_type:
99+
if comm_op.attr("root") == src_rank:
100+
comm_volume = tensor_bytes
101+
else:
102+
comm_volume = None
103+
elif "c_reduce" in comm_op_type:
104+
if comm_op.attr("root_id") == src_rank:
105+
comm_volume = None
106+
else:
107+
comm_volume = tensor_bytes
108+
elif "send_v2" in comm_op_type:
109+
if comm_op.attr("peer") == tgt_rank:
110+
comm_volume = tensor_bytes
111+
else:
112+
comm_volume = None
113+
elif "recv_v2" in comm_op_type:
114+
comm_volume = None
115+
else:
116+
raise ValueError("Unrecognized communication operator.")
117+
return comm_volume
118+
119+
120+
def analyze_comm_requirements_from_op(op, rank):
121+
comm_requirements_to_ranks = {}
122+
if is_collective_comm_op(op):
123+
process_group_id = op.attr("ring_id")
124+
process_group = get_process_group(process_group_id)
125+
if rank not in process_group.ranks:
126+
return comm_requirements_to_ranks
127+
for tgt_rank in process_group.ranks:
128+
comm_volume = get_comm_volume(op, rank, tgt_rank)
129+
if comm_volume is not None:
130+
comm_requirements_to_ranks[tgt_rank] = {}
131+
comm_requirements_to_ranks[tgt_rank][
132+
"comm_volume"] = comm_volume
133+
elif is_p2p_comm_op(op):
134+
tgt_rank = op.attr("peer")
135+
comm_volume = get_comm_volume(op, rank, tgt_rank)
136+
if comm_volume is not None:
137+
comm_requirements_to_ranks[tgt_rank] = {}
138+
comm_requirements_to_ranks[tgt_rank]["comm_volume"] = comm_volume
139+
else:
140+
comm_requirements_to_ranks = {}
141+
return comm_requirements_to_ranks
142+
143+
144+
def analyze_requirements_for_program(program, rank):
145+
resource_requirements = {}
146+
comm_requirements_to_ranks = {}
147+
# only support device_type and only support GPU for now
148+
resource_requirements["device_type"] = DeviceType.GPU
149+
for block in program.blocks:
150+
for op in block.ops:
151+
cur_comm_requirements_to_ranks = analyze_comm_requirements_from_op(
152+
op, rank)
153+
for tgt_rank, link_info in cur_comm_requirements_to_ranks.items():
154+
if tgt_rank in comm_requirements_to_ranks:
155+
comm_requirements_to_ranks[tgt_rank][
156+
"comm_volume"] += link_info["comm_volume"]
157+
else:
158+
comm_requirements_to_ranks[tgt_rank] = {}
159+
comm_requirements_to_ranks[tgt_rank][
160+
"comm_volume"] = link_info["comm_volume"]
161+
return resource_requirements, comm_requirements_to_ranks
162+
163+
164+
def build_process_graph(distributed_program):
165+
graph = Graph()
166+
for src_rank, src_program in distributed_program.items():
167+
resource_requirements, comm_requirements_to_ranks = analyze_requirements_for_program(
168+
src_program, src_rank)
169+
graph.add_node(src_rank, resource_requirements=resource_requirements)
170+
for tgt_rank, comm_requirements in comm_requirements_to_ranks.items():
171+
graph.add_edge(
172+
src_rank, tgt_rank, comm_requirements=comm_requirements)
173+
return graph
174+
175+
176+
def build_cluster_graph(cluster):
177+
graph = Graph()
178+
for machine in cluster.machines.values():
179+
for device in machine.devices.values():
180+
graph.add_node(device.global_id, device=device)
181+
for link in machine.links.values():
182+
graph.add_edge(
183+
link.source.global_id, link.target.global_id, link=link)
184+
return graph
185+
186+
187+
def mapping(distributed_program, cluster):
188+
# A very simple mapping algorithm only for GPUs.
189+
# Here we assume one process will be mapped to one GPU.
190+
# In the future, more mapping configurations and algorithms will be supported.
191+
process_graph = build_process_graph(distributed_program)
192+
193+
cluster_graph = build_cluster_graph(cluster)
194+
195+
for cur_rank_node in process_graph:
196+
cur_rank_node["visited"] = False
197+
198+
for cur_device_node in cluster_graph:
199+
cur_device_node["occupied"] = False
200+
201+
def sort_by_comm_volume(rank_edge):
202+
return rank_edge["comm_requirements"]["comm_volume"]
203+
204+
def sort_by_comm_bandwidth(device_edge):
205+
return device_edge["link"].bandwidth
206+
207+
def select_unvisited_rank_node(rank_node_list):
208+
selected_rank_node = None
209+
for rank_node in rank_node_list:
210+
if rank_node["visited"] is False:
211+
selected_rank_node = rank_node
212+
return selected_rank_node
213+
214+
queue = deque()
215+
root_rank_node = select_unvisited_rank_node(
216+
list(process_graph.nodes.values()))
217+
while root_rank_node is not None:
218+
queue.append(root_rank_node)
219+
while queue:
220+
cur_rank_node = queue.popleft()
221+
if cur_rank_node["visited"]:
222+
continue
223+
device_type = cur_rank_node["resource_requirements"]["device_type"]
224+
cur_device_node = None
225+
for device_node in cluster_graph.nodes.values():
226+
if (device_node["device"].type == device_type) and (
227+
not device_node["occupied"]):
228+
device_node["occupied"] = True
229+
cur_rank_node["visited"] = True
230+
cur_rank_node["device"] = device_node["device"]
231+
cur_device_node = device_node
232+
break
233+
assert cur_device_node, "Cannot find a device to satisfy the requirement."
234+
235+
nbr_rank_edges = []
236+
for nbr_rank_node_id, nbr_rank_edge in process_graph.adjs[
237+
cur_rank_node.id].items():
238+
assert nbr_rank_edge.src_id == cur_rank_node.id and nbr_rank_edge.tgt_id == nbr_rank_node_id
239+
queue.append(process_graph.nodes[nbr_rank_node_id])
240+
nbr_rank_edges.append(nbr_rank_edge)
241+
nbr_rank_edges.sort(key=sort_by_comm_volume)
242+
243+
nbr_device_edges = []
244+
for nbr_device_edge in cluster_graph.adjs[
245+
cur_device_node.id].values():
246+
nbr_device_edges.append(nbr_device_edge)
247+
nbr_device_edges.sort(key=sort_by_comm_bandwidth)
248+
249+
for nbr_rank_edge in nbr_rank_edges:
250+
src_rank_node = process_graph.nodes[nbr_rank_edge.src_id][
251+
"visited"]
252+
if src_rank_node:
253+
continue
254+
device_type = src_rank_node["resource_requirements"][
255+
"device_type"]
256+
nbr_rank_node = process_graph.nodes[nbr_rank_edge.tgt_id]
257+
for nbr_device_edge in nbr_device_edges:
258+
nbr_device_node = cluster_graph.nodes[
259+
nbr_device_edge.tgt_id]
260+
if (nbr_device_node["device"].type == device_type) and (
261+
not nbr_device_node["occupied"]):
262+
nbr_device_node["occupied"] = True
263+
nbr_rank_node["visited"] = True
264+
nbr_rank_node["device"] = nbr_device_node["device"]
265+
break
266+
root_rank_node = select_unvisited_rank_node(
267+
list(process_graph.nodes.values()))
268+
269+
rank_mapping = {}
270+
for rank, rank_node in process_graph.nodes.items():
271+
device = rank_node["device"]
272+
machine = device.machine
273+
if machine.id in rank_mapping:
274+
rank_mapping[machine.id]["hostname"] = machine.hostname
275+
rank_mapping[machine.id]["addr"] = machine.addr
276+
rank_mapping[machine.id]["port"] = machine.port
277+
if rank not in rank_mapping[machine.id]["ranks"]:
278+
rank_mapping[machine.id]["ranks"][rank] = []
279+
rank_mapping[machine.id]["ranks"][rank].append(device.local_id)
280+
else:
281+
rank_mapping[machine.id]["ranks"][rank].append(device.local_id)
282+
else:
283+
rank_mapping[machine.id] = {}
284+
rank_mapping[machine.id]["hostname"] = machine.hostname
285+
rank_mapping[machine.id]["addr"] = machine.addr
286+
rank_mapping[machine.id]["port"] = machine.port
287+
rank_mapping[machine.id]["ranks"] = {}
288+
rank_mapping[machine.id]["ranks"][rank] = []
289+
rank_mapping[machine.id]["ranks"][rank].append(device.local_id)
290+
for machine_mapping in rank_mapping.values():
291+
for rank_devices in machine_mapping["ranks"].values():
292+
rank_devices.sort()
293+
294+
return rank_mapping

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
144144
LIST(REMOVE_ITEM TEST_OPS test_disable_signal_handler)
145145
LIST(REMOVE_ITEM TEST_OPS test_fleet_executor)
146146
LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_multi_devices)
147+
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_mapper)
147148
LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_task_node)
148149
endif()
149150

0 commit comments

Comments
 (0)