File tree Expand file tree Collapse file tree 3 files changed +23
-4
lines changed
sklearn/ensemble/_hist_gradient_boosting Expand file tree Collapse file tree 3 files changed +23
-4
lines changed Original file line number Diff line number Diff line change @@ -112,6 +112,11 @@ Changelog
112
112
:user: `Reshama Shaikh <reshamas> `, and
113
113
:user: `Chiara Marmo <cmarmo> `.
114
114
115
+ - |API | Fixed a bug in :class: `ensemble.HistGradientBoostingClassifier ` and
116
+ :class: `ensemble.HistGradientBoostingRegrerssor ` that would not respect the
117
+ `max_leaf_nodes ` parameter if the criteria was reached at the same time as
118
+ the `max_depth ` criteria. :pr: `16183 ` by `Nicolas Hug `_.
119
+
115
120
- |Fix | Changed the convention for `max_depth ` parameter of
116
121
:class: `ensemble.HistGradientBoostingClassifier ` and
117
122
:class: `ensemble.HistGradientBoostingRegressor `. The depth now corresponds to
Original file line number Diff line number Diff line change @@ -355,16 +355,16 @@ def split_next(self):
355
355
356
356
self .n_nodes += 2
357
357
358
- if self .max_depth is not None and depth == self .max_depth :
358
+ if (self .max_leaf_nodes is not None
359
+ and n_leaf_nodes == self .max_leaf_nodes ):
359
360
self ._finalize_leaf (left_child_node )
360
361
self ._finalize_leaf (right_child_node )
362
+ self ._finalize_splittable_nodes ()
361
363
return left_child_node , right_child_node
362
364
363
- if (self .max_leaf_nodes is not None
364
- and n_leaf_nodes == self .max_leaf_nodes ):
365
+ if self .max_depth is not None and depth == self .max_depth :
365
366
self ._finalize_leaf (left_child_node )
366
367
self ._finalize_leaf (right_child_node )
367
- self ._finalize_splittable_nodes ()
368
368
return left_child_node , right_child_node
369
369
370
370
if left_child_node .n_samples < self .min_samples_leaf * 2 :
Original file line number Diff line number Diff line change @@ -445,3 +445,17 @@ def test_string_target_early_stopping(scoring):
445
445
y = np .array (['x' ] * 50 + ['y' ] * 50 , dtype = object )
446
446
gbrt = HistGradientBoostingClassifier (n_iter_no_change = 10 , scoring = scoring )
447
447
gbrt .fit (X , y )
448
+
449
+
450
+ def test_max_depth_max_leaf_nodes ():
451
+ # Non regression test for
452
+ # https://github.com/scikit-learn/scikit-learn/issues/16179
453
+ # there was a bug when the max_depth and the max_leaf_nodes criteria were
454
+ # met at the same time, which would lead to max_leaf_nodes not being
455
+ # respected.
456
+ X , y = make_classification (random_state = 0 )
457
+ est = HistGradientBoostingClassifier (max_depth = 2 , max_leaf_nodes = 3 ,
458
+ max_iter = 1 ).fit (X , y )
459
+ tree = est ._predictors [0 ][0 ]
460
+ assert tree .get_max_depth () == 2
461
+ assert tree .get_n_leaf_nodes () == 3 # would be 4 prior to bug fix
You can’t perform that action at this time.
0 commit comments