Skip to content

Commit f29da15

Browse files
author
Sylwester Fraczek
authored
[bugfix] to concat input squash (#39593)
* fix and add more tests * remove unwanted changes * check only concat and elementwise * move check to a function * add todo comment * Revert "fix ptq fc attr name fuse_activation->activation_type" This reverts commit ffd0233.
1 parent 2d2f11d commit f29da15

File tree

4 files changed

+156
-15
lines changed

4 files changed

+156
-15
lines changed

paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,27 @@ bool CPUQuantizeSquashPass::IsDequantizeInputUint8(
132132
return false;
133133
}
134134

135+
bool CPUQuantizeSquashPass::IsDequantizeQuantizeIncompatible(
136+
Node* quant_op, Node* dequant_in, Node* next_op) const {
137+
bool is_concat_signed =
138+
quant_op->Op()->GetAttrIfExists<bool>("is_negative_input");
139+
bool is_input_unsigned = IsDequantizeInputUint8(dequant_in);
140+
/* TODO(sfraczek): remove elementwise from this condition when BinaryMKLDNN
141+
kernel will support two different input data types */
142+
bool is_next_op_concat_or_elementwise =
143+
next_op->Op()->Type() == "concat" ||
144+
next_op->Op()->Type().find("elementwise") == 0;
145+
if (is_next_op_concat_or_elementwise && is_concat_signed &&
146+
is_input_unsigned) {
147+
VLOG(4) << "Do not squash dequant-quant, because "
148+
<< "next_op is: " << next_op->Op()->Type()
149+
<< ", is_concat_signed: " << is_concat_signed
150+
<< ", is_input_unsigned: " << is_input_unsigned << ".";
151+
return true;
152+
}
153+
return false;
154+
}
155+
135156
void CPUQuantizeSquashPass::DequantQuantSquash(
136157
Graph* graph,
137158
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
@@ -151,9 +172,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
151172
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern);
152173
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern);
153174

154-
// Don't squash if e.g. just one concat input is unsigned
155-
if (IsDequantizeInputUint8(dequant_in) &&
156-
!quant_op->Op()->GetAttrIfExists<bool>("is_negative_input")) {
175+
if (IsDequantizeQuantizeIncompatible(quant_op, dequant_in, next_op)) {
157176
return;
158177
}
159178

paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ class CPUQuantizeSquashPass : public FusePassBase {
4848
*/
4949
bool IsDequantizeInputUint8(const Node* dequant_in) const;
5050

51+
/*
52+
* Don't squash unsigned dequantize with signed quantize.
53+
* This is important for concat and elementwise ops.
54+
* When inputs have different sign, concat will assume signed type and
55+
* elementwise assumes first input type.
56+
*/
57+
bool IsDequantizeQuantizeIncompatible(Node* quant_op, Node* dequant_in,
58+
Node* next_op) const;
59+
5160
/*
5261
* Squash dequantize-quantize ops pairs into requantize or nothing
5362
*/

paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc

Lines changed: 124 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h"
1615
#include <gtest/gtest.h>
16+
17+
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h"
1718
#include "paddle/fluid/framework/naive_executor.h"
1819
#include "paddle/fluid/platform/place.h"
1920

@@ -234,11 +235,70 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
234235
return prog;
235236
}
236237

238+
/* a->relu->b->Dequant->c(u8)->Quant->d-\
239+
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
240+
* i->relu->j->Dequant->k(u8)->Quant->l-/
241+
*/
242+
ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) {
243+
ProgramDesc prog;
244+
for (auto& v : variable_names) {
245+
prog.MutableBlock(0)->Var(v);
246+
}
247+
SetOp(&prog, "relu", "Relu1", {"a"}, {"b"}, true, {scale, scale_out});
248+
SetOp(&prog, "relu", "Relu2", {"e"}, {"f"}, true, {scale, scale_out});
249+
SetOp(&prog, "relu", "Relu3", {"i"}, {"j"}, true, {scale, scale_out});
250+
251+
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true,
252+
{scale, scale_out});
253+
SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true,
254+
{scale, scale_out});
255+
SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true,
256+
{scale, scale_out});
257+
258+
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out},
259+
0.0f, "float32", false, 1, false); // is_negative_input = false
260+
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out},
261+
0.0f, "float32", false, 1, false); // is_negative_input = false
262+
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out},
263+
0.0f, "float32", false, 1, false); // is_negative_input = false
264+
265+
SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
266+
return prog;
267+
}
268+
269+
/* a->relu->b->Dequant->c(u8)->Quant->d-\
270+
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
271+
* i->pool2d->j->Dequant->k(s8)->Quant->l-/
272+
*/
273+
ProgramDesc BuildU8U8S8ConcatProgramDesc(float scale_out, float scale) {
274+
ProgramDesc prog;
275+
for (auto& v : variable_names) {
276+
prog.MutableBlock(0)->Var(v);
277+
}
278+
SetOp(&prog, "relu", "Pool2d1", {"a"}, {"b"}, true, {scale, scale_out});
279+
SetOp(&prog, "relu", "Relu1", {"e"}, {"f"}, true, {scale, scale_out});
280+
SetOp(&prog, "pool2d", "Pool2d2", {"i"}, {"j"}, true, {scale, scale_out});
281+
282+
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true,
283+
{scale, scale_out});
284+
SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true,
285+
{scale, scale_out});
286+
SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true,
287+
{scale, scale_out});
288+
289+
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out});
290+
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out});
291+
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out});
292+
293+
SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
294+
return prog;
295+
}
296+
237297
/* a->pool2d->b->Dequant->c(s8)->Quant->d-\
238298
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
239299
* i->pool2d->j->Dequant->k(s8)->Quant->l-/
240300
*/
241-
ProgramDesc BuildConvS8U8S8ConcatProgramDesc(float scale_out, float scale) {
301+
ProgramDesc BuildS8U8S8ConcatProgramDesc(float scale_out, float scale) {
242302
ProgramDesc prog;
243303
for (auto& v : variable_names) {
244304
prog.MutableBlock(0)->Var(v);
@@ -255,8 +315,35 @@ ProgramDesc BuildConvS8U8S8ConcatProgramDesc(float scale_out, float scale) {
255315
{scale, scale_out});
256316

257317
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out});
258-
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out},
259-
0.0, "float32", false, 1, false);
318+
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out});
319+
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out});
320+
321+
SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
322+
return prog;
323+
}
324+
325+
/* a->pool2d->b->Dequant->c(s8)->Quant->d-\
326+
* e->pool2d->f->Dequant->g(s8)->Quant->h--Concat1->x
327+
* i->pool2d->j->Dequant->k(s8)->Quant->l-/
328+
*/
329+
ProgramDesc BuildS8S8S8ConcatProgramDesc(float scale_out, float scale) {
330+
ProgramDesc prog;
331+
for (auto& v : variable_names) {
332+
prog.MutableBlock(0)->Var(v);
333+
}
334+
SetOp(&prog, "pool2d", "Pool2d1", {"a"}, {"b"}, true, {scale, scale_out});
335+
SetOp(&prog, "pool2d", "Pool2d2", {"e"}, {"f"}, true, {scale, scale_out});
336+
SetOp(&prog, "pool2d", "Pool2d3", {"i"}, {"j"}, true, {scale, scale_out});
337+
338+
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true,
339+
{scale, scale_out});
340+
SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true,
341+
{scale, scale_out});
342+
SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true,
343+
{scale, scale_out});
344+
345+
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out});
346+
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out});
260347
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out});
261348

