Skip to content

Commit 18c6f40

Browse files
authored
optimizer sharding paramters (#39581)
1 parent 1f7f856 commit 18c6f40

File tree

7 files changed

+45
-78
lines changed

7 files changed

+45
-78
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def __init__(self,
6565
params,
6666
optim,
6767
group=None,
68-
broadcast_fp16=False,
6968
offload=False,
7069
device="gpu",
70+
pertrain_sync_models=True,
7171
**kw):
7272

7373
super().__init__(optim._learning_rate, params, kw)
@@ -98,8 +98,12 @@ def __init__(self,
9898

9999
self.world_size = self.group.nranks
100100
self.rank = self.group.rank
101+
self._global_root_rank = 0
102+
103+
# Synchronous all ranks models
104+
if pertrain_sync_models:
105+
self._sync_params_and_buffers()
101106

102-
self.broadcast_fp16 = broadcast_fp16
103107
self.param_storages = {} # {dtype: {rank: InternalStorage}}
104108

105109
if isinstance(self._optim._grad_clip, ClipGradByGlobalNorm):
@@ -132,6 +136,22 @@ def __init__(self,
132136
# Update optimizer parameters and adjust parameter storage and use according to rank.
133137
self._update_opt_status()
134138

139+
@paddle.no_grad()
140+
def _sync_params_and_buffers(self):
141+
"""
142+
Sync all model states for all ranks
143+
"""
144+
145+
for p in self._local_params:
146+
dist.broadcast(
147+
p,
148+
src=self._global_root_rank,
149+
group=self.group,
150+
use_calc_stream=True)
151+
152+
# Multi stream operation will be supported later
153+
dist.wait(tensor=p, group=self.group, use_calc_stream=True)
154+
135155
def _generate_master_params(self, trainable_params):
136156
if self.offload:
137157
for param in trainable_params:

python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,10 @@ def __init__(
6161
sharding_optimizer,
6262
group=None,
6363
sync_buffers=False,
64-
pertrain_sync_models=True,
6564
buffer_max_size=2**23, #8MB
6665
auto_refresh_trainable=True,
6766
device="gpu",
68-
use_grad_storage=True,
69-
accumulate_grads=False):
67+
use_grad_storage=True):
7068
super().__init__()
7169

7270
# training options
@@ -81,9 +79,6 @@ def __init__(
8179
self._sync_buffers = sync_buffers
8280
self._auto_refresh_trainable = auto_refresh_trainable
8381

84-
# Gradient accumulation, Gradient flip
85-
self._accumulate_grads = accumulate_grads
86-
8782
# Communication related attributes
8883
self._group = dist.new_group(_get_global_group()
8984
.ranks) if group is None else group
@@ -128,16 +123,11 @@ def __init__(
128123
# Set backward pass hooks
129124
self._bw_hooks = []
130125

131-
# Synchronous all ranks models
132-
if pertrain_sync_models:
133-
self._sync_params_and_buffers()
134-
135126
# Set tasks flow
136127
self._tasks_flow = deque()
137128

138129
# Define optimizer step and clear_grad
139-
if self._accumulate_grads:
140-
self._redefine_opt_step()
130+
self._redefine_opt_step()
141131
self._redefine_opt_clear()
142132

143133
def forward(self, *inputs, **kwargs):
@@ -313,9 +303,6 @@ def reduce(*_):
313303

314304
# Change reduce information
315305
self._grad_reduced[index] = False
316-
if not self._accumulate_grads:
317-
param.grad.scale_(scale=self._world_size_scaling)
318-
param._reset_grad_inplace_version(True)
319306

320307
# Clear the gradient that does not belong to the current rank through the callback function
321308
def cleanup():
@@ -362,11 +349,6 @@ def reduce(*_):
362349
if grad_storage.all_checked_in:
363350
assert grad_storage.buffer is not None
364351

365-
# Normalize all ranks grad_storage
366-
if not self._accumulate_grads:
367-
grad_storage.buffer.scale_(
368-
scale=self._world_size_scaling)
369-
370352
# Clearing up the grad_storage buffer
371353
def cleanup():
372354
if dst_rank != self._rank:
@@ -432,22 +414,6 @@ def _setup_backward_hooks(self):
432414
self._bw_hooks.append(
433415
param._register_backward_hook(reduce_function))
434416

435-
@paddle.no_grad()
436-
def _sync_params_and_buffers(self):
437-
"""
438-
Sync all model states for all ranks
439-
"""
440-
441-
for t in self._layer.parameters():
442-
dist.broadcast(
443-
t,
444-
src=self._global_root_rank,
445-
group=self._group,
446-
use_calc_stream=True)
447-
448-
# Multi stream operation will be supported later
449-
dist.wait(tensor=t, group=self._group, use_calc_stream=True)
450-
451417
def _setup_use_grad_storage(self):
452418
"""
453419
Integrate the parameters gradient into a continuous memory according to rank, and support the update of training parameters.
@@ -555,8 +521,6 @@ def _rank_buffer_size(self, buffer_max_size, model_size):
555521
return rank_buffer_size
556522

557523
def _redefine_opt_step(self):
558-
if not self._accumulate_grads:
559-
return
560524
grad_func = self._grad_scale
561525
for opt in self._sharding_optimizers:
562526
opt_step = opt.step

python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def __init__(self,
7272
device="gpu",
7373
segment_size=2**15,
7474
pertrain_sync_models=True,
75-
accumulate_grads=False,
7675
offload=False,
7776
sync_comm=False):
7877
super().__init__()
@@ -82,7 +81,6 @@ def __init__(self,
8281
self._layer = layer
8382
self._default_device = device
8483
self.__sync_buffers = sync_buffers
85-
self._accumulate_grads = accumulate_grads
8684
self._offload = offload
8785
self._sync_comm = sync_comm
8886
# segmentation size
@@ -190,6 +188,7 @@ def _clear_gradients(self):
190188
param.fw_storage.clear_gradient(False)
191189
param.fw_storage._gradient_set_empty(False)
192190
param.bw_storage._clear()
191+
param.bw_storage = None
193192
# 2.Handle unslice param
194193
if not self._offload:
195194
for grad_storage in self._grad_storages.values():
@@ -446,13 +445,12 @@ def _update_params(self):
446445
param,
447446
"fw_storage"), "Find {} don't have fw_storage attribute".format(
448447
param.name)
449-
450-
if self._accumulate_grads:
451-
if self._offload:
452-
with device_guard(device="cpu"):
453-
param.bw_storage.scale_(scale=self._world_size_scaling)
454-
else:
448+
# Gradient average
449+
if self._offload:
450+
with device_guard(device="cpu"):
455451
param.bw_storage.scale_(scale=self._world_size_scaling)
452+
else:
453+
param.bw_storage.scale_(scale=self._world_size_scaling)
456454
param.fw_storage = _VarBaseWrapper(param)
457455
assert param.fw_storage.grad is None
458456
param.fw_storage._copy_gradient_from(param.bw_storage)
@@ -526,17 +524,14 @@ def _get_allreduce_fn(self, param):
526524
def reduce(*_):
527525
if param.name in self._task_flow.full_grad.keys():
528526
full_grad = self._task_flow.full_grad[param.name]
529-
if not self._accumulate_grads:
530-
full_grad.scale_(scale=self._world_size_scaling)
531527
# Only support sync allreduce current rank's layer now
532528
dist.all_reduce(
533529
tensor=full_grad, group=self._group, use_calc_stream=True)
534530
dist.wait(
535531
tensor=full_grad, group=self._group, use_calc_stream=True)
536532

537533
start, end = self._param2buffer[param.name][self._rank]
538-
if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value(
539-
).get_tensor()._is_initialized():
534+
if param.bw_storage is None:
540535
param.bw_storage = core.VarBase(
541536
full_grad._slice(start, end)).detach().clone()
542537
if self._offload:

python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
2828
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
2929

30-
seed = 2021
30+
seed = 2022
3131
epoch = 2
3232
linear_size = 1000
3333

@@ -105,11 +105,7 @@ def train_mlp(model,
105105
params=model.parameters(), optim=optimizer, group=group)
106106

107107
model = ShardingStage2(
108-
model,
109-
optimizer,
110-
group=group,
111-
buffer_max_size=2**21,
112-
accumulate_grads=batch_size == 20)
108+
model, optimizer, group=group, buffer_max_size=2**21)
113109
else:
114110
optimizer = fleet.distributed_optimizer(optimizer)
115111
model = fleet.distributed_model(model)
@@ -140,6 +136,8 @@ def train_mlp(model,
140136
loss = paddle.nn.functional.cross_entropy(input=out, label=label)
141137

142138
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
139+
if batch_size == 20:
140+
avg_loss = avg_loss / 5
143141
avg_loss.backward()
144142

145143
if not accumulate_grad:
@@ -166,6 +164,7 @@ def test_dp_stage2():
166164
mlp4.set_state_dict(state_dict)
167165
mlp5.set_state_dict(state_dict)
168166

167+
# DP VS stage2
169168
dp_params = train_mlp(
170169
mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=False)
171170
stage2_params = train_mlp(
@@ -174,7 +173,8 @@ def test_dp_stage2():
174173
np.testing.assert_allclose(
175174
dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6)
176175

177-
stage2_params = train_mlp(mlp3, sharding_stage=2)
176+
# stage2 accumulate grad
177+
stage2_params = train_mlp(mlp3, sharding_stage=2, accumulate_grad=True)
178178
stage2_accumulate_grad = train_mlp(
179179
mlp4, sharding_stage=2, batch_size=20, accumulate_grad=True)
180180
for i in range(len(stage2_params)):
@@ -184,6 +184,7 @@ def test_dp_stage2():
184184
rtol=1e-5,
185185
atol=1e-5)
186186

187+
# stage2 param list VS param group
187188
stage2_params = train_mlp(
188189
mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True)
189190
for i in range(len(dp_params)):

python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,12 @@ def train_mlp(model, offload=False):
4343
optimizer = optimizer_setting(model=model, use_pure_fp16=True)
4444

4545
model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32')
46-
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
46+
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
4747
scaler = ShardingScaler(scaler)
4848

4949
optimizer = ShardingOptimizerStage2(
5050
params=model.parameters(), optim=optimizer, offload=offload)
51-
model = ShardingStage2(
52-
model, optimizer, buffer_max_size=2**21, accumulate_grads=False)
51+
model = ShardingStage2(model, optimizer, buffer_max_size=2**21)
5352

5453
train_reader = paddle.batch(
5554
reader_decorator(linear_size), batch_size=batch_size, drop_last=True)

python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,10 @@ def train_mlp(model,
101101
optimizer = ShardingOptimizerStage2(
102102
params=model.parameters(), optim=optimizer, group=group)
103103
model = ShardingStage2(
104-
model,
105-
optimizer,
106-
group=group,
107-
buffer_max_size=2**21,
108-
accumulate_grads=batch_size == 20)
104+
model, optimizer, group=group, buffer_max_size=2**21)
109105
elif sharding_stage == 3:
110106
model = ShardingStage3(
111-
model,
112-
optimizer=optimizer,
113-
group=group,
114-
accumulate_grads=batch_size == 20,
115-
sync_comm=recompute)
107+
model, optimizer=optimizer, group=group, sync_comm=recompute)
116108

117109
# check optimizer.minimize() error
118110
if test_minimize:
@@ -231,7 +223,7 @@ def test_stage2_stage3():
231223
stage2_params[i].numpy(),
232224
stage3_params[i].numpy(),
233225
rtol=1e-4,
234-
atol=1e-4)
226+
atol=1e-3)
235227

236228
# fp16 recompute
237229
stage3_params = train_mlp(

python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,7 @@ def train_mlp(model,
9191
scaler = ShardingScaler(scaler)
9292

9393
model = ShardingStage3(
94-
model,
95-
optimizer=optimizer,
96-
group=group,
97-
offload=offload,
98-
accumulate_grads=accumulate_grad)
94+
model, optimizer=optimizer, group=group, offload=offload)
9995

10096
train_reader = paddle.batch(
10197
reader_decorator(), batch_size=batch_size, drop_last=True)

0 commit comments

Comments
 (0)