-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport-det.py
executable file
·117 lines (108 loc) · 3.61 KB
/
export-det.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
from io import BytesIO
import onnx
import torch
from ultralytics import YOLO
from models.common import PostDetect, optim
try:
import onnxsim
except ImportError:
onnxsim = None
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-w',
'--weights',
type=str,
required=True,
help='PyTorch yolov8 weights')
parser.add_argument('--iou-thres',
type=float,
default=0.65,
help='IOU threshoud for NMS plugin')
parser.add_argument('--conf-thres',
type=float,
default=0.25,
help='CONF threshoud for NMS plugin')
parser.add_argument('--topk',
type=int,
default=100,
help='Max number of detection bboxes')
parser.add_argument('--opset',
type=int,
default=11,
help='ONNX opset version')
parser.add_argument('--sim',
action='store_true',
help='simplify onnx model')
parser.add_argument('--input-shape',
nargs='+',
type=int,
default=[1, 3, 640, 640],
help='Model input shape only for api builder')
parser.add_argument('--device',
type=str,
default='cpu',
help='Export ONNX device')
args = parser.parse_args()
assert len(args.input_shape) == 4
PostDetect.conf_thres = args.conf_thres
PostDetect.iou_thres = args.iou_thres
PostDetect.topk = args.topk
return args
def main(args):
YOLOv8 = YOLO(args.weights)
model = YOLOv8.model.fuse().eval()
for m in model.modules():
optim(m)
m.to(args.device)
model.to(args.device)
fake_input = torch.randn(args.input_shape).to(args.device)
for _ in range(2):
model(fake_input)
save_path = args.weights.replace('.pt', '.onnx')
with BytesIO() as f:
torch.onnx.export(
model,
fake_input,
f,
dynamic_axes={
'images': {
0: 'batch_size'
},
'num_dets': {
0: 'batch_size'
},
'bboxes': {
0: 'batch_size'
},
'scores': {
0: 'batch_size'
},
'labels': {
0: 'batch_size'
},
},
opset_version=args.opset,
input_names=['images'],
output_names=['num_dets', 'bboxes', 'scores', 'labels'],
# verbose=True,
)
f.seek(0)
onnx_model = onnx.load(f)
onnx.checker.check_model(onnx_model)
shapes = [None, 1, None, args.topk, 4, None, args.topk, None, args.topk]
for i in onnx_model.graph.output:
for j in i.type.tensor_type.shape.dim:
curr_shape = shapes.pop(0)
if curr_shape is not None:
j.dim_param = str(curr_shape)
if args.sim:
try:
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, 'assert check failed'
except Exception as e:
print(f'Simplifier failure: {e}')
onnx.save(onnx_model, save_path)
print(f'ONNX export success, saved as {save_path}')
if __name__ == '__main__':
main(parse_args())