1
+ import pandas as pd
2
+ from collections import defaultdict
3
+
4
+ def _calculate_pr_curves_optimized (
5
+ prediction_df : pd .DataFrame ,
6
+ groundtruth_df : pd .DataFrame ,
7
+ metrics_to_return : list ,
8
+ pr_curve_max_examples : int ,
9
+ ):
10
+ joint_df = (
11
+ pd .merge (
12
+ groundtruth_df ,
13
+ prediction_df ,
14
+ on = ["datum_id" , "datum_uid" , "label_key" ],
15
+ how = "inner" ,
16
+ suffixes = ("_gt" , "_pd" ),
17
+ ).loc [
18
+ :,
19
+ [
20
+ "datum_uid" ,
21
+ "datum_id" ,
22
+ "label_key" ,
23
+ "label_value_gt" ,
24
+ "id_gt" ,
25
+ "label_value_pd" ,
26
+ "score" ,
27
+ "id_pd" ,
28
+ ],
29
+ ]
30
+ )
31
+
32
+ total_datums_per_label_key = joint_df .drop_duplicates (["datum_uid" , "datum_id" , "label_key" ])["label_key" ].value_counts ()
33
+ total_label_values_per_label_key = joint_df .drop_duplicates (["datum_uid" , "datum_id" , "label_key" ])[["label_key" , "label_value_gt" ]].value_counts ()
34
+
35
+ joint_df = joint_df .assign (
36
+ threshold_index = lambda chain_df : (
37
+ ((joint_df ["score" ] * 100 ) // 5 ).astype ("int32" )
38
+ ),
39
+ is_label_match = lambda chain_df : (
40
+ (chain_df ["label_value_pd" ] == chain_df ["label_value_gt" ])
41
+ )
42
+ )
43
+
44
+
45
+ true_positives = joint_df [joint_df ["is_label_match" ] == True ][["label_key" , "label_value_gt" , "threshold_index" ]].value_counts ()
46
+ ## true_positives = true_positives.reset_index(2).sort_values("threshold_index").groupby(["label_key", "label_value_gt"]).cumsum()
47
+ false_positives = joint_df [joint_df ["is_label_match" ] == False ][["label_key" , "label_value_pd" , "threshold_index" ]].value_counts ()
48
+ ## false_positives = false_positives.reset_index(2).sort_values("threshold_index").groupby(["label_key", "label_value_pd"]).cumsum()
49
+
50
+ dd = defaultdict (lambda : 0 )
51
+ confidence_thresholds = [x / 100 for x in range (5 , 100 , 5 )]
52
+
53
+ tps_keys = []
54
+ tps_values = []
55
+ tps_confidence = []
56
+ tps_cumulative = []
57
+ fns_cumulative = []
58
+
59
+ fps_keys = []
60
+ fps_values = []
61
+ fps_confidence = []
62
+ fps_cumulative = []
63
+ tns_cumulative = []
64
+
65
+ ## Not sure what the efficient way of doing this is in pandas
66
+ for label_key in true_positives .keys ().get_level_values (0 ).unique ():
67
+ for label_value in true_positives .keys ().get_level_values (1 ).unique ():
68
+ dd = true_positives [label_key ][label_value ].to_dict (into = dd )
69
+ cumulative_true_positive = [0 ] * 21
70
+ cumulative_false_negative = [0 ] * 21
71
+ for threshold_index in range (19 , - 1 , - 1 ):
72
+ cumulative_true_positive [threshold_index ] = cumulative_true_positive [threshold_index + 1 ] + dd [threshold_index ]
73
+ cumulative_false_negative [threshold_index ] = total_label_values_per_label_key [label_key ][label_value ] - cumulative_true_positive [threshold_index ]
74
+
75
+ tps_keys += [label_key ] * 19
76
+ tps_values += [label_value ] * 19
77
+ tps_confidence += confidence_thresholds
78
+ tps_cumulative += cumulative_true_positive [1 :- 1 ]
79
+ fns_cumulative += cumulative_false_negative [1 :- 1 ]
80
+
81
+ ## Not sure what the efficient way of doing this is in pandas
82
+ for label_key in false_positives .keys ().get_level_values (0 ).unique ():
83
+ for label_value in false_positives .keys ().get_level_values (1 ).unique ():
84
+ dd = false_positives [label_key ][label_value ].to_dict (into = dd )
85
+ cumulative_false_positive = [0 ] * 21
86
+ cumulative_true_negative = [0 ] * 21
87
+ for threshold_index in range (19 , - 1 , - 1 ):
88
+ cumulative_false_positive [threshold_index ] = cumulative_false_positive [threshold_index + 1 ] + dd [threshold_index ]
89
+ cumulative_true_negative [threshold_index ] = total_datums_per_label_key [label_key ] - total_label_values_per_label_key [label_key ][label_value ] - cumulative_false_positive [threshold_index ]
90
+
91
+ fps_keys += [label_key ] * 19
92
+ fps_values += [label_value ] * 19
93
+ fps_confidence += confidence_thresholds
94
+ fps_cumulative += cumulative_false_positive [1 :- 1 ]
95
+ tns_cumulative += cumulative_true_negative [1 :- 1 ]
96
+
97
+ tps_df = pd .DataFrame ({
98
+ "label_key" : tps_keys ,
99
+ "label_value" : tps_values ,
100
+ "confidence_threshold" : tps_confidence ,
101
+ "true_positives" : tps_cumulative ,
102
+ "false_negatives" : fns_cumulative ,
103
+ })
104
+
105
+ fps_df = pd .DataFrame ({
106
+ "label_key" : fps_keys ,
107
+ "label_value" : fps_values ,
108
+ "confidence_threshold" : fps_confidence ,
109
+ "false_positives" : fps_cumulative ,
110
+ "true_negatives" : tns_cumulative ,
111
+ })
112
+
113
+ pr_curve_counts_df = pd .merge (
114
+ tps_df ,
115
+ fps_df ,
116
+ on = ["label_key" , "label_value" , "confidence_threshold" ],
117
+ how = "outer" ,
118
+ )
119
+
120
+ pr_curve_counts_df .fillna (0 , inplace = True )
121
+
122
+ '''
123
+ Pretty sure there's a bug with accuracy, it assumes that each `datum_id` has every `label_key`.
124
+
125
+ pr_curve_counts_df["precision"] = pr_curve_counts_df["true_positives"] / (
126
+ pr_curve_counts_df["true_positives"]
127
+ + pr_curve_counts_df["false_positives"]
128
+ )
129
+ pr_curve_counts_df["recall"] = pr_curve_counts_df["true_positives"] / (
130
+ pr_curve_counts_df["true_positives"]
131
+ + pr_curve_counts_df["false_negatives"]
132
+ )
133
+ pr_curve_counts_df["accuracy"] = (
134
+ pr_curve_counts_df["true_positives"]
135
+ + pr_curve_counts_df["true_negatives"]
136
+ ) / len(unique_datum_ids)
137
+ pr_curve_counts_df["f1_score"] = (
138
+ 2 * pr_curve_counts_df["precision"] * pr_curve_counts_df["recall"]
139
+ ) / (pr_curve_counts_df["precision"] + pr_curve_counts_df["recall"])
140
+
141
+ # any NaNs that are left are from division by zero errors
142
+ pr_curve_counts_df.fillna(-1, inplace=True)
143
+ '''
144
+
145
+ return pr_curve_counts_df
0 commit comments