Skip to content

Commit

Permalink
BugFix Valor Lite Filtering (#767)
Browse files Browse the repository at this point in the history
  • Loading branch information
czaloom authored Sep 27, 2024
1 parent 8778b17 commit 1fa9b62
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 36 deletions.
111 changes: 111 additions & 0 deletions lite/tests/detection/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,117 @@ def test_filtering_four_detections(four_detections: list[Detection]):
assert m in actual_metrics


def test_filtering_all_detections(four_detections: list[Detection]):
"""
Basic object detection test.
groundtruths
datum uid1
box 1 - label (k1, v1) - tp
box 3 - label (k2, v2) - fn missing prediction
datum uid2
box 2 - label (k1, v1) - fn misclassification
datum uid3
box 1 - label (k1, v1) - tp
box 3 - label (k2, v2) - fn missing prediction
datum uid4
box 2 - label (k1, v1) - fn misclassification
predictions
datum uid1
box 1 - label (k1, v1) - score 0.3 - tp
datum uid2
box 2 - label (k2, v2) - score 0.98 - fp
datum uid3
box 1 - label (k1, v1) - score 0.3 - tp
datum uid4
box 2 - label (k2, v2) - score 0.98 - fp
"""

loader = DataLoader()
loader.add_bounding_boxes(four_detections)
evaluator = loader.finalize()

assert (
evaluator._ranked_pairs
== np.array(
[
[1.0, -1.0, 0.0, 0.0, -1.0, 1.0, 0.98],
[3.0, -1.0, 0.0, 0.0, -1.0, 1.0, 0.98],
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.3],
[2.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.3],
]
)
).all()

assert (
evaluator._label_metadata_per_datum
== np.array(
[
[
[1, 1],
[1, 0],
[1, 1],
[1, 0],
],
[
[1, 0],
[0, 1],
[1, 0],
[0, 1],
],
],
dtype=np.int32,
)
).all()

assert (
evaluator._label_metadata == np.array([[4, 2, 0], [2, 2, 1]])
).all()

# test datum filtering

filter_ = evaluator.create_filter(datum_uids=[])
print(filter_)
assert (filter_.indices == np.array([])).all()
assert (filter_.label_metadata == np.array([[0, 0, 0], [0, 0, 1]])).all()

# test label filtering

filter_ = evaluator.create_filter(labels=[])
assert (filter_.indices == np.array([])).all()
assert (filter_.label_metadata == np.array([[0, 0, 0], [0, 0, 1]])).all()

# test label key filtering

filter_ = evaluator.create_filter(label_keys=[])
assert (filter_.indices == np.array([[]])).all()
assert (filter_.label_metadata == np.array([[0, 0, 0], [0, 0, 1]])).all()

# test combo
filter_ = evaluator.create_filter(
datum_uids=[],
label_keys=["k1"],
)
assert (filter_.indices == np.array([])).all()
assert (filter_.label_metadata == np.array([[0, 0, 0], [0, 0, 1]])).all()

# test evaluation
filter_ = evaluator.create_filter(datum_uids=[])

metrics = evaluator.evaluate(
iou_thresholds=[0.5],
filter_=filter_,
metrics_to_return=[
*MetricType.base_metrics(),
MetricType.DetailedCounts,
],
)

actual_metrics = [m.to_dict() for m in metrics[MetricType.AP]]
assert len(actual_metrics) == 0


def test_filtering_random_detections():
loader = DataLoader()
loader.add_bounding_boxes(generate_random_detections(13, 4, "abc"))
Expand Down
62 changes: 26 additions & 36 deletions lite/valor_lite/detection/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,46 +272,35 @@ def create_filter(
[self.uid_to_index[uid] for uid in datum_uids],
dtype=np.int32,
)
mask = np.zeros_like(mask_pairs, dtype=np.bool_)
mask[
np.isin(self._ranked_pairs[:, 0].astype(int), datum_uids)
] = True
mask_pairs &= mask

