Skip to content

Commit 64f769d

Browse files
authored
[Dygraph] Remove unrequired UT cases of DP in eager mode (#41413)
* remove unrequired ut cases * update * fix bugs * update
1 parent 6f4bd0e commit 64f769d

File tree

7 files changed

+82
-77
lines changed

7 files changed

+82
-77
lines changed

python/paddle/distributed/fleet/utils/hybrid_parallel_util.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import paddle
2121
from paddle.fluid import core
2222
from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, build_groups
23+
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
2324
from collections import OrderedDict
2425
from .log_util import logger
2526

@@ -58,6 +59,30 @@ def _apply_collective_grads(parameters, comm_group):
5859
_split_tensors(coalesced_grads_and_vars)
5960

6061

62+
def _apply_collective_grads_eager(parameters, comm_group):
63+
grad_var_set = set()
64+
grad_vars = []
65+
66+
for param in parameters:
67+
if param.trainable and (param._grad_ivar() is not None):
68+
g_var = param._grad_ivar()
69+
assert not g_var.is_sparse(
70+
), "Now, it doesn't support sparse parameters"
71+
grad_vars.append(g_var)
72+
assert g_var not in grad_var_set
73+
grad_var_set.add(g_var)
74+
75+
coalesced_grads_and_vars = build_groups(grad_vars, 128 * 1024 * 1024)
76+
77+
div_factor = 1.0 / comm_group.nranks
78+
for coalesced_grad, _, _ in coalesced_grads_and_vars:
79+
# need to div nranks
80+
coalesced_grad.scale_(div_factor)
81+
paddle.distributed.all_reduce(coalesced_grad, group=comm_group)
82+
83+
_split_tensors(coalesced_grads_and_vars)
84+
85+
6186
def _broadcast_data_help(data, shape, dtype, hcg):
6287
model_parallel_group = hcg.get_model_parallel_group()
6388
src_rank = hcg.get_model_parallel_group_src_rank()
@@ -115,10 +140,17 @@ def broadcast_dp_parameters(model, hcg):
115140

116141

117142
def fused_allreduce_gradients(parameter_list, hcg):
118-
data_parallel_group = None if hcg is None else hcg.get_data_parallel_group()
119-
logger.debug("dp start fuse allreduce gradients")
120-
with framework.no_grad():
121-
_apply_collective_grads(parameter_list, data_parallel_group)
143+
if _in_legacy_dygraph():
144+
data_parallel_group = None if hcg is None else hcg.get_data_parallel_group(
145+
)
146+
logger.debug("dp start fuse allreduce gradients")
147+
with framework.no_grad():
148+
_apply_collective_grads(parameter_list, data_parallel_group)
149+
elif in_dygraph_mode():
150+
assert hcg is None, "It's not support to use hcg in EagerDygraph now."
151+
data_parallel_group = paddle.distributed.collective._get_default_group()
152+
with framework.no_grad():
153+
_apply_collective_grads_eager(parameter_list, data_parallel_group)
122154

123155

124156
def sharding_reduce_gradients(parameter_list, hcg):

python/paddle/fluid/dygraph/parallel.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from contextlib import contextmanager
2323

2424
import paddle
25+
from paddle import _C_ops
2526
from paddle.fluid import core
2627
from paddle.fluid import framework
2728
from paddle.fluid.dygraph import layers
@@ -307,17 +308,28 @@ def _reshape_inplace(x, shape):
307308

308309
@framework.dygraph_only
309310
def _split_tensors(coalesced_grads_and_grad_vars):
310-
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
311-
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
312-
framework._dygraph_tracer().trace_op(
313-
type='split',
314-
inputs={'X': coalesced_grad},
315-
outputs={'Out': origin_grad_vars},
316-
attrs={'sections': grad_var_len,
317-
'axis': 0})
318-
for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
319-
_reshape_inplace(x=g_var, shape=g_shape)
320-
assert g_var.shape == g_shape
311+
if _in_legacy_dygraph():
312+
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
313+
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
314+
framework._dygraph_tracer().trace_op(
315+
type='split',
316+
inputs={'X': coalesced_grad},
317+
outputs={'Out': origin_grad_vars},
318+
attrs={'sections': grad_var_len,
319+
'axis': 0})
320+
for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
321+
_reshape_inplace(x=g_var, shape=g_shape)
322+
assert g_var.shape == g_shape
323+
elif in_dygraph_mode():
324+
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
325+
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
326+
attrs = ()
327+
attrs += ('sections', grad_var_len)
328+
attrs += ('axis', 0)
329+
_C_ops.split(coalesced_grad, origin_grad_vars, *attrs)
330+
for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
331+
g_var.reshape_(shape=g_shape)
332+
assert g_var.shape == g_shape
321333

322334

323335
def scale_loss(loss):

python/paddle/fluid/tests/unittests/parallel_dygraph_dataparallel_with_pylayer.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
import numpy as np
2222
import paddle.distributed as dist
2323
from paddle.fluid.dygraph.nn import Linear
24-
from paddle.autograd import PyLayer
24+
from paddle.autograd import PyLayer, EagerPyLayer
25+
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
2526
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
2627

2728
batch = 5
@@ -43,6 +44,20 @@ def backward(ctx, dy):
4344
return grad
4445

4546

47+
class cus_tanh_eager(EagerPyLayer):
48+
@staticmethod
49+
def forward(ctx, x):
50+
y = paddle.tanh(x)
51+
ctx.save_for_backward(y)
52+
return y
53+
54+
@staticmethod
55+
def backward(ctx, dy):
56+
y, = ctx.saved_tensor()
57+
grad = dy * (1 - paddle.square(y))
58+
return grad
59+
60+
4661
class SimpleNet(paddle.nn.Layer):
4762
def __init__(self, train_id, model_id):
4863
super(SimpleNet, self).__init__()
@@ -55,7 +70,10 @@ def __init__(self, train_id, model_id):
5570

5671
def forward(self, inputs):
5772
if self.model_id == 0:
58-
inputs = cus_tanh.apply(inputs)
73+
if in_dygraph_mode():
74+
inputs = cus_tanh_eager.apply(inputs)
75+
elif _in_legacy_dygraph():
76+
inputs = cus_tanh.apply(inputs)
5977
else:
6078
inputs = self.tanh(inputs)
6179

python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import subprocess
2424

2525
from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, TrainerProc
26+
from paddle.fluid.framework import _test_eager_guard
2627

2728

2829
def get_cluster_from_args(selected_gpus):
@@ -205,6 +206,8 @@ def test_multiple_gpus_dynamic(self):
205206

206207
class TestDataParallelWithPyLayer(TestMultipleGpus):
207208
def test_parallel_dygraph_dataparallel_with_pylayer(self):
209+
with _test_eager_guard():
210+
self.run_mnist_2gpu('parallel_dygraph_dataparallel_with_pylayer.py')
208211
self.run_mnist_2gpu('parallel_dygraph_dataparallel_with_pylayer.py')
209212

210213

python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_gloo.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -55,35 +55,5 @@ def test_sparse_embedding_fp64(self):
5555
log_name=flag_name)
5656

