@@ -574,6 +574,10 @@ void FCKernel(const Context& dev_ctx,
574574 dev_ctx.HasDnnAttr (" use_mkldnn" )
575575 ? PADDLE_GET_CONST (bool , dev_ctx.GetDnnAttr (" use_mkldnn" ))
576576 : false ;
577+ const bool use_onednn =
578+ (!use_mkldnn && dev_ctx.HasDnnAttr (" use_onednn" ))
579+ ? PADDLE_GET_CONST (bool , dev_ctx.GetDnnAttr (" use_onednn" ))
580+ : use_mkldnn;
577581 const bool use_quantizer =
578582 dev_ctx.HasDnnAttr (" use_quantizer" )
579583 ? PADDLE_GET_CONST (bool , dev_ctx.GetDnnAttr (" use_quantizer" ))
@@ -583,6 +587,11 @@ void FCKernel(const Context& dev_ctx,
583587 ? PADDLE_GET_CONST (std::string,
584588 dev_ctx.GetDnnAttr (" mkldnn_data_type" ))
585589 : " float32" ;
590+ const std::string onednn_data_type =
591+ (use_onednn && dev_ctx.HasDnnAttr (" onednn_data_type" ))
592+ ? PADDLE_GET_CONST (std::string,
593+ dev_ctx.GetDnnAttr (" onednn_data_type" ))
594+ : mkldnn_data_type;
586595 const float scale_in =
587596 dev_ctx.HasDnnAttr (" Scale_in" )
588597 ? PADDLE_GET_CONST (float , dev_ctx.GetDnnAttr (" Scale_in" ))
@@ -601,24 +610,24 @@ void FCKernel(const Context& dev_ctx,
601610 dev_ctx.HasDnnAttr (" force_fp32_output" )
602611 ? PADDLE_GET_CONST (bool , dev_ctx.GetDnnAttr (" force_fp32_output" ))
603612 : false ;
604- std::vector<std::string> mkldnn_data_type_list = {
613+ std::vector<std::string> onednn_data_type_list = {
605614 " float32" , " int8" , " bfloat16" };
606- PADDLE_ENFORCE_EQ (std::find (mkldnn_data_type_list .begin (),
607- mkldnn_data_type_list .end (),
608- mkldnn_data_type ) != mkldnn_data_type_list .end (),
615+ PADDLE_ENFORCE_EQ (std::find (onednn_data_type_list .begin (),
616+ onednn_data_type_list .end (),
617+ onednn_data_type ) != onednn_data_type_list .end (),
609618 true ,
610619 common::errors::InvalidArgument (
611- " The mkldnn_data_type should be [float32, "
620+ " The onednn_data_type should be [float32, "
612621 " int8, bfloat16], but found %s." ,
613- mkldnn_data_type .c_str ()));
622+ onednn_data_type .c_str ()));
614623 auto in_dims = input.dims ();
615- if (use_mkldnn ) {
624+ if (use_onednn ) {
616625 PADDLE_ENFORCE_EQ (
617626 in_dims.size () >= 2 && in_dims.size () <= 4 ,
618627 true ,
619628 common::errors::Unimplemented (
620629 " The Input of fc is expected to be a 2-D, 3-D or 4-D tensor when "
621- " use_mkldnn is set. But received the number of Input's "
630+ " use_onednn is set. But received the number of Input's "
622631 " dimensions is %d, Input's shape is %s." ,
623632 in_dims.size (),
624633 in_dims));
@@ -632,10 +641,10 @@ void FCKernel(const Context& dev_ctx,
632641 bias,
633642 in_num_col_dims,
634643 activation_type,
635- use_mkldnn ,
644+ use_onednn ,
636645 padding_weights,
637646 use_quantizer,
638- mkldnn_data_type ,
647+ onednn_data_type ,
639648 scale_in,
640649 scale_weights,
641650 scale_out,
@@ -649,10 +658,10 @@ void FCKernel(const Context& dev_ctx,
649658 bias,
650659 in_num_col_dims,
651660 activation_type,
652- use_mkldnn ,
661+ use_onednn ,
653662 padding_weights,
654663 use_quantizer,
655- mkldnn_data_type ,
664+ onednn_data_type ,
656665 scale_in,
657666 scale_weights,
658667 scale_out,
@@ -665,10 +674,10 @@ void FCKernel(const Context& dev_ctx,
665674 bias,
666675 in_num_col_dims,
667676 activation_type,
668- use_mkldnn ,
677+ use_onednn ,
669678 padding_weights,
670679 use_quantizer,
671- mkldnn_data_type ,
680+ onednn_data_type ,
672681 scale_in,
673682 scale_weights,
674683 scale_out,
@@ -682,10 +691,10 @@ void FCKernel(const Context& dev_ctx,
682691 bias,
683692 in_num_col_dims,
684693 activation_type,
685- use_mkldnn ,
694+ use_onednn ,
686695 padding_weights,
687696 use_quantizer,
688- mkldnn_data_type ,
697+ onednn_data_type ,
689698 scale_in,
690699 scale_weights,
691700 scale_out,
0 commit comments