Skip to content

Commit 10f0a0f

Browse files
authored
[HybridParallel]Support fp16 in dygraph hybrid parallel (#36420)
* [HybridParallel]Support fp16 in dygraph hybrid parallel * update * update * update for recompute * add unittest of pp+fp16 * add unittest of recompute+fp16 * update * modify ut
1 parent bdac9ff commit 10f0a0f

File tree

8 files changed

+257
-31
lines changed

8 files changed

+257
-31
lines changed

python/paddle/distributed/fleet/base/fleet_base.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from ..meta_parallel import PipelineParallel, ShardingParallel
3636
from ..meta_optimizers import HybridParallelOptimizer
3737
from paddle import _C_ops
38+
from paddle.fluid import core
39+
from paddle.fluid.dygraph import to_variable
3840

3941
__all__ = []
4042

@@ -1548,26 +1550,52 @@ def unscale_method(self, optimizer):
15481550
if getattr(optimizer, '_param_groups', None) and isinstance(
15491551
optimizer._param_groups[0], dict):
15501552
param_grads = []
1553+
param_grads_fp16 = []
1554+
param_grads_fp32 = []
15511555
for group in optimizer._param_groups:
15521556
for param in group['params']:
15531557
if param._grad_ivar() is not None:
15541558
param_grads.append(param._grad_ivar())
1559+
if param._grad_ivar(
1560+
).dtype == core.VarDesc.VarType.FP16:
1561+
param_grads_fp16.append(param._grad_ivar())
1562+
else:
1563+
param_grads_fp32.append(param._grad_ivar())
15551564
else:
15561565
param_grads = [
15571566
param._grad_ivar() for param in optimizer._parameter_list
15581567
if param._grad_ivar() is not None
15591568
]
1560-
_C_ops.check_finite_and_unscale(param_grads, self._scale,
1561-
param_grads, self._found_inf)
1562-
1563-
self._found_inf = paddle.cast(self._found_inf, dtype="int32")
1569+
param_grads_fp16 = [
1570+
param._grad_ivar() for param in optimizer._parameter_list
1571+
if (param._grad_ivar() is not None) and (param._grad_ivar(
1572+
).dtype == core.VarDesc.VarType.FP16)
1573+
]
1574+
param_grads_fp32 = [
1575+
param._grad_ivar() for param in optimizer._parameter_list
1576+
if (param._grad_ivar() is not None) and (param._grad_ivar(
1577+
).dtype == core.VarDesc.VarType.FP32)
1578+
]
1579+
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
1580+
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
1581+
if len(param_grads_fp16):
1582+
_C_ops.check_finite_and_unscale(param_grads_fp16, self._scale,
1583+
param_grads_fp16,
1584+
temp_found_inf_fp16)
1585+
if len(param_grads_fp32):
1586+
_C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
1587+
param_grads_fp32,
1588+
temp_found_inf_fp32)
1589+
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
15641590

15651591
# TODO(shenliang03) Since dp allreduce in the optimizer is
15661592
# after the gradscaler, check_finite needs to synchronize global
15671593
# information. In the future, we should use check_group to speed.
15681594
paddle.distributed.all_reduce(
1569-
self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None)
1570-
self._found_inf = paddle.cast(self._found_inf, dtype="bool")
1595+
paddle.to_tensor(
1596+
[self._found_inf], dtype="int32"),
1597+
op=paddle.distributed.ReduceOp.MAX,
1598+
group=None)
15711599

15721600
# Only tensor_parallel and pipeline_parallel need to modify scaler
15731601
if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL,

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,8 @@ def forward_backward_pipeline(self, data, scaler=None):
145145
p2p.send_backward(input_tensor_grad)
146146

147147
self._layers.allreduce_shared_weight_gradients()
148-
149-
train_loss = self._broadcast_final_loss()
150-
148+
with paddle.amp.auto_cast(enable=False):
149+
train_loss = self._broadcast_final_loss()
151150
return train_loss
152151

153152
def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
@@ -172,7 +171,8 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
172171
train_loss = self.forward_backward_pipeline(data, scaler)
173172

174173
# optimizer
175-
self._optimizer_step()
174+
with paddle.amp.auto_cast(enable=False):
175+
self._optimizer_step()
176176

177177
return train_loss
178178

@@ -242,12 +242,13 @@ def _forward_step(self, input_tensor):
242242
output_tensor, paddle.Tensor
243243
), "Currently, loss_fn should obtain Paddle.Tensor dtype"
244244

245-
if self.accumulate_steps > 1:
246-
output_tensor = output_tensor / self.accumulate_steps
245+
with paddle.amp.auto_cast(enable=False):
246+
if self.accumulate_steps > 1:
247+
output_tensor = output_tensor / self.accumulate_steps
247248

