Skip to content

Commit

Permalink
auto remove temporary file
Browse files Browse the repository at this point in the history
to avoid too many large temporary file in disk

Change-Id: Id255af8df330f7995d1b76c5c088ff5119a93ece
  • Loading branch information
HarmonyHu committed Mar 2, 2023
1 parent d670ada commit 84663ac
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 54 deletions.
38 changes: 13 additions & 25 deletions python/numpy_helper/npz_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,30 @@


def parse_args(args_list):
# yapf: disable
parser = argparse.ArgumentParser(description='Compare two npz tensor files.')
parser.add_argument("target_file", help="Comparing target file")
parser.add_argument("ref_file", help="Comparing reference file")
parser.add_argument('--verbose', '-v', action='count', default=0)
parser.add_argument("--tolerance",
type=str,
default='0.99,0.99,0.90,50',
help="tolerance for cos/cor/euclid similarity/SQNR")
parser.add_argument('--op_info',
type=str,
parser.add_argument("--tolerance", type=str, default='0.99,0.99',
help="tolerance for cos/euclid similarity")
parser.add_argument('--op_info', type=str,
help="A csv file op_info, including order and dequant threshold")
parser.add_argument("--dequant",
action='store_true',
default=False,
parser.add_argument("--dequant", action='store_true',
help="Do dequantization flag, use threshold table provided in --op_info")
parser.add_argument("--excepts", type=str, help="List of tensors except from comparing")
parser.add_argument("--full-array",
action='store_true',
default=False,
parser.add_argument("--full-array", action='store_true',
help="Dump full array data when comparing failed")
parser.add_argument("--stats_int8_tensor",
action='store_true',
default=False,
parser.add_argument("--stats_int8_tensor", action='store_true',
help="Do statistics on int8 tensor for saturate ratio and low ratio")
parser.add_argument("--int8_tensor_close",
type=int,
default=1,
parser.add_argument("--int8_tensor_close", type=int, default=1,
help="whether int8 tensor compare close")
parser.add_argument("--save", type=str, help="Save result as a csv file")
parser.add_argument("--per_axis_compare",
type=int,
default=-1,
parser.add_argument("--per_axis_compare", type=int, default=-1,
help="Compare along axis, usually along axis 1 as per-channel")
parser.add_argument("--post_op", default=False, type=bool,
help="if the bmodel have post handle op")
parser.add_argument("--post_op", action='store_true', help="if the bmodel have post handle op")
args = parser.parse_args(args_list)
# yapf: enable
return args


Expand Down Expand Up @@ -207,11 +195,11 @@ def npz_compare(args_list):
#Todo: select the minimum shape as the base to compare
p = multiprocessing.Process(target=compare_one_array,
args=(tc, npz1, npz2, name, args.verbose, lock, dic,
int8_tensor_close, args.per_axis_compare))
int8_tensor_close, args.per_axis_compare))
else:
p = multiprocessing.Process(target=compare_one_array,
args=(tc, npz1, npz2, name, args.verbose, lock, dic,
int8_tensor_close, args.per_axis_compare))
int8_tensor_close, args.per_axis_compare))
processes.append(p)
p.start()

Expand Down
19 changes: 14 additions & 5 deletions python/tools/model_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
#
# ==============================================================================

import abc
import numpy as np
import argparse
from utils.mlir_shell import *
from utils.mlir_parser import *
from utils.preprocess import preprocess, supported_customization_format
from utils.auto_remove import file_mark, file_clean
from tools.model_runner import mlir_inference, model_inference, show_fake_cmd
import pymlir
from utils.misc import str2bool
Expand Down Expand Up @@ -86,8 +86,12 @@ def __init__(self, args):
self.prefix += "_sym"
self._prepare_input_npz()

def cleanup(self):
file_clean()

def lowering(self):
self.tpu_mlir = "{}_tpu.mlir".format(self.prefix)
file_mark(self.tpu_mlir)
self.final_mlir = "{}_final.mlir".format(self.prefix)
mlir_lowering(self.mlir_file, self.tpu_mlir, self.quantize, self.chip, self.cali_table,
self.asymmetric, self.quantize_table, False, self.customization_format,
Expand Down Expand Up @@ -153,7 +157,7 @@ def _prepare_input_npz(self):
input_op = self.module.inputs[0].op
ppa.load_config(input_op)
self.customization_format = getCustomFormat(ppa.pixel_format, ppa.channel_format)
assert (self.customization_format.starts_with("YUV") < 0)
assert (self.customization_format.startswith("YUV") < 0)
if str(self.chip).lower().endswith('183x'):
ppa.VPSS_W_ALIGN = 32
ppa.VPSS_Y_ALIGN = 32
Expand All @@ -168,7 +172,7 @@ def _prepare_input_npz(self):
x = np.squeeze(data, 0)
if self.customization_format == "GRAYSCALE":
x = ppa.align_gray_frame(x, self.aligned_input)
elif self.customization_format.ends_with("_PLANAR") >= 0:
elif self.customization_format.endswith("_PLANAR") >= 0:
x = ppa.align_planar_frame(x, self.aligned_input)
else:
x = ppa.align_packed_frame(x, self.aligned_input)
Expand All @@ -185,6 +189,7 @@ def _prepare_input_npz(self):
top_outputs = mlir_inference(self.inputs, self.mlir_file)
np.savez(self.ref_npz, **top_outputs)
self.tpu_npz = "{}_tpu_outputs.npz".format(self.prefix)
file_mark(self.tpu_npz)

def validate_tpu_mlir(self):
show_fake_cmd(self.in_f32_npz, self.tpu_mlir, self.tpu_npz)
Expand All @@ -208,6 +213,7 @@ def build_model(self):

def validate_model(self):
self.model_npz = "{}_model_outputs.npz".format(self.prefix)
file_mark(self.model_npz)
show_fake_cmd(self.in_f32_npz, self.model, self.model_npz, self.post_op)
model_outputs = model_inference(self.inputs, self.model, self.post_op)
np.savez(self.model_npz, **model_outputs)
Expand Down Expand Up @@ -263,15 +269,18 @@ def validate_model(self):
help="strip output type cast in bmodel, need outside type conversion")
parser.add_argument("--disable_layer_group", action="store_true",
help="Decide whether to enable layer group pass")
parser.add_argument("--post_op", default=False, type=str2bool,
parser.add_argument("--post_op", action="store_true",
help="if the bmodel have post handle op")
parser.add_argument("--debug", action='store_true', help='to keep all intermediate files for debug')
# yapf: enable
args = parser.parse_args()
if args.customization_format is not None and args.customization_format.starts_with("YUV") >= 0:
if args.customization_format is not None and args.customization_format.startswith("YUV") >= 0:
args.aligned_input = True

tool = DeployTool(args)
# lowering to tpu
tool.lowering()
# generate model
tool.build_model()
if not args.debug:
tool.cleanup()
7 changes: 4 additions & 3 deletions python/tools/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def fp32_to_bf16(d_fp32):


def show_fake_cmd(in_npz: str, model: str, out_npz: str, post_op=False):
print("[CMD]: model_runner.py --input {} --model {} --output {} --post_op {}".format(in_npz, model, out_npz, post_op))
post_param = "" if not post_op else "--post_op"
print("[CMD]: model_runner.py --input {} --model {} --output {} {}".format(in_npz, model, out_npz, post_param))


def get_chip_from_model(model_file: str) -> str:
Expand Down Expand Up @@ -354,11 +355,11 @@ def pytorch_inference(inputs: dict, model: str, dump_all: bool = True) -> dict:
help="mlir/onnx/tflie/bmodel/prototxt file.")
parser.add_argument("--weight", type=str, default="", help="caffemodel for caffe")
parser.add_argument("--output", default='_output.npz', help="output npz file")
parser.add_argument("--dump_all_tensors",action='store_true',
parser.add_argument("--dump_all_tensors", action='store_true',
help="dump all tensors to output file")
parser.add_argument("--debug", type=str, nargs="?", const="",
help="configure the debugging information.")
parser.add_argument("--post_op", default=False, type=str2bool,
parser.add_argument("--post_op", action='store_true',
help="if the bmodel have post handle op")
# yapf: enable
args = parser.parse_args()
Expand Down
10 changes: 7 additions & 3 deletions python/tools/model_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from utils.mlir_shell import *
from utils.mlir_parser import *
from utils.misc import *
from utils.auto_remove import file_mark, file_clean
from utils.preprocess import get_preprocess_parser, preprocess
import pymlir

Expand All @@ -27,11 +28,12 @@ def __init__(self, model_name):
self.do_mlir_infer = True

def cleanup(self):
pass
file_clean()

def model_transform(self, mlir_file: str, post_handle_type=""):
self.mlir_file = mlir_file
mlir_origin = mlir_file.replace('.mlir', '_origin.mlir', 1)
file_mark(mlir_origin)
self.converter.generate_mlir(mlir_origin)
mlir_opt_for_top(mlir_origin, self.mlir_file, post_handle_type)
print("Mlir file generated:{}".format(mlir_file))
Expand Down Expand Up @@ -72,9 +74,9 @@ def model_validate(self, file_list: str, tolerance, excepts, test_result):
show_fake_cmd(in_f32_npz, self.mlir_file, test_result)
f32_outputs = mlir_inference(inputs, self.mlir_file)
np.savez(test_result, **f32_outputs)

# compare all blobs layer by layers
f32_blobs_compare(test_result, ref_npz, tolerance, excepts=excepts)
file_mark(ref_npz)
else:
np.savez(test_result, **ref_outputs)

Expand Down Expand Up @@ -205,6 +207,7 @@ def get_model_transform(args):
parser.add_argument("--excepts", default='-', help="excepts")
parser.add_argument("--post_handle_type", default="", type=str,
help="post handle type, such as yolo,ssd etc")
parser.add_argument("--debug", action='store_true', help='to keep all intermediate files for debug')
parser.add_argument("--mlir", type=str, required=True, help="output mlir model file")
# yapf: enable
parser = get_preprocess_parser(existed_parser=parser)
Expand All @@ -214,4 +217,5 @@ def get_model_transform(args):
if args.test_input:
assert (args.test_result)
tool.model_validate(args.test_input, args.tolerance, args.excepts, args.test_result)
tool.cleanup()
if not args.debug:
tool.cleanup()
2 changes: 1 addition & 1 deletion python/transform/CaffeConverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self,
self.blobs = self.net.blobs
self.mlir = None
self.layer_dict = self.net.layer_dict
self.weight_file = "{}_top_weight.npz".format(model_name)
self.weight_file = "{}_top_origin_weight.npz".format(model_name)
self.init_shapes(input_shapes)
self.init_MLIRImporter()
self.location = self.resolve_alias()
Expand Down
2 changes: 1 addition & 1 deletion python/transform/MLIRImporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class State:


def get_weight_file(model_name: str, state: str, chip: str):
name = "{}_{}_{}_weight.npz".format(model_name, state, chip)
name = "{}_{}_{}_origin_weight.npz".format(model_name, state, chip)
return name.lower()


Expand Down
11 changes: 8 additions & 3 deletions python/transform/OnnxConverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import onnx
import onnxruntime
import numpy as np
import random
from utils.pad_setting import get_TF_SAME_Padding, set_auto_pad
from utils.pad_setting import set_auto_pad
from utils.auto_remove import file_mark, file_clean
import copy

onnx_attr_translator = {
Expand Down Expand Up @@ -102,7 +102,7 @@ def __init__(self,
preprocess_args=None):
super().__init__()
self.model_name = model_name
self.weight_file = "{}_top_weight.npz".format(model_name)
self.weight_file = "{}_top_origin_weight.npz".format(model_name)
self.model = None
self.mlir = None
self.node_name_mapping = {} # used in onnx opt
Expand Down Expand Up @@ -204,6 +204,9 @@ def __del__(self):
del self.mlir
self.mlir = None

def cleanup(self):
file_clean()

def check_need(self, name):
for node in self.converted_nodes:
for i in node.inputs:
Expand Down Expand Up @@ -372,6 +375,7 @@ def load_onnx_model(self, onnx_file, input_shapes: list, output_names: list):
self.addWeight(name, data)
self.add_shape_info()
self.onnx_file = "{}_opt.onnx".format(self.model_name)
file_mark(self.onnx_file)
onnx.save(self.model, self.onnx_file)
strip_model = onnx.ModelProto()
strip_model.CopyFrom(self.model)
Expand Down Expand Up @@ -433,6 +437,7 @@ def get_unk_shape(self, unk_op):
intermediate_layer_value_info.name = name
model.graph.output.append(intermediate_layer_value_info)
onnx_file = "generate_onnx_with_unk.onnx"
file_mark(onnx_file)
onnx.save(model, onnx_file)
session = onnxruntime.InferenceSession(onnx_file)
os.remove(onnx_file)
Expand Down
33 changes: 33 additions & 0 deletions python/utils/auto_remove.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/env python3
# ==============================================================================
#
# Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved.
#
# TPU-MLIR is licensed under the 2-Clause BSD License except for the
# third-party components.
#
# ==============================================================================

import os
from .mlir_parser import MlirParser

g_auto_remove_files = []


def file_mark(file: str):
g_auto_remove_files.append(file)


def file_clean():
for n in g_auto_remove_files:
if not os.path.exists(n):
continue
if n.endswith('.mlir'):
try:
parser = MlirParser(n)
weight_npz = parser.module_weight_file
if os.path.exists(weight_npz):
os.remove(weight_npz)
except:
pass
os.remove(n)
4 changes: 3 additions & 1 deletion python/utils/mlir_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def mlir_to_model(tpu_mlir: str,


def f32_blobs_compare(a_npz: str, b_npz: str, tolerance: str, excepts=None, show_detail=True, post_op=False):
cmd = ["npz_tool.py", "compare", a_npz, b_npz, "--tolerance", tolerance, "--post_op", post_op]
cmd = ["npz_tool.py", "compare", a_npz, b_npz, "--tolerance", tolerance]
if post_op:
cmd.extend(["--post_op", post_op])
if excepts:
cmd.extend(["--except", excepts])
if show_detail:
Expand Down
Loading

0 comments on commit 84663ac

Please sign in to comment.