From 8345b9db48133dc0a2782ccbf8ee038f8012fdf6 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Thu, 23 May 2024 14:48:23 -0700 Subject: [PATCH] Update IREE onnx import to be in sync with Torch-MLIR (#17476) This commit updates `iree-import-onnx` so that it behaves the same as torch-mlir's version (https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/tools/import_onnx/__main__.py). Specifically, enabling the data propagation improves shape inference and is leading to more models passing. Related to #17021 --------- Signed-off-by: saienduri --- .../compiler/tools/import_onnx/__main__.py | 73 +++++++++++++++++-- 1 file changed, 65 insertions(+), 8 deletions(-) diff --git a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py index 9e2fe1397073..e5aff51f990c 100644 --- a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py +++ b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py @@ -15,8 +15,10 @@ python -m iree.compiler.tools.import_onnx ... """ import argparse +import os from pathlib import Path import sys +import tempfile try: import onnx @@ -38,8 +40,8 @@ ) -def main(args): - model_proto = load_onnx_model(args.input_file) +def main(args: argparse.Namespace): + model_proto = load_onnx_model(args) context = Context() model_info = onnx_importer.ModelInfo(model_proto) m = model_info.create_module(context=context).operation @@ -58,13 +60,56 @@ def main(args): print(m.get_asm(assume_verified=not args.no_verify)) -def load_onnx_model(file_path: Path) -> onnx.ModelProto: - raw_model = onnx.load(file_path) - inferred_model = onnx.shape_inference.infer_shapes(raw_model) - return inferred_model +def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: + input_dir = os.path.dirname(os.path.abspath(args.input_file)) - -def parse_arguments(argv=None): + # Load the model, with possible external data coming from the default + # location, or the location specified on the command line. + if args.data_dir is None: + raw_model = onnx.load(args.input_file) + else: + raw_model = onnx.load(args.input_file, load_external_data=False) + onnx.load_external_data_for_model(raw_model, args.data_dir) + + # Do shape inference two ways. First, attempt in-memory to avoid redundant + # loading and the need for writing a temporary file somewhere. If that + # fails, typically because of the 2 GB protobuf size limit, try again via + # files. See + # https://onnx.ai/onnx/repo-docs/PythonAPIOverview.html#shape-inference-a-large-onnx-model-2gb + # for details about the file-based technique. + + # Run the checker to test whether the file is above the threshold for + # in-memory shape inference. If not, go ahead and do the shape inference. + try: + onnx.checker.check_model(raw_model) + inferred_model = onnx.shape_inference.infer_shapes( + raw_model, data_prop=args.data_prop + ) + return inferred_model + except ValueError: + pass + + # Model is too big for in-memory inference: do file-based shape inference + # to a temp file. + # Make a temp dir for all the temp files we'll be generating as a side + # effect of infering shapes. For now, the only file is a new .onnx holding + # the revised model with shapes. + with tempfile.TemporaryDirectory(dir=input_dir) as temp_dir_name: + temp_dir_path = Path(temp_dir_name) + temp_inferred_file = temp_dir_path / "temp-inferred.onnx" + onnx.shape_inference.infer_shapes_path( + args.input_file, temp_inferred_file, data_prop=args.data_prop + ) + + # Load the temp file and the external data. + inferred_model = onnx.load(temp_inferred_file, load_external_data=False) + data_dir = Path(input_dir if args.data_dir is None else args.data_dir) + onnx.load_external_data_for_model(inferred_model, data_dir) + + return inferred_model + + +def parse_arguments(argv=None) -> argparse.Namespace: parser = argparse.ArgumentParser(description="IREE ONNX import tool") parser.add_argument("input_file", help="ONNX protobuf input", type=Path) parser.add_argument( @@ -75,6 +120,18 @@ def parse_arguments(argv=None): action="store_true", help="Disable verification prior to printing", ) + parser.add_argument( + "--data-prop", + default=True, + action=argparse.BooleanOptionalAction, + help="Toggle data propogation for onnx shape inference", + ) + parser.add_argument( + "--data-dir", + help="Path to the base directory of the data." + " Defaults to the directory of the input file.", + type=Path, + ) args = parser.parse_args(argv) return args