262349
SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
@@ -834,16 +921,46 @@ TEST(CpuQuantizeSquashPass, quant_bf16_conv2d) {
834921
remove_nodes);
835922
}
836923

837-
TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat) {
924+
TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat1) {
838925
// removed 2 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
839926
auto remove_nodes = 8;
840927
std::unordered_map<std::string, int> expected_operators = {{"concat", 1},
841928
{"quantize", 1},
842929
{"dequantize", 1},
843930
{"relu", 1},
844931
{"pool2d", 2}};
845-
CheckNodesTest(BuildConvS8U8S8ConcatProgramDesc(1.2f, 1.2f),
846-
expected_operators, remove_nodes);
932+
CheckNodesTest(BuildS8U8S8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
933+
remove_nodes);
934+
}
935+
936+
TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat2) {
937+
// removed 1 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
938+
auto remove_nodes = 4;
939+
std::unordered_map<std::string, int> expected_operators = {{"concat", 1},
940+
{"quantize", 2},
941+
{"dequantize", 2},
942+
{"relu", 2},
943+
{"pool2d", 1}};
944+
CheckNodesTest(BuildU8U8S8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
945+
remove_nodes);
946+
}
947+
948+
TEST(CpuQuantizeSquashPass, squash_all_s8_input_to_concat1) {
949+
// removed 3 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
950+
auto remove_nodes = 12;
951+
std::unordered_map<std::string, int> expected_operators = {
952+
{"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"pool2d", 3}};
953+
CheckNodesTest(BuildS8S8S8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
954+
remove_nodes);
955+
}
956+
957+
TEST(CpuQuantizeSquashPass, squash_all_u8_input_to_concat2) {
958+
// removed 3 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
959+
auto remove_nodes = 12;
960+
std::unordered_map<std::string, int> expected_operators = {
961+
{"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"relu", 3}};
962+
CheckNodesTest(BuildU8U8U8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
963+
remove_nodes);
847964
}
848965

849966
} // namespace ir

paddle/fluid/inference/api/mkldnn_quantizer.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,11 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs(
116116
// force unsigned type if already know it
117117
bool is_unsigned = false;
118118
bool compute_scale = true;
119-
if (op->Type() == "conv2d") {
119+
if (op->Type() == "conv2d" || op->Type() == "fc") {
120120
// output of conv2d with relu must be unsigned
121121
std::string fuse_activation =
122122
op->GetAttrIfExists<std::string>("fuse_activation");
123123
is_unsigned = (fuse_activation == "relu" || fuse_activation == "relu6");
124-
} else if (op->Type() == "fc") {
125-
std::string activation_type =
126-
op->GetAttrIfExists<std::string>("activation_type");
127-
is_unsigned = (activation_type == "relu" || activation_type == "relu6");
128124
} else if (op->Type() == "relu") {
129125
is_unsigned = true;
130126
} else if (op->Type() == "transpose2" || op->Type() == "reshape2" ||

0 commit comments

Comments
 (0)