Skip to content

Commit 30f25f7

Browse files
authored
Merge pull request #623 from PINTO0309/fix_tile
[experimental] Support for dynamic `Tile`, dynamic `Reshape`
2 parents bf8e894 + bb4fe3f commit 30f25f7

File tree

4 files changed

+136
-104
lines changed

4 files changed

+136
-104
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,15 +270,15 @@ Video speed is adjusted approximately 50 times slower than actual speed.
270270
docker run --rm -it \
271271
-v `pwd`:/workdir \
272272
-w /workdir \
273-
ghcr.io/pinto0309/onnx2tf:1.20.8
273+
ghcr.io/pinto0309/onnx2tf:1.20.9
274274

275275
or
276276

277277
# Authentication is not required for pulls from Docker Hub.
278278
docker run --rm -it \
279279
-v `pwd`:/workdir \
280280
-w /workdir \
281-
docker.io/pinto0309/onnx2tf:1.20.8
281+
docker.io/pinto0309/onnx2tf:1.20.9
282282

283283
or
284284

onnx2tf/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from onnx2tf.onnx2tf import convert, main
22

3-
__version__ = '1.20.8'
3+
__version__ = '1.20.9'

onnx2tf/ops/Reshape.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,12 @@ def make_node(
8080

8181
# If Reshape's shape contains zeros, get the deformed shape from the output shape
8282
if isinstance(reshape_shape, list) and reshape_shape.count(0) > 0:
83-
new_shape = [-1 if isinstance(s, str) else int(s) for s in output_shape]
83+
before_tensor_shapes = tf.shape(tf_layers_dict[graph_node_input_1.name]['tf_node'])
84+
new_shape = [before_tensor_shapes[idx] if isinstance(s, str) else int(s) for idx, s in enumerate(output_shape)]
8485
reshape_shape = new_shape
8586
elif isinstance(reshape_shape, np.ndarray) and np.count_nonzero(reshape_shape == 0) > 0:
86-
new_shape = [-1 if isinstance(s, str) else int(s) for s in output_shape]
87+
before_tensor_shapes = tf.shape(tf_layers_dict[graph_node_input_1.name]['tf_node'])
88+
new_shape = [before_tensor_shapes[idx] if isinstance(s, str) else int(s) for idx, s in enumerate(output_shape)]
8789
reshape_shape = new_shape
8890

8991
onnx_tensor_infos_for_validation: Dict[str: np.ndarray] = kwargs['onnx_tensor_infos_for_validation']

onnx2tf/ops/Tile.py

Lines changed: 129 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -155,114 +155,144 @@ def define_tile(
155155
)
156156

157157
tensor_1_candidate_for_transpositions = list(itertools.permutations(range(len(input_tensor_1.shape))))
158-
tensor_2_candidate_for_transpositions = list(itertools.permutations(range(len(input_tensor_2))))
158+
tensor_2_candidate_for_transpositions = None
159159

160-
for tensor_1_candidate_for_transposition in tensor_1_candidate_for_transpositions:
161-
for tensor_2_candidate_for_transposition in tensor_2_candidate_for_transpositions:
162-
try:
163-
# Build TF dummy model
164-
input_1 = tf_keras.Input(
165-
shape=validation_data_1.shape[1:],
166-
batch_size=validation_data_1.shape[0] \
167-
if isinstance(validation_data_1.shape[0], int) else None,
168-
name='dummy_input_1',
169-
dtype=validation_data_1.dtype,
170-
)
171-
input_2 = validation_data_2
172-
dummy_tile = define_tile(
173-
target_input_tensor_1=input_1,
174-
target_perm_1=list(tensor_1_candidate_for_transposition),
175-
target_input_tensor_2=input_2,
176-
target_perm_2=list(tensor_2_candidate_for_transposition),
177-
target_name=graph_node.name,
178-
**kwargs
179-
)
180-
# Verify that the output shape matches that of ONNX
181-
# If the combination of each value of a dimension is not correct,
182-
# invalidate the normal processing judgment.
183-
onnx_output_shape_prod = np.prod([dim if not isinstance(dim, str) else -1 for dim in onnx_output_shape])
184-
tile_output_shapes = list(dummy_tile.shape)
185-
tile_output_shape_prod = np.prod([dim if dim is not None else -1 for dim in tile_output_shapes])
186-
if onnx_output_shape_prod != tile_output_shape_prod:
187-
del input_1
188-
del input_2
189-
del dummy_tile
190-
continue
191-
192-
# Perform simple accuracy verification
193-
# Terminate when the error is less than 1e-3
194-
if onnx_tensor_infos:
195-
try:
196-
# Search for the axis with the smallest error
197-
val_model = tf_keras.Model(
198-
inputs=[
199-
input_1,
200-
],
201-
outputs=[
202-
dummy_tile,
203-
],
204-
)
160+
if isinstance(input_tensor_2, int):
161+
tensor_2_candidate_for_transpositions = list(itertools.permutations(range(input_tensor_2)))
162+
elif isinstance(input_tensor_2, np.ndarray) and hasattr(input_tensor_2, '__len__'):
163+
tiles_tensor_length = len(input_tensor_2)
164+
if tiles_tensor_length > 1:
165+
tensor_2_candidate_for_transpositions = list(itertools.permutations(range(len(input_tensor_2))))
166+
else:
167+
tensor_2_candidate_for_transpositions = list(itertools.permutations(input_tensor_2))
168+
elif tf_keras.backend.is_keras_tensor(input_tensor_2) and hasattr(input_tensor_2.shape, '__len__'):
169+
tiles_tensor_length = len(input_tensor_2.shape)
170+
if tiles_tensor_length > 1:
171+
tensor_2_candidate_for_transpositions = list(itertools.permutations(range(tiles_tensor_length)))
172+
else:
173+
# Dynamic Tensor
174+
pass
175+
else:
176+
# Unknown
177+
pass
205178

206-
# TF dummy inference
207-
tf_tensor_infos: Dict[Any] = dummy_tf_inference(
208-
model=val_model,
209-
inputs=[
210-
input_1,
211-
],
212-
verification_datas=[
213-
validation_data_1,
214-
],
215-
)
179+
if tensor_2_candidate_for_transpositions is not None:
180+
for tensor_1_candidate_for_transposition in tensor_1_candidate_for_transpositions:
181+
for tensor_2_candidate_for_transposition in tensor_2_candidate_for_transpositions:
182+
try:
183+
# Build TF dummy model
184+
input_1 = tf_keras.Input(
185+
shape=validation_data_1.shape[1:],
186+
batch_size=validation_data_1.shape[0] \
187+
if isinstance(validation_data_1.shape[0], int) else None,
188+
name='dummy_input_1',
189+
dtype=validation_data_1.dtype,
190+
)
191+
input_2 = validation_data_2
192+
dummy_tile = define_tile(
193+
target_input_tensor_1=input_1,
194+
target_perm_1=list(tensor_1_candidate_for_transposition),
195+
target_input_tensor_2=input_2,
196+
target_perm_2=list(tensor_2_candidate_for_transposition),
197+
target_name=graph_node.name,
198+
**kwargs
199+
)
200+
# Verify that the output shape matches that of ONNX
201+
# If the combination of each value of a dimension is not correct,
202+
# invalidate the normal processing judgment.
203+
onnx_output_shape_prod = np.prod([dim if not isinstance(dim, str) else -1 for dim in onnx_output_shape])
204+
tile_output_shapes = list(dummy_tile.shape)
205+
tile_output_shape_prod = np.prod([dim if dim is not None else -1 for dim in tile_output_shapes])
206+
if onnx_output_shape_prod != tile_output_shape_prod:
216207
del input_1
217208
del input_2
218209
del dummy_tile
219-
del val_model
210+
continue
211+
212+
# Perform simple accuracy verification
213+
# Terminate when the error is less than 1e-3
214+
if onnx_tensor_infos:
215+
try:
216+
# Search for the axis with the smallest error
217+
val_model = tf_keras.Model(
218+
inputs=[
219+
input_1,
220+
],
221+
outputs=[
222+
dummy_tile,
223+
],
224+
)
220225

221-
# Validation
222-
onnx_tf_output_pairs = {
223-
(oi[0], ti[0]): (oi[1], ti[1]) \
224-
for oi, ti in zip(onnx_tensor_infos.items(), tf_tensor_infos.items())
225-
}
226-
"""
227-
check_results: Dict[str, List[np.ndarray, int, float|int]]
228-
{
229-
onnx_output_name: [
230-
onnx_tensor,
231-
matched_flg, <--- 0: Unmatched, 1: Matched, 2: Skipped (Deleted or Shape Unmatched)
232-
max_abs_err,
233-
]
226+
# TF dummy inference
227+
tf_tensor_infos: Dict[Any] = dummy_tf_inference(
228+
model=val_model,
229+
inputs=[
230+
input_1,
231+
],
232+
verification_datas=[
233+
validation_data_1,
234+
],
235+
)
236+
del input_1
237+
del input_2
238+
del dummy_tile
239+
del val_model
240+
241+
# Validation
242+
onnx_tf_output_pairs = {
243+
(oi[0], ti[0]): (oi[1], ti[1]) \
244+
for oi, ti in zip(onnx_tensor_infos.items(), tf_tensor_infos.items())
234245
}
235-
"""
236-
check_results = onnx_tf_tensor_validation(
237-
output_pairs=onnx_tf_output_pairs,
238-
rtol=0.0,
239-
atol=0.0,
240-
)
241-
result_err = sum([val[2] for val in check_results.values()])
242-
if result_err < min_abs_err:
243-
min_abs_err = result_err
244-
min_abs_err_perm_1 = list(tensor_1_candidate_for_transposition)
245-
min_abs_err_perm_2 = list(tensor_2_candidate_for_transposition)
246-
if min_abs_err < 1e-3:
247-
break
248-
except Exception as ex1:
249-
pass
250-
except Exception as ex2:
251-
pass
252-
else:
253-
continue
254-
break
246+
"""
247+
check_results: Dict[str, List[np.ndarray, int, float|int]]
248+
{
249+
onnx_output_name: [
250+
onnx_tensor,
251+
matched_flg, <--- 0: Unmatched, 1: Matched, 2: Skipped (Deleted or Shape Unmatched)
252+
max_abs_err,
253+
]
254+
}
255+
"""
256+
check_results = onnx_tf_tensor_validation(
257+
output_pairs=onnx_tf_output_pairs,
258+
rtol=0.0,
259+
atol=0.0,
260+
)
261+
result_err = sum([val[2] for val in check_results.values()])
262+
if result_err < min_abs_err:
263+
min_abs_err = result_err
264+
min_abs_err_perm_1 = list(tensor_1_candidate_for_transposition)
265+
min_abs_err_perm_2 = list(tensor_2_candidate_for_transposition)
266+
if min_abs_err < 1e-3:
267+
break
268+
except Exception as ex1:
269+
pass
270+
except Exception as ex2:
271+
pass
272+
else:
273+
continue
274+
break
255275

256276
# Generation of TF OP
257-
tf_layers_dict[graph_node_output.name]['tf_node'] = \
258-
define_tile(
259-
target_input_tensor_1=input_tensor_1,
260-
target_perm_1=min_abs_err_perm_1,
261-
target_input_tensor_2=input_tensor_2,
262-
target_perm_2=min_abs_err_perm_2,
263-
target_name=graph_node.name,
264-
**kwargs
265-
)
277+
if tensor_2_candidate_for_transpositions is not None:
278+
tf_layers_dict[graph_node_output.name]['tf_node'] = \
279+
define_tile(
280+
target_input_tensor_1=input_tensor_1,
281+
target_perm_1=min_abs_err_perm_1,
282+
target_input_tensor_2=input_tensor_2,
283+
target_perm_2=min_abs_err_perm_2,
284+
target_name=graph_node.name,
285+
**kwargs
286+
)
287+
else:
288+
# Dynamic Tensor
289+
tf_layers_dict[graph_node_output.name]['tf_node'] = \
290+
tf.tile(
291+
input=input_tensor_1 \
292+
if not isinstance(input_tensor_1, np.ndarray) \
293+
else tf.convert_to_tensor(input_tensor_1),
294+
multiples=tf.convert_to_tensor([dim for dim in input_tensor_2]),
295+
)
266296

267297
# Post-process transpose
268298
tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(

0 commit comments

Comments
 (0)