@@ -36,7 +36,7 @@ using ConstEigenVectorArrayMap =
3636 Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1 >>;
3737
3838template <typename T, typename Context>
39- void BatchNormGradFunctor (const Context& ctx ,
39+ void BatchNormGradFunctor (const Context& dev_ctx ,
4040 const DenseTensor& x,
4141 const paddle::optional<DenseTensor>& scale,
4242 const paddle::optional<DenseTensor>& bias,
@@ -117,7 +117,7 @@ void BatchNormGradFunctor(const Context& ctx,
117117
118118 // init output
119119 if (d_x) {
120- ctx .template Alloc <T>(d_x);
120+ dev_ctx .template Alloc <T>(d_x);
121121 }
122122
123123 const T* mean_data = nullptr ;
@@ -128,7 +128,7 @@ void BatchNormGradFunctor(const Context& ctx,
128128 const auto * running_variance = variance.get_ptr ();
129129 mean_data = running_mean->data <T>();
130130 inv_var_tensor.Resize ({C});
131- T* running_inv_var_data = ctx .template Alloc <T>(&inv_var_tensor);
131+ T* running_inv_var_data = dev_ctx .template Alloc <T>(&inv_var_tensor);
132132 EigenVectorArrayMap<T> inv_var_tmp (running_inv_var_data, C);
133133 ConstEigenVectorArrayMap<T> var_arr (running_variance->data <T>(), C);
134134
@@ -145,8 +145,8 @@ void BatchNormGradFunctor(const Context& ctx,
145145 T* d_bias_data = nullptr ;
146146 T* d_scale_data = nullptr ;
147147 if (d_scale && d_bias) {
148- d_bias_data = ctx .template Alloc <T>(d_bias);
149- d_scale_data = ctx .template Alloc <T>(d_scale);
148+ d_bias_data = dev_ctx .template Alloc <T>(d_bias);
149+ d_scale_data = dev_ctx .template Alloc <T>(d_scale);
150150 }
151151
152152 // d_bias = np.sum(d_y, axis=0)
@@ -162,7 +162,7 @@ void BatchNormGradFunctor(const Context& ctx,
162162 }
163163
164164 if (d_x && (N * sample_size) == 1 && !use_global_stats) {
165- phi::Copy (ctx , *d_y, ctx .GetPlace (), false , d_x);
165+ phi::Copy (dev_ctx , *d_y, dev_ctx .GetPlace (), false , d_x);
166166 return ;
167167 }
168168 auto * Scale = scale.get_ptr ();
@@ -185,13 +185,13 @@ void BatchNormGradFunctor(const Context& ctx,
185185
186186 DenseTensor dy_sum;
187187 dy_sum.Resize ({C});
188- auto dy_sum_data = ctx .template Alloc <T>(&dy_sum);
188+ auto dy_sum_data = dev_ctx .template Alloc <T>(&dy_sum);
189189 EigenVectorArrayMap<T> dy_sum_arr (dy_sum_data, C);
190190
191191 DenseTensor dy_mul_x_sub_mean_mul_invstd_sum;
192192 dy_mul_x_sub_mean_mul_invstd_sum.Resize ({C});
193193 auto dy_mul_x_sub_mean_mul_invstd_sum_data =
194- ctx .template Alloc <T>(&dy_mul_x_sub_mean_mul_invstd_sum);
194+ dev_ctx .template Alloc <T>(&dy_mul_x_sub_mean_mul_invstd_sum);
195195 EigenVectorArrayMap<T> dy_mul_x_sub_mean_mul_invstd_sum_arr (
196196 dy_mul_x_sub_mean_mul_invstd_sum_data, C);
197197
@@ -209,7 +209,8 @@ void BatchNormGradFunctor(const Context& ctx,
209209 case DataLayout::kNCHW : {
210210 if (is_inplace) {
211211 auto px = x;
212- EigenArrayMap<T> x_data (ctx.template Alloc <T>(&px), sample_size, N * C);
212+ EigenArrayMap<T> x_data (
213+ dev_ctx.template Alloc <T>(&px), sample_size, N * C);
213214 ConstEigenArrayMap<T> y_data (x.data <T>(), sample_size, N * C);
214215 for (int nc = 0 ; nc < N * C; ++nc) {
215216 x_data.col (nc) = (y_data.col (nc) - bias_arr (nc % C)) /
@@ -235,7 +236,7 @@ void BatchNormGradFunctor(const Context& ctx,
235236
236237 if (d_x) {
237238 EigenArrayMap<T> d_x_arr (
238- ctx .template Alloc <T>(d_x), sample_size, N * C);
239+ dev_ctx .template Alloc <T>(d_x), sample_size, N * C);
239240 if (!use_global_stats) {
240241 for (int nc = 0 ; nc < N * C; ++nc) {
241242 int c = nc % C;
@@ -257,7 +258,8 @@ void BatchNormGradFunctor(const Context& ctx,
257258 case DataLayout::kNHWC : {
258259 if (is_inplace) {
259260 auto px = x;
260- EigenArrayMap<T> x_data (ctx.template Alloc <T>(&px), C, N * sample_size);
261+ EigenArrayMap<T> x_data (
262+ dev_ctx.template Alloc <T>(&px), C, N * sample_size);
261263 ConstEigenArrayMap<T> y_data (x.data <T>(), C, N * sample_size);
262264 for (int nhw = 0 ; nhw < N * sample_size; nhw++) {
263265 x_data.col (nhw) =
@@ -281,7 +283,7 @@ void BatchNormGradFunctor(const Context& ctx,
281283
282284 if (d_x) {
283285 EigenArrayMap<T> d_x_arr (
284- ctx .template Alloc <T>(d_x), C, N * sample_size);
286+ dev_ctx .template Alloc <T>(d_x), C, N * sample_size);
285287 if (!use_global_stats) {
286288 for (int nhw = 0 ; nhw < N * sample_size; ++nhw) {
287289 d_x_arr.col (nhw) =
@@ -348,7 +350,7 @@ void BatchNormGradKernel(const Context& dev_ctx,
348350
349351template <typename T, typename Context>
350352void BatchNormDoubleGradKernel (
351- const Context& ctx ,
353+ const Context& dev_ctx ,
352354 const DenseTensor& x,
353355 const paddle::optional<DenseTensor>& scale,
354356 const paddle::optional<DenseTensor>& mean,
@@ -390,8 +392,8 @@ void BatchNormDoubleGradKernel(
390392 auto * dX = x_grad;
391393 auto * dScale = scale_grad;
392394 auto * ddY = y_grad_grad;
393- ctx .template Alloc <T>(dX);
394- ctx .template Alloc <T>(ddY);
395+ dev_ctx .template Alloc <T>(dX);
396+ dev_ctx .template Alloc <T>(ddY);
395397
396398 const auto & x_dims = X->dims ();
397399 const int C = static_cast <int >(
@@ -409,7 +411,7 @@ void BatchNormDoubleGradKernel(
409411 mean_data = running_mean->data <T>();
410412 inv_var_tensor.Resize ({C});
411413
412- T* running_inv_var_data = ctx .template Alloc <T>(&inv_var_tensor);
414+ T* running_inv_var_data = dev_ctx .template Alloc <T>(&inv_var_tensor);
413415 EigenVectorArrayMap<T> inv_var_tmp (running_inv_var_data, C);
414416 ConstEigenVectorArrayMap<T> var_arr (running_variance->data <T>(), C);
415417
@@ -427,15 +429,15 @@ void BatchNormDoubleGradKernel(
427429 if (data_layout == DataLayout::kNCHW && x_dims.size () > 2 ) {
428430 VLOG (3 ) << " Transform batchnorm output from NCHW to NHWC" ;
429431 // Input Tensor
430- ResizeToChannelLast<Context, T>(ctx , X, &transformed_x);
431- TransToChannelLast<Context, T>(ctx , X, &transformed_x);
432- ResizeToChannelLast<Context, T>(ctx , dY, &transformed_dy);
433- TransToChannelLast<Context, T>(ctx , dY, &transformed_dy);
434- ResizeToChannelLast<Context, T>(ctx , ddX, &transformed_ddx);
435- TransToChannelLast<Context, T>(ctx , ddX, &transformed_ddx);
432+ ResizeToChannelLast<Context, T>(dev_ctx , X, &transformed_x);
433+ TransToChannelLast<Context, T>(dev_ctx , X, &transformed_x);
434+ ResizeToChannelLast<Context, T>(dev_ctx , dY, &transformed_dy);
435+ TransToChannelLast<Context, T>(dev_ctx , dY, &transformed_dy);
436+ ResizeToChannelLast<Context, T>(dev_ctx , ddX, &transformed_ddx);
437+ TransToChannelLast<Context, T>(dev_ctx , ddX, &transformed_ddx);
436438 // Output Tensor
437- ResizeToChannelLast<Context, T>(ctx , dX, &transformed_dx);
438- ResizeToChannelLast<Context, T>(ctx , ddY, &transformed_ddy);
439+ ResizeToChannelLast<Context, T>(dev_ctx , dX, &transformed_dx);
440+ ResizeToChannelLast<Context, T>(dev_ctx , ddY, &transformed_ddy);
439441 } else {
440442 transformed_x.ShareDataWith (*X);
441443 transformed_dy.ShareDataWith (*dY);
@@ -452,29 +454,29 @@ void BatchNormDoubleGradKernel(
452454 Tensor mean_tile;
453455 mean_tile.Resize ({C, sample_size});
454456 EigenArrayMap<T> mean_tile_data (
455- ctx .template Alloc <T>(&mean_tile), C, sample_size);
457+ dev_ctx .template Alloc <T>(&mean_tile), C, sample_size);
456458
457459 DenseTensor inv_var_tile;
458460 inv_var_tile.Resize ({C, sample_size});
459461 EigenArrayMap<T> inv_var_tile_data (
460- ctx .template Alloc <T>(&inv_var_tile), C, sample_size);
462+ dev_ctx .template Alloc <T>(&inv_var_tile), C, sample_size);
461463
462464 mean_tile_data = mean_arr.replicate (1 , sample_size);
463465 inv_var_tile_data = inv_var_arr.replicate (1 , sample_size);
464466
465467 DenseTensor Scale_data;
466468 if (!Scale) {
467469 Scale_data.Resize ({C});
468- ctx .template Alloc <T>(&Scale_data);
469- set_constant (ctx , &Scale_data, static_cast <T>(1 ));
470+ dev_ctx .template Alloc <T>(&Scale_data);
471+ set_constant (dev_ctx , &Scale_data, static_cast <T>(1 ));
470472 }
471473 ConstEigenVectorArrayMap<T> scale_arr (
472474 Scale ? Scale->data <T>() : Scale_data.data <T>(), C);
473475
474476 Tensor scale_tile;
475477 scale_tile.Resize ({C, sample_size});
476478 EigenArrayMap<T> scale_tile_data (
477- ctx .template Alloc <T>(&scale_tile), C, sample_size);
479+ dev_ctx .template Alloc <T>(&scale_tile), C, sample_size);
478480 scale_tile_data = scale_arr.replicate (1 , sample_size);
479481
480482 ConstEigenArrayMap<T> dy_arr (transformed_dy.data <T>(), C, sample_size);
@@ -484,13 +486,13 @@ void BatchNormDoubleGradKernel(
484486 x_sub_mean_mul_invstd.Resize ({C, sample_size});
485487
486488 EigenArrayMap<T> x_sub_mean_mul_invstd_arr (
487- ctx .template Alloc <T>(&x_sub_mean_mul_invstd), C, sample_size);
489+ dev_ctx .template Alloc <T>(&x_sub_mean_mul_invstd), C, sample_size);
488490 x_sub_mean_mul_invstd_arr = (x_arr - mean_tile_data) * inv_var_tile_data;
489491
490492 if (dX) {
491- ctx .template Alloc <T>(dX);
493+ dev_ctx .template Alloc <T>(dX);
492494 EigenArrayMap<T> dx_arr (
493- ctx .template Alloc <T>(&transformed_dx), C, sample_size);
495+ dev_ctx .template Alloc <T>(&transformed_dx), C, sample_size);
494496 dx_arr.setZero ();
495497 if (use_global_stats) {
496498 // math: dx = (ddscale * dy) * inv_var
@@ -499,7 +501,7 @@ void BatchNormDoubleGradKernel(
499501 Tensor ddscale_tile;
500502 ddscale_tile.Resize ({C, sample_size});
501503 EigenArrayMap<T> ddscale_tile_data (
502- ctx .template Alloc <T>(&ddscale_tile), C, sample_size);
504+ dev_ctx .template Alloc <T>(&ddscale_tile), C, sample_size);
503505 ddscale_tile_data = ddscale_arr.replicate (1 , sample_size);
504506
505507 dx_arr = dy_arr * ddscale_tile_data * inv_var_tile_data;
@@ -551,7 +553,7 @@ void BatchNormDoubleGradKernel(
551553 Tensor ddscale_tile;
552554 ddscale_tile.Resize ({C, sample_size});
553555 EigenArrayMap<T> ddscale_tile_data (
554- ctx .template Alloc <T>(&ddscale_tile), C, sample_size);
556+ dev_ctx .template Alloc <T>(&ddscale_tile), C, sample_size);
555557 ddscale_tile_data = ddscale_arr.replicate (1 , sample_size);
556558
557559 dx_arr +=
@@ -569,11 +571,11 @@ void BatchNormDoubleGradKernel(
569571 }
570572 if (data_layout == DataLayout::kNCHW ) {
571573 VLOG (3 ) << " Transform batchnorm output from NHWC to NCHW" ;
572- TransToChannelFirst<Context, T>(ctx , &transformed_dx, dX);
574+ TransToChannelFirst<Context, T>(dev_ctx , &transformed_dx, dX);
573575 }
574576 }
575577 if (dScale) {
576- EigenVectorArrayMap<T> dscale_arr (ctx .template Alloc <T>(dScale), C);
578+ EigenVectorArrayMap<T> dscale_arr (dev_ctx .template Alloc <T>(dScale), C);
577579 dscale_arr.setZero ();
578580 if (use_global_stats) {
579581 // math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var
@@ -588,7 +590,7 @@ void BatchNormDoubleGradKernel(
588590 Tensor first_grad;
589591 first_grad.Resize ({C, sample_size});
590592 EigenArrayMap<T> first_grad_arr (
591- ctx .template Alloc <T>(&first_grad), C, sample_size);
593+ dev_ctx .template Alloc <T>(&first_grad), C, sample_size);
592594 first_grad_arr.setZero ();
593595
594596 first_grad_arr +=
@@ -607,9 +609,9 @@ void BatchNormDoubleGradKernel(
607609 }
608610
609611 if (ddY) {
610- ctx .template Alloc <T>(ddY);
612+ dev_ctx .template Alloc <T>(ddY);
611613 EigenArrayMap<T> ddy_arr (
612- ctx .template Alloc <T>(&transformed_ddy), C, sample_size);
614+ dev_ctx .template Alloc <T>(&transformed_ddy), C, sample_size);
613615 ddy_arr.setZero ();
614616 if (use_global_stats) { // NOLINT
615617 // math: ddy = r * ddx * inv_var + ddbias +
@@ -639,7 +641,7 @@ void BatchNormDoubleGradKernel(
639641 Tensor ddscale_tile;
640642 ddscale_tile.Resize ({C, sample_size});
641643 EigenArrayMap<T> ddscale_tile_data (
642- ctx .template Alloc <T>(&ddscale_tile), C, sample_size);
644+ dev_ctx .template Alloc <T>(&ddscale_tile), C, sample_size);
643645 ddscale_tile_data = ddscale_arr.replicate (1 , sample_size);
644646
645647 ddy_arr += x_sub_mean_mul_invstd_arr * ddscale_tile_data;
@@ -650,15 +652,15 @@ void BatchNormDoubleGradKernel(
650652 Tensor ddbias_tile;
651653 ddbias_tile.Resize ({C, sample_size});
652654 EigenArrayMap<T> ddbias_tile_data (
653- ctx .template Alloc <T>(&ddbias_tile), C, sample_size);
655+ dev_ctx .template Alloc <T>(&ddbias_tile), C, sample_size);
654656 ddbias_tile_data = ddbias_arr.replicate (1 , sample_size);
655657
656658 ddy_arr += ddbias_tile_data;
657659 }
658660
659661 if (data_layout == DataLayout::kNCHW ) {
660662 VLOG (3 ) << " Transform batchnorm output from NHWC to NCHW" ;
661- TransToChannelFirst<Context, T>(ctx , &transformed_ddy, ddY);
663+ TransToChannelFirst<Context, T>(dev_ctx , &transformed_ddy, ddY);
662664 }
663665 }
664666}
0 commit comments