Skip to content

Commit 89069af

Browse files
authored
Support quantization of condition block (#37498)
* Support sub graph quant-post
1 parent 76c7322 commit 89069af

File tree

3 files changed

+387
-41
lines changed

3 files changed

+387
-41
lines changed

python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py

Lines changed: 72 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,23 @@ def quantize(self):
410410
for op_type in self._dynamic_quantize_op_type):
411411
self._collect_dynamic_quantize_op_threshold(
412412
self._dynamic_quantize_op_type)
413+
414+
# Move sub blocks persistable var to global block
415+
global_block = self._program.global_block()
416+
for _op in global_block.ops:
417+
if _op.type == "while":
418+
_block_id = _op.attr("sub_block").id
419+
_block = self._program.block(_block_id)
420+
persistables = []
421+
for _name, _var in _block.vars.items():
422+
if _var.persistable:
423+
global_block._clone_variable(_var)
424+
persistables.append(_name)
425+
for _name in persistables:
426+
_block._remove_var(_name)
427+
persistables.extend(_op.input('X'))
428+
_op.desc.set_input("X", persistables)
429+
413430
return self._program
414431

415432
def save_quantized_model(self,
@@ -451,10 +468,6 @@ def _load_model_data(self):
451468
model_filename=self._model_filename,
452469
params_filename=self._params_filename)
453470

454-
if self._program.num_blocks > 1:
455-
_logger.error("The post training quantization requires that the "
456-
"program only has one block.")
457-
458471
if self._optimize_model:
459472
self._optimize_fp32_model()
460473

@@ -505,23 +518,26 @@ def collect_var_name(var_name_list, persistable_var_names, op_type):
505518
self._quantized_act_var_name.add(var_name)
506519

507520
persistable_var_names = _all_persistable_var_names(self._program)
508-
for op in self._program.global_block().ops:
509-
op_type = op.type
510-
if self._is_full_quantize and \
511-
op_type not in self._quantizable_op_type:
512-
_logger.warning(op_type + " is not supported for quantization.")
513-
# For quantized ops, sample inputs and outputs
514-
if op_type in self._quantizable_op_type:
515-
collect_var_name(
516-
_get_op_input_var_names(op), persistable_var_names, op_type)
517-
collect_var_name(
518-
_get_op_output_var_names(op), persistable_var_names,
519-
op_type)
520-
# For other op, only sample output scale
521-
elif op_type in self._out_scale_op_list:
522-
collect_var_name(
523-
_get_op_output_var_names(op), persistable_var_names,
524-
op_type)
521+
for block_id in range(len(self._program.blocks)):
522+
for op in self._program.blocks[block_id].ops:
523+
op_type = op.type
524+
if self._is_full_quantize and \
525+
op_type not in self._quantizable_op_type:
526+
_logger.warning(op_type +
527+
" is not supported for quantization.")
528+
# For quantized ops, sample inputs and outputs
529+
if op_type in self._quantizable_op_type:
530+
collect_var_name(
531+
_get_op_input_var_names(op), persistable_var_names,
532+
op_type)
533+
collect_var_name(
534+
_get_op_output_var_names(op), persistable_var_names,
535+
op_type)
536+
# For other op, only sample output scale
537+
elif op_type in self._out_scale_op_list:
538+
collect_var_name(
539+
_get_op_output_var_names(op), persistable_var_names,
540+
op_type)
525541

526542
def _set_activation_persistable(self):
527543
'''
@@ -696,16 +712,17 @@ def _save_input_threhold(self):
696712
'''
697713
assert self._algo == "min_max", \
698714
"The algo should be min_max to save input threshold."
699-
for op in self._program.global_block().ops:
700-
if op.type in self._quantizable_op_type:
701-
for var_name in _get_op_input_var_names(op):
702-
assert var_name in self._quantized_var_min
703-
assert var_name in self._quantized_var_max
704-
op._set_attr(var_name + ".min",
705-
self._quantized_var_min[var_name])
706-
op._set_attr(var_name + ".max",
707-
self._quantized_var_max[var_name])
708-
op._set_attr("with_quant_attr", True)
715+
for block_id in range(len(self._program.blocks)):
716+
for op in self._program.blocks[block_id].ops:
717+
if op.type in self._quantizable_op_type:
718+
for var_name in _get_op_input_var_names(op):
719+
assert var_name in self._quantized_var_min
720+
assert var_name in self._quantized_var_max
721+
op._set_attr(var_name + ".min",
722+
self._quantized_var_min[var_name])
723+
op._set_attr(var_name + ".max",
724+
self._quantized_var_max[var_name])
725+
op._set_attr("with_quant_attr", True)
709726

