Skip to content

Conversation

@0x45f
Copy link
Contributor

@0x45f 0x45f commented Nov 2, 2021

PR types

New features

PR changes

Others

Describe

  1. 支持模型动转静后进行pure fp16训练。
    本PR在CastPureFp16Inputs函数中加入了特判逻辑:对于动转静中所调用的run_program op直接跳过,不再进行后续的cast操作。因为run_program op只有FP32 kernel,且我们在动转静pure fp16训练时在调用run_program op时已经将run_program op的输入cast为fp16类型,所以在CastPureFp16Inputs应该跳过对于run_program op的处理,避免将其输入再cast回fp32类型。
  2. 动态图pure fp16训练loss 和 动转静pure fp16训练loss存在波动,所以将单测中的atol放大到1e-3,以确保单测通过。对于动转静pure fp16训练,已经在mnist和resnet网络上测试了网络的收敛性, 均可以正常收敛。
    • mnist部分训练过程
      image
    • resnet部分训练结果
      image
  3. 支持了动转静AMP以及动转静pure fp16训练中paddle.amp.auto_cast接口中黑白名单(custom_white_list参数和custom_black_list参数)的设置。在此PR之前在动转静AMP训练中设置这两个参数并不会生效。
@paddle-bot-old
Copy link

paddle-bot-old bot commented Nov 2, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.


def _in_pure_fp16_guard():
tracer = _dygraph_tracer()
if tracer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if tracer:
return tracer and tracer._amp_level == core.AmpLevel.O2
Aurelius84
Aurelius84 previously approved these changes Nov 22, 2021
"""
Tests model decorated by `dygraph_to_static_output` in static mode. For users, the model is defined in dygraph mode and trained in static mode.
"""
with fluid.dygraph.guard(place):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认已经开启动态图了,不需要这个guard了

dygraph_loss = self.train(to_static=False)
self.assertTrue(
np.allclose(static_loss, dygraph_loss),
msg="static_loss: {} \n dygraph_loss: {}".format(static_loss,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确保CI稳定没有问题,若放开tol,需要加NOTE

Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Aurelius84 Aurelius84 merged commit 52edad6 into PaddlePaddle:develop Nov 24, 2021
@0x45f 0x45f deleted the dy2stat_support_pure_fp16 branch November 24, 2021 03:40
Zjq9409 pushed a commit to Zjq9409/Paddle that referenced this pull request Dec 10, 2021
* run dy2stat pure fp16 in Linear model * no use self._pure_fp16_inputs * add test and fix Adam error in dy2stat pure fp16 training * use paddle.optimizer.Adam * run test in gpu * change test time for CI * enlarge atol for test_resnet_pure_fp16 * refine code and enlarge atol * make custom_white_list and custom_black_list take effect for AMP and pure fp16 * check tracer is not None * use default atol * change filter_size * change atol and add some NOTE
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants