@@ -550,7 +550,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
550550 y : array-like of shape (n_samples,)
551551 Target values.
552552
553- classes : array-like of shape (n_classes), default=None
553+ classes : array-like of shape (n_classes, ), default=None
554554 List of all the classes that can possibly appear in the y vector.
555555
556556 Must be provided at the first call to partial_fit, can be omitted
@@ -569,16 +569,19 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
569569 if _check_partial_fit_first_call (self , classes ):
570570 # This is the first call to partial_fit:
571571 # initialize various cumulative counters
572- n_effective_classes = len (classes ) if len ( classes ) > 1 else 2
573- self ._init_counters (n_effective_classes , n_features )
572+ n_classes = len (classes )
573+ self ._init_counters (n_classes , n_features )
574574 self .n_features_ = n_features
575575 elif n_features != self .n_features_ :
576576 msg = "Number of features %d does not match previous data %d."
577577 raise ValueError (msg % (n_features , self .n_features_ ))
578578
579579 Y = label_binarize (y , classes = self .classes_ )
580580 if Y .shape [1 ] == 1 :
581- Y = np .concatenate ((1 - Y , Y ), axis = 1 )
581+ if len (self .classes_ ) == 2 :
582+ Y = np .concatenate ((1 - Y , Y ), axis = 1 )
583+ else : # degenerate case: just one class
584+ Y = np .ones_like (Y )
582585
583586 if X .shape [0 ] != Y .shape [0 ]:
584587 msg = "X.shape[0]=%d and y.shape[0]=%d are incompatible."
@@ -634,7 +637,10 @@ def fit(self, X, y, sample_weight=None):
634637 Y = labelbin .fit_transform (y )
635638 self .classes_ = labelbin .classes_
636639 if Y .shape [1 ] == 1 :
637- Y = np .concatenate ((1 - Y , Y ), axis = 1 )
640+ if len (self .classes_ ) == 2 :
641+ Y = np .concatenate ((1 - Y , Y ), axis = 1 )
642+ else : # degenerate case: just one class
643+ Y = np .ones_like (Y )
638644
639645 # LabelBinarizer().fit_transform() returns arrays with dtype=np.int64.
640646 # We convert it to np.float64 to support sample_weight consistently;
@@ -649,18 +655,17 @@ def fit(self, X, y, sample_weight=None):
649655
650656 # Count raw events from data before updating the class log prior
651657 # and feature log probas
652- n_effective_classes = Y .shape [1 ]
653-
654- self ._init_counters (n_effective_classes , n_features )
658+ n_classes = Y .shape [1 ]
659+ self ._init_counters (n_classes , n_features )
655660 self ._count (X , Y )
656661 alpha = self ._check_alpha ()
657662 self ._update_feature_log_prob (alpha )
658663 self ._update_class_log_prior (class_prior = class_prior )
659664 return self
660665
661- def _init_counters (self , n_effective_classes , n_features ):
662- self .class_count_ = np .zeros (n_effective_classes , dtype = np .float64 )
663- self .feature_count_ = np .zeros ((n_effective_classes , n_features ),
666+ def _init_counters (self , n_classes , n_features ):
667+ self .class_count_ = np .zeros (n_classes , dtype = np .float64 )
668+ self .feature_count_ = np .zeros ((n_classes , n_features ),
664669 dtype = np .float64 )
665670
666671 # mypy error: Decorated property not supported
@@ -714,7 +719,7 @@ class MultinomialNB(_BaseDiscreteNB):
714719 Number of samples encountered for each class during fitting. This
715720 value is weighted by the sample weight when provided.
716721
717- class_log_prior_ : ndarray of shape (n_classes, )
722+ class_log_prior_ : ndarray of shape (n_classes,)
718723 Smoothed empirical log probability for each class.
719724
720725 classes_ : ndarray of shape (n_classes,)
@@ -962,11 +967,11 @@ class BernoulliNB(_BaseDiscreteNB):
962967
963968 Attributes
964969 ----------
965- class_count_ : ndarray of shape (n_classes)
970+ class_count_ : ndarray of shape (n_classes, )
966971 Number of samples encountered for each class during fitting. This
967972 value is weighted by the sample weight when provided.
968973
969- class_log_prior_ : ndarray of shape (n_classes)
974+ class_log_prior_ : ndarray of shape (n_classes, )
970975 Log probability of each class (smoothed).
971976
972977 classes_ : ndarray of shape (n_classes,)
@@ -1053,8 +1058,8 @@ def _update_feature_log_prob(self, alpha):
10531058
10541059 def _joint_log_likelihood (self , X ):
10551060 """Calculate the posterior log probability of the samples X"""
1056- n_classes , n_features = self .feature_log_prob_ .shape
1057- n_samples , n_features_X = X .shape
1061+ n_features = self .feature_log_prob_ .shape [ 1 ]
1062+ n_features_X = X .shape [ 1 ]
10581063
10591064 if n_features_X != n_features :
10601065 raise ValueError ("Expected input with %d features, got %d instead"
@@ -1173,7 +1178,7 @@ def fit(self, X, y, sample_weight=None):
11731178 y : array-like of shape (n_samples,)
11741179 Target values.
11751180
1176- sample_weight : array-like of shape (n_samples), default=None
1181+ sample_weight : array-like of shape (n_samples, ), default=None
11771182 Weights applied to individual samples (1. for unweighted).
11781183
11791184 Returns
@@ -1207,16 +1212,16 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
12071212 total number of categories for the given feature. This can, for
12081213 instance, be achieved with the help of OrdinalEncoder.
12091214
1210- y : array-like of shape (n_samples)
1215+ y : array-like of shape (n_samples, )
12111216 Target values.
12121217
1213- classes : array-like of shape (n_classes), default=None
1218+ classes : array-like of shape (n_classes, ), default=None
12141219 List of all the classes that can possibly appear in the y vector.
12151220
12161221 Must be provided at the first call to partial_fit, can be omitted
12171222 in subsequent calls.
12181223
1219- sample_weight : array-like of shape (n_samples), default=None
1224+ sample_weight : array-like of shape (n_samples, ), default=None
12201225 Weights applied to individual samples (1. for unweighted).
12211226
12221227 Returns
@@ -1241,9 +1246,9 @@ def _check_X_y(self, X, y):
12411246 check_non_negative (X , "CategoricalNB (input X)" )
12421247 return X , y
12431248
1244- def _init_counters (self , n_effective_classes , n_features ):
1245- self .class_count_ = np .zeros (n_effective_classes , dtype = np .float64 )
1246- self .category_count_ = [np .zeros ((n_effective_classes , 0 ))
1249+ def _init_counters (self , n_classes , n_features ):
1250+ self .class_count_ = np .zeros (n_classes , dtype = np .float64 )
1251+ self .category_count_ = [np .zeros ((n_classes , 0 ))
12471252 for _ in range (n_features )]
12481253
12491254 @staticmethod
0 commit comments