Skip to content

Commit 6f32544

Browse files
dpoznikglemaitre
andauthored
FIX discrete Naive Bayes model fitting for degenerate single-class case (scikit-learn#18925)
Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent 27f1c73 commit 6f32544

File tree

3 files changed

+138
-95
lines changed

3 files changed

+138
-95
lines changed

doc/whats_new/v1.0.rst

+7
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ Changelog
105105
:mod:`sklearn.naive_bayes`
106106
..........................
107107

108+
- |Fix| The `fit` and `partial_fit` methods of the discrete naive Bayes
109+
classifiers (:class:`naive_bayes.BernoulliNB`,
110+
:class:`naive_bayes.CategoricalNB`, :class:`naive_bayes.ComplementNB`,
111+
and :class:`naive_bayes.MultinomialNB`) now correctly handle the degenerate
112+
case of a single class in the training set.
113+
:pr:`18925` by :user:`David Poznik <dpoznik>`.
114+
108115
- |API| The attribute ``sigma_`` is now deprecated in
109116
:class:`naive_bayes.GaussianNB` and will be removed in 1.2.
110117
Use ``var_`` instead.

sklearn/naive_bayes.py

+28-23
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
550550
y : array-like of shape (n_samples,)
551551
Target values.
552552
553-
classes : array-like of shape (n_classes), default=None
553+
classes : array-like of shape (n_classes,), default=None
554554
List of all the classes that can possibly appear in the y vector.
555555
556556
Must be provided at the first call to partial_fit, can be omitted
@@ -569,16 +569,19 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
569569
if _check_partial_fit_first_call(self, classes):
570570
# This is the first call to partial_fit:
571571
# initialize various cumulative counters
572-
n_effective_classes = len(classes) if len(classes) > 1 else 2
573-
self._init_counters(n_effective_classes, n_features)
572+
n_classes = len(classes)
573+
self._init_counters(n_classes, n_features)
574574
self.n_features_ = n_features
575575
elif n_features != self.n_features_:
576576
msg = "Number of features %d does not match previous data %d."
577577
raise ValueError(msg % (n_features, self.n_features_))
578578

579579
Y = label_binarize(y, classes=self.classes_)
580580
if Y.shape[1] == 1:
581-
Y = np.concatenate((1 - Y, Y), axis=1)
581+
if len(self.classes_) == 2:
582+
Y = np.concatenate((1 - Y, Y), axis=1)
583+
else: # degenerate case: just one class
584+
Y = np.ones_like(Y)
582585

583586
if X.shape[0] != Y.shape[0]:
584587
msg = "X.shape[0]=%d and y.shape[0]=%d are incompatible."
@@ -634,7 +637,10 @@ def fit(self, X, y, sample_weight=None):
634637
Y = labelbin.fit_transform(y)
635638
self.classes_ = labelbin.classes_
636639
if Y.shape[1] == 1:
637-
Y = np.concatenate((1 - Y, Y), axis=1)
640+
if len(self.classes_) == 2:
641+
Y = np.concatenate((1 - Y, Y), axis=1)
642+
else: # degenerate case: just one class
643+
Y = np.ones_like(Y)
638644

639645
# LabelBinarizer().fit_transform() returns arrays with dtype=np.int64.
640646
# We convert it to np.float64 to support sample_weight consistently;
@@ -649,18 +655,17 @@ def fit(self, X, y, sample_weight=None):
649655

650656
# Count raw events from data before updating the class log prior
651657
# and feature log probas
652-
n_effective_classes = Y.shape[1]
653-
654-
self._init_counters(n_effective_classes, n_features)
658+
n_classes = Y.shape[1]
659+
self._init_counters(n_classes, n_features)
655660
self._count(X, Y)
656661
alpha = self._check_alpha()
657662
self._update_feature_log_prob(alpha)
658663
self._update_class_log_prior(class_prior=class_prior)
659664
return self
660665

661-
def _init_counters(self, n_effective_classes, n_features):
662-
self.class_count_ = np.zeros(n_effective_classes, dtype=np.float64)
663-
self.feature_count_ = np.zeros((n_effective_classes, n_features),
666+
def _init_counters(self, n_classes, n_features):
667+
self.class_count_ = np.zeros(n_classes, dtype=np.float64)
668+
self.feature_count_ = np.zeros((n_classes, n_features),
664669
dtype=np.float64)
665670

666671
# mypy error: Decorated property not supported
@@ -714,7 +719,7 @@ class MultinomialNB(_BaseDiscreteNB):
714719
Number of samples encountered for each class during fitting. This
715720
value is weighted by the sample weight when provided.
716721
717-
class_log_prior_ : ndarray of shape (n_classes, )
722+
class_log_prior_ : ndarray of shape (n_classes,)
718723
Smoothed empirical log probability for each class.
719724
720725
classes_ : ndarray of shape (n_classes,)
@@ -962,11 +967,11 @@ class BernoulliNB(_BaseDiscreteNB):
962967
963968
Attributes
964969
----------
965-
class_count_ : ndarray of shape (n_classes)
970+
class_count_ : ndarray of shape (n_classes,)
966971
Number of samples encountered for each class during fitting. This
967972
value is weighted by the sample weight when provided.
968973
969-
class_log_prior_ : ndarray of shape (n_classes)
974+
class_log_prior_ : ndarray of shape (n_classes,)
970975
Log probability of each class (smoothed).
971976
972977
classes_ : ndarray of shape (n_classes,)
@@ -1053,8 +1058,8 @@ def _update_feature_log_prob(self, alpha):
10531058

10541059
def _joint_log_likelihood(self, X):
10551060
"""Calculate the posterior log probability of the samples X"""
1056-
n_classes, n_features = self.feature_log_prob_.shape
1057-
n_samples, n_features_X = X.shape
1061+
n_features = self.feature_log_prob_.shape[1]
1062+
n_features_X = X.shape[1]
10581063

10591064
if n_features_X != n_features:
10601065
raise ValueError("Expected input with %d features, got %d instead"
@@ -1173,7 +1178,7 @@ def fit(self, X, y, sample_weight=None):
11731178
y : array-like of shape (n_samples,)
11741179
Target values.
11751180
1176-
sample_weight : array-like of shape (n_samples), default=None
1181+
sample_weight : array-like of shape (n_samples,), default=None
11771182
Weights applied to individual samples (1. for unweighted).
11781183
11791184
Returns
@@ -1207,16 +1212,16 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
12071212
total number of categories for the given feature. This can, for
12081213
instance, be achieved with the help of OrdinalEncoder.
12091214
1210-
y : array-like of shape (n_samples)
1215+
y : array-like of shape (n_samples,)
12111216
Target values.
12121217
1213-
classes : array-like of shape (n_classes), default=None
1218+
classes : array-like of shape (n_classes,), default=None
12141219
List of all the classes that can possibly appear in the y vector.
12151220
12161221
Must be provided at the first call to partial_fit, can be omitted
12171222
in subsequent calls.
12181223
1219-
sample_weight : array-like of shape (n_samples), default=None
1224+
sample_weight : array-like of shape (n_samples,), default=None
12201225
Weights applied to individual samples (1. for unweighted).
12211226
12221227
Returns
@@ -1241,9 +1246,9 @@ def _check_X_y(self, X, y):
12411246
check_non_negative(X, "CategoricalNB (input X)")
12421247
return X, y
12431248

1244-
def _init_counters(self, n_effective_classes, n_features):
1245-
self.class_count_ = np.zeros(n_effective_classes, dtype=np.float64)
1246-
self.category_count_ = [np.zeros((n_effective_classes, 0))
1249+
def _init_counters(self, n_classes, n_features):
1250+
self.class_count_ = np.zeros(n_classes, dtype=np.float64)
1251+
self.category_count_ = [np.zeros((n_classes, 0))
12471252
for _ in range(n_features)]
12481253

12491254
@staticmethod

0 commit comments

Comments
 (0)