Skip to content

Commit 1cb8aa5

Browse files
authored
Fixed issue with confidence for single class detectors when exporting (WongKinYiu#607)
* Fixed issue with confidence for single class detectors when exporting * Typo
1 parent 09d6293 commit 1cb8aa5

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@
146146
if opt.grid:
147147
if opt.end2end:
148148
print('\nStarting export end2end onnx model for %s...' % 'TensorRT' if opt.max_wh is None else 'onnxruntime')
149-
model = End2End(model,opt.topk_all,opt.iou_thres,opt.conf_thres,opt.max_wh,device)
149+
model = End2End(model,opt.topk_all,opt.iou_thres,opt.conf_thres,opt.max_wh,device,len(labels))
150150
if opt.end2end and opt.max_wh is None:
151151
output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
152152
shapes = [opt.batch_size, 1, opt.batch_size, opt.topk_all, 4,

models/experimental.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def symbolic(g,
158158

159159
class ONNX_ORT(nn.Module):
160160
'''onnx module with ONNX-Runtime NMS operation.'''
161-
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None):
161+
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None, n_classes=80):
162162
super().__init__()
163163
self.device = device if device else torch.device("cpu")
164164
self.max_obj = torch.tensor([max_obj]).to(device)
@@ -168,12 +168,17 @@ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, de
168168
self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
169169
dtype=torch.float32,
170170
device=self.device)
171+
self.n_classes=n_classes
171172

172173
def forward(self, x):
173174
boxes = x[:, :, :4]
174175
conf = x[:, :, 4:5]
175176
scores = x[:, :, 5:]
176-
scores *= conf
177+
if self.n_classes == 1:
178+
scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
179+
# so there is no need to multiplicate.
180+
else:
181+
scores *= conf # conf = obj_conf * cls_conf
177182
boxes @= self.convert_matrix
178183
max_score, category_id = scores.max(2, keepdim=True)
179184
dis = category_id.float() * self.max_wh
@@ -189,7 +194,7 @@ def forward(self, x):
189194

190195
class ONNX_TRT(nn.Module):
191196
'''onnx module with TensorRT NMS operation.'''
192-
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None):
197+
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80):
193198
super().__init__()
194199
assert max_wh is None
195200
self.device = device if device else torch.device('cpu')
@@ -200,12 +205,17 @@ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,d
200205
self.plugin_version = '1'
201206
self.score_activation = 0
202207
self.score_threshold = score_thres
208+
self.n_classes=n_classes
203209

204210
def forward(self, x):
205211
boxes = x[:, :, :4]
206212
conf = x[:, :, 4:5]
207213
scores = x[:, :, 5:]
208-
scores *= conf
214+
if self.n_classes == 1:
215+
scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
216+
# so there is no need to multiplicate.
217+
else:
218+
scores *= conf # conf = obj_conf * cls_conf
209219
num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(boxes, scores, self.background_class, self.box_coding,
210220
self.iou_threshold, self.max_obj,
211221
self.plugin_version, self.score_activation,
@@ -215,14 +225,14 @@ def forward(self, x):
215225

216226
class End2End(nn.Module):
217227
'''export onnx or tensorrt model with NMS operation.'''
218-
def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None):
228+
def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None, n_classes=80):
219229
super().__init__()
220230
device = device if device else torch.device('cpu')
221231
assert isinstance(max_wh,(int)) or max_wh is None
222232
self.model = model.to(device)
223233
self.model.model[-1].end2end = True
224234
self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
225-
self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device)
235+
self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device, n_classes)
226236
self.end2end.eval()
227237

228238
def forward(self, x):

0 commit comments

Comments
 (0)