2424from paddle .distributed import fleet
2525from paddle .fluid .dygraph import nn
2626
27- from paddle .distributed .fleet .meta_optimizers .dygraph_optimizer .dygraph_sharding_optimizer import DygraphShardingOptimizer
2827from paddle .distributed .fleet .meta_optimizers .dygraph_optimizer .sharding_optimizer_stage2 import ShardingOptimizerStage2
2928from paddle .distributed .fleet .meta_parallel .sharding .sharding_stage2 import ShardingStage2
3029
@@ -70,7 +69,7 @@ def __reader__():
7069 return __reader__
7170
7271
73- def optimizer_setting (model , use_pure_fp16 , stage = 1 ):
72+ def optimizer_setting (model , use_pure_fp16 ):
7473 clip = paddle .nn .ClipGradByGlobalNorm (clip_norm = 1.0 )
7574 optimizer = paddle .optimizer .AdamW (
7675 parameters = model .parameters (),
@@ -87,20 +86,16 @@ def train_mlp(model,
8786 use_pure_fp16 = False ,
8887 all_test = False ,
8988 accumulate_grad = False ):
90- if sharding_stage == 1 :
89+ if sharding_stage == "dp" :
9190 hcg = fleet .get_hybrid_communicate_group ()
9291 group = hcg .get_check_parallel_group ()
9392 else :
9493 group = paddle .distributed .new_group ([0 , 1 ])
95- optimizer = optimizer_setting (
96- model = model , use_pure_fp16 = use_pure_fp16 , stage = sharding_stage )
94+ optimizer = optimizer_setting (model = model , use_pure_fp16 = use_pure_fp16 )
9795
9896 if use_pure_fp16 :
99- model , optimizer = paddle .amp .decorate (
100- models = model ,
101- optimizers = optimizer ,
102- level = 'O2' ,
103- save_dtype = 'float32' )
97+ model = paddle .amp .decorate (
98+ models = model , level = 'O2' , save_dtype = 'float32' )
10499
105100 if sharding_stage == 2 :
106101 optimizer = ShardingOptimizerStage2 (
@@ -164,7 +159,7 @@ def train_mlp(model,
164159 return model .parameters ()
165160
166161
167- def test_stage1_stage2 ():
162+ def test_dp_stage2 ():
168163 mlp = MLP ()
169164 state_dict = mlp .state_dict ()
170165 mlp1 = MLP ()
@@ -175,11 +170,13 @@ def test_stage1_stage2():
175170 mlp2 .set_state_dict (state_dict )
176171 mlp3 .set_state_dict (state_dict )
177172 mlp4 .set_state_dict (state_dict )
178- stage1_params = train_mlp (mlp , sharding_stage = 1 , use_pure_fp16 = False )
179- stage2_params = train_mlp (mlp , sharding_stage = 2 , use_pure_fp16 = False )
180- for i in range (len (stage1_params )):
181- np .testing .assert_allclose (
182- stage1_params [i ].numpy (), stage2_params [i ].numpy (), rtol = 1e-6 )
173+ dp_params = train_mlp (mlp1 , sharding_stage = "dp" , use_pure_fp16 = False )
174+ stage2_params = train_mlp (mlp2 , sharding_stage = 2 , use_pure_fp16 = False )
175+ for i in range (len (dp_params )):
176+ for j in range (len (stage2_params )):
177+ if dp_params [i ].name == stage2_params [j ].name :
178+ np .testing .assert_allclose (
179+ dp_params [i ].numpy (), stage2_params [j ].numpy (), rtol = 1e-6 )
183180
184181 stage2_params = train_mlp (
185182 mlp3 , sharding_stage = 2 , use_pure_fp16 = True , all_test = True )
@@ -201,4 +198,4 @@ def test_stage1_stage2():
201198
202199
203200if __name__ == '__main__' :
204- test_stage1_stage2 ()
201+ test_dp_stage2 ()
0 commit comments