Skip to content

Commit 5b69242

Browse files
authored
modify datanorm op test=develop (#23030)
1 parent 3e1676f commit 5b69242

File tree

4 files changed

+470
-43
lines changed

4 files changed

+470
-43
lines changed

paddle/fluid/framework/unused_var_check.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ const std::unordered_set<std::string> op_has_unsed_vars_white_list = {
5353
"precision_recall", // 1
5454
"fusion_seqpool_cvm_concat", // 2
5555
"fused_batch_norm_act", // 2
56-
"fused_batch_norm_act_grad" // 2
56+
"fused_batch_norm_act_grad", // 2
57+
"data_norm", // 0
58+
"data_norm_grad", // 0
5759
};
5860

5961
namespace paddle {

paddle/fluid/operators/data_norm_op.cc

Lines changed: 236 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,17 @@ class DataNormOp : public framework::OperatorWithKernel {
5151
PADDLE_ENFORCE(ctx->HasOutput("Means"), "");
5252
PADDLE_ENFORCE(ctx->HasOutput("Scales"), "");
5353
PADDLE_ENFORCE(ctx->HasOutput("Y"), "");
54+
bool enable_scale_and_shift =
55+
ctx->Attrs().Get<bool>("enable_scale_and_shift");
56+
if (enable_scale_and_shift) {
57+
PADDLE_ENFORCE_EQ(
58+
ctx->HasInput("scale_w"), true,
59+
platform::errors::InvalidArgument(
60+
"Input(scale_w) of DataNormOp should not be null."));
61+
PADDLE_ENFORCE_EQ(ctx->HasInput("bias"), true,
62+
platform::errors::InvalidArgument(
63+
"Input(bias) of DataNormOp should not be null."));
64+
}
5465

5566
const auto x_dims = ctx->GetInputDim("X");
5667
const DataLayout data_layout = framework::StringToDataLayout(
@@ -72,6 +83,45 @@ class DataNormOp : public framework::OperatorWithKernel {
7283
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C);
7384
}
7485

86+
if (enable_scale_and_shift) {
87+
auto scale_dim = ctx->GetInputDim("scale_w");
88+
auto bias_dim = ctx->GetInputDim("bias");
89+
90+
PADDLE_ENFORCE_EQ(
91+
scale_dim.size(), 1UL,
92+
platform::errors::InvalidArgument("the dimensionof scale"
93+
"must equal to 1. But received: "
94+
"the shape of scale is [%s], "
95+
"the dimensionof scale is [%d]",
96+
scale_dim, scale_dim.size()));
97+
PADDLE_ENFORCE_EQ(
98+
bias_dim.size(), 1UL,
99+
platform::errors::InvalidArgument("the dimension of bias"
100+
"must equal to 1. But received: "
101+
"the shape of bias is [%s],"
102+
"the dimension of bias is [%d]",
103+
bias_dim, bias_dim.size()));
104+
105+
bool check = true;
106+
if ((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 ||
107+
framework::product(bias_dim) <= 0)) {
108+
check = false;
109+
}
110+
111+
if (check) {
112+
PADDLE_ENFORCE_EQ(scale_dim[0], C,
113+
platform::errors::InvalidArgument(
114+
"the shape of scale must equal to [%d]"
115+
"But received: the shape of scale is [%d]",
116+
C, scale_dim[0]));
117+
PADDLE_ENFORCE_EQ(bias_dim[0], C,
118+
platform::errors::InvalidArgument(
119+
"the shape of bias must equal to [%d]"
120+
"But received: the shape of bias is [%d]",
121+
C, bias_dim[0]));
122+
}
123+
}
124+
75125
ctx->SetOutputDim("Y", x_dims);
76126
ctx->SetOutputDim("Means", {C});
77127
ctx->SetOutputDim("Scales", {C});
@@ -99,6 +149,17 @@ class DataNormOp : public framework::OperatorWithKernel {
99149
ctx, "BatchSquareSum"),
100150
"BatchSquareSum input should be of float type");
101151

152+
bool enable_scale_and_shift = ctx.Attr<bool>("enable_scale_and_shift");
153+
if (enable_scale_and_shift) {
154+
PADDLE_ENFORCE_EQ(dn_param_type,
155+
OperatorWithKernel::IndicateVarDataType(ctx, "scale_w"),
156+
platform::errors::InvalidArgument(
157+
"scale_w input should be of float type"));
158+
PADDLE_ENFORCE_EQ(dn_param_type,
159+
OperatorWithKernel::IndicateVarDataType(ctx, "bias"),
160+
platform::errors::InvalidArgument(
161+
"bias input should be of float type"));
162+
}
102163
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
103164
framework::LibraryType library = framework::LibraryType::kPlain;
104165
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
@@ -133,6 +194,19 @@ class DataNormOpMaker : public framework::OpProtoAndCheckerMaker {
133194
"summary_decay_rate",
134195
"(float, default 0.9999999) The decay rate when update the summary")
135196
.SetDefault(0.9999999);
197+
AddAttr<bool>(
198+
"enable_scale_and_shift",
199+
"(bool, default false) Set to true to enable scale and shift such as "
200+
"batch_norm op")
201+
.SetDefault(false);
202+
AddInput("scale_w",
203+
"scale_w is a 1-dimensional tensor of size C "
204+
"that is applied to the output")
205+
.AsDispensable();
206+
AddInput("bias",
207+
"bias is a 1-dimensional tensor of size C "
208+
"that is applied to the output")
209+
.AsDispensable();
136210
AddAttr<std::string>("data_layout", "").SetDefault("NCHW");
137211
AddAttr<bool>("sync_stats", "(bool, default false) only used in multi-GPU")
138212
.SetDefault(false);
@@ -194,7 +268,6 @@ class DataNormKernel<platform::CPUDeviceContext, T>
194268
// alloc memory
195269
T *y_data = y->mutable_data<T>(ctx.GetPlace());
196270

197-
Eigen::Array<T, Eigen::Dynamic, 1> inv_std(C);
198271
ConstEigenVectorArrayMap<T> b_size_arr(
199272
ctx.Input<Tensor>("BatchSize")->data<T>(), C);
200273
ConstEigenVectorArrayMap<T> b_sum_arr(
@@ -210,6 +283,7 @@ class DataNormKernel<platform::CPUDeviceContext, T>
210283

211284
const T *means_data = mean_out->data<T>();
212285
const T *x_data = x->data<T>();
286+
213287
const T *scales_data = scales->data<T>();
214288
const int slot_dim = ctx.Attr<int>("slot_dim");
215289
T min_precision = 1e-7f;
@@ -218,7 +292,8 @@ class DataNormKernel<platform::CPUDeviceContext, T>
218292
case DataLayout::kNHWC: {
219293
// if slot_dim is set and batch size is larger than zero, we choose
220294
// to check if show number is zero, if so, skip normalization.
221-
if (slot_dim > 0 && N > 0) {
295+
if (slot_dim > 0 && N > 0 &&
296+
(!ctx.Attr<bool>("enable_scale_and_shift"))) {
222297
const int item_size = x->numel() / N;
223298
// location of show number in one embedding
224299
int offset = 0;
@@ -239,10 +314,56 @@ class DataNormKernel<platform::CPUDeviceContext, T>
239314
offset += item_size;
240315
}
241316
} else {
242-
EigenArrayMap<T>(y_data, C, N) =
243-
(ConstEigenArrayMap<T>(x->data<T>(), C, N).colwise() - means_arr)
244-
.colwise() *
245-
scales_arr;
317+
if (!ctx.Attr<bool>("enable_scale_and_shift") && slot_dim <= 0) {
318+
EigenArrayMap<T>(y_data, C, N) =
319+
(ConstEigenArrayMap<T>(x->data<T>(), C, N).colwise() -
320+
means_arr)
321+
.colwise() *
322+
scales_arr;
323+
} else if (ctx.Attr<bool>("enable_scale_and_shift") &&
324+
slot_dim <= 0) {
325+
const auto *scale_w = ctx.Input<Tensor>("scale_w");
326+
const auto *bias = ctx.Input<Tensor>("bias");
327+
ConstEigenVectorArrayMap<T> scale_w_arr(scale_w->data<T>(), C);
328+
ConstEigenVectorArrayMap<T> bias_arr(bias->data<T>(), C);
329+
330+
Eigen::Array<T, Eigen::Dynamic, 1> new_scale =
331+
scales_arr * scale_w_arr;
332+
Eigen::Array<T, Eigen::Dynamic, 1> new_bias =
333+
bias_arr - means_arr * scales_arr * scale_w_arr;
334+
EigenArrayMap<T>(y_data, C, N) =
335+
(ConstEigenArrayMap<T>(x->data<T>(), C, N).colwise() *
336+
new_scale)
337+
.colwise() +
338+
new_bias;
339+
340+
} else {
341+
const int item_size = x->numel() / N;
342+
const auto *scale_w = ctx.Input<Tensor>("scale_w");
343+
const auto *bias = ctx.Input<Tensor>("bias");
344+
const T *scale_w_data = scale_w->data<T>();
345+
const T *bias_data = bias->data<T>();
346+
// location of show number in one embedding
347+
int offset = 0;
348+
for (int k = 0; k < N; ++k) {
349+
for (int i = 0; i < item_size; i += slot_dim) {
350+
if (x_data[offset + i] > -min_precision &&
351+
x_data[offset + i] < min_precision) {
352+
// show = 0
353+
memset(y_data + offset + i, 0, sizeof(T) * slot_dim);
354+
} else {
355+
for (int j = i; j < i + slot_dim; ++j) {
356+
y_data[offset + j] = ((x_data[offset + j] - means_data[j]) *
357+
scales_data[j]) *
358+
scale_w_data[j] +
359+
bias_data[j];
360+
}
361+
}
362+
} // end for i
363+
364+
offset += item_size;
365+
} // end for k
366+
}
246367
}
247368
break;
248369
}
@@ -274,7 +395,8 @@ class DataNormGradOp : public framework::OperatorWithKernel {
274395
"Output(BatchSquareSum) of DataNormGradOp should not be null."));
275396
PADDLE_ENFORCE(ctx->HasInput("Means"), "");
276397
PADDLE_ENFORCE(ctx->HasInput("Scales"), "");
277-
398+
bool enable_scale_and_shift =
399+
ctx->Attrs().Get<bool>("enable_scale_and_shift");
278400
// check output
279401
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSize")), "");
280402
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSum")), "");
@@ -294,6 +416,22 @@ class DataNormGradOp : public framework::OperatorWithKernel {
294416
ctx->SetOutputDim(framework::GradVarName("BatchSize"), {C});
295417
ctx->SetOutputDim(framework::GradVarName("BatchSum"), {C});
296418
ctx->SetOutputDim(framework::GradVarName("BatchSquareSum"), {C});
419+
if (enable_scale_and_shift) {
420+
const bool has_scale_grad =
421+
ctx->HasOutput(framework::GradVarName("scale_w"));
422+
const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("bias"));
423+
424+
PADDLE_ENFORCE_EQ((has_scale_grad == has_bias_grad), true,
425+
platform::errors::InvalidArgument(
426+
"Output(Scale@GRAD) and Output(Bias@GRAD)"
427+
"must be null or not be null at same time. "
428+
"But now, has Scale@Grad=[%d], has Bias@GRAD=[%d]",
429+
has_scale_grad, has_bias_grad));
430+
if (has_scale_grad) {
431+
ctx->SetOutputDim(framework::GradVarName("scale_w"), {C});
432+
ctx->SetOutputDim(framework::GradVarName("bias"), {C});
433+
}
434+
}
297435
}
298436

299437
protected:
@@ -353,26 +491,30 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
353491
const int C =
354492
(data_layout == DataLayout::kNCHW ? x_dims[1]
355493
: x_dims[x_dims.size() - 1]);
356-
357494
// init output
358495
Tensor *d_x = nullptr;
359496
if (ctx.HasOutput(framework::GradVarName("X"))) {
360497
d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
361498
}
499+
362500
auto *d_batch_size =
363501
ctx.Output<Tensor>(framework::GradVarName("BatchSize"));
364502
auto *d_batch_sum = ctx.Output<Tensor>(framework::GradVarName("BatchSum"));
365503
auto *d_batch_square_sum =
366504
ctx.Output<Tensor>(framework::GradVarName("BatchSquareSum"));
367505

506+
const T *mean_data = means->data<T>();
507+
const T *inv_var_data = scales->data<T>();
508+
ConstEigenVectorArrayMap<T> mean_arr(mean_data, C);
509+
ConstEigenVectorArrayMap<T> inv_var_arr(inv_var_data, C);
510+
368511
T *d_batch_size_data = d_batch_size->mutable_data<T>(ctx.GetPlace());
369512
T *d_batch_sum_data = d_batch_sum->mutable_data<T>(ctx.GetPlace());
370513
T *d_batch_square_sum_data =
371514
d_batch_square_sum->mutable_data<T>(ctx.GetPlace());
372515
EigenVectorArrayMap<T> d_batch_size_arr(d_batch_size_data, C);
373516
EigenVectorArrayMap<T> d_batch_sum_arr(d_batch_sum_data, C);
374517
EigenVectorArrayMap<T> d_batch_square_sum_arr(d_batch_square_sum_data, C);
375-
376518
d_batch_size_arr.setZero();
377519
d_batch_sum_arr.setZero();
378520
d_batch_square_sum_arr.setZero();
@@ -392,8 +534,86 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
392534
if (d_x != nullptr) {
393535
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), C, N);
394536
d_x_arr.setZero();
395-
for (int nc = 0; nc < N; ++nc) {
396-
d_x_arr.col(nc) = d_y_arr.col(nc) * scales_arr;
537+
if (!ctx.Attr<bool>("enable_scale_and_shift")) {
538+
for (int nc = 0; nc < N; ++nc) {
539+
d_x_arr.col(nc) = d_y_arr.col(nc) * scales_arr;
540+
}
541+
} else {
542+
const auto *scale_w = ctx.Input<Tensor>("scale_w");
543+
auto *d_scale =
544+
ctx.Output<Tensor>(framework::GradVarName("scale_w"));
545+
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("bias"));
546+
ConstEigenVectorArrayMap<T> scale_arr(scale_w->data<T>(), C);
547+
T *d_bias_data = nullptr;
548+
T *d_scale_data = nullptr;
549+
550+
d_scale->mutable_data<T>(ctx.GetPlace());
551+
d_bias->mutable_data<T>(ctx.GetPlace());
552+
d_bias_data = d_bias->mutable_data<T>(ctx.GetPlace());
553+
d_scale_data = d_scale->mutable_data<T>(ctx.GetPlace());
554+
555+
EigenVectorArrayMap<T> d_bias_arr(d_bias_data, C);
556+
EigenVectorArrayMap<T> d_scale_arr(d_scale_data, C);
557+
Tensor dy_sum;
558+
dy_sum.Resize({C});
559+
dy_sum.mutable_data<T>(ctx.GetPlace());
560+
EigenVectorArrayMap<T> dy_sum_arr(
561+
dy_sum.mutable_data<T>(ctx.GetPlace()), C);
562+
Tensor dy_mul_x_sub_mean_mul_invstd_sum;
563+
dy_mul_x_sub_mean_mul_invstd_sum.Resize({C});
564+
dy_mul_x_sub_mean_mul_invstd_sum.mutable_data<T>(ctx.GetPlace());
565+
EigenVectorArrayMap<T> dy_mul_x_sub_mean_mul_invstd_sum_arr(
566+
dy_mul_x_sub_mean_mul_invstd_sum.mutable_data<T>(
567+
ctx.GetPlace()),
568+
C);
569+
570+
dy_sum_arr.setZero();
571+
dy_mul_x_sub_mean_mul_invstd_sum_arr.setZero();
572+
573+
if (slot_dim <= 0) {
574+
for (int n = 0; n < N; ++n) {
575+
dy_sum_arr += d_y_arr.col(n);
576+
dy_mul_x_sub_mean_mul_invstd_sum_arr +=
577+
((x_arr.col(n) - mean_arr) * inv_var_arr * d_y_arr.col(n));
578+
}
579+
if (d_scale && d_bias) {
580+
d_bias_arr = dy_sum_arr;
581+
d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr;
582+
}
583+
for (int nc = 0; nc < N; ++nc) {
584+
d_x_arr.col(nc) = d_y_arr.col(nc) * scales_arr * scale_arr;
585+
}
586+
} else {
587+
int offset = 0;
588+
const int item_size = x->numel() / N;
589+
T *d_x_data = d_x->mutable_data<T>(ctx.GetPlace());
590+
T *d_scale_data = d_scale->mutable_data<T>(ctx.GetPlace());
591+
T *d_bias_data = d_bias->mutable_data<T>(ctx.GetPlace());
592+
const T *dy_data = d_y->data<T>();
593+
const T *scales_data = scales->data<T>();
594+
const T *scale_w_data = scale_w->data<T>();
595+
const T *x_data = x->data<T>();
596+
for (int i = 0; i < item_size; i++) {
597+
d_bias_data[i] = 0;
598+
d_scale_data[i] = 0;
599+
}
600+
for (int k = 0; k < N; ++k) {
601+
for (int i = 0; i < item_size; i += slot_dim) {
602+
if (!(x_data[offset + i] > -min_precision &&
603+
x_data[offset + i] < min_precision)) {
604+
// show != 0
605+
for (int j = i; j < i + slot_dim; ++j) {
606+
d_x_data[offset + j] = dy_data[offset + j] *
607+
scales_data[j] * scale_w_data[j];
608+
d_bias_data[j] += dy_data[offset + j];
609+
d_scale_data[j] += (x_data[offset + j] - mean_data[j]) *
610+
inv_var_data[j] * dy_data[offset + j];
611+
}
612+
}
613+
}
614+
offset += item_size;
615+
}
616+
}
397617
}
398618
}
399619

@@ -466,6 +686,8 @@ class DataNormGradMaker : public framework::SingleGradOpMaker<T> {
466686
op->SetInput("X", this->Input("X"));
467687
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
468688

689+
op->SetInput("scale_w", this->Input("scale_w"));
690+
op->SetInput("bias", this->Input("bias"));
469691
op->SetOutput("BatchSize", this->Input("BatchSize"));
470692
op->SetOutput("BatchSum", this->Input("BatchSum"));
471693
op->SetOutput("BatchSquareSum", this->Input("BatchSquareSum"));
@@ -481,6 +703,9 @@ class DataNormGradMaker : public framework::SingleGradOpMaker<T> {
481703
this->InputGrad("BatchSum"));
482704
op->SetOutput(framework::GradVarName("BatchSquareSum"),
483705
this->InputGrad("BatchSquareSum"));
706+
op->SetOutput(framework::GradVarName("scale_w"),
707+
this->InputGrad("scale_w"));
708+
op->SetOutput(framework::GradVarName("bias"), this->InputGrad("bias"));
484709
}
485710
};
486711

0 commit comments

Comments
 (0)