Skip to content

Commit a7c0029

Browse files
authored
Support dynamic batch for TensorRT and onnxruntime (WongKinYiu#329)
* Support dynamic batch for TensorRT and onnxruntime * Fix output name * Add some images * Add dynamic-batch usage notebook * Add example notebook for onnxruntime and tensorrt
1 parent 0d882e5 commit a7c0029

File tree

8 files changed

+5231
-2
lines changed

8 files changed

+5231
-2
lines changed

YOLOv7-Dynamic-Batch-ONNXRUNTIME.ipynb

Lines changed: 693 additions & 0 deletions
Large diffs are not rendered by default.

YOLOv7-Dynamic-Batch-TENSORRT.ipynb

Lines changed: 4512 additions & 0 deletions
Large diffs are not rendered by default.

export.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
2121
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
2222
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes')
23+
parser.add_argument('--dynamic-batch', action='store_true', help='dynamic batch onnx for tensorrt and onnx-runtime')
2324
parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
2425
parser.add_argument('--end2end', action='store_true', help='export end2end onnx')
2526
parser.add_argument('--max-wh', type=int, default=None, help='None for tensorrt nms, int value for onnx-runtime nms')
@@ -31,6 +32,8 @@
3132
parser.add_argument('--include-nms', action='store_true', help='export end2end onnx')
3233
opt = parser.parse_args()
3334
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
35+
opt.dynamic = opt.dynamic and not opt.end2end
36+
opt.dynamic = False if opt.dynamic_batch else opt.dynamic
3437
print(opt)
3538
set_logging()
3639
t = time.time()
@@ -80,6 +83,28 @@
8083
f = opt.weights.replace('.pt', '.onnx') # filename
8184
model.eval()
8285
output_names = ['classes', 'boxes'] if y is None else ['output']
86+
dynamic_axes = None
87+
if opt.dynamic:
88+
dynamic_axes = {'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
89+
'output': {0: 'batch', 2: 'y', 3: 'x'}}
90+
if opt.dynamic_batch:
91+
opt.batch_size = 'batch'
92+
dynamic_axes = {
93+
'images': {
94+
0: 'batch',
95+
}, }
96+
if opt.end2end and opt.max_wh is None:
97+
output_axes = {
98+
'num_dets': {0: 'batch'},
99+
'det_boxes': {0: 'batch'},
100+
'det_scores': {0: 'batch'},
101+
'det_classes': {0: 'batch'},
102+
}
103+
else:
104+
output_axes = {
105+
'output': {0: 'batch'},
106+
}
107+
dynamic_axes.update(output_axes)
83108
if opt.grid and opt.end2end:
84109
print('\nStarting export end2end onnx model for %s...' % 'TensorRT' if opt.max_wh is None else 'onnxruntime')
85110
model = End2End(model,opt.topk_all,opt.iou_thres,opt.conf_thres,opt.max_wh,device)
@@ -92,8 +117,7 @@
92117

93118
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
94119
output_names=output_names,
95-
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
96-
'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic and not opt.end2end else None)
120+
dynamic_axes=dynamic_axes)
97121

98122
# Checks
99123
onnx_model = onnx.load(f) # load onnx model

inference/images/bus.jpg

476 KB
Loading

inference/images/image1.jpg

78.8 KB
Loading

inference/images/image2.jpg

140 KB
Loading

inference/images/image3.jpg

115 KB
Loading

inference/images/zidane.jpg

165 KB
Loading

0 commit comments

Comments
 (0)