@@ -7,16 +7,25 @@ template <typename T>
77struct logSigmoid_updateOutput_functor
88{
99 __device__ void operator ()(T *output, const T *input) const {
10- *output = -THCNumerics<T>::log (1 .f + THCNumerics<T>::exp (- *input));
10+ const T max = fmaxType (0 .f , - *input);
11+ const T z = THCNumerics<T>::exp (-max) + THCNumerics<T>::exp (-*input -max);
12+ *output = -(max + THCNumerics<T>::log (z));
1113 }
1214};
1315
1416template <typename T>
1517struct logSigmoid_updateGradInput_functor
1618{
1719 __device__ void operator ()(T *gradInput, const T *input, const T *gradOutput) const {
18- const T z = THCNumerics<T>::exp (- *input);
19- *gradInput = *gradOutput * z / (1 .f + z);
20+ const T max = fmaxType (0 .f , -*input);
21+ const T z = THCNumerics<T>::exp (-max) + THCNumerics<T>::exp (-*input -max);
22+ T max_deriv = 0 .f ;
23+ T sign = -1 .f ;
24+ if (*input < 0 .f ){
25+ max_deriv = -1 .f ;
26+ sign = 1 .f ;
27+ }
28+ *gradInput = *gradOutput * (-max_deriv - sign*((z - 1 .f )/z));
2029 }
2130};
2231
@@ -25,11 +34,14 @@ template <>
2534struct logSigmoid_updateOutput_functor <half> {
2635 __device__ __forceinline__ void operator ()(half* output, const half *input) const {
2736#ifdef CUDA_HALF_INSTRUCTIONS
28- const half one = __float2half (1 .f );
29- *output = __hneg (THCNumerics<half>::log (one + THCNumerics<half>::exp (__hneg (*input))));
37+ const half max = fmaxType (__float2half (0 .f ), __hneg (*input));
38+ const half z = THCNumerics<half>::exp (__hneg (max)) + THCNumerics<half>::exp (__hneg (*input) - max);
39+ *output = __hneg (max + THCNumerics<half>::log (z));
3040#else
3141 float in = __half2float (*input);
32- *output = __float2half (-THCNumerics<float >::log (1 .f + THCNumerics<float >::exp (-in)));
42+ float max = fmaxType (0 .f , -in);
43+ float z = THCNumerics<float >::exp (-max) + THCNumerics<float >::exp (-in - max);
44+ *output = __float2half (-(max + THCNumerics<float >::log (z)));
3345#endif
3446 }
3547};
@@ -39,12 +51,28 @@ struct logSigmoid_updateGradInput_functor<half> {
3951 __device__ __forceinline__ void operator ()(half* gradInput, const half *input, const half *gradOutput) const {
4052#ifdef CUDA_HALF_INSTRUCTIONS
4153 const half one = __float2half (1 .f );
42- const half in_exp = THCNumerics<half>::exp (__hneg (*input));
43- *gradInput = hdiv (__hmul (*gradOutput, in_exp), __hadd (one, in_exp));
54+ const half zero = __float2half (0 .f );
55+ const half max = fmaxType (zero, __hneg (*input));
56+ const half z = THCNumerics<half>::exp (__hneg (max)) + THCNumerics<half>::exp (__hneg (*input) - max);
57+ half max_deriv = zero;
58+ half sign = __hneg (one);
59+ if (*input < zero){
60+ max_deriv = __hneg (one);
61+ sign = one;
62+ }
63+ *gradInput = __hmul (*gradOutput, (__hneg (max_deriv) - __hmul (sign, __hdiv (z - one, z))));
4464#else
45- const float in_exp = THCNumerics<float >::exp (-(__half2float (*input)));
65+ const float in = __half2float (*input);
66+ const float max = fmaxType (0 .f , -in);
67+ const float z = THCNumerics<float >::exp (-max) + THCNumerics<float >::exp (-in - max);
4668 const float go = __half2float (*gradOutput);
47- *gradInput = __float2half (go * in_exp / (1 .f + in_exp));
69+ float max_deriv = 0 .f ;
70+ float sign = -1 .f ;
71+ if (in < 0 .f ){
72+ max_deriv = -1 .f ;
73+ sign = 1 .f ;
74+ }
75+ *gradInput = __float2half (go * (-max_deriv - sign*((z - 1 .f )/z)));
4876#endif
4977 }
5078};
0 commit comments