Skip to content

Commit 2a49693

Browse files
authored
Add scores to PrecisionRecallCurve (#796)
1 parent f1546ee commit 2a49693

File tree

7 files changed

+189
-67
lines changed

7 files changed

+189
-67
lines changed

lite/benchmarks/benchmark_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def run_benchmarking_analysis(
211211
)
212212

213213
# evaluate
214-
eval_time, _ = time_it(evaluator.evaluate)()
214+
eval_time, _ = time_it(evaluator.compute_precision_recall)()
215215
if eval_time > evaluation_timeout and evaluation_timeout != -1:
216216
raise TimeoutError(
217217
f"Base evaluation timed out with {evaluator.n_datums} datums."

lite/benchmarks/benchmark_objdet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def run_benchmarking_analysis(
322322
)
323323

324324
# evaluate - base metrics only
325-
eval_time, metrics = time_it(evaluator.evaluate)()
325+
eval_time, metrics = time_it(evaluator.compute_precision_recall)()
326326
if eval_time > evaluation_timeout and evaluation_timeout != -1:
327327
raise TimeoutError(
328328
f"Base evaluation timed out with {evaluator.n_datums} datums."

lite/examples/object-detection.ipynb

Lines changed: 42 additions & 12 deletions
Large diffs are not rendered by default.

lite/tests/object_detection/test_pr_curve.py

Lines changed: 97 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,15 @@ def test_pr_curve_simple():
3131
score_thresholds=score_thresholds,
3232
)
3333

34-
assert pr_curve.shape == (2, 1, 101)
35-
assert np.isclose(pr_curve[0][0], 1.0).all()
36-
assert np.isclose(pr_curve[1][0], 1 / 3).all()
34+
assert pr_curve.shape == (2, 1, 101, 2)
35+
36+
# test precision values
37+
assert np.isclose(pr_curve[0, 0, :, 0], 1.0).all()
38+
assert np.isclose(pr_curve[1, 0, :, 0], 1 / 3).all()
39+
40+
# test score values
41+
assert np.isclose(pr_curve[0, 0, :, 1], 0.95).all()
42+
assert np.isclose(pr_curve[1, 0, :, 1], 0.65).all()
3743

3844

3945
def test_pr_curve_using_torch_metrics_example(
@@ -59,111 +65,162 @@ def test_pr_curve_using_torch_metrics_example(
5965
as_dict=True,
6066
)
6167

62-
# AP = 1.0
63-
a = [1.0 for _ in range(101)]
64-
65-
# AP = 0.505
66-
b = [1.0 for _ in range(51)] + [0.0 for _ in range(50)]
67-
68-
# AP = 0.791
69-
c = (
70-
[1.0 for _ in range(71)]
71-
+ [8 / 9 for _ in range(10)]
72-
+ [0.0 for _ in range(20)]
73-
)
74-
75-
# AP = 0.722
76-
d = (
77-
[1.0 for _ in range(41)]
78-
+ [0.8 for _ in range(40)]
79-
+ [0.0 for _ in range(20)]
80-
)
81-
82-
# AP = 0.576
83-
e = (
84-
[1.0 for _ in range(41)]
85-
+ [0.8571428571428571 for _ in range(20)]
86-
+ [0.0 for _ in range(40)]
87-
)
88-
8968
# test PrecisionRecallCurve
9069
actual_metrics = [m for m in metrics[MetricType.PrecisionRecallCurve]]
9170
expected_metrics = [
9271
{
9372
"type": "PrecisionRecallCurve",
94-
"value": a,
73+
"value": {
74+
"precisions": [1.0 for _ in range(101)],
75+
"scores": (
76+
[0.953 for _ in range(21)]
77+
+ [0.805 for _ in range(20)]
78+
+ [0.611 for _ in range(20)]
79+
+ [0.407 for _ in range(20)]
80+
+ [0.335 for _ in range(20)]
81+
),
82+
},
9583
"parameters": {
9684
"iou_threshold": 0.5,
9785
"label": "0",
9886
},
9987
},
10088
{
10189
"type": "PrecisionRecallCurve",
102-
"value": d,
90+
"value": {
91+
"precisions": (
92+
[1.0 for _ in range(41)]
93+
+ [0.8 for _ in range(40)]
94+
+ [0.0 for _ in range(20)]
95+
),
96+
"scores": (
97+
[0.953 for _ in range(21)]
98+
+ [0.805 for _ in range(20)]
99+
+ [0.407 for _ in range(20)]
100+
+ [0.335 for _ in range(20)]
101+
+ [0.0 for _ in range(20)]
102+
),
103+
},
103104
"parameters": {
104105
"iou_threshold": 0.75,
105106
"label": "0",
106107
},
107108
},
108109
{
109110
"type": "PrecisionRecallCurve",
110-
"value": a,
111+
"value": {
112+
"precisions": [1.0 for _ in range(101)],
113+
"scores": [0.3 for _ in range(101)],
114+
},
111115
"parameters": {
112116
"iou_threshold": 0.5,
113117
"label": "1",
114118
},
115119
},
116120
{
117121
"type": "PrecisionRecallCurve",
118-
"value": a,
122+
"value": {
123+
"precisions": [1.0 for _ in range(101)],
124+
"scores": [0.3 for _ in range(101)],
125+
},
119126
"parameters": {
120127
"iou_threshold": 0.75,
121128
"label": "1",
122129
},
123130
},
124131
{
125132
"type": "PrecisionRecallCurve",
126-
"value": b,
133+
"value": {
134+
"precisions": [1.0 for _ in range(51)]
135+
+ [0.0 for _ in range(50)],
136+
"scores": [0.726 for _ in range(51)]
137+
+ [0.0 for _ in range(50)],
138+
},
127139
"parameters": {
128140
"iou_threshold": 0.5,
129141
"label": "2",
130142
},
131143
},
132144
{
133145
"type": "PrecisionRecallCurve",
134-
"value": b,
146+
"value": {
147+
"precisions": [1.0 for _ in range(51)]
148+
+ [0.0 for _ in range(50)],
149+
"scores": [0.726 for _ in range(51)]
150+
+ [0.0 for _ in range(50)],
151+
},
135152
"parameters": {
136153
"iou_threshold": 0.75,
137154
"label": "2",
138155
},
139156
},
140157
{
141158
"type": "PrecisionRecallCurve",
142-
"value": a,
159+
"value": {
160+
"precisions": [1.0 for _ in range(101)],
161+
"scores": [0.546 for _ in range(51)]
162+
+ [0.236 for _ in range(50)],
163+
},
143164
"parameters": {
144165
"iou_threshold": 0.5,
145166
"label": "4",
146167
},
147168
},
148169
{
149170
"type": "PrecisionRecallCurve",
150-
"value": a,
171+
"value": {
172+
"precisions": [1.0 for _ in range(101)],
173+
"scores": [0.546 for _ in range(51)]
174+
+ [0.236 for _ in range(50)],
175+
},
151176
"parameters": {
152177
"iou_threshold": 0.75,
153178
"label": "4",
154179
},
155180
},
156181
{
157182
"type": "PrecisionRecallCurve",
158-
"value": c,
183+
"value": {
184+
"precisions": (
185+
[1.0 for _ in range(71)]
186+
+ [8 / 9 for _ in range(10)]
187+
+ [0.0 for _ in range(20)]
188+
),
189+
"scores": (
190+
[0.883 for _ in range(11)]
191+
+ [0.782 for _ in range(10)]
192+
+ [0.561 for _ in range(10)]
193+
+ [0.532 for _ in range(10)]
194+
+ [0.349 for _ in range(10)]
195+
+ [0.271 for _ in range(10)]
196+
+ [0.204 for _ in range(10)]
197+
+ [0.202 for _ in range(10)]
198+
+ [0.0 for _ in range(20)]
199+
),
200+
},
159201
"parameters": {
160202
"iou_threshold": 0.5,
161203
"label": "49",
162204
},
163205
},
164206
{
165207
"type": "PrecisionRecallCurve",
166-
"value": e,
208+
"value": {
209+
"precisions": (
210+
[1.0 for _ in range(41)]
211+
+ [0.8571428571428571 for _ in range(20)]
212+
+ [0.0 for _ in range(40)]
213+
),
214+
"scores": (
215+
[0.883 for _ in range(11)]
216+
+ [0.782 for _ in range(10)]
217+
+ [0.561 for _ in range(10)]
218+
+ [0.532 for _ in range(10)]
219+
+ [0.271 for _ in range(10)]
220+
+ [0.204 for _ in range(10)]
221+
+ [0.0 for _ in range(40)]
222+
),
223+
},
167224
"parameters": {
168225
"iou_threshold": 0.75,
169226
"label": "49",

lite/valor_lite/object_detection/computation.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -334,15 +334,17 @@ def compute_metrics(
334334
counts = np.zeros((n_ious, n_scores, n_labels, 7))
335335

336336
pd_labels = data[:, 5].astype(int)
337-
unique_pd_labels = np.unique(pd_labels)
337+
scores = data[:, 6]
338+
unique_pd_labels, unique_pd_indices = np.unique(
339+
pd_labels, return_index=True
340+
)
338341
gt_count = label_metadata[:, 0]
339342
running_total_count = np.zeros(
340343
(n_ious, n_rows),
341344
dtype=np.float64,
342345
)
343346
running_tp_count = np.zeros_like(running_total_count)
344347
running_gt_count = np.zeros_like(running_total_count)
345-
pr_curve = np.zeros((n_ious, n_labels, 101))
346348

347349
mask_score_nonzero = data[:, 6] > 1e-9
348350
mask_gt_exists = data[:, 1] >= 0.0
@@ -475,20 +477,42 @@ def compute_metrics(
475477
out=recall,
476478
)
477479
recall_index = np.floor(recall * 100.0).astype(int)
480+
481+
# bin precision-recall curve
482+
pr_curve = np.zeros((n_ious, n_labels, 101, 2))
478483
for iou_idx in range(n_ious):
479484
p = precision[iou_idx]
480485
r = recall_index[iou_idx]
481-
pr_curve[iou_idx, pd_labels, r] = np.maximum(
482-
pr_curve[iou_idx, pd_labels, r], p
486+
pr_curve[iou_idx, pd_labels, r, 0] = np.maximum(
487+
pr_curve[iou_idx, pd_labels, r, 0],
488+
p,
489+
)
490+
pr_curve[iou_idx, pd_labels, r, 1] = np.maximum(
491+
pr_curve[iou_idx, pd_labels, r, 1],
492+
scores,
483493
)
484494

485495
# calculate average precision
486-
running_max = np.zeros((n_ious, n_labels))
496+
running_max_precision = np.zeros((n_ious, n_labels))
497+
running_max_score = np.zeros((n_labels))
487498
for recall in range(100, -1, -1):
488-
precision = pr_curve[:, :, recall]
489-
running_max = np.maximum(precision, running_max)
490-
average_precision += running_max
491-
pr_curve[:, :, recall] = running_max
499+
500+
# running max precision
501+
running_max_precision = np.maximum(
502+
pr_curve[:, :, recall, 0],
503+
running_max_precision,
504+
)
505+
pr_curve[:, :, recall, 0] = running_max_precision
506+
507+
# running max score
508+
running_max_score = np.maximum(
509+
pr_curve[:, :, recall, 1],
510+
running_max_score,
511+
)
512+
pr_curve[:, :, recall, 1] = running_max_score
513+
514+
average_precision += running_max_precision
515+
492516
average_precision = average_precision / 101.0
493517

494518
# calculate average recall

lite/valor_lite/object_detection/manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,12 @@ def compute_precision_recall(
595595

596596
metrics[MetricType.PrecisionRecallCurve] = [
597597
PrecisionRecallCurve(
598-
precision=pr_curves[iou_idx][label_idx].astype(float).tolist(),
598+
precisions=pr_curves[iou_idx, label_idx, :, 0]
599+
.astype(float)
600+
.tolist(),
601+
scores=pr_curves[iou_idx, label_idx, :, 1]
602+
.astype(float)
603+
.tolist(),
599604
iou_threshold=iou_threshold,
600605
label=label,
601606
)

lite/valor_lite/object_detection/metric.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -591,8 +591,10 @@ class PrecisionRecallCurve:
591591
592592
Attributes
593593
----------
594-
precision : list[float]
594+
precisions : list[float]
595595
Interpolated precision values corresponding to recalls at 0.0, 0.01, ..., 1.0.
596+
scores : list[float]
597+
Maximum prediction score for each point on the interpolated curve.
596598
iou_threshold : float
597599
The Intersection over Union (IoU) threshold used to determine true positives.
598600
label : str
@@ -606,14 +608,18 @@ class PrecisionRecallCurve:
606608
Converts the instance to a dictionary representation.
607609
"""
608610

609-
precision: list[float]
611+
precisions: list[float]
612+
scores: list[float]
610613
iou_threshold: float
611614
label: str
612615

613616
def to_metric(self) -> Metric:
614617
return Metric(
615618
type=type(self).__name__,
616-
value=self.precision,
619+
value={
620+
"precisions": self.precisions,
621+
"scores": self.scores,
622+
},
617623
parameters={
618624
"iou_threshold": self.iou_threshold,
619625
"label": self.label,

0 commit comments

Comments
 (0)