Skip to content

Conversation

@Baibaifan
Copy link
Contributor

@Baibaifan Baibaifan commented Mar 3, 2022

PR types

New features

PR changes

APIs

Describe

Add group sharded api

  1. group_sharded_parallel
  2. save_group_sharded_model
import paddle from paddle.fluid.dygraph.nn import Linear from paddle.distributed import fleet from paddle.distributed.sharding import group_sharded_parallel, save_group_sharded_model fleet.init(is_collective=True) group = paddle.distributed.new_group([0, 1]) model = Linear(1000, 1000) clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters(), weight_decay=0.00001, grad_clip=clip) # wrap sharding model, optimizer and scaler model, optimizer, scaler = group_sharded_parallel(model, optimizer, "p_g", scaler=scaler) img, label = data label.stop_gradient = True img.stop_gradient = True out = model(img) loss = paddle.nn.functional.cross_entropy(input=out, label=label) loss.backward() optimizer.step() optimizer.clear_grad() # save model and optimizer state_dict save_group_sharded_model(model, output=output_dir, optimizer=optimizer) 
@paddle-bot-old
Copy link

paddle-bot-old bot commented Mar 3, 2022

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里引入fluid的原因是什么?fluid下的api会被废弃。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改为paddle.autograd.no_grad()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议去掉这个对象,不需要为参数定义单独增加一个对象

  1. 直接在group_sharded_parallel函数里使用level='os'或者直接使用level=1,参考amp的level定义,一般理解level对应一个整数,类似verbose之类的
  2. os, os_g, p_g_os是什么的缩写?可读性较差,是否有更好的表示方式?
Copy link
Contributor Author

@Baibaifan Baibaifan Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

经过讨论,去掉ShardedLevel,采用字符串名字"os", "os_g", "p_g_os"作为level,level名字和论文对齐。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shard_level -> level
因为api名称已经包含sharded了,这里的参数默认都是针对shard的参数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

除了group_sharded以外,是否还有其他的sharded方式?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前采用group_sharded的意思是分组参数切片,是和数据并行并列的一种分布式方式,所以定义为group_sharded。目前还未有其他sharded方式。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不需要用__all__公开api,通过__init__.py公开就行
paddle.distributed.sharding.group_sharded_parallel
而不是
paddle.distributed.sharding.group_sharded.group_sharded_parallel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

除了group外,是否还有其他的参数形式?
save_for_group_sharded -> save_sharded_model ? 或者save_group_sharded_model呢?
类似save_inference_model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

经讨论修改为save_group_sharded_model

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for set_tests_properties(test_dygraph_group_sharded_api PROPERTIES TIMEOUT 120)

@Baibaifan Baibaifan closed this Mar 9, 2022
@Baibaifan Baibaifan reopened this Mar 9, 2022
@Baibaifan Baibaifan merged commit f40ed5f into PaddlePaddle:develop Mar 9, 2022
from . import cloud_utils # noqa: F401
from . import utils # noqa: F401

from .sharding import * # noqa: F401
Copy link
Contributor

@gongweibao gongweibao Mar 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why import *?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

5 participants