File tree Expand file tree Collapse file tree 3 files changed +23
-4
lines changed
sklearn/ensemble/_hist_gradient_boosting Expand file tree Collapse file tree 3 files changed +23
-4
lines changed Original file line number Diff line number Diff line change @@ -112,6 +112,11 @@ Changelog
112112 :user: `Reshama Shaikh <reshamas> `, and
113113 :user: `Chiara Marmo <cmarmo> `.
114114
115+ - |API | Fixed a bug in :class: `ensemble.HistGradientBoostingClassifier ` and
116+ :class: `ensemble.HistGradientBoostingRegrerssor ` that would not respect the
117+ `max_leaf_nodes ` parameter if the criteria was reached at the same time as
118+ the `max_depth ` criteria. :pr: `16183 ` by `Nicolas Hug `_.
119+
115120- |Fix | Changed the convention for `max_depth ` parameter of
116121 :class: `ensemble.HistGradientBoostingClassifier ` and
117122 :class: `ensemble.HistGradientBoostingRegressor `. The depth now corresponds to
Original file line number Diff line number Diff line change @@ -355,16 +355,16 @@ def split_next(self):
355355
356356 self .n_nodes += 2
357357
358- if self .max_depth is not None and depth == self .max_depth :
358+ if (self .max_leaf_nodes is not None
359+ and n_leaf_nodes == self .max_leaf_nodes ):
359360 self ._finalize_leaf (left_child_node )
360361 self ._finalize_leaf (right_child_node )
362+ self ._finalize_splittable_nodes ()
361363 return left_child_node , right_child_node
362364
363- if (self .max_leaf_nodes is not None
364- and n_leaf_nodes == self .max_leaf_nodes ):
365+ if self .max_depth is not None and depth == self .max_depth :
365366 self ._finalize_leaf (left_child_node )
366367 self ._finalize_leaf (right_child_node )
367- self ._finalize_splittable_nodes ()
368368 return left_child_node , right_child_node
369369
370370 if left_child_node .n_samples < self .min_samples_leaf * 2 :
Original file line number Diff line number Diff line change @@ -445,3 +445,17 @@ def test_string_target_early_stopping(scoring):
445445 y = np .array (['x' ] * 50 + ['y' ] * 50 , dtype = object )
446446 gbrt = HistGradientBoostingClassifier (n_iter_no_change = 10 , scoring = scoring )
447447 gbrt .fit (X , y )
448+
449+
450+ def test_max_depth_max_leaf_nodes ():
451+ # Non regression test for
452+ # https://github.com/scikit-learn/scikit-learn/issues/16179
453+ # there was a bug when the max_depth and the max_leaf_nodes criteria were
454+ # met at the same time, which would lead to max_leaf_nodes not being
455+ # respected.
456+ X , y = make_classification (random_state = 0 )
457+ est = HistGradientBoostingClassifier (max_depth = 2 , max_leaf_nodes = 3 ,
458+ max_iter = 1 ).fit (X , y )
459+ tree = est ._predictors [0 ][0 ]
460+ assert tree .get_max_depth () == 2
461+ assert tree .get_n_leaf_nodes () == 3 # would be 4 prior to bug fix
You can’t perform that action at this time.
0 commit comments