Skip to content
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ pass_library(adaptive_pool2d_convert_global_pass inference)
pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(layer_norm_fuse_pass inference)
pass_library(add_support_int8_pass inference)
pass_library(matmul_scale_fuse_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)
if(WITH_GPU OR WITH_ROCM)
Expand Down
43 changes: 43 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,49 @@ PDNode *patterns::MatmulV2::operator()() {
return matmul_v2_out;
}

PDNode *patterns::MatmulScale::operator()() {
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto matmul_in_x = pattern->NewNode(matmul_in_x_repr())
->AsInput()
->assert_is_op_input("matmul", "X");
auto matmul_in_y = pattern->NewNode(matmul_in_y_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto scale_op = pattern->NewNode(scale_op_repr())->assert_is_op("scale");
auto scale_in_x = pattern->NewNode(scale_in_x_repr())
->assert_is_op_output("matmul", "Out")
->assert_is_op_input("scale", "X");
auto scale_out = pattern->NewNode(scale_out_repr())
->AsOutput()
->assert_is_op_output("scale", "Out");
matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({scale_in_x});
scale_op->LinksFrom({scale_in_x}).LinksTo({scale_out});
return scale_out;
}

PDNode *patterns::MatmulV2Scale::operator()() {
auto matmul_v2_op =
pattern->NewNode(matmul_v2_op_repr())->assert_is_op("matmul_v2");
auto matmul_v2_in_x = pattern->NewNode(matmul_v2_in_x_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "X");
auto matmul_v2_in_y = pattern->NewNode(matmul_v2_in_y_repr())
->AsInput()
->assert_is_persistable_var() // Y is weight
->assert_is_op_input("matmul_v2", "Y");
auto scale_op = pattern->NewNode(scale_op_repr())->assert_is_op("scale");
auto scale_in_x = pattern->NewNode(scale_in_x_repr())
->assert_is_op_output("matmul_v2", "Out")
->assert_is_op_input("scale", "X");
auto scale_out = pattern->NewNode(scale_out_repr())
->AsOutput()
->assert_is_op_output("scale", "Out");
matmul_v2_op->LinksFrom({matmul_v2_in_x, matmul_v2_in_y})
.LinksTo({scale_in_x});
scale_op->LinksFrom({scale_in_x}).LinksTo({scale_out});
return scale_out;
}

PDNode *patterns::Squeeze2Matmul::operator()() {
auto squeeze2_in_x = pattern->NewNode(squeeze2_in_x_repr())
->assert_is_op_input("squeeze2", "X")
Expand Down
30 changes: 30 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,36 @@ struct MatmulV2 : public PatternBase {
PATTERN_DECL_NODE(matmul_v2_out);
};

// Matmul + scale
// Forward pass.
struct MatmulScale : public PatternBase {
MatmulScale(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_scale") {}

PDNode* operator()();
PATTERN_DECL_NODE(matmul_in_x);
PATTERN_DECL_NODE(matmul_in_y);
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(scale_in_x);
PATTERN_DECL_NODE(scale_op);
PATTERN_DECL_NODE(scale_out);
};

// Matmul_v2 + scale
// Forward pass.
struct MatmulV2Scale : public PatternBase {
MatmulV2Scale(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_v2_scale") {}

PDNode* operator()();
PATTERN_DECL_NODE(matmul_v2_in_x);
PATTERN_DECL_NODE(matmul_v2_in_y);
PATTERN_DECL_NODE(matmul_v2_op);
PATTERN_DECL_NODE(scale_in_x);
PATTERN_DECL_NODE(scale_op);
PATTERN_DECL_NODE(scale_out);
};

// Squeeze2 + Matmul
// Forward pass.
struct Squeeze2Matmul : public PatternBase {
Expand Down
258 changes: 258 additions & 0 deletions paddle/fluid/framework/ir/matmul_scale_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/matmul_scale_fuse_pass.h"

#include <cmath>
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_proto_maker.h"

#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace ir {

class Node;

MatmulScaleFusePass::MatmulScaleFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("transpose_X")
.IsType<bool>()
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End()
.AddAttr("alpha")
.IsType<float>()
.End();

AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("ScaleTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("bias_after_scale")
.IsType<bool>()
.End()
.AddAttr("scale")
.End()
.AddAttr("bias")
.IsNumEQ(0.0f)
.End();
}

MatmulV2ScaleFusePass::MatmulV2ScaleFusePass() {
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();

AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("ScaleTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("bias_after_scale")
.IsType<bool>()
.End()
.AddAttr("scale")
.End()
.AddAttr("bias")
.IsNumEQ(0.0f)
.End();
}

void MatmulScaleFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "matmul_scale_fuse";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
patterns::MatmulScale matmul_scale_pattern(gpd.mutable_pattern(), name_scope);
matmul_scale_pattern();

int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "matmul_scale_fuse pass";
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_in_x, scale_in_x, matmul_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_op, scale_op, matmul_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, matmul_scale_pattern);

auto* scope = param_scope();
float bias = BOOST_GET_CONST(float, scale_op->Op()->GetAttr("bias"));
if (std::abs(bias) > 1e-5) return;
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "matmul_scale_fuse_pass in op compat failed.";
return;
}

float scale = BOOST_GET_CONST(float, scale_op->Op()->GetAttr("scale"));
float matmul_alpha =
BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha"));
auto const& names = scale_op->Op()->InputNames();
bool has_scale_tensor =
std::find(names.begin(), names.end(), "ScaleTensor") != names.end();
if (has_scale_tensor && scale_op->Op()->Input("ScaleTensor").size() > 0) {
std::string scale_var_name = scale_op->Op()->Input("ScaleTensor").front();
auto* scale_var = scope->FindVar(scale_var_name);
// ScaleTensor must be weight
if (scale_var == nullptr) return;
auto* scale_tensor = scale_var->GetMutable<LoDTensor>();
scale = *(scale_tensor->data<float>());
}

OpDesc* matmul_desc = matmul_op->Op();
matmul_desc->SetAttr("alpha", scale * matmul_alpha);
matmul_desc->SetOutput("Out", {scale_out->Name()});
if (!IsCompat(*matmul_desc)) {
LOG(WARNING) << "matmul_scale_fuse_pass in out mul op compat failed.";
return;
}
IR_NODE_LINK_TO(matmul_op, scale_out);
GraphSafeRemoveNodes(graph, {scale_in_x, scale_op});
++found_count;
};

gpd(graph, handler);
AddStatis(found_count);
}

void MatmulV2ScaleFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "matmul_v2_scale_fuse";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
patterns::MatmulV2Scale matmul_v2_scale_pattern(gpd.mutable_pattern(),
name_scope);
matmul_v2_scale_pattern();

int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "matmul_v2_scale_fuse pass";
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
matmul_v2_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
matmul_v2_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op,
matmul_v2_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_in_x, scale_in_x, matmul_v2_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_op, scale_op, matmul_v2_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, matmul_v2_scale_pattern);

auto* scope = param_scope();
float bias = BOOST_GET_CONST(float, scale_op->Op()->GetAttr("bias"));
if (std::abs(bias) > 1e-5) return;
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "matmul_v2_scale_fuse_pass in op compat failed.";
return;
}

