@@ -155,114 +155,144 @@ def define_tile(
155155 )
156156
157157 tensor_1_candidate_for_transpositions = list (itertools .permutations (range (len (input_tensor_1 .shape ))))
158- tensor_2_candidate_for_transpositions = list ( itertools . permutations ( range ( len ( input_tensor_2 ))))
158+ tensor_2_candidate_for_transpositions = None
159159
160- for tensor_1_candidate_for_transposition in tensor_1_candidate_for_transpositions :
161- for tensor_2_candidate_for_transposition in tensor_2_candidate_for_transpositions :
162- try :
163- # Build TF dummy model
164- input_1 = tf_keras .Input (
165- shape = validation_data_1 .shape [1 :],
166- batch_size = validation_data_1 .shape [0 ] \
167- if isinstance (validation_data_1 .shape [0 ], int ) else None ,
168- name = 'dummy_input_1' ,
169- dtype = validation_data_1 .dtype ,
170- )
171- input_2 = validation_data_2
172- dummy_tile = define_tile (
173- target_input_tensor_1 = input_1 ,
174- target_perm_1 = list (tensor_1_candidate_for_transposition ),
175- target_input_tensor_2 = input_2 ,
176- target_perm_2 = list (tensor_2_candidate_for_transposition ),
177- target_name = graph_node .name ,
178- ** kwargs
179- )
180- # Verify that the output shape matches that of ONNX
181- # If the combination of each value of a dimension is not correct,
182- # invalidate the normal processing judgment.
183- onnx_output_shape_prod = np .prod ([dim if not isinstance (dim , str ) else - 1 for dim in onnx_output_shape ])
184- tile_output_shapes = list (dummy_tile .shape )
185- tile_output_shape_prod = np .prod ([dim if dim is not None else - 1 for dim in tile_output_shapes ])
186- if onnx_output_shape_prod != tile_output_shape_prod :
187- del input_1
188- del input_2
189- del dummy_tile
190- continue
191-
192- # Perform simple accuracy verification
193- # Terminate when the error is less than 1e-3
194- if onnx_tensor_infos :
195- try :
196- # Search for the axis with the smallest error
197- val_model = tf_keras .Model (
198- inputs = [
199- input_1 ,
200- ],
201- outputs = [
202- dummy_tile ,
203- ],
204- )
160+ if isinstance (input_tensor_2 , int ):
161+ tensor_2_candidate_for_transpositions = list (itertools .permutations (range (input_tensor_2 )))
162+ elif isinstance (input_tensor_2 , np .ndarray ) and hasattr (input_tensor_2 , '__len__' ):
163+ tiles_tensor_length = len (input_tensor_2 )
164+ if tiles_tensor_length > 1 :
165+ tensor_2_candidate_for_transpositions = list (itertools .permutations (range (len (input_tensor_2 ))))
166+ else :
167+ tensor_2_candidate_for_transpositions = list (itertools .permutations (input_tensor_2 ))
168+ elif tf_keras .backend .is_keras_tensor (input_tensor_2 ) and hasattr (input_tensor_2 .shape , '__len__' ):
169+ tiles_tensor_length = len (input_tensor_2 .shape )
170+ if tiles_tensor_length > 1 :
171+ tensor_2_candidate_for_transpositions = list (itertools .permutations (range (tiles_tensor_length )))
172+ else :
173+ # Dynamic Tensor
174+ pass
175+ else :
176+ # Unknown
177+ pass
205178
206- # TF dummy inference
207- tf_tensor_infos : Dict [Any ] = dummy_tf_inference (
208- model = val_model ,
209- inputs = [
210- input_1 ,
211- ],
212- verification_datas = [
213- validation_data_1 ,
214- ],
215- )
179+ if tensor_2_candidate_for_transpositions is not None :
180+ for tensor_1_candidate_for_transposition in tensor_1_candidate_for_transpositions :
181+ for tensor_2_candidate_for_transposition in tensor_2_candidate_for_transpositions :
182+ try :
183+ # Build TF dummy model
184+ input_1 = tf_keras .Input (
185+ shape = validation_data_1 .shape [1 :],
186+ batch_size = validation_data_1 .shape [0 ] \
187+ if isinstance (validation_data_1 .shape [0 ], int ) else None ,
188+ name = 'dummy_input_1' ,
189+ dtype = validation_data_1 .dtype ,
190+ )
191+ input_2 = validation_data_2
192+ dummy_tile = define_tile (
193+ target_input_tensor_1 = input_1 ,
194+ target_perm_1 = list (tensor_1_candidate_for_transposition ),
195+ target_input_tensor_2 = input_2 ,
196+ target_perm_2 = list (tensor_2_candidate_for_transposition ),
197+ target_name = graph_node .name ,
198+ ** kwargs
199+ )
200+ # Verify that the output shape matches that of ONNX
201+ # If the combination of each value of a dimension is not correct,
202+ # invalidate the normal processing judgment.
203+ onnx_output_shape_prod = np .prod ([dim if not isinstance (dim , str ) else - 1 for dim in onnx_output_shape ])
204+ tile_output_shapes = list (dummy_tile .shape )
205+ tile_output_shape_prod = np .prod ([dim if dim is not None else - 1 for dim in tile_output_shapes ])
206+ if onnx_output_shape_prod != tile_output_shape_prod :
216207 del input_1
217208 del input_2
218209 del dummy_tile
219- del val_model
210+ continue
211+
212+ # Perform simple accuracy verification
213+ # Terminate when the error is less than 1e-3
214+ if onnx_tensor_infos :
215+ try :
216+ # Search for the axis with the smallest error
217+ val_model = tf_keras .Model (
218+ inputs = [
219+ input_1 ,
220+ ],
221+ outputs = [
222+ dummy_tile ,
223+ ],
224+ )
220225
221- # Validation
222- onnx_tf_output_pairs = {
223- (oi [0 ], ti [0 ]): (oi [1 ], ti [1 ]) \
224- for oi , ti in zip (onnx_tensor_infos .items (), tf_tensor_infos .items ())
225- }
226- """
227- check_results: Dict[str, List[np.ndarray, int, float|int]]
228- {
229- onnx_output_name: [
230- onnx_tensor,
231- matched_flg, <--- 0: Unmatched, 1: Matched, 2: Skipped (Deleted or Shape Unmatched)
232- max_abs_err,
233- ]
226+ # TF dummy inference
227+ tf_tensor_infos : Dict [Any ] = dummy_tf_inference (
228+ model = val_model ,
229+ inputs = [
230+ input_1 ,
231+ ],
232+ verification_datas = [
233+ validation_data_1 ,
234+ ],
235+ )
236+ del input_1
237+ del input_2
238+ del dummy_tile
239+ del val_model
240+
241+ # Validation
242+ onnx_tf_output_pairs = {
243+ (oi [0 ], ti [0 ]): (oi [1 ], ti [1 ]) \
244+ for oi , ti in zip (onnx_tensor_infos .items (), tf_tensor_infos .items ())
234245 }
235- """
236- check_results = onnx_tf_tensor_validation (
237- output_pairs = onnx_tf_output_pairs ,
238- rtol = 0.0 ,
239- atol = 0.0 ,
240- )
241- result_err = sum ([val [2 ] for val in check_results .values ()])
242- if result_err < min_abs_err :
243- min_abs_err = result_err
244- min_abs_err_perm_1 = list (tensor_1_candidate_for_transposition )
245- min_abs_err_perm_2 = list (tensor_2_candidate_for_transposition )
246- if min_abs_err < 1e-3 :
247- break
248- except Exception as ex1 :
249- pass
250- except Exception as ex2 :
251- pass
252- else :
253- continue
254- break
246+ """
247+ check_results: Dict[str, List[np.ndarray, int, float|int]]
248+ {
249+ onnx_output_name: [
250+ onnx_tensor,
251+ matched_flg, <--- 0: Unmatched, 1: Matched, 2: Skipped (Deleted or Shape Unmatched)
252+ max_abs_err,
253+ ]
254+ }
255+ """
256+ check_results = onnx_tf_tensor_validation (
257+ output_pairs = onnx_tf_output_pairs ,
258+ rtol = 0.0 ,
259+ atol = 0.0 ,
260+ )
261+ result_err = sum ([val [2 ] for val in check_results .values ()])
262+ if result_err < min_abs_err :
263+ min_abs_err = result_err
264+ min_abs_err_perm_1 = list (tensor_1_candidate_for_transposition )
265+ min_abs_err_perm_2 = list (tensor_2_candidate_for_transposition )
266+ if min_abs_err < 1e-3 :
267+ break
268+ except Exception as ex1 :
269+ pass
270+ except Exception as ex2 :
271+ pass
272+ else :
273+ continue
274+ break
255275
256276 # Generation of TF OP
257- tf_layers_dict [graph_node_output .name ]['tf_node' ] = \
258- define_tile (
259- target_input_tensor_1 = input_tensor_1 ,
260- target_perm_1 = min_abs_err_perm_1 ,
261- target_input_tensor_2 = input_tensor_2 ,
262- target_perm_2 = min_abs_err_perm_2 ,
263- target_name = graph_node .name ,
264- ** kwargs
265- )
277+ if tensor_2_candidate_for_transpositions is not None :
278+ tf_layers_dict [graph_node_output .name ]['tf_node' ] = \
279+ define_tile (
280+ target_input_tensor_1 = input_tensor_1 ,
281+ target_perm_1 = min_abs_err_perm_1 ,
282+ target_input_tensor_2 = input_tensor_2 ,
283+ target_perm_2 = min_abs_err_perm_2 ,
284+ target_name = graph_node .name ,
285+ ** kwargs
286+ )
287+ else :
288+ # Dynamic Tensor
289+ tf_layers_dict [graph_node_output .name ]['tf_node' ] = \
290+ tf .tile (
291+ input = input_tensor_1 \
292+ if not isinstance (input_tensor_1 , np .ndarray ) \
293+ else tf .convert_to_tensor (input_tensor_1 ),
294+ multiples = tf .convert_to_tensor ([dim for dim in input_tensor_2 ]),
295+ )
266296
267297 # Post-process transpose
268298 tf_layers_dict [graph_node_output .name ]['tf_node' ] = post_process_transpose (
0 commit comments