@@ -55,6 +55,10 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
5555 op->SetInput (" X" , {inputs[0 ]});
5656 op->SetOutput (" Out" , {outputs[0 ]});
5757 op->SetAttr (" mkldnn_data_type" , mkldnn_data_type);
58+ } else if (type == " slice" ) {
59+ op->SetInput (" Input" , {inputs[0 ]});
60+ op->SetOutput (" Out" , {outputs[0 ]});
61+ op->SetAttr (" mkldnn_data_type" , mkldnn_data_type);
5862 } else if (type == " dropout" ) {
5963 op->SetInput (" X" , {inputs[0 ]});
6064 op->SetOutput (" Out" , {outputs[0 ]});
@@ -784,6 +788,113 @@ TEST(CpuQuantizePass, reshapeBetweenNonQuantizedOp) {
784788 added_nodes_count, 2 .0f * 127 );
785789}
786790
791+ static const std::initializer_list<std::string> variable_names_slice = {
792+ " a" , " b" , " c" , " d" };
793+
794+ // a->Dequantize->b
795+ // b->Slice->c
796+ // c->Dropout->d
797+ ProgramDesc BuildProgramDescSlice () {
798+ ProgramDesc prog;
799+ for (auto & v : variable_names_slice) {
800+ prog.MutableBlock (0 )->Var (v);
801+ }
802+ SetOp (&prog, " dequantize" , " Dequantize1" , {" a" }, {" b" }, true );
803+ SetOp (&prog, " slice" , " Slice" , {" b" }, {" c" }, true , " int8" );
804+ SetOp (&prog, " dropout" , " Dropout" , {" c" }, {" d" }, true , " float32" );
805+
806+ return prog;
807+ }
808+
809+ // a->Transpose->b
810+ // b->slice->c
811+ // c->Dropout->d
812+ ProgramDesc BuildProgramDescSliceBetweenNonQuantizedOp () {
813+ ProgramDesc prog;
814+ for (auto & v : variable_names_slice) {
815+ prog.MutableBlock (0 )->Var (v);
816+ }
817+
818+ SetOp (&prog, " transpose2" , " Transpose2" , {" a" }, {" b" }, true , " float32" );
819+ SetOp (&prog, " slice" , " Slice" , {" b" }, {" c" }, true , " int8" );
820+ SetOp (&prog, " dropout" , " Dropout" , {" c" }, {" d" }, true , " float32" );
821+
822+ return prog;
823+ }
824+
825+ void MainTestSlice (const ProgramDesc& prog, int transpose_count,
826+ int slice_count, int quant_count, int dequant_count,
827+ int added_nodes_count, float scale) {
828+ std::unique_ptr<ir::Graph> graph (new ir::Graph (prog));
829+ int original_nodes_num, current_nodes_num;
830+ PreparePass (&graph, prog, variable_names_slice, &original_nodes_num,
831+ ¤t_nodes_num);
832+
833+ float quant_scale = 1 .0f ;
834+ float dequant_scale = 1 .0f ;
835+ int quantize_nodes_count = 0 ;
836+ int dequantize_nodes_count = 0 ;
837+ int transpose_nodes_count = 0 ;
838+ int slice_nodes_count = 0 ;
839+ for (auto * node : graph->Nodes ()) {
840+ if (node->IsOp ()) {
841+ auto * op = node->Op ();
842+ if (op->Type () == " transpose2" ) {
843+ transpose_nodes_count++;
844+ } else if (op->Type () == " slice" ) {
845+ slice_nodes_count++;
846+ } else if (op->Type () == " quantize" ) {
847+ quantize_nodes_count++;
848+ quant_scale = BOOST_GET_CONST (float , op->GetAttr (" Scale" ));
849+ EXPECT_EQ (quant_scale, scale) << " Scale for node '" + op->Type () + " '." ;
850+ } else if (op->Type () == " dequantize" ) {
851+ dequantize_nodes_count++;
852+ auto op_name = op->GetAttrIfExists <std::string>(" name" );
853+ VLOG (3 ) << op_name << " \n " ;
854+ if (op_name != " Dequantize1" ) {
855+ dequant_scale = BOOST_GET_CONST (float , op->GetAttr (" Scale" ));
856+ EXPECT_EQ (dequant_scale, scale)
857+ << " Scale for node '" + op->Type () + " '." ;
858+ }
859+ }
860+ }
861+ }
862+ EXPECT_EQ (transpose_nodes_count, transpose_count);
863+ EXPECT_EQ (slice_nodes_count, slice_count);
864+ EXPECT_EQ (quantize_nodes_count, quant_count);
865+ EXPECT_EQ (dequantize_nodes_count, dequant_count);
866+ EXPECT_EQ (original_nodes_num + added_nodes_count, current_nodes_num);
867+ }
868+
869+ TEST (CpuQuantizePass, slice) {
870+ // a->Dequantize->b
871+ // b2->Quant->b3->slice->c1->Dequant->c2
872+ // c2->Dropout->d
873+ int slice_count = 1 ;
874+ int transpose_count = 0 ;
875+ int quant_count = 1 ;
876+ int dequant_count = 2 ;
877+ // 1 Quant + 1 IN + 1 DeQuant + 1 OUT
878+ int added_nodes_count = 4 ;
879+ MainTestSlice (BuildProgramDescSlice (), transpose_count, slice_count,
880+ quant_count, dequant_count, added_nodes_count, 2 .0f * 127 );
881+ }
882+
883+ TEST (CpuQuantizePass, sliceBetweenNonQuantizedOp) {
884+ // a->Transpos2->b
885+ // b->slice->c
886+ // c->Dropout->d
887+ int slice_count = 1 ;
888+ int transpose_count = 1 ;
889+ int quant_count = 0 ;
890+ int dequant_count = 0 ;
891+ // 0 Quant + 0 IN + 0 DeQuant + 0 OUT
892+ int added_nodes_count = 0 ;
893+ MainTestSlice (BuildProgramDescSliceBetweenNonQuantizedOp (), transpose_count,
894+ slice_count, quant_count, dequant_count, added_nodes_count,
895+ 2 .0f * 127 );
896+ }
897+
787898static const std::initializer_list<std::string> variable_names_matmul = {
788899 " a" , " b" , " c" , " d" , " e" , " f" };
789900
0 commit comments