Skip to content

Commit 68835a8

Browse files
authored
revert align_grad_clip (#74403)
1 parent d800ce5 commit 68835a8

File tree

2 files changed

+0
-113
lines changed

2 files changed

+0
-113
lines changed

python/paddle/nn/clip.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,6 @@ def _dygraph_clip(self, params_grads):
717717
sum_square_list = []
718718
sum_square_list_fp16 = []
719719
sum_square_list_fp32 = []
720-
flag_new_pp = True
721720
if len(params_grads) > 0 and len(params_grads[0]) > 0:
722721
src_mesh = params_grads[0][0].process_mesh
723722
else:
@@ -743,7 +742,6 @@ def _dygraph_clip(self, params_grads):
743742
# if the gradient mesh is not equal to src mesh
744743
# do reshard to get the result of squared_l2 from other pp stage mesh
745744
if src_mesh is not None and g.process_mesh != src_mesh:
746-
flag_new_pp = False
747745
pp_mesh = get_complete_pp_mesh(g.process_mesh)
748746
if set(g.process_mesh.process_ids) < set(pp_mesh.process_ids):
749747
sum_square = dist.reshard(
@@ -792,44 +790,6 @@ def async_add_n(var_list):
792790
global_norm_var.append(global_norm_var_fp64)
793791

794792
global_norm_var = async_add_n(global_norm_var)
795-
global_mesh = dist.get_mesh()
796-
is_pp_enable = False
797-
if global_mesh is not None:
798-
is_pp_enable = (
799-
"pp" in global_mesh.dim_names
800-
and global_mesh.get_dim_size("pp") > 1
801-
)
802-
if (
803-
flag_new_pp and src_mesh is not None and is_pp_enable
804-
): # Use new pp_flask,At this point global_norm_var it's sub_norm_var_sum,we need to sum it between different pp_stage
805-
global_pp_mesh = global_mesh.get_mesh_with_dim("pp")
806-
reorder_mesh = global_pp_mesh._mesh.reshape(
807-
global_mesh.get_dim_size("pp"), -1
808-
)
809-
curr_rank = dist.get_rank()
810-
assert (
811-
curr_rank in global_pp_mesh.process_ids
812-
), "current rank is not in pp process mesh"
813-
curr_rank_sub_group = None
814-
for col in range(
815-
reorder_mesh.shape[-1]
816-
): # every_sub_mesh need to create a new group,otherwise,the group id of sub_mesh will be the same,which will cause the all_gather error
817-
sub_mesh = dist.ProcessMesh(reorder_mesh[:, col], ["pp"])
818-
sub_group = dist.new_group(sub_mesh.process_ids)
819-
if curr_rank in reorder_mesh[:, col]:
820-
curr_rank_sub_group = sub_group
821-
global_norm_var_list = []
822-
dist.all_gather(
823-
global_norm_var_list,
824-
global_norm_var._local_value(),
825-
group=curr_rank_sub_group,
826-
)
827-
real_global_norm_var = async_add_n(global_norm_var_list)
828-
global_norm_var = dist.shard_tensor(
829-
real_global_norm_var,
830-
global_norm_var.process_mesh,
831-
global_norm_var.placements,
832-
)
833793

834794
if self.should_comm_on_shard_dim and hasattr(self, 'sharding_group'):
835795
paddle.distributed.all_reduce(

test/auto_parallel/PP_Schedules_demo.py

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -414,67 +414,6 @@ def test_dp_pp(self):
414414
opt.clear_grad()
415415
return losses_by_step, all_losses_in_one_step_md5sum
416416

417-
def test_pp_model_with_ClipGradByGlobalNorm(self):
418-
"""Test pipeline parallel model with ClipGradByGlobalNorm using PPMyModel as the baseline"""
419-
fix_seeds()
420-
pp_model = PPMyModel()
421-
opt = paddle.optimizer.AdamW(
422-
learning_rate=0.001,
423-
parameters=pp_model.parameters(),
424-
grad_clip=paddle.nn.ClipGradByGlobalNorm(1.0),
425-
)
426-
loss_fn = nn.MSELoss()
427-
dataset = RandomDataset(image_size=8, output_size=8, num_samples=8)
428-
loader = DataLoader(dataset, batch_size=1)
429-
pp_losses_step = []
430-
num_iterations = 20
431-
432-
for iter_idx in range(num_iterations):
433-
pp_losses_micro_batch = []
434-
for i, (data, label) in enumerate(loader):
435-
output = pp_model(data)
436-
loss = loss_fn(output, label)
437-
pp_losses_micro_batch.append(loss.item())
438-
loss.backward()
439-
pp_losses_step.append(
440-
np.array(pp_losses_micro_batch, dtype=np.float32).mean()
441-
)
442-
opt.step()
443-
opt.clear_grad()
444-
return pp_losses_step
445-
446-
def test_ScheduleFThenB_with_ClipGradByGlobalNorm(self):
447-
fix_seeds()
448-
self.model = PPMyModel_SingleStage()
449-
self.micro_batches = 8
450-
self.stage = PipelineStage(self.model, self.rank, 4, group=self.group)
451-
self.stage.has_backward = True
452-
loss_fn_ = nn.MSELoss()
453-
schedule = ScheduleFThenB(
454-
self.stage, self.micro_batches, loss_fn=loss_fn_
455-
)
456-
opt = paddle.optimizer.AdamW(
457-
learning_rate=0.001,
458-
parameters=self.model.parameters(),
459-
grad_clip=paddle.nn.ClipGradByGlobalNorm(1.0),
460-
)
461-
dataset = RandomDataset(image_size=8, output_size=8, num_samples=8)
462-
loader = DataLoader(dataset, batch_size=8)
463-
losses_by_step = []
464-
num_iterations = 20
465-
466-
for iter_idx in range(num_iterations):
467-
losses_by_micro_batch = []
468-
for i, (data, label) in enumerate(loader):
469-
schedule.step(data, target=label, losses=losses_by_micro_batch)
470-
if self.rank == 3:
471-
losses_by_step.append(
472-
np.array(losses_by_micro_batch, dtype=np.float32).mean()
473-
)
474-
opt.step()
475-
opt.clear_grad()
476-
return losses_by_step
477-
478417
def test_dp_pp_align_mode(self):
479418
fix_seeds()
480419
paddle.set_flags({'FLAGS_enable_auto_parallel_align_mode': True})
@@ -551,12 +490,6 @@ def run_test(self):
551490
scheduleFThenB_losses = self.test_ScheduleFThenB()
552491
schedule1f1b_losses = self.test_Schedule1F1B()
553492
schedulevpp_losses = self.test_ScheduleVPP()
554-
pp_model_with_ClipGradByGlobalNorm_losses = (
555-
self.test_pp_model_with_ClipGradByGlobalNorm()
556-
)
557-
scheduleFThenB_with_ClipGradByGlobalNorm_losses = (
558-
self.test_ScheduleFThenB_with_ClipGradByGlobalNorm()
559-
)
560493
dp_pp_losses, dp_pp_losses_md5sum = self.test_dp_pp()
561494
dp_pp_align_mode_losses, dp_pp_align_mode_losses_md5sum = (
562495
self.test_dp_pp_align_mode()
@@ -587,12 +520,6 @@ def run_test(self):
587520
rtol=1e-5,
588521
)
589522

590-
np.testing.assert_allclose(
591-
pp_model_with_ClipGradByGlobalNorm_losses,
592-
scheduleFThenB_with_ClipGradByGlobalNorm_losses,
593-
rtol=1e-5,
594-
)
595-
596523
np.testing.assert_allclose(
597524
dp_pp_align_mode_losses,
598525
dp_pp_losses,

0 commit comments

Comments
 (0)