@@ -277,6 +277,109 @@ def get_op_nodes_not_followed_by_specific_op(model, op1, op2):
277
277
278
278
return not_selected_op1_nodes
279
279
280
+ def custom_write_calibration_table (calibration_cache , dir = "." ):
281
+ """
282
+ Helper function to write calibration table to files.
283
+ """
284
+
285
+ import json
286
+ import logging
287
+ import flatbuffers
288
+ import numpy as np
289
+
290
+ import onnxruntime .quantization .CalTableFlatBuffers .KeyValue as KeyValue
291
+ import onnxruntime .quantization .CalTableFlatBuffers .TrtTable as TrtTable
292
+ from onnxruntime .quantization .calibrate import CalibrationMethod , TensorData , TensorsData
293
+
294
+ logging .info (f"calibration cache: { calibration_cache } " )
295
+
296
+ class MyEncoder (json .JSONEncoder ):
297
+ def default (self , obj ):
298
+ if isinstance (obj , (TensorData , TensorsData )):
299
+ return obj .to_dict ()
300
+ if isinstance (obj , TensorDataWrapper ):
301
+ return obj .data_dict
302
+ if isinstance (obj , np .ndarray ):
303
+ return {"data" : obj .tolist (), "dtype" : str (obj .dtype ), "CLS" : "numpy.array" }
304
+ if isinstance (obj , CalibrationMethod ):
305
+ return {"CLS" : obj .__class__ .__name__ , "value" : str (obj )}
306
+ return json .JSONEncoder .default (self , obj )
307
+
308
+ json_data = json .dumps (calibration_cache , cls = MyEncoder )
309
+
310
+ with open (os .path .join (dir , "calibration.json" ), "w" ) as file :
311
+ file .write (json_data ) # use `json.loads` to do the reverse
312
+
313
+ # Serialize data using FlatBuffers
314
+ zero = np .array (0 )
315
+ builder = flatbuffers .Builder (1024 )
316
+ key_value_list = []
317
+
318
+ for key in sorted (calibration_cache .keys ()):
319
+ values = calibration_cache [key ]
320
+ d_values = values .to_dict ()
321
+
322
+ highest = d_values .get ("highest" , zero )
323
+ lowest = d_values .get ("lowest" , zero )
324
+
325
+ highest_val = highest .item () if hasattr (highest , "item" ) else float (highest )
326
+ lowest_val = lowest .item () if hasattr (lowest , "item" ) else float (lowest )
327
+
328
+ floats = [float (highest_val ), float (lowest_val )]
329
+
330
+ value = str (max (floats ))
331
+
332
+ flat_key = builder .CreateString (key )
333
+ flat_value = builder .CreateString (value )
334
+
335
+ KeyValue .KeyValueStart (builder )
336
+ KeyValue .KeyValueAddKey (builder , flat_key )
337
+ KeyValue .KeyValueAddValue (builder , flat_value )
338
+ key_value = KeyValue .KeyValueEnd (builder )
339
+
340
+ key_value_list .append (key_value )
341
+
342
+
343
+ TrtTable .TrtTableStartDictVector (builder , len (key_value_list ))
344
+ for key_value in key_value_list :
345
+ builder .PrependUOffsetTRelative (key_value )
346
+ main_dict = builder .EndVector ()
347
+
348
+ TrtTable .TrtTableStart (builder )
349
+ TrtTable .TrtTableAddDict (builder , main_dict )
350
+ cal_table = TrtTable .TrtTableEnd (builder )
351
+
352
+ builder .Finish (cal_table )
353
+ buf = builder .Output ()
354
+
355
+ with open (os .path .join (dir , "calibration.flatbuffers" ), "wb" ) as file :
356
+ file .write (buf )
357
+
358
+ # Deserialize data (for validation)
359
+ if os .environ .get ("QUANTIZATION_DEBUG" , 0 ) in (1 , "1" ):
360
+ cal_table = TrtTable .TrtTable .GetRootAsTrtTable (buf , 0 )
361
+ dict_len = cal_table .DictLength ()
362
+ for i in range (dict_len ):
363
+ key_value = cal_table .Dict (i )
364
+ logging .info (key_value .Key ())
365
+ logging .info (key_value .Value ())
366
+
367
+ # write plain text
368
+ with open (os .path .join (dir , "calibration.cache" ), "w" ) as file :
369
+ for key in sorted (calibration_cache .keys ()):
370
+ values = calibration_cache [key ]
371
+ d_values = values .to_dict ()
372
+ highest = d_values .get ("highest" , zero )
373
+ lowest = d_values .get ("lowest" , zero )
374
+
375
+ highest_val = highest .item () if hasattr (highest , "item" ) else float (highest )
376
+ lowest_val = lowest .item () if hasattr (lowest , "item" ) else float (lowest )
377
+
378
+ floats = [float (highest_val ), float (lowest_val )]
379
+
380
+ value = key + " " + str (max (floats ))
381
+ file .write (value )
382
+ file .write ("\n " )
280
383
281
384
def parse_input_args ():
282
385
parser = argparse .ArgumentParser ()
@@ -553,8 +656,42 @@ def output_run_config(flags, samples):
553
656
for k , v in compute_range .data .items ():
554
657
json_compute_range [k ] = (float (v .range_value [0 ]), float (v .range_value [1 ]))
555
658
659
+ print ("Writing calibration table" )
660
+ try :
661
+ write_calibration_table (json_compute_range )
662
+ except AttributeError as e :
663
+ class TensorDataWrapper :
664
+ def __init__ (self , data_dict ):
665
+ self .data_dict = data_dict
666
+
667
+ def to_dict (self ):
668
+ return self .data_dict
669
+
670
+ def __repr__ (self ):
671
+ return repr (self .data_dict )
672
+
673
+ def __serializable__ (self ):
674
+ return self .data_dict
675
+
676
+ calibration_data = {}
677
+ for k , v in compute_range .data .items ():
678
+ if hasattr (v , 'to_dict' ):
679
+ tensor_dict = v .to_dict ()
680
+ processed_dict = {}
681
+ for dk , dv in tensor_dict .items ():
682
+ if isinstance (dv , np .ndarray ):
683
+ processed_dict [dk ] = dv .item () if dv .size == 1 else dv .tolist ()
684
+ elif isinstance (dv , np .number ):
685
+ processed_dict [dk ] = dv .item ()
686
+ else :
687
+ processed_dict [dk ] = dv
688
+ calibration_data [k ] = TensorDataWrapper (processed_dict )
689
+ else :
690
+ calibration_data [k ] = v
691
+
692
+ print ("Using custom calibration table function" )
693
+ custom_write_calibration_table (calibration_data )
556
694
557
- write_calibration_table (json_compute_range )
558
695
print ("Calibration is done. Calibration cache is saved to calibration.json" )
559
696
560
697
model_quants = model_quants + "_int8"
0 commit comments