From 84663ac8264078b9a9614b46a75150b24d76d629 Mon Sep 17 00:00:00 2001 From: "pengchao.hu" Date: Wed, 1 Mar 2023 21:31:05 +0800 Subject: [PATCH] auto remove temporary file to avoid too many large temporary file in disk Change-Id: Id255af8df330f7995d1b76c5c088ff5119a93ece --- python/numpy_helper/npz_compare.py | 38 ++++++++++-------------------- python/tools/model_deploy.py | 19 +++++++++++---- python/tools/model_runner.py | 7 +++--- python/tools/model_transform.py | 10 +++++--- python/transform/CaffeConverter.py | 2 +- python/transform/MLIRImporter.py | 2 +- python/transform/OnnxConverter.py | 11 ++++++--- python/utils/auto_remove.py | 33 ++++++++++++++++++++++++++ python/utils/mlir_shell.py | 4 +++- regression/run_model.sh | 26 ++++++++++---------- 10 files changed, 98 insertions(+), 54 deletions(-) create mode 100644 python/utils/auto_remove.py diff --git a/python/numpy_helper/npz_compare.py b/python/numpy_helper/npz_compare.py index 31eedc50f..eb8e67501 100755 --- a/python/numpy_helper/npz_compare.py +++ b/python/numpy_helper/npz_compare.py @@ -19,42 +19,30 @@ def parse_args(args_list): + # yapf: disable parser = argparse.ArgumentParser(description='Compare two npz tensor files.') parser.add_argument("target_file", help="Comparing target file") parser.add_argument("ref_file", help="Comparing reference file") parser.add_argument('--verbose', '-v', action='count', default=0) - parser.add_argument("--tolerance", - type=str, - default='0.99,0.99,0.90,50', - help="tolerance for cos/cor/euclid similarity/SQNR") - parser.add_argument('--op_info', - type=str, + parser.add_argument("--tolerance", type=str, default='0.99,0.99', + help="tolerance for cos/euclid similarity") + parser.add_argument('--op_info', type=str, help="A csv file op_info, including order and dequant threshold") - parser.add_argument("--dequant", - action='store_true', - default=False, + parser.add_argument("--dequant", action='store_true', help="Do dequantization flag, use threshold table provided in --op_info") parser.add_argument("--excepts", type=str, help="List of tensors except from comparing") - parser.add_argument("--full-array", - action='store_true', - default=False, + parser.add_argument("--full-array", action='store_true', help="Dump full array data when comparing failed") - parser.add_argument("--stats_int8_tensor", - action='store_true', - default=False, + parser.add_argument("--stats_int8_tensor", action='store_true', help="Do statistics on int8 tensor for saturate ratio and low ratio") - parser.add_argument("--int8_tensor_close", - type=int, - default=1, + parser.add_argument("--int8_tensor_close", type=int, default=1, help="whether int8 tensor compare close") parser.add_argument("--save", type=str, help="Save result as a csv file") - parser.add_argument("--per_axis_compare", - type=int, - default=-1, + parser.add_argument("--per_axis_compare", type=int, default=-1, help="Compare along axis, usually along axis 1 as per-channel") - parser.add_argument("--post_op", default=False, type=bool, - help="if the bmodel have post handle op") + parser.add_argument("--post_op", action='store_true', help="if the bmodel have post handle op") args = parser.parse_args(args_list) + # yapf: enable return args @@ -207,11 +195,11 @@ def npz_compare(args_list): #Todo: select the minimum shape as the base to compare p = multiprocessing.Process(target=compare_one_array, args=(tc, npz1, npz2, name, args.verbose, lock, dic, - int8_tensor_close, args.per_axis_compare)) + int8_tensor_close, args.per_axis_compare)) else: p = multiprocessing.Process(target=compare_one_array, args=(tc, npz1, npz2, name, args.verbose, lock, dic, - int8_tensor_close, args.per_axis_compare)) + int8_tensor_close, args.per_axis_compare)) processes.append(p) p.start() diff --git a/python/tools/model_deploy.py b/python/tools/model_deploy.py index d08317e76..2c8fa05ab 100755 --- a/python/tools/model_deploy.py +++ b/python/tools/model_deploy.py @@ -8,12 +8,12 @@ # # ============================================================================== -import abc import numpy as np import argparse from utils.mlir_shell import * from utils.mlir_parser import * from utils.preprocess import preprocess, supported_customization_format +from utils.auto_remove import file_mark, file_clean from tools.model_runner import mlir_inference, model_inference, show_fake_cmd import pymlir from utils.misc import str2bool @@ -86,8 +86,12 @@ def __init__(self, args): self.prefix += "_sym" self._prepare_input_npz() + def cleanup(self): + file_clean() + def lowering(self): self.tpu_mlir = "{}_tpu.mlir".format(self.prefix) + file_mark(self.tpu_mlir) self.final_mlir = "{}_final.mlir".format(self.prefix) mlir_lowering(self.mlir_file, self.tpu_mlir, self.quantize, self.chip, self.cali_table, self.asymmetric, self.quantize_table, False, self.customization_format, @@ -153,7 +157,7 @@ def _prepare_input_npz(self): input_op = self.module.inputs[0].op ppa.load_config(input_op) self.customization_format = getCustomFormat(ppa.pixel_format, ppa.channel_format) - assert (self.customization_format.starts_with("YUV") < 0) + assert (self.customization_format.startswith("YUV") < 0) if str(self.chip).lower().endswith('183x'): ppa.VPSS_W_ALIGN = 32 ppa.VPSS_Y_ALIGN = 32 @@ -168,7 +172,7 @@ def _prepare_input_npz(self): x = np.squeeze(data, 0) if self.customization_format == "GRAYSCALE": x = ppa.align_gray_frame(x, self.aligned_input) - elif self.customization_format.ends_with("_PLANAR") >= 0: + elif self.customization_format.endswith("_PLANAR") >= 0: x = ppa.align_planar_frame(x, self.aligned_input) else: x = ppa.align_packed_frame(x, self.aligned_input) @@ -185,6 +189,7 @@ def _prepare_input_npz(self): top_outputs = mlir_inference(self.inputs, self.mlir_file) np.savez(self.ref_npz, **top_outputs) self.tpu_npz = "{}_tpu_outputs.npz".format(self.prefix) + file_mark(self.tpu_npz) def validate_tpu_mlir(self): show_fake_cmd(self.in_f32_npz, self.tpu_mlir, self.tpu_npz) @@ -208,6 +213,7 @@ def build_model(self): def validate_model(self): self.model_npz = "{}_model_outputs.npz".format(self.prefix) + file_mark(self.model_npz) show_fake_cmd(self.in_f32_npz, self.model, self.model_npz, self.post_op) model_outputs = model_inference(self.inputs, self.model, self.post_op) np.savez(self.model_npz, **model_outputs) @@ -263,11 +269,12 @@ def validate_model(self): help="strip output type cast in bmodel, need outside type conversion") parser.add_argument("--disable_layer_group", action="store_true", help="Decide whether to enable layer group pass") - parser.add_argument("--post_op", default=False, type=str2bool, + parser.add_argument("--post_op", action="store_true", help="if the bmodel have post handle op") + parser.add_argument("--debug", action='store_true', help='to keep all intermediate files for debug') # yapf: enable args = parser.parse_args() - if args.customization_format is not None and args.customization_format.starts_with("YUV") >= 0: + if args.customization_format is not None and args.customization_format.startswith("YUV") >= 0: args.aligned_input = True tool = DeployTool(args) @@ -275,3 +282,5 @@ def validate_model(self): tool.lowering() # generate model tool.build_model() + if not args.debug: + tool.cleanup() diff --git a/python/tools/model_runner.py b/python/tools/model_runner.py index 8744795da..5bcf7f87a 100755 --- a/python/tools/model_runner.py +++ b/python/tools/model_runner.py @@ -43,7 +43,8 @@ def fp32_to_bf16(d_fp32): def show_fake_cmd(in_npz: str, model: str, out_npz: str, post_op=False): - print("[CMD]: model_runner.py --input {} --model {} --output {} --post_op {}".format(in_npz, model, out_npz, post_op)) + post_param = "" if not post_op else "--post_op" + print("[CMD]: model_runner.py --input {} --model {} --output {} {}".format(in_npz, model, out_npz, post_param)) def get_chip_from_model(model_file: str) -> str: @@ -354,11 +355,11 @@ def pytorch_inference(inputs: dict, model: str, dump_all: bool = True) -> dict: help="mlir/onnx/tflie/bmodel/prototxt file.") parser.add_argument("--weight", type=str, default="", help="caffemodel for caffe") parser.add_argument("--output", default='_output.npz', help="output npz file") - parser.add_argument("--dump_all_tensors",action='store_true', + parser.add_argument("--dump_all_tensors", action='store_true', help="dump all tensors to output file") parser.add_argument("--debug", type=str, nargs="?", const="", help="configure the debugging information.") - parser.add_argument("--post_op", default=False, type=str2bool, + parser.add_argument("--post_op", action='store_true', help="if the bmodel have post handle op") # yapf: enable args = parser.parse_args() diff --git a/python/tools/model_transform.py b/python/tools/model_transform.py index fa940e2ae..4985dfd52 100755 --- a/python/tools/model_transform.py +++ b/python/tools/model_transform.py @@ -16,6 +16,7 @@ from utils.mlir_shell import * from utils.mlir_parser import * from utils.misc import * +from utils.auto_remove import file_mark, file_clean from utils.preprocess import get_preprocess_parser, preprocess import pymlir @@ -27,11 +28,12 @@ def __init__(self, model_name): self.do_mlir_infer = True def cleanup(self): - pass + file_clean() def model_transform(self, mlir_file: str, post_handle_type=""): self.mlir_file = mlir_file mlir_origin = mlir_file.replace('.mlir', '_origin.mlir', 1) + file_mark(mlir_origin) self.converter.generate_mlir(mlir_origin) mlir_opt_for_top(mlir_origin, self.mlir_file, post_handle_type) print("Mlir file generated:{}".format(mlir_file)) @@ -72,9 +74,9 @@ def model_validate(self, file_list: str, tolerance, excepts, test_result): show_fake_cmd(in_f32_npz, self.mlir_file, test_result) f32_outputs = mlir_inference(inputs, self.mlir_file) np.savez(test_result, **f32_outputs) - # compare all blobs layer by layers f32_blobs_compare(test_result, ref_npz, tolerance, excepts=excepts) + file_mark(ref_npz) else: np.savez(test_result, **ref_outputs) @@ -205,6 +207,7 @@ def get_model_transform(args): parser.add_argument("--excepts", default='-', help="excepts") parser.add_argument("--post_handle_type", default="", type=str, help="post handle type, such as yolo,ssd etc") + parser.add_argument("--debug", action='store_true', help='to keep all intermediate files for debug') parser.add_argument("--mlir", type=str, required=True, help="output mlir model file") # yapf: enable parser = get_preprocess_parser(existed_parser=parser) @@ -214,4 +217,5 @@ def get_model_transform(args): if args.test_input: assert (args.test_result) tool.model_validate(args.test_input, args.tolerance, args.excepts, args.test_result) - tool.cleanup() + if not args.debug: + tool.cleanup() diff --git a/python/transform/CaffeConverter.py b/python/transform/CaffeConverter.py index e65ca5210..7cde8a291 100644 --- a/python/transform/CaffeConverter.py +++ b/python/transform/CaffeConverter.py @@ -53,7 +53,7 @@ def __init__(self, self.blobs = self.net.blobs self.mlir = None self.layer_dict = self.net.layer_dict - self.weight_file = "{}_top_weight.npz".format(model_name) + self.weight_file = "{}_top_origin_weight.npz".format(model_name) self.init_shapes(input_shapes) self.init_MLIRImporter() self.location = self.resolve_alias() diff --git a/python/transform/MLIRImporter.py b/python/transform/MLIRImporter.py index a2c3985eb..5791ebd70 100755 --- a/python/transform/MLIRImporter.py +++ b/python/transform/MLIRImporter.py @@ -98,7 +98,7 @@ class State: def get_weight_file(model_name: str, state: str, chip: str): - name = "{}_{}_{}_weight.npz".format(model_name, state, chip) + name = "{}_{}_{}_origin_weight.npz".format(model_name, state, chip) return name.lower() diff --git a/python/transform/OnnxConverter.py b/python/transform/OnnxConverter.py index 9624fe91b..bf53eb917 100755 --- a/python/transform/OnnxConverter.py +++ b/python/transform/OnnxConverter.py @@ -18,8 +18,8 @@ import onnx import onnxruntime import numpy as np -import random -from utils.pad_setting import get_TF_SAME_Padding, set_auto_pad +from utils.pad_setting import set_auto_pad +from utils.auto_remove import file_mark, file_clean import copy onnx_attr_translator = { @@ -102,7 +102,7 @@ def __init__(self, preprocess_args=None): super().__init__() self.model_name = model_name - self.weight_file = "{}_top_weight.npz".format(model_name) + self.weight_file = "{}_top_origin_weight.npz".format(model_name) self.model = None self.mlir = None self.node_name_mapping = {} # used in onnx opt @@ -204,6 +204,9 @@ def __del__(self): del self.mlir self.mlir = None + def cleanup(self): + file_clean() + def check_need(self, name): for node in self.converted_nodes: for i in node.inputs: @@ -372,6 +375,7 @@ def load_onnx_model(self, onnx_file, input_shapes: list, output_names: list): self.addWeight(name, data) self.add_shape_info() self.onnx_file = "{}_opt.onnx".format(self.model_name) + file_mark(self.onnx_file) onnx.save(self.model, self.onnx_file) strip_model = onnx.ModelProto() strip_model.CopyFrom(self.model) @@ -433,6 +437,7 @@ def get_unk_shape(self, unk_op): intermediate_layer_value_info.name = name model.graph.output.append(intermediate_layer_value_info) onnx_file = "generate_onnx_with_unk.onnx" + file_mark(onnx_file) onnx.save(model, onnx_file) session = onnxruntime.InferenceSession(onnx_file) os.remove(onnx_file) diff --git a/python/utils/auto_remove.py b/python/utils/auto_remove.py new file mode 100644 index 000000000..a76c9f132 --- /dev/null +++ b/python/utils/auto_remove.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# ============================================================================== +# +# Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved. +# +# TPU-MLIR is licensed under the 2-Clause BSD License except for the +# third-party components. +# +# ============================================================================== + +import os +from .mlir_parser import MlirParser + +g_auto_remove_files = [] + + +def file_mark(file: str): + g_auto_remove_files.append(file) + + +def file_clean(): + for n in g_auto_remove_files: + if not os.path.exists(n): + continue + if n.endswith('.mlir'): + try: + parser = MlirParser(n) + weight_npz = parser.module_weight_file + if os.path.exists(weight_npz): + os.remove(weight_npz) + except: + pass + os.remove(n) diff --git a/python/utils/mlir_shell.py b/python/utils/mlir_shell.py index 250022b16..28348f7d3 100755 --- a/python/utils/mlir_shell.py +++ b/python/utils/mlir_shell.py @@ -159,7 +159,9 @@ def mlir_to_model(tpu_mlir: str, def f32_blobs_compare(a_npz: str, b_npz: str, tolerance: str, excepts=None, show_detail=True, post_op=False): - cmd = ["npz_tool.py", "compare", a_npz, b_npz, "--tolerance", tolerance, "--post_op", post_op] + cmd = ["npz_tool.py", "compare", a_npz, b_npz, "--tolerance", tolerance] + if post_op: + cmd.extend(["--post_op", post_op]) if excepts: cmd.extend(["--except", excepts]) if show_detail: diff --git a/regression/run_model.sh b/regression/run_model.sh index 30133c1a4..2685366ac 100755 --- a/regression/run_model.sh +++ b/regression/run_model.sh @@ -67,9 +67,11 @@ else do_dynamic=0 fi +do_post_opt= post_handle_def= if [ x$do_post_handle == x1 ]; then post_handle_def="--post_handle_type=${post_type}" + do_post_opt="--post_op" fi NET_DIR=$REGRESSION_PATH/regression_out/${model_name}_${chip_name} mkdir -p $NET_DIR @@ -231,7 +233,7 @@ if [ ${do_f32} == 1 ]; then --tolerance 0.99,0.99 \ --compare_all \ --model ${model_name}_${chip_name}_f32.${model_type} \ - --post_op ${do_post_handle} + ${do_post_opt} fi if [ ${do_f16} == 1 ]; then @@ -245,7 +247,7 @@ if [ ${do_f16} == 1 ]; then --tolerance 0.95,0.85 \ --compare_all \ --model ${model_name}_${chip_name}_f16.${model_type} \ - --post_op ${do_post_handle} + ${do_post_opt} fi if [ ${do_bf16} == 1 ]; then @@ -259,7 +261,7 @@ if [ ${do_bf16} == 1 ]; then --tolerance 0.95,0.85 \ --compare_all \ --model ${model_name}_${chip_name}_bf16.${model_type} \ - --post_op ${do_post_handle} + ${do_post_opt} fi ######################### @@ -318,7 +320,7 @@ if [ ${do_symmetric} == 1 ]; then --quant_input \ --quant_output \ --model ${model_name}_${chip_name}_int8_sym.${model_type} \ - --post_op ${do_post_handle} + ${do_post_opt} fi #do_symmetric @@ -342,7 +344,7 @@ if [ $do_asymmetric == 1 ]; then ${excepts_opt} \ --compare_all \ --model ${model_name}_${chip_name}_int8_asym.${model_type} \ - --post_op ${do_post_handle} + ${do_post_opt} fi #do_asymmetric @@ -386,7 +388,7 @@ if [ $do_dynamic == 1 ]; then --chip ${chip_name} \ --compare_all \ --model ${static_model} \ - --post_op ${do_post_handle} + ${do_post_opt} model_deploy.py \ --mlir ${model_name}.mlir \ @@ -399,16 +401,16 @@ if [ $do_dynamic == 1 ]; then --tolerance 0.99,0.99 \ --compare_all \ --model ${dynamic_model} \ - --post_op ${do_post_handle} + ${do_post_opt} model_runner.py --input ${static_input_npz} \ --model ${static_model} \ --output ${static_model_name}_out_f32.npz - --post_op ${do_post_handle} + ${do_post_opt} model_runner.py --input ${static_input_npz} \ --model ${dynamic_model} \ --output ${dynamic_model_name}_out_f32.npz \ - --post_op ${do_post_handle} + ${do_post_opt} npz_tool.py compare ${static_model_name}_out_f32.npz \ ${dynamic_model_name}_out_f32.npz -vv fi @@ -424,7 +426,7 @@ if [ $do_dynamic == 1 ]; then # ${excepts_opt} \ # --tolerance 0.95,0.85 \ # --model ${model_name}_${chip_name}_f16.${model_type} \ - # --post_op ${do_post_handle} + # ${do_post_opt} # fi # if [ ${do_bf16} == 1 ]; then @@ -438,7 +440,7 @@ if [ $do_dynamic == 1 ]; then # ${excepts_opt} \ # --tolerance 0.95,0.85 \ # --model ${model_name}_${chip_name}_bf16.${model_type} \ - # --post_op ${do_post_handle} + # ${do_post_opt} # fi fi #do_dynamic @@ -465,7 +467,7 @@ if [ x${do_int4_sym} == x1 ]; then # --quant_input \ # --quant_output \ # --model ${model_name}_${chip_name}_int4_sym.${model_type} \ - # --post_op ${do_post_handle} + # ${do_post_opt} #Temporary test code tpuc-opt ${model_name}.mlir \