Skip to content

Commit 8a901c5

Browse files
committed
Update ops for Sigmoid and Tanh
1 parent aec658f commit 8a901c5

File tree

4 files changed

+39
-25
lines changed

4 files changed

+39
-25
lines changed

Sigmoid.cu

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,27 @@
44
#include <THC/THCApply.cuh>
55

66
template <typename T>
7-
struct sigmoidupdateOutput_functor
8-
{
9-
__device__ void operator()(T *output, const T *input) const
10-
{
11-
*output = ScalarConvert<double, T>::to(1./(1.+ exp(-*input)));
7+
struct SigmoidGradInputOp {
8+
__device__ __forceinline__ void operator()(T* gradInput, const T *output, const T *gradOutput) const {
9+
*gradInput = *gradOutput * (1.f - *output) * (*output);
1210
}
1311
};
1412

15-
template <typename T>
16-
struct sigmoidupdateGradInput_functor
17-
{
18-
__device__ void operator()(T *gradInput, const T *output, const T *gradOutput) const
19-
{
20-
*gradInput = ScalarConvert<double, T>::to(*gradOutput * (1.-*output) * (*output));
13+
#ifdef CUDA_HALF_TENSOR
14+
template <>
15+
struct SigmoidGradInputOp<half> {
16+
__device__ __forceinline__ void operator()(half* gradInput, const half *output, const half *gradOutput) const {
17+
#ifdef CUDA_HALF_INSTRUCTIONS
18+
half one = __float2half(1.f);
19+
*gradInput = __hmul(*gradOutput, __hmul(__hadd(one, __hneg(*output)), *output));
20+
#else
21+
float out = __half2float(*output);
22+
float go = __half2float(*gradOutput);
23+
*gradInput = __float2half(go * (1.f - out) * out);
24+
#endif
2125
}
2226
};
27+
#endif
2328

2429
#include "generic/Sigmoid.cu"
2530
#include "THCGenerateFloatTypes.h"

Tanh.cu

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,32 @@
44
#include <THC/THCApply.cuh>
55

66
template <typename T>
7-
struct tanhupdateOutput_functor
7+
struct TanhGradInputOp
88
{
9-
__device__ void operator()(T *output, const T *input) const
10-
{
11-
*output = tanh(*input);
9+
__device__ __forceinline__ void operator()(T *gradInput,
10+
const T *output, const T *gradOutput) const {
11+
*gradInput = *gradOutput * (1.f - *output * *output);
1212
}
1313
};
1414

15-
template <typename T>
16-
struct tanhupdateGradInput_functor
15+
#ifdef CUDA_HALF_TENSOR
16+
template <>
17+
struct TanhGradInputOp<half>
1718
{
18-
__device__ void operator()(T *gradInput, const T *output, const T *gradOutput) const
19-
{
20-
*gradInput = *gradOutput * (1 - *output * *output);
19+
__device__ __forceinline__ void operator()(half *gradInput,
20+
const half *output, const half *gradOutput) const {
21+
#ifdef CUDA_HALF_INSTRUCTIONS
22+
const half one = __float2half(1.f);
23+
const half out_square = __hmul(*output, *output);
24+
*gradInput = __hmul(*gradOutput, __hadd(one, __hneg(out_square)));
25+
#else
26+
float out = __half2float(*output);
27+
float go = __half2float(*gradOutput);
28+
*gradInput = __float2half(go * (1.f - out * out));
29+
#endif
2130
}
2231
};
32+
#endif
2333

2434
#include "generic/Tanh.cu"
2535
#include "THCGenerateFloatTypes.h"

generic/Sigmoid.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ void THNN_(Sigmoid_updateOutput)(
1010
THCTensor *output)
1111
{
1212
THCUNN_assertSameGPU(state, 2, input, output);
13-
THCTensor_(resizeAs)(state, output, input);
14-
THC_pointwiseApply2(state, output, input, sigmoidupdateOutput_functor<real>());
13+
THCTensor_(sigmoid)(state, output, input);
1514
}
1615

1716
void THNN_(Sigmoid_updateGradInput)(
@@ -24,7 +23,7 @@ void THNN_(Sigmoid_updateGradInput)(
2423
THCUNN_check_nElement(state, input, gradOutput);
2524
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
2625
THCTensor_(resizeAs)(state, gradInput, output);
27-
THC_pointwiseApply3(state, gradInput, output, gradOutput, sigmoidupdateGradInput_functor<real>());
26+
THC_pointwiseApply3(state, gradInput, output, gradOutput, SigmoidGradInputOp<real>());
2827
}
2928

3029
#endif

generic/Tanh.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ void THNN_(Tanh_updateOutput)(
1111
{
1212
THCUNN_assertSameGPU(state, 2, input, output);
1313
THCTensor_(resizeAs)(state, output, input);
14-
THC_pointwiseApply2(state, output, input, tanhupdateOutput_functor<real>());
14+
THCTensor_(tanh)(state, output, input);
1515
}
1616

1717
void THNN_(Tanh_updateGradInput)(
@@ -24,7 +24,7 @@ void THNN_(Tanh_updateGradInput)(
2424
THCUNN_check_shape(state, output, gradOutput);
2525
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
2626
THCTensor_(resizeAs)(state, gradInput, output);
27-
THC_pointwiseApply3(state, gradInput, output, gradOutput, tanhupdateGradInput_functor<real>());
27+
THC_pointwiseApply3(state, gradInput, output, gradOutput, TanhGradInputOp<real>());
2828
}
2929

3030
#endif

0 commit comments

Comments
 (0)