mask = np.zeros_like(mask_datums, dtype=np.bool_)
mask[datum_uids] = True
mask_datums &= mask
mask_pairs[
~np.isin(self._ranked_pairs[:, 0].astype(int), datum_uids)
] = False
mask_datums[~np.isin(np.arange(n_datums), datum_uids)] = False

if labels is not None:
if isinstance(labels, list):
labels = np.array(
[self.label_to_index[label] for label in labels]
)
mask = np.zeros_like(mask_pairs, dtype=np.bool_)
mask[np.isin(self._ranked_pairs[:, 4].astype(int), labels)] = True
mask_pairs &= mask

mask = np.zeros_like(mask_labels, dtype=np.bool_)
mask[labels] = True
mask_labels &= mask
mask_pairs[
~np.isin(self._ranked_pairs[:, 4].astype(int), labels)
] = False
mask_labels[~np.isin(np.arange(n_labels), labels)] = False

if label_keys is not None:
if isinstance(label_keys, list):
label_keys = np.array(
[self.label_key_to_index[key] for key in label_keys]
)
label_indices = np.where(
np.isclose(self._label_metadata[:, 2], label_keys)
)[0]
mask = np.zeros_like(mask_pairs, dtype=np.bool_)
mask[
np.isin(self._ranked_pairs[:, 4].astype(int), label_indices)
] = True
mask_pairs &= mask

mask = np.zeros_like(mask_labels, dtype=np.bool_)
mask[label_indices] = True
mask_labels &= mask
label_indices = (
np.where(np.isclose(self._label_metadata[:, 2], label_keys))[0]
if label_keys.size > 0
else np.array([])
)
mask_pairs[
~np.isin(self._ranked_pairs[:, 4].astype(int), label_indices)
] = False
mask_labels[~np.isin(np.arange(n_labels), label_indices)] = False

mask = mask_datums[:, np.newaxis] & mask_labels[np.newaxis, :]
label_metadata_per_datum = self._label_metadata_per_datum.copy()
Expand Down Expand Up @@ -399,7 +388,7 @@ def evaluate(
)
for iou_idx in range(average_precision.shape[0])
for label_idx in range(average_precision.shape[1])
if int(label_metadata[label_idx][0]) > 0
if int(label_metadata[label_idx, 0]) > 0
]

metrics[MetricType.mAP] = [
Expand All @@ -419,7 +408,7 @@ def evaluate(
label=self.index_to_label[label_idx],
)
for label_idx in range(self.n_labels)
if int(label_metadata[label_idx][0]) > 0
if int(label_metadata[label_idx, 0]) > 0
]

metrics[MetricType.mAPAveragedOverIOUs] = [
Expand All @@ -442,7 +431,7 @@ def evaluate(
)
for score_idx in range(average_recall.shape[0])
for label_idx in range(average_recall.shape[1])
if int(label_metadata[label_idx][0]) > 0
if int(label_metadata[label_idx, 0]) > 0
]

metrics[MetricType.mAR] = [
Expand All @@ -464,7 +453,7 @@ def evaluate(
label=self.index_to_label[label_idx],
)
for label_idx in range(self.n_labels)
if int(label_metadata[label_idx][0]) > 0
if int(label_metadata[label_idx, 0]) > 0
]

metrics[MetricType.mARAveragedOverScores] = [
Expand All @@ -487,16 +476,17 @@ def evaluate(
)
for iou_idx, iou_threshold in enumerate(iou_thresholds)
for label_idx, label in self.index_to_label.items()
if int(label_metadata[label_idx][0]) > 0
if int(label_metadata[label_idx, 0]) > 0
]

for label_idx, label in self.index_to_label.items():

if label_metadata[label_idx, 0] == 0:
continue

for score_idx, score_threshold in enumerate(score_thresholds):
for iou_idx, iou_threshold in enumerate(iou_thresholds):

if label_metadata[label_idx, 0] == 0:
continue

row = precision_recall[iou_idx][score_idx][label_idx]
kwargs = {
"label": label,
Expand Down

0 comments on commit 1fa9b62

Please sign in to comment.