@@ -410,6 +410,23 @@ def quantize(self):
410410 for op_type in self ._dynamic_quantize_op_type ):
411411 self ._collect_dynamic_quantize_op_threshold (
412412 self ._dynamic_quantize_op_type )
413+
414+ # Move sub blocks persistable var to global block
415+ global_block = self ._program .global_block ()
416+ for _op in global_block .ops :
417+ if _op .type == "while" :
418+ _block_id = _op .attr ("sub_block" ).id
419+ _block = self ._program .block (_block_id )
420+ persistables = []
421+ for _name , _var in _block .vars .items ():
422+ if _var .persistable :
423+ global_block ._clone_variable (_var )
424+ persistables .append (_name )
425+ for _name in persistables :
426+ _block ._remove_var (_name )
427+ persistables .extend (_op .input ('X' ))
428+ _op .desc .set_input ("X" , persistables )
429+
413430 return self ._program
414431
415432 def save_quantized_model (self ,
@@ -451,10 +468,6 @@ def _load_model_data(self):
451468 model_filename = self ._model_filename ,
452469 params_filename = self ._params_filename )
453470
454- if self ._program .num_blocks > 1 :
455- _logger .error ("The post training quantization requires that the "
456- "program only has one block." )
457-
458471 if self ._optimize_model :
459472 self ._optimize_fp32_model ()
460473
@@ -505,23 +518,26 @@ def collect_var_name(var_name_list, persistable_var_names, op_type):
505518 self ._quantized_act_var_name .add (var_name )
506519
507520 persistable_var_names = _all_persistable_var_names (self ._program )
508- for op in self ._program .global_block ().ops :
509- op_type = op .type
510- if self ._is_full_quantize and \
511- op_type not in self ._quantizable_op_type :
512- _logger .warning (op_type + " is not supported for quantization." )
513- # For quantized ops, sample inputs and outputs
514- if op_type in self ._quantizable_op_type :
515- collect_var_name (
516- _get_op_input_var_names (op ), persistable_var_names , op_type )
517- collect_var_name (
518- _get_op_output_var_names (op ), persistable_var_names ,
519- op_type )
520- # For other op, only sample output scale
521- elif op_type in self ._out_scale_op_list :
522- collect_var_name (
523- _get_op_output_var_names (op ), persistable_var_names ,
524- op_type )
521+ for block_id in range (len (self ._program .blocks )):
522+ for op in self ._program .blocks [block_id ].ops :
523+ op_type = op .type
524+ if self ._is_full_quantize and \
525+ op_type not in self ._quantizable_op_type :
526+ _logger .warning (op_type +
527+ " is not supported for quantization." )
528+ # For quantized ops, sample inputs and outputs
529+ if op_type in self ._quantizable_op_type :
530+ collect_var_name (
531+ _get_op_input_var_names (op ), persistable_var_names ,
532+ op_type )
533+ collect_var_name (
534+ _get_op_output_var_names (op ), persistable_var_names ,
535+ op_type )
536+ # For other op, only sample output scale
537+ elif op_type in self ._out_scale_op_list :
538+ collect_var_name (
539+ _get_op_output_var_names (op ), persistable_var_names ,
540+ op_type )
525541
526542 def _set_activation_persistable (self ):
527543 '''
@@ -696,16 +712,17 @@ def _save_input_threhold(self):
696712 '''
697713 assert self ._algo == "min_max" , \
698714 "The algo should be min_max to save input threshold."
699- for op in self ._program .global_block ().ops :
700- if op .type in self ._quantizable_op_type :
701- for var_name in _get_op_input_var_names (op ):
702- assert var_name in self ._quantized_var_min
703- assert var_name in self ._quantized_var_max
704- op ._set_attr (var_name + ".min" ,
705- self ._quantized_var_min [var_name ])
706- op ._set_attr (var_name + ".max" ,
707- self ._quantized_var_max [var_name ])
708- op ._set_attr ("with_quant_attr" , True )
715+ for block_id in range (len (self ._program .blocks )):
716+ for op in self ._program .blocks [block_id ].ops :
717+ if op .type in self ._quantizable_op_type :
718+ for var_name in _get_op_input_var_names (op ):
719+ assert var_name in self ._quantized_var_min
720+ assert var_name in self ._quantized_var_max
721+ op ._set_attr (var_name + ".min" ,
722+ self ._quantized_var_min [var_name ])
723+ op ._set_attr (var_name + ".max" ,
724+ self ._quantized_var_max [var_name ])
725+ op ._set_attr ("with_quant_attr" , True )
709726
710727 def _collect_activation_abs_min_max (self ):
711728 '''
@@ -795,7 +812,12 @@ def _update_program(self):
795812 activation_quantize_type = self ._activation_quantize_type ,
796813 weight_quantize_type = self ._weight_quantize_type ,
797814 quantizable_op_type = major_quantizable_op_types )
798- transform_pass .apply (graph )
815+
816+ for sub_graph in graph .all_sub_graphs ():
817+ # Insert fake_quant/fake_dequantize op must in test graph, so
818+ # set per graph's _for_test is True.
819+ sub_graph ._for_test = True
820+ transform_pass .apply (sub_graph )
799821
800822 # use AddQuantDequantPass to insert fake_quant_dequant op
801823 minor_quantizable_op_types = []
@@ -806,7 +828,10 @@ def _update_program(self):
806828 scope = self ._scope ,
807829 place = self ._place ,
808830 quantizable_op_type = minor_quantizable_op_types )
809- add_quant_dequant_pass .apply (graph )
831+
832+ for sub_graph in graph .all_sub_graphs ():
833+ sub_graph ._for_test = True
834+ add_quant_dequant_pass .apply (sub_graph )
810835
811836 # save threshold to scale var node
812837 if self ._algo in ["KL" , "hist" ]:
@@ -836,7 +861,11 @@ def _update_program(self):
836861 activation_bits = self ._activation_bits ,
837862 weight_quantize_type = self ._weight_quantize_type ,
838863 quantizable_op_type = major_quantizable_op_types )
839- freeze_pass .apply (graph )
864+
865+ for sub_graph in graph .all_sub_graphs ():
866+ sub_graph ._for_test = True
867+ freeze_pass .apply (sub_graph )
868+
840869 self ._program = graph .to_program ()
841870
842871 def _save_output_threshold (self ):
@@ -888,13 +917,15 @@ def analysis_and_save_info(op_node, out_var_name):
888917 save_info (op_node , out_var_name , self ._quantized_var_max ,
889918 "out_max" , "post_min_max" )
890919
891- for op in self ._program .global_block ().ops :
892- if op .type in (self ._quantizable_op_type + self ._out_scale_op_list ):
893- out_var_names = _get_op_output_var_names (op )
894- assert len (out_var_names ) == 1 , "Post training " + \
895- "quantization only support one output for " + op .type
896- for var_name in out_var_names :
897- analysis_and_save_info (op , var_name )
920+ for block_id in range (len (self ._program .blocks )):
921+ for op in self ._program .blocks [block_id ].ops :
922+ if op .type in (
923+ self ._quantizable_op_type + self ._out_scale_op_list ):
924+ out_var_names = _get_op_output_var_names (op )
925+ assert len (out_var_names ) == 1 , "Post training " + \
926+ "quantization only support one output for " + op .type
927+ for var_name in out_var_names :
928+ analysis_and_save_info (op , var_name )
898929
899930 def _collect_dynamic_quantize_op_threshold (self , target_ops_type ):
900931 """
0 commit comments