2525# from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
2626
2727# from intel_quantization.quantize_graph import GraphRewriter
28- from intel_quantization .transform_graph .strip_unused import StripUnusedNodes
29- from intel_quantization .transform_graph .fold_batch_norm import FoldBatchNormNodes
30- from intel_quantization .transform_graph .insert_logging import InsertLogging
31- from intel_quantization .transform_graph .freeze_max_min import freeze_max
32- from intel_quantization .transform_graph .freeze_max_min import freeze_min
33- from intel_quantization .transform_graph .freeze_max_min import freeze_requantization_range
34- from intel_quantization .transform_graph .fuse_quantized_conv_and_requantize import fuse_quantized_conv_and_requantize
35- from intel_quantization .transform_graph .fuse_column_wise_mul import FuseColumnWiseMul
36- from intel_quantization .transform_graph .rerange_quantized_concat import RerangeQuantizedConcat
37- from intel_quantization .util import read_graph , write_graph
38- from intel_quantization .quantize_graph .quantize_graph_for_intel_cpu import QuantizeGraphForIntel
28+ from .transform_graph .strip_unused import StripUnusedNodes
29+ from .transform_graph .fold_batch_norm import FoldBatchNormNodes
30+ from .transform_graph .insert_logging import InsertLogging
31+ from .transform_graph .freeze_max_min import freeze_max
32+ from .transform_graph .freeze_max_min import freeze_min
33+ from .transform_graph .freeze_max_min import freeze_requantization_range
34+ from .transform_graph .freeze_max_min import get_all_fp32_data , get_tensor_histogram , combine_histogram
35+ from .transform_graph .fuse_quantized_conv_and_requantize import fuse_quantized_conv_and_requantize
36+ from .transform_graph .fuse_column_wise_mul import FuseColumnWiseMul
37+ from .transform_graph .rerange_quantized_concat import RerangeQuantizedConcat
38+ from .util import read_graph , write_graph
39+ from .quantize_graph .quantize_graph_for_intel_cpu import QuantizeGraphForIntel
40+ from .quantize_graph .quantize_graph_common import QuantizeGraphHelper
3941import os
4042import shlex
4143import subprocess
42- import sys
4344import logging
4445
4546logging .getLogger ().setLevel (level = logging .INFO )
5253
5354class GraphConverter :
5455 def __init__ (self , input_graph , output_graph , inputs = [], outputs = [], excluded_ops = [], excluded_nodes = [],
55- per_channel = False , input_graph_is_binary = True ):
56+ per_channel = False , input_graph_is_binary = True , algo = 'DIRECT' ):
5657 """Convert graph.
5758
5859 :param input_graph: input graph pb file.
@@ -73,13 +74,18 @@ def __init__(self, input_graph, output_graph, inputs=[], outputs=[], excluded_op
7374 self .per_channel = per_channel
7475 self .excluded_ops = excluded_ops
7576 self .excluded_nodes = excluded_nodes
77+ self .algo = algo
7678 self ._low_precision_mode = 'eightbit'
77-
79+ self ._calibration_data = []
80+ self ._fp32_print_data = []
7881 self .gen_calib_data_cmds = None
7982 self .debug = False
8083 self ._check_tf_version ()
8184 self ._check_args ()
8285 self ._gen_tmp_filenames ()
86+ self ._kl_op_dict = {}
87+ self ._kl_keys = []
88+ self ._print_node_mapping = {}
8389
8490 def _check_tf_version (self ):
8591 is_supported_version = False
@@ -113,7 +119,7 @@ def _gen_tmp_filenames(self):
113119 self ._fp32_optimized_graph = os .path .join (self ._output_path , 'fp32_optimized_graph.pb' )
114120 self ._int8_dynamic_range_graph = os .path .join (self ._output_path , 'int8_dynamic_range_graph.pb' )
115121 self ._int8_logged_graph = os .path .join (self ._output_path , 'int8_logged_graph.pb' )
116- self ._requant_min_max_log = os .path .join (self ._output_path , 'requant_min_max_log.txt ' )
122+ self ._fp32_logged_graph = os .path .join (self ._output_path , 'fp32_logged_graph.pb ' )
117123 self ._int8_frozen_range_graph = os .path .join (self ._output_path , 'int8_frozen_range_graph.pb' )
118124 if not self .output_graph :
119125 self .output_graph = os .path .join (self ._output_path , 'int8_final_fused_graph.pb' )
@@ -137,6 +143,58 @@ def convert(self):
137143 else :
138144 self .quantize ()
139145
146+ def _get_fp32_print_node_names (self ):
147+ offset_map = {
148+ "QuantizedConv2DWithBiasSumAndRelu" : 3 ,
149+ "QuantizedConv2DWithBiasAndRelu" : 2 ,
150+ "QuantizedConv2DWithBias" : 1 ,
151+ }
152+ target_conv_op = []
153+ sorted_graph = QuantizeGraphHelper ().get_sorted_graph (
154+ self ._fp32_origin_graph , self .outputs )
155+
156+ node_name_mapping = {
157+ node .name : node
158+ for node in self ._tmp_graph_def .node if node .op != "Const"
159+ }
160+
161+ for node in self ._tmp_graph_def .node :
162+ if node .op in offset_map :
163+ target_conv_op .append (node .name .split ('_eightbit_' )[0 ])
164+ fp32_node_name_mapping = {
165+ node .name : node
166+ for node in sorted_graph .node if node .op != "Const"
167+ }
168+ sorted_node_names = [i .name for i in sorted_graph .node if i .op != "Const" ]
169+
170+ output_node_names = []
171+ for i in target_conv_op :
172+ if node_name_mapping [
173+ i + "_eightbit_quantized_conv" ].op == 'QuantizedConv2DWithBiasSumAndRelu' :
174+ start_index = sorted_node_names .index (i )
175+ for index , value in enumerate (sorted_node_names [start_index :]):
176+ if fp32_node_name_mapping [value ].op .startswith (
177+ "Add" ) and fp32_node_name_mapping [
178+ sorted_node_names [start_index + index + 1 ]].op == "Relu" :
179+ output_node_names .append (
180+ sorted_node_names [start_index + index + 1 ])
181+ self ._print_node_mapping [sorted_node_names [start_index + index + 1 ]] = i
182+ elif i in sorted_node_names :
183+ start_index = sorted_node_names .index (i )
184+ end_index = start_index + offset_map [node_name_mapping [
185+ i + "_eightbit_quantized_conv" ].op ]
186+ output_node_names .append (sorted_node_names [end_index ])
187+ self ._print_node_mapping [sorted_node_names [end_index ]] = i
188+
189+ for i in output_node_names :
190+ self ._kl_keys .append (';' + i + '__print__;__KL' )
191+
192+ InsertLogging (self ._fp32_origin_graph ,
193+ node_name_list = output_node_names ,
194+ message = "__KL:" ,
195+ summarize = - 1 , dump_fp32 = True ).do_transformation ()
196+ write_graph (self ._fp32_origin_graph , self ._fp32_logged_graph )
197+
140198 def quantize (self ):
141199 """Quantize graph only (without optimizing fp32 graph), including:
142200 1) quantize graph,
@@ -150,9 +208,14 @@ def quantize(self):
150208 'to generate calibration data.' )
151209 try :
152210 self ._quantize_graph ()
211+ if self .algo == "KL" :
212+ self ._get_fp32_print_node_names ()
213+ self ._generate_calibration_data (self ._fp32_logged_graph ,
214+ self ._fp32_print_data , True )
215+
153216 self ._insert_logging ()
154- self ._generate_calibration_data ()
155- self ._freeze_requantization_ranges ()
217+ self ._generate_calibration_data (self . _int8_logged_graph , self . _calibration_data )
218+ self ._freeze_requantization_ranges (self . _kl_op_dict )
156219 self ._fuse_requantize_with_fused_quantized_conv ()
157220 except Exception as e :
158221 logging .error ('Failed to quantize graph due to: %s' , str (e ))
@@ -172,6 +235,7 @@ def _optimize_frozen_fp32_graph(self):
172235 self ._tmp_graph_def = graph_util .remove_training_nodes (self ._tmp_graph_def , self .outputs )
173236 self ._tmp_graph_def = FoldBatchNormNodes (self ._tmp_graph_def ).do_transform ()
174237 write_graph (self ._tmp_graph_def , self ._fp32_optimized_graph )
238+ self ._fp32_origin_graph = self ._tmp_graph_def
175239
176240 def _quantize_graph (self ):
177241 """quantize graph."""
@@ -199,32 +263,50 @@ def _insert_logging(self):
199263 ops = ["RequantizationRange{}" .format ("PerChannel" if self .per_channel else "" )],
200264 message = "__requant_min_max:" ).do_transformation ()
201265 InsertLogging (self ._tmp_graph_def , ops = ["Min" ], message = "__min:" ).do_transformation ()
202- InsertLogging (self ._tmp_graph_def , ops = ["Max" ], message = "__max:" ).do_transformation ()
266+ InsertLogging (self ._tmp_graph_def , ops = ["Max" ],
267+ message = "__max:" ).do_transformation ()
268+ # InsertLogging(
269+ # self._tmp_graph_def,
270+ # ops=["QuantizedConv2DWithBiasAndRelu",
271+ # "QuantizedConv2DWithBias"
272+ # ],
273+ # message="__KL:",
274+ # summarize=-1).do_transformation()
275+
203276 write_graph (self ._tmp_graph_def , self ._int8_logged_graph )
204277 self ._tmp_graph_def .CopyFrom (int8_dynamic_range_graph_def )
205278
206- def _generate_calibration_data (self ):
279+ def _generate_calibration_data (self , graph , output , enable_kl_algo = False ):
207280 cmd = self .gen_calib_data_cmds
208- cmd = cmd .format (self ._int8_logged_graph )
209- f = open (self ._requant_min_max_log , 'w' , buffering = 1 )
210- p = subprocess .Popen (shlex .split (cmd ), stderr = subprocess .STDOUT , stdout = subprocess .PIPE )
211- try :
212- for line in p .stdout :
213- line_str = line .decode (sys .stdout .encoding )
214- sys .stdout .write (line_str )
215- f .write (line_str )
216- p .communicate ()
217- except Exception :
218- p .kill ()
219- p .wait ()
220- raise
221- if p .poll ():
222- raise SystemExit ('ERROR generating calibration data, command: \n {}' .format (cmd ))
223-
224- def _freeze_requantization_ranges (self ):
225- self ._tmp_graph_def = freeze_max (self ._tmp_graph_def , self ._requant_min_max_log )
226- self ._tmp_graph_def = freeze_min (self ._tmp_graph_def , self ._requant_min_max_log )
227- self ._tmp_graph_def = freeze_requantization_range (self ._tmp_graph_def , self ._requant_min_max_log )
281+ cmd = cmd .format (graph )
282+ p = subprocess .Popen (shlex .split (cmd ),
283+ stderr = subprocess .STDOUT ,
284+ stdout = subprocess .PIPE )
285+ while p .poll () is None :
286+ line = p .stdout .readline ().strip ().decode ()
287+ if line and line .startswith (';' ):
288+ if not enable_kl_algo :
289+ output .append (line )
290+
291+ if enable_kl_algo and line .rsplit (':' )[0 ] in self ._kl_keys :
292+ fp32_data = get_all_fp32_data (line .rsplit (':' )[- 1 ])
293+ key = self ._print_node_mapping [line [1 :].split ('__print' )[0 ]] + '_eightbit_requant_range'
294+ if key not in self ._kl_op_dict :
295+ self ._kl_op_dict [key ] = get_tensor_histogram (fp32_data )
296+ else :
297+ self ._kl_op_dict [key ] = combine_histogram (self ._kl_op_dict [key ], fp32_data )
298+
299+ def _freeze_requantization_ranges (self , additional_data = None ):
300+ use_moving_average = self .algo == "MA"
301+ self ._tmp_graph_def = freeze_max (self ._tmp_graph_def ,
302+ self ._calibration_data ,
303+ use_moving_average )
304+ self ._tmp_graph_def = freeze_min (self ._tmp_graph_def ,
305+ self ._calibration_data ,
306+ use_moving_average )
307+ self ._tmp_graph_def = freeze_requantization_range (
308+ self ._tmp_graph_def , self ._calibration_data , use_moving_average ,
309+ additional_data )
228310 if self .debug :
229311 write_graph (self ._tmp_graph_def , self ._int8_frozen_range_graph )
230312
@@ -256,5 +338,3 @@ def _post_clean(self):
256338 """
257339 if gfile .Exists (self ._int8_logged_graph ):
258340 os .remove (self ._int8_logged_graph )
259- if gfile .Exists (self ._requant_min_max_log ):
260- os .remove (self ._requant_min_max_log )
0 commit comments