710727
def _collect_activation_abs_min_max(self):
711728
'''
@@ -795,7 +812,12 @@ def _update_program(self):
795812
activation_quantize_type=self._activation_quantize_type,
796813
weight_quantize_type=self._weight_quantize_type,
797814
quantizable_op_type=major_quantizable_op_types)
798-
transform_pass.apply(graph)
815+
816+
for sub_graph in graph.all_sub_graphs():
817+
# Insert fake_quant/fake_dequantize op must in test graph, so
818+
# set per graph's _for_test is True.
819+
sub_graph._for_test = True
820+
transform_pass.apply(sub_graph)
799821

800822
# use AddQuantDequantPass to insert fake_quant_dequant op
801823
minor_quantizable_op_types = []
@@ -806,7 +828,10 @@ def _update_program(self):
806828
scope=self._scope,
807829
place=self._place,
808830
quantizable_op_type=minor_quantizable_op_types)
809-
add_quant_dequant_pass.apply(graph)
831+
832+
for sub_graph in graph.all_sub_graphs():
833+
sub_graph._for_test = True
834+
add_quant_dequant_pass.apply(sub_graph)
810835

811836
# save threshold to scale var node
812837
if self._algo in ["KL", "hist"]:
@@ -836,7 +861,11 @@ def _update_program(self):
836861
activation_bits=self._activation_bits,
837862
weight_quantize_type=self._weight_quantize_type,
838863
quantizable_op_type=major_quantizable_op_types)
839-
freeze_pass.apply(graph)
864+
865+
for sub_graph in graph.all_sub_graphs():
866+
sub_graph._for_test = True
867+
freeze_pass.apply(sub_graph)
868+
840869
self._program = graph.to_program()
841870

842871
def _save_output_threshold(self):
@@ -888,13 +917,15 @@ def analysis_and_save_info(op_node, out_var_name):
888917
save_info(op_node, out_var_name, self._quantized_var_max,
889918
"out_max", "post_min_max")
890919

891-
for op in self._program.global_block().ops:
892-
if op.type in (self._quantizable_op_type + self._out_scale_op_list):
893-
out_var_names = _get_op_output_var_names(op)
894-
assert len(out_var_names) == 1, "Post training " + \
895-
"quantization only support one output for " + op.type
896-
for var_name in out_var_names:
897-
analysis_and_save_info(op, var_name)
920+
for block_id in range(len(self._program.blocks)):
921+
for op in self._program.blocks[block_id].ops:
922+
if op.type in (
923+
self._quantizable_op_type + self._out_scale_op_list):
924+
out_var_names = _get_op_output_var_names(op)
925+
assert len(out_var_names) == 1, "Post training " + \
926+
"quantization only support one output for " + op.type
927+
for var_name in out_var_names:
928+
analysis_and_save_info(op, var_name)
898929

899930
def _collect_dynamic_quantize_op_threshold(self, target_ops_type):
900931
"""

python/paddle/fluid/contrib/slim/tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ endfunction()
139139
if(WIN32)
140140
list(REMOVE_ITEM TEST_OPS test_light_nas)
141141
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mnist)
142+
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_while)
142143
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
143144
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
144145
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model)
@@ -336,6 +337,7 @@ if(NOT WIN32)
336337
set_tests_properties(test_post_training_quantization_mobilenetv1 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY")
337338
set_tests_properties(test_post_training_quantization_resnet50 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY")
338339
set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT 120)
340+
set_tests_properties(test_post_training_quantization_while PROPERTIES TIMEOUT 120)
339341
set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 120)
340342
set_tests_properties(test_weight_quantization_mobilenetv1 PROPERTIES TIMEOUT 120)
341343
endif()

0 commit comments

Comments
 (0)