diff --git a/compiler/one-cmds/one-import-onnx b/compiler/one-cmds/one-import-onnx index 213e86e7d08..2d50ecc58b5 100644 --- a/compiler/one-cmds/one-import-onnx +++ b/compiler/one-cmds/one-import-onnx @@ -298,9 +298,7 @@ def _convert(args): logfile_path = os.path.realpath(args.output_path) + '.log' # get import onnx extension path - ext_path = None - if not _disable_ext(args): - ext_path = _check_ext() + ext_path = _check_ext() with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir: # save intermediate @@ -322,61 +320,92 @@ def _convert(args): os.path.splitext(basename)[0] + '~.onnx') onnx.save(onnx_model, fixed_path) + run_default_import = True + ext_alt_onnx_path = None if ext_path: # save onnx_model to temporary alt file basename = os.path.basename(getattr(args, 'input_path')) - alt_path = os.path.join(tmpdir, os.path.splitext(basename)[0] + '-alt.onnx') - onnx.save(onnx_model, alt_path) - - # call extension with options - ext_cmd = [ext_path] - if oneutils.is_valid_attr(args, 'unroll_rnn'): - ext_cmd.append('--unroll_rnn') - if oneutils.is_valid_attr(args, 'unroll_lstm'): - ext_cmd.append('--unroll_lstm') - if oneutils.is_valid_attr(args, 'experimental_disable_batchmatmul_unfold'): - ext_cmd.append('--experimental_disable_batchmatmul_unfold') - if oneutils.is_valid_attr(args, 'save_intermediate'): - ext_cmd.append('--save_intermediate') - if oneutils.is_valid_attr(args, 'keep_io_order'): - ext_cmd.append('--keep_io_order') - ext_cmd.append(alt_path) - ext_cmd.append(getattr(args, 'output_path')) - oneutils.run(ext_cmd, logfile=f) - return - - tf_savedmodel = onnx_tf.backend.prepare(onnx_model) - - savedmodel_name = os.path.splitext(os.path.basename( - args.output_path))[0] + '.savedmodel' - savedmodel_output_path = os.path.join(tmpdir, savedmodel_name) - tf_savedmodel.export_graph(savedmodel_output_path) - - # make a command to convert from tf to tflite - tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py') - tf2tfliteV2_output_name = os.path.splitext(os.path.basename( - args.output_path))[0] + '.tflite' - tf2tfliteV2_output_path = os.path.join(tmpdir, tf2tfliteV2_output_name) - - tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(args, tf2tfliteV2_path, - savedmodel_output_path, - tf2tfliteV2_output_path) - - f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode()) - - # convert tf to tflite - oneutils.run(tf2tfliteV2_cmd, logfile=f) - - # make a command to convert from tflite to circle - tflite2circle_path = os.path.join(dir_path, 'tflite2circle') - tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path, - tf2tfliteV2_output_path, - getattr(args, 'output_path')) - - f.write((' '.join(tflite2circle_cmd) + '\n').encode()) - - # convert tflite to circle - oneutils.run(tflite2circle_cmd, err_prefix="tflite2circle", logfile=f) + ext_alt_onnx_path = os.path.join(tmpdir, + os.path.splitext(basename)[0] + '-alt.onnx') + onnx.save(onnx_model, ext_alt_onnx_path) + + if _force_ext(args): + run_default_import = False + + res_conv = -1 + if run_default_import: + if _force_ext(args): + print( + "onnx-import-onnx: 'force_ext' is True, " + "but onnx-import-onnx-ext is not installed. " + "onnx-tf is used.", + flush=True) + # TODO split these to small functions + try: + tf_savedmodel = onnx_tf.backend.prepare(onnx_model) + + savedmodel_name = os.path.splitext(os.path.basename( + args.output_path))[0] + '.savedmodel' + savedmodel_output_path = os.path.join(tmpdir, savedmodel_name) + tf_savedmodel.export_graph(savedmodel_output_path) + + # make a command to convert from tf to tflite + tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py') + tf2tfliteV2_output_name = os.path.splitext( + os.path.basename(args.output_path))[0] + '.tflite' + tf2tfliteV2_output_path = os.path.join(tmpdir, tf2tfliteV2_output_name) + + tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd( + args, tf2tfliteV2_path, savedmodel_output_path, + tf2tfliteV2_output_path) + + f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode()) + + # convert tf to tflite + res_conv = oneutils.run_ret(tf2tfliteV2_cmd, logfile=f) + + if res_conv == 0: + # make a command to convert from tflite to circle + tflite2circle_path = os.path.join(dir_path, 'tflite2circle') + tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd( + tflite2circle_path, tf2tfliteV2_output_path, + getattr(args, 'output_path')) + + f.write((' '.join(tflite2circle_cmd) + '\n').encode()) + + # convert tflite to circle + res_conv = oneutils.run_ret(tflite2circle_cmd, + err_prefix="tflite2circle", + logfile=f) + except: + res_conv = -1 + + # if default conversion fails, try with one-import-onnx-ext if available + if ext_path and not _disable_ext(args): + if res_conv != 0: + if run_default_import: + print( + "onnx-import-onnx: Conversion with onnx-tf failed. " + "Fallback to use one-import-onnx-ext", + flush=True) + # call extension with options + ext_cmd = [ext_path] + if oneutils.is_valid_attr(args, 'unroll_rnn'): + ext_cmd.append('--unroll_rnn') + if oneutils.is_valid_attr(args, 'unroll_lstm'): + ext_cmd.append('--unroll_lstm') + if oneutils.is_valid_attr(args, + 'experimental_disable_batchmatmul_unfold'): + ext_cmd.append('--experimental_disable_batchmatmul_unfold') + if oneutils.is_valid_attr(args, 'save_intermediate'): + ext_cmd.append('--save_intermediate') + if oneutils.is_valid_attr(args, 'keep_io_order'): + ext_cmd.append('--keep_io_order') + ext_cmd.append(ext_alt_onnx_path) + ext_cmd.append(getattr(args, 'output_path')) + res_conv = oneutils.run_ret(ext_cmd, logfile=f) + + sys.exit(res_conv) def main(): diff --git a/compiler/one-cmds/tests/one-import-onnx_ext_001.test b/compiler/one-cmds/tests/one-import-onnx_ext_001.test index be8dadcf1e8..3c7e6f95d91 100644 --- a/compiler/one-cmds/tests/one-import-onnx_ext_001.test +++ b/compiler/one-cmds/tests/one-import-onnx_ext_001.test @@ -15,6 +15,7 @@ # limitations under the License. # test for one-import-onnx to invoke extension +# default should execute and not one-import-onnx-ext filename_ext="$(basename -- $0)" filename="${filename_ext%.*}" @@ -43,7 +44,7 @@ one-import-onnx \ --input_path ${inputfile} \ --output_path ${outputfile} > ${logfile} 2>&1 -if ! grep -q "one-import-onnx-ext dummy output!!!" "${logfile}"; then +if grep -q "one-import-onnx-ext dummy output!!!" "${logfile}"; then trap_err_onexit fi