Skip to content

Commit 272b32f

Browse files
authored
Replacing dropout eval eigen usage by cuda kernel (#40053)
* Replacing dropout eval eigen usage by cuda kernel
1 parent a8e02ef commit 272b32f

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

paddle/fluid/operators/dropout_impl.cu.h

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,15 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
184184
bool is_fix_seed, int seed_val, const Tensor& x,
185185
const Tensor* seed, Tensor* mask, Tensor* y) {
186186
auto& place = *dev_ctx.eigen_device();
187+
int64_t x_numel = x.numel();
188+
auto stream = dev_ctx.stream();
189+
auto* x_data = x.data<T>();
190+
auto* y_data = y->data<T>();
187191

188192
if (!is_test) {
189-
int64_t x_numel = x.numel();
190-
auto stream = dev_ctx.stream();
191193
auto* mask_data = mask->data<uint8_t>();
192194
size_t size = phi::product(mask->dims());
193195

194-
auto* x_data = x.data<T>();
195-
auto* y_data = y->data<T>();
196196
if (dropout_prob == 1.0f) {
197197
#ifdef PADDLE_WITH_HIP
198198
PADDLE_ENFORCE_GPU_SUCCESS(
@@ -254,12 +254,24 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
254254
}
255255
#endif
256256
} else {
257-
auto X = EigenMatrix<T>::Reshape(x, 1);
258-
auto Y = EigenMatrix<T>::Reshape(*y, 1);
259257
if (upscale_in_train) {
260-
Y.device(place) = X;
258+
// todo: can y share with data with x directly?
259+
#ifdef PADDLE_WITH_HIP
260+
PADDLE_ENFORCE_GPU_SUCCESS(
261+
hipMemcpyAsync(y_data, x_data, sizeof(T) * x_numel,
262+
hipMemcpyDeviceToDevice, stream));
263+
#else
264+
PADDLE_ENFORCE_GPU_SUCCESS(
265+
cudaMemcpyAsync(y_data, x_data, sizeof(T) * x_numel,
266+
cudaMemcpyDeviceToDevice, stream));
267+
#endif
261268
} else {
262-
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
269+
T factor = static_cast<T>(1.0f - dropout_prob);
270+
std::vector<const framework::Tensor*> ins = {&x};
271+
std::vector<framework::Tensor*> outs = {y};
272+
auto functor = phi::funcs::ScaleFunctor<T>(factor);
273+
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
274+
&outs, functor);
263275
}
264276
}
265277
}

0 commit comments

Comments
 (0)