248-
if self.total_loss is None:
249-
self.total_loss = paddle.zeros_like(output_tensor)
250-
self.total_loss += output_tensor.detach()
249+
if self.total_loss is None:
250+
self.total_loss = paddle.zeros_like(output_tensor)
251+
self.total_loss += output_tensor.detach()
251252

252253
self.micro_batch_id += 1
253254
return output_tensor
@@ -321,13 +322,29 @@ def _broadcast_final_loss(self):
321322
if self.is_last_stage:
322323
assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
323324
loss = self.total_loss.detach()
325+
is_fp32 = paddle.to_tensor(
326+
1) if loss.dtype == paddle.float32 else paddle.to_tensor(0)
327+
paddle.distributed.broadcast(
328+
is_fp32,
329+
src=self.global_rank,
330+
use_calc_stream=True,
331+
group=self.pp_group)
324332
paddle.distributed.broadcast(
325333
loss,
326334
src=self.global_rank,
327335
use_calc_stream=True,
328336
group=self.pp_group)
329337
else:
330-
loss = paddle.zeros(shape=[1], dtype="float32")
338+
is_fp32 = paddle.to_tensor(1)
339+
paddle.distributed.broadcast(
340+
is_fp32,
341+
src=self._hcg.get_rank_from_stage(self.num_stages - 1),
342+
use_calc_stream=True,
343+
group=self.pp_group)
344+
loss = paddle.zeros(
345+
shape=[1],
346+
dtype="float32") if is_fp32.numpy()[0] else paddle.zeros(
347+
shape=[1], dtype="float16")
331348
paddle.distributed.broadcast(
332349
loss,
333350
src=self._hcg.get_rank_from_stage(self.num_stages - 1),

python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,14 @@ def forward(ctx, run_function, all_outputs, *args):
198198

199199
# TODO support AMP
200200
tracer = framework._dygraph_tracer()
201-
if tracer._amp_level == core.AmpLevel.O0:
202-
ctx.is_fw_autocast = False
201+
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
202+
if tracer._amp_level == core.AmpLevel.O2:
203+
ctx.amp_level = 'O2'
204+
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
205+
ctx.amp_level = 'O1'
203206
else:
204-
ctx.is_fw_autocast = True
205-
ctx.amp_mode = 'O1'
207+
raise ValueError("unsupported amp level: {}".format(
208+
tracer._amp_level))
206209
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
207210

208211
with paddle.no_grad():
@@ -263,7 +266,7 @@ def backward(ctx, *args):
263266
enable=ctx.is_fw_autocast,
264267
custom_white_list=ctx.amp_white_list,
265268
custom_black_list=ctx.amp_black_list,
266-
level=ctx.amp_mode):
269+
level=ctx.amp_level):
267270
detached_inputs = detach_variable(tuple(inputs))
268271
outputs = ctx.run_function(*detached_inputs)
269272

python/paddle/distributed/fleet/utils/recompute.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,14 @@ def forward(ctx, run_function, preserve_rng_state, *args):
9898

9999
# TODO support AMP
100100
tracer = framework._dygraph_tracer()
101-
if tracer._amp_level == core.AmpLevel.O0:
102-
ctx.is_fw_autocast = False
101+
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
102+
if tracer._amp_level == core.AmpLevel.O2:
103+
ctx.amp_level = 'O2'
104+
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
105+
ctx.amp_level = 'O1'
103106
else:
104-
ctx.is_fw_autocast = True
105-
ctx.amp_mode = 'O1'
107+
raise ValueError("unsupported amp level: {}".format(
108+
tracer._amp_level))
106109
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
107110

108111
with paddle.no_grad():
@@ -133,15 +136,15 @@ def backward(ctx, *args):
133136
enable=ctx.is_fw_autocast,
134137
custom_white_list=ctx.amp_white_list,
135138
custom_black_list=ctx.amp_black_list,
136-
level=ctx.amp_mode):
139+
level=ctx.amp_level):
137140
detached_inputs = detach_variable(tuple(inputs))
138141
outputs = ctx.run_function(*detached_inputs)
139142
else:
140143
with paddle.amp.auto_cast(
141144
enable=ctx.is_fw_autocast,
142145
custom_white_list=ctx.amp_white_list,
143146
custom_black_list=ctx.amp_black_list,
144-
level=ctx.amp_mode):
147+
level=ctx.amp_level):
145148
detached_inputs = detach_variable(tuple(inputs))
146149
outputs = ctx.run_function(*detached_inputs)
147150

python/paddle/fluid/framework.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6097,7 +6097,7 @@ def __init__(self, shape, dtype, **kwargs):
60976097

60986098
self.need_clip = kwargs.get('need_clip', True)
60996099

6100-
self.is_distributed = False
6100+
self.is_distributed = kwargs.get('is_distributed', False)
61016101
# self.block = default_main_program().global_block()
61026102

