Skip to content

Conversation

@haohongxiang
Copy link
Contributor

@haohongxiang haohongxiang commented Oct 13, 2021

PR types

Bug fixes

PR changes

Others

Describe

[HybridParallel]Support fp16 in dygraph hybrid parallel

单卡FP32与3D混合并行+recompute+FP16的loss曲线精度对比
image

ForFishes
ForFishes previously approved these changes Oct 14, 2021
Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use tracer._amp_level==core.AmpLevel.O0 instead of 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save for other amp level

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is ok to put guard here, but I wonder if it is needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing guard will cause diff of precision while broadcasting train_loss. So it is necessary to put guard here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can consider make found_if int32 or fp32 originally to avoid these casts afterward.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ForFishes ForFishes merged commit 10f0a0f into PaddlePaddle:develop Oct 18, 2021
fuyinno4 pushed a commit that referenced this pull request Oct 26, 2021
…imizer (#36707) * fix bugs in HybridParallelClipGrad of hybrid_parallel_optimizer (#36237) * fix bugs in HybridParallelClipGrad of hybrid_parallel_optimizer * update * update * fix bugs in mp_layers、pp_layers and HybridParallelClipGrad (#36144) * fix calling bug of HybridParallelClipGrad * fix bugs of HybridParallelClipGrad * add unittest of pp with HybridParallelClipGrad * fix bugs in mp_layers.py * update * fix bugs in pp_layers.py * update * [HybridParallel]Rebuild code for pipeline (#36396) * add no_sync for parameters sync * add pipeline for moe * [HybridParallel]Support fp16 in dygraph hybrid parallel (#36420) * [HybridParallel]Support fp16 in dygraph hybrid parallel * update * update * update for recompute * add unittest of pp+fp16 * add unittest of recompute+fp16 * update * modify ut * modify ut of cond (#36475) * fix bugs of ClipGradByGlobalNorm in HybridParallel (#36555) * fix bugs of ClipGradByGlobalNorm * add unittests * add unittests * [HybridParallel]fix bug of check_inf in fleet_base.py (#36651) * fix bug of check_inf * fix allreduce * support ClipGradByGlobalNorm in sharding (#36012) * support ClipGradByGlobalNorm in sharding * support ClipGradByGlobalNorm in sharding * test=allcase * Update test_linalg_cond.py * Update hybrid_parallel_util.py * Update hybrid_parallel_util.py Co-authored-by: ShenLiang <1422485404@qq.com> Co-authored-by: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants