@@ -78,7 +78,8 @@ class ConvMKLDNNHandlerT
7878 mkldnn::convolution_backward_weights>(
7979 dev_ctx, mkldnn_engine, cpu_place,
8080 platform::CreateKey (dev_ctx, framework::vectorize(input->dims ()),
81- unique_name)) {
81+ unique_name)),
82+ is_test_(ctx.Attr<bool >(" is_test" )) {
8283 if (!this ->isCached ()) {
8384 PADDLE_ENFORCE_EQ (
8485 input->layout (), framework::DataLayout::kMKLDNN ,
@@ -159,7 +160,6 @@ class ConvMKLDNNHandlerT
159160 framework::slice_ddim (filter_dims, 2 , filter_dims.size ());
160161
161162 const auto ksize = framework::vectorize (filter_data_dims);
162- const bool is_test = ctx.Attr <bool >(" is_test" );
163163
164164 auto strides_temp = ctx.Attr <std::vector<int >>(" strides" );
165165 std::vector<int64_t > strides (begin (strides_temp), end (strides_temp));
@@ -214,9 +214,8 @@ class ConvMKLDNNHandlerT
214214
215215 const auto dst_md = platform::MKLDNNMemDesc (
216216 dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
217- const auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
218- : mkldnn::prop_kind::forward_training;
219-
217+ const auto fwd_prop_kind = is_test_ ? mkldnn::prop_kind::forward_inference
218+ : mkldnn::prop_kind::forward_training;
220219 float sum_scale = 1 .0f ;
221220 std::vector<float > output_shift_scale;
222221 if (platform::is_int8<T>())
@@ -261,7 +260,8 @@ class ConvMKLDNNHandlerT
261260 mkldnn::convolution_backward_weights>(
262261 dev_ctx, dev_ctx.GetEngine(), cpu_place,
263262 platform::CreateKey(dev_ctx, framework::vectorize(in->dims ()),
264- unique_name)) {
263+ unique_name)),
264+ is_test_(false ) {
265265 if (!this ->isBwdCached ()) {
266266 PADDLE_ENFORCE_EQ (
267267 in->layout (), framework::DataLayout::kMKLDNN ,
@@ -291,7 +291,7 @@ class ConvMKLDNNHandlerT
291291 " Wrong format set for output_grad tensor" ));
292292
293293 PADDLE_ENFORCE_EQ (
294- ctx. Attr < bool >( " is_test " ) , false ,
294+ is_test_ , false ,
295295 platform::errors::InvalidArgument (
296296 " is_test attribute should be set to False in training phase." ));
297297
@@ -557,26 +557,26 @@ class ConvMKLDNNHandlerT
557557 framework::vectorize (in_mem->dims ()),
558558 platform::MKLDNNGetDataType<T>(), in_mem->format ());
559559 return this ->AcquireMemoryWithReorder (
560- user_mem_md, mem_md, platform::to_void_cast<T>(in_mem_data), key_mem);
560+ user_mem_md, mem_md, platform::to_void_cast<T>(in_mem_data), key_mem,
561+ is_test_);
561562 } else {
562563 const std::string target_key_suffix{key_mem_target};
563564 const auto target_mem_p = this ->AcquireMemory (target_key_suffix);
564565 user_mem_p->set_data_handle (platform::to_void_cast<T>(in_mem_data));
565566 if (user_mem_p != target_mem_p) {
566- this ->AcquireReorder (user_mem_p, target_mem_p, key_mem );
567+ this ->AcquireReorder (user_mem_p, target_mem_p);
567568 }
568569 return target_mem_p;
569570 }
570571 }
571572
572573 std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder (
573574 const framework::Tensor* filter, const int groups, const bool is_conv3d,
574- const bool is_test, const std::vector<float >& scale_data = {1 .0f },
575- int mask = 0 ) {
575+ const std::vector<float >& scale_data = {1 .0f }, int mask = 0 ) {
576576 // This is workaround to make execution faster, delete
577577 // if statement after including md inside Tensor
578578 auto weights_mem_p = this ->AcquireMemory (" @weights_mem_p_target" );
579- if (is_test && weights_mem_p) {
579+ if (is_test_ && weights_mem_p) {
580580 return weights_mem_p;
581581 } else {
582582 const K* filter_data = filter->data <K>();
@@ -589,16 +589,16 @@ class ConvMKLDNNHandlerT
589589
590590 return this ->AcquireMemoryWithReorder (
591591 user_src_md, this ->fwd_pd_ ->weights_desc (),
592- platform::to_void_cast<K>(filter_data), " @weights_mem_p" , is_test, {} ,
593- scale_data, mask);
592+ platform::to_void_cast<K>(filter_data), " @weights_mem_p" , is_test_ ,
593+ {}, scale_data, mask);
594594 }
595595 }
596596
597597 std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder (
598- const framework::Tensor* bias, const bool is_test,
598+ const framework::Tensor* bias,
599599 const std::vector<float >& scale_data = {1 .0f }, int mask = 0 ) {
600600 auto bias_mem_p = this ->AcquireMemory (" @bias_mem_p_target" );
601- if (is_test && bias_mem_p) {
601+ if (is_test_ && bias_mem_p) {
602602 return bias_mem_p;
603603 } else {
604604 const K* bias_data = bias->data <K>();
@@ -608,7 +608,7 @@ class ConvMKLDNNHandlerT
608608
609609 return this ->AcquireMemoryWithReorder (
610610 user_bias_md, this ->fwd_pd_ ->bias_desc (),
611- platform::to_void_cast<K>(bias_data), " @bias_mem_p" , is_test , {},
611+ platform::to_void_cast<K>(bias_data), " @bias_mem_p" , is_test_ , {},
612612 scale_data, mask);
613613 }
614614 }
@@ -641,7 +641,7 @@ class ConvMKLDNNHandlerT
641641 platform::GetMKLDNNFormat (this ->fwd_pd_ ->dst_desc ())) {
642642 auto residual_memory_p = this ->AcquireResidualMemory (residual_param);
643643 dst_memory_p = this ->template AcquireDstMemory <T_out>(output);
644- this ->AcquireReorder (residual_memory_p, dst_memory_p, " @residual_dst " );
644+ this ->AcquireReorder (residual_memory_p, dst_memory_p);
645645 } else {
646646 // Changing ShareDataWith to TensorCopy results in performance drop
647647 // on ResNet architectures
@@ -651,6 +651,9 @@ class ConvMKLDNNHandlerT
651651 }
652652 return dst_memory_p;
653653 }
654+
655+ private:
656+ const bool is_test_;
654657};
655658
656659} // anonymous namespace
@@ -695,7 +698,6 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
695698 ctx.template device_context <platform::MKLDNNDeviceContext>();
696699 const auto & mkldnn_engine = dev_ctx.GetEngine ();
697700
698- const bool is_test = ctx.Attr <bool >(" is_test" );
699701 const bool is_conv3d = ctx.Attr <std::vector<int >>(" strides" ).size () == 3U ;
700702 const bool fuse_residual_conn = ctx.Attr <bool >(" fuse_residual_connection" );
701703
@@ -712,7 +714,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
712714 auto src_memory_p = handler.AcquireSrcMemoryWithReorder (input);
713715
714716 auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder (
715- filter, ctx.Attr <int >(" groups" ), is_conv3d, is_test );
717+ filter, ctx.Attr <int >(" groups" ), is_conv3d);
716718
717719 std::shared_ptr<dnnl::memory> dst_memory_p;
718720 if (fuse_residual_conn) {
@@ -731,7 +733,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
731733 {MKLDNN_ARG_DST, *dst_memory_p}};
732734
733735 if (bias) {
734- auto bias_memory_p = handler.AcquireBiasMemoryWithReorder (bias, is_test );
736+ auto bias_memory_p = handler.AcquireBiasMemoryWithReorder (bias);
735737 args.insert ({MKLDNN_ARG_BIAS, *bias_memory_p});
736738 }
737739
@@ -783,11 +785,10 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
783785 ctx.Attr <std::vector<float >>(" Scale_weights" );
784786 const bool is_multi_channel = scale_weights_data.size () > 1 ;
785787 const int & groups = ctx.Attr <int >(" groups" );
786- const bool & is_test = ctx.Attr <bool >(" is_test" );
787788 int mask_reorder =
788789 is_multi_channel ? ((groups != 1 ) ? (1 << 1 ) + (1 << 0 ) : 1 << 0 ) : 0 ;
789790 auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder (
790- filter, groups, false , is_test, scale_weights_data, mask_reorder);
791+ filter, groups, false , scale_weights_data, mask_reorder);
791792
792793 std::shared_ptr<dnnl::memory> dst_memory_p;
793794 if (fuse_residual_conn) {
@@ -822,7 +823,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
822823 handler.get_int8_bias_scales (ctx);
823824
824825 auto bias_memory_p = handler.AcquireBiasMemoryWithReorder (
825- bias, is_test, scale_bias_data, mask_reorder);
826+ bias, scale_bias_data, mask_reorder);
826827 args.insert ({MKLDNN_ARG_BIAS, *bias_memory_p});
827828 }
828829
0 commit comments