@@ -146,6 +146,14 @@ void CPUQuantizePass::QuantizeInputs(Graph* g,
146146 float shift,
147147 std::string shift_attr_name) const {
148148 auto inputs = op->inputs ;
149+ auto var_names = op->Op ()->Inputs ().at (input_name);
150+ std::vector<std::string> unique_var_names;
151+ for (unsigned i = 0 ; i < var_names.size (); i++)
152+ if (std::find (unique_var_names.begin (),
153+ unique_var_names.end (),
154+ var_names[i]) == unique_var_names.end ())
155+ unique_var_names.push_back (var_names[i]);
156+
149157 auto output = op->outputs [0 ];
150158 PADDLE_ENFORCE_GE (inputs.size (),
151159 1 ,
@@ -163,33 +171,59 @@ void CPUQuantizePass::QuantizeInputs(Graph* g,
163171 // create a quantize op desc prototype
164172 OpDesc q_desc;
165173 q_desc.SetType (" quantize" );
166-
167174 std::vector<Node*> quantize_out_nodes (inputs.size ());
168175 std::vector<std::string> quantize_out_node_names (inputs.size ());
169176
170177 double scale_out = GetScaleValueForNode (output);
171178 unsigned max = are_inputs_unsigned ? U8_MAX : S8_MAX;
172179 float scale = scale_out * max;
173180
174- for (size_t i = 0 ; i < inputs.size (); i++) {
175- // Create quantize output variable
181+ for (size_t var_id = 0 ; var_id < unique_var_names.size (); var_id++) {
182+ auto index = -1 ;
183+ for (size_t it = 0 ; it < inputs.size (); it++) {
184+ if (inputs[it]->Name () == unique_var_names[var_id]) index = it;
185+ }
186+
187+ if (index == -1 ) {
188+ PADDLE_ENFORCE_NE (index,
189+ -1 ,
190+ platform::errors::InvalidArgument (
191+ " Var(%s) isn't the input of the %s operator." ,
192+ unique_var_names[var_id],
193+ op->Op ()->Type ()));
194+ }
195+
196+ auto * input = inputs.at (index);
197+
176198 VarDesc quantize_out_desc (patterns::PDNodeName (" quantize" , " out" ));
177- quantize_out_nodes[i ] = g->CreateVarNode (&quantize_out_desc);
178- quantize_out_node_names[i ] = quantize_out_nodes[i ]->Name ();
199+ quantize_out_nodes[var_id ] = g->CreateVarNode (&quantize_out_desc);
200+ quantize_out_node_names[var_id ] = quantize_out_nodes[var_id ]->Name ();
179201
180202 q_desc.SetAttr (" Scale" , scale);
181203 q_desc.SetAttr (" Shift" , shift);
182- q_desc.SetInput (" Input" , std::vector<std::string>({inputs[i] ->Name ()}));
183- q_desc.SetOutput (" Output " ,
184- std::vector<std::string>({quantize_out_node_names[i ]}));
204+ q_desc.SetInput (" Input" , std::vector<std::string>({input ->Name ()}));
205+ q_desc.SetOutput (
206+ " Output " , std::vector<std::string>({quantize_out_node_names[var_id ]}));
185207 q_desc.SetAttr (" is_negative_input" , !are_inputs_unsigned);
186208 auto quantize_op = g->CreateOpNode (&q_desc); // OpDesc will be copied.
187209
188210 // link quantize op
189- UnlinkNodes (inputs[i], op);
190- IR_NODE_LINK_TO (inputs[i], quantize_op);
191- IR_NODE_LINK_TO (quantize_op, quantize_out_nodes[i]);
192- IR_NODE_LINK_TO (quantize_out_nodes[i], op);
211+ UnlinkNodes (input, op);
212+ IR_NODE_LINK_TO (input, quantize_op);
213+ IR_NODE_LINK_TO (quantize_op, quantize_out_nodes[var_id]);
214+ IR_NODE_LINK_TO (quantize_out_nodes[var_id], op);
215+ }
216+
217+ // If any inputs were duplicated, now you have to enter them in the correct
218+ // order.
219+ for (size_t i = unique_var_names.size (); i < var_names.size (); i++) {
220+ auto index = std::find (
221+ unique_var_names.begin (), unique_var_names.end (), var_names[i]);
222+ if (index != unique_var_names.end ()) {
223+ auto id = std::distance (unique_var_names.begin (), index);
224+ quantize_out_node_names[i] = quantize_out_nodes[id]->Name ();
225+ IR_NODE_LINK_TO (quantize_out_nodes[id], op);
226+ }
193227 }
194228
195229 // update op's input
@@ -252,44 +286,62 @@ void CPUQuantizePass::DequantizeOutputs(Graph* g,
252286 bool is_unsigned,
253287 std::string scale_attr_name) const {
254288 auto outputs = op->outputs ;
289+ auto var_names = op->Op ()->Outputs ().at (output_name);
290+
255291 PADDLE_ENFORCE_GE (outputs.size (),
256292 1 ,
257293 platform::errors::InvalidArgument (
258294 " OP(%s)'s outputs(%d) must be equal or greater than 1." ,
259295 op->Name (),
260296 outputs.size ()));
261297
262- std::vector<std::string> quantize_in_node_names (outputs.size ());
298+ std::vector<std::string> dequantize_in_node_names (outputs.size ());
299+ std::vector<Node*> dequantize_in_nodes (outputs.size ());
263300
264301 unsigned max = is_unsigned ? U8_MAX : S8_MAX;
265302 float scale = scale_to_one * max;
266303
267- for (size_t i = 0 ; i < outputs.size (); i++) {
304+ for (size_t var_id = 0 ; var_id < var_names.size (); var_id++) {
305+ auto index = -1 ;
306+ for (size_t it = 0 ; it < outputs.size (); it++) {
307+ if (outputs[it]->Name () == var_names[var_id]) index = it;
308+ }
309+
310+ if (index == -1 ) {
311+ PADDLE_ENFORCE_NE (index,
312+ -1 ,
313+ platform::errors::InvalidArgument (
314+ " Var(%s) isn't the input of the %s operator." ,
315+ var_names[var_id],
316+ op->Op ()->Type ()));
317+ }
318+
319+ auto * output = outputs.at (index);
320+
268321 // Create dequantize input variable
269322 VarDesc dequantize_in_desc (patterns::PDNodeName (" dequantize" , " in" ));
270- Node* dequantize_in_node = g->CreateVarNode (&dequantize_in_desc);
271- quantize_in_node_names[i ] = dequantize_in_node ->Name ();
323+ dequantize_in_nodes[var_id] = g->CreateVarNode (&dequantize_in_desc);
324+ dequantize_in_node_names[var_id ] = dequantize_in_nodes[var_id] ->Name ();
272325
273326 // create a dequantize op node for output.
274327 OpDesc deq_desc;
275328 deq_desc.SetType (" dequantize" );
276- deq_desc.SetInput (" Input" ,
277- std::vector<std::string>({quantize_in_node_names[i]}));
278- deq_desc.SetOutput (" Output" ,
279- std::vector<std::string>({outputs[i]->Name ()}));
329+ deq_desc.SetInput (
330+ " Input" , std::vector<std::string>({dequantize_in_node_names[var_id]}));
331+ deq_desc.SetOutput (" Output" , std::vector<std::string>({output->Name ()}));
280332 deq_desc.SetAttr (" Scale" , scale);
281333 deq_desc.SetAttr (" is_negative_input" , !is_unsigned);
282334 auto dequantize_op = g->CreateOpNode (&deq_desc); // OpDesc will be copied.
283335
284336 // link dequantize op
285- UnlinkNodes (op, outputs[i] );
286- IR_NODE_LINK_TO (op, dequantize_in_node );
287- IR_NODE_LINK_TO (dequantize_in_node , dequantize_op);
288- IR_NODE_LINK_TO (dequantize_op, outputs[i] );
337+ UnlinkNodes (op, output );
338+ IR_NODE_LINK_TO (op, dequantize_in_nodes[var_id] );
339+ IR_NODE_LINK_TO (dequantize_in_nodes[var_id] , dequantize_op);
340+ IR_NODE_LINK_TO (dequantize_op, output );
289341 }
290342
291343 // update op's output
292- op->Op ()->SetOutput (output_name, quantize_in_node_names );
344+ op->Op ()->SetOutput (output_name, dequantize_in_node_names );
293345 if (!scale_attr_name.empty ()) op->Op ()->SetAttr (scale_attr_name, scale);
294346}
295347
0 commit comments