Skip to content

Commit 1fa9b62

Browse files
authored
BugFix Valor Lite Filtering (#767)
1 parent 8778b17 commit 1fa9b62

File tree

2 files changed

+137
-36
lines changed

2 files changed

+137
-36
lines changed

lite/tests/detection/test_filtering.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,117 @@ def test_filtering_four_detections(four_detections: list[Detection]):
411411
assert m in actual_metrics
412412

413413

414+
def test_filtering_all_detections(four_detections: list[Detection]):
415+
"""
416+
Basic object detection test.
417+
418+
groundtruths
419+
datum uid1
420+
box 1 - label (k1, v1) - tp
421+
box 3 - label (k2, v2) - fn missing prediction
422+
datum uid2
423+
box 2 - label (k1, v1) - fn misclassification
424+
datum uid3
425+
box 1 - label (k1, v1) - tp
426+
box 3 - label (k2, v2) - fn missing prediction
427+
datum uid4
428+
box 2 - label (k1, v1) - fn misclassification
429+
430+
predictions
431+
datum uid1
432+
box 1 - label (k1, v1) - score 0.3 - tp
433+
datum uid2
434+
box 2 - label (k2, v2) - score 0.98 - fp
435+
datum uid3
436+
box 1 - label (k1, v1) - score 0.3 - tp
437+
datum uid4
438+
box 2 - label (k2, v2) - score 0.98 - fp
439+
"""
440+
441+
loader = DataLoader()
442+
loader.add_bounding_boxes(four_detections)
443+
evaluator = loader.finalize()
444+
445+
assert (
446+
evaluator._ranked_pairs
447+
== np.array(
448+
[
449+
[1.0, -1.0, 0.0, 0.0, -1.0, 1.0, 0.98],
450+
[3.0, -1.0, 0.0, 0.0, -1.0, 1.0, 0.98],
451+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.3],
452+
[2.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.3],
453+
]
454+
)
455+
).all()
456+
457+
assert (
458+
evaluator._label_metadata_per_datum
459+
== np.array(
460+
[
461+
[
462+
[1, 1],
463+
[1, 0],
464+
[1, 1],
465+
[1, 0],
466+
],
467+
[
468+
[1, 0],
469+
[0, 1],
470+
[1, 0],
471+
[0, 1],
472+
],
473+
],
474+
dtype=np.int32,
475+
)
476+
).all()
477+
478+
assert (
479+
evaluator._label_metadata == np.array([[4, 2, 0], [2, 2, 1]])
480+
).all()
481+
482+
# test datum filtering
483+
484+
filter_ = evaluator.create_filter(datum_uids=[])
485+
print(filter_)
486+
assert (filter_.indices == np.array([])).all()
487+
assert (filter_.label_metadata == np.array([[0, 0, 0], [0, 0, 1]])).all()
488+
489+
# test label filtering
490+
491+
filter_ = evaluator.create_filter(labels=[])
492+
assert (filter_.indices == np.array([])).all()
493+
assert (filter_.label_metadata == np.array([[0, 0, 0], [0, 0, 1]])).all()
494+
495+
# test label key filtering
496+
497+
filter_ = evaluator.create_filter(label_keys=[])
498+
assert (filter_.indices == np.array([[]])).all()
499+
assert (filter_.label_metadata == np.array([[0, 0, 0], [0, 0, 1]])).all()
500+
501+
# test combo
502+
filter_ = evaluator.create_filter(
503+
datum_uids=[],
504+
label_keys=["k1"],
505+
)
506+
assert (filter_.indices == np.array([])).all()
507+
assert (filter_.label_metadata == np.array([[0, 0, 0], [0, 0, 1]])).all()
508+
509+
# test evaluation
510+
filter_ = evaluator.create_filter(datum_uids=[])
511+
512+
metrics = evaluator.evaluate(
513+
iou_thresholds=[0.5],
514+
filter_=filter_,
515+
metrics_to_return=[
516+
*MetricType.base_metrics(),
517+
MetricType.DetailedCounts,
518+
],
519+
)
520+
521+
actual_metrics = [m.to_dict() for m in metrics[MetricType.AP]]
522+
assert len(actual_metrics) == 0
523+
524+
414525
def test_filtering_random_detections():
415526
loader = DataLoader()
416527
loader.add_bounding_boxes(generate_random_detections(13, 4, "abc"))

lite/valor_lite/detection/manager.py

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -272,46 +272,35 @@ def create_filter(
272272
[self.uid_to_index[uid] for uid in datum_uids],
273273
dtype=np.int32,
274274
)
275-
mask = np.zeros_like(mask_pairs, dtype=np.bool_)
276-
mask[
277-
np.isin(self._ranked_pairs[:, 0].astype(int), datum_uids)
278-
] = True
279-
mask_pairs &= mask
280-
281-
mask = np.zeros_like(mask_datums, dtype=np.bool_)
282-
mask[datum_uids] = True
283-
mask_datums &= mask
275+
mask_pairs[
276+
~np.isin(self._ranked_pairs[:, 0].astype(int), datum_uids)
277+
] = False
278+
mask_datums[~np.isin(np.arange(n_datums), datum_uids)] = False
284279

