Skip to content

Commit bcbb889

Browse files
authored
onednn/fc_kernel.cc add use_onednn [fluid_ops] (#74325)
* Fix * Fix
1 parent 39403d1 commit bcbb889

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

paddle/phi/kernels/fusion/onednn/fc_kernel.cc

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,10 @@ void FCKernel(const Context& dev_ctx,
574574
dev_ctx.HasDnnAttr("use_mkldnn")
575575
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("use_mkldnn"))
576576
: false;
577+
const bool use_onednn =
578+
(!use_mkldnn && dev_ctx.HasDnnAttr("use_onednn"))
579+
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("use_onednn"))
580+
: use_mkldnn;
577581
const bool use_quantizer =
578582
dev_ctx.HasDnnAttr("use_quantizer")
579583
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("use_quantizer"))
@@ -583,6 +587,11 @@ void FCKernel(const Context& dev_ctx,
583587
? PADDLE_GET_CONST(std::string,
584588
dev_ctx.GetDnnAttr("mkldnn_data_type"))
585589
: "float32";
590+
const std::string onednn_data_type =
591+
(use_onednn && dev_ctx.HasDnnAttr("onednn_data_type"))
592+
? PADDLE_GET_CONST(std::string,
593+
dev_ctx.GetDnnAttr("onednn_data_type"))
594+
: mkldnn_data_type;
586595
const float scale_in =
587596
dev_ctx.HasDnnAttr("Scale_in")
588597
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_in"))
@@ -601,24 +610,24 @@ void FCKernel(const Context& dev_ctx,
601610
dev_ctx.HasDnnAttr("force_fp32_output")
602611
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output"))
603612
: false;
604-
std::vector<std::string> mkldnn_data_type_list = {
613+
std::vector<std::string> onednn_data_type_list = {
605614
"float32", "int8", "bfloat16"};
606-
PADDLE_ENFORCE_EQ(std::find(mkldnn_data_type_list.begin(),
607-
mkldnn_data_type_list.end(),
608-
mkldnn_data_type) != mkldnn_data_type_list.end(),
615+
PADDLE_ENFORCE_EQ(std::find(onednn_data_type_list.begin(),
616+
onednn_data_type_list.end(),
617+
onednn_data_type) != onednn_data_type_list.end(),
609618
true,
610619
common::errors::InvalidArgument(
611-
"The mkldnn_data_type should be [float32, "
620+
"The onednn_data_type should be [float32, "
612621
"int8, bfloat16], but found %s.",
613-
mkldnn_data_type.c_str()));
622+
onednn_data_type.c_str()));
614623
auto in_dims = input.dims();
615-
if (use_mkldnn) {
624+
if (use_onednn) {
616625
PADDLE_ENFORCE_EQ(
617626
in_dims.size() >= 2 && in_dims.size() <= 4,
618627
true,
619628
common::errors::Unimplemented(
620629
"The Input of fc is expected to be a 2-D, 3-D or 4-D tensor when "
621-
"use_mkldnn is set. But received the number of Input's "
630+
"use_onednn is set. But received the number of Input's "
622631
"dimensions is %d, Input's shape is %s.",
623632
in_dims.size(),
624633
in_dims));
@@ -632,10 +641,10 @@ void FCKernel(const Context& dev_ctx,
632641
bias,
633642
in_num_col_dims,
634643
activation_type,
635-
use_mkldnn,
644+
use_onednn,
636645
padding_weights,
637646
use_quantizer,
638-
mkldnn_data_type,
647+
onednn_data_type,
639648
scale_in,
640649
scale_weights,
641650
scale_out,
@@ -649,10 +658,10 @@ void FCKernel(const Context& dev_ctx,
649658
bias,
650659
in_num_col_dims,
651660
activation_type,
652-
use_mkldnn,
661+
use_onednn,
653662
padding_weights,
654663
use_quantizer,
655-
mkldnn_data_type,
664+
onednn_data_type,
656665
scale_in,
657666
scale_weights,
658667
scale_out,
@@ -665,10 +674,10 @@ void FCKernel(const Context& dev_ctx,
665674
bias,
666675
in_num_col_dims,
667676
activation_type,
668-
use_mkldnn,
677+
use_onednn,
669678
padding_weights,
670679
use_quantizer,
671-
mkldnn_data_type,
680+
onednn_data_type,
672681
scale_in,
673682
scale_weights,
674683
scale_out,
@@ -682,10 +691,10 @@ void FCKernel(const Context& dev_ctx,
682691
bias,
683692
in_num_col_dims,
684693
activation_type,
685-
use_mkldnn,
694+
use_onednn,
686695
padding_weights,
687696
use_quantizer,
688-
mkldnn_data_type,
697+
onednn_data_type,
689698
scale_in,
690699
scale_weights,
691700
scale_out,

0 commit comments

Comments
 (0)