Skip to content

Conversation

@Baibaifan
Copy link
Contributor

@Baibaifan Baibaifan commented Dec 15, 2021

PR types

Performance optimization

PR changes

Others

Describe

Integration sharding stage2 function
1.Support group = None
2.Support param_groups for optimizer

import paddle from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2 from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2 fleet.init(is_collective=True) group = paddle.distributed.new_group([0, 1]) # wrap model & optimizer  model = model_class(...) oss_optimizer = ShardingOptimizer(params=model.parameters(), optim=optimizer, group=group) model = ShardingStage2(model, oss_optimizer, group=group) # use optimizer as normal img, label = data label.stop_gradient = True img.stop_gradient = True out = model(img) loss = paddle.nn.functional.cross_entropy(input=out, label=label) oss_optimizer.step() oss_optimizer.clear_grad()

image

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Baibaifan Baibaifan force-pushed the integration_stage2_function branch 3 times, most recently from b1bf3cc to 7d5ad2e Compare December 15, 2021 06:35
@Baibaifan Baibaifan force-pushed the integration_stage2_function branch from 7d5ad2e to 576a132 Compare December 15, 2021 09:42
@Baibaifan Baibaifan force-pushed the integration_stage2_function branch from 4db80ec to c1bf4fc Compare December 16, 2021 06:25
@Baibaifan Baibaifan force-pushed the integration_stage2_function branch 3 times, most recently from 04d6d9f to 5d6cc91 Compare December 17, 2021 06:14
@Baibaifan Baibaifan force-pushed the integration_stage2_function branch from 5d6cc91 to 7b26ec9 Compare December 17, 2021 08:44
@Baibaifan Baibaifan force-pushed the integration_stage2_function branch from 7b26ec9 to cf9b633 Compare December 17, 2021 11:27
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deepcopy increase memory..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修复,改成引用传递。

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need support global group if group=None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已支持

@Baibaifan Baibaifan force-pushed the integration_stage2_function branch 3 times, most recently from 2715035 to 0f53247 Compare December 17, 2021 13:02
@Baibaifan Baibaifan changed the title Integration sharding stage2 function [Dygraph]Integration sharding stage2 function Dec 18, 2021
ForFishes
ForFishes previously approved these changes Dec 18, 2021
Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Baibaifan Baibaifan merged commit 327e505 into PaddlePaddle:develop Dec 19, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants