Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,17 +366,17 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
float epsilon =
PADDLE_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon"));

bool is_mkldnn = fuse_option == FUSE_MKLDNN;
bool is_onednn = fuse_option == FUSE_ONEDNN;
auto input_names = conv->Op()->InputNames();
bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
input_names.end() &&
!conv->Op()->Input("Bias").empty();
bool mkldnn_with_bias = is_mkldnn && has_bias;
bool onednn_with_bias = is_onednn && has_bias;

// Create eltwise_y (conv bias) variable
phi::DenseTensor* eltwise_y_in_tensor = nullptr;
Node* eltwise_y_in_node = nullptr;
if (!mkldnn_with_bias) {
if (!onednn_with_bias) {
VarDesc eltwise_y_in_desc(
patterns::PDNodeName("fuse_conv_bn", conv_type() + "_eltwise_y_in"));
eltwise_y_in_desc.SetShape(common::vectorize(bn_bias_tensor->dims()));
Expand Down Expand Up @@ -413,13 +413,13 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {

// with MKL-DNN fuse conv+bn into conv with bias
// without MKL-DNN fuse conv+bn into conv+elementwise_add
if (is_mkldnn) {
if (is_onednn) {
if (conv->Op()->Type() == "conv2d" ||
conv->Op()->Type() == "depthwise_conv2d" ||
conv->Op()->Type() == "conv2d_transpose") {
ConvertToFusedOp(conv->Op());
}
if (mkldnn_with_bias) {
if (onednn_with_bias) {
// reuse existing conv bias node
auto conv_bias_names = conv->Op()->Input("Bias");
PADDLE_ENFORCE_EQ(
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/framework/ir/fuse_pass_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ void FusePassBase::AddStatis(int count_of_fused) const {
FuseOptions FusePassBase::FindFuseOption(const Node& node1,
const Node& node2) const {
#ifdef PADDLE_WITH_DNNL
bool node1_mkldnn = node1.Op()->HasAttr("use_mkldnn") &&
bool node1_onednn = node1.Op()->HasAttr("use_mkldnn") &&
PADDLE_GET_CONST(bool, node1.Op()->GetAttr("use_mkldnn"));
bool node2_mkldnn = node2.Op()->HasAttr("use_mkldnn") &&
bool node2_onednn = node2.Op()->HasAttr("use_mkldnn") &&
PADDLE_GET_CONST(bool, node2.Op()->GetAttr("use_mkldnn"));
if (node1_mkldnn && node2_mkldnn)
return FUSE_MKLDNN;
else if (!node1_mkldnn && !node2_mkldnn)
if (node1_onednn && node2_onednn)
return FUSE_ONEDNN;
else if (!node1_onednn && !node2_onednn)
return FUSE_NATIVE;
else
return DO_NOT_FUSE;
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/ir/fuse_pass_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ static const char kScaleAndZeroPointParamAttr[] =

enum FuseOptions {
DO_NOT_FUSE, // fusing will not be done
FUSE_NATIVE, // fusing will be done without MKL-DNN
FUSE_MKLDNN // fusing will be done with MKL-DNN
FUSE_NATIVE, // fusing will be done without ONE-DNN
FUSE_ONEDNN // fusing will be done with ONE-DNN
};

class FusePassBase : public OpCompatSensiblePass {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv(
GET_IR_NODE_FROM_SUBGRAPH(
elementwise_out, elementwise_out, elementwise_pattern);

if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return;
if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_ONEDNN) return;
if (!IsReachable(g, residual_data, conv_output)) return;
if (HasFusedActivation(conv_op)) return;
if (HasFusedElementwiseAdd(conv_op)) return;
Expand Down Expand Up @@ -237,8 +237,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
return;
}

if (FindFuseOption(*conv_x_op, *elementwise_op) != FUSE_MKLDNN) return;
if (FindFuseOption(*conv_y_op, *elementwise_op) != FUSE_MKLDNN) return;
if (FindFuseOption(*conv_x_op, *elementwise_op) != FUSE_ONEDNN) return;
if (FindFuseOption(*conv_y_op, *elementwise_op) != FUSE_ONEDNN) return;

Node* projection_node;
Node* residual_conv_op;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void MatmulElementwiseAddMKLDNNFusePass::FuseMatmulElementwiseAdd(
GET_IR_NODE_FROM_SUBGRAPH(
elementwise_add_out, elementwise_add_out, matmul_pattern);

if (FindFuseOption(*matmul, *elementwise_add) != FUSE_MKLDNN) return;
if (FindFuseOption(*matmul, *elementwise_add) != FUSE_ONEDNN) return;
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "op compat for matmul_elementwise_add_onednn_fuse_pass failed.";
Expand Down