Skip to content

Commit 9af1c77

Browse files
authored
Ignore abstains in Scorer, change LabelModel default tie break policy (#1450)
1 parent a9c28a2 commit 9af1c77

File tree

4 files changed

+40
-5
lines changed

4 files changed

+40
-5
lines changed

Diff for: snorkel/analysis/scorer.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ def __init__(
4949
if metric not in METRICS:
5050
raise ValueError(f"Unrecognized metric: {metric}")
5151

52-
filter_dict = {} if abstain_label is None else {"golds": [abstain_label]}
52+
filter_dict = (
53+
{}
54+
if abstain_label is None
55+
else {"golds": [abstain_label], "preds": [abstain_label]}
56+
)
5357
self.metrics = {
5458
m: partial(metric_score, metric=m, filter_dict=filter_dict)
5559
for m in metrics

Diff for: snorkel/labeling/model/label_model.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def predict(
396396
self,
397397
L: np.ndarray,
398398
return_probs: Optional[bool] = False,
399-
tie_break_policy: str = "random",
399+
tie_break_policy: str = "abstain",
400400
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
401401
"""Return predicted labels, with ties broken according to policy.
402402
@@ -446,7 +446,7 @@ def score(
446446
L: np.ndarray,
447447
Y: np.ndarray,
448448
metrics: Optional[List[str]] = ["accuracy"],
449-
tie_break_policy: str = "random",
449+
tie_break_policy: str = "abstain",
450450
) -> Dict[str, float]:
451451
"""Calculate one or more scores from user-specified and/or user-defined metrics.
452452
@@ -455,7 +455,7 @@ def score(
455455
L
456456
An [n,m] matrix with values in {-1,0,1,...,k-1}
457457
Y
458-
Gold labels associated with datapoints in L
458+
Gold labels associated with data points in L
459459
metrics
460460
A list of metric names
461461
tie_break_policy
@@ -477,6 +477,11 @@ def score(
477477
>>> label_model.score(L, Y=np.array([1, 1, 1]), metrics=["f1"])
478478
{'f1': 0.8}
479479
"""
480+
if tie_break_policy == "abstain": # pragma: no cover
481+
logging.warning(
482+
"Metrics calculated over data points with non-abstain labels only"
483+
)
484+
480485
Y_pred, Y_prob = self.predict(
481486
L, return_probs=True, tie_break_policy=tie_break_policy
482487
)

Diff for: test/analysis/test_scorer.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,18 @@ def test_abstain_labels(self) -> None:
6868
results_expected = dict(accuracy=0.6)
6969
self.assertEqual(results, results_expected)
7070

71-
# Test abstain=-1
71+
# Test abstain=-1 for gold
7272
scorer = Scorer(metrics=["accuracy"], abstain_label=-1)
7373
results = scorer.score(golds, preds, probs)
7474
results_expected = dict(accuracy=0.75)
7575
self.assertEqual(results, results_expected)
7676

77+
# Test abstain=-1 for preds and gold
78+
abstain_preds = np.array([-1, -1, 1, 1, 0])
79+
results = scorer.score(golds, abstain_preds)
80+
results_expected = dict(accuracy=0.5)
81+
self.assertEqual(results, results_expected)
82+
7783
# Test abstain set to different value
7884
scorer = Scorer(metrics=["accuracy"], abstain_label=10)
7985
results = scorer.score(golds, preds, probs)

Diff for: test/labeling/model/test_label_model.py

+20
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,14 @@ def test_predict_proba(self):
240240
np.testing.assert_array_almost_equal(probs, true_probs)
241241

242242
def test_predict(self):
243+
# 3 LFs that always disagree/abstain leads to all abstains
244+
L = np.array([[-1, 1, 0], [0, -1, 1], [1, 0, -1]])
245+
label_model = LabelModel(cardinality=2, verbose=False)
246+
label_model.fit(L, n_epochs=100)
247+
np.testing.assert_array_almost_equal(
248+
label_model.predict(L), np.array([-1, -1, -1])
249+
)
250+
243251
L = np.array([[0, 1, 0], [0, 1, 0]])
244252
label_model = self._set_up_model(L)
245253

@@ -254,6 +262,18 @@ def test_predict(self):
254262
np.testing.assert_array_almost_equal(probs, true_probs)
255263

256264
def test_score(self):
265+
L = np.array([[1, 1, 0], [-1, -1, -1], [1, 0, 1]])
266+
Y = np.array([1, 0, 1])
267+
label_model = LabelModel(cardinality=2, verbose=False)
268+
label_model.fit(L, n_epochs=100)
269+
results = label_model.score(L, Y)
270+
np.testing.assert_array_almost_equal(
271+
label_model.predict(L), np.array([1, -1, 1])
272+
)
273+
274+
results_expected = dict(accuracy=1.0)
275+
self.assertEqual(results, results_expected)
276+
257277
L = np.array([[1, 0, 1], [1, 0, 1]])
258278
label_model = self._set_up_model(L)
259279
label_model.mu = nn.Parameter(label_model.mu_init.clone().clamp(0.01, 0.99))

0 commit comments

Comments
 (0)