Skip to content

Commit 2f078bf

Browse files
StefanieSengerogrisellorentzenchrlucyleeow
authored
MNT Improve exception handling for invalid labels in cohen_kappa_score (scikit-learn#31175)
Co-authored-by: Olivier Grisel <[email protected]> Co-authored-by: Christian Lorentzen <[email protected]> Co-authored-by: Lucy Liu <[email protected]>
1 parent 86d099e commit 2f078bf

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

sklearn/metrics/_classification.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,9 @@ class labels [2]_.
832832
labels : array-like of shape (n_classes,), default=None
833833
List of labels to index the matrix. This may be used to select a
834834
subset of labels. If `None`, all labels that appear at least once in
835-
``y1`` or ``y2`` are used.
835+
``y1`` or ``y2`` are used. Note that at least one label in `labels` must be
836+
present in `y1`, even though this function is otherwise agnostic to the order
837+
of `y1` and `y2`.
836838
837839
weights : {'linear', 'quadratic'}, default=None
838840
Weighting type to calculate the score. `None` means not weighted;
@@ -866,7 +868,18 @@ class labels [2]_.
866868
>>> cohen_kappa_score(y1, y2)
867869
0.6875
868870
"""
869-
confusion = confusion_matrix(y1, y2, labels=labels, sample_weight=sample_weight)
871+
try:
872+
confusion = confusion_matrix(y1, y2, labels=labels, sample_weight=sample_weight)
873+
except ValueError as e:
874+
if "At least one label specified must be in y_true" in str(e):
875+
msg = (
876+
"At least one label in `labels` must be present in `y1` (even though "
877+
"`cohen_kappa_score` is otherwise agnostic to the order of `y1` and "
878+
"`y2`)."
879+
)
880+
raise ValueError(msg) from e
881+
raise
882+
870883
n_classes = confusion.shape[0]
871884
sum0 = np.sum(confusion, axis=0)
872885
sum1 = np.sum(confusion, axis=1)

sklearn/metrics/tests/test_classification.py

+11
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,17 @@ def test_cohen_kappa():
926926
)
927927

928928

929+
def test_cohen_kappa_score_error_wrong_label():
930+
"""Test that correct error is raised when users pass labels that are not in y1."""
931+
labels = [1, 2]
932+
y1 = np.array(["a"] * 5 + ["b"] * 5)
933+
y2 = np.array(["b"] * 10)
934+
with pytest.raises(
935+
ValueError, match="At least one label in `labels` must be present in `y1`"
936+
):
937+
cohen_kappa_score(y1, y2, labels=labels)
938+
939+
929940
@pytest.mark.parametrize("zero_division", [0, 1, np.nan])
930941
@pytest.mark.parametrize("y_true, y_pred", [([0], [0])])
931942
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)