Skip to content

Commit

Permalink
optimized classification pr curve
Browse files Browse the repository at this point in the history
  • Loading branch information
jqu-striveworks committed Aug 22, 2024
1 parent 18b08b7 commit 6f2ecf5
Show file tree
Hide file tree
Showing 5 changed files with 747 additions and 0 deletions.
3 changes: 3 additions & 0 deletions valor-sandbox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .optimized_implementation import _calculate_pr_curves_optimized
from .valor_implementation import _calculate_pr_curves
from .test_utils import generate_groundtruth, generate_predictions, pretty_print_tracemalloc_snapshot
145 changes: 145 additions & 0 deletions valor-sandbox/optimized_implementation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import pandas as pd
from collections import defaultdict

def _calculate_pr_curves_optimized(
prediction_df: pd.DataFrame,
groundtruth_df: pd.DataFrame,
metrics_to_return: list,
pr_curve_max_examples: int,
):
joint_df = (
pd.merge(
groundtruth_df,
prediction_df,
on=["datum_id", "datum_uid", "label_key"],
how="inner",
suffixes=("_gt", "_pd"),
).loc[
:,
[
"datum_uid",
"datum_id",
"label_key",
"label_value_gt",
"id_gt",
"label_value_pd",
"score",
"id_pd",
],
]
)

total_datums_per_label_key = joint_df.drop_duplicates(["datum_uid", "datum_id", "label_key"])["label_key"].value_counts()
total_label_values_per_label_key = joint_df.drop_duplicates(["datum_uid", "datum_id", "label_key"])[["label_key", "label_value_gt"]].value_counts()

