Skip to content

Commit 85f861b

Browse files
committed
Fix
1 parent 6635a46 commit 85f861b

File tree

3 files changed

+89
-87
lines changed

3 files changed

+89
-87
lines changed

paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ using ConstEigenVectorArrayMap =
3636
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
3737

3838
template <typename T, typename Context>
39-
void BatchNormGradFunctor(const Context& ctx,
39+
void BatchNormGradFunctor(const Context& dev_ctx,
4040
const DenseTensor& x,
4141
const paddle::optional<DenseTensor>& scale,
4242
const paddle::optional<DenseTensor>& bias,
@@ -117,7 +117,7 @@ void BatchNormGradFunctor(const Context& ctx,
117117

118118
// init output
119119
if (d_x) {
120-
ctx.template Alloc<T>(d_x);
120+
dev_ctx.template Alloc<T>(d_x);
121121
}
122122

123123
const T* mean_data = nullptr;
@@ -128,7 +128,7 @@ void BatchNormGradFunctor(const Context& ctx,
128128
const auto* running_variance = variance.get_ptr();
129129
mean_data = running_mean->data<T>();
130130
inv_var_tensor.Resize({C});
131-
T* running_inv_var_data = ctx.template Alloc<T>(&inv_var_tensor);
131+
T* running_inv_var_data = dev_ctx.template Alloc<T>(&inv_var_tensor);
132132
EigenVectorArrayMap<T> inv_var_tmp(running_inv_var_data, C);
133133
ConstEigenVectorArrayMap<T> var_arr(running_variance->data<T>(), C);
134134

@@ -145,8 +145,8 @@ void BatchNormGradFunctor(const Context& ctx,
145145
T* d_bias_data = nullptr;
146146
T* d_scale_data = nullptr;
147147
if (d_scale && d_bias) {
148-
d_bias_data = ctx.template Alloc<T>(d_bias);
149-
d_scale_data = ctx.template Alloc<T>(d_scale);
148+
d_bias_data = dev_ctx.template Alloc<T>(d_bias);
149+
d_scale_data = dev_ctx.template Alloc<T>(d_scale);
150150
}
151151

152152
// d_bias = np.sum(d_y, axis=0)
@@ -162,7 +162,7 @@ void BatchNormGradFunctor(const Context& ctx,
162162
}
163163

164164
if (d_x && (N * sample_size) == 1 && !use_global_stats) {
165-
phi::Copy(ctx, *d_y, ctx.GetPlace(), false, d_x);
165+
phi::Copy(dev_ctx, *d_y, dev_ctx.GetPlace(), false, d_x);
166166
return;
167167
}
168168
auto* Scale = scale.get_ptr();
@@ -185,13 +185,13 @@ void BatchNormGradFunctor(const Context& ctx,
185185

186186
DenseTensor dy_sum;
187187
dy_sum.Resize({C});
188-
auto dy_sum_data = ctx.template Alloc<T>(&dy_sum);
188+
auto dy_sum_data = dev_ctx.template Alloc<T>(&dy_sum);
189189
EigenVectorArrayMap<T> dy_sum_arr(dy_sum_data, C);
190190

191191
DenseTensor dy_mul_x_sub_mean_mul_invstd_sum;
192192
dy_mul_x_sub_mean_mul_invstd_sum.Resize({C});
193193
auto dy_mul_x_sub_mean_mul_invstd_sum_data =
194-
ctx.template Alloc<T>(&dy_mul_x_sub_mean_mul_invstd_sum);
194+
dev_ctx.template Alloc<T>(&dy_mul_x_sub_mean_mul_invstd_sum);
195195
EigenVectorArrayMap<T> dy_mul_x_sub_mean_mul_invstd_sum_arr(
196196
dy_mul_x_sub_mean_mul_invstd_sum_data, C);
197197

@@ -209,7 +209,8 @@ void BatchNormGradFunctor(const Context& ctx,
209209
case DataLayout::kNCHW: {
210210
if (is_inplace) {
211211
auto px = x;
212-
EigenArrayMap<T> x_data(ctx.template Alloc<T>(&px), sample_size, N * C);
212+
EigenArrayMap<T> x_data(
213+
dev_ctx.template Alloc<T>(&px), sample_size, N * C);
213214
ConstEigenArrayMap<T> y_data(x.data<T>(), sample_size, N * C);
214215
for (int nc = 0; nc < N * C; ++nc) {
215216
x_data.col(nc) = (y_data.col(nc) - bias_arr(nc % C)) /
@@ -235,7 +236,7 @@ void BatchNormGradFunctor(const Context& ctx,
235236

236237
if (d_x) {
237238
EigenArrayMap<T> d_x_arr(
238-
ctx.template Alloc<T>(d_x), sample_size, N * C);
239+
dev_ctx.template Alloc<T>(d_x), sample_size, N * C);
239240
if (!use_global_stats) {
240241
for (int nc = 0; nc < N * C; ++nc) {
241242
int c = nc % C;
@@ -257,7 +258,8 @@ void BatchNormGradFunctor(const Context& ctx,
257258
case DataLayout::kNHWC: {
258259
if (is_inplace) {
259260
auto px = x;
260-
EigenArrayMap<T> x_data(ctx.template Alloc<T>(&px), C, N * sample_size);
261+
EigenArrayMap<T> x_data(
262+
dev_ctx.template Alloc<T>(&px), C, N * sample_size);
261263
ConstEigenArrayMap<T> y_data(x.data<T>(), C, N * sample_size);
262264
for (int nhw = 0; nhw < N * sample_size; nhw++) {
263265
x_data.col(nhw) =
@@ -281,7 +283,7 @@ void BatchNormGradFunctor(const Context& ctx,
281283

282284
if (d_x) {
283285
EigenArrayMap<T> d_x_arr(
284-
ctx.template Alloc<T>(d_x), C, N * sample_size);
286+
dev_ctx.template Alloc<T>(d_x), C, N * sample_size);
285287
if (!use_global_stats) {
286288
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
287289
d_x_arr.col(nhw) =
@@ -348,7 +350,7 @@ void BatchNormGradKernel(const Context& dev_ctx,
348350

349351
template <typename T, typename Context>
350352
void BatchNormDoubleGradKernel(
351-
const Context& ctx,
353+
const Context& dev_ctx,
352354
const DenseTensor& x,
353355
const paddle::optional<DenseTensor>& scale,
354356
const paddle::optional<DenseTensor>& mean,
@@ -390,8 +392,8 @@ void BatchNormDoubleGradKernel(
390392
auto* dX = x_grad;
391393
auto* dScale = scale_grad;
392394
auto* ddY = y_grad_grad;
393-
ctx.template Alloc<T>(dX);
394-
ctx.template Alloc<T>(ddY);
395+
dev_ctx.template Alloc<T>(dX);
396+
dev_ctx.template Alloc<T>(ddY);
395397

396398
const auto& x_dims = X->dims();
397399
const int C = static_cast<int>(
@@ -409,7 +411,7 @@ void BatchNormDoubleGradKernel(
409411
mean_data = running_mean->data<T>();
410412
inv_var_tensor.Resize({C});
411413

412-
T* running_inv_var_data = ctx.template Alloc<T>(&inv_var_tensor);
414+
T* running_inv_var_data = dev_ctx.template Alloc<T>(&inv_var_tensor);
413415
EigenVectorArrayMap<T> inv_var_tmp(running_inv_var_data, C);
414416
ConstEigenVectorArrayMap<T> var_arr(running_variance->data<T>(), C);
415417

@@ -427,15 +429,15 @@ void BatchNormDoubleGradKernel(
427429
if (data_layout == DataLayout::kNCHW && x_dims.size() > 2) {
428430
VLOG(3) << "Transform batchnorm output from NCHW to NHWC";
429431
// Input Tensor
430-
ResizeToChannelLast<Context, T>(ctx, X, &transformed_x);
431-
TransToChannelLast<Context, T>(ctx, X, &transformed_x);
432-
ResizeToChannelLast<Context, T>(ctx, dY, &transformed_dy);
433-
TransToChannelLast<Context, T>(ctx, dY, &transformed_dy);
434-
ResizeToChannelLast<Context, T>(ctx, ddX, &transformed_ddx);
435-
TransToChannelLast<Context, T>(ctx, ddX, &transformed_ddx);
432+
ResizeToChannelLast<Context, T>(dev_ctx, X, &transformed_x);
433+
TransToChannelLast<Context, T>(dev_ctx, X, &transformed_x);
434+
ResizeToChannelLast<Context, T>(dev_ctx, dY, &transformed_dy);
435+
TransToChannelLast<Context, T>(dev_ctx, dY, &transformed_dy);
436+
ResizeToChannelLast<Context, T>(dev_ctx, ddX, &transformed_ddx);
437+
TransToChannelLast<Context, T>(dev_ctx, ddX, &transformed_ddx);
436438
// Output Tensor
437-
ResizeToChannelLast<Context, T>(ctx, dX, &transformed_dx);
438-
ResizeToChannelLast<Context, T>(ctx, ddY, &transformed_ddy);
439+
ResizeToChannelLast<Context, T>(dev_ctx, dX, &transformed_dx);
440+
ResizeToChannelLast<Context, T>(dev_ctx, ddY, &transformed_ddy);
439441
} else {
440442
transformed_x.ShareDataWith(*X);
441443
transformed_dy.ShareDataWith(*dY);
@@ -452,29 +454,29 @@ void BatchNormDoubleGradKernel(
452454
Tensor mean_tile;
453455
mean_tile.Resize({C, sample_size});
454456
EigenArrayMap<T> mean_tile_data(
455-
ctx.template Alloc<T>(&mean_tile), C, sample_size);
457+
dev_ctx.template Alloc<T>(&mean_tile), C, sample_size);
456458

457459
DenseTensor inv_var_tile;
458460
inv_var_tile.Resize({C, sample_size});
459461
EigenArrayMap<T> inv_var_tile_data(
460-
ctx.template Alloc<T>(&inv_var_tile), C, sample_size);
462+
dev_ctx.template Alloc<T>(&inv_var_tile), C, sample_size);
461463

462464
mean_tile_data = mean_arr.replicate(1, sample_size);
463465
inv_var_tile_data = inv_var_arr.replicate(1, sample_size);
464466

465467
DenseTensor Scale_data;
466468
if (!Scale) {
467469
Scale_data.Resize({C});
468-
ctx.template Alloc<T>(&Scale_data);
469-
set_constant(ctx, &Scale_data, static_cast<T>(1));
470+
dev_ctx.template Alloc<T>(&Scale_data);
471+
set_constant(dev_ctx, &Scale_data, static_cast<T>(1));
470472
}
471473
ConstEigenVectorArrayMap<T> scale_arr(
472474
Scale ? Scale->data<T>() : Scale_data.data<T>(), C);
473475

474476
Tensor scale_tile;
475477
scale_tile.Resize({C, sample_size});
476478
EigenArrayMap<T> scale_tile_data(
477-
ctx.template Alloc<T>(&scale_tile), C, sample_size);
479+
dev_ctx.template Alloc<T>(&scale_tile), C, sample_size);
478480
scale_tile_data = scale_arr.replicate(1, sample_size);
479481

480482
ConstEigenArrayMap<T> dy_arr(transformed_dy.data<T>(), C, sample_size);
@@ -484,13 +486,13 @@ void BatchNormDoubleGradKernel(
484486
x_sub_mean_mul_invstd.Resize({C, sample_size});
485487

486488
EigenArrayMap<T> x_sub_mean_mul_invstd_arr(
487-
ctx.template Alloc<T>(&x_sub_mean_mul_invstd), C, sample_size);
489+
dev_ctx.template Alloc<T>(&x_sub_mean_mul_invstd), C, sample_size);
488490
x_sub_mean_mul_invstd_arr = (x_arr - mean_tile_data) * inv_var_tile_data;
489491

490492
if (dX) {
491-
ctx.template Alloc<T>(dX);
493+
dev_ctx.template Alloc<T>(dX);
492494
EigenArrayMap<T> dx_arr(
493-
ctx.template Alloc<T>(&transformed_dx), C, sample_size);
495+
dev_ctx.template Alloc<T>(&transformed_dx), C, sample_size);
494496
dx_arr.setZero();
495497
if (use_global_stats) {
496498
// math: dx = (ddscale * dy) * inv_var
@@ -499,7 +501,7 @@ void BatchNormDoubleGradKernel(
499501
Tensor ddscale_tile;
500502
ddscale_tile.Resize({C, sample_size});
501503
EigenArrayMap<T> ddscale_tile_data(
502-
ctx.template Alloc<T>(&ddscale_tile), C, sample_size);
504+
dev_ctx.template Alloc<T>(&ddscale_tile), C, sample_size);
503505
ddscale_tile_data = ddscale_arr.replicate(1, sample_size);
504506

505507
dx_arr = dy_arr * ddscale_tile_data * inv_var_tile_data;
@@ -551,7 +553,7 @@ void BatchNormDoubleGradKernel(
551553
Tensor ddscale_tile;
552554
ddscale_tile.Resize({C, sample_size});
553555
EigenArrayMap<T> ddscale_tile_data(
554-
ctx.template Alloc<T>(&ddscale_tile), C, sample_size);
556+
dev_ctx.template Alloc<T>(&ddscale_tile), C, sample_size);
555557
ddscale_tile_data = ddscale_arr.replicate(1, sample_size);
556558

557559
dx_arr +=
@@ -569,11 +571,11 @@ void BatchNormDoubleGradKernel(
569571
}
570572
if (data_layout == DataLayout::kNCHW) {
571573
VLOG(3) << "Transform batchnorm output from NHWC to NCHW";
572-
TransToChannelFirst<Context, T>(ctx, &transformed_dx, dX);
574+
TransToChannelFirst<Context, T>(dev_ctx, &transformed_dx, dX);
573575
}
574576
}
575577
if (dScale) {
576-
EigenVectorArrayMap<T> dscale_arr(ctx.template Alloc<T>(dScale), C);
578+
EigenVectorArrayMap<T> dscale_arr(dev_ctx.template Alloc<T>(dScale), C);
577579
dscale_arr.setZero();
578580
if (use_global_stats) {
579581
// math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var
@@ -588,7 +590,7 @@ void BatchNormDoubleGradKernel(
588590
Tensor first_grad;
589591
first_grad.Resize({C, sample_size});
590592
EigenArrayMap<T> first_grad_arr(
591-
ctx.template Alloc<T>(&first_grad), C, sample_size);
593+
dev_ctx.template Alloc<T>(&first_grad), C, sample_size);
592594
first_grad_arr.setZero();
593595

594596
first_grad_arr +=
@@ -607,9 +609,9 @@ void BatchNormDoubleGradKernel(
607609
}
608610

609611
if (ddY) {
610-
ctx.template Alloc<T>(ddY);
612+
dev_ctx.template Alloc<T>(ddY);
611613
EigenArrayMap<T> ddy_arr(
612-
ctx.template Alloc<T>(&transformed_ddy), C, sample_size);
614+
dev_ctx.template Alloc<T>(&transformed_ddy), C, sample_size);
613615
ddy_arr.setZero();
614616
if (use_global_stats) { // NOLINT
615617
// math: ddy = r * ddx * inv_var + ddbias +
@@ -639,7 +641,7 @@ void BatchNormDoubleGradKernel(
639641
Tensor ddscale_tile;
640642
ddscale_tile.Resize({C, sample_size});
641643
EigenArrayMap<T> ddscale_tile_data(
642-
ctx.template Alloc<T>(&ddscale_tile), C, sample_size);
644+
dev_ctx.template Alloc<T>(&ddscale_tile), C, sample_size);
643645
ddscale_tile_data = ddscale_arr.replicate(1, sample_size);
644646

645647
ddy_arr += x_sub_mean_mul_invstd_arr * ddscale_tile_data;
@@ -650,15 +652,15 @@ void BatchNormDoubleGradKernel(
650652
Tensor ddbias_tile;
651653
ddbias_tile.Resize({C, sample_size});
652654
EigenArrayMap<T> ddbias_tile_data(
653-
ctx.template Alloc<T>(&ddbias_tile), C, sample_size);
655+
dev_ctx.template Alloc<T>(&ddbias_tile), C, sample_size);
654656
ddbias_tile_data = ddbias_arr.replicate(1, sample_size);
655657

656658
ddy_arr += ddbias_tile_data;
657659
}
658660

659661
if (data_layout == DataLayout::kNCHW) {
660662
VLOG(3) << "Transform batchnorm output from NHWC to NCHW";
661-
TransToChannelFirst<Context, T>(ctx, &transformed_ddy, ddY);
663+
TransToChannelFirst<Context, T>(dev_ctx, &transformed_ddy, ddY);
662664
}
663665
}
664666
}

0 commit comments

Comments
 (0)