Skip to content

Conversation

@Baibaifan
Copy link
Contributor

@Baibaifan Baibaifan commented Jan 17, 2022

PR types

New features

PR changes

Others

Describe

Add sharding stage3 offload.

import paddle from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3 fleet.init(is_collective=True) group = paddle.distributed.new_group([0, 1]) # wrap model model = model_class(...) model = ShardingStage3(model, optimizer=optimizer, group=group, offload=True) # use optimizer as normal img, label = data label.stop_gradient = True img.stop_gradient = True out = model(img) loss = paddle.nn.functional.cross_entropy(input=out, label=label) loss.backward() optimizer.step() optimizer.clear_grad() # Get all parameter from parameter slice # If parameter need to convert to cpu, please add convert2cpu=True model.get_all_parameters(convert2cpu=True)

stage3 and stage3 offload fp32 GPT 117M
1

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Baibaifan Baibaifan force-pushed the sharding_stage3_offload branch 3 times, most recently from 735bdbe to 895910b Compare January 21, 2022 12:22
@Baibaifan Baibaifan force-pushed the sharding_stage3_offload branch from 895910b to e05bc19 Compare January 21, 2022 14:13
@Baibaifan Baibaifan closed this Jan 23, 2022
@Baibaifan Baibaifan reopened this Jan 23, 2022
@Baibaifan Baibaifan closed this Jan 23, 2022
@Baibaifan Baibaifan reopened this Jan 23, 2022
@Baibaifan Baibaifan closed this Jan 23, 2022
@Baibaifan Baibaifan reopened this Jan 23, 2022
Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Baibaifan Baibaifan merged commit 4682310 into PaddlePaddle:develop Jan 24, 2022
@Baibaifan Baibaifan changed the title Add sharding stage3 offload [Dygraph]Add sharding stage3 offload Jan 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants