Skip to content

Commit 828f87a

Browse files
authored
sharding_stage2_pfp16 (#37836)
1 parent 3e33ef5 commit 828f87a

File tree

3 files changed

+36
-20
lines changed

3 files changed

+36
-20
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,14 @@ def __init__(self,
8383
# Default information
8484
self._optim_defaults = kw
8585
self._optim = optim
86+
assert hasattr(self._optim, "_master_weights"
87+
), "Must use optimizer with _master_weights attribute"
8688
self._local_params = params
8789
self._default_device = device
90+
self._pfp16 = len(
91+
list(
92+
filter(lambda x: x.trainable and x.dtype == Type.fp16.value,
93+
self._local_params))) > 0
8894

8995
assert group is not None, "Distributed communication group is must be gived"
9096
self.group = group
@@ -98,6 +104,12 @@ def __init__(self,
98104
# Update optimizer parameters and adjust parameter storage and use according to rank.
99105
self.update_opt_status()
100106

107+
def _generate_master_params(self, trainable_params):
108+
for param in trainable_params:
109+
if param.dtype == Type.fp16.value:
110+
self._optim._master_weights[param.name] = paddle.cast(
111+
param, Type.fp32.value)
112+
101113
def update_opt_status(self):
102114
"""Update optimizer status and parameter storage information, and special functions to be developed.
103115
"""
@@ -207,6 +219,8 @@ def _integration_params(self):
207219
# Merge all the trainable params in a single InternalStorage
208220
trainable_params = list(
209221
filter(lambda x: x.trainable, params))
222+
if self._pfp16 and dst_rank == self.rank:
223+
self._generate_master_params(trainable_params)
210224
if trainable_params:
211225
param_storage = ParamStorage(
212226
size=self.rank_buffer_size[dtype][dst_rank],

python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import paddle.distributed as dist
3131

3232
from ...utils.internal_storage import GradStorage
33+
from ...meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
3334
from .sharding_utils import Taskflow, Type
3435

3536

@@ -70,6 +71,11 @@ def __init__(
7071
self._layer = layer
7172
self._sharding_optimizers = [sharding_optimizer] if not isinstance(
7273
sharding_optimizer, list) else sharding_optimizer
74+
assert all(
75+
list(
76+
map(lambda opt: isinstance(opt, ShardingOptimizerStage2),
77+
self._sharding_optimizers))
78+
), "Please use ShardingOptimizerStage2 optimizer"
7379
self._sync_buffers = sync_buffers
7480
self._auto_refresh_trainable = auto_refresh_trainable
7581

@@ -88,8 +94,7 @@ def __init__(
8894

8995
# Global statistical parameters
9096
self._all_params = list(
91-
chain(
92-
* [optim.local_params for optim in self._sharding_optimizers]))
97+
chain(*[optim.local_params for optim in self._sharding_optimizers]))
9398
self._trainable_params = []
9499
self._grad_reduced = []
95100
self._trainable_param2rank = {}
@@ -436,7 +441,7 @@ def _setup_use_grad_storage(self):
436441
._fill))
437442

438443
self._grad_storage_list = list(
439-
chain(* [
444+
chain(*[
440445
self._grad_storages[dtype].values()
441446
for dtype in self._grad_storages.keys()
442447
]))

python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from paddle.distributed import fleet
2525
from paddle.fluid.dygraph import nn
2626

27-
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import DygraphShardingOptimizer
2827
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
2928
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
3029

@@ -70,7 +69,7 @@ def __reader__():
7069
return __reader__
7170

7271

73-
def optimizer_setting(model, use_pure_fp16, stage=1):
72+
def optimizer_setting(model, use_pure_fp16):
7473
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
7574
optimizer = paddle.optimizer.AdamW(
7675
parameters=model.parameters(),
@@ -87,20 +86,16 @@ def train_mlp(model,
8786
use_pure_fp16=False,
8887
all_test=False,
8988
accumulate_grad=False):
90-
if sharding_stage == 1:
89+
if sharding_stage == "dp":
9190
hcg = fleet.get_hybrid_communicate_group()
9291
group = hcg.get_check_parallel_group()
9392
else:
9493
group = paddle.distributed.new_group([0, 1])
95-
optimizer = optimizer_setting(
96-
model=model, use_pure_fp16=use_pure_fp16, stage=sharding_stage)
94+
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
9795

9896
if use_pure_fp16:
99-
model, optimizer = paddle.amp.decorate(
100-
models=model,
101-
optimizers=optimizer,
102-
level='O2',
103-
save_dtype='float32')
97+
model = paddle.amp.decorate(
98+
models=model, level='O2', save_dtype='float32')
10499

105100
if sharding_stage == 2:
106101
optimizer = ShardingOptimizerStage2(
@@ -164,7 +159,7 @@ def train_mlp(model,
164159
return model.parameters()
165160

166161

167-
def test_stage1_stage2():
162+
def test_dp_stage2():
168163
mlp = MLP()
169164
state_dict = mlp.state_dict()
170165
mlp1 = MLP()
@@ -175,11 +170,13 @@ def test_stage1_stage2():
175170
mlp2.set_state_dict(state_dict)
176171
mlp3.set_state_dict(state_dict)
177172
mlp4.set_state_dict(state_dict)
178-
stage1_params = train_mlp(mlp, sharding_stage=1, use_pure_fp16=False)
179-
stage2_params = train_mlp(mlp, sharding_stage=2, use_pure_fp16=False)
180-
for i in range(len(stage1_params)):
181-
np.testing.assert_allclose(
182-
stage1_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6)
173+
dp_params = train_mlp(mlp1, sharding_stage="dp", use_pure_fp16=False)
174+
stage2_params = train_mlp(mlp2, sharding_stage=2, use_pure_fp16=False)
175+
for i in range(len(dp_params)):
176+
for j in range(len(stage2_params)):
177+
if dp_params[i].name == stage2_params[j].name:
178+
np.testing.assert_allclose(
179+
dp_params[i].numpy(), stage2_params[j].numpy(), rtol=1e-6)
183180

184181
stage2_params = train_mlp(
185182
mlp3, sharding_stage=2, use_pure_fp16=True, all_test=True)
@@ -201,4 +198,4 @@ def test_stage1_stage2():
201198

202199

203200
if __name__ == '__main__':
204-
test_stage1_stage2()
201+
test_dp_stage2()

0 commit comments

Comments
 (0)