@@ -15,7 +15,6 @@ limitations under the License. */
1515#include " paddle/fluid/operators/amp/fp16_type_traits.h"
1616#include " paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
1717#include " paddle/fluid/operators/gelu_op.h"
18- #include " paddle/fluid/platform/float16.h"
1918
2019namespace paddle {
2120namespace operators {
@@ -27,9 +26,11 @@ struct GeluWithApproximateFunctor {
2726 // this function is tanh approximation of gelu
2827 MPType x = static_cast <MPType>(arg_x);
2928 MPType one = static_cast <MPType>(1 );
30- MPType out = x * static_cast <MPType>(0.5 ) *
31- (one + tanh (static_cast <MPType>(0.79788456 ) * x *
32- (one + static_cast <MPType>(0.044715 ) * x * x)));
29+ MPType half = static_cast <MPType>(0.5 );
30+ MPType kAlpha = static_cast <MPType>(M_2_SQRTPI * M_SQRT1_2);
31+ auto tanh_out =
32+ tanh (kAlpha * x * (one + static_cast <MPType>(GELU_CONSTANT) * x * x));
33+ MPType out = x * half * (one + tanh_out);
3334 return static_cast <T>(out);
3435 }
3536};
@@ -40,9 +41,10 @@ struct GeluWithoutApproximateFunctor {
4041 inline HOSTDEVICE T operator ()(T arg_x) {
4142 // actual gelu with approximation = false
4243 MPType x = static_cast <MPType>(arg_x);
44+ MPType one = static_cast <MPType>(1 );
45+ MPType half = static_cast <MPType>(0.5 );
4346 MPType erf_out = erf (x * static_cast <MPType>(M_SQRT1_2));
44- MPType out =
45- x * static_cast <MPType>(0.5 ) * (static_cast <MPType>(1 ) + erf_out);
47+ MPType out = x * half * (one + erf_out);
4648 return static_cast <T>(out);
4749 }
4850};
@@ -71,6 +73,68 @@ class GeluKernel<platform::CUDADeviceContext, T>
7173 }
7274};
7375
76+ template <typename T>
77+ struct GeluWithApproximateGradFunctor {
78+ using MPType = typename details::MPTypeTrait<T>::Type;
79+ inline HOSTDEVICE T operator ()(T arg_x, T arg_dout) {
80+ MPType x = static_cast <MPType>(arg_x);
81+ MPType dout = static_cast <MPType>(arg_dout);
82+ MPType one = static_cast <MPType>(1 );
83+ MPType half = static_cast <MPType>(0.5 );
84+ MPType kAlpha = static_cast <MPType>(M_2_SQRTPI * M_SQRT1_2);
85+ MPType kBeta =
86+ kAlpha * static_cast <MPType>(GELU_CONSTANT) * static_cast <MPType>(3 );
87+ auto cube_x = x * x * x;
88+ auto tanh_out =
89+ tanh (kAlpha * ((static_cast <MPType>(GELU_CONSTANT) * cube_x) + x));
90+ auto ans =
91+ half * (one + tanh_out +
92+ (one - tanh_out * tanh_out) * (x * kAlpha + kBeta * cube_x));
93+ return static_cast <T>(ans * dout);
94+ }
95+ };
96+
97+ template <typename T>
98+ struct GeluWithoutApproximateGradFunctor {
99+ using MPType = typename details::MPTypeTrait<T>::Type;
100+ inline HOSTDEVICE T operator ()(T arg_x, T arg_dout) {
101+ MPType x = static_cast <MPType>(arg_x);
102+ MPType dout = static_cast <MPType>(arg_dout);
103+ MPType one = static_cast <MPType>(1 );
104+ MPType half = static_cast <MPType>(0.5 );
105+ MPType kAlpha = static_cast <MPType>(M_2_SQRTPI * M_SQRT1_2);
106+ auto ans = half * (one + erf (x * static_cast <MPType>(M_SQRT1_2))) +
107+ half * kAlpha * x * exp (-half * x * x);
108+ return static_cast <T>(ans * dout);
109+ }
110+ };
111+
112+ template <typename T>
113+ class GeluGradKernel <platform::CUDADeviceContext, T>
114+ : public framework::OpKernel<T> {
115+ public:
116+ void Compute (const framework::ExecutionContext& context) const override {
117+ auto * x = context.Input <framework::Tensor>(" X" );
118+ auto * dout =
119+ context.Input <framework::Tensor>(framework::GradVarName (" Out" ));
120+ auto * dx = context.Output <framework::Tensor>(framework::GradVarName (" X" ));
121+ auto approximate = context.Attr <bool >(" approximate" );
122+ dx->mutable_data <T>(dout->place ());
123+
124+ std::vector<const framework::Tensor*> ins = {x, dout};
125+ std::vector<framework::Tensor*> outs = {dx};
126+ const auto & dev_ctx =
127+ context.template device_context <platform::CUDADeviceContext>();
128+ if (approximate) {
129+ LaunchElementwiseCudaKernel<ElementwiseType::kBinary , T, T>(
130+ dev_ctx, ins, &outs, 0 , GeluWithApproximateGradFunctor<T>());
131+ } else {
132+ LaunchElementwiseCudaKernel<ElementwiseType::kBinary , T, T>(
133+ dev_ctx, ins, &outs, 0 , GeluWithoutApproximateGradFunctor<T>());
134+ }
135+ }
136+ };
137+
74138} // namespace operators
75139} // namespace paddle
76140
0 commit comments