Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,27 @@ bool CPUQuantizeSquashPass::IsDequantizeInputUint8(
return false;
}

bool CPUQuantizeSquashPass::IsDequantizeQuantizeIncompatible(
Node* quant_op, Node* dequant_in, Node* next_op) const {
bool is_concat_signed =
quant_op->Op()->GetAttrIfExists<bool>("is_negative_input");
bool is_input_unsigned = IsDequantizeInputUint8(dequant_in);
/* TODO(sfraczek): remove elementwise from this condition when BinaryMKLDNN
kernel will support two different input data types */
bool is_next_op_concat_or_elementwise =
next_op->Op()->Type() == "concat" ||
next_op->Op()->Type().find("elementwise") == 0;
if (is_next_op_concat_or_elementwise && is_concat_signed &&
is_input_unsigned) {
VLOG(4) << "Do not squash dequant-quant, because "
<< "next_op is: " << next_op->Op()->Type()
<< ", is_concat_signed: " << is_concat_signed
<< ", is_input_unsigned: " << is_input_unsigned << ".";
return true;
}
return false;
}

void CPUQuantizeSquashPass::DequantQuantSquash(
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
Expand All @@ -151,9 +172,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern);

// Don't squash if e.g. just one concat input is unsigned
if (IsDequantizeInputUint8(dequant_in) &&
!quant_op->Op()->GetAttrIfExists<bool>("is_negative_input")) {
if (IsDequantizeQuantizeIncompatible(quant_op, dequant_in, next_op)) {
return;
}

Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/
bool IsDequantizeInputUint8(const Node* dequant_in) const;

/*
* Don't squash unsigned dequantize with signed quantize.
* This is important for concat and elementwise ops.
* When inputs have different sign, concat will assume signed type and
* elementwise assumes first input type.
*/
bool IsDequantizeQuantizeIncompatible(Node* quant_op, Node* dequant_in,
Node* next_op) const;

/*
* Squash dequantize-quantize ops pairs into requantize or nothing
*/
Expand Down
131 changes: 124 additions & 7 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h"
#include <gtest/gtest.h>

#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/platform/place.h"

Expand Down Expand Up @@ -234,11 +235,70 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
return prog;
}

/* a->relu->b->Dequant->c(u8)->Quant->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
* i->relu->j->Dequant->k(u8)->Quant->l-/
*/
ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "relu", "Relu1", {"a"}, {"b"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu2", {"e"}, {"f"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu3", {"i"}, {"j"}, true, {scale, scale_out});

SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true,
{scale, scale_out});

SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out},
0.0f, "float32", false, 1, false); // is_negative_input = false
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out},
0.0f, "float32", false, 1, false); // is_negative_input = false
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out},
0.0f, "float32", false, 1, false); // is_negative_input = false

SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
return prog;
}

/* a->relu->b->Dequant->c(u8)->Quant->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
* i->pool2d->j->Dequant->k(s8)->Quant->l-/
*/
ProgramDesc BuildU8U8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "relu", "Pool2d1", {"a"}, {"b"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu1", {"e"}, {"f"}, true, {scale, scale_out});
SetOp(&prog, "pool2d", "Pool2d2", {"i"}, {"j"}, true, {scale, scale_out});

SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true,
{scale, scale_out});

SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out});

SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
return prog;
}

/* a->pool2d->b->Dequant->c(s8)->Quant->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
* i->pool2d->j->Dequant->k(s8)->Quant->l-/
*/
ProgramDesc BuildConvS8U8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc BuildS8U8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
Expand All @@ -255,8 +315,35 @@ ProgramDesc BuildConvS8U8S8ConcatProgramDesc(float scale_out, float scale) {
{scale, scale_out});

SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out},
0.0, "float32", false, 1, false);
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out});

SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
return prog;
}

/* a->pool2d->b->Dequant->c(s8)->Quant->d-\
* e->pool2d->f->Dequant->g(s8)->Quant->h--Concat1->x
* i->pool2d->j->Dequant->k(s8)->Quant->l-/
*/
ProgramDesc BuildS8S8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "pool2d", "Pool2d1", {"a"}, {"b"}, true, {scale, scale_out});
SetOp(&prog, "pool2d", "Pool2d2", {"e"}, {"f"}, true, {scale, scale_out});
SetOp(&prog, "pool2d", "Pool2d3", {"i"}, {"j"}, true, {scale, scale_out});

SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true,
{scale, scale_out});

SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out});

SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
Expand Down Expand Up @@ -834,16 +921,46 @@ TEST(CpuQuantizeSquashPass, quant_bf16_conv2d) {
remove_nodes);
}

TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat) {
TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat1) {
// removed 2 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 8;
std::unordered_map<std::string, int> expected_operators = {{"concat", 1},
{"quantize", 1},
{"dequantize", 1},
{"relu", 1},
{"pool2d", 2}};
CheckNodesTest(BuildConvS8U8S8ConcatProgramDesc(1.2f, 1.2f),
expected_operators, remove_nodes);
CheckNodesTest(BuildS8U8S8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
remove_nodes);
}

TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat2) {
// removed 1 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 4;
std::unordered_map<std::string, int> expected_operators = {{"concat", 1},
{"quantize", 2},
{"dequantize", 2},
{"relu", 2},
{"pool2d", 1}};
CheckNodesTest(BuildU8U8S8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
remove_nodes);
}

TEST(CpuQuantizeSquashPass, squash_all_s8_input_to_concat1) {
// removed 3 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 12;
std::unordered_map<std::string, int> expected_operators = {
{"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"pool2d", 3}};
CheckNodesTest(BuildS8S8S8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
remove_nodes);
}

TEST(CpuQuantizeSquashPass, squash_all_u8_input_to_concat2) {
// removed 3 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 12;
std::unordered_map<std::string, int> expected_operators = {
{"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"relu", 3}};
CheckNodesTest(BuildU8U8U8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
remove_nodes);
}

} // namespace ir
Expand Down
6 changes: 1 addition & 5 deletions paddle/fluid/inference/api/mkldnn_quantizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,11 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs(
// force unsigned type if already know it
bool is_unsigned = false;
bool compute_scale = true;
if (op->Type() == "conv2d") {
if (op->Type() == "conv2d" || op->Type() == "fc") {
// output of conv2d with relu must be unsigned
std::string fuse_activation =
op->GetAttrIfExists<std::string>("fuse_activation");
is_unsigned = (fuse_activation == "relu" || fuse_activation == "relu6");
} else if (op->Type() == "fc") {
std::string activation_type =
op->GetAttrIfExists<std::string>("activation_type");
is_unsigned = (activation_type == "relu" || activation_type == "relu6");
} else if (op->Type() == "relu") {
is_unsigned = true;
} else if (op->Type() == "transpose2" || op->Type() == "reshape2" ||
Expand Down