Skip to content

Commit 53d20cf

Browse files
committed
add unittests
1 parent bc48eb9 commit 53d20cf

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import division
16+
from __future__ import print_function
17+
18+
import paddle
19+
import numpy as np
20+
from hybrid_parallel_mp_model import TestDistMPTraning
21+
import paddle.distributed.fleet as fleet
22+
import unittest
23+
24+
25+
class TestMPFP16(TestDistMPTraning):
26+
def build_optimizer(self, model):
27+
grad_clip = paddle.nn.ClipGradByGlobalNorm(1.0)
28+
scheduler = paddle.optimizer.lr.ExponentialDecay(
29+
learning_rate=0.001, gamma=0.999, verbose=True)
30+
optimizer = paddle.optimizer.SGD(scheduler,
31+
grad_clip=grad_clip,
32+
parameters=model.parameters())
33+
34+
model, optimizer = paddle.amp.decorate(
35+
models=model,
36+
optimizers=optimizer,
37+
level='O2',
38+
save_dtype='float32')
39+
40+
return optimizer
41+
42+
def train_batch(self, batch, model, optimizer, is_mp):
43+
scaler = paddle.amp.GradScaler(init_loss_scaling=5160)
44+
if is_mp:
45+
scaler = fleet.distributed_scaler(scaler)
46+
with paddle.amp.auto_cast(enable=True, level="O2"):
47+
output = model(batch)
48+
loss = output.mean()
49+
50+
scaled = scaler.scale(loss)
51+
scaled.backward()
52+
scaler.step(optimizer)
53+
scaler.update()
54+
optimizer.clear_grad()
55+
return scaled
56+
57+
58+
if __name__ == "__main__":
59+
unittest.main()

python/paddle/fluid/tests/unittests/test_parallel_dygraph_tensor_parallel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def test_hybrid_parallel_mp_model(self):
3030
def test_hybrid_parallel_mp_amp(self):
3131
self.run_mnist_2gpu('hybrid_parallel_mp_amp.py')
3232

33+
def test_hybrid_parallel_mp_fp16(self):
34+
self.run_mnist_2gpu('hybrid_parallel_mp_fp16.py')
35+
3336
def test_hybrid_parallel_mp_clip_grad(self):
3437
self.run_mnist_2gpu('hybrid_parallel_mp_clip_grad.py')
3538

0 commit comments

Comments
 (0)