Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
"Weight scale should be nonzero, but get zero."));
weight_scale[i] = weight_scale[i] / range;
}
} else {
} else if (dequant_type == "fake_quantize_dequantize_abs_max") {
// Implement quantize_dequantize_abs_max quantization algorithm
float abs_max_weight = 0.;
for (int j = 0; j < weight_tensor->numel(); j++) {
Expand All @@ -192,6 +192,9 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero"));
weight_scale.push_back(abs_max_weight / range);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantize_dequantize op type: %s", dequant_type));
}

nodes2rm.insert(quant_dequant_op_outscale);
Expand Down
51 changes: 37 additions & 14 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,7 @@ PDNode *patterns::Matmul::operator()() {
->assert_is_op_input("matmul", "X");
auto matmul_in_y = pattern->NewNode(matmul_in_y_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("matmul", "Y");
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
Expand All @@ -1615,23 +1616,45 @@ PDNode *patterns::Matmul::operator()() {
return matmul_out;
}

// MatmulV2: tensor * weight
PDNode *patterns::MatmulV2Weight::operator()() {
auto matmul_v2_op =
pattern->NewNode(matmul_v2_op_repr())->assert_is_op("matmul_v2");

auto matmul_v2_in_x = pattern->NewNode(matmul_v2_in_x_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "X");
auto matmul_v2_in_y = pattern->NewNode(matmul_v2_in_y_repr())
->AsInput()
->assert_is_persistable_var() // Y is weight
->assert_is_op_input("matmul_v2", "Y");
auto matmul_v2_out = pattern->NewNode(matmul_v2_out_repr())
->AsOutput()
->assert_is_op_output("matmul_v2", "Out");

matmul_v2_op->LinksFrom({matmul_v2_in_x, matmul_v2_in_y})
.LinksTo({matmul_v2_out});
return matmul_v2_out;
}

// MatmulV2: tensor * tensor or tensor * weight
PDNode *patterns::MatmulV2::operator()() {
auto matmul_op =
pattern->NewNode(matmul_op_repr())->assert_is_op("matmul_v2");
auto matmul_v2_op =
pattern->NewNode(matmul_v2_op_repr())->assert_is_op("matmul_v2");

auto matmul_in_x = pattern->NewNode(matmul_in_x_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "X");
auto matmul_in_y = pattern->NewNode(matmul_in_y_repr())
->assert_is_persistable_var()
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
->assert_is_op_output("matmul_v2", "Out");
auto matmul_v2_in_x = pattern->NewNode(matmul_v2_in_x_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "X");
auto matmul_v2_in_y = pattern->NewNode(matmul_v2_in_y_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto matmul_v2_out = pattern->NewNode(matmul_v2_out_repr())
->AsOutput()
->assert_is_op_output("matmul_v2", "Out");

matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out});
return matmul_out;
matmul_v2_op->LinksFrom({matmul_v2_in_x, matmul_v2_in_y})
.LinksTo({matmul_v2_out});
return matmul_v2_out;
}

PDNode *patterns::Squeeze2Matmul::operator()() {
Expand Down
23 changes: 17 additions & 6 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -976,17 +976,28 @@ struct Matmul : public PatternBase {
PATTERN_DECL_NODE(matmul_out);
};

// Matmul_v2 op
// Forward pass for matmul_v2.
// MatmulV2: tensor * weight
struct MatmulV2Weight : public PatternBase {
MatmulV2Weight(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_v2_weight") {}

PDNode* operator()();
PATTERN_DECL_NODE(matmul_v2_in_x);
PATTERN_DECL_NODE(matmul_v2_in_y);
PATTERN_DECL_NODE(matmul_v2_op);
PATTERN_DECL_NODE(matmul_v2_out);
};

// MatmulV2: tensor * tensor or tensor * weight
struct MatmulV2 : public PatternBase {
MatmulV2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_v2") {}

PDNode* operator()();
PATTERN_DECL_NODE(matmul_in_x);
PATTERN_DECL_NODE(matmul_in_y);
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
PATTERN_DECL_NODE(matmul_v2_in_x);
PATTERN_DECL_NODE(matmul_v2_in_y);
PATTERN_DECL_NODE(matmul_v2_op);
PATTERN_DECL_NODE(matmul_v2_out);
};

// Squeeze2 + Matmul
Expand Down
Loading