- Notifications
You must be signed in to change notification settings - Fork 5.9k
Speed up elemwise grad #8402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed up elemwise grad #8402
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -20,9 +20,11 @@ limitations under the License. */ | |
| | ||
| #ifdef __NVCC__ | ||
| #include <thrust/iterator/iterator_adaptor.h> | ||
| constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; | ||
| #endif | ||
| | ||
| #include "paddle/fluid/operators/math/math_function.h" | ||
| #include "paddle/fluid/platform/for_range.h" | ||
| | ||
| namespace paddle { | ||
| namespace operators { | ||
| | @@ -311,6 +313,258 @@ EIGEN_FUNCTOR(Mul, EIGEN_MUL); | |
| #define EIGEN_DIV(x, y) ((x) / (y)) | ||
| EIGEN_FUNCTOR(Div, EIGEN_DIV); | ||
| | ||
| template <typename T, typename DX_OP, typename DY_OP> | ||
| struct ElemwiseGradNoBroadcast { | ||
| const T* x_; | ||
| const T* y_; | ||
| const T* out_; | ||
| const T* dout_; | ||
| | ||
| HOSTDEVICE void operator()(size_t i) { | ||
| if (dx_ != nullptr) { | ||
| dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]); | ||
| } | ||
| if (dy_ != nullptr) { | ||
| dy_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]); | ||
| } | ||
| } | ||
| | ||
| DX_OP dx_op_; | ||
| DY_OP dy_op_; | ||
| T* dx_; | ||
| T* dy_; | ||
| }; | ||
| | ||
| template <typename T, typename DX_OP, typename DY_OP> | ||
| static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out, | ||
| const T* dout, int h, int w, DX_OP dx_op, | ||
| DY_OP dy_op, T* dx, T* dy) { | ||
| for (int i = 0; i < h; ++i) { | ||
| for (int j = 0; j < w; ++j) { | ||
| int x_offset = i * w + j; | ||
| if (dx != nullptr) { | ||
| dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); | ||
| } | ||
| if (dy != nullptr) { | ||
| T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); | ||
| if (i == 0) { | ||
| dy[j] = tmp; | ||
| } else { | ||
| dy[j] += tmp; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| #ifdef __NVCC__ | ||
| template <typename T, typename DX_OP, typename DY_OP> | ||
| static __global__ void ElemwiseGradBroadcast1CUDAKernel( | ||
| const T* x, const T* y, const T* out, const T* dout, int h, int w, | ||
| DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) { | ||
| extern __shared__ char shm_buffer[]; | ||
| T* shm = reinterpret_cast<T*>(shm_buffer); | ||
| | ||
| int j = blockIdx.x; | ||
| int i = threadIdx.x; | ||
| int tid = threadIdx.x; | ||
| shm[tid] = 0; | ||
| | ||
| do { | ||
| int x_offset = i * w + j; | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The data(x, dx, dout) access is not continuous. This may have an impact on Performance. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed. However, this will make the reduction easier. There could be a more effective implementation. | ||
| if (dx) { | ||
| dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder whether this will be faster than before. For elementwise_add_grad, the computation of dx only use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I just check this by reading the generated PTX file. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
cool... | ||
| } | ||
| if (dy) { | ||
| shm[tid] += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); | ||
| } | ||
| i += ELEMWISE_MAX_BLOCK_DIM; | ||
| } while (i < h); | ||
| | ||
| if (dy) { | ||
| __syncthreads(); | ||
| | ||
| h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; | ||
| | ||
| // Sum, could be optimized | ||
| if (threadIdx.x == 0) { | ||
| for (int k = 1; k < h; ++k) { | ||
| shm[0] += shm[k]; | ||
| } | ||
| dy[j] = shm[0]; | ||
| } | ||
| } | ||
| } | ||
| | ||
| template <typename T, typename DX_OP, typename DY_OP> | ||
| static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x, | ||
| const T* y, const T* out, const T* dout, | ||
| int h, int w, DX_OP dx_op, DY_OP dy_op, | ||
| T* dx, T* dy) { | ||
| int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); | ||
| int gird_size = w; | ||
| int shared_mem_size = block_size * sizeof(T); | ||
| ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, shared_mem_size, | ||
| stream>>>(x, y, out, dout, h, w, dx_op, | ||
| dy_op, dx, dy); | ||
| } | ||
| | ||
| #endif | ||
| | ||
| template <typename T, typename DX_OP, typename DY_OP> | ||
| static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out, | ||
| const T* dout, int pre, int n, int post, | ||
| DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) { | ||
| for (int i = 0; i < pre; ++i) { | ||
| for (int j = 0; j < n; ++j) { | ||
| for (int k = 0; k < post; ++k) { | ||
| int x_offset = i * n * post + j * post + k; | ||
| if (dx != nullptr) { | ||
| dx[x_offset] = | ||
| dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); | ||
| } | ||
| if (dy != nullptr) { | ||
| T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); | ||
| if (i == 0 && k == 0) { | ||
| dy[j] = tmp; | ||
| } else { | ||
| dy[j] += tmp; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| | ||
| #ifdef __NVCC__ | ||
| | ||
| template <typename T, typename DX_OP, typename DY_OP> | ||
| static __global__ void ElemwiseGradBroadcast2CUDAKernel( | ||
| const T* x, const T* y, const T* out, const T* dout, int pre, int n, | ||
| int post, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) { | ||
| int tid = threadIdx.x; | ||
| int j = blockIdx.x; | ||
| | ||
| extern __shared__ char shm_buffer[]; | ||
| T* shm = reinterpret_cast<T*>(shm_buffer); | ||
| shm[tid] = 0; | ||
| int ttid = tid; | ||
| | ||
| while (true) { | ||
| int i = ttid / post; | ||
| int k = ttid % post; | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The division is very time consuming, it is recommended to multiply. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure which implementation is faster, the multiplication between float values or the division between integers. However, it should not cost too much time these lines since it is not the main logic of the method. | ||
| if (i >= pre) break; | ||
| | ||
| int x_offset = i * n * post + j * post + k; | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ==> There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The compiler should optimize this equation. | ||
| | ||
| if (dx != nullptr) { | ||
| dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); | ||
| } | ||
| | ||
| if (dy != nullptr) { | ||
| shm[tid] += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); | ||
| } | ||
| | ||
| ttid += ELEMWISE_MAX_BLOCK_DIM; | ||
| } | ||
| | ||
| if (dy) { | ||
| __syncthreads(); | ||
| int h = pre * post; | ||
| h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; | ||
| | ||
| // Sum, could be optimized | ||
| if (tid == 0) { | ||
| for (int i = 1; i < h; ++i) { | ||
| shm[0] += shm[i]; | ||
| } | ||
| dy[j] = shm[0]; | ||
| } | ||
| } | ||
| } | ||
| | ||
| template <typename T, typename DX_OP, typename DY_OP> | ||
| static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x, | ||
| const T* y, const T* out, const T* dout, | ||
| int pre, int n, int post, DX_OP dx_op, | ||
| DY_OP dy_op, T* dx, T* dy) { | ||
| int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); | ||
| int gird_size = n; | ||
| int shared_mem_size = block_size * sizeof(T); | ||
| ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, shared_mem_size, | ||
| stream>>>(x, y, out, dout, pre, n, post, | ||
| dx_op, dy_op, dx, dy); | ||
| } | ||
| | ||
| #endif | ||
| | ||
| template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP> | ||
| void ElemwiseGradCompute(const framework::ExecutionContext& ctx, | ||
| const framework::Tensor& x, const framework::Tensor& y, | ||
| const framework::Tensor& out, | ||
| const framework::Tensor& dout, int axis, | ||
| framework::Tensor* dx, framework::Tensor* dy, | ||
| DX_OP dx_op, DY_OP dy_op) { | ||
| if (x.dims() == y.dims()) { | ||
| size_t N = static_cast<size_t>(framework::product(x.dims())); | ||
| platform::ForRange<DeviceContext> for_range( | ||
| ctx.template device_context<DeviceContext>(), N); | ||
| for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{ | ||
| x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op, | ||
| dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), | ||
| dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())}); | ||
| } else { // Y is a scalar | ||
| auto x_dim = x.dims(); | ||
| auto y_dim = y.dims(); | ||
| | ||
| if (y_dim.size() == 1 && y_dim[0] == 1) { | ||
| // y is a scalar | ||
| auto extended_dims = framework::vectorize(x_dim); | ||
| extended_dims.push_back(1); | ||
| x_dim = framework::make_ddim(extended_dims); | ||
| } | ||
| | ||
| axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis); | ||
| int pre, n, post; | ||
| get_mid_dims(x_dim, y_dim, axis, pre, n, post); | ||
| if (post == 1) { | ||
| int h = pre; | ||
| int w = n; | ||
| if (platform::is_gpu_place(ctx.GetPlace())) { | ||
| #ifdef __NVCC__ | ||
| ElemwiseGradBroadcast1CUDA( | ||
| ctx.template device_context<DeviceContext>().stream(), x.data<T>(), | ||
| y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op, | ||
| dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), | ||
| dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())); | ||
| #endif | ||
| } else { | ||
| ElemwiseGradBroadcast1CPU( | ||
| x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w, | ||
| dx_op, dy_op, | ||
| dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), | ||
| dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())); | ||
| } | ||
| } else { | ||
| if (platform::is_gpu_place(ctx.GetPlace())) { | ||
| #ifdef __NVCC__ | ||
| ElemwiseGradBroadcast2CUDA( | ||
| ctx.template device_context<DeviceContext>().stream(), x.data<T>(), | ||
| y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op, | ||
| dy_op, | ||
| dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), | ||
| dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())); | ||
| #endif | ||
| } else { | ||
| ElemwiseGradBroadcast2CPU( | ||
| x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, | ||
| post, dx_op, dy_op, | ||
| dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), | ||
| dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())); | ||
| } | ||
| } | ||
| } | ||
| }; | ||
| | ||
| template <typename DeviceContext, typename T, typename functor, | ||
| typename broadcastfunctor, typename broadcast2functor> | ||
| void ElementwiseGradCompute(const framework::ExecutionContext& ctx, | ||
| | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add
inlineThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, inline is decided by the compiler.