@@ -51,6 +51,17 @@ class DataNormOp : public framework::OperatorWithKernel {
5151 PADDLE_ENFORCE (ctx->HasOutput (" Means" ), " " );
5252 PADDLE_ENFORCE (ctx->HasOutput (" Scales" ), " " );
5353 PADDLE_ENFORCE (ctx->HasOutput (" Y" ), " " );
54+ bool enable_scale_and_shift =
55+ ctx->Attrs ().Get <bool >(" enable_scale_and_shift" );
56+ if (enable_scale_and_shift) {
57+ PADDLE_ENFORCE_EQ (
58+ ctx->HasInput (" scale_w" ), true ,
59+ platform::errors::InvalidArgument (
60+ " Input(scale_w) of DataNormOp should not be null." ));
61+ PADDLE_ENFORCE_EQ (ctx->HasInput (" bias" ), true ,
62+ platform::errors::InvalidArgument (
63+ " Input(bias) of DataNormOp should not be null." ));
64+ }
5465
5566 const auto x_dims = ctx->GetInputDim (" X" );
5667 const DataLayout data_layout = framework::StringToDataLayout (
@@ -72,6 +83,45 @@ class DataNormOp : public framework::OperatorWithKernel {
7283 PADDLE_ENFORCE_EQ (ctx->GetInputDim (" BatchSquareSum" )[0 ], C);
7384 }
7485
86+ if (enable_scale_and_shift) {
87+ auto scale_dim = ctx->GetInputDim (" scale_w" );
88+ auto bias_dim = ctx->GetInputDim (" bias" );
89+
90+ PADDLE_ENFORCE_EQ (
91+ scale_dim.size (), 1UL ,
92+ platform::errors::InvalidArgument (" the dimensionof scale"
93+ " must equal to 1. But received: "
94+ " the shape of scale is [%s], "
95+ " the dimensionof scale is [%d]" ,
96+ scale_dim, scale_dim.size ()));
97+ PADDLE_ENFORCE_EQ (
98+ bias_dim.size (), 1UL ,
99+ platform::errors::InvalidArgument (" the dimension of bias"
100+ " must equal to 1. But received: "
101+ " the shape of bias is [%s],"
102+ " the dimension of bias is [%d]" ,
103+ bias_dim, bias_dim.size ()));
104+
105+ bool check = true ;
106+ if ((!ctx->IsRuntime ()) && (framework::product (scale_dim) <= 0 ||
107+ framework::product (bias_dim) <= 0 )) {
108+ check = false ;
109+ }
110+
111+ if (check) {
112+ PADDLE_ENFORCE_EQ (scale_dim[0 ], C,
113+ platform::errors::InvalidArgument (
114+ " the shape of scale must equal to [%d]"
115+ " But received: the shape of scale is [%d]" ,
116+ C, scale_dim[0 ]));
117+ PADDLE_ENFORCE_EQ (bias_dim[0 ], C,
118+ platform::errors::InvalidArgument (
119+ " the shape of bias must equal to [%d]"
120+ " But received: the shape of bias is [%d]" ,
121+ C, bias_dim[0 ]));
122+ }
123+ }
124+
75125 ctx->SetOutputDim (" Y" , x_dims);
76126 ctx->SetOutputDim (" Means" , {C});
77127 ctx->SetOutputDim (" Scales" , {C});
@@ -99,6 +149,17 @@ class DataNormOp : public framework::OperatorWithKernel {
99149 ctx, " BatchSquareSum" ),
100150 " BatchSquareSum input should be of float type" );
101151
152+ bool enable_scale_and_shift = ctx.Attr <bool >(" enable_scale_and_shift" );
153+ if (enable_scale_and_shift) {
154+ PADDLE_ENFORCE_EQ (dn_param_type,
155+ OperatorWithKernel::IndicateVarDataType (ctx, " scale_w" ),
156+ platform::errors::InvalidArgument (
157+ " scale_w input should be of float type" ));
158+ PADDLE_ENFORCE_EQ (dn_param_type,
159+ OperatorWithKernel::IndicateVarDataType (ctx, " bias" ),
160+ platform::errors::InvalidArgument (
161+ " bias input should be of float type" ));
162+ }
102163 // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
103164 framework::LibraryType library = framework::LibraryType::kPlain ;
104165 framework::DataLayout layout = framework::DataLayout::kAnyLayout ;
@@ -133,6 +194,19 @@ class DataNormOpMaker : public framework::OpProtoAndCheckerMaker {
133194 " summary_decay_rate" ,
134195 " (float, default 0.9999999) The decay rate when update the summary" )
135196 .SetDefault (0.9999999 );
197+ AddAttr<bool >(
198+ " enable_scale_and_shift" ,
199+ " (bool, default false) Set to true to enable scale and shift such as "
200+ " batch_norm op" )
201+ .SetDefault (false );
202+ AddInput (" scale_w" ,
203+ " scale_w is a 1-dimensional tensor of size C "
204+ " that is applied to the output" )
205+ .AsDispensable ();
206+ AddInput (" bias" ,
207+ " bias is a 1-dimensional tensor of size C "
208+ " that is applied to the output" )
209+ .AsDispensable ();
136210 AddAttr<std::string>(" data_layout" , " " ).SetDefault (" NCHW" );
137211 AddAttr<bool >(" sync_stats" , " (bool, default false) only used in multi-GPU" )
138212 .SetDefault (false );
@@ -194,7 +268,6 @@ class DataNormKernel<platform::CPUDeviceContext, T>
194268 // alloc memory
195269 T *y_data = y->mutable_data <T>(ctx.GetPlace ());
196270
197- Eigen::Array<T, Eigen::Dynamic, 1 > inv_std (C);
198271 ConstEigenVectorArrayMap<T> b_size_arr (
199272 ctx.Input <Tensor>(" BatchSize" )->data <T>(), C);
200273 ConstEigenVectorArrayMap<T> b_sum_arr (
@@ -210,6 +283,7 @@ class DataNormKernel<platform::CPUDeviceContext, T>
210283
211284 const T *means_data = mean_out->data <T>();
212285 const T *x_data = x->data <T>();
286+
213287 const T *scales_data = scales->data <T>();
214288 const int slot_dim = ctx.Attr <int >(" slot_dim" );
215289 T min_precision = 1e-7f ;
@@ -218,7 +292,8 @@ class DataNormKernel<platform::CPUDeviceContext, T>
218292 case DataLayout::kNHWC : {
219293 // if slot_dim is set and batch size is larger than zero, we choose
220294 // to check if show number is zero, if so, skip normalization.
221- if (slot_dim > 0 && N > 0 ) {
295+ if (slot_dim > 0 && N > 0 &&
296+ (!ctx.Attr <bool >(" enable_scale_and_shift" ))) {
222297 const int item_size = x->numel () / N;
223298 // location of show number in one embedding
224299 int offset = 0 ;
@@ -239,10 +314,56 @@ class DataNormKernel<platform::CPUDeviceContext, T>
239314 offset += item_size;
240315 }
241316 } else {
242- EigenArrayMap<T>(y_data, C, N) =
243- (ConstEigenArrayMap<T>(x->data <T>(), C, N).colwise () - means_arr)
244- .colwise () *
245- scales_arr;
317+ if (!ctx.Attr <bool >(" enable_scale_and_shift" ) && slot_dim <= 0 ) {
318+ EigenArrayMap<T>(y_data, C, N) =
319+ (ConstEigenArrayMap<T>(x->data <T>(), C, N).colwise () -
320+ means_arr)
321+ .colwise () *
322+ scales_arr;
323+ } else if (ctx.Attr <bool >(" enable_scale_and_shift" ) &&
324+ slot_dim <= 0 ) {
325+ const auto *scale_w = ctx.Input <Tensor>(" scale_w" );
326+ const auto *bias = ctx.Input <Tensor>(" bias" );
327+ ConstEigenVectorArrayMap<T> scale_w_arr (scale_w->data <T>(), C);
328+ ConstEigenVectorArrayMap<T> bias_arr (bias->data <T>(), C);
329+
330+ Eigen::Array<T, Eigen::Dynamic, 1 > new_scale =
331+ scales_arr * scale_w_arr;
332+ Eigen::Array<T, Eigen::Dynamic, 1 > new_bias =
333+ bias_arr - means_arr * scales_arr * scale_w_arr;
334+ EigenArrayMap<T>(y_data, C, N) =
335+ (ConstEigenArrayMap<T>(x->data <T>(), C, N).colwise () *
336+ new_scale)
337+ .colwise () +
338+ new_bias;
339+
340+ } else {
341+ const int item_size = x->numel () / N;
342+ const auto *scale_w = ctx.Input <Tensor>(" scale_w" );
343+ const auto *bias = ctx.Input <Tensor>(" bias" );
344+ const T *scale_w_data = scale_w->data <T>();
345+ const T *bias_data = bias->data <T>();
346+ // location of show number in one embedding
347+ int offset = 0 ;
348+ for (int k = 0 ; k < N; ++k) {
349+ for (int i = 0 ; i < item_size; i += slot_dim) {
350+ if (x_data[offset + i] > -min_precision &&
351+ x_data[offset + i] < min_precision) {
352+ // show = 0
353+ memset (y_data + offset + i, 0 , sizeof (T) * slot_dim);
354+ } else {
355+ for (int j = i; j < i + slot_dim; ++j) {
356+ y_data[offset + j] = ((x_data[offset + j] - means_data[j]) *
357+ scales_data[j]) *
358+ scale_w_data[j] +
359+ bias_data[j];
360+ }
361+ }
362+ } // end for i
363+
364+ offset += item_size;
365+ } // end for k
366+ }
246367 }
247368 break ;
248369 }
@@ -274,7 +395,8 @@ class DataNormGradOp : public framework::OperatorWithKernel {
274395 " Output(BatchSquareSum) of DataNormGradOp should not be null." ));
275396 PADDLE_ENFORCE (ctx->HasInput (" Means" ), " " );
276397 PADDLE_ENFORCE (ctx->HasInput (" Scales" ), " " );
277-
398+ bool enable_scale_and_shift =
399+ ctx->Attrs ().Get <bool >(" enable_scale_and_shift" );
278400 // check output
279401 PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" BatchSize" )), " " );
280402 PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" BatchSum" )), " " );
@@ -294,6 +416,22 @@ class DataNormGradOp : public framework::OperatorWithKernel {
294416 ctx->SetOutputDim (framework::GradVarName (" BatchSize" ), {C});
295417 ctx->SetOutputDim (framework::GradVarName (" BatchSum" ), {C});
296418 ctx->SetOutputDim (framework::GradVarName (" BatchSquareSum" ), {C});
419+ if (enable_scale_and_shift) {
420+ const bool has_scale_grad =
421+ ctx->HasOutput (framework::GradVarName (" scale_w" ));
422+ const bool has_bias_grad = ctx->HasOutput (framework::GradVarName (" bias" ));
423+
424+ PADDLE_ENFORCE_EQ ((has_scale_grad == has_bias_grad), true ,
425+ platform::errors::InvalidArgument (
426+ " Output(Scale@GRAD) and Output(Bias@GRAD)"
427+ " must be null or not be null at same time. "
428+ " But now, has Scale@Grad=[%d], has Bias@GRAD=[%d]" ,
429+ has_scale_grad, has_bias_grad));
430+ if (has_scale_grad) {
431+ ctx->SetOutputDim (framework::GradVarName (" scale_w" ), {C});
432+ ctx->SetOutputDim (framework::GradVarName (" bias" ), {C});
433+ }
434+ }
297435 }
298436
299437 protected:
@@ -353,26 +491,30 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
353491 const int C =
354492 (data_layout == DataLayout::kNCHW ? x_dims[1 ]
355493 : x_dims[x_dims.size () - 1 ]);
356-
357494 // init output
358495 Tensor *d_x = nullptr ;
359496 if (ctx.HasOutput (framework::GradVarName (" X" ))) {
360497 d_x = ctx.Output <Tensor>(framework::GradVarName (" X" ));
361498 }
499+
362500 auto *d_batch_size =
363501 ctx.Output <Tensor>(framework::GradVarName (" BatchSize" ));
364502 auto *d_batch_sum = ctx.Output <Tensor>(framework::GradVarName (" BatchSum" ));
365503 auto *d_batch_square_sum =
366504 ctx.Output <Tensor>(framework::GradVarName (" BatchSquareSum" ));
367505
506+ const T *mean_data = means->data <T>();
507+ const T *inv_var_data = scales->data <T>();
508+ ConstEigenVectorArrayMap<T> mean_arr (mean_data, C);
509+ ConstEigenVectorArrayMap<T> inv_var_arr (inv_var_data, C);
510+
368511 T *d_batch_size_data = d_batch_size->mutable_data <T>(ctx.GetPlace ());
369512 T *d_batch_sum_data = d_batch_sum->mutable_data <T>(ctx.GetPlace ());
370513 T *d_batch_square_sum_data =
371514 d_batch_square_sum->mutable_data <T>(ctx.GetPlace ());
372515 EigenVectorArrayMap<T> d_batch_size_arr (d_batch_size_data, C);
373516 EigenVectorArrayMap<T> d_batch_sum_arr (d_batch_sum_data, C);
374517 EigenVectorArrayMap<T> d_batch_square_sum_arr (d_batch_square_sum_data, C);
375-
376518 d_batch_size_arr.setZero ();
377519 d_batch_sum_arr.setZero ();
378520 d_batch_square_sum_arr.setZero ();
@@ -392,8 +534,86 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
392534 if (d_x != nullptr ) {
393535 EigenArrayMap<T> d_x_arr (d_x->mutable_data <T>(ctx.GetPlace ()), C, N);
394536 d_x_arr.setZero ();
395- for (int nc = 0 ; nc < N; ++nc) {
396- d_x_arr.col (nc) = d_y_arr.col (nc) * scales_arr;
537+ if (!ctx.Attr <bool >(" enable_scale_and_shift" )) {
538+ for (int nc = 0 ; nc < N; ++nc) {
539+ d_x_arr.col (nc) = d_y_arr.col (nc) * scales_arr;
540+ }
541+ } else {
542+ const auto *scale_w = ctx.Input <Tensor>(" scale_w" );
543+ auto *d_scale =
544+ ctx.Output <Tensor>(framework::GradVarName (" scale_w" ));
545+ auto *d_bias = ctx.Output <Tensor>(framework::GradVarName (" bias" ));
546+ ConstEigenVectorArrayMap<T> scale_arr (scale_w->data <T>(), C);
547+ T *d_bias_data = nullptr ;
548+ T *d_scale_data = nullptr ;
549+
550+ d_scale->mutable_data <T>(ctx.GetPlace ());
551+ d_bias->mutable_data <T>(ctx.GetPlace ());
552+ d_bias_data = d_bias->mutable_data <T>(ctx.GetPlace ());
553+ d_scale_data = d_scale->mutable_data <T>(ctx.GetPlace ());
554+
555+ EigenVectorArrayMap<T> d_bias_arr (d_bias_data, C);
556+ EigenVectorArrayMap<T> d_scale_arr (d_scale_data, C);
557+ Tensor dy_sum;
558+ dy_sum.Resize ({C});
559+ dy_sum.mutable_data <T>(ctx.GetPlace ());
560+ EigenVectorArrayMap<T> dy_sum_arr (
561+ dy_sum.mutable_data <T>(ctx.GetPlace ()), C);
562+ Tensor dy_mul_x_sub_mean_mul_invstd_sum;
563+ dy_mul_x_sub_mean_mul_invstd_sum.Resize ({C});
564+ dy_mul_x_sub_mean_mul_invstd_sum.mutable_data <T>(ctx.GetPlace ());
565+ EigenVectorArrayMap<T> dy_mul_x_sub_mean_mul_invstd_sum_arr (
566+ dy_mul_x_sub_mean_mul_invstd_sum.mutable_data <T>(
567+ ctx.GetPlace ()),
568+ C);
569+
570+ dy_sum_arr.setZero ();
571+ dy_mul_x_sub_mean_mul_invstd_sum_arr.setZero ();
572+
573+ if (slot_dim <= 0 ) {
574+ for (int n = 0 ; n < N; ++n) {
575+ dy_sum_arr += d_y_arr.col (n);
576+ dy_mul_x_sub_mean_mul_invstd_sum_arr +=
577+ ((x_arr.col (n) - mean_arr) * inv_var_arr * d_y_arr.col (n));
578+ }
579+ if (d_scale && d_bias) {
580+ d_bias_arr = dy_sum_arr;
581+ d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr;
582+ }
583+ for (int nc = 0 ; nc < N; ++nc) {
584+ d_x_arr.col (nc) = d_y_arr.col (nc) * scales_arr * scale_arr;
585+ }
586+ } else {
587+ int offset = 0 ;
588+ const int item_size = x->numel () / N;
589+ T *d_x_data = d_x->mutable_data <T>(ctx.GetPlace ());
590+ T *d_scale_data = d_scale->mutable_data <T>(ctx.GetPlace ());
591+ T *d_bias_data = d_bias->mutable_data <T>(ctx.GetPlace ());
592+ const T *dy_data = d_y->data <T>();
593+ const T *scales_data = scales->data <T>();
594+ const T *scale_w_data = scale_w->data <T>();
595+ const T *x_data = x->data <T>();
596+ for (int i = 0 ; i < item_size; i++) {
597+ d_bias_data[i] = 0 ;
598+ d_scale_data[i] = 0 ;
599+ }
600+ for (int k = 0 ; k < N; ++k) {
601+ for (int i = 0 ; i < item_size; i += slot_dim) {
602+ if (!(x_data[offset + i] > -min_precision &&
603+ x_data[offset + i] < min_precision)) {
604+ // show != 0
605+ for (int j = i; j < i + slot_dim; ++j) {
606+ d_x_data[offset + j] = dy_data[offset + j] *
607+ scales_data[j] * scale_w_data[j];
608+ d_bias_data[j] += dy_data[offset + j];
609+ d_scale_data[j] += (x_data[offset + j] - mean_data[j]) *
610+ inv_var_data[j] * dy_data[offset + j];
611+ }
612+ }
613+ }
614+ offset += item_size;
615+ }
616+ }
397617 }
398618 }
399619
@@ -466,6 +686,8 @@ class DataNormGradMaker : public framework::SingleGradOpMaker<T> {
466686 op->SetInput (" X" , this ->Input (" X" ));
467687 op->SetInput (framework::GradVarName (" Y" ), this ->OutputGrad (" Y" ));
468688
689+ op->SetInput (" scale_w" , this ->Input (" scale_w" ));
690+ op->SetInput (" bias" , this ->Input (" bias" ));
469691 op->SetOutput (" BatchSize" , this ->Input (" BatchSize" ));
470692 op->SetOutput (" BatchSum" , this ->Input (" BatchSum" ));
471693 op->SetOutput (" BatchSquareSum" , this ->Input (" BatchSquareSum" ));
@@ -481,6 +703,9 @@ class DataNormGradMaker : public framework::SingleGradOpMaker<T> {
481703 this ->InputGrad (" BatchSum" ));
482704 op->SetOutput (framework::GradVarName (" BatchSquareSum" ),
483705 this ->InputGrad (" BatchSquareSum" ));
706+ op->SetOutput (framework::GradVarName (" scale_w" ),
707+ this ->InputGrad (" scale_w" ));
708+ op->SetOutput (framework::GradVarName (" bias" ), this ->InputGrad (" bias" ));
484709 }
485710};
486711
0 commit comments