Skip to content

Commit 0ffba1c

Browse files
authored
Correct multiple inputs and outputs (#48872)
1 parent 428fb80 commit 0ffba1c

File tree

3 files changed

+129
-35
lines changed

3 files changed

+129
-35
lines changed

paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,14 @@ void CPUQuantizePass::QuantizeInputs(Graph* g,
146146
float shift,
147147
std::string shift_attr_name) const {
148148
auto inputs = op->inputs;
149+
auto var_names = op->Op()->Inputs().at(input_name);
150+
std::vector<std::string> unique_var_names;
151+
for (unsigned i = 0; i < var_names.size(); i++)
152+
if (std::find(unique_var_names.begin(),
153+
unique_var_names.end(),
154+
var_names[i]) == unique_var_names.end())
155+
unique_var_names.push_back(var_names[i]);
156+
149157
auto output = op->outputs[0];
150158
PADDLE_ENFORCE_GE(inputs.size(),
151159
1,
@@ -163,33 +171,59 @@ void CPUQuantizePass::QuantizeInputs(Graph* g,
163171
// create a quantize op desc prototype
164172
OpDesc q_desc;
165173
q_desc.SetType("quantize");
166-
167174
std::vector<Node*> quantize_out_nodes(inputs.size());
168175
std::vector<std::string> quantize_out_node_names(inputs.size());
169176

170177
double scale_out = GetScaleValueForNode(output);
171178
unsigned max = are_inputs_unsigned ? U8_MAX : S8_MAX;
172179
float scale = scale_out * max;
173180

174-
for (size_t i = 0; i < inputs.size(); i++) {
175-
// Create quantize output variable
181+
for (size_t var_id = 0; var_id < unique_var_names.size(); var_id++) {
182+
auto index = -1;
183+
for (size_t it = 0; it < inputs.size(); it++) {
184+
if (inputs[it]->Name() == unique_var_names[var_id]) index = it;
185+
}
186+
187+
if (index == -1) {
188+
PADDLE_ENFORCE_NE(index,
189+
-1,
190+
platform::errors::InvalidArgument(
191+
"Var(%s) isn't the input of the %s operator.",
192+
unique_var_names[var_id],
193+
op->Op()->Type()));
194+
}
195+
196+
auto* input = inputs.at(index);
197+
176198
VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
177-
quantize_out_nodes[i] = g->CreateVarNode(&quantize_out_desc);
178-
quantize_out_node_names[i] = quantize_out_nodes[i]->Name();
199+
quantize_out_nodes[var_id] = g->CreateVarNode(&quantize_out_desc);
200+
quantize_out_node_names[var_id] = quantize_out_nodes[var_id]->Name();
179201

180202
q_desc.SetAttr("Scale", scale);
181203
q_desc.SetAttr("Shift", shift);
182-
q_desc.SetInput("Input", std::vector<std::string>({inputs[i]->Name()}));
183-
q_desc.SetOutput("Output",
184-
std::vector<std::string>({quantize_out_node_names[i]}));
204+
q_desc.SetInput("Input", std::vector<std::string>({input->Name()}));
205+
q_desc.SetOutput(
206+
"Output", std::vector<std::string>({quantize_out_node_names[var_id]}));
185207
q_desc.SetAttr("is_negative_input", !are_inputs_unsigned);
186208
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
187209

188210
// link quantize op
189-
UnlinkNodes(inputs[i], op);
190-
IR_NODE_LINK_TO(inputs[i], quantize_op);
191-
IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]);
192-
IR_NODE_LINK_TO(quantize_out_nodes[i], op);
211+
UnlinkNodes(input, op);
212+
IR_NODE_LINK_TO(input, quantize_op);
213+
IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[var_id]);
214+
IR_NODE_LINK_TO(quantize_out_nodes[var_id], op);
215+
}
216+
217+
// If any inputs were duplicated, now you have to enter them in the correct
218+
// order.
219+
for (size_t i = unique_var_names.size(); i < var_names.size(); i++) {
220+
auto index = std::find(
221+
unique_var_names.begin(), unique_var_names.end(), var_names[i]);
222+
if (index != unique_var_names.end()) {
223+
auto id = std::distance(unique_var_names.begin(), index);
224+
quantize_out_node_names[i] = quantize_out_nodes[id]->Name();
225+
IR_NODE_LINK_TO(quantize_out_nodes[id], op);
226+
}
193227
}
194228

195229
// update op's input
@@ -252,44 +286,62 @@ void CPUQuantizePass::DequantizeOutputs(Graph* g,
252286
bool is_unsigned,
253287
std::string scale_attr_name) const {
254288
auto outputs = op->outputs;
289+
auto var_names = op->Op()->Outputs().at(output_name);
290+
255291
PADDLE_ENFORCE_GE(outputs.size(),
256292
1,
257293
platform::errors::InvalidArgument(
258294
"OP(%s)'s outputs(%d) must be equal or greater than 1.",
259295
op->Name(),
260296
outputs.size()));
261297

262-
std::vector<std::string> quantize_in_node_names(outputs.size());
298+
std::vector<std::string> dequantize_in_node_names(outputs.size());
299+
std::vector<Node*> dequantize_in_nodes(outputs.size());
263300

264301
unsigned max = is_unsigned ? U8_MAX : S8_MAX;
265302
float scale = scale_to_one * max;
266303

267-
for (size_t i = 0; i < outputs.size(); i++) {
304+
for (size_t var_id = 0; var_id < var_names.size(); var_id++) {
305+
auto index = -1;
306+
for (size_t it = 0; it < outputs.size(); it++) {
307+
if (outputs[it]->Name() == var_names[var_id]) index = it;
308+
}
309+
310+
if (index == -1) {
311+
PADDLE_ENFORCE_NE(index,
312+
-1,
313+
platform::errors::InvalidArgument(
314+
"Var(%s) isn't the input of the %s operator.",
315+
var_names[var_id],
316+
op->Op()->Type()));
317+
}
318+
319+
auto* output = outputs.at(index);
320+
268321
// Create dequantize input variable
269322
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
270-
Node* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
271-
quantize_in_node_names[i] = dequantize_in_node->Name();
323+
dequantize_in_nodes[var_id] = g->CreateVarNode(&dequantize_in_desc);
324+
dequantize_in_node_names[var_id] = dequantize_in_nodes[var_id]->Name();
272325

273326
// create a dequantize op node for output.
274327
OpDesc deq_desc;
275328
deq_desc.SetType("dequantize");
276-
deq_desc.SetInput("Input",
277-
std::vector<std::string>({quantize_in_node_names[i]}));
278-
deq_desc.SetOutput("Output",
279-
std::vector<std::string>({outputs[i]->Name()}));
329+
deq_desc.SetInput(
330+
"Input", std::vector<std::string>({dequantize_in_node_names[var_id]}));
331+
deq_desc.SetOutput("Output", std::vector<std::string>({output->Name()}));
280332
deq_desc.SetAttr("Scale", scale);
281333
deq_desc.SetAttr("is_negative_input", !is_unsigned);
282334
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.
283335

284336
// link dequantize op
285-
UnlinkNodes(op, outputs[i]);
286-
IR_NODE_LINK_TO(op, dequantize_in_node);
287-
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
288-
IR_NODE_LINK_TO(dequantize_op, outputs[i]);
337+
UnlinkNodes(op, output);
338+
IR_NODE_LINK_TO(op, dequantize_in_nodes[var_id]);
339+
IR_NODE_LINK_TO(dequantize_in_nodes[var_id], dequantize_op);
340+
IR_NODE_LINK_TO(dequantize_op, output);
289341
}
290342

291343
// update op's output
292-
op->Op()->SetOutput(output_name, quantize_in_node_names);
344+
op->Op()->SetOutput(output_name, dequantize_in_node_names);
293345
if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
294346
}
295347

paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,45 @@ TEST(CpuQuantizePass, multi_gru_3) {
881881
MainTestMultiGru(layers);
882882
}
883883

884+
static const std::initializer_list<std::string>
885+
variable_names_multi_inputs_outputs = {"a", "b", "c1", "c2", "d", "e"};
886+
887+
// a->Pool->b
888+
// b->Split->c1, c2
889+
// (c1, c2, c1, c2)->Concat->d
890+
// d->Pool->e
891+
ProgramDesc BuildProgramDescMulti() {
892+
ProgramDesc prog;
893+
for (auto& v : variable_names_multi_inputs_outputs) {
894+
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
895+
}
896+
897+
SetOp(&prog, "pool2d", "Pool", {"a"}, {"b"}, true, "float32");
898+
SetOp(&prog, "split", "Split", {"b"}, {"c1", "c2"}, true, "int8");
899+
SetOp(
900+
&prog, "concat", "Concat", {"c1", "c2", "c1", "c2"}, {"d"}, true, "int8");
901+
SetOp(&prog, "pool2d", "Pool2", {"d"}, {"e"}, true, "float32");
902+
903+
return prog;
904+
}
905+
906+
TEST(CpuQuantizePass, multi_inputs_outputs_ops) {
907+
// a->QUANT1->Split
908+
// b1->DEQUANT->OUT->QUANT
909+
// b2->DEQUANT->OUT->QUANT
910+
// (b1, b2, b1, b2)->Concat->c->DEQUANT->Pool->d
911+
int added_nodes = 6 * 2;
912+
std::unordered_map<std::string, int> expected_operators = {{"pool2d", 2},
913+
{"concat", 1},
914+
{"split", 1},
915+
{"quantize", 3},
916+
{"dequantize", 3}};
917+
MainTest(BuildProgramDescMulti(),
918+
variable_names_multi_inputs_outputs,
919+
expected_operators,
920+
added_nodes);
921+
}
922+
884923
} // namespace ir
885924
} // namespace framework
886925
} // namespace paddle

paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
158158
PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("Scale"));
159159
float dequant_shift = dequant_op->Op()->GetAttrIfExists<float>("Shift");
160160
float quant_shift = quant_op->Op()->GetAttrIfExists<float>("Shift");
161+
if (quant_op->Op()->GetAttrIfExists<bool>("is_negative_input") !=
162+
dequant_op->Op()->GetAttrIfExists<bool>("is_negative_input")) {
163+
return;
164+
}
165+
161166
PADDLE_ENFORCE_NE(
162167
nodes_keep_counter->find(dequant_out),
163168
nodes_keep_counter->end(),
@@ -169,14 +174,13 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
169174
if (dequant_scale == quant_scale && dequant_shift == quant_shift) {
170175
// squash dequantize-quantize to nothing
171176
auto quant_out_var_name = quant_out->Name();
172-
auto next_op_inputs = next_op_desc->InputNames();
173-
for (const auto& name : next_op_inputs) {
174-
auto input_names = next_op_desc->Input(name);
177+
for (auto input_name : next_op_desc->InputNames()) {
178+
auto& input_names = next_op_desc->MutableInputs()->at(input_name);
175179
std::replace(input_names.begin(),
176180
input_names.end(),
177181
quant_out_var_name,
178182
dequant_in->Name());
179-
next_op_desc->SetInput(name, input_names);
183+
next_op_desc->SetInput(input_name, input_names);
180184
}
181185

182186
if (keep_dequant)
@@ -413,12 +417,11 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
413417

414418
// update the next operator input,
415419
// by replacing quant_out with first_quant_out
416-
auto last_op_names = last_op->Op()->Input(last_op_input_name);
417-
last_op_names.erase(
418-
std::remove(
419-
last_op_names.begin(), last_op_names.end(), quant_out->Name()),
420-
last_op_names.end());
421-
last_op_names.push_back(first_quant_out->Name());
420+
auto last_op_names = last_op->Op()->Inputs().at(last_op_input_name);
421+
std::replace(last_op_names.begin(),
422+
last_op_names.end(),
423+
quant_out->Name(),
424+
first_quant_out->Name());
422425
last_op_op->SetInput(last_op_input_name,
423426
std::vector<std::string>(last_op_names));
424427

0 commit comments

Comments
 (0)