joint_df = joint_df.assign(
threshold_index=lambda chain_df: (
((joint_df["score"] * 100) // 5).astype("int32")
),
is_label_match=lambda chain_df: (
(chain_df["label_value_pd"] == chain_df["label_value_gt"])
)
)


true_positives = joint_df[joint_df["is_label_match"] == True][["label_key", "label_value_gt", "threshold_index"]].value_counts()
## true_positives = true_positives.reset_index(2).sort_values("threshold_index").groupby(["label_key", "label_value_gt"]).cumsum()
false_positives = joint_df[joint_df["is_label_match"] == False][["label_key", "label_value_pd", "threshold_index"]].value_counts()
## false_positives = false_positives.reset_index(2).sort_values("threshold_index").groupby(["label_key", "label_value_pd"]).cumsum()

dd = defaultdict(lambda: 0)
confidence_thresholds = [x / 100 for x in range(5, 100, 5)]

tps_keys = []
tps_values = []
tps_confidence = []
tps_cumulative = []
fns_cumulative = []

fps_keys = []
fps_values = []
fps_confidence = []
fps_cumulative = []
tns_cumulative = []

## Not sure what the efficient way of doing this is in pandas
for label_key in true_positives.keys().get_level_values(0).unique():
for label_value in true_positives.keys().get_level_values(1).unique():
dd = true_positives[label_key][label_value].to_dict(into=dd)
cumulative_true_positive = [0] * 21
cumulative_false_negative = [0] * 21
for threshold_index in range(19, -1, -1):
cumulative_true_positive[threshold_index] = cumulative_true_positive[threshold_index + 1] + dd[threshold_index]
cumulative_false_negative[threshold_index] = total_label_values_per_label_key[label_key][label_value] - cumulative_true_positive[threshold_index]

tps_keys += [label_key] * 19
tps_values += [label_value] * 19
tps_confidence += confidence_thresholds
tps_cumulative += cumulative_true_positive[1:-1]
fns_cumulative += cumulative_false_negative[1:-1]

## Not sure what the efficient way of doing this is in pandas
for label_key in false_positives.keys().get_level_values(0).unique():
for label_value in false_positives.keys().get_level_values(1).unique():
dd = false_positives[label_key][label_value].to_dict(into=dd)
cumulative_false_positive = [0] * 21
cumulative_true_negative = [0] * 21
for threshold_index in range(19, -1, -1):
cumulative_false_positive[threshold_index] = cumulative_false_positive[threshold_index + 1] + dd[threshold_index]
cumulative_true_negative[threshold_index] = total_datums_per_label_key[label_key] - total_label_values_per_label_key[label_key][label_value] - cumulative_false_positive[threshold_index]

fps_keys += [label_key] * 19
fps_values += [label_value] * 19
fps_confidence += confidence_thresholds
fps_cumulative += cumulative_false_positive[1:-1]
tns_cumulative += cumulative_true_negative[1:-1]

tps_df = pd.DataFrame({
"label_key": tps_keys,
"label_value": tps_values,
"confidence_threshold": tps_confidence,
"true_positives": tps_cumulative,
"false_negatives": fns_cumulative,
})

fps_df = pd.DataFrame({
"label_key": fps_keys,
"label_value": fps_values,
"confidence_threshold": fps_confidence,
"false_positives": fps_cumulative,
"true_negatives": tns_cumulative,
})

pr_curve_counts_df = pd.merge(
tps_df,
fps_df,
on=["label_key", "label_value", "confidence_threshold"],
how="outer",
)

pr_curve_counts_df.fillna(0, inplace=True)

'''
Pretty sure there's a bug with accuracy, it assumes that each `datum_id` has every `label_key`.
pr_curve_counts_df["precision"] = pr_curve_counts_df["true_positives"] / (
pr_curve_counts_df["true_positives"]
+ pr_curve_counts_df["false_positives"]
)
pr_curve_counts_df["recall"] = pr_curve_counts_df["true_positives"] / (
pr_curve_counts_df["true_positives"]
+ pr_curve_counts_df["false_negatives"]
)
pr_curve_counts_df["accuracy"] = (
pr_curve_counts_df["true_positives"]
+ pr_curve_counts_df["true_negatives"]
) / len(unique_datum_ids)
pr_curve_counts_df["f1_score"] = (
2 * pr_curve_counts_df["precision"] * pr_curve_counts_df["recall"]
) / (pr_curve_counts_df["precision"] + pr_curve_counts_df["recall"])
# any NaNs that are left are from division by zero errors
pr_curve_counts_df.fillna(-1, inplace=True)
'''

return pr_curve_counts_df
85 changes: 85 additions & 0 deletions valor-sandbox/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import time
import random
import numpy as np
import pandas as pd

import os
import linecache
import tracemalloc

random.seed(time.time())

label_values = ["cat", "dog", "bee", "rat", "cow", "fox", "ant", "owl", "bat"]

def generate_groundtruth(n, n_class=3):
n_class = min(n_class, len(label_values))
n_class - max(1, n_class)

classes = label_values[:n_class]

df = pd.DataFrame({
"datum_uid": [f"uid{i}" for i in range(n)],
"datum_id": [f"img{i}" for i in range(n)],
"id": [f"gt{i}" for i in range(n)],
"label_key": "class_label",
"label_value": [random.choice(classes) for _ in range(n)],
})

return df

def generate_predictions(n, n_class=3, preds_per_datum=1):
n_class = min(n_class, len(label_values))
n_class - max(1, n_class)

classes = label_values[:n_class]

preds_per_datum = min(1, preds_per_datum)
preds_per_datum = max(n_class, preds_per_datum)

all_labels = []
all_scores = []
for _ in range(n):
labels = random.sample(classes, preds_per_datum)
scores = [random.uniform(0,1) for _ in range(preds_per_datum)]
total = sum(scores)
for i in range(len(scores)):
scores[i] /= total

all_labels += labels
all_scores += scores

df = pd.DataFrame({
"datum_uid": np.repeat([f"uid{i}" for i in range(n)], preds_per_datum),
"datum_id": np.repeat([f"img{i}" for i in range(n)], preds_per_datum),
"id": [f"pd{i}" for i in range(n*preds_per_datum)],
"label_key": "class_label",
"label_value": all_labels,
"score": all_scores,
})

return df

def pretty_print_tracemalloc_snapshot(snapshot, key_type='lineno', limit=3):
snapshot = snapshot.filter_traces((
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
tracemalloc.Filter(False, "<unknown>"),
))
top_stats = snapshot.statistics(key_type)

print("Top %s lines" % limit)
for index, stat in enumerate(top_stats[:limit], 1):
frame = stat.traceback[0]
# replace "/path/to/module/file.py" with "module/file.py"
filename = os.sep.join(frame.filename.split(os.sep)[-2:])
print("#%s: %s:%s: %.1f KiB"
% (index, filename, frame.lineno, stat.size / 1024))
line = linecache.getline(frame.filename, frame.lineno).strip()
if line:
print(' %s' % line)

other = top_stats[limit:]
if other:
size = sum(stat.size for stat in other)
print("%s other: %.1f KiB" % (len(other), size / 1024))
total = sum(stat.size for stat in top_stats)
print("Total allocated size: %.1f KiB" % (total / 1024))
Loading

0 comments on commit 6f2ecf5

Please sign in to comment.