float scale = BOOST_GET_CONST(float, scale_op->Op()->GetAttr("scale"));
auto const& names = scale_op->Op()->InputNames();
bool has_scale_tensor =
std::find(names.begin(), names.end(), "ScaleTensor") != names.end();
if (has_scale_tensor && scale_op->Op()->Input("ScaleTensor").size() > 0) {
std::string scale_var_name = scale_op->Op()->Input("ScaleTensor").front();
auto* scale_var = scope->FindVar(scale_var_name);
// ScaleTensor must be weight
if (scale_var == nullptr) return;
auto* scale_tensor = scale_var->GetMutable<LoDTensor>();
scale = *(scale_tensor->data<float>());
}

auto* matmul_y =
scope->FindVar(matmul_v2_in_y->Name())->GetMutable<LoDTensor>();
auto y_data = matmul_y->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < matmul_y->numel(); ++i) {
y_data[i] *= scale;
}

OpDesc* matmul_v2_desc = matmul_v2_op->Op();
matmul_v2_desc->SetOutput("Out", {scale_out->Name()});
if (!IsCompat(*matmul_v2_desc)) {
LOG(WARNING) << "matmul_v2_scale_fuse_pass in out mul op compat failed.";
return;
}
IR_NODE_LINK_TO(matmul_v2_op, scale_out);
GraphSafeRemoveNodes(graph, {scale_in_x, scale_op});
++found_count;
};

gpd(graph, handler);
AddStatis(found_count);
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(matmul_scale_fuse_pass,
paddle::framework::ir::MatmulScaleFusePass);
REGISTER_PASS_CAPABILITY(matmul_scale_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("scale", 0));

REGISTER_PASS(matmul_v2_scale_fuse_pass,
paddle::framework::ir::MatmulV2ScaleFusePass);
REGISTER_PASS_CAPABILITY(matmul_v2_scale_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("scale", 0));
Loading