Skip to content

Commit ea940df

Browse files
committed
Fix
1 parent 8341678 commit ea940df

File tree

14 files changed

+190
-35
lines changed

14 files changed

+190
-35
lines changed

paddle/phi/infermeta/backward.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,7 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
921921

922922
void InstanceNormGradInferMeta(const MetaTensor& x,
923923
const MetaTensor& scale,
924+
const MetaTensor& bias,
924925
const MetaTensor& saved_mean,
925926
const MetaTensor& saved_variance,
926927
const MetaTensor& y_grad,
@@ -939,10 +940,18 @@ void InstanceNormGradInferMeta(const MetaTensor& x,
939940
x_grad->set_dtype(x.dtype());
940941
x_grad->set_layout(x.layout());
941942
if (scale_grad) {
942-
scale_grad->set_dims({C});
943+
if (C == 0) {
944+
scale_grad->set_dims({scale.dims()[0]});
945+
} else {
946+
scale_grad->set_dims({C});
947+
}
943948
}
944949
if (bias_grad) {
945-
bias_grad->set_dims({C});
950+
if (C == 0) {
951+
bias_grad->set_dims({bias.dims()[0]});
952+
} else {
953+
bias_grad->set_dims({C});
954+
}
946955
}
947956
}
948957
void InstanceNormDoubleGradInferMeta(const MetaTensor& x,

paddle/phi/infermeta/backward.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
345345

346346
void InstanceNormGradInferMeta(const MetaTensor& x,
347347
const MetaTensor& scale,
348+
const MetaTensor& bias,
348349
const MetaTensor& saved_mean,
349350
const MetaTensor& saved_variance,
350351
const MetaTensor& y_grad,

paddle/phi/infermeta/spmd_rules/instance_norm.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ SpmdInfo InstanceNormInferSpmd(const DistMetaTensor& x,
131131

132132
SpmdInfo InstanceNormGradInferSpmd(const DistMetaTensor& x,
133133
const DistMetaTensor& scale,
134+
const DistMetaTensor& bias UNUSED,
134135
const DistMetaTensor& saved_mean,
135136
const DistMetaTensor& saved_variance,
136137
const DistMetaTensor& y_grad,

paddle/phi/infermeta/spmd_rules/instance_norm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ SpmdInfo InstanceNormInferSpmd(const DistMetaTensor& x,
2727

2828
SpmdInfo InstanceNormGradInferSpmd(const DistMetaTensor& x,
2929
const DistMetaTensor& scale,
30+
const DistMetaTensor& bias,
3031
const DistMetaTensor& saved_mean,
3132
const DistMetaTensor& saved_variance,
3233
const DistMetaTensor& y_grad,

paddle/phi/infermeta/ternary.cc

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -826,13 +826,6 @@ void InstanceNormInferMeta(const MetaTensor& x,
826826
common::errors::InvalidArgument(
827827
"The y in InstanceNormInferMeta can't be nullptr."));
828828
const auto x_dims = x.dims();
829-
PADDLE_ENFORCE_NE(common::product(x_dims),
830-
0,
831-
common::errors::PreconditionNotMet(
832-
"The Input variable X has not "
833-
"been initialized. You may need to confirm "
834-
"if you put exe.run(startup_program) "
835-
"after optimizer.minimize function."));
836829
PADDLE_ENFORCE_GE(
837830
x_dims.size(),
838831
2,
@@ -867,13 +860,16 @@ void InstanceNormInferMeta(const MetaTensor& x,
867860
scale_dim.size()));
868861
bool check = config.is_runtime || contain_unknown_dim(scale_dim);
869862
if (check) {
870-
PADDLE_ENFORCE_EQ(scale_dim[0],
871-
C,
872-
common::errors::InvalidArgument(
873-
"ShapeError: the shape of scale must equal to [%d]"
874-
"But received: the shape of scale is [%d]",
875-
C,
876-
scale_dim[0]));
863+
if (C != 0) {
864+
PADDLE_ENFORCE_EQ(
865+
scale_dim[0],
866+
C,
867+
common::errors::InvalidArgument(
868+
"ShapeError: the shape of scale must equal to [%d]"
869+
"But received: the shape of scale is [%d]",
870+
C,
871+
scale_dim[0]));
872+
}
877873
}
878874
}
879875
if (bias) {
@@ -889,13 +885,15 @@ void InstanceNormInferMeta(const MetaTensor& x,
889885
bias_dim.size()));
890886
bool check = config.is_runtime || !contain_unknown_dim(bias_dim);
891887
if (check) {
892-
PADDLE_ENFORCE_EQ(bias_dim[0],
893-
C,
894-
common::errors::InvalidArgument(
895-
"ShapeError: the shape of bias must equal to [%d]"
896-
"But received: the shape of bias is [%d]",
897-
C,
898-
bias_dim[0]));
888+
if (C != 0) {
889+
PADDLE_ENFORCE_EQ(bias_dim[0],
890+
C,
891+
common::errors::InvalidArgument(
892+
"ShapeError: the shape of bias must equal to [%d]"
893+
"But received: the shape of bias is [%d]",
894+
C,
895+
bias_dim[0]));
896+
}
899897
}
900898
}
901899
y->set_dims(x_dims);

paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,19 @@ void InstanceNormGradKernel(const Context& dev_ctx,
5151
DenseTensor* d_x,
5252
DenseTensor* d_scale,
5353
DenseTensor* d_bias) {
54+
phi::funcs::SetConstant<CPUContext, T> set_constant;
55+
dev_ctx.template Alloc<T>(d_x);
56+
if (x.numel() == 0) {
57+
if (d_scale) {
58+
dev_ctx.template Alloc<T>(d_scale);
59+
set_constant(dev_ctx, d_scale, static_cast<T>(0));
60+
}
61+
if (d_bias) {
62+
dev_ctx.template Alloc<T>(d_bias);
63+
set_constant(dev_ctx, d_bias, static_cast<T>(0));
64+
}
65+
return;
66+
}
5467
const auto* scale_ptr = scale.get_ptr();
5568

5669
const auto& x_dims = x.dims();
@@ -60,7 +73,6 @@ void InstanceNormGradKernel(const Context& dev_ctx,
6073
const int NxC = N * C;
6174
const int sample_size = static_cast<int>(x.numel() / N / C);
6275

63-
dev_ctx.template Alloc<T>(d_x);
6476
auto* place = dev_ctx.eigen_device();
6577

6678
Eigen::DSizes<int, 2> rshape(NxC, sample_size);
@@ -83,8 +95,6 @@ void InstanceNormGradKernel(const Context& dev_ctx,
8395
NxC_shape.set(0, NxC);
8496
#endif
8597

86-
phi::funcs::SetConstant<CPUContext, T> set_constant;
87-
8898
DenseTensor scale_data;
8999
if (!scale_ptr) {
90100
scale_data.Resize({C});

paddle/phi/kernels/cpu/instance_norm_kernel.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,22 @@ void InstanceNormKernel(const Context& dev_ctx,
3838
DenseTensor* y,
3939
DenseTensor* saved_mean,
4040
DenseTensor* saved_variance) {
41+
phi::funcs::SetConstant<CPUContext, T> set_constant;
42+
if (x.numel() == 0) {
43+
dev_ctx.template Alloc<T>(y);
44+
set_constant(dev_ctx, y, static_cast<T>(0));
45+
46+
if (saved_mean) {
47+
dev_ctx.template Alloc<T>(saved_mean);
48+
set_constant(dev_ctx, saved_mean, static_cast<T>(0));
49+
}
50+
if (saved_variance) {
51+
dev_ctx.template Alloc<T>(saved_variance);
52+
set_constant(dev_ctx, saved_variance, static_cast<T>(0));
53+
}
54+
return;
55+
}
56+
4157
const auto& x_dims = x.dims();
4258
T epsilon = static_cast<T>(epsilon_f);
4359
const int N = static_cast<int>(x_dims[0]);
@@ -63,7 +79,6 @@ void InstanceNormKernel(const Context& dev_ctx,
6379
Eigen::IndexList<Eigen::type2index<1>> rdims;
6480
#endif
6581

66-
phi::funcs::SetConstant<CPUContext, T> set_constant;
6782
DenseTensor saved_mean_tmp, saved_variance_tmp;
6883
if (saved_mean) {
6984
dev_ctx.template Alloc<T>(saved_mean);

paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,11 +326,25 @@ void InstanceNormGradKernel(const Context &dev_ctx,
326326
x_tmp.ShareDataWith(x).Resize({1, NxC, H, W, D});
327327
d_y_tmp.ShareDataWith(d_y).Resize({1, NxC, H, W, D});
328328

329+
phi::funcs::SetConstant<GPUContext, AccT> set_constant;
330+
329331
dev_ctx.template Alloc<T>(d_x);
332+
if (x.numel() == 0) {
333+
if (d_scale) {
334+
dev_ctx.template Alloc<AccT>(d_scale);
335+
set_constant(dev_ctx, d_scale, static_cast<AccT>(0));
336+
}
337+
if (d_bias) {
338+
dev_ctx.template Alloc<AccT>(d_bias);
339+
set_constant(dev_ctx, d_bias, static_cast<AccT>(0));
340+
}
341+
return;
342+
}
330343
if (d_scale && d_bias) {
331344
dev_ctx.template Alloc<AccT>(d_scale);
332345
dev_ctx.template Alloc<AccT>(d_bias);
333346
}
347+
334348
if (scale_ptr) {
335349
PADDLE_ENFORCE_EQ(
336350
scale_ptr->dims().size(),
@@ -354,8 +368,6 @@ void InstanceNormGradKernel(const Context &dev_ctx,
354368
scale_ptr->dims()));
355369
}
356370

357-
phi::funcs::SetConstant<GPUContext, AccT> set_constant;
358-
359371
const int n = x.numel();
360372
const int block = 512;
361373
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();

paddle/phi/kernels/gpu/instance_norm_kernel.cu

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,20 @@ void InstanceNormKernel(const Context &dev_ctx,
6060
DenseTensor x_tmp;
6161
x_tmp.ShareDataWith(x).Resize({1, NxC, H, W, D});
6262
dev_ctx.template Alloc<T>(y);
63+
phi::funcs::SetConstant<GPUContext, BatchNormParamType<T>> functor;
64+
phi::funcs::SetConstant<GPUContext, T> functor_y;
65+
if (x.numel() == 0) {
66+
functor_y(dev_ctx, y, static_cast<T>(0));
67+
if (saved_mean) {
68+
dev_ctx.template Alloc<BatchNormParamType<T>>(saved_mean);
69+
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
70+
}
71+
if (saved_variance) {
72+
dev_ctx.template Alloc<BatchNormParamType<T>>(saved_variance);
73+
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
74+
}
75+
return;
76+
}
6377

6478
#ifdef PADDLE_WITH_HIP
6579
miopenTensorDescriptor_t data_desc_;
@@ -144,7 +158,6 @@ void InstanceNormKernel(const Context &dev_ctx,
144158
auto handle = dev_ctx.cudnn_handle();
145159

146160
DenseTensor saved_mean_tmp, saved_variance_tmp;
147-
phi::funcs::SetConstant<GPUContext, BatchNormParamType<T>> functor;
148161

149162
if (saved_mean) {
150163
dev_ctx.template Alloc<BatchNormParamType<T>>(saved_mean);

paddle/phi/kernels/xpu/instance_norm_grad_kernel.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "paddle/phi/kernels/instance_norm_grad_kernel.h"
1616
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1717
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/kernels/full_kernel.h"
1819
#include "paddle/phi/kernels/funcs/norm_utils.h"
1920

2021
namespace phi {
@@ -44,6 +45,23 @@ void InstanceNormGradKernel(const Context& dev_ctx,
4445
x_dims.size()));
4546

4647
dev_ctx.template Alloc<T>(d_x);
48+
if (x.numel() == 0) {
49+
if (d_scale) {
50+
phi::Full<float, Context>(
51+
dev_ctx,
52+
phi::IntArray(common::vectorize(d_scale->dims())),
53+
0.f,
54+
d_scale);
55+
}
56+
if (d_bias) {
57+
phi::Full<float, Context>(
58+
dev_ctx,
59+
phi::IntArray(common::vectorize(d_bias->dims())),
60+
0.f,
61+
d_bias);
62+
}
63+
return;
64+
}
4765
T* d_scale_data = nullptr;
4866
T* d_bias_data = nullptr;
4967
if (d_scale && d_bias) {

0 commit comments

Comments
 (0)