61036103
@property
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import division
16+
from __future__ import print_function
17+
18+
import unittest
19+
import paddle
20+
import numpy as np
21+
import random
22+
import paddle
23+
import paddle.distributed as dist
24+
import paddle.distributed.fleet as fleet
25+
from hybrid_parallel_pp_layer import AlexNetPipeDesc, AlexNet
26+
27+
28+
def set_random_seed(seed, dp_id, rank_id):
29+
"""Set random seed for reproducability."""
30+
random.seed(seed)
31+
np.random.seed(seed + dp_id)
32+
paddle.seed(seed + dp_id)
33+
34+
35+
batch_size = 4
36+
micro_batch_size = 2
37+
38+
39+
class TestDistPPTraning(unittest.TestCase):
40+
def setUp(self):
41+
strategy = fleet.DistributedStrategy()
42+
self.model_parallel_size = 1
43+
self.data_parallel_size = 1
44+
self.pipeline_parallel_size = 2
45+
strategy.hybrid_configs = {
46+
"dp_degree": self.data_parallel_size,
47+
"mp_degree": self.model_parallel_size,
48+
"pp_degree": self.pipeline_parallel_size,
49+
}
50+
strategy.pipeline_configs = {
51+
"accumulate_steps": batch_size // micro_batch_size,
52+
"micro_batch_size": micro_batch_size
53+
}
54+
fleet.init(is_collective=True, strategy=strategy)
55+
56+
def test_pp_model(self):
57+
hcg = fleet.get_hybrid_communicate_group()
58+
word_size = hcg.get_model_parallel_world_size()
59+
dp_id = hcg.get_data_parallel_rank()
60+
pp_id = hcg.get_stage_id()
61+
rank_id = dist.get_rank()
62+
set_random_seed(1024, dp_id, rank_id)
63+
64+
#construct model a
65+
model_a = AlexNet(10)
66+
scheduler_a = paddle.optimizer.lr.PiecewiseDecay(
67+
boundaries=[2], values=[0.001, 0.002], verbose=True)
68+
optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a,
69+
parameters=model_a.parameters())
70+
71+
scaler_a = paddle.amp.GradScaler(init_loss_scaling=2**5)
72+
73+
# construct model b
74+
model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size)
75+
scheduler_b = paddle.optimizer.lr.PiecewiseDecay(
76+
boundaries=[2], values=[0.001, 0.002], verbose=True)
77+
optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b,
78+
parameters=model_b.parameters())
79+
80+
param_len = len(model_a.parameters())
81+
parameters = []
82+
for param in model_a.parameters():
83+
parameters.append(param.numpy())
84+
85+
for idx, param in enumerate(model_b.parameters()):
86+
param.set_value(parameters[idx + pp_id * (param_len // 2)])
87+
88+
model_a, optimizer_a = paddle.amp.decorate(
89+
models=model_a,
90+
optimizers=optimizer_a,
91+
level='O2',
92+
save_dtype='float32')
93+
model_b, optimizer_b = paddle.amp.decorate(
94+
models=model_b,
95+
optimizers=optimizer_b,
96+
level='O2',
97+
save_dtype='float32')
98+
99+
model_b = fleet.distributed_model(model_b)
100+
optimizer_b = fleet.distributed_optimizer(optimizer_b)
101+
scaler_b = paddle.amp.GradScaler(init_loss_scaling=2**5)
102+
scaler_b = fleet.distributed_scaler(scaler_b)
103+
104+
# construct reader
105+
train_reader = paddle.batch(
106+
paddle.dataset.mnist.train(), batch_size=batch_size, drop_last=True)
107+
108+
for step_id, data in enumerate(train_reader()):
109+
x_data = np.array([x[0] for x in data]).astype('float32').reshape(
110+
batch_size, 1, 28, 28)
111+
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
112+
batch_size, 1)
113+
img = paddle.to_tensor(x_data)
114+
label = paddle.to_tensor(y_data)
115+
img.stop_gradient = True
116+
label.stop_gradient = True
117+
118+
if step_id >= 5:
119+
return True
120+
121+
with paddle.amp.auto_cast(enable=True, level='O2'):
122+
loss_a = model_a(img, label)
123+
scaler_a.scale(loss_a).backward()
124+
with paddle.amp.auto_cast(enable=False):
125+
scaler_a.minimize(optimizer_a, loss_a)
126+
optimizer_a.clear_grad()
127+
scheduler_a.step()
128+
129+
loss_b = model_b.train_batch(
130+
[img, label], optimizer_b, scheduler_b, scaler=scaler_b)
131+
132+
print("loss: ", loss_a.numpy(), loss_b.numpy())
133+
np.testing.assert_allclose(
134+
loss_a.numpy(), loss_b.numpy(), rtol=5e-3)
135+
136+
137+
if __name__ == "__main__":
138+
unittest.main()

0 commit comments

Comments
 (0)