285280
if labels is not None:
286281
if isinstance(labels, list):
287282
labels = np.array(
288283
[self.label_to_index[label] for label in labels]
289284
)
290-
mask = np.zeros_like(mask_pairs, dtype=np.bool_)
291-
mask[np.isin(self._ranked_pairs[:, 4].astype(int), labels)] = True
292-
mask_pairs &= mask
293-
294-
mask = np.zeros_like(mask_labels, dtype=np.bool_)
295-
mask[labels] = True
296-
mask_labels &= mask
285+
mask_pairs[
286+
~np.isin(self._ranked_pairs[:, 4].astype(int), labels)
287+
] = False
288+
mask_labels[~np.isin(np.arange(n_labels), labels)] = False
297289

298290
if label_keys is not None:
299291
if isinstance(label_keys, list):
300292
label_keys = np.array(
301293
[self.label_key_to_index[key] for key in label_keys]
302294
)
303-
label_indices = np.where(
304-
np.isclose(self._label_metadata[:, 2], label_keys)
305-
)[0]
306-
mask = np.zeros_like(mask_pairs, dtype=np.bool_)
307-
mask[
308-
np.isin(self._ranked_pairs[:, 4].astype(int), label_indices)
309-
] = True
310-
mask_pairs &= mask
311-
312-
mask = np.zeros_like(mask_labels, dtype=np.bool_)
313-
mask[label_indices] = True
314-
mask_labels &= mask
295+
label_indices = (
296+
np.where(np.isclose(self._label_metadata[:, 2], label_keys))[0]
297+
if label_keys.size > 0
298+
else np.array([])
299+
)
300+
mask_pairs[
301+
~np.isin(self._ranked_pairs[:, 4].astype(int), label_indices)
302+
] = False
303+
mask_labels[~np.isin(np.arange(n_labels), label_indices)] = False
315304

316305
mask = mask_datums[:, np.newaxis] & mask_labels[np.newaxis, :]
317306
label_metadata_per_datum = self._label_metadata_per_datum.copy()
@@ -399,7 +388,7 @@ def evaluate(
399388
)
400389
for iou_idx in range(average_precision.shape[0])
401390
for label_idx in range(average_precision.shape[1])
402-
if int(label_metadata[label_idx][0]) > 0
391+
if int(label_metadata[label_idx, 0]) > 0
403392
]
404393

405394
metrics[MetricType.mAP] = [
@@ -419,7 +408,7 @@ def evaluate(
419408
label=self.index_to_label[label_idx],
420409
)
421410
for label_idx in range(self.n_labels)
422-
if int(label_metadata[label_idx][0]) > 0
411+
if int(label_metadata[label_idx, 0]) > 0
423412
]
424413

425414
metrics[MetricType.mAPAveragedOverIOUs] = [
@@ -442,7 +431,7 @@ def evaluate(
442431
)
443432
for score_idx in range(average_recall.shape[0])
444433
for label_idx in range(average_recall.shape[1])
445-
if int(label_metadata[label_idx][0]) > 0
434+
if int(label_metadata[label_idx, 0]) > 0
446435
]
447436

448437
metrics[MetricType.mAR] = [
@@ -464,7 +453,7 @@ def evaluate(
464453
label=self.index_to_label[label_idx],
465454
)
466455
for label_idx in range(self.n_labels)
467-
if int(label_metadata[label_idx][0]) > 0
456+
if int(label_metadata[label_idx, 0]) > 0
468457
]
469458

470459
metrics[MetricType.mARAveragedOverScores] = [
@@ -487,16 +476,17 @@ def evaluate(
487476
)
488477
for iou_idx, iou_threshold in enumerate(iou_thresholds)
489478
for label_idx, label in self.index_to_label.items()
490-
if int(label_metadata[label_idx][0]) > 0
479+
if int(label_metadata[label_idx, 0]) > 0
491480
]
492481

493482
for label_idx, label in self.index_to_label.items():
483+
484+
if label_metadata[label_idx, 0] == 0:
485+
continue
486+
494487
for score_idx, score_threshold in enumerate(score_thresholds):
495488
for iou_idx, iou_threshold in enumerate(iou_thresholds):
496489

497-
if label_metadata[label_idx, 0] == 0:
498-
continue
499-
500490
row = precision_recall[iou_idx][score_idx][label_idx]
501491
kwargs = {
502492
"label": label,

0 commit comments

Comments
 (0)