Skip to content

Commit 7c47337

Browse files
thomasjpfanglemaitre
authored andcommitted
BUG Fixes voting named_estiamtor bug (scikit-learn#15375)
1 parent 5a2df5e commit 7c47337

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

doc/whats_new/v0.22.rst

+8
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,14 @@ Changelog
716716
- ``optimize.newton_cg`
717717
- ``random.random_choice_csc``
718718

719+
:mod:`sklearn.voting`
720+
.....................
721+
722+
- |Fix| The `named_estimators_` attribute in :class:`voting.VotingClassifier`
723+
and :class:`voting.VotingRegressor` now correctly maps to dropped estimators.
724+
Previously, the `named_estimators_` mapping was incorrect whenever one of the
725+
estimators was dropped. :pr:`15375` by `Thomas Fan`_.
726+
719727
:mod:`sklearn.isotonic`
720728
..................................
721729

sklearn/ensemble/_voting.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,13 @@ def fit(self, X, y, sample_weight=None):
7171
)
7272

7373
self.named_estimators_ = Bunch()
74-
for k, e in zip(self.estimators, self.estimators_):
75-
self.named_estimators_[k[0]] = e
74+
75+
# Uses None or 'drop' as placeholder for dropped estimators
76+
est_iter = iter(self.estimators_)
77+
for name, est in self.estimators:
78+
current_est = est if est in (None, 'drop') else next(est_iter)
79+
self.named_estimators_[name] = current_est
80+
7681
return self
7782

7883

sklearn/ensemble/tests/test_voting.py

+18
Original file line numberDiff line numberDiff line change
@@ -560,3 +560,21 @@ def test_deprecate_none_transformer(Voter, BaseEstimator):
560560
"Use the string 'drop' instead.")
561561
with pytest.warns(DeprecationWarning, match=msg):
562562
est.fit(X, y)
563+
564+
565+
# TODO: Remove drop parametrize in 0.24 when None is removed in Voting*
566+
@pytest.mark.parametrize(
567+
"Voter, BaseEstimator",
568+
[(VotingClassifier, DecisionTreeClassifier),
569+
(VotingRegressor, DecisionTreeRegressor)]
570+
)
571+
@pytest.mark.parametrize("drop", [None, 'drop'])
572+
def test_correct_named_estimator_with_drop(Voter, BaseEstimator, drop):
573+
est = Voter(estimators=[('lr', drop),
574+
('tree', BaseEstimator(random_state=0))])
575+
576+
with pytest.warns(None) as rec:
577+
est.fit(X, y)
578+
assert rec if drop is None else not rec
579+
580+
assert est.named_estimators_['lr'] == drop

0 commit comments

Comments
 (0)