5757

58-
class TestParallelDygraphSparseEmdeddingEager_GLOO(TestDistBase):
59-
def _setup_config(self):
60-
self._sync_mode = False
61-
self._eager_mode = True
62-
self._gloo_mode = True
63-
self._dygraph = True
64-
65-
def test_sparse_embedding(self):
66-
self.check_with_place(
67-
"parallel_dygraph_sparse_embedding.py",
68-
delta=1e-5,
69-
check_error_log=True,
70-
log_name=flag_name)
71-
72-
73-
class TestParallelDygraphSparseEmdeddingEagerFP64_GLOO(TestDistBase):
74-
def _setup_config(self):
75-
self._sync_mode = False
76-
self._eager_mode = True
77-
self._gloo_mode = True
78-
self._dygraph = True
79-
80-
def test_sparse_embedding_fp64(self):
81-
self.check_with_place(
82-
"parallel_dygraph_sparse_embedding_fp64.py",
83-
delta=1e-5,
84-
check_error_log=True,
85-
log_name=flag_name)
86-
87-
8858
if __name__ == "__main__":
8959
unittest.main()

python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height_gloo.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,5 @@ def test_sparse_embedding(self):
4040
log_name=flag_name)
4141

4242

43-
class TestParallelDygraphSparseEmdeddingOverHeightEager_GLOO(TestDistBase):
44-
def _setup_config(self):
45-
self._sync_mode = False
46-
self._eager_mode = True
47-
self._gloo_mode = True
48-
self._dygraph = True
49-
50-
def test_sparse_embedding(self):
51-
self.check_with_place(
52-
"parallel_dygraph_sparse_embedding_over_height.py",
53-
delta=1e-7,
54-
check_error_log=True,
55-
log_name=flag_name)
56-
57-
5843
if __name__ == "__main__":
5944
unittest.main()

python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer_gloo.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,5 @@ def test_transformer(self):
5757
log_name=flag_name)
5858

5959

60-
class TestParallelDygraphTransformerEager_GLOO(TestDistBase):
61-
def _setup_config(self):
62-
self._sync_mode = False
63-
self._eager_mode = True
64-
self._gloo_mode = True
65-
self._dygraph = True
66-
67-
def test_transformer(self):
68-
self.check_with_place(
69-
"parallel_dygraph_transformer.py",
70-
delta=1e-5,
71-
check_error_log=True,
72-
log_name=flag_name)
73-
74-
7560
if __name__ == "__main__":
7661
unittest.main()

0 commit comments

Comments
 (0)