Skip to content

Commit 080191e

Browse files
authored
Merge pull request #651 from PINTO0309/sm_to_cf
Workaround to the problem of TensorFlow corrupting the FlatBuffer input/output order during INT8 quantization.
2 parents 46e19fe + 58d3415 commit 080191e

File tree

4 files changed

+73
-7
lines changed

4 files changed

+73
-7
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,15 +293,15 @@ Video speed is adjusted approximately 50 times slower than actual speed.
293293
docker run --rm -it \
294294
-v `pwd`:/workdir \
295295
-w /workdir \
296-
ghcr.io/pinto0309/onnx2tf:1.22.4
296+
ghcr.io/pinto0309/onnx2tf:1.22.5
297297

298298
or
299299

300300
# Authentication is not required for pulls from Docker Hub.
301301
docker run --rm -it \
302302
-v `pwd`:/workdir \
303303
-w /workdir \
304-
docker.io/pinto0309/onnx2tf:1.22.4
304+
docker.io/pinto0309/onnx2tf:1.22.5
305305

306306
or
307307

onnx2tf/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from onnx2tf.onnx2tf import convert, main
22

3-
__version__ = '1.22.4'
3+
__version__ = '1.22.5'

onnx2tf/onnx2tf.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,12 @@ def sanitizing(node):
856856
onnx_graph_output_names: List[str] = [
857857
outputop.name for outputop in graph.outputs
858858
]
859+
onnx_graph_input_shapes: List[List[int | str]] = [
860+
inputop.shape for inputop in graph.inputs
861+
]
862+
onnx_graph_output_shapes: List[List[int | str]] = [
863+
outputop.shape for outputop in graph.outputs
864+
]
859865

860866
# Inputs
861867
for graph_input in graph.inputs:
@@ -1298,6 +1304,8 @@ def sanitizing(node):
12981304
tflite_file_name=f'{output_file_name}_float32.tflite',
12991305
onnx_input_names=onnx_graph_input_names,
13001306
onnx_output_names=onnx_graph_output_names,
1307+
onnx_graph_input_shapes=onnx_graph_input_shapes,
1308+
onnx_graph_output_shapes=onnx_graph_output_shapes,
13011309
)
13021310
if output_weights:
13031311
weights_export(
@@ -1325,6 +1333,8 @@ def sanitizing(node):
13251333
tflite_file_name=f'{output_file_name}_float16.tflite',
13261334
onnx_input_names=onnx_graph_input_names,
13271335
onnx_output_names=onnx_graph_output_names,
1336+
onnx_graph_input_shapes=onnx_graph_input_shapes,
1337+
onnx_graph_output_shapes=onnx_graph_output_shapes,
13281338
)
13291339
if output_weights:
13301340
weights_export(
@@ -1372,6 +1382,8 @@ def sanitizing(node):
13721382
tflite_file_name=f'{output_file_name}_dynamic_range_quant.tflite',
13731383
onnx_input_names=onnx_graph_input_names,
13741384
onnx_output_names=onnx_graph_output_names,
1385+
onnx_graph_input_shapes=onnx_graph_input_shapes,
1386+
onnx_graph_output_shapes=onnx_graph_output_shapes,
13751387
)
13761388
if output_weights:
13771389
weights_export(
@@ -1501,6 +1513,8 @@ def representative_dataset_gen():
15011513
tflite_file_name=f'{output_file_name}_integer_quant.tflite',
15021514
onnx_input_names=onnx_graph_input_names,
15031515
onnx_output_names=onnx_graph_output_names,
1516+
onnx_graph_input_shapes=onnx_graph_input_shapes,
1517+
onnx_graph_output_shapes=onnx_graph_output_shapes,
15041518
)
15051519
if output_weights:
15061520
weights_export(
@@ -1536,6 +1550,8 @@ def representative_dataset_gen():
15361550
tflite_file_name=f'{output_file_name}_full_integer_quant.tflite',
15371551
onnx_input_names=onnx_graph_input_names,
15381552
onnx_output_names=onnx_graph_output_names,
1553+
onnx_graph_input_shapes=onnx_graph_input_shapes,
1554+
onnx_graph_output_shapes=onnx_graph_output_shapes,
15391555
)
15401556
if output_weights:
15411557
weights_export(
@@ -1572,6 +1588,8 @@ def representative_dataset_gen():
15721588
tflite_file_name=f'{output_file_name}_integer_quant_with_int16_act.tflite',
15731589
onnx_input_names=onnx_graph_input_names,
15741590
onnx_output_names=onnx_graph_output_names,
1591+
onnx_graph_input_shapes=onnx_graph_input_shapes,
1592+
onnx_graph_output_shapes=onnx_graph_output_shapes,
15751593
)
15761594
info(Color.GREEN(f'INT8 Quantization with int16 activations tflite output complete!'))
15771595
except RuntimeError as ex:
@@ -1603,6 +1621,8 @@ def representative_dataset_gen():
16031621
tflite_file_name=f'{output_file_name}_full_integer_quant_with_int16_act.tflite',
16041622
onnx_input_names=onnx_graph_input_names,
16051623
onnx_output_names=onnx_graph_output_names,
1624+
onnx_graph_input_shapes=onnx_graph_input_shapes,
1625+
onnx_graph_output_shapes=onnx_graph_output_shapes,
16061626
)
16071627
info(Color.GREEN(f'Full INT8 Quantization with int16 activations tflite output complete!'))
16081628
except RuntimeError as ex:

