-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
18b08b7
commit 6f2ecf5
Showing
5 changed files
with
747 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .optimized_implementation import _calculate_pr_curves_optimized | ||
from .valor_implementation import _calculate_pr_curves | ||
from .test_utils import generate_groundtruth, generate_predictions, pretty_print_tracemalloc_snapshot |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import pandas as pd | ||
from collections import defaultdict | ||
|
||
def _calculate_pr_curves_optimized( | ||
prediction_df: pd.DataFrame, | ||
groundtruth_df: pd.DataFrame, | ||
metrics_to_return: list, | ||
pr_curve_max_examples: int, | ||
): | ||
joint_df = ( | ||
pd.merge( | ||
groundtruth_df, | ||
prediction_df, | ||
on=["datum_id", "datum_uid", "label_key"], | ||
how="inner", | ||
suffixes=("_gt", "_pd"), | ||
).loc[ | ||
:, | ||
[ | ||
"datum_uid", | ||
"datum_id", | ||
"label_key", | ||
"label_value_gt", | ||
"id_gt", | ||
"label_value_pd", | ||
"score", | ||
"id_pd", | ||
], | ||
] | ||
) | ||
|
||
total_datums_per_label_key = joint_df.drop_duplicates(["datum_uid", "datum_id", "label_key"])["label_key"].value_counts() | ||
total_label_values_per_label_key = joint_df.drop_duplicates(["datum_uid", "datum_id", "label_key"])[["label_key", "label_value_gt"]].value_counts() | ||
|
||
joint_df = joint_df.assign( | ||
threshold_index=lambda chain_df: ( | ||
((joint_df["score"] * 100) // 5).astype("int32") | ||
), | ||
is_label_match=lambda chain_df: ( | ||
(chain_df["label_value_pd"] == chain_df["label_value_gt"]) | ||
) | ||
) | ||
|
||
|
||
true_positives = joint_df[joint_df["is_label_match"] == True][["label_key", "label_value_gt", "threshold_index"]].value_counts() | ||
## true_positives = true_positives.reset_index(2).sort_values("threshold_index").groupby(["label_key", "label_value_gt"]).cumsum() | ||
false_positives = joint_df[joint_df["is_label_match"] == False][["label_key", "label_value_pd", "threshold_index"]].value_counts() | ||
## false_positives = false_positives.reset_index(2).sort_values("threshold_index").groupby(["label_key", "label_value_pd"]).cumsum() | ||
|
||
dd = defaultdict(lambda: 0) | ||
confidence_thresholds = [x / 100 for x in range(5, 100, 5)] | ||
|
||
tps_keys = [] | ||
tps_values = [] | ||
tps_confidence = [] | ||
tps_cumulative = [] | ||
fns_cumulative = [] | ||
|
||
fps_keys = [] | ||
fps_values = [] | ||
fps_confidence = [] | ||
fps_cumulative = [] | ||
tns_cumulative = [] | ||
|
||
## Not sure what the efficient way of doing this is in pandas | ||
for label_key in true_positives.keys().get_level_values(0).unique(): | ||
for label_value in true_positives.keys().get_level_values(1).unique(): | ||
dd = true_positives[label_key][label_value].to_dict(into=dd) | ||
cumulative_true_positive = [0] * 21 | ||
cumulative_false_negative = [0] * 21 | ||
for threshold_index in range(19, -1, -1): | ||
cumulative_true_positive[threshold_index] = cumulative_true_positive[threshold_index + 1] + dd[threshold_index] | ||
cumulative_false_negative[threshold_index] = total_label_values_per_label_key[label_key][label_value] - cumulative_true_positive[threshold_index] | ||
|
||
tps_keys += [label_key] * 19 | ||
tps_values += [label_value] * 19 | ||
tps_confidence += confidence_thresholds | ||
tps_cumulative += cumulative_true_positive[1:-1] | ||
fns_cumulative += cumulative_false_negative[1:-1] | ||
|
||
## Not sure what the efficient way of doing this is in pandas | ||
for label_key in false_positives.keys().get_level_values(0).unique(): | ||
for label_value in false_positives.keys().get_level_values(1).unique(): | ||
dd = false_positives[label_key][label_value].to_dict(into=dd) | ||
cumulative_false_positive = [0] * 21 | ||
cumulative_true_negative = [0] * 21 | ||
for threshold_index in range(19, -1, -1): | ||
cumulative_false_positive[threshold_index] = cumulative_false_positive[threshold_index + 1] + dd[threshold_index] | ||
cumulative_true_negative[threshold_index] = total_datums_per_label_key[label_key] - total_label_values_per_label_key[label_key][label_value] - cumulative_false_positive[threshold_index] | ||
|
||
fps_keys += [label_key] * 19 | ||
fps_values += [label_value] * 19 | ||
fps_confidence += confidence_thresholds | ||
fps_cumulative += cumulative_false_positive[1:-1] | ||
tns_cumulative += cumulative_true_negative[1:-1] | ||
|
||
tps_df = pd.DataFrame({ | ||
"label_key": tps_keys, | ||
"label_value": tps_values, | ||
"confidence_threshold": tps_confidence, | ||
"true_positives": tps_cumulative, | ||
"false_negatives": fns_cumulative, | ||
}) | ||
|
||
fps_df = pd.DataFrame({ | ||
"label_key": fps_keys, | ||
"label_value": fps_values, | ||
"confidence_threshold": fps_confidence, | ||
"false_positives": fps_cumulative, | ||
"true_negatives": tns_cumulative, | ||
}) | ||
|
||
pr_curve_counts_df = pd.merge( | ||
tps_df, | ||
fps_df, | ||
on=["label_key", "label_value", "confidence_threshold"], | ||
how="outer", | ||
) | ||
|
||
pr_curve_counts_df.fillna(0, inplace=True) | ||
|
||
''' | ||
Pretty sure there's a bug with accuracy, it assumes that each `datum_id` has every `label_key`. | ||
pr_curve_counts_df["precision"] = pr_curve_counts_df["true_positives"] / ( | ||
pr_curve_counts_df["true_positives"] | ||
+ pr_curve_counts_df["false_positives"] | ||
) | ||
pr_curve_counts_df["recall"] = pr_curve_counts_df["true_positives"] / ( | ||
pr_curve_counts_df["true_positives"] | ||
+ pr_curve_counts_df["false_negatives"] | ||
) | ||
pr_curve_counts_df["accuracy"] = ( | ||
pr_curve_counts_df["true_positives"] | ||
+ pr_curve_counts_df["true_negatives"] | ||
) / len(unique_datum_ids) | ||
pr_curve_counts_df["f1_score"] = ( | ||
2 * pr_curve_counts_df["precision"] * pr_curve_counts_df["recall"] | ||
) / (pr_curve_counts_df["precision"] + pr_curve_counts_df["recall"]) | ||
# any NaNs that are left are from division by zero errors | ||
pr_curve_counts_df.fillna(-1, inplace=True) | ||
''' | ||
|
||
return pr_curve_counts_df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import time | ||
import random | ||
import numpy as np | ||
import pandas as pd | ||
|
||
import os | ||
import linecache | ||
import tracemalloc | ||
|
||
random.seed(time.time()) | ||
|
||
label_values = ["cat", "dog", "bee", "rat", "cow", "fox", "ant", "owl", "bat"] | ||
|
||
def generate_groundtruth(n, n_class=3): | ||
n_class = min(n_class, len(label_values)) | ||
n_class - max(1, n_class) | ||
|
||
classes = label_values[:n_class] | ||
|
||
df = pd.DataFrame({ | ||
"datum_uid": [f"uid{i}" for i in range(n)], | ||
"datum_id": [f"img{i}" for i in range(n)], | ||
"id": [f"gt{i}" for i in range(n)], | ||
"label_key": "class_label", | ||
"label_value": [random.choice(classes) for _ in range(n)], | ||
}) | ||
|
||
return df | ||
|
||
def generate_predictions(n, n_class=3, preds_per_datum=1): | ||
n_class = min(n_class, len(label_values)) | ||
n_class - max(1, n_class) | ||
|
||
classes = label_values[:n_class] | ||
|
||
preds_per_datum = min(1, preds_per_datum) | ||
preds_per_datum = max(n_class, preds_per_datum) | ||
|
||
all_labels = [] | ||
all_scores = [] | ||
for _ in range(n): | ||
labels = random.sample(classes, preds_per_datum) | ||
scores = [random.uniform(0,1) for _ in range(preds_per_datum)] | ||
total = sum(scores) | ||
for i in range(len(scores)): | ||
scores[i] /= total | ||
|
||
all_labels += labels | ||
all_scores += scores | ||
|
||
df = pd.DataFrame({ | ||
"datum_uid": np.repeat([f"uid{i}" for i in range(n)], preds_per_datum), | ||
"datum_id": np.repeat([f"img{i}" for i in range(n)], preds_per_datum), | ||
"id": [f"pd{i}" for i in range(n*preds_per_datum)], | ||
"label_key": "class_label", | ||
"label_value": all_labels, | ||
"score": all_scores, | ||
}) | ||
|
||
return df | ||
|
||
def pretty_print_tracemalloc_snapshot(snapshot, key_type='lineno', limit=3): | ||
snapshot = snapshot.filter_traces(( | ||
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"), | ||
tracemalloc.Filter(False, "<unknown>"), | ||
)) | ||
top_stats = snapshot.statistics(key_type) | ||
|
||
print("Top %s lines" % limit) | ||
for index, stat in enumerate(top_stats[:limit], 1): | ||
frame = stat.traceback[0] | ||
# replace "/path/to/module/file.py" with "module/file.py" | ||
filename = os.sep.join(frame.filename.split(os.sep)[-2:]) | ||
print("#%s: %s:%s: %.1f KiB" | ||
% (index, filename, frame.lineno, stat.size / 1024)) | ||
line = linecache.getline(frame.filename, frame.lineno).strip() | ||
if line: | ||
print(' %s' % line) | ||
|
||
other = top_stats[limit:] | ||
if other: | ||
size = sum(stat.size for stat in other) | ||
print("%s other: %.1f KiB" % (len(other), size / 1024)) | ||
total = sum(stat.size for stat in top_stats) | ||
print("Total allocated size: %.1f KiB" % (total / 1024)) |
Oops, something went wrong.