Skip to content

Commit 1b1c114

Browse files
committed
Fix 'tuple' object has no attribute 'to_dict' for bert
Use custom_write_calibration_table for migraphx
1 parent afba506 commit 1b1c114

File tree

1 file changed

+138
-1
lines changed

1 file changed

+138
-1
lines changed

quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py

+138-1
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,109 @@ def get_op_nodes_not_followed_by_specific_op(model, op1, op2):
277277

278278
return not_selected_op1_nodes
279279

280+
def custom_write_calibration_table(calibration_cache, dir="."):
281+
"""
282+
Helper function to write calibration table to files.
283+
"""
284+
285+
import json
286+
import logging
287+
import flatbuffers
288+
import numpy as np
289+
290+
import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue
291+
import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable
292+
from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData
293+
294+
logging.info(f"calibration cache: {calibration_cache}")
295+
296+
class MyEncoder(json.JSONEncoder):
297+
def default(self, obj):
298+
if isinstance(obj, (TensorData, TensorsData)):
299+
return obj.to_dict()
300+
if isinstance(obj, TensorDataWrapper):
301+
return obj.data_dict
302+
if isinstance(obj, np.ndarray):
303+
return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"}
304+
if isinstance(obj, CalibrationMethod):
305+
return {"CLS": obj.__class__.__name__, "value": str(obj)}
306+
return json.JSONEncoder.default(self, obj)
307+
308+
json_data = json.dumps(calibration_cache, cls=MyEncoder)
309+
310+
with open(os.path.join(dir, "calibration.json"), "w") as file:
311+
file.write(json_data) # use `json.loads` to do the reverse
312+
313+
# Serialize data using FlatBuffers
314+
zero = np.array(0)
315+
builder = flatbuffers.Builder(1024)
316+
key_value_list = []
317+
318+
for key in sorted(calibration_cache.keys()):
319+
values = calibration_cache[key]
320+
d_values = values.to_dict()
321+
322+
highest = d_values.get("highest", zero)
323+
lowest = d_values.get("lowest", zero)
324+
325+
highest_val = highest.item() if hasattr(highest, "item") else float(highest)
326+
lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest)
327+
328+
floats = [float(highest_val), float(lowest_val)]
329+
330+
value = str(max(floats))
331+
332+
flat_key = builder.CreateString(key)
333+
flat_value = builder.CreateString(value)
334+
335+
KeyValue.KeyValueStart(builder)
336+
KeyValue.KeyValueAddKey(builder, flat_key)
337+
KeyValue.KeyValueAddValue(builder, flat_value)
338+
key_value = KeyValue.KeyValueEnd(builder)
339+
340+
key_value_list.append(key_value)
341+
342+
343+
TrtTable.TrtTableStartDictVector(builder, len(key_value_list))
344+
for key_value in key_value_list:
345+
builder.PrependUOffsetTRelative(key_value)
346+
main_dict = builder.EndVector()
347+
348+
TrtTable.TrtTableStart(builder)
349+
TrtTable.TrtTableAddDict(builder, main_dict)
350+
cal_table = TrtTable.TrtTableEnd(builder)
351+
352+
builder.Finish(cal_table)
353+
buf = builder.Output()
354+
355+
with open(os.path.join(dir, "calibration.flatbuffers"), "wb") as file:
356+
file.write(buf)
357+
358+
# Deserialize data (for validation)
359+
if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"):
360+
cal_table = TrtTable.TrtTable.GetRootAsTrtTable(buf, 0)
361+
dict_len = cal_table.DictLength()
362+
for i in range(dict_len):
363+
key_value = cal_table.Dict(i)
364+
logging.info(key_value.Key())
365+
logging.info(key_value.Value())
366+
367+
# write plain text
368+
with open(os.path.join(dir, "calibration.cache"), "w") as file:
369+
for key in sorted(calibration_cache.keys()):
370+
values = calibration_cache[key]
371+
d_values = values.to_dict()
372+
highest = d_values.get("highest", zero)
373+
lowest = d_values.get("lowest", zero)
374+
375+
highest_val = highest.item() if hasattr(highest, "item") else float(highest)
376+
lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest)
377+
378+
floats = [float(highest_val), float(lowest_val)]
379+
380+
value = key + " " + str(max(floats))
381+
file.write(value)
382+
file.write("\n")
280383

281384
def parse_input_args():
282385
parser = argparse.ArgumentParser()
@@ -553,8 +656,42 @@ def output_run_config(flags, samples):
553656
for k, v in compute_range.data.items():
554657
json_compute_range[k] = (float(v.range_value[0]), float(v.range_value[1]))
555658

659+
print("Writing calibration table")
660+
try:
661+
write_calibration_table(json_compute_range)
662+
except AttributeError as e:
663+
class TensorDataWrapper:
664+
def __init__(self, data_dict):
665+
self.data_dict = data_dict
666+
667+
def to_dict(self):
668+
return self.data_dict
669+
670+
def __repr__(self):
671+
return repr(self.data_dict)
672+
673+
def __serializable__(self):
674+
return self.data_dict
675+
676+
calibration_data = {}
677+
for k, v in compute_range.data.items():
678+
if hasattr(v, 'to_dict'):
679+
tensor_dict = v.to_dict()
680+
processed_dict = {}
681+
for dk, dv in tensor_dict.items():
682+
if isinstance(dv, np.ndarray):
683+
processed_dict[dk] = dv.item() if dv.size == 1 else dv.tolist()
684+
elif isinstance(dv, np.number):
685+
processed_dict[dk] = dv.item()
686+
else:
687+
processed_dict[dk] = dv
688+
calibration_data[k] = TensorDataWrapper(processed_dict)
689+
else:
690+
calibration_data[k] = v
691+
692+
print("Using custom calibration table function")
693+
custom_write_calibration_table(calibration_data)
556694

557-
write_calibration_table(json_compute_range)
558695
print("Calibration is done. Calibration cache is saved to calibration.json")
559696

560697
model_quants = model_quants + "_int8"

0 commit comments

Comments
 (0)