Skip to content

Fix 'tuple' object has no attribute 'to_dict' for bert #512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 138 additions & 1 deletion quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,109 @@ def get_op_nodes_not_followed_by_specific_op(model, op1, op2):

return not_selected_op1_nodes

def custom_write_calibration_table(calibration_cache, dir="."):
"""
Helper function to write calibration table to files.
"""

import json
import logging
import flatbuffers
import numpy as np

import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue
import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable
from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData

logging.info(f"calibration cache: {calibration_cache}")

class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (TensorData, TensorsData)):
return obj.to_dict()
if isinstance(obj, TensorDataWrapper):
return obj.data_dict
if isinstance(obj, np.ndarray):
return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"}
if isinstance(obj, CalibrationMethod):
return {"CLS": obj.__class__.__name__, "value": str(obj)}
return json.JSONEncoder.default(self, obj)

json_data = json.dumps(calibration_cache, cls=MyEncoder)

with open(os.path.join(dir, "calibration.json"), "w") as file:
file.write(json_data) # use `json.loads` to do the reverse

# Serialize data using FlatBuffers
zero = np.array(0)
builder = flatbuffers.Builder(1024)
key_value_list = []

for key in sorted(calibration_cache.keys()):
values = calibration_cache[key]
d_values = values.to_dict()

highest = d_values.get("highest", zero)
lowest = d_values.get("lowest", zero)

highest_val = highest.item() if hasattr(highest, "item") else float(highest)
lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest)

floats = [float(highest_val), float(lowest_val)]

value = str(max(floats))

flat_key = builder.CreateString(key)
flat_value = builder.CreateString(value)

KeyValue.KeyValueStart(builder)
KeyValue.KeyValueAddKey(builder, flat_key)
KeyValue.KeyValueAddValue(builder, flat_value)
key_value = KeyValue.KeyValueEnd(builder)

key_value_list.append(key_value)


TrtTable.TrtTableStartDictVector(builder, len(key_value_list))
for key_value in key_value_list:
builder.PrependUOffsetTRelative(key_value)
main_dict = builder.EndVector()

TrtTable.TrtTableStart(builder)
TrtTable.TrtTableAddDict(builder, main_dict)
cal_table = TrtTable.TrtTableEnd(builder)

builder.Finish(cal_table)
buf = builder.Output()

with open(os.path.join(dir, "calibration.flatbuffers"), "wb") as file:
file.write(buf)

# Deserialize data (for validation)
if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"):
cal_table = TrtTable.TrtTable.GetRootAsTrtTable(buf, 0)
dict_len = cal_table.DictLength()
for i in range(dict_len):
key_value = cal_table.Dict(i)
logging.info(key_value.Key())
logging.info(key_value.Value())

# write plain text
with open(os.path.join(dir, "calibration.cache"), "w") as file:
for key in sorted(calibration_cache.keys()):
values = calibration_cache[key]
d_values = values.to_dict()
highest = d_values.get("highest", zero)
lowest = d_values.get("lowest", zero)

highest_val = highest.item() if hasattr(highest, "item") else float(highest)
lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest)

floats = [float(highest_val), float(lowest_val)]

value = key + " " + str(max(floats))
file.write(value)
file.write("\n")

def parse_input_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -553,8 +656,42 @@ def output_run_config(flags, samples):
for k, v in compute_range.data.items():
json_compute_range[k] = (float(v.range_value[0]), float(v.range_value[1]))

print("Writing calibration table")
try:
write_calibration_table(json_compute_range)
except AttributeError as e:
class TensorDataWrapper:
def __init__(self, data_dict):
self.data_dict = data_dict

def to_dict(self):
return self.data_dict

def __repr__(self):
return repr(self.data_dict)

def __serializable__(self):
return self.data_dict

calibration_data = {}
for k, v in compute_range.data.items():
if hasattr(v, 'to_dict'):
tensor_dict = v.to_dict()
processed_dict = {}
for dk, dv in tensor_dict.items():
if isinstance(dv, np.ndarray):
processed_dict[dk] = dv.item() if dv.size == 1 else dv.tolist()
elif isinstance(dv, np.number):
processed_dict[dk] = dv.item()
else:
processed_dict[dk] = dv
calibration_data[k] = TensorDataWrapper(processed_dict)
else:
calibration_data[k] = v

print("Using custom calibration table function")
custom_write_calibration_table(calibration_data)

write_calibration_table(json_compute_range)
print("Calibration is done. Calibration cache is saved to calibration.json")

model_quants = model_quants + "_int8"
Expand Down
Loading