diff --git a/valor-sandbox/optimized_implementation.py b/valor-sandbox/optimized_implementation.py index 81c485cdf..146a5b12e 100644 --- a/valor-sandbox/optimized_implementation.py +++ b/valor-sandbox/optimized_implementation.py @@ -118,9 +118,7 @@ def _calculate_pr_curves_optimized( ) pr_curve_counts_df.fillna(0, inplace=True) - - ''' - Pretty sure there's a bug with accuracy, it assumes that each `datum_id` has every `label_key`. + pr_curve_counts_df["total_datums"] = pr_curve_counts_df["label_key"].map(total_datums_per_label_key.to_dict()) pr_curve_counts_df["precision"] = pr_curve_counts_df["true_positives"] / ( pr_curve_counts_df["true_positives"] @@ -133,13 +131,12 @@ def _calculate_pr_curves_optimized( pr_curve_counts_df["accuracy"] = ( pr_curve_counts_df["true_positives"] + pr_curve_counts_df["true_negatives"] - ) / len(unique_datum_ids) + ) / pr_curve_counts_df["total_datums"] pr_curve_counts_df["f1_score"] = ( 2 * pr_curve_counts_df["precision"] * pr_curve_counts_df["recall"] ) / (pr_curve_counts_df["precision"] + pr_curve_counts_df["recall"]) # any NaNs that are left are from division by zero errors pr_curve_counts_df.fillna(-1, inplace=True) - ''' return pr_curve_counts_df \ No newline at end of file