1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15- #include " paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h"
1615#include < gtest/gtest.h>
16+
17+ #include " paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h"
1718#include " paddle/fluid/framework/naive_executor.h"
1819#include " paddle/fluid/platform/place.h"
1920
@@ -234,11 +235,70 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
234235 return prog;
235236}
236237
238+ /* a->relu->b->Dequant->c(u8)->Quant->d-\
239+ * e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
240+ * i->relu->j->Dequant->k(u8)->Quant->l-/
241+ */
242+ ProgramDesc BuildU8U8U8ConcatProgramDesc (float scale_out, float scale) {
243+ ProgramDesc prog;
244+ for (auto & v : variable_names) {
245+ prog.MutableBlock (0 )->Var (v);
246+ }
247+ SetOp (&prog, " relu" , " Relu1" , {" a" }, {" b" }, true , {scale, scale_out});
248+ SetOp (&prog, " relu" , " Relu2" , {" e" }, {" f" }, true , {scale, scale_out});
249+ SetOp (&prog, " relu" , " Relu3" , {" i" }, {" j" }, true , {scale, scale_out});
250+
251+ SetOp (&prog, " dequantize" , " Dequant1" , {" b" }, {" c" }, true ,
252+ {scale, scale_out});
253+ SetOp (&prog, " dequantize" , " Dequant2" , {" f" }, {" g" }, true ,
254+ {scale, scale_out});
255+ SetOp (&prog, " dequantize" , " Dequant3" , {" j" }, {" k" }, true ,
256+ {scale, scale_out});
257+
258+ SetOp (&prog, " quantize" , " Quant1" , {" c" }, {" d" }, true , {scale, scale_out},
259+ 0 .0f , " float32" , false , 1 , false ); // is_negative_input = false
260+ SetOp (&prog, " quantize" , " Quant2" , {" g" }, {" h" }, true , {scale, scale_out},
261+ 0 .0f , " float32" , false , 1 , false ); // is_negative_input = false
262+ SetOp (&prog, " quantize" , " Quant3" , {" k" }, {" l" }, true , {scale, scale_out},
263+ 0 .0f , " float32" , false , 1 , false ); // is_negative_input = false
264+
265+ SetOp (&prog, " concat" , " Concat1" , {" d" , " h" , " l" }, {" x" }, true );
266+ return prog;
267+ }
268+
269+ /* a->relu->b->Dequant->c(u8)->Quant->d-\
270+ * e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
271+ * i->pool2d->j->Dequant->k(s8)->Quant->l-/
272+ */
273+ ProgramDesc BuildU8U8S8ConcatProgramDesc (float scale_out, float scale) {
274+ ProgramDesc prog;
275+ for (auto & v : variable_names) {
276+ prog.MutableBlock (0 )->Var (v);
277+ }
278+ SetOp (&prog, " relu" , " Pool2d1" , {" a" }, {" b" }, true , {scale, scale_out});
279+ SetOp (&prog, " relu" , " Relu1" , {" e" }, {" f" }, true , {scale, scale_out});
280+ SetOp (&prog, " pool2d" , " Pool2d2" , {" i" }, {" j" }, true , {scale, scale_out});
281+
282+ SetOp (&prog, " dequantize" , " Dequant1" , {" b" }, {" c" }, true ,
283+ {scale, scale_out});
284+ SetOp (&prog, " dequantize" , " Dequant2" , {" f" }, {" g" }, true ,
285+ {scale, scale_out});
286+ SetOp (&prog, " dequantize" , " Dequant3" , {" j" }, {" k" }, true ,
287+ {scale, scale_out});
288+
289+ SetOp (&prog, " quantize" , " Quant1" , {" c" }, {" d" }, true , {scale, scale_out});
290+ SetOp (&prog, " quantize" , " Quant2" , {" g" }, {" h" }, true , {scale, scale_out});
291+ SetOp (&prog, " quantize" , " Quant3" , {" k" }, {" l" }, true , {scale, scale_out});
292+
293+ SetOp (&prog, " concat" , " Concat1" , {" d" , " h" , " l" }, {" x" }, true );
294+ return prog;
295+ }
296+
237297/* a->pool2d->b->Dequant->c(s8)->Quant->d-\
238298 * e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
239299 * i->pool2d->j->Dequant->k(s8)->Quant->l-/
240300 */
241- ProgramDesc BuildConvS8U8S8ConcatProgramDesc (float scale_out, float scale) {
301+ ProgramDesc BuildS8U8S8ConcatProgramDesc (float scale_out, float scale) {
242302 ProgramDesc prog;
243303 for (auto & v : variable_names) {
244304 prog.MutableBlock (0 )->Var (v);
@@ -255,8 +315,35 @@ ProgramDesc BuildConvS8U8S8ConcatProgramDesc(float scale_out, float scale) {
255315 {scale, scale_out});
256316
257317 SetOp (&prog, " quantize" , " Quant1" , {" c" }, {" d" }, true , {scale, scale_out});
258- SetOp (&prog, " quantize" , " Quant2" , {" g" }, {" h" }, true , {scale, scale_out},
259- 0.0 , " float32" , false , 1 , false );
318+ SetOp (&prog, " quantize" , " Quant2" , {" g" }, {" h" }, true , {scale, scale_out});
319+ SetOp (&prog, " quantize" , " Quant3" , {" k" }, {" l" }, true , {scale, scale_out});
320+
321+ SetOp (&prog, " concat" , " Concat1" , {" d" , " h" , " l" }, {" x" }, true );
322+ return prog;
323+ }
324+
325+ /* a->pool2d->b->Dequant->c(s8)->Quant->d-\
326+ * e->pool2d->f->Dequant->g(s8)->Quant->h--Concat1->x
327+ * i->pool2d->j->Dequant->k(s8)->Quant->l-/
328+ */
329+ ProgramDesc BuildS8S8S8ConcatProgramDesc (float scale_out, float scale) {
330+ ProgramDesc prog;
331+ for (auto & v : variable_names) {
332+ prog.MutableBlock (0 )->Var (v);
333+ }
334+ SetOp (&prog, " pool2d" , " Pool2d1" , {" a" }, {" b" }, true , {scale, scale_out});
335+ SetOp (&prog, " pool2d" , " Pool2d2" , {" e" }, {" f" }, true , {scale, scale_out});
336+ SetOp (&prog, " pool2d" , " Pool2d3" , {" i" }, {" j" }, true , {scale, scale_out});
337+
338+ SetOp (&prog, " dequantize" , " Dequant1" , {" b" }, {" c" }, true ,
339+ {scale, scale_out});
340+ SetOp (&prog, " dequantize" , " Dequant2" , {" f" }, {" g" }, true ,
341+ {scale, scale_out});
342+ SetOp (&prog, " dequantize" , " Dequant3" , {" j" }, {" k" }, true ,
343+ {scale, scale_out});
344+
345+ SetOp (&prog, " quantize" , " Quant1" , {" c" }, {" d" }, true , {scale, scale_out});
346+ SetOp (&prog, " quantize" , " Quant2" , {" g" }, {" h" }, true , {scale, scale_out});
260347 SetOp (&prog, " quantize" , " Quant3" , {" k" }, {" l" }, true , {scale, scale_out});
261348
262349 SetOp (&prog, " concat" , " Concat1" , {" d" , " h" , " l" }, {" x" }, true );
@@ -834,16 +921,46 @@ TEST(CpuQuantizeSquashPass, quant_bf16_conv2d) {
834921 remove_nodes);
835922}
836923
837- TEST (CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat ) {
924+ TEST (CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat1 ) {
838925 // removed 2 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
839926 auto remove_nodes = 8 ;
840927 std::unordered_map<std::string, int > expected_operators = {{" concat" , 1 },
841928 {" quantize" , 1 },
842929 {" dequantize" , 1 },
843930 {" relu" , 1 },
844931 {" pool2d" , 2 }};
845- CheckNodesTest (BuildConvS8U8S8ConcatProgramDesc (1 .2f , 1 .2f ),
846- expected_operators, remove_nodes);
932+ CheckNodesTest (BuildS8U8S8ConcatProgramDesc (1 .2f , 1 .2f ), expected_operators,
933+ remove_nodes);
934+ }
935+
936+ TEST (CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat2) {
937+ // removed 1 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
938+ auto remove_nodes = 4 ;
939+ std::unordered_map<std::string, int > expected_operators = {{" concat" , 1 },
940+ {" quantize" , 2 },
941+ {" dequantize" , 2 },
942+ {" relu" , 2 },
943+ {" pool2d" , 1 }};
944+ CheckNodesTest (BuildU8U8S8ConcatProgramDesc (1 .2f , 1 .2f ), expected_operators,
945+ remove_nodes);
946+ }
947+
948+ TEST (CpuQuantizeSquashPass, squash_all_s8_input_to_concat1) {
949+ // removed 3 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
950+ auto remove_nodes = 12 ;
951+ std::unordered_map<std::string, int > expected_operators = {
952+ {" concat" , 1 }, {" quantize" , 0 }, {" dequantize" , 0 }, {" pool2d" , 3 }};
953+ CheckNodesTest (BuildS8S8S8ConcatProgramDesc (1 .2f , 1 .2f ), expected_operators,
954+ remove_nodes);
955+ }
956+
957+ TEST (CpuQuantizeSquashPass, squash_all_u8_input_to_concat2) {
958+ // removed 3 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
959+ auto remove_nodes = 12 ;
960+ std::unordered_map<std::string, int > expected_operators = {
961+ {" concat" , 1 }, {" quantize" , 0 }, {" dequantize" , 0 }, {" relu" , 3 }};
962+ CheckNodesTest (BuildU8U8U8ConcatProgramDesc (1 .2f , 1 .2f ), expected_operators,
963+ remove_nodes);
847964}
848965
849966} // namespace ir
0 commit comments