Skip to content

Commit f82ba73

Browse files
committed
Fix
1 parent 8341678 commit f82ba73

File tree

13 files changed

+158
-5
lines changed

13 files changed

+158
-5
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ bool BatchNormOpInferSymbolicShape(
378378
"ShapeError: the dimension of scale must equal to 1."
379379
"But received: the dimension of scale is [%d]",
380380
scale_dims.size()));
381-
infer_context->AddEqualCstr(scale_dims[0], C);
381+
if (C != 0) infer_context->AddEqualCstr(scale_dims[0], C);
382382
}
383383

384384
if (!bias_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) {
@@ -389,7 +389,7 @@ bool BatchNormOpInferSymbolicShape(
389389
"ShapeError: the dimension of bias must equal to 1."
390390
"But received: the dimension of bias is [%d]",
391391
bias_dims.size()));
392-
infer_context->AddEqualCstr(bias_dims[0], C);
392+
if (C != 0) infer_context->AddEqualCstr(bias_dims[0], C);
393393
}
394394

395395
// Set output shapes

paddle/phi/infermeta/multiary.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ void BatchNormInferMeta(const MetaTensor& x,
935935
const auto x_dims = x.dims();
936936
for (int i = 0; i < x_dims.size(); i++) {
937937
PADDLE_ENFORCE_EQ(
938-
(x_dims[i] == -1) || (x_dims[i] > 0),
938+
(x_dims[i] == -1) || (x_dims[i] >= 0),
939939
true,
940940
common::errors::InvalidArgument(
941941
"Each dimension of input tensor is expected to be -1 or a "
@@ -1001,7 +1001,7 @@ void BatchNormInferMeta(const MetaTensor& x,
10011001
check = false;
10021002
}
10031003

1004-
if (check) {
1004+
if (check && C != 0) {
10051005
PADDLE_ENFORCE_EQ(scale.dims()[0],
10061006
C,
10071007
common::errors::InvalidArgument(

paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/core/kernel_registry.h"
1919
#include "paddle/phi/kernels/batch_norm_kernel.h"
20+
#include "paddle/phi/kernels/full_kernel.h"
2021
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
2122
#include "paddle/phi/kernels/funcs/eigen/common.h"
2223
#include "paddle/phi/kernels/funcs/math_function.h"
@@ -326,6 +327,21 @@ void BatchNormGradKernel(const Context& dev_ctx,
326327
DenseTensor* x_grad,
327328
DenseTensor* scale_grad,
328329
DenseTensor* bias_grad) {
330+
if (x.numel() == 0) {
331+
dev_ctx.template Alloc<T>(x_grad);
332+
if (scale_grad)
333+
phi::Full<T, Context>(
334+
dev_ctx,
335+
phi::IntArray(common::vectorize(scale_grad->dims())),
336+
0,
337+
scale_grad);
338+
if (bias_grad)
339+
phi::Full<T, Context>(dev_ctx,
340+
phi::IntArray(common::vectorize(bias_grad->dims())),
341+
0,
342+
bias_grad);
343+
return;
344+
}
329345
BatchNormGradFunctor<T, Context>(dev_ctx,
330346
x,
331347
scale,

paddle/phi/kernels/cpu/batch_norm_kernel.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,19 @@ void BatchNormKernel(const Context& dev_ctx,
5151
DenseTensor* saved_mean,
5252
DenseTensor* saved_variance,
5353
DenseTensor* reserve_space) {
54+
if (x.numel() == 0) {
55+
dev_ctx.template Alloc<T>(y);
56+
if (mean_out) dev_ctx.template Alloc<T>(mean_out);
57+
if (variance_out) dev_ctx.template Alloc<T>(variance_out);
58+
if (saved_mean) dev_ctx.template Alloc<T>(saved_mean);
59+
if (saved_variance) dev_ctx.template Alloc<T>(saved_variance);
60+
if (reserve_space) {
61+
// infermeta dim is -1.
62+
reserve_space->Resize({0});
63+
dev_ctx.template Alloc<T>(reserve_space);
64+
}
65+
return;
66+
}
5467
bool test_mode = is_test && (!trainable_statistics);
5568

5669
bool global_stats = test_mode || use_global_stats;

paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,21 @@ void BatchNormGradKernel(const Context &dev_ctx,
13301330
DenseTensor *x_grad,
13311331
DenseTensor *scale_grad,
13321332
DenseTensor *bias_grad) {
1333+
if (x.numel() == 0) {
1334+
dev_ctx.template Alloc<T>(x_grad);
1335+
if (scale_grad)
1336+
phi::Full<T, Context>(
1337+
dev_ctx,
1338+
phi::IntArray(common::vectorize(scale_grad->dims())),
1339+
0,
1340+
scale_grad);
1341+
if (bias_grad)
1342+
phi::Full<T, Context>(dev_ctx,
1343+
phi::IntArray(common::vectorize(bias_grad->dims())),
1344+
0,
1345+
bias_grad);
1346+
return;
1347+
}
13331348
BatchNormGradFunctor<T, Context>(dev_ctx,
13341349
x,
13351350
scale,

paddle/phi/kernels/gpu/batch_norm_kernel.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,18 @@ void BatchNormKernel(const Context &dev_ctx,
533533
DenseTensor *saved_mean,
534534
DenseTensor *saved_variance,
535535
DenseTensor *reserve_space) {
536+
if (x.numel() == 0) {
537+
dev_ctx.template Alloc<T>(y);
538+
if (mean_out) dev_ctx.template Alloc<T>(mean_out);
539+
if (variance_out) dev_ctx.template Alloc<T>(variance_out);
540+
if (saved_mean) dev_ctx.template Alloc<T>(saved_mean);
541+
if (saved_variance) dev_ctx.template Alloc<T>(saved_variance);
542+
if (reserve_space) {
543+
reserve_space->Resize({0});
544+
dev_ctx.template Alloc<T>(reserve_space);
545+
}
546+
return;
547+
}
536548
double epsilon = epsilon_f;
537549
const bool trainable_stats = trainable_statistics;
538550
const DataLayout data_layout = common::StringToDataLayout(data_layout_str);

paddle/phi/kernels/gpudnn/softmax_grad_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ void SoftmaxGradGPUDNNKernel(const Context& dev_ctx,
2828
int axis,
2929
DenseTensor* x_grad) {
3030
dev_ctx.template Alloc<T>(x_grad);
31+
if (x_grad->numel() == 0) return;
3132

3233
const int rank = out.dims().size();
3334
// For 0D Tensor

paddle/phi/kernels/gpudnn/softmax_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ void SoftmaxGPUDNNKernel(const Context& dev_ctx,
2727
int axis,
2828
DenseTensor* out) {
2929
dev_ctx.template Alloc<T>(out);
30+
if (x.numel() == 0) return;
3031

3132
const int rank = x.dims().size();
3233
// For 0D Tensor

paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,21 @@ void BatchNormGradKernel(const Context &dev_ctx,
9090
DenseTensor *x_grad,
9191
DenseTensor *scale_grad,
9292
DenseTensor *bias_grad) {
93+
if (x.numel() == 0) {
94+
dev_ctx.template Alloc<T>(x_grad);
95+
if (scale_grad)
96+
phi::Full<T, Context>(
97+
dev_ctx,
98+
phi::IntArray(common::vectorize(scale_grad->dims())),
99+
0,
100+
scale_grad);
101+
if (bias_grad)
102+
phi::Full<T, Context>(dev_ctx,
103+
phi::IntArray(common::vectorize(bias_grad->dims())),
104+
0,
105+
bias_grad);
106+
return;
107+
}
93108
using XPUType = typename XPUTypeTrait<T>::Type;
94109
const auto *d_y = &y_grad;
95110
PADDLE_ENFORCE_EQ(data_layout == "NCHW" || data_layout == "NHWC",

paddle/phi/kernels/xpu/batch_norm_kernel.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,18 @@ void BatchNormKernel(const Context& dev_ctx,
4040
DenseTensor* saved_mean,
4141
DenseTensor* saved_variance,
4242
DenseTensor* reserve_space) {
43+
if (x.numel() == 0) {
44+
dev_ctx.template Alloc<T>(y);
45+
if (mean_out) dev_ctx.template Alloc<T>(mean_out);
46+
if (variance_out) dev_ctx.template Alloc<T>(variance_out);
47+
if (saved_mean) dev_ctx.template Alloc<T>(saved_mean);
48+
if (saved_variance) dev_ctx.template Alloc<T>(saved_variance);
49+
if (reserve_space) {
50+
reserve_space->Resize({0});
51+
dev_ctx.template Alloc<T>(reserve_space);
52+
}
53+
return;
54+
}
4355
using XPUType = typename XPUTypeTrait<T>::Type;
4456
bool test_mode = is_test && (!trainable_statistics);
4557
bool global_stats = test_mode || use_global_stats;

0 commit comments

Comments
 (0)