@@ -272,46 +272,35 @@ def create_filter(
272
272
[self .uid_to_index [uid ] for uid in datum_uids ],
273
273
dtype = np .int32 ,
274
274
)
275
- mask = np .zeros_like (mask_pairs , dtype = np .bool_ )
276
- mask [
277
- np .isin (self ._ranked_pairs [:, 0 ].astype (int ), datum_uids )
278
- ] = True
279
- mask_pairs &= mask
280
-
281
- mask = np .zeros_like (mask_datums , dtype = np .bool_ )
282
- mask [datum_uids ] = True
283
- mask_datums &= mask
275
+ mask_pairs [
276
+ ~ np .isin (self ._ranked_pairs [:, 0 ].astype (int ), datum_uids )
277
+ ] = False
278
+ mask_datums [~ np .isin (np .arange (n_datums ), datum_uids )] = False
284
279
285
280
if labels is not None :
286
281
if isinstance (labels , list ):
287
282
labels = np .array (
288
283
[self .label_to_index [label ] for label in labels ]
289
284
)
290
- mask = np .zeros_like (mask_pairs , dtype = np .bool_ )
291
- mask [np .isin (self ._ranked_pairs [:, 4 ].astype (int ), labels )] = True
292
- mask_pairs &= mask
293
-
294
- mask = np .zeros_like (mask_labels , dtype = np .bool_ )
295
- mask [labels ] = True
296
- mask_labels &= mask
285
+ mask_pairs [
286
+ ~ np .isin (self ._ranked_pairs [:, 4 ].astype (int ), labels )
287
+ ] = False
288
+ mask_labels [~ np .isin (np .arange (n_labels ), labels )] = False
297
289
298
290
if label_keys is not None :
299
291
if isinstance (label_keys , list ):
300
292
label_keys = np .array (
301
293
[self .label_key_to_index [key ] for key in label_keys ]
302
294
)
303
- label_indices = np .where (
304
- np .isclose (self ._label_metadata [:, 2 ], label_keys )
305
- )[0 ]
306
- mask = np .zeros_like (mask_pairs , dtype = np .bool_ )
307
- mask [
308
- np .isin (self ._ranked_pairs [:, 4 ].astype (int ), label_indices )
309
- ] = True
310
- mask_pairs &= mask
311
-
312
- mask = np .zeros_like (mask_labels , dtype = np .bool_ )
313
- mask [label_indices ] = True
314
- mask_labels &= mask
295
+ label_indices = (
296
+ np .where (np .isclose (self ._label_metadata [:, 2 ], label_keys ))[0 ]
297
+ if label_keys .size > 0
298
+ else np .array ([])
299
+ )
300
+ mask_pairs [
301
+ ~ np .isin (self ._ranked_pairs [:, 4 ].astype (int ), label_indices )
302
+ ] = False
303
+ mask_labels [~ np .isin (np .arange (n_labels ), label_indices )] = False
315
304
316
305
mask = mask_datums [:, np .newaxis ] & mask_labels [np .newaxis , :]
317
306
label_metadata_per_datum = self ._label_metadata_per_datum .copy ()
@@ -399,7 +388,7 @@ def evaluate(
399
388
)
400
389
for iou_idx in range (average_precision .shape [0 ])
401
390
for label_idx in range (average_precision .shape [1 ])
402
- if int (label_metadata [label_idx ][ 0 ]) > 0
391
+ if int (label_metadata [label_idx , 0 ]) > 0
403
392
]
404
393
405
394
metrics [MetricType .mAP ] = [
@@ -419,7 +408,7 @@ def evaluate(
419
408
label = self .index_to_label [label_idx ],
420
409
)
421
410
for label_idx in range (self .n_labels )
422
- if int (label_metadata [label_idx ][ 0 ]) > 0
411
+ if int (label_metadata [label_idx , 0 ]) > 0
423
412
]
424
413
425
414
metrics [MetricType .mAPAveragedOverIOUs ] = [
@@ -442,7 +431,7 @@ def evaluate(
442
431
)
443
432
for score_idx in range (average_recall .shape [0 ])
444
433
for label_idx in range (average_recall .shape [1 ])
445
- if int (label_metadata [label_idx ][ 0 ]) > 0
434
+ if int (label_metadata [label_idx , 0 ]) > 0
446
435
]
447
436
448
437
metrics [MetricType .mAR ] = [
@@ -464,7 +453,7 @@ def evaluate(
464
453
label = self .index_to_label [label_idx ],
465
454
)
466
455
for label_idx in range (self .n_labels )
467
- if int (label_metadata [label_idx ][ 0 ]) > 0
456
+ if int (label_metadata [label_idx , 0 ]) > 0
468
457
]
469
458
470
459
metrics [MetricType .mARAveragedOverScores ] = [
@@ -487,16 +476,17 @@ def evaluate(
487
476
)
488
477
for iou_idx , iou_threshold in enumerate (iou_thresholds )
489
478
for label_idx , label in self .index_to_label .items ()
490
- if int (label_metadata [label_idx ][ 0 ]) > 0
479
+ if int (label_metadata [label_idx , 0 ]) > 0
491
480
]
492
481
493
482
for label_idx , label in self .index_to_label .items ():
483
+
484
+ if label_metadata [label_idx , 0 ] == 0 :
485
+ continue
486
+
494
487
for score_idx , score_threshold in enumerate (score_thresholds ):
495
488
for iou_idx , iou_threshold in enumerate (iou_thresholds ):
496
489
497
- if label_metadata [label_idx , 0 ] == 0 :
498
- continue
499
-
500
490
row = precision_recall [iou_idx ][score_idx ][label_idx ]
501
491
kwargs = {
502
492
"label" : label ,
0 commit comments