Skip to content

Commit 6f2ecf5

Browse files
optimized classification pr curve
1 parent 18b08b7 commit 6f2ecf5

File tree

5 files changed

+747
-0
lines changed

5 files changed

+747
-0
lines changed

valor-sandbox/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .optimized_implementation import _calculate_pr_curves_optimized
2+
from .valor_implementation import _calculate_pr_curves
3+
from .test_utils import generate_groundtruth, generate_predictions, pretty_print_tracemalloc_snapshot
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import pandas as pd
2+
from collections import defaultdict
3+
4+
def _calculate_pr_curves_optimized(
5+
prediction_df: pd.DataFrame,
6+
groundtruth_df: pd.DataFrame,
7+
metrics_to_return: list,
8+
pr_curve_max_examples: int,
9+
):
10+
joint_df = (
11+
pd.merge(
12+
groundtruth_df,
13+
prediction_df,
14+
on=["datum_id", "datum_uid", "label_key"],
15+
how="inner",
16+
suffixes=("_gt", "_pd"),
17+
).loc[
18+
:,
19+
[
20+
"datum_uid",
21+
"datum_id",
22+
"label_key",
23+
"label_value_gt",
24+
"id_gt",
25+
"label_value_pd",
26+
"score",
27+
"id_pd",
28+
],
29+
]
30+
)
31+
32+
total_datums_per_label_key = joint_df.drop_duplicates(["datum_uid", "datum_id", "label_key"])["label_key"].value_counts()
33+
total_label_values_per_label_key = joint_df.drop_duplicates(["datum_uid", "datum_id", "label_key"])[["label_key", "label_value_gt"]].value_counts()
34+
35+
joint_df = joint_df.assign(
36+
threshold_index=lambda chain_df: (
37+
((joint_df["score"] * 100) // 5).astype("int32")
38+
),
39+
is_label_match=lambda chain_df: (
40+
(chain_df["label_value_pd"] == chain_df["label_value_gt"])
41+
)
42+
)
43+
44+
45+
true_positives = joint_df[joint_df["is_label_match"] == True][["label_key", "label_value_gt", "threshold_index"]].value_counts()
46+
## true_positives = true_positives.reset_index(2).sort_values("threshold_index").groupby(["label_key", "label_value_gt"]).cumsum()
47+
false_positives = joint_df[joint_df["is_label_match"] == False][["label_key", "label_value_pd", "threshold_index"]].value_counts()
48+
## false_positives = false_positives.reset_index(2).sort_values("threshold_index").groupby(["label_key", "label_value_pd"]).cumsum()
49+
50+
dd = defaultdict(lambda: 0)
51+
confidence_thresholds = [x / 100 for x in range(5, 100, 5)]
52+
53+
tps_keys = []
54+
tps_values = []
55+
tps_confidence = []
56+
tps_cumulative = []
57+
fns_cumulative = []
58+
59+
fps_keys = []
60+
fps_values = []
61+
fps_confidence = []
62+
fps_cumulative = []
63+
tns_cumulative = []
64+
65+
## Not sure what the efficient way of doing this is in pandas
66+
for label_key in true_positives.keys().get_level_values(0).unique():
67+
for label_value in true_positives.keys().get_level_values(1).unique():
68+
dd = true_positives[label_key][label_value].to_dict(into=dd)
69+
cumulative_true_positive = [0] * 21
70+
cumulative_false_negative = [0] * 21
71+
for threshold_index in range(19, -1, -1):
72+
cumulative_true_positive[threshold_index] = cumulative_true_positive[threshold_index + 1] + dd[threshold_index]
73+
cumulative_false_negative[threshold_index] = total_label_values_per_label_key[label_key][label_value] - cumulative_true_positive[threshold_index]
74+
75+
tps_keys += [label_key] * 19
76+
tps_values += [label_value] * 19
77+
tps_confidence += confidence_thresholds
78+
tps_cumulative += cumulative_true_positive[1:-1]
79+
fns_cumulative += cumulative_false_negative[1:-1]
80+
81+
## Not sure what the efficient way of doing this is in pandas
82+
for label_key in false_positives.keys().get_level_values(0).unique():
83+
for label_value in false_positives.keys().get_level_values(1).unique():
84+
dd = false_positives[label_key][label_value].to_dict(into=dd)
85+
cumulative_false_positive = [0] * 21
86+
cumulative_true_negative = [0] * 21
87+
for threshold_index in range(19, -1, -1):
88+
cumulative_false_positive[threshold_index] = cumulative_false_positive[threshold_index + 1] + dd[threshold_index]
89+
cumulative_true_negative[threshold_index] = total_datums_per_label_key[label_key] - total_label_values_per_label_key[label_key][label_value] - cumulative_false_positive[threshold_index]
90+
91+
fps_keys += [label_key] * 19
92+
fps_values += [label_value] * 19
93+
fps_confidence += confidence_thresholds
94+
fps_cumulative += cumulative_false_positive[1:-1]
95+
tns_cumulative += cumulative_true_negative[1:-1]
96+
97+
tps_df = pd.DataFrame({
98+
"label_key": tps_keys,
99+
"label_value": tps_values,
100+
"confidence_threshold": tps_confidence,
101+
"true_positives": tps_cumulative,
102+
"false_negatives": fns_cumulative,
103+
})
104+
105+
fps_df = pd.DataFrame({
106+
"label_key": fps_keys,
107+
"label_value": fps_values,
108+
"confidence_threshold": fps_confidence,
109+
"false_positives": fps_cumulative,
110+
"true_negatives": tns_cumulative,
111+
})
112+
113+
pr_curve_counts_df = pd.merge(
114+
tps_df,
115+
fps_df,
116+
on=["label_key", "label_value", "confidence_threshold"],
117+
how="outer",
118+
)
119+
120+
pr_curve_counts_df.fillna(0, inplace=True)
121+
122+
'''
123+
Pretty sure there's a bug with accuracy, it assumes that each `datum_id` has every `label_key`.
124+
125+
pr_curve_counts_df["precision"] = pr_curve_counts_df["true_positives"] / (
126+
pr_curve_counts_df["true_positives"]
127+
+ pr_curve_counts_df["false_positives"]
128+
)
129+
pr_curve_counts_df["recall"] = pr_curve_counts_df["true_positives"] / (
130+
pr_curve_counts_df["true_positives"]
131+
+ pr_curve_counts_df["false_negatives"]
132+
)
133+
pr_curve_counts_df["accuracy"] = (
134+
pr_curve_counts_df["true_positives"]
135+
+ pr_curve_counts_df["true_negatives"]
136+
) / len(unique_datum_ids)
137+
pr_curve_counts_df["f1_score"] = (
138+
2 * pr_curve_counts_df["precision"] * pr_curve_counts_df["recall"]
139+
) / (pr_curve_counts_df["precision"] + pr_curve_counts_df["recall"])
140+
141+
# any NaNs that are left are from division by zero errors
142+
pr_curve_counts_df.fillna(-1, inplace=True)
143+
'''
144+
145+
return pr_curve_counts_df

valor-sandbox/test_utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import time
2+
import random
3+
import numpy as np
4+
import pandas as pd
5+
6+
import os
7+
import linecache
8+
import tracemalloc
9+
10+
random.seed(time.time())
11+
12+
label_values = ["cat", "dog", "bee", "rat", "cow", "fox", "ant", "owl", "bat"]
13+
14+
def generate_groundtruth(n, n_class=3):
15+
n_class = min(n_class, len(label_values))
16+
n_class - max(1, n_class)
17+
18+
classes = label_values[:n_class]
19+
20+
df = pd.DataFrame({
21+
"datum_uid": [f"uid{i}" for i in range(n)],
22+
"datum_id": [f"img{i}" for i in range(n)],
23+
"id": [f"gt{i}" for i in range(n)],
24+
"label_key": "class_label",
25+
"label_value": [random.choice(classes) for _ in range(n)],
26+
})
27+
28+
return df
29+
30+
def generate_predictions(n, n_class=3, preds_per_datum=1):
31+
n_class = min(n_class, len(label_values))
32+
n_class - max(1, n_class)
33+
34+
classes = label_values[:n_class]
35+
36+
preds_per_datum = min(1, preds_per_datum)
37+
preds_per_datum = max(n_class, preds_per_datum)
38+
39+
all_labels = []
40+
all_scores = []
41+
for _ in range(n):
42+
labels = random.sample(classes, preds_per_datum)
43+
scores = [random.uniform(0,1) for _ in range(preds_per_datum)]
44+
total = sum(scores)
45+
for i in range(len(scores)):
46+
scores[i] /= total
47+
48+
all_labels += labels
49+
all_scores += scores
50+
51+
df = pd.DataFrame({
52+
"datum_uid": np.repeat([f"uid{i}" for i in range(n)], preds_per_datum),
53+
"datum_id": np.repeat([f"img{i}" for i in range(n)], preds_per_datum),
54+
"id": [f"pd{i}" for i in range(n*preds_per_datum)],
55+
"label_key": "class_label",
56+
"label_value": all_labels,
57+
"score": all_scores,
58+
})
59+
60+
return df
61+
62+
def pretty_print_tracemalloc_snapshot(snapshot, key_type='lineno', limit=3):
63+
snapshot = snapshot.filter_traces((
64+
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
65+
tracemalloc.Filter(False, "<unknown>"),
66+
))
67+
top_stats = snapshot.statistics(key_type)
68+
69+
print("Top %s lines" % limit)
70+
for index, stat in enumerate(top_stats[:limit], 1):
71+
frame = stat.traceback[0]
72+
# replace "/path/to/module/file.py" with "module/file.py"
73+
filename = os.sep.join(frame.filename.split(os.sep)[-2:])
74+
print("#%s: %s:%s: %.1f KiB"
75+
% (index, filename, frame.lineno, stat.size / 1024))
76+
line = linecache.getline(frame.filename, frame.lineno).strip()
77+
if line:
78+
print(' %s' % line)
79+
80+
other = top_stats[limit:]
81+
if other:
82+
size = sum(stat.size for stat in other)
83+
print("%s other: %.1f KiB" % (len(other), size / 1024))
84+
total = sum(stat.size for stat in top_stats)
85+
print("Total allocated size: %.1f KiB" % (total / 1024))

0 commit comments

Comments
 (0)