Skip to content

Commit 7b69304

Browse files
committed
Add comments to algorithm for loss
1 parent 019cb58 commit 7b69304

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

asreviewcontrib/insights/algorithms.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,31 @@ def _recall_values(labels, x_absolute=False, y_absolute=False):
2020

2121

2222
def _loss_value(labels):
23-
def _auc_trapezoidal(x, y):
24-
x = np.array(x)
25-
y = np.array(y)
26-
return np.sum((y[1:] + y[:-1]) / 2 * np.diff(x))
27-
2823
Ny = sum(labels)
2924
Nx = len(labels)
3025

31-
best_auc = Nx * Ny - 0.5 - ((Ny * Ny) / 2)
32-
actual_auc = _auc_trapezoidal(*_recall_values(labels,
33-
x_absolute=True,
34-
y_absolute=True))
26+
# The best AUC represents the entire area under the perfect curve, which is
27+
# the total area Nx * Ny, minus the area above the perfect curve (which is
28+
# the sum of a series with a formula (Ny * Ny) / 2) plus 0.5 to account for
29+
# the boundary.
30+
best_auc = Nx * Ny - (((Ny * Ny) / 2) + 0.5)
31+
32+
# Compute recall values (y) based on the provided labels. We don't need x
33+
# values because the points are uniformly spaced.
34+
y = np.array(_recall_values(labels, x_absolute=True, y_absolute=True)[1])
35+
36+
# The actual AUC is calculated by approximating the area under the curve
37+
# using the trapezoidal rule. (y[1:] + y[:-1]) / 2 takes the average height
38+
# between consecutive y values, and we sum them up.
39+
actual_auc = np.sum((y[1:] + y[:-1]) / 2)
40+
41+
# The worst AUC represents the area under the worst-case step curve, which
42+
# is simply the area under the recall curve where all positive labels are
43+
# clumped at the end, calculated as (Ny * Ny) / 2.
3544
worst_auc = ((Ny * Ny) / 2)
3645

46+
# The normalized loss is the difference between the best AUC and the actual
47+
# AUC, normalized by the range between the best and worst AUCs.
3748
normalized_loss = (best_auc - actual_auc) / (best_auc - worst_auc) if best_auc != worst_auc else 0 # noqa: E501
3849

3950
return normalized_loss

0 commit comments

Comments
 (0)