Skip to content

Commit 6419f65

Browse files
glemaitreqinhanmin2014
authored andcommitted
FIX remove max_samples in RandomTreesEmbedding (scikit-learn#15693)
1 parent 63cd600 commit 6419f65

File tree

2 files changed

+3
-18
lines changed

2 files changed

+3
-18
lines changed

doc/whats_new/v0.22.rst

+1-4
Original file line numberDiff line numberDiff line change
@@ -326,13 +326,10 @@ Changelog
326326

327327
- |Enhancement| Addition of ``max_samples`` argument allows limiting
328328
size of bootstrap samples to be less than size of dataset. Added to
329-
:class:`ensemble.ForestClassifier`,
330-
:class:`ensemble.ForestRegressor`,
331329
:class:`ensemble.RandomForestClassifier`,
332330
:class:`ensemble.RandomForestRegressor`,
333331
:class:`ensemble.ExtraTreesClassifier`,
334-
:class:`ensemble.ExtraTreesRegressor`,
335-
:class:`ensemble.RandomTreesEmbedding`. :pr:`14682` by
332+
:class:`ensemble.ExtraTreesRegressor`. :pr:`14682` by
336333
:user:`Matt Hancock <notmatthancock>` and
337334
:pr:`5963` by :user:`Pablo Duboue <DrDub>`.
338335

sklearn/ensemble/_forest.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -2112,17 +2112,6 @@ class RandomTreesEmbedding(BaseForest):
21122112
and add more estimators to the ensemble, otherwise, just fit a whole
21132113
new forest. See :term:`the Glossary <warm_start>`.
21142114
2115-
max_samples : int or float, default=None
2116-
If bootstrap is True, the number of samples to draw from X
2117-
to train each base estimator.
2118-
2119-
- If None (default), then draw `X.shape[0]` samples.
2120-
- If int, then draw `max_samples` samples.
2121-
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
2122-
`max_samples` should be in the interval `(0, 1)`.
2123-
2124-
.. versionadded:: 0.22
2125-
21262115
Attributes
21272116
----------
21282117
estimators_ : list of DecisionTreeClassifier
@@ -2154,8 +2143,7 @@ def __init__(self,
21542143
n_jobs=None,
21552144
random_state=None,
21562145
verbose=0,
2157-
warm_start=False,
2158-
max_samples=None):
2146+
warm_start=False):
21592147
super().__init__(
21602148
base_estimator=ExtraTreeRegressor(),
21612149
n_estimators=n_estimators,
@@ -2170,7 +2158,7 @@ def __init__(self,
21702158
random_state=random_state,
21712159
verbose=verbose,
21722160
warm_start=warm_start,
2173-
max_samples=max_samples)
2161+
max_samples=None)
21742162

21752163
self.max_depth = max_depth
21762164
self.min_samples_split = min_samples_split

0 commit comments

Comments
 (0)