Skip to content

Conversation

@Zjq9409
Copy link
Contributor

@Zjq9409 Zjq9409 commented Dec 19, 2021

PR types

Performance optimization

PR changes

OPs

Describe

使用elementwise优化gelu算子GPU反向计算,前向计算+反向计算优化后性能数据如下:
image

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@ZzSean
Copy link
Contributor

ZzSean commented Dec 20, 2021

PR名字描述稍微详细一点,跟上个PR对应

@Zjq9409 Zjq9409 changed the title optimize gelu backward use elementwise to optimize gelu backward implementation on GPU Dec 20, 2021
@Zjq9409
Copy link
Contributor Author

Zjq9409 commented Dec 21, 2021

PR名字描述稍微详细一点,跟上个PR对应

Done.

tanh(kAlpha * x * (one + static_cast<MPType>(0.044715) * x * x));
auto ans =
half * x * ((one - tanh_out * tanh_out) *
(kAlpha + static_cast<MPType>(0.1070322243) * x * x)) +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要出现公式以外的魔鬼数字,都用表达式来代替

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto tanh_out =
tanh(kAlpha * x * (one + static_cast<MPType>(0.044715) * x * x));
auto ans =
half * x * ((one - tanh_out * tanh_out) *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

公式再化简一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/gelu_op.h"
#include "paddle/fluid/platform/float16.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件引用已经在paddle/fluid/operators/amp/fp16_type_traits.h 引用过了,可以删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

(one + tanh(static_cast<MPType>(0.79788456) * x *
(one + static_cast<MPType>(0.044715) * x * x)));
MPType half = static_cast<MPType>(0.5);
MPType decimal = static_cast<MPType>(0.044715);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3})))
gelu的公式,其中“0.044715”是一个固定的常量值

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.044715

可以用宏定义的方式表达这个常量值, 比如:

#define GELU_CONSTANT 0.044715 

同时将这个宏定义放在通用文件夹中,比如gelu_op.h 中,同步修改使用了0.044715 的代码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

MPType dout = static_cast<MPType>(arg_dout);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType decimal = static_cast<MPType>(0.044715);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分的魔术数还是存在

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

(one + tanh(static_cast<MPType>(0.79788456) * x *
(one + static_cast<MPType>(0.044715) * x * x)));
MPType half = static_cast<MPType>(0.5);
MPType decimal = static_cast<MPType>(0.044715);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.044715

可以用宏定义的方式表达这个常量值, 比如:

#define GELU_CONSTANT 0.044715 

同时将这个宏定义放在通用文件夹中,比如gelu_op.h 中,同步修改使用了0.044715 的代码

MPType kBeta = kAlpha * decimal * static_cast<MPType>(3);
auto tanh_out = tanh(kAlpha * x * (one + decimal * x * x));
auto temp = (one - tanh_out * tanh_out) * (kAlpha + kBeta * x * x);
auto ans = half * x * temp + half * (one + tanh_out);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分的计算中应该存在多次重复计算的参数,比如:x^3 ,可以把这类参数挑出来,减少计算量

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@JamesLim-sy JamesLim-sy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@JamesLim-sy JamesLim-sy merged commit 858e435 into PaddlePaddle:develop Dec 22, 2021
zmxdream pushed a commit to zmxdream/Paddle that referenced this pull request Dec 25, 2021
…lePaddle#38263) * optimize gelu backward * optimize gelu backward * optimize code * Number to expression * Replacement number
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants