@@ -158,7 +158,7 @@ def symbolic(g,
158
158
159
159
class ONNX_ORT (nn .Module ):
160
160
'''onnx module with ONNX-Runtime NMS operation.'''
161
- def __init__ (self , max_obj = 100 , iou_thres = 0.45 , score_thres = 0.25 , max_wh = 640 , device = None ):
161
+ def __init__ (self , max_obj = 100 , iou_thres = 0.45 , score_thres = 0.25 , max_wh = 640 , device = None , n_classes = 80 ):
162
162
super ().__init__ ()
163
163
self .device = device if device else torch .device ("cpu" )
164
164
self .max_obj = torch .tensor ([max_obj ]).to (device )
@@ -168,12 +168,17 @@ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, de
168
168
self .convert_matrix = torch .tensor ([[1 , 0 , 1 , 0 ], [0 , 1 , 0 , 1 ], [- 0.5 , 0 , 0.5 , 0 ], [0 , - 0.5 , 0 , 0.5 ]],
169
169
dtype = torch .float32 ,
170
170
device = self .device )
171
+ self .n_classes = n_classes
171
172
172
173
def forward (self , x ):
173
174
boxes = x [:, :, :4 ]
174
175
conf = x [:, :, 4 :5 ]
175
176
scores = x [:, :, 5 :]
176
- scores *= conf
177
+ if self .n_classes == 1 :
178
+ scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
179
+ # so there is no need to multiplicate.
180
+ else :
181
+ scores *= conf # conf = obj_conf * cls_conf
177
182
boxes @= self .convert_matrix
178
183
max_score , category_id = scores .max (2 , keepdim = True )
179
184
dis = category_id .float () * self .max_wh
@@ -189,7 +194,7 @@ def forward(self, x):
189
194
190
195
class ONNX_TRT (nn .Module ):
191
196
'''onnx module with TensorRT NMS operation.'''
192
- def __init__ (self , max_obj = 100 , iou_thres = 0.45 , score_thres = 0.25 , max_wh = None ,device = None ):
197
+ def __init__ (self , max_obj = 100 , iou_thres = 0.45 , score_thres = 0.25 , max_wh = None ,device = None , n_classes = 80 ):
193
198
super ().__init__ ()
194
199
assert max_wh is None
195
200
self .device = device if device else torch .device ('cpu' )
@@ -200,12 +205,17 @@ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,d
200
205
self .plugin_version = '1'
201
206
self .score_activation = 0
202
207
self .score_threshold = score_thres
208
+ self .n_classes = n_classes
203
209
204
210
def forward (self , x ):
205
211
boxes = x [:, :, :4 ]
206
212
conf = x [:, :, 4 :5 ]
207
213
scores = x [:, :, 5 :]
208
- scores *= conf
214
+ if self .n_classes == 1 :
215
+ scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
216
+ # so there is no need to multiplicate.
217
+ else :
218
+ scores *= conf # conf = obj_conf * cls_conf
209
219
num_det , det_boxes , det_scores , det_classes = TRT_NMS .apply (boxes , scores , self .background_class , self .box_coding ,
210
220
self .iou_threshold , self .max_obj ,
211
221
self .plugin_version , self .score_activation ,
@@ -215,14 +225,14 @@ def forward(self, x):
215
225
216
226
class End2End (nn .Module ):
217
227
'''export onnx or tensorrt model with NMS operation.'''
218
- def __init__ (self , model , max_obj = 100 , iou_thres = 0.45 , score_thres = 0.25 , max_wh = None , device = None ):
228
+ def __init__ (self , model , max_obj = 100 , iou_thres = 0.45 , score_thres = 0.25 , max_wh = None , device = None , n_classes = 80 ):
219
229
super ().__init__ ()
220
230
device = device if device else torch .device ('cpu' )
221
231
assert isinstance (max_wh ,(int )) or max_wh is None
222
232
self .model = model .to (device )
223
233
self .model .model [- 1 ].end2end = True
224
234
self .patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
225
- self .end2end = self .patch_model (max_obj , iou_thres , score_thres , max_wh , device )
235
+ self .end2end = self .patch_model (max_obj , iou_thres , score_thres , max_wh , device , n_classes )
226
236
self .end2end .eval ()
227
237
228
238
def forward (self , x ):
0 commit comments