Skip to content

Commit 98b3c7c

Browse files
authored
FIX max_leaf_node and max_depth interaction in GBDT (scikit-learn#16183)
* fix max_leaf_node max_depth interaction * Added test * comment * what's new * simpler solution * moved and simplified test * typo
1 parent 4a3b436 commit 98b3c7c

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

doc/whats_new/v0.23.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ Changelog
112112
:user:`Reshama Shaikh <reshamas>`, and
113113
:user:`Chiara Marmo <cmarmo>`.
114114

115+
- |API| Fixed a bug in :class:`ensemble.HistGradientBoostingClassifier` and
116+
:class:`ensemble.HistGradientBoostingRegrerssor` that would not respect the
117+
`max_leaf_nodes` parameter if the criteria was reached at the same time as
118+
the `max_depth` criteria. :pr:`16183` by `Nicolas Hug`_.
119+
115120
- |Fix| Changed the convention for `max_depth` parameter of
116121
:class:`ensemble.HistGradientBoostingClassifier` and
117122
:class:`ensemble.HistGradientBoostingRegressor`. The depth now corresponds to

sklearn/ensemble/_hist_gradient_boosting/grower.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -355,16 +355,16 @@ def split_next(self):
355355

356356
self.n_nodes += 2
357357

358-
if self.max_depth is not None and depth == self.max_depth:
358+
if (self.max_leaf_nodes is not None
359+
and n_leaf_nodes == self.max_leaf_nodes):
359360
self._finalize_leaf(left_child_node)
360361
self._finalize_leaf(right_child_node)
362+
self._finalize_splittable_nodes()
361363
return left_child_node, right_child_node
362364

363-
if (self.max_leaf_nodes is not None
364-
and n_leaf_nodes == self.max_leaf_nodes):
365+
if self.max_depth is not None and depth == self.max_depth:
365366
self._finalize_leaf(left_child_node)
366367
self._finalize_leaf(right_child_node)
367-
self._finalize_splittable_nodes()
368368
return left_child_node, right_child_node
369369

370370
if left_child_node.n_samples < self.min_samples_leaf * 2:

sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,17 @@ def test_string_target_early_stopping(scoring):
445445
y = np.array(['x'] * 50 + ['y'] * 50, dtype=object)
446446
gbrt = HistGradientBoostingClassifier(n_iter_no_change=10, scoring=scoring)
447447
gbrt.fit(X, y)
448+
449+
450+
def test_max_depth_max_leaf_nodes():
451+
# Non regression test for
452+
# https://github.com/scikit-learn/scikit-learn/issues/16179
453+
# there was a bug when the max_depth and the max_leaf_nodes criteria were
454+
# met at the same time, which would lead to max_leaf_nodes not being
455+
# respected.
456+
X, y = make_classification(random_state=0)
457+
est = HistGradientBoostingClassifier(max_depth=2, max_leaf_nodes=3,
458+
max_iter=1).fit(X, y)
459+
tree = est._predictors[0][0]
460+
assert tree.get_max_depth() == 2
461+
assert tree.get_n_leaf_nodes() == 3 # would be 4 prior to bug fix

0 commit comments

Comments
 (0)