onnx2tf/utils/common_functions.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4544,6 +4544,8 @@ def rewrite_tflite_inout_opname(
45444544
tflite_file_name: str,
45454545
onnx_input_names: List[str],
45464546
onnx_output_names: List[str],
4547+
onnx_graph_input_shapes: List[List[int |str]],
4548+
onnx_graph_output_shapes: List[List[int |str]],
45474549
):
45484550
"""Rewrite the input/output OP name of tflite to the input/output OP name of ONNX.
45494551
Pre-installation of flatc is required.
@@ -4561,6 +4563,12 @@ def rewrite_tflite_inout_opname(
45614563
45624564
onnx_output_names: List[str]
45634565
List of ONNX output OP names
4566+
4567+
onnx_graph_input_shapes: List[List[int |str]]
4568+
List of ONNX input OP shapes
4569+
4570+
onnx_graph_output_shapes: List[List[int |str]]
4571+
List of ONNX output OP shapes
45644572
"""
45654573
try:
45664574
# Check to see if flatc is installed
@@ -4611,12 +4619,50 @@ def rewrite_tflite_inout_opname(
46114619
flat_output_nums: List[int] = flat_subgraphs.outputs
46124620
flat_input_infos = [flat_tensors[idx] for idx in flat_input_nums]
46134621
flat_output_infos = [flat_tensors[idx] for idx in flat_output_nums]
4622+
4623+
# Determination of the number of inputs/outputs of the same shape
4624+
# Correct name discrepancies based on shape if multiple inputs/outputs shapes do not overlap
4625+
# However, if there are inputs/outputs containing undefined dimensions,
4626+
# workaround is skipped because correction is not possible.
4627+
# https://github.com/PINTO0309/onnx2tf/issues/650
4628+
inputs_second_dim_elements = [
4629+
tuple(onnx_graph_input_shape) \
4630+
for onnx_graph_input_shape in onnx_graph_input_shapes
4631+
]
4632+
inputs_has_duplicates = len(inputs_second_dim_elements) != len(set(inputs_second_dim_elements))
4633+
inputs_has_undefined_dim = any(isinstance(item, str) for onnx_graph_input_shape in onnx_graph_input_shapes for item in onnx_graph_input_shape)
4634+
4635+
outputs_second_dim_elements = [
4636+
tuple(onnx_graph_output_shape) \
4637+
for onnx_graph_output_shape in onnx_graph_output_shapes
4638+
]
4639+
outputs_has_duplicates = len(outputs_second_dim_elements) != len(set(outputs_second_dim_elements))
4640+
outputs_has_undefined_dim = any(isinstance(item, str) for onnx_graph_output_shape in onnx_graph_output_shapes for item in onnx_graph_output_shape)
4641+
46144642
# INPUT
4615-
for idx, flat_input_info in enumerate(flat_input_infos):
4616-
flat_input_info.name = onnx_input_names[idx]
4643+
if not inputs_has_duplicates and not inputs_has_undefined_dim:
4644+
for onnx_input_name, onnx_input_shape in zip(onnx_input_names, onnx_graph_input_shapes):
4645+
for flat_input_info in flat_input_infos:
4646+
if np.prod(onnx_input_shape) == np.prod(list(flat_input_info.shape)):
4647+
flat_input_info.name = onnx_input_name
4648+
break
4649+
else:
4650+
for idx, flat_input_info in enumerate(flat_input_infos):
4651+
flat_input_info.name = onnx_input_names[idx]
4652+
46174653
# OUTPUT
4618-
for idx, flat_output_info in enumerate(flat_output_infos):
4619-
flat_output_info.name = onnx_output_names[idx]
4654+
if not outputs_has_duplicates and not outputs_has_undefined_dim:
4655+
for onnx_output_name, onnx_output_shape in zip(onnx_output_names, onnx_graph_output_shapes):
4656+
for flat_output_info in flat_output_infos:
4657+
if np.prod(onnx_output_shape) == np.prod(list(flat_output_info.shape)):
4658+
flat_output_info.name = onnx_output_name
4659+
break
4660+
else:
4661+
for idx, flat_output_info in enumerate(flat_output_infos):
4662+
flat_output_info.name = onnx_output_names[idx]
4663+
4664+
if inputs_has_duplicates or inputs_has_undefined_dim or outputs_has_duplicates or outputs_has_undefined_dim:
4665+
warn('Carefully check the output .tflite as the order of input OP names and output OP names may have been corrupted by TensorFlow.')
46204666

46214667
# make signature_defs
46224668
"""

0 commit comments

Comments
 (0)