Skip to content

Commit

Permalink
[one-cmds] Revise one-import-onnx with new force_ext (#13921)
Browse files Browse the repository at this point in the history
This will revise one-import-onnx with new force_ext option.
- added some comments how extension is called with force_ext and
disable_ext options

ONE-DCO-1.0-Signed-off-by: SaeHie Park <[email protected]>
  • Loading branch information
seanshpark authored Sep 4, 2024
1 parent 015680c commit c66e299
Showing 1 changed file with 33 additions and 4 deletions.
37 changes: 33 additions & 4 deletions compiler/one-cmds/one-import-onnx
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,30 @@ def _get_parser():
'Ensure generated circle model preserves the I/O order of the original onnx model.'
)

# force use onnx-tf for one-import-onnx-ext is installed
parser.add_argument('--disable_ext',
action='store_true',
help='Disable one-import-onnx-ext and use legacy onnx-tf package')
# NOTE How one-import-onnx-ext is called;
# - applies only when one-import-onnx-ext is installed
# - default onnx-tf is called for conversion
# - if onnx-tf fails, one-import-onnx-ext is called
# - if 'force_ext' is given, skip onnx-tf and call one-import-onnx-ext
# - if 'disable_ext' is given, one-import-onnx-ext is not called
# - both 'force_ext', 'disable_ext' should not be set

# converter version
extension_group = parser.add_argument_group('extension arguments')
use_extension = extension_group.add_mutually_exclusive_group()

# use one-import-onnx-ext in the first place
use_extension.add_argument(
'--force_ext',
action='store_true',
help='Use one-import-onnx-ext in first attempt and skip default tool')

# do not call one-import-onnx-ext when default one-import-onnx fails
use_extension.add_argument('--disable_ext',
action='store_true',
help='Disable one-import-onnx-ext for second attempt')

parser.add_argument('--use_extension', type=str, help=argparse.SUPPRESS)

# save intermediate file(s)
parser.add_argument('--save_intermediate',
Expand Down Expand Up @@ -244,6 +264,15 @@ def _remap_io_names(onnx_model):
remapper.update()


def _force_ext(args):
if oneutils.is_valid_attr(args, 'force_ext'):
return True
env_force_ext = os.getenv('ONE_IMPORT_ONNX_EXT_FORCE')
if env_force_ext == '1' or env_force_ext == 'Y':
return True
return False


def _disable_ext(args):
if oneutils.is_valid_attr(args, 'disable_ext'):
return True
Expand Down

0 comments on commit c66e299

Please sign in to comment.